Skip to content

Commit a7982d5

Browse files
committed
Optimize codes.
1 parent 2543be1 commit a7982d5

File tree

2 files changed

+24
-27
lines changed

2 files changed

+24
-27
lines changed

graph_net/paddle/graph_meta_restorer.py

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -36,34 +36,20 @@ def __call__(self, model_path):
3636
is_weight_meta_fully_updated = self._update_by_original_name(
3737
weight_meta_classes, self.original_name2parent_weight_meta_class
3838
)
39-
if is_weight_meta_fully_updated:
40-
new_weight_meta_codes = []
41-
for meta_class in weight_meta_classes:
42-
new_weight_meta_codes.append(
43-
self._generate_py_code_from_meta_class(meta_class)
44-
)
45-
46-
weight_meta_file_path = os.path.join(model_path, "weight_meta.py")
47-
if self.config["update_inplace"]:
48-
print(f"[GraphMetaRestorer] Update {weight_meta_file_path}")
49-
with open(weight_meta_file_path, "w") as f:
50-
f.write("\n\n".join(new_weight_meta_codes))
39+
if (
40+
not self.config["weight_meta_allow_partial_update"]
41+
or is_weight_meta_fully_updated
42+
):
43+
self._rewrite_meta_codes(model_path, weight_meta_classes, "weight_meta.py")
5144

5245
is_input_meta_fully_updated = self._update_by_tensor_spec(
5346
input_meta_classes, self.original_name2parent_input_meta_class
5447
)
55-
if is_input_meta_fully_updated:
56-
new_input_meta_codes = []
57-
for meta_class in input_meta_classes:
58-
new_input_meta_codes.append(
59-
self._generate_py_code_from_meta_class(meta_class)
60-
)
61-
62-
input_meta_file_path = os.path.join(model_path, "input_meta.py")
63-
if self.config["update_inplace"]:
64-
print(f"[GraphMetaRestorer] Update {input_meta_file_path}")
65-
with open(input_meta_file_path, "w") as f:
66-
f.write("\n\n".join(new_input_meta_codes))
48+
if (
49+
not self.config["input_meta_allow_partial_update"]
50+
or is_input_meta_fully_updated
51+
):
52+
self._rewrite_meta_codes(model_path, input_meta_classes, "input_meta.py")
6753

6854
def _load_weight_and_input_meta_classes(self, model_path):
6955
weight_meta_file_path = os.path.join(model_path, "weight_meta.py")
@@ -115,7 +101,7 @@ def _update_by_original_name(self, meta_classes, original_name2parent_meta_class
115101
updated_class_names.add(meta_class.name)
116102

117103
print(
118-
f"[GraphMetaRestorer] {len(updated_class_names)}/{len(meta_classes)} classes are updated."
104+
f"[GraphMetaRestorer] {len(updated_class_names)}/{len(meta_classes)} classes can be restored."
119105
)
120106
return len(meta_classes) == len(updated_class_names)
121107

@@ -133,7 +119,7 @@ def _update_by_tensor_spec(self, meta_classes, original_name2parent_meta_class):
133119
updated_class_names.add(meta_class.name)
134120

135121
print(
136-
f"[GraphMetaRestorer] {len(updated_class_names)}/{len(meta_classes)} classes are updated."
122+
f"[GraphMetaRestorer] {len(updated_class_names)}/{len(meta_classes)} classes can be restored."
137123
)
138124
return len(meta_classes) == len(updated_class_names)
139125

@@ -151,3 +137,14 @@ def _generate_py_code_from_meta_class(self, meta_class):
151137
)
152138
lines.append(f" {name} = {value_str}")
153139
return "\n".join(lines)
140+
141+
def _rewrite_meta_codes(self, model_path, updated_meta_classes, filename):
142+
new_meta_codes = []
143+
for meta_class in updated_meta_classes:
144+
new_meta_codes.append(self._generate_py_code_from_meta_class(meta_class))
145+
146+
meta_file_path = os.path.join(model_path, filename)
147+
if self.config["update_inplace"]:
148+
print(f"[GraphMetaRestorer] Update {meta_file_path}")
149+
with open(meta_file_path, "w") as f:
150+
f.write("\n\n".join(new_meta_codes))

graph_net/paddle/naive_graph_decomposer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def __init__(
1515
input_spec=None,
1616
):
1717
self.model = model
18-
self.name = name
18+
self.name = name.replace("/", "_")
1919
self.dynamic = dynamic
2020
self.input_spec = input_spec
2121
self.config = self.make_config(**config)

0 commit comments

Comments
 (0)