Skip to content

Commit 22a2772

Browse files
committed
add a script to check if a given model is fully fusable
1 parent adff744 commit 22a2772

File tree

3 files changed

+176
-0
lines changed

3 files changed

+176
-0
lines changed
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#!/bin/bash
2+
3+
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
4+
os.path.dirname(graph_net.__file__))")
5+
6+
# input model path
7+
MODEL_NAME=resnet18d.ra2_in1k
8+
MODEL_PATH_IN_SAMPLES=/timm/$MODEL_NAME
9+
checker_config_json_str=$(cat <<EOF
10+
{
11+
"post_extract_process_config": {
12+
"post_extract_process_path":"$GRAPH_NET_ROOT/torch/post_extract_process_count_kernels.py",
13+
"post_extract_process_class_name": "GraphFullyFusionable"
14+
}
15+
}
16+
EOF
17+
)
18+
CHECKER_CONFIG=$(echo $checker_config_json_str | base64 -w 0)
19+
20+
python3 -m graph_net.torch.check_model_fusability --model-path $GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES --checker-config=$CHECKER_CONFIG
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import argparse
2+
from graph_net.imp_util import load_module
3+
import sys
4+
import json
5+
import base64
6+
7+
8+
def _load_class_from_file(file_path, class_name):
9+
module = load_module(file_path)
10+
return getattr(module, class_name)
11+
12+
13+
def _convert_to_dict(config_str):
14+
if config_str is None:
15+
return {}
16+
config_str = base64.b64decode(config_str).decode("utf-8")
17+
config = json.loads(config_str)
18+
assert isinstance(config, dict), f"config should be a dict. {config_str=}"
19+
return config
20+
21+
22+
def _get_checker(args):
23+
if args.checker_config is None:
24+
return lambda model_path: model_path
25+
checker_config = _convert_to_dict(args.checker_config).get(
26+
"post_extract_process_config"
27+
)
28+
checker_class = _load_class_from_file(
29+
checker_config["post_extract_process_path"],
30+
class_name=checker_config["post_extract_process_class_name"],
31+
)
32+
return checker_class(checker_config.get("checker_config", {}))
33+
34+
35+
def main(args):
36+
checker = _get_checker(args)
37+
model_path = args.model_path
38+
print(f"{model_path=}")
39+
try:
40+
checker(model_path)
41+
except KeyboardInterrupt:
42+
sys.exit(-1)
43+
except Exception as e:
44+
print(e)
45+
46+
47+
if __name__ == "__main__":
48+
parser = argparse.ArgumentParser(description="load and run model")
49+
parser.add_argument(
50+
"--model-path",
51+
type=str,
52+
required=True,
53+
help="Path to folder e.g '../../samples/torch/resnet18'",
54+
)
55+
parser.add_argument(
56+
"--checker-config",
57+
type=str,
58+
required=False,
59+
default=None,
60+
help="checker configuration string",
61+
)
62+
args = parser.parse_args()
63+
main(args=args)
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
from graph_net.torch import utils
2+
import importlib.util
3+
import torch
4+
from typing import Type
5+
from torch.profiler import profile, record_function, ProfilerActivity
6+
7+
8+
class GraphFullyFusionable:
9+
def __init__(self, config):
10+
self.config = config
11+
12+
def __call__(self, model_path=None):
13+
if model_path is None:
14+
return False
15+
# model
16+
model_class = load_class_from_file(
17+
f"{model_path}/model.py", class_name="GraphModule"
18+
)
19+
assert model_class is not None
20+
model = model_class()
21+
print(f"{model_path=}")
22+
23+
inputs_params = utils.load_converted_from_text(f"{model_path}")
24+
params = inputs_params["weight_info"]
25+
state_dict = {k: utils.replay_tensor(v) for k, v in params.items()}
26+
27+
# try to run the model
28+
try:
29+
model(**state_dict)
30+
except Exception as e:
31+
print(f"failed in running model:{e}")
32+
# print(f"removing: {model_path}")
33+
# shutil.rmtree(model_path)
34+
return False
35+
# try to compile the model
36+
try:
37+
compiled_model = torch.compile(model)
38+
except Exception as e:
39+
print(f"failed in compiling model:{e}")
40+
# print(f"removing: {model_path}")
41+
# shutil.rmtree(model_path)
42+
return False
43+
compiled_num_of_kernels = count_kernels(compiled_model, state_dict)
44+
if compiled_num_of_kernels == 1:
45+
print(model_path, "can be fully integrated")
46+
return True
47+
else:
48+
print(f"{model_path} can not be fully integrated")
49+
# print(f"removing: {model_path}")
50+
# shutil.rmtree(model_path)
51+
return False
52+
53+
54+
def load_class_from_file(file_path: str, class_name: str) -> Type[torch.nn.Module]:
55+
spec = importlib.util.spec_from_file_location("unnamed", file_path)
56+
unnamed = importlib.util.module_from_spec(spec)
57+
spec.loader.exec_module(unnamed)
58+
model_class = getattr(unnamed, class_name, None)
59+
return model_class
60+
61+
62+
def count_kernels(model, sample_inputs) -> int:
63+
"""
64+
Count the number of CUDA kernel launches performed during a model's forward pass.
65+
66+
Args:
67+
model(graph models)
68+
sample_inputs(tensors)
69+
70+
Returns:
71+
int: The number of kernels used.
72+
73+
Behavior:
74+
- Runs the model once inside a PyTorch profiler context.
75+
- Identifies the event with key = 'cudaLaunchKernel', which corresponds
76+
to the number of CUDA kernel launches.
77+
"""
78+
model.eval()
79+
# Use PyTorch Profiler
80+
81+
with profile(
82+
activities=[ProfilerActivity.CUDA, ProfilerActivity.CPU],
83+
record_shapes=True,
84+
) as prof:
85+
with record_function("model_inference"):
86+
_ = model(**sample_inputs)
87+
events = prof.key_averages()
88+
89+
total_count = 0
90+
for e in events:
91+
if e.key == "cuLaunchKernel" or e.key == "cudaLaunchKernel":
92+
total_count += e.count
93+
return total_count

0 commit comments

Comments
 (0)