Skip to content
Open
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
986793a
1119
roll-away Nov 19, 2025
282a32d
1120
roll-away Nov 20, 2025
0187ae0
1120.2
roll-away Nov 20, 2025
5d46f55
model_path
roll-away Nov 20, 2025
39b4139
remove unnecessary files and pre-committed
roll-away Nov 20, 2025
b775e46
remove unnecessary files and pre-committed
roll-away Nov 21, 2025
44ad76f
1121 remove unnecessary files
roll-away Nov 21, 2025
0fc84c4
modify rev version
roll-away Nov 21, 2025
19dc60b
modify rev version
roll-away Nov 21, 2025
d6eda81
modify rev version
roll-away Nov 21, 2025
956ad33
accuracy issues targeted
roll-away Nov 21, 2025
8c8070b
test script and modify feature
roll-away Nov 21, 2025
ef7d4b6
return set[str]
roll-away Nov 21, 2025
181b293
add logfile for test
roll-away Nov 21, 2025
2aac268
filter can get the number of kernels in naive_graph_decomposer
roll-away Nov 24, 2025
00d5b4b
Merge branch 'PaddlePaddle:develop' into develop
roll-away Nov 24, 2025
75c3e61
post extract process feature
roll-away Nov 25, 2025
fe89add
remove unnecessary code blocks and variables
roll-away Nov 25, 2025
ca860b3
modify the way of counting kernels used
roll-away Nov 25, 2025
c21717f
modify the way of counting kernels used
roll-away Nov 25, 2025
de54e88
modify script, rename files and variables
roll-away Nov 25, 2025
9363023
add failure protection and log output when removing directories
roll-away Nov 26, 2025
adff744
Merge branch 'PaddlePaddle:develop' into develop
roll-away Nov 27, 2025
ca20508
add a script to check fusability of a given model
roll-away Dec 1, 2025
fc0071c
Merge branch 'PaddlePaddle:develop' into develop
roll-away Dec 1, 2025
9a28d45
Merge branch 'develop' of github.com:roll-away/GraphNet into develop
roll-away Dec 1, 2025
513cc38
add a script to check if a given model is fully fusable
roll-away Dec 1, 2025
4847ee3
Merge branch 'PaddlePaddle:develop' into develop
roll-away Dec 1, 2025
6538119
Merge branch 'PaddlePaddle:develop' into develop
roll-away Dec 1, 2025
22a2772
add a script to check if a given model is fully fusable
roll-away Dec 1, 2025
684dba9
a script to check if a given model is fully fusable
roll-away Dec 1, 2025
bfe0848
Merge branch 'PaddlePaddle:develop' into develop
roll-away Dec 1, 2025
f8cc102
add a script to check if a given model is fully fusionable
roll-away Dec 1, 2025
f131cfb
add a script to check if a given model is fully fusionable
roll-away Dec 1, 2025
f7f3d2a
add a script to find fully fusionable subgraph
roll-away Dec 1, 2025
353e7bd
find the biggest fully fusionable subgraph
roll-away Dec 2, 2025
b703458
update new codes
roll-away Dec 8, 2025
0b687cf
get fusible subgraph test
roll-away Dec 8, 2025
e70b44b
get fusible subgraph test
roll-away Dec 8, 2025
7dbb6e9
modify get fully fusible subgraph
roll-away Dec 9, 2025
f71b56b
improve fully_fusible_subgraph_extractor.py efficiency
lixinqi Dec 9, 2025
93fabbf
Merge pull request #1 from lixinqi/lxq_fusibletest
roll-away Dec 9, 2025
6df0cd0
backup code
lixinqi Dec 9, 2025
babdde5
Improve efficiency of test/fully_fusible_subgraph_extractor_test.sh
lixinqi Dec 9, 2025
48467f7
Merge pull request #2 from lixinqi/lxq_fusibletest
roll-away Dec 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions graph_net/test/naive_decomposer_and_post_extract_process_test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#!/bin/bash
SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
GRAPH_NET_DIR=$(dirname "$SCRIPT_DIR")
PROJECT_ROOT=$(dirname "$GRAPH_NET_DIR")

# 将项目根目录加入Python路径
export PYTHONPATH="$PROJECT_ROOT:$PYTHONPATH"

GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
os.path.dirname(graph_net.__file__))")

