We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
2 parents 5aef3b0 + 7870fdf commit 8b087d4Copy full SHA for 8b087d4
pytorch_pfn_extras/onnx/_as_output.py
@@ -90,7 +90,12 @@ def trace(
90
) -> Generator[Tuple[torch.nn.Module, _Outputs], None, None]:
91
_outputs.outputs = _Outputs()
92
if not isinstance(module, torch.jit.ScriptModule):
93
+ orig = module
94
module = _ModuleWithAdditionalOutputs(module, _outputs.outputs)
95
+ if orig.training:
96
+ module.train()
97
+ else:
98
+ module.eval()
99
try:
100
yield module, _outputs.outputs
101
finally:
0 commit comments