Skip to content

Commit 3ad7f41

Browse files
authored
Merge branch 'master' into pfto_reconstruct
2 parents 0f17be2 + fdfd358 commit 3ad7f41

File tree

13 files changed

+450
-105
lines changed

13 files changed

+450
-105
lines changed

.readthedocs.yaml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@
22

33
version: 2
44
formats: all
5+
build:
6+
os: ubuntu-22.04
7+
tools:
8+
python: "3.8"
59
sphinx:
610
configuration: docs/source/conf.py
711
python:
8-
version: 3.8
912
install:
1013
- method: pip
1114
path: .

pytorch_pfn_extras/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.6.6'
1+
__version__ = '0.6.7'

pytorch_pfn_extras/handler/_logic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def consume_options(self, options: Dict[str, Any]) -> None:
192192
self._grad_scaler = options.pop('grad_scaler', None)
193193

194194
self._backward_fn = options.pop('backward_function', None)
195-
autocast_options = options.get("autocast", False)
195+
autocast_options = options.pop("autocast", False)
196196
if isinstance(autocast_options, bool):
197197
autocast_options = {"enabled": autocast_options, "device_type": "cuda"}
198198
self._autocast = _autocast._AutocastManager(

pytorch_pfn_extras/onnx/_as_output.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,12 @@ def trace(
9090
) -> Generator[Tuple[torch.nn.Module, _Outputs], None, None]:
9191
_outputs.outputs = _Outputs()
9292
if not isinstance(module, torch.jit.ScriptModule):
93+
orig = module
9394
module = _ModuleWithAdditionalOutputs(module, _outputs.outputs)
95+
if orig.training:
96+
module.train()
97+
else:
98+
module.eval()
9499
try:
95100
yield module, _outputs.outputs
96101
finally:

pytorch_pfn_extras/onnx/export_testcase.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,7 @@ def export_testcase(
328328
if isinstance(outs, torch.Tensor):
329329
outs = outs,
330330
assert outs is not None
331+
outs = torch._C._jit_flatten(outs)[0]
331332
# Remove unused inputs
332333
# - When keep_initializers_as_inputs=True, inputs contains initializers.
333334
# So we have to filt initializers.

pytorch_pfn_extras/onnx/pfto_exporter/export.py

Lines changed: 121 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import typing
44
import warnings
55
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union, cast
6+
from contextlib import contextmanager
67

78
import onnx
89
import onnx.checker
@@ -11,6 +12,7 @@
1112
import onnx.shape_inference
1213
import pytorch_pfn_extras
1314
import pytorch_pfn_extras.onnx._constants
15+
from pytorch_pfn_extras.onnx import _grad as grad
1416
from pytorch_pfn_extras.onnx._globals import GLOBALS
1517
from pytorch_pfn_extras.torchscript import run_jit_pass
1618
import torch
@@ -28,6 +30,89 @@
2830

2931
_ppe_ignore_scope: str = "_ppe_as_out_module"
3032
_list_create_ops: List[str] = ["prim::ListConstruct", "onnx::SequenceConstruct", "onnx::SequenceEmpty"]
33+
_fix_ir_version = 8
34+
35+
# Original from https://github.com/pytorch/pytorch/blob/52a36a98d9425479f62b6e2d1a59e434b85f7f7e/torch/csrc/jit/passes/normalize_ops.cpp#L85-L162
36+
_op_normalize_table: Dict[str, str] = {
37+
"absolute": "abs",
38+
"absolute_": "abs_",
39+
"clip": "clamp",
40+
"clip_": "clamp_",
41+
"det": "linalg_det",
42+
"matrix_power": "linalg_matrix_power",
43+
"matrix_exp": "linalg_matrix_exp",
44+
"ger": "outer",
45+
"arccos": "acos",
46+
"arccos_": "acos_",
47+
"arcsin": "asin",
48+
"arcsin_": "asin_",
49+
"arctan": "atan",
50+
"arctan_": "atan_",
51+
"arctan2": "atan2",
52+
"arctan2_": "atan2_",
53+
"arccosh": "acosh",
54+
"arccosh_": "acosh_",
55+
"arcsinh": "asinh",
56+
"arcsinh_": "asinh_",
57+
"arctanh": "atanh",
58+
"arctanh_": "atanh_",
59+
"fix": "trunc",
60+
"fix_": "trunc_",
61+
"negative": "neg",
62+
"negative_": "neg_",
63+
"subtract": "sub",
64+
"subtract_": "sub_",
65+
"greater_equal": "ge",
66+
"greater_equal_": "ge_",
67+
"greater": "gt",
68+
"greater_": "gt_",
69+
"less_equal": "le",
70+
"less_equal_": "le_",
71+
"less": "lt",
72+
"less_": "lt_",
73+
"not_equal": "ne",
74+
"not_equal_": "ne_",
75+
"divide": "div",
76+
"divide_": "div_",
77+
"multiply": "mul",
78+
"multiply_": "mul_",
79+
"linalg_matmul": "matmul",
80+
"inverse": "linalg_inv",
81+
"true_divide": "div",
82+
"true_divide_": "div_",
83+
"concat": "cat",
84+
"concatenate": "cat",
85+
"row_stack": "vstack",
86+
"swapdims": "transpose",
87+
"swapdims_": "transpose_",
88+
"swapaxes": "transpose",
89+
"swapaxes_": "transpose_",
90+
"moveaxis": "movedim",
91+
"special_erf": "erf",
92+
"special_erfc": "erfc",
93+
"special_erfinv": "erfinv",
94+
"special_expit": "sigmoid",
95+
"special_exp2": "exp2",
96+
"special_expm1": "expm1",
97+
"special_logit": "logit",
98+
"special_logsumexp": "logsumexp",
99+
"special_round": "round",
100+
"special_log1p": "log1p",
101+
"special_sinc": "sinc",
102+
"special_digamma": "digamma",
103+
"special_psi": "digamma",
104+
"special_i0": "i0",
105+
"special_xlogy": "xlogy",
106+
"special_log_softmax": "log_softmax",
107+
"orgqr": "linalg_householder_product",
108+
"adjoint": "mH",
109+
"special_multigammaln": "mvlgamma",
110+
"special_polygamma": "polygamma",
111+
"special_softmax": "softmax",
112+
"special_gammainc": "igamma",
113+
"special_gammaincc": "igammac",
114+
"special_gammaln": "lgamma",
115+
}
31116

32117
if pytorch_pfn_extras.requires("1.13"):
33118
from torch.onnx._internal import jit_utils
@@ -157,6 +242,20 @@ def _apply_tensor_info_to_value_info(v: onnx.ValueInfoProto, t: torch.Tensor) ->
157242
a.dim_value = i
158243

159244

245+
@contextmanager
246+
def _force_tracing() -> Any:
247+
old_is_tracing = torch.jit.is_tracing
248+
249+
def is_tracing() -> bool:
250+
return True
251+
252+
try:
253+
torch.jit.is_tracing = is_tracing
254+
yield
255+
finally:
256+
torch.jit.is_tracing = old_is_tracing
257+
258+
160259
@dataclasses.dataclass
161260
class _ExporterOptions:
162261
opset_version: int = 12
@@ -230,25 +329,30 @@ def _restore_state(self) -> None:
230329
if torch.cuda.is_available():
231330
torch.cuda.set_rng_state_all(self.cuda_rng_state)
232331

332+
# TODO(twata): Use `self.traced` instead or use traced result outputs
333+
def _get_original_outputs(self) -> None:
334+
self._restore_state()
335+
with _force_tracing(), grad.init_grad_state():
336+
self.original_outputs = self.original_model(*self.inputs)
337+
self.flat_outputs = _to_tuple_if_not_sequence(torch._C._jit_flatten(self.original_outputs)[0])
338+
233339
def _run_trace(self) -> None:
234340
# TODO(twata): Use `torch._C._craete_graph_by_tracing` instead.
235341
# So that we don't need to run heavy models multiple times
236-
self.traced: torch.jit.RecursiveScriptModule = torch.jit.trace( # type: ignore
237-
self.original_model,
238-
self.inputs,
239-
check_trace=self.check_trace,
240-
strict=self.strict_trace,
241-
_force_outplace=self.force_outplace_trace,
242-
)
342+
self._restore_state()
343+
with grad.init_grad_state():
344+
self.traced: torch.jit.RecursiveScriptModule = torch.jit.trace( # type: ignore
345+
self.original_model,
346+
self.inputs,
347+
check_trace=self.check_trace,
348+
strict=self.strict_trace,
349+
_force_outplace=self.force_outplace_trace,
350+
)
243351

244352
self.graph_doc_string = f"""
245353
# Model: {self.traced.original_name}
246354
"""
247355

248-
# TODO(twata): Use `self.traced` instead or use traced result outputs
249-
self._restore_state()
250-
self.original_outputs = self.original_model(*self.inputs)
251-
self.flat_outputs = _to_tuple_if_not_sequence(torch._C._jit_flatten(self.original_outputs)[0])
252356
self.g: torch._C.Graph = self.traced.inlined_graph
253357
"""
254358
`self.trace` ignores the override of `state_dict` method in `self.original_model`.
@@ -554,6 +658,9 @@ def symbolic_function(self, n: torch._C.Node) -> Optional[Callable]:
554658

555659
import pytorch_pfn_extras.onnx.symbolic_registry as sym_reg
556660

661+
if not sym_reg.is_registered_op(op, domain, self.opset_version) and op in _op_normalize_table:
662+
op = _op_normalize_table[op]
663+
557664
if sym_reg.is_registered_op(op, domain, self.opset_version): # type: ignore[no-untyped-call]
558665
return cast( # type: ignore[redundant-cast]
559666
Callable, sym_reg.get_registered_op(op, domain, self.opset_version) # type: ignore[no-untyped-call]
@@ -971,10 +1078,11 @@ def get_model_opset_imports(graph: onnx.GraphProto) -> List[onnx.OperatorSetIdPr
9711078
opset_imports.append(onnx.helper.make_opsetid(domain, version))
9721079
return opset_imports
9731080

974-
model: onnx.ModelProto = onnx.helper.make_model(
1081+
model: onnx.ModelProto = onnx.helper.make_model_gen_version(
9751082
graph,
9761083
opset_imports=get_model_opset_imports(graph),
9771084
producer_name="pfto",
1085+
ir_version=_fix_ir_version,
9781086
)
9791087
model = self.check_model(model)
9801088

@@ -1006,6 +1114,7 @@ def _convert(self) -> None:
10061114
sym_hel._set_onnx_shape_inference( # type: ignore[no-untyped-call]
10071115
False # TODO(twata): Use `self.onnx_shape_inference`
10081116
)
1117+
self._get_original_outputs()
10091118
self._run_trace()
10101119
self.model: onnx.ModelProto = self.generate_onnx()
10111120
finally:

pytorch_pfn_extras/training/extensions/_snapshot.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -337,20 +337,19 @@ def __init__(
337337
def initialize( # type: ignore[override]
338338
self, manager: ExtensionsManagerProtocol) -> Optional[str]:
339339
target = manager if self._target is None else self._target
340-
outdir = manager.out
341340
writer = manager.writer if self.writer is None else self.writer
342341
self.writer = writer
343342
loaded_fn = None
344343
if self.autoload:
345-
# If ``autoload`` is on, this code scans the ``outdir``
344+
# If ``autoload`` is on, this code scans the ``writer.out_dir``
346345
# for potential snapshot files by matching the file names
347346
# from ``filename`` format, picks up the latest one in
348347
# terms of mtime, and tries to load it it the target or
349348
# manager.
350349
assert writer is not None
351-
loaded_fn = _find_latest_snapshot(self.filename, outdir, writer.fs)
350+
loaded_fn = _find_latest_snapshot(self.filename, writer.out_dir, writer.fs)
352351
if loaded_fn:
353-
snapshot_file = writer.fs.open(os.path.join(outdir, loaded_fn), 'rb')
352+
snapshot_file = writer.fs.open(os.path.join(writer.out_dir, loaded_fn), 'rb')
354353
# As described above (at ``autoload`` option),
355354
# snapshot files to be autoloaded must be saved by
356355
# ``save_npz`` . In order to support general format,
@@ -376,10 +375,10 @@ def initialize( # type: ignore[override]
376375
# injected here.
377376
def _cleanup() -> None:
378377
assert writer is not None
379-
files = _find_stale_snapshots(self.filename, outdir,
378+
files = _find_stale_snapshots(self.filename, writer.out_dir,
380379
self.n_retains, writer.fs)
381380
for file in files:
382-
writer.fs.remove(os.path.join(outdir, file))
381+
writer.fs.remove(os.path.join(writer.out_dir, file))
383382

384383
assert writer is not None
385384
writer._add_cleanup_hook(_cleanup)

0 commit comments

Comments
 (0)