11import os
2+ from typing import List
3+ import paddle
4+ from graph_net import imp_util
25from graph_net .paddle .extractor import GraphExtractor as BuiltinGraphExtractor
36
47
@@ -19,47 +22,66 @@ def __init__(
1922
2023 def make_config (
2124 self ,
22- split_positions = () ,
25+ split_positions = None ,
2326 group_head_and_tail = False ,
2427 chain_style = False ,
2528 output_dir = "./tmp/naive_decomposer_dir" ,
29+ post_extract_process_path = None ,
30+ post_extract_process_class_name = None ,
31+ post_extract_process_config = None ,
2632 ):
27- for pos in split_positions :
33+ assert not chain_style , "chain_style=True is not supported now."
34+ if split_positions is not None :
2835 assert isinstance (
29- pos , int
30- ), f"split_positions should be list of int, { split_positions = } "
36+ split_positions , (tuple , list )
37+ ), f"split_positions is expected to be tuple or list, but recived { split_positions = } "
38+ for pos in split_positions :
39+ assert isinstance (
40+ pos , int
41+ ), f"split_positions is expected to be tuple or list of int, but recived { split_positions = } "
3142 return {
3243 "split_positions" : split_positions ,
3344 "group_head_and_tail" : group_head_and_tail ,
3445 "chain_style" : chain_style ,
3546 "output_dir" : output_dir ,
47+ "post_extract_process_path" : post_extract_process_path ,
48+ "post_extract_process_class_name" : post_extract_process_class_name ,
49+ "post_extract_process_config" : post_extract_process_config ,
3650 }
3751
3852 def __call__ (self , ** input_dict ):
3953 extracted_model = self .get_naive_decomposer_extractor ()(** input_dict )
4054 return extracted_model
4155
4256 def get_naive_decomposer_extractor (self ):
43- return NaiveDecomposerExtractor (self )
57+ return NaiveDecomposerExtractor (
58+ config = self .config ,
59+ parent_model = self .model ,
60+ parent_model_name = self .name ,
61+ parent_input_spec = self .input_spec ,
62+ )
4463
4564
4665class NaiveDecomposerExtractor :
47- def __init__ (self , parent_graph_extractor ):
48- super ().__init__ ()
49- self .parent_graph_extractor = parent_graph_extractor
66+ def __init__ (
67+ self ,
68+ config : dict ,
69+ parent_model : paddle .nn .Layer ,
70+ parent_model_name : str ,
71+ parent_input_spec : List [paddle .static .InputSpec ],
72+ ):
73+ self .config = config
5074 self .extracted = False
5175 self .builtin_extractor = BuiltinGraphExtractor (
52- model = parent_graph_extractor . model ,
53- name = parent_graph_extractor . name ,
54- dynamic = parent_graph_extractor . dynamic ,
55- input_spec = parent_graph_extractor . input_spec ,
56- workspace_path = self .parent_graph_extractor . config ["output_dir" ],
76+ model = parent_model ,
77+ name = parent_model_name ,
78+ dynamic = False ,
79+ input_spec = parent_input_spec ,
80+ workspace_path = self .config ["output_dir" ],
5781 )
58- self .split_positions = self .parent_graph_extractor .config ["split_positions" ]
59- self .group_head_and_tail = self .parent_graph_extractor .config [
60- "group_head_and_tail"
61- ]
62- self .post_process = self .make_post_process (self .parent_graph_extractor .config )
82+ self .split_positions = self .config ["split_positions" ]
83+ self .group_head_and_tail = self .config ["group_head_and_tail" ]
84+ self .post_extract_process = self .make_post_extract_process (self .config )
6385
6486 def do_extract (self , ** input_dict ):
6587 # 1. Run the model to dump pir programs
@@ -97,14 +119,19 @@ def __call__(self, **input_dict):
97119 if not self .extracted :
98120 extracted_model = self .do_extract (** input_dict )
99121 self .extracted = True
100- # if self.extracted:
101- # for subgraph_path in self.subgraph_path_list:
102- # self.post_process (subgraph_path)
122+
123+ for subgraph_path in self .subgraph_path_list :
124+ self ._post_extract_process (subgraph_path )
103125 return extracted_model
104126
105- def make_post_process (self , config ):
127+ def _post_extract_process (self , model_path ):
106128 return None
107- # if config["post_process_path"] is None:
108- # return None
109- # module = imp_util.load_module(config["post_process_path"])
110- # return module.PostExtractProcess(config["post_process_config"])
129+ # model_path = os.path.join(self.config["output_dir"], self.model_name)
130+ # return self.post_extract_process(model_path)
131+
132+ def make_post_extract_process (self , config ):
133+ if config .get ("post_extract_process_path" ) is None :
134+ return lambda * args , ** kwargs : None
135+ module = imp_util .load_module (config ["post_extract_process_path" ])
136+ cls = getattr (module , config ["post_extract_process_class_name" ])
137+ return cls (config ["post_extract_process_config" ])
0 commit comments