Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
14f77b9
Support to save the running_states.
Xreki Dec 3, 2025
5678f1c
Define a dataclass DecomposeConfig.
Xreki Dec 3, 2025
7d9581f
Fix a sample.
Xreki Dec 3, 2025
3d30e86
Merge branch 'develop' into opt_saved_results
Xreki Dec 3, 2025
864e7b3
Record the number of original incorrect models.
Xreki Dec 3, 2025
4f954ce
Add original_name for ResNet18.
Xreki Dec 5, 2025
3b6c041
Support use original tensor_meta to to recover the re-extracted samples.
Xreki Dec 5, 2025
a59fbba
Add original_name in meta for some paddle samples.
Xreki Dec 8, 2025
78010b8
Merge branch 'add_original_name_sample' into add_original_names
Xreki Dec 8, 2025
2543be1
Merge branch 'develop' into add_original_names
Xreki Dec 8, 2025
a7982d5
Optimize codes.
Xreki Dec 8, 2025
6d15fda
Merge branch 'opt_saved_results' into add_original_names
Xreki Dec 8, 2025
5c40420
Enable meta restorer in binary composer.
Xreki Dec 8, 2025
74f423e
Merge branch 'develop' into opt_saved_results
Xreki Dec 8, 2025
1196549
Optimzie codes.
Xreki Dec 8, 2025
c8a9f68
Merge branch 'develop' into add_original_names
Xreki Dec 9, 2025
c067624
Merge branch 'develop' into opt_saved_results
Xreki Dec 9, 2025
ade8bb9
Merge branch 'opt_saved_results' into add_original_names
Xreki Dec 9, 2025
2b99941
Temporally support to save the random states.
Xreki Dec 9, 2025
d310856
Change the initialization method of tensor back to truncated normal.
Xreki Dec 9, 2025
d7c91a2
Opimize codes.
Xreki Dec 10, 2025
00b070d
Support fixed-start method.
Xreki Dec 10, 2025
7e95d7f
Merge branch 'opt_saved_results' into add_original_names
Xreki Dec 10, 2025
7cfd4eb
Support fixed-start method.
Xreki Dec 10, 2025
a2f80c6
Merge branch 'opt_saved_results' into add_original_names
Xreki Dec 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion graph_net/imp_util.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import importlib.util as imp


Expand All @@ -6,5 +7,5 @@ def load_module(path, name="unnamed"):
module = imp.module_from_spec(spec)
module.__file__ = path
spec.loader.exec_module(module)
module.__graph_net_file_path__ = path
module.__graph_net_file_path__ = os.path.normpath(path)
return module
82 changes: 54 additions & 28 deletions graph_net/paddle/graph_decomposer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import os
from typing import List
import paddle
from graph_net import imp_util
from graph_net.paddle.extractor import GraphExtractor as BuiltinGraphExtractor


Expand All @@ -12,54 +15,74 @@ def __init__(
input_spec=None,
):
self.model = model
self.name = name
self.name = name.replace("/", "_")
self.dynamic = dynamic
self.input_spec = input_spec
self.config = self.make_config(**config)

def make_config(
self,
split_positions=(),
split_positions=None,
group_head_and_tail=False,
chain_style=False,
output_dir="./tmp/naive_decomposer_dir",
post_extract_process_path=None,
post_extract_process_class_name=None,
post_extract_process_config=None,
):
for pos in split_positions:
assert not chain_style, "chain_style=True is not supported now."
if split_positions is not None:
assert isinstance(
pos, int
), f"split_positions should be list of int, {split_positions=}"
split_positions, (tuple, list)
), f"split_positions is expected to be tuple or list, but recived {split_positions=}"
for pos in split_positions:
assert isinstance(
pos, int
), f"split_positions is expected to be tuple or list of int, but recived {split_positions=}"
return {
"split_positions": split_positions,
"group_head_and_tail": group_head_and_tail,
"chain_style": chain_style,
"output_dir": output_dir,
"post_extract_process_path": post_extract_process_path,
"post_extract_process_class_name": post_extract_process_class_name,
"post_extract_process_config": post_extract_process_config,
}

def __call__(self, **input_dict):
extracted_model = self.get_naive_decomposer_extractor()(**input_dict)
return extracted_model

def get_naive_decomposer_extractor(self):
return NaiveDecomposerExtractor(self)
return NaiveDecomposerExtractor(
config=self.config,
parent_model=self.model,
parent_model_name=self.name,
parent_input_spec=self.input_spec,
)


