Skip to content

Commit bd5c9db

Browse files
committed
Merge branch 'feat/adjust_saturation_cvcuda' into brightness_contrast_hue_cvcuda
2 parents e91c6e2 + bd8cf13 commit bd5c9db

File tree

2 files changed

+69
-3
lines changed

2 files changed

+69
-3
lines changed

test/test_transforms_v2.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6225,7 +6225,18 @@ def test_kernel_image(self, dtype, device):
62256225
def test_kernel_video(self):
62266226
check_kernel(F.adjust_saturation_video, make_video(), saturation_factor=0.5)
62276227

6228-
@pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_image_pil, make_video])
6228+
@pytest.mark.parametrize(
6229+
"make_input",
6230+
[
6231+
make_image_tensor,
6232+
make_image,
6233+
make_image_pil,
6234+
make_video,
6235+
pytest.param(
6236+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available")
6237+
),
6238+
],
6239+
)
62296240
def test_functional(self, make_input):
62306241
check_functional(F.adjust_saturation, make_input(), saturation_factor=0.5)
62316242

@@ -6236,9 +6247,16 @@ def test_functional(self, make_input):
62366247
(F._color._adjust_saturation_image_pil, PIL.Image.Image),
62376248
(F.adjust_saturation_image, tv_tensors.Image),
62386249
(F.adjust_saturation_video, tv_tensors.Video),
6250+
pytest.param(
6251+
F._color._adjust_saturation_cvcuda,
6252+
"cvcuda.Tensor",
6253+
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available"),
6254+
),
62396255
],
62406256
)
62416257
def test_functional_signature(self, kernel, input_type):
6258+
if input_type == "cvcuda.Tensor":
6259+
input_type = _import_cvcuda().Tensor
62426260
check_functional_kernel_signature_match(F.adjust_saturation, kernel=kernel, input_type=input_type)
62436261

62446262
def test_functional_error(self):
@@ -6248,11 +6266,28 @@ def test_functional_error(self):
62486266
with pytest.raises(ValueError, match="is not non-negative"):
62496267
F.adjust_saturation(make_image(), saturation_factor=-1)
62506268

6269+
@pytest.mark.parametrize(
6270+
"make_input",
6271+
[
6272+
make_image,
6273+
pytest.param(
6274+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available")
6275+
),
6276+
],
6277+
)
6278+
@pytest.mark.parametrize("color_space", ["RGB", "GRAY"])
62516279
@pytest.mark.parametrize("saturation_factor", [0.1, 0.5, 1.0])
6252-
def test_correctness_image(self, saturation_factor):
6253-
image = make_image(dtype=torch.uint8, device="cpu")
6280+
def test_correctness_image(self, make_input, color_space, saturation_factor):
6281+
image = make_input(dtype=torch.uint8, color_space=color_space, device="cpu")
62546282

62556283
actual = F.adjust_saturation(image, saturation_factor=saturation_factor)
6284+
6285+
if make_input is make_image_cvcuda:
6286+
actual = F.cvcuda_to_tensor(actual).to(device="cpu")
6287+
actual = actual.squeeze(0)
6288+
image = F.cvcuda_to_tensor(image)
6289+
image = image.squeeze(0)
6290+
62566291
expected = F.to_image(F.adjust_saturation(F.to_pil_image(image), saturation_factor=saturation_factor))
62576292

62586293
assert_close(actual, expected, rtol=0, atol=1)

torchvision/transforms/v2/functional/_color.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,37 @@ def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> to
197197
return adjust_saturation_image(video, saturation_factor=saturation_factor)
198198

199199

200+
def _adjust_saturation_cvcuda(image: "cvcuda.Tensor", saturation_factor: float) -> "cvcuda.Tensor":
201+
if saturation_factor < 0:
202+
raise ValueError(f"saturation_factor ({saturation_factor}) is not non-negative.")
203+
204+
c = image.shape[3]
205+
if c not in [1, 3, 4]:
206+
raise TypeError(f"Input image tensor permitted channel values are 1, 3, or 4, but found {c}")
207+
208+
if c == 1: # Match PIL behaviour
209+
return image
210+
211+
# grayscale weights
212+
sf = saturation_factor
213+
r, g, b = 0.2989, 0.587, 0.114
214+
twist_data = [
215+
[sf + (1 - sf) * r, (1 - sf) * g, (1 - sf) * b, 0.0],
216+
[(1 - sf) * r, sf + (1 - sf) * g, (1 - sf) * b, 0.0],
217+
[(1 - sf) * r, (1 - sf) * g, sf + (1 - sf) * b, 0.0],
218+
]
219+
twist_tensor = cvcuda.as_tensor(
220+
torch.tensor(twist_data, dtype=torch.float32, device="cuda"),
221+
"HW",
222+
)
223+
224+
return cvcuda.color_twist(image, twist_tensor)
225+
226+
227+
if CVCUDA_AVAILABLE:
228+
_register_kernel_internal(adjust_saturation, cvcuda.Tensor)(_adjust_saturation_cvcuda)
229+
230+
200231
def adjust_contrast(inpt: torch.Tensor, contrast_factor: float) -> torch.Tensor:
201232
"""See :class:`~torchvision.transforms.RandomAutocontrast`"""
202233
if torch.jit.is_scripting():

0 commit comments

Comments
 (0)