Skip to content

Commit e287fc1

Browse files
committed
update to include five ten crop and resized crop, use placeholder transforms for flip and resize for now
1 parent 6a0035d commit e287fc1

File tree

3 files changed

+88
-15
lines changed

3 files changed

+88
-15
lines changed

test/test_transforms_v2.py

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3551,11 +3551,8 @@ def test_functional_image_correctness(self, kwargs, make_input):
35513551

35523552
actual = F.crop(image, **kwargs)
35533553

3554-
if make_input == make_image_cvcuda:
3555-
actual = F.cvcuda_to_tensor(actual).to(device="cpu")
3556-
actual = actual.squeeze(0)
3557-
image = F.cvcuda_to_tensor(image).to(device="cpu")
3558-
image = image.squeeze(0)
3554+
if make_input is make_image_cvcuda:
3555+
image = cvcuda_to_pil_compatible_tensor(image)
35593556

35603557
expected = F.to_image(F.crop(F.to_pil_image(image), **kwargs))
35613558

@@ -3676,15 +3673,15 @@ def test_transform_image_correctness(self, param, value, seed, make_input):
36763673

36773674
torch.manual_seed(seed)
36783675

3679-
if make_input == make_image_cvcuda:
3676+
if make_input is make_image_cvcuda:
36803677
image = cvcuda_to_pil_compatible_tensor(image)
36813678

36823679
expected = F.to_image(transform(F.to_pil_image(image)))
36833680

36843681
if make_input == make_image_cvcuda and will_pad:
36853682
# when padding is applied, CV-CUDA will always fill with zeros
36863683
# cannot use assert_equal since it will fail unless random is all zeros
3687-
torch.testing.assert_close(actual, expected, rtol=0, atol=get_max_value(image.dtype))
3684+
assert_close(actual, expected, rtol=0, atol=get_max_value(image.dtype))
36883685
else:
36893686
assert_equal(actual, expected)
36903687

@@ -4510,6 +4507,9 @@ def test_kernel(self, kernel, make_input):
45104507
make_segmentation_mask,
45114508
make_video,
45124509
make_keypoints,
4510+
pytest.param(
4511+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
4512+
),
45134513
],
45144514
)
45154515
def test_functional(self, make_input):
@@ -4526,9 +4526,16 @@ def test_functional(self, make_input):
45264526
(F.resized_crop_mask, tv_tensors.Mask),
45274527
(F.resized_crop_video, tv_tensors.Video),
45284528
(F.resized_crop_keypoints, tv_tensors.KeyPoints),
4529+
pytest.param(
4530+
F.resized_crop_image,
4531+
"cvcuda.Tensor",
4532+
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"),
4533+
),
45294534
],
45304535
)
45314536
def test_functional_signature(self, kernel, input_type):
4537+
if input_type == "cvcuda.Tensor":
4538+
input_type = _import_cvcuda().Tensor
45324539
check_functional_kernel_signature_match(F.resized_crop, kernel=kernel, input_type=input_type)
45334540

45344541
@param_value_parametrization(
@@ -4545,6 +4552,9 @@ def test_functional_signature(self, kernel, input_type):
45454552
make_segmentation_mask,
45464553
make_video,
45474554
make_keypoints,
4555+
pytest.param(
4556+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
4557+
),
45484558
],
45494559
)
45504560
def test_transform(self, param, value, make_input):
@@ -4556,20 +4566,37 @@ def test_transform(self, param, value, make_input):
45564566

45574567
# `InterpolationMode.NEAREST` is modeled after the buggy `INTER_NEAREST` interpolation of CV2.
45584568
# The PIL equivalent of `InterpolationMode.NEAREST` is `InterpolationMode.NEAREST_EXACT`
4569+
@pytest.mark.parametrize(
4570+
"make_input",
4571+
[
4572+
make_image,
4573+
pytest.param(
4574+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
4575+
),
4576+
],
4577+
)
45594578
@pytest.mark.parametrize("interpolation", set(INTERPOLATION_MODES) - {transforms.InterpolationMode.NEAREST})
4560-
def test_functional_image_correctness(self, interpolation):
4561-
image = make_image(self.INPUT_SIZE, dtype=torch.uint8)
4579+
def test_functional_image_correctness(self, make_input, interpolation):
4580+
image = make_input(self.INPUT_SIZE, dtype=torch.uint8)
45624581

45634582
actual = F.resized_crop(
45644583
image, **self.CROP_KWARGS, size=self.OUTPUT_SIZE, interpolation=interpolation, antialias=True
45654584
)
4585+
4586+
if make_input is make_image_cvcuda:
4587+
image = cvcuda_to_pil_compatible_tensor(image)
4588+
45664589
expected = F.to_image(
45674590
F.resized_crop(
45684591
F.to_pil_image(image), **self.CROP_KWARGS, size=self.OUTPUT_SIZE, interpolation=interpolation
45694592
)
45704593
)
45714594