class NaiveDecomposerExtractor:
def __init__(self, parent_graph_extractor):
super().__init__()
self.parent_graph_extractor = parent_graph_extractor
def __init__(
self,
config: dict,
parent_model: paddle.nn.Layer,
parent_model_name: str,
parent_input_spec: List[paddle.static.InputSpec],
):
self.config = config
self.extracted = False
self.parent_model_path = os.path.dirname(parent_model.__graph_net_file_path__)
self.builtin_extractor = BuiltinGraphExtractor(
model=parent_graph_extractor.model,
name=parent_graph_extractor.name,
dynamic=parent_graph_extractor.dynamic,
input_spec=parent_graph_extractor.input_spec,
workspace_path=self.parent_graph_extractor.config["output_dir"],
model=parent_model,
name=parent_model_name,
dynamic=False,
input_spec=parent_input_spec,
workspace_path=self.config["output_dir"],
)
self.split_positions = self.parent_graph_extractor.config["split_positions"]
self.group_head_and_tail = self.parent_graph_extractor.config[
"group_head_and_tail"
]
self.post_process = self.make_post_process(self.parent_graph_extractor.config)
self.split_positions = self.config["split_positions"]
self.group_head_and_tail = self.config["group_head_and_tail"]
self.post_extract_process = self.make_post_extract_process(self.config)

def do_extract(self, **input_dict):
# 1. Run the model to dump pir programs
Expand Down Expand Up @@ -97,14 +120,17 @@ def __call__(self, **input_dict):
if not self.extracted:
extracted_model = self.do_extract(**input_dict)
self.extracted = True
# if self.extracted:
# for subgraph_path in self.subgraph_path_list:
# self.post_process(subgraph_path)

for subgraph_path in self.subgraph_path_list:
self._post_extract_process(subgraph_path)
return extracted_model

def make_post_process(self, config):
return None
# if config["post_process_path"] is None:
# return None
# module = imp_util.load_module(config["post_process_path"])
# return module.PostExtractProcess(config["post_process_config"])
def _post_extract_process(self, subgraph_path):
return self.post_extract_process(subgraph_path)

def make_post_extract_process(self, config):
if config.get("post_extract_process_path") is None:
return lambda *args, **kwargs: None
module = imp_util.load_module(config["post_extract_process_path"])
cls = getattr(module, config["post_extract_process_class_name"])
return cls(config["post_extract_process_config"], self.parent_model_path)
150 changes: 150 additions & 0 deletions graph_net/paddle/graph_meta_restorer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import os
from graph_net import path_utils
from graph_net.paddle import utils


class GraphMetaRestorer:
def __init__(self, config, parent_model_path):
self.config = config
self.parent_model_path = parent_model_path
print(f"parent_model_path: {self.parent_model_path}")

assert path_utils.is_single_model_dir(
parent_model_path
), f"{parent_model_path=} is not a graphnet sample."
(
parent_weight_meta_classes,
parent_input_meta_classes,
) = self._load_weight_and_input_meta_classes(parent_model_path)
self.original_name2parent_weight_meta_class = self._convert_to_dict(
parent_weight_meta_classes
)
self.original_name2parent_input_meta_class = self._convert_to_dict(
parent_input_meta_classes
)

def __call__(self, model_path):
assert path_utils.is_single_model_dir(
model_path
), f"{model_path=} is not a graphnet sample."
(
weight_meta_classes,
input_meta_classes,
) = self._load_weight_and_input_meta_classes(model_path)

assert self.config["update_inplace"]
is_weight_meta_fully_updated = self._update_by_original_name(
weight_meta_classes, self.original_name2parent_weight_meta_class
)
if (
not self.config["weight_meta_allow_partial_update"]
or is_weight_meta_fully_updated
):
self._rewrite_meta_codes(model_path, weight_meta_classes, "weight_meta.py")

is_input_meta_fully_updated = self._update_by_tensor_spec(
input_meta_classes, self.original_name2parent_input_meta_class
)
if (
not self.config["input_meta_allow_partial_update"]
or is_input_meta_fully_updated
):
self._rewrite_meta_codes(model_path, input_meta_classes, "input_meta.py")

def _load_weight_and_input_meta_classes(self, model_path):
weight_meta_file_path = os.path.join(model_path, "weight_meta.py")
weight_meta_classes = [
meta_class
for (name, meta_class) in utils.get_meta_classes(weight_meta_file_path)
]

input_meta_file_path = os.path.join(model_path, "input_meta.py")
input_meta_classes = [
meta_class
for (name, meta_class) in utils.get_meta_classes(input_meta_file_path)
]

return weight_meta_classes, input_meta_classes

def _convert_to_dict(self, meta_classes):
original_name2meta_class = {}
for meta_class in meta_classes:
assert meta_class.original_name not in original_name2meta_class.keys()
original_name2meta_class[meta_class.original_name] = meta_class
return original_name2meta_class

