Skip to content

Commit 7870fdf

Browse files
author
twata
committed
[onnx] Inherit training status in as output module
1 parent bcb5f37 commit 7870fdf

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

pytorch_pfn_extras/onnx/_as_output.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,12 @@ def trace(
9090
) -> Generator[Tuple[torch.nn.Module, _Outputs], None, None]:
9191
_outputs.outputs = _Outputs()
9292
if not isinstance(module, torch.jit.ScriptModule):
93+
orig = module
9394
module = _ModuleWithAdditionalOutputs(module, _outputs.outputs)
95+
if orig.training:
96+
module.train()
97+
else:
98+
module.eval()
9499
try:
95100
yield module, _outputs.outputs
96101
finally:

0 commit comments

Comments
 (0)