4572-
torch.testing.assert_close(actual, expected, atol=1, rtol=0)
4595+
atol = 1
4596+
if make_input is make_image_cvcuda and interpolation == transforms.InterpolationMode.BICUBIC:
4597+
# CV-CUDA BICUBIC differs from PIL ground truth BICUBIC
4598+
atol = 10
4599+
assert_close(actual, expected, atol=atol, rtol=0)
45734600

45744601
def _reference_resized_crop_bounding_boxes(self, bounding_boxes, *, top, left, height, width, size):
45754602
new_height, new_width = size
@@ -5044,7 +5071,7 @@ def test_image_correctness(self, output_size, make_input, fn):
50445071

50455072
actual = fn(image, output_size)
50465073

5047-
if make_input == make_image_cvcuda:
5074+
if make_input is make_image_cvcuda:
50485075
image = cvcuda_to_pil_compatible_tensor(image)
50495076

50505077
expected = F.to_image(F.center_crop(F.to_pil_image(image), output_size=output_size))

torchvision/transforms/v2/_geometry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,8 @@ class RandomResizedCrop(Transform):
255255

256256
_v1_transform_cls = _transforms.RandomResizedCrop
257257

258+
_transformed_types = Transform._transformed_types + (is_cvcuda_tensor,)
259+
258260
def __init__(
259261
self,
260262
size: Union[int, Sequence[int]],

torchvision/transforms/v2/functional/_geometry.py

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,32 @@ def resize_video(
618618
return resize_image(video, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias)
619619

620620

621+
def _resize_cvcuda(
622+
image: "cvcuda.Tensor",
623+
size: Optional[list[int]],
624+
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
625+
max_size: Optional[int] = None,
626+
antialias: Optional[bool] = True,
627+
) -> "cvcuda.Tensor":
628+
# placeholder func for now, will be handled in PR for resize alone
629+
# since placeholder convert to from torch tensor and use resize_image
630+
from ._type_conversion import cvcuda_to_tensor, to_cvcuda_tensor
631+
632+
return to_cvcuda_tensor(
633+
resize_image(
634+
cvcuda_to_tensor(image),
635+
size=size,
636+
interpolation=interpolation,
637+
max_size=max_size,
638+
antialias=antialias,
639+
)
640+
)
641+
642+
643+
if CVCUDA_AVAILABLE:
644+
_register_kernel_internal(resize, _import_cvcuda().Tensor)(_resize_cvcuda)
645+
646+
621647
def affine(
622648
inpt: torch.Tensor,
623649
angle: Union[int, float],
@@ -2959,6 +2985,24 @@ def resized_crop_video(
29592985
)
29602986

29612987

2988+
def _resized_crop_cvcuda(
2989+
image: "cvcuda.Tensor",
2990+
top: int,
2991+
left: int,
2992+
height: int,
2993+
width: int,
2994+
size: list[int],
2995+
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
2996+
antialias: Optional[bool] = True,
2997+
) -> "cvcuda.Tensor":
2998+
image = _crop_cvcuda(image, top, left, height, width)
2999+
return _resize_cvcuda(image, size, interpolation=interpolation, antialias=antialias)
3000+
3001+
3002+
if CVCUDA_AVAILABLE:
3003+
_register_kernel_internal(resized_crop, _import_cvcuda().Tensor)(_resized_crop_cvcuda)
3004+
3005+
29623006
def five_crop(
29633007
inpt: torch.Tensor, size: list[int]
29643008
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
@@ -3037,15 +3081,15 @@ def _five_crop_cvcuda(
30373081
size: list[int],
30383082
) -> tuple["cvcuda.Tensor", "cvcuda.Tensor", "cvcuda.Tensor", "cvcuda.Tensor", "cvcuda.Tensor"]:
30393083
crop_height, crop_width = _parse_five_crop_size(size)
3040-
image_height, image_width = image.shape[-2:]
3084+
image_height, image_width = image.shape[1], image.shape[2]
30413085

30423086
if crop_width > image_width or crop_height > image_height:
30433087
raise ValueError(f"Requested crop size {size} is bigger than input size {(image_height, image_width)}")
30443088

30453089
tl = _crop_cvcuda(image, 0, 0, crop_height, crop_width)
3046-
tr = _crop_cvcuda(image, 0, image_width - crop_height, crop_width, crop_height)
3047-
bl = _crop_cvcuda(image, image_height - crop_height, 0, crop_width, crop_height)
3048-
br = _crop_cvcuda(image, image_height - crop_height, image_width - crop_width, crop_width, crop_height)
3090+
tr = _crop_cvcuda(image, 0, image_width - crop_width, crop_height, crop_width)
3091+
bl = _crop_cvcuda(image, image_height - crop_height, 0, crop_height, crop_width)
3092+
br = _crop_cvcuda(image, image_height - crop_height, image_width - crop_width, crop_height, crop_width)
30493093
center = _center_crop_cvcuda(image, [crop_height, crop_width])
30503094

30513095
return tl, tr, bl, br, center

0 commit comments

Comments
 (0)