1- import tempfile
2- from typing import Callable , List , Optional
1+ import os
2+ from typing import Callable
33
44import onnx
55import onnxruntime as ort
66import torch
77import pytorch_pfn_extras .onnx .pfto_exporter .export as pfto
8+ from pytorch_pfn_extras .onnx import export_testcase
9+ from pytorch_pfn_extras_tests .onnx_tests .test_export_testcase import _get_output_dir
810
911
1012def run_model_test (
@@ -13,8 +15,6 @@ def run_model_test(
1315 check_torch_export = True ,
1416 rtol = 1e-05 ,
1517 atol = 1e-08 ,
16- input_names : Optional [List [str ]] = None ,
17- output_names : Optional [List [str ]] = None ,
1818 skip_oxrt = False ,
1919 operator_export_type = torch .onnx .OperatorExportTypes .ONNX ,
2020 strict_trace = True ,
@@ -35,63 +35,60 @@ def run_model_test(
3535
3636 if operator_export_type == torch .onnx .OperatorExportTypes .ONNX_ATEN :
3737 skip_oxrt = True
38- with tempfile .NamedTemporaryFile () as f :
39- f .close ()
40- rng_state = torch .get_rng_state ()
41- with pfto ._force_tracing ():
42- expected = model (* args )
43- if not isinstance (expected , tuple ):
44- expected = (expected ,)
4538
46- te_model = None
47- if check_torch_export :
48- torch .set_rng_state (rng_state )
49- with tempfile .NamedTemporaryFile () as torch_f :
50- torch_f .close ()
51- torch .onnx .export (
52- model ,
53- args ,
54- torch_f .name ,
55- input_names = input_names ,
56- output_names = output_names ,
57- ** kwargs ,
58- )
59- te_model = onnx .load (torch_f .name )
39+ rng_state = torch .get_rng_state ()
40+ with pfto ._force_tracing ():
41+ expected = model (* args )
42+ if isinstance (expected , torch .Tensor ):
43+ expected = (expected ,)
44+ expected = torch ._C ._jit_flatten (expected )[0 ]
6045
61- if input_names is None :
62- input_names = [f"input_{ idx } " for idx , _ in enumerate (args )]
63- if output_names is None :
64- output_names = [f"output_{ idx } " for idx , _ in enumerate (expected )]
46+ te_model = None
47+ if check_torch_export :
6548 torch .set_rng_state (rng_state )
66- actual = pfto .export (
49+ pt_dir = _get_output_dir ("pt" , ** kwargs )
50+ export_testcase (
6751 model ,
6852 args ,
69- f .name ,
70- input_names = input_names ,
71- output_names = output_names ,
72- strict_trace = strict_trace ,
53+ pt_dir ,
54+ use_pfto = False ,
7355 ** kwargs ,
7456 )
75- if not isinstance (actual , tuple ):
76- actual = (actual ,)
77- assert len (actual ) == len (expected )
57+ te_model = onnx .load (os .path .join (pt_dir , "model.onnx" ))
7858
79- for a , e in zip (actual , expected ):
80- if isinstance (a , torch .Tensor ) and isinstance (e , torch .Tensor ):
81- assert torch .isclose (a , e , rtol = rtol , atol = atol ).all ()
59+ torch .set_rng_state (rng_state )
60+ pf_dir = _get_output_dir ("pf" , ** kwargs )
61+ actual = export_testcase (
62+ model ,
63+ args ,
64+ pf_dir ,
65+ strict_trace = strict_trace ,
66+ return_output = True ,
67+ use_pfto = True ,
68+ ** kwargs ,
69+ )
70+ if isinstance (actual , torch .Tensor ):
71+ actual = (actual ,)
72+ expected = torch ._C ._jit_flatten (expected )[0 ]
73+ assert len (actual ) == len (expected )
8274
83- pfto_model = onnx .load (f .name )
84- if te_model is not None :
85- assert len (te_model .graph .output ) == len (pfto_model .graph .output )
86- assert len (te_model .graph .input ) == len (pfto_model .graph .input )
75+ for a , e in zip (actual , expected ):
76+ if isinstance (a , torch .Tensor ) and isinstance (e , torch .Tensor ):
77+ assert torch .isclose (a , e , rtol = rtol , atol = atol ).all ()
8778
88- if skip_oxrt :
89- return pfto_model
90-
91- ort_session = ort .InferenceSession (f .name )
92- actual = ort_session .run (None , {k : v .cpu ().numpy () for k , v in zip (input_names , args )})
93- for a , e in zip (actual , expected ):
94- cmp = torch .isclose (torch .tensor (a ), e .cpu (), rtol = rtol , atol = atol )
95- assert cmp .all (), f"{ cmp .logical_not ().count_nonzero ()} / { cmp .numel ()} values failed"
79+ pfto_model = onnx .load (os .path .join (pf_dir , "model.onnx" ))
80+ if te_model is not None :
81+ assert len (te_model .graph .output ) == len (pfto_model .graph .output )
82+ assert len (te_model .graph .input ) == len (pfto_model .graph .input )
9683
84+ if skip_oxrt :
9785 return pfto_model
86+
87+ ort_session = ort .InferenceSession (os .path .join (pf_dir , "model.onnx" ))
88+ input_names = [i .name for i in pfto_model .graph .input ]
89+ actual = ort_session .run (None , {k : v .cpu ().numpy () for k , v in zip (input_names , args )})
90+ for a , e in zip (actual , expected ):
91+ cmp = torch .isclose (torch .tensor (a ), e .cpu (), rtol = rtol , atol = atol )
92+ assert cmp .all (), f"{ cmp .logical_not ().count_nonzero ()} / { cmp .numel ()} values failed"
93+
94+ return pfto_model
0 commit comments