Skip to content

Commit b65ecf3

Browse files
author
emcastillo
authored
Merge pull request #788 from linshokaku/fix-nightly-test
Fix nightly test
2 parents 8500fe2 + 8be574c commit b65ecf3

File tree

5 files changed

+15
-8
lines changed

5 files changed

+15
-8
lines changed

pytorch_pfn_extras/distributed/_distributed_validation_sampler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,13 @@ def __init__(
2626
seed: int = 0,
2727
) -> None:
2828
if num_replicas is None:
29-
if not dist.is_available(): # type: ignore[no-untyped-call]
29+
if not dist.is_available() or not dist.is_initialized(): # type: ignore[no-untyped-call]
3030
raise RuntimeError(
3131
"Requires distributed package to be available"
3232
)
3333
num_replicas = dist.get_world_size() # type: ignore[no-untyped-call]
3434
if rank is None:
35-
if not dist.is_available(): # type: ignore[no-untyped-call]
35+
if not dist.is_available() or not dist.is_initialized(): # type: ignore[no-untyped-call]
3636
raise RuntimeError(
3737
"Requires distributed package to be available"
3838
)

tests/pytorch_pfn_extras_tests/distributed_tests/test_distributed_validation_sampler.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def test_default(base_dataset):
1818
sample_idxs = []
1919
with mock.patch.object(
2020
dist, "get_world_size", return_value=_world_size
21-
), mock.patch.object(dist, "is_available", return_value=True):
21+
), mock.patch.object(dist, "is_initialized", return_value=True):
2222
for rank in range(_world_size):
2323
with mock.patch.object(dist, "get_rank", return_value=rank):
2424
sampler = DistributedValidationSampler(base_dataset)
@@ -41,7 +41,7 @@ def test_no_shuffle(base_dataset):
4141
]
4242
with mock.patch.object(
4343
dist, "get_world_size", return_value=_world_size
44-
), mock.patch.object(dist, "is_available", return_value=True):
44+
), mock.patch.object(dist, "is_initialized", return_value=True):
4545
for rank in range(_world_size):
4646
with mock.patch.object(dist, "get_rank", return_value=rank):
4747
sampler = DistributedValidationSampler(
@@ -57,7 +57,7 @@ def test_manual_num_replicas_and_ranks(base_dataset):
5757
with mock.patch.object(
5858
dist, "get_world_size", side_effect=AssertionError()
5959
), mock.patch.object(
60-
dist, "is_available", side_effect=AssertionError()
60+
dist, "is_initialized", side_effect=AssertionError()
6161
), mock.patch.object(
6262
dist, "get_rank", side_effect=AssertionError()
6363
):
@@ -78,7 +78,7 @@ def test_seed(base_dataset):
7878
assert list(sampler1) != list(sampler2)
7979

8080

81-
def test_no_distributed_available(base_dataset):
81+
def test_no_distributed_initialized(base_dataset):
8282
with pytest.raises(RuntimeError):
8383
DistributedValidationSampler(base_dataset, num_replicas=_world_size)
8484
with pytest.raises(RuntimeError):
@@ -88,6 +88,10 @@ def test_no_distributed_available(base_dataset):
8888
def test_invalid_rank(base_dataset):
8989
with mock.patch.object(dist, "get_world_size", return_value=_world_size):
9090
with pytest.raises(ValueError):
91-
DistributedValidationSampler(base_dataset, rank=-1)
91+
DistributedValidationSampler(
92+
base_dataset, num_replicas=_world_size, rank=-1
93+
)
9294
with pytest.raises(ValueError):
93-
DistributedValidationSampler(base_dataset, rank=_world_size)
95+
DistributedValidationSampler(
96+
base_dataset, num_replicas=_world_size, rank=_world_size
97+
)

tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_evaluator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,7 @@ def update_fn(engine, batch):
273273
return evaluator
274274

275275

276+
@pytest.mark.filterwarnings("ignore::UserWarning")
276277
def test_ignite_evaluator_reporting_metrics():
277278
try:
278279
from ignite.metrics import MeanSquaredError

tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_progress_bar_notebook.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def test_run_progress_bar_notebook():
4747
reason="progress bar notebook import failed, "
4848
"maybe ipython is not installed",
4949
)
50+
@pytest.mark.filterwarnings("ignore::UserWarning")
5051
def test_ignite_extensions_manager_with_progressbar_notebook():
5152
try:
5253
from ignite.engine import create_supervised_trainer

tests/pytorch_pfn_extras_tests/training_tests/test_manager.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,7 @@ def test_extensions_manager_state_dict_future_ppe_version():
314314
manager_2.load_state_dict(state_dict)
315315

316316

317+
@pytest.mark.filterwarnings("ignore::UserWarning")
317318
def test_ignite_extensions_manager_state_dict():
318319
try:
319320
from ignite.engine import create_supervised_trainer

0 commit comments

Comments
 (0)