@@ -18,7 +18,7 @@ def test_default(base_dataset):
1818 sample_idxs = []
1919 with mock .patch .object (
2020 dist , "get_world_size" , return_value = _world_size
21- ), mock .patch .object (dist , "is_available " , return_value = True ):
21+ ), mock .patch .object (dist , "is_initialized " , return_value = True ):
2222 for rank in range (_world_size ):
2323 with mock .patch .object (dist , "get_rank" , return_value = rank ):
2424 sampler = DistributedValidationSampler (base_dataset )
@@ -41,7 +41,7 @@ def test_no_shuffle(base_dataset):
4141 ]
4242 with mock .patch .object (
4343 dist , "get_world_size" , return_value = _world_size
44- ), mock .patch .object (dist , "is_available " , return_value = True ):
44+ ), mock .patch .object (dist , "is_initialized " , return_value = True ):
4545 for rank in range (_world_size ):
4646 with mock .patch .object (dist , "get_rank" , return_value = rank ):
4747 sampler = DistributedValidationSampler (
@@ -57,7 +57,7 @@ def test_manual_num_replicas_and_ranks(base_dataset):
5757 with mock .patch .object (
5858 dist , "get_world_size" , side_effect = AssertionError ()
5959 ), mock .patch .object (
60- dist , "is_available " , side_effect = AssertionError ()
60+ dist , "is_initialized " , side_effect = AssertionError ()
6161 ), mock .patch .object (
6262 dist , "get_rank" , side_effect = AssertionError ()
6363 ):
@@ -78,16 +78,20 @@ def test_seed(base_dataset):
7878 assert list (sampler1 ) != list (sampler2 )
7979
8080
81- def test_no_distributed_available (base_dataset ):
82- with pytest .raises (ValueError ):
81+ def test_no_distributed_initialized (base_dataset ):
82+ with pytest .raises (RuntimeError ):
8383 DistributedValidationSampler (base_dataset , num_replicas = _world_size )
84- with pytest .raises (ValueError ):
84+ with pytest .raises (RuntimeError ):
8585 DistributedValidationSampler (base_dataset , rank = 0 )
8686
8787
8888def test_invalid_rank (base_dataset ):
8989 with mock .patch .object (dist , "get_world_size" , return_value = _world_size ):
9090 with pytest .raises (ValueError ):
91- DistributedValidationSampler (base_dataset , rank = - 1 )
91+ DistributedValidationSampler (
92+ base_dataset , num_replicas = _world_size , rank = - 1
93+ )
9294 with pytest .raises (ValueError ):
93- DistributedValidationSampler (base_dataset , rank = _world_size )
95+ DistributedValidationSampler (
96+ base_dataset , num_replicas = _world_size , rank = _world_size
97+ )
0 commit comments