Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 82 additions & 20 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torchvision.transforms.v2 as transforms

from common_utils import (
assert_close,
assert_equal,
cache,
cpu_and_cuda,
Expand All @@ -41,7 +42,6 @@
)

from torch import nn
from torch.testing import assert_close
from torch.utils._pytree import tree_flatten, tree_map
from torch.utils.data import DataLoader, default_collate
from torchvision import tv_tensors
Expand Down Expand Up @@ -3778,17 +3778,17 @@ def test_kernel_image(self, dtype, device):
@pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel_image_inplace(self, dtype, device):
input = make_image(self.INPUT_SIZE, dtype=dtype, device=device)
input_version = input._version
inpt = make_image(self.INPUT_SIZE, dtype=dtype, device=device)
input_version = inpt._version

output_out_of_place = F.erase_image(input, **self.FUNCTIONAL_KWARGS)
assert output_out_of_place.data_ptr() != input.data_ptr()
assert output_out_of_place is not input
output_out_of_place = F.erase_image(inpt, **self.FUNCTIONAL_KWARGS)
assert output_out_of_place.data_ptr() != inpt.data_ptr()
assert output_out_of_place is not inpt

output_inplace = F.erase_image(input, **self.FUNCTIONAL_KWARGS, inplace=True)
assert output_inplace.data_ptr() == input.data_ptr()
output_inplace = F.erase_image(inpt, **self.FUNCTIONAL_KWARGS, inplace=True)
assert output_inplace.data_ptr() == inpt.data_ptr()
assert output_inplace._version > input_version
assert output_inplace is input
assert output_inplace is inpt

assert_equal(output_inplace, output_out_of_place)

Expand All @@ -3797,7 +3797,15 @@ def test_kernel_video(self):

@pytest.mark.parametrize(
"make_input",
[make_image_tensor, make_image_pil, make_image, make_video],
[
make_image_tensor,
make_image_pil,
make_image,
make_video,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available")
),
],
)
def test_functional(self, make_input):
check_functional(F.erase, make_input(), **self.FUNCTIONAL_KWARGS)
Expand All @@ -3809,25 +3817,48 @@ def test_functional(self, make_input):
(F._augment._erase_image_pil, PIL.Image.Image),
(F.erase_image, tv_tensors.Image),
(F.erase_video, tv_tensors.Video),
pytest.param(
F._augment._erase_image_cvcuda,
None,
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available"),
),
],
)
def test_functional_signature(self, kernel, input_type):
if kernel is F._augment._erase_image_cvcuda:
input_type = _import_cvcuda().Tensor
check_functional_kernel_signature_match(F.erase, kernel=kernel, input_type=input_type)

@pytest.mark.parametrize(
"make_input",
[make_image_tensor, make_image_pil, make_image, make_video],
[
make_image_tensor,
make_image_pil,
make_image,
make_video,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available")
),
],
)
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_transform(self, make_input, device):
input = make_input(device=device)
inpt = make_input(device=device)

with pytest.warns(UserWarning, match="currently passing through inputs of type"):
# shouldn't get a warning for cvcuda
if make_input is make_image_cvcuda:
check_transform(
transforms.RandomErasing(p=1),
input,
check_v1_compatibility=not isinstance(input, PIL.Image.Image),
inpt,
check_v1_compatibility=False,
)
else:
with pytest.warns(UserWarning, match="currently passing through inputs of type"):
check_transform(
transforms.RandomErasing(p=1),
inpt,
check_v1_compatibility=not isinstance(inpt, PIL.Image.Image),
)

def _reference_erase_image(self, image, *, i, j, h, w, v):
mask = torch.zeros_like(image, dtype=torch.bool)
Expand All @@ -3842,16 +3873,38 @@ def _reference_erase_image(self, image, *, i, j, h, w, v):

return erased_image

@pytest.mark.parametrize(
"make_input",
[
make_image,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available")
),
],
)
@pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_functional_image_correctness(self, dtype, device):
image = make_image(dtype=dtype, device=device)
def test_functional_image_correctness(self, make_input, dtype, device):
image = make_input(dtype=dtype, device=device)

actual = F.erase(image, **self.FUNCTIONAL_KWARGS)