def _update_tensor_meta(self, meta_class, parent_meta_class):
if (
parent_meta_class
and meta_class.dtype == parent_meta_class.dtype
and meta_class.shape == parent_meta_class.shape
):
for attr_name in ["max_val", "min_val", "mean", "std", "data"]:
if hasattr(meta_class, attr_name) or hasattr(
parent_meta_class, attr_name
):
attr_value = getattr(parent_meta_class, attr_name, None)
setattr(meta_class, attr_name, attr_value)
return True
return False

def _update_by_original_name(self, meta_classes, original_name2parent_meta_class):
updated_class_names = set()
for meta_class in meta_classes:
if not meta_class.original_name:
continue

parent_meta_class = original_name2parent_meta_class.get(
meta_class.original_name, None
)
if self._update_tensor_meta(meta_class, parent_meta_class):
updated_class_names.add(meta_class.name)

print(
f"[GraphMetaRestorer] {len(updated_class_names)}/{len(meta_classes)} classes can be restored."
)
return len(meta_classes) == len(updated_class_names)

def _update_by_tensor_spec(self, meta_classes, original_name2parent_meta_class):
updated_class_names = set()
for meta_class in meta_classes:
matched_parent_meta_class = [
parent_meta_class
for parent_meta_class in original_name2parent_meta_class.values()
if meta_class.dtype == parent_meta_class.dtype
and meta_class.shape == parent_meta_class.shape
]
if len(matched_parent_meta_class) == 1:
self._update_tensor_meta(meta_class, matched_parent_meta_class[0])
updated_class_names.add(meta_class.name)

print(
f"[GraphMetaRestorer] {len(updated_class_names)}/{len(meta_classes)} classes can be restored."
)
return len(meta_classes) == len(updated_class_names)

def _generate_py_code_from_meta_class(self, meta_class):
lines = [f"class {meta_class.__name__}:"]
members = vars(meta_class)
members = {k: v for k, v in members.items() if not k.startswith("__")}

if not members:
return lines[0] + "\n pass"

for name, value in members.items():
value_str = (
f"float('{repr(value)}')" if isinstance(value, float) else repr(value)
)
lines.append(f" {name} = {value_str}")
return "\n".join(lines)

def _rewrite_meta_codes(self, model_path, updated_meta_classes, filename):
new_meta_codes = []
for meta_class in updated_meta_classes:
new_meta_codes.append(self._generate_py_code_from_meta_class(meta_class))

meta_file_path = os.path.join(model_path, filename)
if self.config["update_inplace"]:
print(f"[GraphMetaRestorer] Update {meta_file_path}")
with open(meta_file_path, "w") as f:
f.write("\n\n".join(new_meta_codes))
61 changes: 61 additions & 0 deletions graph_net/paddle/random_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import os
import pickle
import numpy as np
import random
import re
import paddle

from graph_net.paddle import samples_util


def set_seed(random_seed):
paddle.seed(random_seed)
random.seed(random_seed)
np.random.seed(random_seed)


def _extract_model_name_for_original_sample(model_path):
fields = model_path.rstrip("/").split(os.sep)
pattern = r"^subgraph(_\d+)?$"
model_name = (
f"{fields[-2]}_{fields[-1]}" if re.match(pattern, fields[-1]) else fields[-1]
)
return model_name


def _extract_model_name_for_decomposed_subgraph(model_path):
# Parse model name and subgraph index
model_name_with_subgraph_idx = model_path.rstrip("/").split(os.sep)[-1]
model_name = "_".join(model_name_with_subgraph_idx.split("_")[:-1])
return model_name


def _generate_random_state_filename(model_path):
samples_dir = samples_util.get_default_samples_directory()
if os.path.abspath(model_path).startswith(samples_dir):
model_name = _extract_model_name_for_original_sample(model_path)
else:
model_name = _extract_model_name_for_decomposed_subgraph(model_path)
return f"{model_name}.random_states.pkl"


def save_random_states(model_path, output_dir, random_state_dict):
filepath = os.path.join(output_dir, _generate_random_state_filename(model_path))
print(f"Write to {filepath}.", flush=True)
try:
with open(filepath, "wb") as f:
pickle.dump(random_state_dict, f)
except Exception:
print(f"Fail to open {filepath}.")


def load_random_states(model_path, output_dir):
filepath = os.path.join(output_dir, _generate_random_state_filename(model_path))
print(f"Read from {filepath}.", flush=True)
random_states = None
try:
with open(filepath, "rb") as f:
random_states = pickle.load(f)
except Exception:
print(f"Fail to open {filepath}.")
return random_states
Loading
Loading