Skip to content

Commit 8b087d4

Browse files
author
emcastillo
authored
Merge pull request #695 from take-cheeze/as_out_train_mode
[onnx] Inherit training status in as output module
2 parents 5aef3b0 + 7870fdf commit 8b087d4

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

pytorch_pfn_extras/onnx/_as_output.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,12 @@ def trace(
9090
) -> Generator[Tuple[torch.nn.Module, _Outputs], None, None]:
9191
_outputs.outputs = _Outputs()
9292
if not isinstance(module, torch.jit.ScriptModule):
93+
orig = module
9394
module = _ModuleWithAdditionalOutputs(module, _outputs.outputs)
95+
if orig.training:
96+
module.train()
97+
else:
98+
module.eval()
9499
try:
95100
yield module, _outputs.outputs
96101
finally:

0 commit comments

Comments
 (0)