if make_input is make_image_cvcuda:
image = F.cvcuda_to_tensor(image)[0].cpu()

expected = self._reference_erase_image(image, **self.FUNCTIONAL_KWARGS)

assert_equal(actual, expected)

@pytest.mark.parametrize(
"make_input",
[
make_image,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available")
),
],
)
@param_value_parametrization(
scale=[(0.1, 0.2), [0.0, 1.0]],
ratio=[(0.3, 0.7), [0.1, 5.0]],
Expand All @@ -3860,10 +3913,10 @@ def test_functional_image_correctness(self, dtype, device):
@pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("seed", list(range(5)))
def test_transform_image_correctness(self, param, value, dtype, device, seed):
def test_transform_image_correctness(self, make_input, param, value, dtype, device, seed):
transform = transforms.RandomErasing(**{param: value}, p=1)

image = make_image(dtype=dtype, device=device)
image = make_input(dtype=dtype, device=device)

with freeze_rng_state():
torch.manual_seed(seed)
Expand All @@ -3874,9 +3927,18 @@ def test_transform_image_correctness(self, param, value, dtype, device, seed):
torch.manual_seed(seed)
actual = transform(image)

if make_input is make_image_cvcuda:
image = F.cvcuda_to_tensor(image)[0].cpu()

expected = self._reference_erase_image(image, **params)

assert_equal(actual, expected)
if make_input is make_image_cvcuda and value == "random":
# CV-CUDA doesnt have same random distribution as torchvision
# it uses its own seeding, but we have determinism
# set seed with torch.randint in the kernel
assert_close(actual, expected, rtol=0, atol=256)
else:
assert_equal(actual, expected)

def test_transform_errors(self):
with pytest.raises(TypeError, match="Argument value should be either a number or str or a sequence"):
Expand Down
7 changes: 7 additions & 0 deletions torchvision/transforms/v2/_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,15 @@
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision import transforms as _transforms, tv_tensors
from torchvision.transforms.v2 import functional as F
from torchvision.transforms.v2.functional._utils import _is_cvcuda_available, _is_cvcuda_tensor

from ._transform import _RandomApplyTransform, Transform
from ._utils import _check_sequence_input, _parse_labels_getter, has_any, is_pure_tensor, query_chw, query_size


CVCUDA_AVAILABLE = _is_cvcuda_available()


class RandomErasing(_RandomApplyTransform):
"""Randomly select a rectangle region in the input image or video and erase its pixels.

Expand Down Expand Up @@ -48,6 +52,9 @@ class RandomErasing(_RandomApplyTransform):

_v1_transform_cls = _transforms.RandomErasing

if CVCUDA_AVAILABLE:
_transformed_types = _RandomApplyTransform._transformed_types + (_is_cvcuda_tensor,)

def _extract_params_for_v1_transform(self) -> dict[str, Any]:
return dict(
super()._extract_params_for_v1_transform(),
Expand Down
5 changes: 3 additions & 2 deletions torchvision/transforms/v2/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401
from torchvision.transforms.v2.functional import get_dimensions, get_size, is_pure_tensor
from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT
from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT, _is_cvcuda_tensor


def _setup_number_or_seq(arg: int | float | Sequence[int | float], name: str) -> Sequence[float]:
Expand Down Expand Up @@ -182,7 +182,7 @@ def query_chw(flat_inputs: list[Any]) -> tuple[int, int, int]:
chws = {
tuple(get_dimensions(inpt))
for inpt in flat_inputs
if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video))
if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video, _is_cvcuda_tensor))
}
if not chws:
raise TypeError("No image or video was found in the sample")
Expand All @@ -207,6 +207,7 @@ def query_size(flat_inputs: list[Any]) -> tuple[int, int]:
tv_tensors.Mask,
tv_tensors.BoundingBoxes,
tv_tensors.KeyPoints,
_is_cvcuda_tensor,
),
)
}
Expand Down
101 changes: 100 additions & 1 deletion torchvision/transforms/v2/functional/_augment.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import io
from types import SimpleNamespace
from typing import TYPE_CHECKING

import PIL.Image

Expand All @@ -8,7 +10,13 @@
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
from torchvision.utils import _log_api_usage_once

from ._utils import _get_kernel, _register_kernel_internal
from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal


CVCUDA_AVAILABLE = _is_cvcuda_available()

if TYPE_CHECKING:
import cvcuda # type: ignore[import-not-found]


def erase(
Expand Down Expand Up @@ -58,6 +66,97 @@ def erase_video(
return erase_image(video, i=i, j=j, h=h, w=w, v=v, inplace=inplace)


def _erase_image_cvcuda(
image: "cvcuda.Tensor",
i: int,
j: int,
h: int,
w: int,
v: torch.Tensor,
inplace: bool = False,
) -> "cvcuda.Tensor":
cvcuda = _import_cvcuda()

if inplace:
raise ValueError("inplace is not supported for cvcuda.Tensor")

# the v tensor is random if it has spatial dimensions > 1x1
is_random_fill = v.shape[-2:] != (1, 1)

# allocate any space for standard torch tensors
mask = (1 << image.shape[3]) - 1
src_anchor = torch.tensor([[j, i]], dtype=torch.int32, device="cuda")
src_erasing = torch.tensor([[w, h, mask]], dtype=torch.int32, device="cuda")
src_idx = torch.tensor([0], dtype=torch.int32, device="cuda")

# allocate the fill values based on if random or not
# use zeros for random fill since we have to pass the tensor to the kernel anyway
if is_random_fill:
src_vals = torch.zeros(4, device="cuda", dtype=torch.float32)
# CV-CUDA requires that the fill values is a flat size 4 tensor
# so we need to flatten the fill values and pad with zeros if needed
else:
v_flat = v.flatten().to(dtype=torch.float32, device="cuda")
if v_flat.numel() == 1:
src_vals = v_flat.expand(4).contiguous()
else:
if v_flat.numel() >= 4:
src_vals = v_flat[:4]
else:
pad_len = 4 - v_flat.numel()
src_vals = torch.cat([v_flat, torch.zeros(pad_len, device="cuda", dtype=torch.float32)])
src_vals = src_vals.contiguous()

# the simple tensors can be read directly by CV-CUDA
cv_imgIdx = cvcuda.as_tensor(
src_idx.reshape(
1,
),
"N",
)
cv_values = cvcuda.as_tensor(
src_vals.reshape(
1 * 4,
),
"N",
)

# packed types (_2S32, _3S32) need to be copied into pre-allocated tensors
# torch does not support these packed types directly, so we create a helper function
# which will enable torch copy into the data directly (by overriding type/strides info)
def _to_torch(cv_tensor: cvcuda.Tensor, shape: tuple[int, ...], typestr: str) -> torch.Tensor:
iface = cv_tensor.cuda().__cuda_array_interface__
iface.update(shape=shape, typestr=typestr, strides=None)
return torch.as_tensor(SimpleNamespace(__cuda_array_interface__=iface), device="cuda")

# allocate the data for packed types
cv_anchor = cvcuda.Tensor((1,), cvcuda.Type._2S32, "N")
cv_erasing = cvcuda.Tensor((1,), cvcuda.Type._3S32, "N")

# do a memcpy with torch, pretending data is scalar type contiguous
_to_torch(cv_anchor, (1, 2), "<i4").copy_(src_anchor)
_to_torch(cv_erasing, (1, 3), "<i4").copy_(src_erasing)

# derive seed from torch's RNG so CV-CUDA is deterministic when user sets torch.manual_seed()
seed = 0
if is_random_fill:
seed = int(torch.randint(0, 2147483648, (1,)).item())

return cvcuda.erase(
src=image,
anchor=cv_anchor,
erasing=cv_erasing,
values=cv_values,
imgIdx=cv_imgIdx,
random=is_random_fill,
seed=seed,
)


if CVCUDA_AVAILABLE:
_register_kernel_internal(erase, _import_cvcuda().Tensor)(_erase_image_cvcuda)


def jpeg(image: torch.Tensor, quality: int) -> torch.Tensor:
"""See :class:`~torchvision.transforms.v2.JPEG` for details."""
if torch.jit.is_scripting():
Expand Down
Loading