Skip to content

Commit fdfd358

Browse files
author
emcastillo
authored
Merge pull request #687 from take-cheeze/pfto_test_names
[onnx] Output pfto tests to out/ dir like export_testcase
2 parents 8b087d4 + aa00145 commit fdfd358

File tree

3 files changed

+63
-54
lines changed

3 files changed

+63
-54
lines changed

pytorch_pfn_extras/onnx/export_testcase.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,7 @@ def export_testcase(
328328
if isinstance(outs, torch.Tensor):
329329
outs = outs,
330330
assert outs is not None
331+
outs = torch._C._jit_flatten(outs)[0]
331332
# Remove unused inputs
332333
# - When keep_initializers_as_inputs=True, inputs contains initializers.
333334
# So we have to filt initializers.

tests/pytorch_pfn_extras_tests/onnx_tests/test_export_testcase.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323

2424
output_dir = 'out'
25+
output_counter = {}
2526

2627

2728
class Net(nn.Module):
@@ -47,14 +48,24 @@ def _get_output_dir(d, **kwargs):
4748
output_dir_base = 'out'
4849
opset_ver = kwargs.get('opset_version', pytorch_pfn_extras.onnx._constants.onnx_default_opset)
4950

51+
test_name = os.getenv("PYTEST_CURRENT_TEST").split(':')[-1].split(' ')[0]
52+
if d:
53+
test_name = f"{test_name}_{d}"
54+
if "model_overwrite" not in kwargs:
55+
if test_name in output_counter:
56+
output_counter[test_name] += 1
57+
test_name = f"{test_name}_{output_counter[test_name]}"
58+
else:
59+
output_counter[test_name] = 0
60+
5061
output_dir = os.path.join(
51-
output_dir_base, 'opset{}'.format(opset_ver), d)
62+
output_dir_base, 'opset{}'.format(opset_ver), test_name)
5263
os.makedirs(output_dir, exist_ok=True)
5364
return output_dir
5465

5566

5667
def _helper(model, args, d, use_pfto=True, **kwargs):
57-
output_dir = _get_output_dir(d)
68+
output_dir = _get_output_dir(d, **kwargs)
5869
if 'training' not in kwargs:
5970
kwargs['training'] = model.training
6071
if 'do_constant_folding' not in kwargs:
Lines changed: 49 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1-
import tempfile
2-
from typing import Callable, List, Optional
1+
import os
2+
from typing import Callable
33

44
import onnx
55
import onnxruntime as ort
66
import torch
77
import pytorch_pfn_extras.onnx.pfto_exporter.export as pfto
8+
from pytorch_pfn_extras.onnx import export_testcase
9+
from pytorch_pfn_extras_tests.onnx_tests.test_export_testcase import _get_output_dir
810

911

