@@ -6225,7 +6225,18 @@ def test_kernel_image(self, dtype, device):
62256225 def test_kernel_video (self ):
62266226 check_kernel (F .adjust_saturation_video , make_video (), saturation_factor = 0.5 )
62276227
6228- @pytest .mark .parametrize ("make_input" , [make_image_tensor , make_image , make_image_pil , make_video ])
6228+ @pytest .mark .parametrize (
6229+ "make_input" ,
6230+ [
6231+ make_image_tensor ,
6232+ make_image ,
6233+ make_image_pil ,
6234+ make_video ,
6235+ pytest .param (
6236+ make_image_cvcuda , marks = pytest .mark .skipif (not CVCUDA_AVAILABLE , reason = "CVCUDA not available" )
6237+ ),
6238+ ],
6239+ )
62296240 def test_functional (self , make_input ):
62306241 check_functional (F .adjust_saturation , make_input (), saturation_factor = 0.5 )
62316242
@@ -6236,9 +6247,16 @@ def test_functional(self, make_input):
62366247 (F ._color ._adjust_saturation_image_pil , PIL .Image .Image ),
62376248 (F .adjust_saturation_image , tv_tensors .Image ),
62386249 (F .adjust_saturation_video , tv_tensors .Video ),
6250+ pytest .param (
6251+ F ._color ._adjust_saturation_cvcuda ,
6252+ "cvcuda.Tensor" ,
6253+ marks = pytest .mark .skipif (not CVCUDA_AVAILABLE , reason = "CVCUDA not available" ),
6254+ ),
62396255 ],
62406256 )
62416257 def test_functional_signature (self , kernel , input_type ):
6258+ if input_type == "cvcuda.Tensor" :
6259+ input_type = _import_cvcuda ().Tensor
62426260 check_functional_kernel_signature_match (F .adjust_saturation , kernel = kernel , input_type = input_type )
62436261
62446262 def test_functional_error (self ):
@@ -6248,11 +6266,28 @@ def test_functional_error(self):
62486266 with pytest .raises (ValueError , match = "is not non-negative" ):
62496267 F .adjust_saturation (make_image (), saturation_factor = - 1 )
62506268
6269+ @pytest .mark .parametrize (
6270+ "make_input" ,
6271+ [
6272+ make_image ,
6273+ pytest .param (
6274+ make_image_cvcuda , marks = pytest .mark .skipif (not CVCUDA_AVAILABLE , reason = "CVCUDA not available" )
6275+ ),
6276+ ],
6277+ )
6278+ @pytest .mark .parametrize ("color_space" , ["RGB" , "GRAY" ])
62516279 @pytest .mark .parametrize ("saturation_factor" , [0.1 , 0.5 , 1.0 ])
6252- def test_correctness_image (self , saturation_factor ):
6253- image = make_image (dtype = torch .uint8 , device = "cpu" )
6280+ def test_correctness_image (self , make_input , color_space , saturation_factor ):
6281+ image = make_input (dtype = torch .uint8 , color_space = color_space , device = "cpu" )
62546282
62556283 actual = F .adjust_saturation (image , saturation_factor = saturation_factor )
6284+
6285+ if make_input is make_image_cvcuda :
6286+ actual = F .cvcuda_to_tensor (actual ).to (device = "cpu" )
6287+ actual = actual .squeeze (0 )
6288+ image = F .cvcuda_to_tensor (image )
6289+ image = image .squeeze (0 )
6290+
62566291 expected = F .to_image (F .adjust_saturation (F .to_pil_image (image ), saturation_factor = saturation_factor ))
62576292
62586293 assert_close (actual , expected , rtol = 0 , atol = 1 )
0 commit comments