Skip to content

Commit 49eceab

Browse files
authored
Merge branch 'main' into depr-to-tensor
2 parents d9265f0 + 60449a4 commit 49eceab

File tree

2 files changed

+30
-20
lines changed

2 files changed

+30
-20
lines changed

test/test_prototype_transforms_functional.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -199,21 +199,31 @@ def resize_bounding_box():
199199
yield SampleInput(bounding_box, size=size, image_size=bounding_box.image_size)
200200

201201

202-
class TestKernelsCommon:
203-
@pytest.mark.parametrize("functional_info", FUNCTIONAL_INFOS, ids=lambda functional_info: functional_info.name)
204-
def test_scriptable(self, functional_info):
205-
jit.script(functional_info.functional)
206-
207-
@pytest.mark.parametrize(
208-
("functional_info", "sample_input"),
209-
[
210-
pytest.param(functional_info, sample_input, id=f"{functional_info.name}-{idx}")
211-
for functional_info in FUNCTIONAL_INFOS
212-
for idx, sample_input in enumerate(functional_info.sample_inputs())
213-
],
214-
)
215-
def test_eager_vs_scripted(self, functional_info, sample_input):
216-
eager = functional_info(sample_input)
217-
scripted = jit.script(functional_info.functional)(*sample_input.args, **sample_input.kwargs)
218-
219-
torch.testing.assert_close(eager, scripted)
202+
@pytest.mark.parametrize(
203+
"kernel",
204+
[
205+
pytest.param(kernel, id=name)
206+
for name, kernel in F.__dict__.items()
207+
if not name.startswith("_")
208+
and callable(kernel)
209+
and any(feature_type in name for feature_type in {"image", "segmentation_mask", "bounding_box", "label"})
210+
and "pil" not in name
211+
],
212+
)
213+
def test_scriptable(kernel):
214+
jit.script(kernel)
215+
216+
217+
@pytest.mark.parametrize(
218+
("functional_info", "sample_input"),
219+
[
220+
pytest.param(functional_info, sample_input, id=f"{functional_info.name}-{idx}")
221+
for functional_info in FUNCTIONAL_INFOS
222+
for idx, sample_input in enumerate(functional_info.sample_inputs())
223+
],
224+
)
225+
def test_eager_vs_scripted(functional_info, sample_input):
226+
eager = functional_info(sample_input)
227+
scripted = jit.script(functional_info.functional)(*sample_input.args, **sample_input.kwargs)
228+
229+
torch.testing.assert_close(eager, scripted)

torchvision/prototype/transforms/functional/_type_conversion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import unittest.mock
2-
from typing import Dict, Any, Tuple, cast
2+
from typing import Dict, Any, Tuple
33

44
import numpy as np
55
import PIL.Image
@@ -22,4 +22,4 @@ def decode_video_with_av(encoded_video: torch.Tensor) -> Tuple[torch.Tensor, tor
2222

2323

2424
def label_to_one_hot(label: torch.Tensor, *, num_categories: int) -> torch.Tensor:
25-
return cast(torch.Tensor, one_hot(label, num_classes=num_categories))
25+
return one_hot(label, num_classes=num_categories) # type: ignore[no-any-return]

0 commit comments

Comments
 (0)