@@ -36,34 +36,20 @@ def __call__(self, model_path):
3636 is_weight_meta_fully_updated = self ._update_by_original_name (
3737 weight_meta_classes , self .original_name2parent_weight_meta_class
3838 )
39- if is_weight_meta_fully_updated :
40- new_weight_meta_codes = []
41- for meta_class in weight_meta_classes :
42- new_weight_meta_codes .append (
43- self ._generate_py_code_from_meta_class (meta_class )
44- )
45-
46- weight_meta_file_path = os .path .join (model_path , "weight_meta.py" )
47- if self .config ["update_inplace" ]:
48- print (f"[GraphMetaRestorer] Update { weight_meta_file_path } " )
49- with open (weight_meta_file_path , "w" ) as f :
50- f .write ("\n \n " .join (new_weight_meta_codes ))
39+ if (
40+ not self .config ["weight_meta_allow_partial_update" ]
41+ or is_weight_meta_fully_updated
42+ ):
43+ self ._rewrite_meta_codes (model_path , weight_meta_classes , "weight_meta.py" )
5144
5245 is_input_meta_fully_updated = self ._update_by_tensor_spec (
5346 input_meta_classes , self .original_name2parent_input_meta_class
5447 )
55- if is_input_meta_fully_updated :
56- new_input_meta_codes = []
57- for meta_class in input_meta_classes :
58- new_input_meta_codes .append (
59- self ._generate_py_code_from_meta_class (meta_class )
60- )
61-
62- input_meta_file_path = os .path .join (model_path , "input_meta.py" )
63- if self .config ["update_inplace" ]:
64- print (f"[GraphMetaRestorer] Update { input_meta_file_path } " )
65- with open (input_meta_file_path , "w" ) as f :
66- f .write ("\n \n " .join (new_input_meta_codes ))
48+ if (
49+ not self .config ["input_meta_allow_partial_update" ]
50+ or is_input_meta_fully_updated
51+ ):
52+ self ._rewrite_meta_codes (model_path , input_meta_classes , "input_meta.py" )
6753
6854 def _load_weight_and_input_meta_classes (self , model_path ):
6955 weight_meta_file_path = os .path .join (model_path , "weight_meta.py" )
@@ -115,7 +101,7 @@ def _update_by_original_name(self, meta_classes, original_name2parent_meta_class
115101 updated_class_names .add (meta_class .name )
116102
117103 print (
118- f"[GraphMetaRestorer] { len (updated_class_names )} /{ len (meta_classes )} classes are updated ."
104+ f"[GraphMetaRestorer] { len (updated_class_names )} /{ len (meta_classes )} classes can be restored ."
119105 )
120106 return len (meta_classes ) == len (updated_class_names )
121107
@@ -133,7 +119,7 @@ def _update_by_tensor_spec(self, meta_classes, original_name2parent_meta_class):
133119 updated_class_names .add (meta_class .name )
134120
135121 print (
136- f"[GraphMetaRestorer] { len (updated_class_names )} /{ len (meta_classes )} classes are updated ."
122+ f"[GraphMetaRestorer] { len (updated_class_names )} /{ len (meta_classes )} classes can be restored ."
137123 )
138124 return len (meta_classes ) == len (updated_class_names )
139125
@@ -151,3 +137,14 @@ def _generate_py_code_from_meta_class(self, meta_class):
151137 )
152138 lines .append (f" { name } = { value_str } " )
153139 return "\n " .join (lines )
140+
141+ def _rewrite_meta_codes (self , model_path , updated_meta_classes , filename ):
142+ new_meta_codes = []
143+ for meta_class in updated_meta_classes :
144+ new_meta_codes .append (self ._generate_py_code_from_meta_class (meta_class ))
145+
146+ meta_file_path = os .path .join (model_path , filename )
147+ if self .config ["update_inplace" ]:
148+ print (f"[GraphMetaRestorer] Update { meta_file_path } " )
149+ with open (meta_file_path , "w" ) as f :
150+ f .write ("\n \n " .join (new_meta_codes ))
0 commit comments