Skip to content

Commit 8be574c

Browse files
committed
fix nightly test
1 parent 0832ac5 commit 8be574c

File tree

4 files changed

+16
-10
lines changed

4 files changed

+16
-10
lines changed

pytorch_pfn_extras/distributed/_distributed_validation_sampler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,13 @@ def __init__(
2626
seed: int = 0,
2727
) -> None:
2828
if num_replicas is None:
29-
if not dist.is_available(): # type: ignore[no-untyped-call]
29+
if not dist.is_available() or not dist.is_initialized(): # type: ignore[no-untyped-call]
3030
raise RuntimeError(
3131
"Requires distributed package to be available"
3232
)
3333
num_replicas = dist.get_world_size() # type: ignore[no-untyped-call]
3434
if rank is None:
35-
if not dist.is_available(): # type: ignore[no-untyped-call]
35+
if not dist.is_available() or not dist.is_initialized(): # type: ignore[no-untyped-call]
3636
raise RuntimeError(
3737
"Requires distributed package to be available"
3838
)

tests/pytorch_pfn_extras_tests/distributed_tests/test_distributed_validation_sampler.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def test_default(base_dataset):
1818
sample_idxs = []
1919
with mock.patch.object(
2020
dist, "get_world_size", return_value=_world_size
21-
), mock.patch.object(dist, "is_available", return_value=True):
21+
), mock.patch.object(dist, "is_initialized", return_value=True):
2222
for rank in range(_world_size):
2323
with mock.patch.object(dist, "get_rank", return_value=rank):
2424
sampler = DistributedValidationSampler(base_dataset)
@@ -41,7 +41,7 @@ def test_no_shuffle(base_dataset):
4141
]
4242
with mock.patch.object(
4343
dist, "get_world_size", return_value=_world_size
44-
), mock.patch.object(dist, "is_available", return_value=True):
44+
), mock.patch.object(dist, "is_initialized", return_value=True):
4545
for rank in range(_world_size):
4646
with mock.patch.object(dist, "get_rank", return_value=rank):
4747
sampler = DistributedValidationSampler(
@@ -57,7 +57,7 @@ def test_manual_num_replicas_and_ranks(base_dataset):
5757
with mock.patch.object(
5858
dist, "get_world_size", side_effect=AssertionError()
5959
), mock.patch.object(
60-
dist, "is_available", side_effect=AssertionError()
60+
dist, "is_initialized", side_effect=AssertionError()
6161
), mock.patch.object(
6262
dist, "get_rank", side_effect=AssertionError()
6363
):
@@ -78,16 +78,20 @@ def test_seed(base_dataset):
7878
assert list(sampler1) != list(sampler2)
7979

8080

81-
def test_no_distributed_available(base_dataset):
82-
with pytest.raises(ValueError):
81+
def test_no_distributed_initialized(base_dataset):
82+
with pytest.raises(RuntimeError):
8383
DistributedValidationSampler(base_dataset, num_replicas=_world_size)
84-
with pytest.raises(ValueError):
84+
with pytest.raises(RuntimeError):
8585
DistributedValidationSampler(base_dataset, rank=0)
8686

8787

8888
def test_invalid_rank(base_dataset):
8989
with mock.patch.object(dist, "get_world_size", return_value=_world_size):
9090
with pytest.raises(ValueError):
91-
DistributedValidationSampler(base_dataset, rank=-1)
91+
DistributedValidationSampler(
92+
base_dataset, num_replicas=_world_size, rank=-1
93+
)
9294
with pytest.raises(ValueError):
93-
DistributedValidationSampler(base_dataset, rank=_world_size)
95+
DistributedValidationSampler(
96+
base_dataset, num_replicas=_world_size, rank=_world_size
97+
)

tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_evaluator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,7 @@ def update_fn(engine, batch):
272272
evaluator = Engine(update_fn)
273273
return evaluator
274274

275+
275276
@pytest.mark.filterwarnings("ignore::UserWarning")
276277
def test_ignite_evaluator_reporting_metrics():
277278
try:

tests/pytorch_pfn_extras_tests/training_tests/test_manager.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,7 @@ def test_extensions_manager_state_dict_future_ppe_version():
313313
with pytest.warns(UserWarning, match="version"):
314314
manager_2.load_state_dict(state_dict)
315315

316+
316317
@pytest.mark.filterwarnings("ignore::UserWarning")
317318
def test_ignite_extensions_manager_state_dict():
318319
try:

0 commit comments

Comments
 (0)