Skip to content

Commit c132364

Browse files
committed
Support use original tensor_meta to to recover the re-extracted samples.
1 parent 4f954ce commit c132364

File tree

3 files changed

+57
-36
lines changed

3 files changed

+57
-36
lines changed

graph_net/paddle/naive_graph_decomposer.py

Lines changed: 53 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
import os
2+
from typing import List
3+
import paddle
4+
from graph_net import imp_util
25
from graph_net.paddle.extractor import GraphExtractor as BuiltinGraphExtractor
36

47

@@ -19,47 +22,66 @@ def __init__(
1922

2023
def make_config(
2124
self,
22-
split_positions=(),
25+
split_positions=None,
2326
group_head_and_tail=False,
2427
chain_style=False,
2528
output_dir="./tmp/naive_decomposer_dir",
29+
post_extract_process_path=None,
30+
post_extract_process_class_name=None,
31+
post_extract_process_config=None,
2632
):
27-
for pos in split_positions:
33+
assert not chain_style, "chain_style=True is not supported now."
34+
if split_positions is not None:
2835
assert isinstance(
29-
pos, int
30-
), f"split_positions should be list of int, {split_positions=}"
36+
split_positions, (tuple, list)
37+
), f"split_positions is expected to be tuple or list, but recived {split_positions=}"
38+
for pos in split_positions:
39+
assert isinstance(
40+
pos, int
41+
), f"split_positions is expected to be tuple or list of int, but recived {split_positions=}"
3142
return {
3243
"split_positions": split_positions,
3344
"group_head_and_tail": group_head_and_tail,
3445
"chain_style": chain_style,
3546
"output_dir": output_dir,
47+
"post_extract_process_path": post_extract_process_path,
48+
"post_extract_process_class_name": post_extract_process_class_name,
49+
"post_extract_process_config": post_extract_process_config,
3650
}
3751

3852
def __call__(self, **input_dict):
3953
extracted_model = self.get_naive_decomposer_extractor()(**input_dict)
4054
return extracted_model
4155

4256
def get_naive_decomposer_extractor(self):
43-
return NaiveDecomposerExtractor(self)
57+
return NaiveDecomposerExtractor(
58+
config=self.config,
59+
parent_model=self.model,
60+
parent_model_name=self.name,
61+
parent_input_spec=self.input_spec,
62+
)
4463

4564

4665
class NaiveDecomposerExtractor:
47-
def __init__(self, parent_graph_extractor):
48-
super().__init__()
49-
self.parent_graph_extractor = parent_graph_extractor
66+
def __init__(
67+
self,
68+
config: dict,
69+
parent_model: paddle.nn.Layer,
70+
parent_model_name: str,
71+
parent_input_spec: List[paddle.static.InputSpec],
72+
):
73+
self.config = config
5074
self.extracted = False
5175
self.builtin_extractor = BuiltinGraphExtractor(
52-
model=parent_graph_extractor.model,
53-
name=parent_graph_extractor.name,
54-
dynamic=parent_graph_extractor.dynamic,
55-
input_spec=parent_graph_extractor.input_spec,
56-
workspace_path=self.parent_graph_extractor.config["output_dir"],
76+
model=parent_model,
77+
name=parent_model_name,
78+
dynamic=False,
79+
input_spec=parent_input_spec,
80+
workspace_path=self.config["output_dir"],
5781
)
58-
self.split_positions = self.parent_graph_extractor.config["split_positions"]
59-
self.group_head_and_tail = self.parent_graph_extractor.config[
60-
"group_head_and_tail"
61-
]
62-
self.post_process = self.make_post_process(self.parent_graph_extractor.config)
82+
self.split_positions = self.config["split_positions"]
83+
self.group_head_and_tail = self.config["group_head_and_tail"]
84+
self.post_extract_process = self.make_post_extract_process(self.config)
6385

6486
def do_extract(self, **input_dict):
6587
# 1. Run the model to dump pir programs
@@ -97,14 +119,19 @@ def __call__(self, **input_dict):
97119
if not self.extracted:
98120
extracted_model = self.do_extract(**input_dict)
99121
self.extracted = True
100-
# if self.extracted:
101-
# for subgraph_path in self.subgraph_path_list:
102-
# self.post_process(subgraph_path)
122+
123+
for subgraph_path in self.subgraph_path_list:
124+
self._post_extract_process(subgraph_path)
103125
return extracted_model
104126

105-
def make_post_process(self, config):
127+
def _post_extract_process(self, model_path):
106128
return None
107-
# if config["post_process_path"] is None:
108-
# return None
109-
# module = imp_util.load_module(config["post_process_path"])
110-
# return module.PostExtractProcess(config["post_process_config"])
129+
# model_path = os.path.join(self.config["output_dir"], self.model_name)
130+
# return self.post_extract_process(model_path)
131+
132+
def make_post_extract_process(self, config):
133+
if config.get("post_extract_process_path") is None:
134+
return lambda *args, **kwargs: None
135+
module = imp_util.load_module(config["post_extract_process_path"])
136+
cls = getattr(module, config["post_extract_process_class_name"])
137+
return cls(config["post_extract_process_config"])

graph_net/paddle/run_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
import os
2-
import sys
32
import json
43
import base64
54
import argparse
6-
from typing import Type
75

86
os.environ["FLAGS_logging_pir_py_code_dir"] = "/tmp/dump"
97

@@ -26,7 +24,8 @@ def get_input_dict(model_path):
2624

2725
state_dict = {}
2826
for k, v in params.items():
29-
state_dict[k] = paddle.nn.parameter.Parameter(utils.replay_tensor(v), name=k)
27+
name = v["original_name"] if v.get("original_name", None) else k
28+
state_dict[k] = paddle.nn.parameter.Parameter(utils.replay_tensor(v), name=name)
3029
for k, v in inputs.items():
3130
state_dict[k] = utils.replay_tensor(v)
3231
return state_dict
@@ -83,4 +82,5 @@ def main(args):
8382
help="decorator configuration string",
8483
)
8584
args = parser.parse_args()
85+
print(args)
8686
main(args=args)

graph_net/paddle/utils.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,4 @@
1-
import re
2-
from collections import OrderedDict
3-
import uuid
4-
import json
5-
import os
6-
import argparse
71
import importlib
8-
import inspect
92
import ast
103
import math
114
import numpy as np
@@ -169,6 +162,7 @@ def convert_meta_classes_to_tensors(file_path):
169162
},
170163
"data": data_value,
171164
"name": attrs.get("name"),
165+
"original_name": attrs.get("original_name", None),
172166
}
173167

174168

0 commit comments

Comments
 (0)