|
3 | 3 | import typing |
4 | 4 | import warnings |
5 | 5 | from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union, cast |
| 6 | +from contextlib import contextmanager |
6 | 7 |
|
7 | 8 | import onnx |
8 | 9 | import onnx.checker |
|
11 | 12 | import onnx.shape_inference |
12 | 13 | import pytorch_pfn_extras |
13 | 14 | import pytorch_pfn_extras.onnx._constants |
| 15 | +from pytorch_pfn_extras.onnx import _grad as grad |
14 | 16 | from pytorch_pfn_extras.onnx._globals import GLOBALS |
15 | 17 | from pytorch_pfn_extras.torchscript import run_jit_pass |
16 | 18 | import torch |
|
28 | 30 |
|
29 | 31 | _ppe_ignore_scope: str = "_ppe_as_out_module" |
30 | 32 | _list_create_ops: List[str] = ["prim::ListConstruct", "onnx::SequenceConstruct", "onnx::SequenceEmpty"] |
| 33 | +_fix_ir_version = 8 |
| 34 | + |
| 35 | +# Original from https://github.com/pytorch/pytorch/blob/52a36a98d9425479f62b6e2d1a59e434b85f7f7e/torch/csrc/jit/passes/normalize_ops.cpp#L85-L162 |
| 36 | +_op_normalize_table: Dict[str, str] = { |
| 37 | + "absolute": "abs", |
| 38 | + "absolute_": "abs_", |
| 39 | + "clip": "clamp", |
| 40 | + "clip_": "clamp_", |
| 41 | + "det": "linalg_det", |
| 42 | + "matrix_power": "linalg_matrix_power", |
| 43 | + "matrix_exp": "linalg_matrix_exp", |
| 44 | + "ger": "outer", |
| 45 | + "arccos": "acos", |
| 46 | + "arccos_": "acos_", |
| 47 | + "arcsin": "asin", |
| 48 | + "arcsin_": "asin_", |
| 49 | + "arctan": "atan", |
| 50 | + "arctan_": "atan_", |
| 51 | + "arctan2": "atan2", |
| 52 | + "arctan2_": "atan2_", |
| 53 | + "arccosh": "acosh", |
| 54 | + "arccosh_": "acosh_", |
| 55 | + "arcsinh": "asinh", |
| 56 | + "arcsinh_": "asinh_", |
| 57 | + "arctanh": "atanh", |
| 58 | + "arctanh_": "atanh_", |
| 59 | + "fix": "trunc", |
| 60 | + "fix_": "trunc_", |
| 61 | + "negative": "neg", |
| 62 | + "negative_": "neg_", |
| 63 | + "subtract": "sub", |
| 64 | + "subtract_": "sub_", |
| 65 | + "greater_equal": "ge", |
| 66 | + "greater_equal_": "ge_", |
| 67 | + "greater": "gt", |
| 68 | + "greater_": "gt_", |
| 69 | + "less_equal": "le", |
| 70 | + "less_equal_": "le_", |
| 71 | + "less": "lt", |
| 72 | + "less_": "lt_", |
| 73 | + "not_equal": "ne", |
| 74 | + "not_equal_": "ne_", |
| 75 | + "divide": "div", |
| 76 | + "divide_": "div_", |
| 77 | + "multiply": "mul", |
| 78 | + "multiply_": "mul_", |
| 79 | + "linalg_matmul": "matmul", |
| 80 | + "inverse": "linalg_inv", |
| 81 | + "true_divide": "div", |
| 82 | + "true_divide_": "div_", |
| 83 | + "concat": "cat", |
| 84 | + "concatenate": "cat", |
| 85 | + "row_stack": "vstack", |
| 86 | + "swapdims": "transpose", |
| 87 | + "swapdims_": "transpose_", |
| 88 | + "swapaxes": "transpose", |
| 89 | + "swapaxes_": "transpose_", |
| 90 | + "moveaxis": "movedim", |
| 91 | + "special_erf": "erf", |
| 92 | + "special_erfc": "erfc", |
| 93 | + "special_erfinv": "erfinv", |
| 94 | + "special_expit": "sigmoid", |
| 95 | + "special_exp2": "exp2", |
| 96 | + "special_expm1": "expm1", |
| 97 | + "special_logit": "logit", |
| 98 | + "special_logsumexp": "logsumexp", |
| 99 | + "special_round": "round", |
| 100 | + "special_log1p": "log1p", |
| 101 | + "special_sinc": "sinc", |
| 102 | + "special_digamma": "digamma", |
| 103 | + "special_psi": "digamma", |
| 104 | + "special_i0": "i0", |
| 105 | + "special_xlogy": "xlogy", |
| 106 | + "special_log_softmax": "log_softmax", |
| 107 | + "orgqr": "linalg_householder_product", |
| 108 | + "adjoint": "mH", |
| 109 | + "special_multigammaln": "mvlgamma", |
| 110 | + "special_polygamma": "polygamma", |
| 111 | + "special_softmax": "softmax", |
| 112 | + "special_gammainc": "igamma", |
| 113 | + "special_gammaincc": "igammac", |
| 114 | + "special_gammaln": "lgamma", |
| 115 | +} |
31 | 116 |
|
32 | 117 | if pytorch_pfn_extras.requires("1.13"): |
33 | 118 | from torch.onnx._internal import jit_utils |
@@ -157,6 +242,20 @@ def _apply_tensor_info_to_value_info(v: onnx.ValueInfoProto, t: torch.Tensor) -> |
157 | 242 | a.dim_value = i |
158 | 243 |
|
159 | 244 |
|
| 245 | +@contextmanager |
| 246 | +def _force_tracing() -> Any: |
| 247 | + old_is_tracing = torch.jit.is_tracing |
| 248 | + |
| 249 | + def is_tracing() -> bool: |
| 250 | + return True |
| 251 | + |
| 252 | + try: |
| 253 | + torch.jit.is_tracing = is_tracing |
| 254 | + yield |
| 255 | + finally: |
| 256 | + torch.jit.is_tracing = old_is_tracing |
| 257 | + |
| 258 | + |
160 | 259 | @dataclasses.dataclass |
161 | 260 | class _ExporterOptions: |
162 | 261 | opset_version: int = 12 |
@@ -230,25 +329,30 @@ def _restore_state(self) -> None: |
230 | 329 | if torch.cuda.is_available(): |
231 | 330 | torch.cuda.set_rng_state_all(self.cuda_rng_state) |
232 | 331 |
|
| 332 | + # TODO(twata): Use `self.traced` instead or use traced result outputs |
| 333 | + def _get_original_outputs(self) -> None: |
| 334 | + self._restore_state() |
| 335 | + with _force_tracing(), grad.init_grad_state(): |
| 336 | + self.original_outputs = self.original_model(*self.inputs) |
| 337 | + self.flat_outputs = _to_tuple_if_not_sequence(torch._C._jit_flatten(self.original_outputs)[0]) |
| 338 | + |
233 | 339 | def _run_trace(self) -> None: |
234 | 340 | # TODO(twata): Use `torch._C._craete_graph_by_tracing` instead. |
235 | 341 | # So that we don't need to run heavy models multiple times |
236 | | - self.traced: torch.jit.RecursiveScriptModule = torch.jit.trace( # type: ignore |
237 | | - self.original_model, |
238 | | - self.inputs, |
239 | | - check_trace=self.check_trace, |
240 | | - strict=self.strict_trace, |
241 | | - _force_outplace=self.force_outplace_trace, |
242 | | - ) |
| 342 | + self._restore_state() |
| 343 | + with grad.init_grad_state(): |
| 344 | + self.traced: torch.jit.RecursiveScriptModule = torch.jit.trace( # type: ignore |
| 345 | + self.original_model, |
| 346 | + self.inputs, |
| 347 | + check_trace=self.check_trace, |
| 348 | + strict=self.strict_trace, |
| 349 | + _force_outplace=self.force_outplace_trace, |
| 350 | + ) |
243 | 351 |
|
244 | 352 | self.graph_doc_string = f""" |
245 | 353 | # Model: {self.traced.original_name} |
246 | 354 | """ |
247 | 355 |
|
248 | | - # TODO(twata): Use `self.traced` instead or use traced result outputs |
249 | | - self._restore_state() |
250 | | - self.original_outputs = self.original_model(*self.inputs) |
251 | | - self.flat_outputs = _to_tuple_if_not_sequence(torch._C._jit_flatten(self.original_outputs)[0]) |
252 | 356 | self.g: torch._C.Graph = self.traced.inlined_graph |
253 | 357 | """ |
254 | 358 | `self.trace` ignores the override of `state_dict` method in `self.original_model`. |
@@ -554,6 +658,9 @@ def symbolic_function(self, n: torch._C.Node) -> Optional[Callable]: |
554 | 658 |
|
555 | 659 | import pytorch_pfn_extras.onnx.symbolic_registry as sym_reg |
556 | 660 |
|
| 661 | + if not sym_reg.is_registered_op(op, domain, self.opset_version) and op in _op_normalize_table: |
| 662 | + op = _op_normalize_table[op] |
| 663 | + |
557 | 664 | if sym_reg.is_registered_op(op, domain, self.opset_version): # type: ignore[no-untyped-call] |
558 | 665 | return cast( # type: ignore[redundant-cast] |
559 | 666 | Callable, sym_reg.get_registered_op(op, domain, self.opset_version) # type: ignore[no-untyped-call] |
@@ -971,10 +1078,11 @@ def get_model_opset_imports(graph: onnx.GraphProto) -> List[onnx.OperatorSetIdPr |
971 | 1078 | opset_imports.append(onnx.helper.make_opsetid(domain, version)) |
972 | 1079 | return opset_imports |
973 | 1080 |
|
974 | | - model: onnx.ModelProto = onnx.helper.make_model( |
| 1081 | + model: onnx.ModelProto = onnx.helper.make_model_gen_version( |
975 | 1082 | graph, |
976 | 1083 | opset_imports=get_model_opset_imports(graph), |
977 | 1084 | producer_name="pfto", |
| 1085 | + ir_version=_fix_ir_version, |
978 | 1086 | ) |
979 | 1087 | model = self.check_model(model) |
980 | 1088 |
|
@@ -1006,6 +1114,7 @@ def _convert(self) -> None: |
1006 | 1114 | sym_hel._set_onnx_shape_inference( # type: ignore[no-untyped-call] |
1007 | 1115 | False # TODO(twata): Use `self.onnx_shape_inference` |
1008 | 1116 | ) |
| 1117 | + self._get_original_outputs() |
1009 | 1118 | self._run_trace() |
1010 | 1119 | self.model: onnx.ModelProto = self.generate_onnx() |
1011 | 1120 | finally: |
|
0 commit comments