Skip to content

Commit dd654cb

Browse files
author
twata
committed
[onnx] Fix grad op domain
1 parent b65ecf3 commit dd654cb

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

pytorch_pfn_extras/onnx/_grad.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def _grad( # type: ignore
8585
@staticmethod
8686
def symbolic(g, output, grad_output, *inputs): # type: ignore
8787
return g.op(
88-
"ai.onnx.preview::Gradient",
88+
"ai.onnx.preview.training::Gradient",
8989
*inputs,
9090
xs_s=input_names,
9191
zs_s=[],

tests/pytorch_pfn_extras_tests/onnx_tests/test_grad.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def forward(self, x):
6262

6363

6464
@pytest.mark.parametrize("use_pfto", [False, True])
65-
@pytest.mark.filterwarnings("ignore:The shape inference of ai.onnx.preview..Gradient type is missing:UserWarning")
65+
@pytest.mark.filterwarnings("ignore:The shape inference of ai.onnx.preview.training..Gradient type is missing:UserWarning")
6666
@pytest.mark.filterwarnings("ignore:Specified output_names .*:UserWarning")
6767
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
6868
def test_grad(use_pfto: bool):
@@ -103,6 +103,7 @@ def forward(self, x):
103103
)
104104

105105
actual_onnx = onnx.load(os.path.join(output_dir, 'model.onnx'))
106+
print(actual_onnx)
106107
named_nodes = {n.name: n for n in actual_onnx.graph.node}
107108
if pytorch_pfn_extras.requires("1.13") and not use_pfto:
108109
assert '/_ppe_as_out_module/conv/Conv' in named_nodes
@@ -136,7 +137,7 @@ def forward(self, x):
136137

137138

138139
@pytest.mark.parametrize("use_pfto", [False, True])
139-
@pytest.mark.filterwarnings("ignore:The shape inference of ai.onnx.preview..Gradient type is missing:UserWarning")
140+
@pytest.mark.filterwarnings("ignore:The shape inference of ai.onnx.preview.training..Gradient type is missing:UserWarning")
140141
@pytest.mark.filterwarnings("ignore:Specified output_names .*:UserWarning")
141142
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
142143
def test_grad_multiple_times(use_pfto: bool):
@@ -218,7 +219,7 @@ def forward(self, x):
218219

219220

220221
@pytest.mark.parametrize("use_pfto", [False, True])
221-
@pytest.mark.filterwarnings("ignore:The shape inference of ai.onnx.preview..Gradient type is missing:UserWarning")
222+
@pytest.mark.filterwarnings("ignore:The shape inference of ai.onnx.preview.training..Gradient type is missing:UserWarning")
222223
@pytest.mark.filterwarnings("ignore:Specified output_names .*:UserWarning")
223224
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
224225
def test_grad_with_multiple_inputs(use_pfto: bool):

0 commit comments

Comments
 (0)