1012
def run_model_test(
@@ -13,8 +15,6 @@ def run_model_test(
1315
check_torch_export=True,
1416
rtol=1e-05,
1517
atol=1e-08,
16-
input_names: Optional[List[str]] = None,
17-
output_names: Optional[List[str]] = None,
1818
skip_oxrt=False,
1919
operator_export_type=torch.onnx.OperatorExportTypes.ONNX,
2020
strict_trace=True,
@@ -35,63 +35,60 @@ def run_model_test(
3535

3636
if operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN:
3737
skip_oxrt = True
38-
with tempfile.NamedTemporaryFile() as f:
39-
f.close()
40-
rng_state = torch.get_rng_state()
41-
with pfto._force_tracing():
42-
expected = model(*args)
43-
if not isinstance(expected, tuple):
44-
expected = (expected,)
4538

46-
te_model = None
47-
if check_torch_export:
48-
torch.set_rng_state(rng_state)
49-
with tempfile.NamedTemporaryFile() as torch_f:
50-
torch_f.close()
51-
torch.onnx.export(
52-
model,
53-
args,
54-
torch_f.name,
55-
input_names=input_names,
56-
output_names=output_names,
57-
**kwargs,
58-
)
59-
te_model = onnx.load(torch_f.name)
39+
rng_state = torch.get_rng_state()
40+
with pfto._force_tracing():
41+
expected = model(*args)
42+
if isinstance(expected, torch.Tensor):
43+
expected = (expected,)
44+
expected = torch._C._jit_flatten(expected)[0]
6045

61-
if input_names is None:
62-
input_names = [f"input_{idx}" for idx, _ in enumerate(args)]
63-
if output_names is None:
64-
output_names = [f"output_{idx}" for idx, _ in enumerate(expected)]
46+
te_model = None
47+
if check_torch_export:
6548
torch.set_rng_state(rng_state)
66-
actual = pfto.export(
49+
pt_dir = _get_output_dir("pt", **kwargs)
50+
export_testcase(
6751
model,
6852
args,
69-
f.name,
70-
input_names=input_names,
71-
output_names=output_names,
72-
strict_trace=strict_trace,
53+
pt_dir,
54+
use_pfto=False,
7355
**kwargs,
7456
)
75-
if not isinstance(actual, tuple):
76-
actual = (actual,)
77-
assert len(actual) == len(expected)
57+
te_model = onnx.load(os.path.join(pt_dir, "model.onnx"))
7858

79-
for a, e in zip(actual, expected):
80-
if isinstance(a, torch.Tensor) and isinstance(e, torch.Tensor):
81-
assert torch.isclose(a, e, rtol=rtol, atol=atol).all()
59+
torch.set_rng_state(rng_state)
60+
pf_dir = _get_output_dir("pf", **kwargs)
61+
actual = export_testcase(
62+
model,
63+
args,
64+
pf_dir,
65+
strict_trace=strict_trace,
66+
return_output=True,
67+
use_pfto=True,
68+
**kwargs,
69+
)
70+
if isinstance(actual, torch.Tensor):
71+
actual = (actual,)
72+
expected = torch._C._jit_flatten(expected)[0]
73+
assert len(actual) == len(expected)
8274

83-
pfto_model = onnx.load(f.name)
84-
if te_model is not None:
85-
assert len(te_model.graph.output) == len(pfto_model.graph.output)
86-
assert len(te_model.graph.input) == len(pfto_model.graph.input)
75+
for a, e in zip(actual, expected):
76+
if isinstance(a, torch.Tensor) and isinstance(e, torch.Tensor):
77+
assert torch.isclose(a, e, rtol=rtol, atol=atol).all()
8778

88-
if skip_oxrt:
89-
return pfto_model
90-
91-
ort_session = ort.InferenceSession(f.name)
92-
actual = ort_session.run(None, {k: v.cpu().numpy() for k, v in zip(input_names, args)})
93-
for a, e in zip(actual, expected):
94-
cmp = torch.isclose(torch.tensor(a), e.cpu(), rtol=rtol, atol=atol)
95-
assert cmp.all(), f"{cmp.logical_not().count_nonzero()} / {cmp.numel()} values failed"
79+
pfto_model = onnx.load(os.path.join(pf_dir, "model.onnx"))
80+
if te_model is not None:
81+
assert len(te_model.graph.output) == len(pfto_model.graph.output)
82+
assert len(te_model.graph.input) == len(pfto_model.graph.input)
9683

84+
if skip_oxrt:
9785
return pfto_model
86+
87+
ort_session = ort.InferenceSession(os.path.join(pf_dir, "model.onnx"))
88+
input_names = [i.name for i in pfto_model.graph.input]
89+
actual = ort_session.run(None, {k: v.cpu().numpy() for k, v in zip(input_names, args)})
90+
for a, e in zip(actual, expected):
91+
cmp = torch.isclose(torch.tensor(a), e.cpu(), rtol=rtol, atol=atol)
92+
assert cmp.all(), f"{cmp.logical_not().count_nonzero()} / {cmp.numel()} values failed"
93+
94+
return pfto_model

0 commit comments

Comments
 (0)