# input model path
MODEL_NAME=resnet18
MODEL_PATH_IN_SAMPLES=/timm/$MODEL_NAME
decorator_config_json_str=$(cat <<EOF
{
"decorator_path": "$GRAPH_NET_ROOT/torch/extractor.py",
"decorator_config": {
"name": "$MODEL_NAME",
"custom_extractor_path": "$GRAPH_NET_ROOT/torch/naive_graph_decomposer.py",
"custom_extractor_config": {
"output_dir": "/work/.BCloud/countkernels/",
"split_positions": [8, 16, 32],
"group_head_and_tail": true,
"filter_path":"$GRAPH_NET_ROOT/torch/naive_subgraph_filter.py",
"filter_config": {},
"post_extract_process_path":"$GRAPH_NET_ROOT/torch/post_extract_process.py",
"post_extract_process_config": {
"decorator_path": "$GRAPH_NET_ROOT/torch/shape_prop.py",
"decorator_class_name": "ShapePropagate"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

L28 - L30与本功能无关,删除

}
}
}
}
EOF
)
DECORATOR_CONFIG=$(echo $decorator_config_json_str | base64 -w 0)

python3 -m graph_net.torch.run_model --model-path $GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES --decorator-config=$DECORATOR_CONFIG
9 changes: 9 additions & 0 deletions graph_net/test/naive_graph_decomposer_test.sh
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
#!/bin/bash
SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
GRAPH_NET_DIR=$(dirname "$SCRIPT_DIR")
PROJECT_ROOT=$(dirname "$GRAPH_NET_DIR")

# 将项目根目录加入Python路径
export PYTHONPATH="$PROJECT_ROOT:$PYTHONPATH"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment line 2 to line 7

It's not a good way to force users modifying PYTHONPATH.
If scripts failed, It's user's duty to set PYTHONPATH in .bashrc

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已删 在注释里写了怎么跑这个脚本






GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
Expand Down
29 changes: 29 additions & 0 deletions graph_net/torch/naive_graph_decomposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def make_config(
output_dir="./tmp/naive_decomposer_dir",
filter_path=None,
filter_config=None,
post_extract_process_path=None,
post_extract_process_config=None,
):
for pos in split_positions:
assert isinstance(
Expand All @@ -44,6 +46,8 @@ def make_config(
"output_dir": output_dir,
"filter_path": filter_path,
"filter_config": filter_config if filter_config is not None else {},
"post_extract_process_path": post_extract_process_path,
"post_extract_process_config": post_extract_process_config,
}

def __call__(self, gm: torch.fx.GraphModule, sample_inputs):
Expand Down Expand Up @@ -71,6 +75,7 @@ def __init__(self, parent_graph_extractor, submodule, seq_no):
self.seq_no = seq_no
self.extracted = False
name = f"{parent_graph_extractor.name}_{self.seq_no}"
self.modelname = name
self.builtin_extractor = BuiltinGraphExtractor(
name=name,
dynamic=False,
Expand All @@ -79,21 +84,45 @@ def __init__(self, parent_graph_extractor, submodule, seq_no):
workspace_path=self.parent_graph_extractor.config["output_dir"],
)
self.filter = self.make_filter(self.parent_graph_extractor.config)
self.post_extract_process = self.make_post_extract_process(
self.parent_graph_extractor.config
)

def forward(self, *args):
if not self.extracted:
if self.need_extract(self.submodule, args):
self.builtin_extractor(self.submodule, args)
self.get_post_extract_process(self.submodule, args)
self.extracted = True
return self.submodule(*args)

def need_extract(self, gm, sample_inputs):
# print("need_extract")
if self.filter is None:
return True
# if self.fusionablity_filter is not None:
# print("fusionablity of this model is ", self.fusionablity_filter(gm, sample_inputs))
return self.filter(gm, sample_inputs)

def get_post_extract_process(self, gm, sample_inputs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删除无用的参数gm, sample_inputs,清理下debug代码

# print("modelname: ",self.modelname)
# print("parent_graph_extractor.config: ",self.parent_graph_extractor.config['output_dir'])
# print("get_post_extract_process")
model_path = os.path.join(
self.parent_graph_extractor.config["output_dir"], self.modelname
)
return self.post_extract_process(model_path)

def make_filter(self, config):
# print("make_filter")
if config["filter_path"] is None:
return None
module = imp_util.load_module(config["filter_path"])
return module.GraphFilter(config["filter_config"])

def make_post_extract_process(self, config):
# print("make post_extract_process")
if config["filter_path"] is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里不是filter_path

return None
module = imp_util.load_module(config["post_extract_process_path"])
return module.PostExtractProcess(config["post_extract_process_config"])
3 changes: 2 additions & 1 deletion graph_net/torch/naive_subgraph_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@ def __init__(self, config):
self.config = config

def __call__(self, gm, sample_inputs):
print(f"GraphFilter\n{gm.code}")
print("GraphFilter")
# print(f"GraphFilter\n{gm.code}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删除

return True
144 changes: 144 additions & 0 deletions graph_net/torch/post_extract_process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
from graph_net.torch import utils
import argparse
import importlib.util
import inspect
import shutil
import torch
import logging
from pathlib import Path
from typing import Type, Any
import sys
import json
import base64
from contextlib import contextmanager

from torch.profiler import profile, record_function, ProfilerActivity


class PostExtractProcess:
def __init__(self, config):
self.config = config

def __call__(self, model_path=None):
print("PostExtractProcess")
if model_path is None:
return False
import json
import base64
import sys
import os

json_string = json.dumps(self.config)
json_bytes = json_string.encode("utf-8")
b64_encoded_bytes = base64.b64encode(json_bytes)
decorator_config = b64_encoded_bytes.decode("utf-8")

# args
parser = argparse.ArgumentParser(description="load and run model")
parser.add_argument(
"--model-path",
type=str,
required=True,
help="Path to folder e.g '../../samples/torch/resnet18'",
)
parser.add_argument(
"--decorator-config",
type=str,
required=False,
default=None,
help="decorator configuration string",
)
args = parser.parse_args()

# model
model_class = load_class_from_file(
f"{model_path}/model.py", class_name="GraphModule"
)
assert model_class is not None
model = model_class()
print(f"{model_path=}")

model = _get_decorator(args)(model)

inputs_params = utils.load_converted_from_text(f"{model_path}")
params = inputs_params["weight_info"]
state_dict = {k: utils.replay_tensor(v) for k, v in params.items()}

compiled_num_of_kernels = compile_and_count_kernels(model, state_dict)
print("compiled: nums_of_kernels = ", compiled_num_of_kernels)
if compiled_num_of_kernels == 1:
print("Graph is fully fusionable")
return True
else:
print(f"Graph is not fully fusionable ({compiled_num_of_kernels} kernels)")
shutil.rmtree(model_path)
return False


def _convert_to_dict(config_str):
if config_str is None:
return {}
config_str = base64.b64decode(config_str).decode("utf-8")
config = json.loads(config_str)
assert isinstance(config, dict), f"config should be a dict. {config_str=}"
return config


def _get_decorator(args):
if args.decorator_config is None:
return lambda model: model
decorator_config = _convert_to_dict(args.decorator_config)
if "decorator_path" not in decorator_config:
return lambda model: model
class_name = decorator_config.get("decorator_class_name", "RunModelDecorator")
decorator_class = load_class_from_file(
decorator_config["decorator_path"],
class_name=class_name,
)
return decorator_class(decorator_config.get("decorator_config", {}))


def load_class_from_file(file_path: str, class_name: str) -> Type[torch.nn.Module]:
spec = importlib.util.spec_from_file_location("unnamed", file_path)
unnamed = importlib.util.module_from_spec(spec)
spec.loader.exec_module(unnamed)
model_class = getattr(unnamed, class_name, None)
return model_class


def compile_and_count_kernels(gm, sample_inputs) -> int:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gm -> model,这里一般是torch.nn.Module,不是torch.compile里面转换后的fx.GraphModule类型

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.compile还是放在外面吧,函数功能专用一点,这个函数就用来统计CUDA Kernel数量,不排除有场景需要统计eager执行CUDA Kernel数量的需求。

"""
Count the number of CUDA kernel launches performed during a model's forward pass.

Args:
gm(graph models)
sample_inputs(tensors)

Returns:
int: The number of kernels used.

Behavior:
- Runs the model once inside a PyTorch profiler context.
- Identifies the event with key = 'cudaLaunchKernel', which corresponds
to the number of CUDA kernel launches.
"""
gm.eval()
# Use PyTorch Profiler
compiled_gm = torch.compile(gm)
_ = compiled_gm(**sample_inputs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里跑一遍的目的是什么?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要warm up吗


with profile(
activities=[ProfilerActivity.CUDA, ProfilerActivity.CPU],
record_shapes=True,
) as prof:
with record_function("model_inference"):
output = compiled_gm(**sample_inputs)
print(prof.key_averages().table()) # print a table of profiler result
events = prof.key_averages()
if_compile_work = any(e.key == "TorchDynamo Cache Lookup" for e in events)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个Profiler显示的字段,极有可能在不同的版本会不一样,用这个字段判断极有可能在版本升级之后出现问题。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改为不判断是否经过compile,现在是在compile上下文之中会出现问题

if not if_compile_work:
print("Compile failed")
return -1
for e in events:
if e.key == "cuLaunchKernel":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

应该统计cudaLaunchKernelcuLaunchKernel两个接口调用数量的和

return e.count