-
Notifications
You must be signed in to change notification settings - Fork 44
check if a graph can be fully fused into a single cuda kernel #381
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from 17 commits
986793a
282a32d
0187ae0
5d46f55
39b4139
b775e46
44ad76f
0fc84c4
19dc60b
d6eda81
956ad33
8c8070b
ef7d4b6
181b293
2aac268
00d5b4b
75c3e61
fe89add
ca860b3
c21717f
de54e88
9363023
adff744
ca20508
fc0071c
9a28d45
513cc38
4847ee3
6538119
22a2772
684dba9
bfe0848
f8cc102
f131cfb
f7f3d2a
353e7bd
b703458
0b687cf
e70b44b
7dbb6e9
f71b56b
93fabbf
6df0cd0
babdde5
48467f7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| #!/bin/bash | ||
| SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) | ||
| GRAPH_NET_DIR=$(dirname "$SCRIPT_DIR") | ||
| PROJECT_ROOT=$(dirname "$GRAPH_NET_DIR") | ||
|
|
||
| # 将项目根目录加入Python路径 | ||
| export PYTHONPATH="$PROJECT_ROOT:$PYTHONPATH" | ||
|
|
||
| GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print( | ||
| os.path.dirname(graph_net.__file__))") | ||
|
|
||
| # input model path | ||
| MODEL_NAME=resnet18 | ||
| MODEL_PATH_IN_SAMPLES=/timm/$MODEL_NAME | ||
| decorator_config_json_str=$(cat <<EOF | ||
| { | ||
| "decorator_path": "$GRAPH_NET_ROOT/torch/extractor.py", | ||
| "decorator_config": { | ||
| "name": "$MODEL_NAME", | ||
| "custom_extractor_path": "$GRAPH_NET_ROOT/torch/naive_graph_decomposer.py", | ||
| "custom_extractor_config": { | ||
| "output_dir": "/work/.BCloud/countkernels/", | ||
| "split_positions": [8, 16, 32], | ||
| "group_head_and_tail": true, | ||
| "filter_path":"$GRAPH_NET_ROOT/torch/naive_subgraph_filter.py", | ||
| "filter_config": {}, | ||
| "post_extract_process_path":"$GRAPH_NET_ROOT/torch/post_extract_process.py", | ||
| "post_extract_process_config": { | ||
| "decorator_path": "$GRAPH_NET_ROOT/torch/shape_prop.py", | ||
| "decorator_class_name": "ShapePropagate" | ||
|
||
| } | ||
| } | ||
| } | ||
| } | ||
| EOF | ||
| ) | ||
| DECORATOR_CONFIG=$(echo $decorator_config_json_str | base64 -w 0) | ||
|
|
||
| python3 -m graph_net.torch.run_model --model-path $GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES --decorator-config=$DECORATOR_CONFIG | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,13 @@ | ||
| #!/bin/bash | ||
| SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) | ||
| GRAPH_NET_DIR=$(dirname "$SCRIPT_DIR") | ||
| PROJECT_ROOT=$(dirname "$GRAPH_NET_DIR") | ||
|
|
||
| # 将项目根目录加入Python路径 | ||
roll-away marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| export PYTHONPATH="$PROJECT_ROOT:$PYTHONPATH" | ||
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
| GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -32,6 +32,8 @@ def make_config( | |
| output_dir="./tmp/naive_decomposer_dir", | ||
| filter_path=None, | ||
| filter_config=None, | ||
| post_extract_process_path=None, | ||
lixinqi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| post_extract_process_config=None, | ||
| ): | ||
| for pos in split_positions: | ||
| assert isinstance( | ||
|
|
@@ -44,6 +46,8 @@ def make_config( | |
| "output_dir": output_dir, | ||
| "filter_path": filter_path, | ||
| "filter_config": filter_config if filter_config is not None else {}, | ||
| "post_extract_process_path": post_extract_process_path, | ||
| "post_extract_process_config": post_extract_process_config, | ||
roll-away marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } | ||
|
|
||
| def __call__(self, gm: torch.fx.GraphModule, sample_inputs): | ||
|
|
@@ -71,6 +75,7 @@ def __init__(self, parent_graph_extractor, submodule, seq_no): | |
| self.seq_no = seq_no | ||
| self.extracted = False | ||
| name = f"{parent_graph_extractor.name}_{self.seq_no}" | ||
| self.modelname = name | ||
roll-away marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| self.builtin_extractor = BuiltinGraphExtractor( | ||
| name=name, | ||
| dynamic=False, | ||
|
|
@@ -79,21 +84,45 @@ def __init__(self, parent_graph_extractor, submodule, seq_no): | |
| workspace_path=self.parent_graph_extractor.config["output_dir"], | ||
| ) | ||
| self.filter = self.make_filter(self.parent_graph_extractor.config) | ||
| self.post_extract_process = self.make_post_extract_process( | ||
| self.parent_graph_extractor.config | ||
| ) | ||
|
|
||
| def forward(self, *args): | ||
| if not self.extracted: | ||
| if self.need_extract(self.submodule, args): | ||
| self.builtin_extractor(self.submodule, args) | ||
| self.get_post_extract_process(self.submodule, args) | ||
| self.extracted = True | ||
| return self.submodule(*args) | ||
|
|
||
| def need_extract(self, gm, sample_inputs): | ||
| # print("need_extract") | ||
| if self.filter is None: | ||
| return True | ||
| # if self.fusionablity_filter is not None: | ||
| # print("fusionablity of this model is ", self.fusionablity_filter(gm, sample_inputs)) | ||
| return self.filter(gm, sample_inputs) | ||
|
|
||
| def get_post_extract_process(self, gm, sample_inputs): | ||
|
||
| # print("modelname: ",self.modelname) | ||
| # print("parent_graph_extractor.config: ",self.parent_graph_extractor.config['output_dir']) | ||
| # print("get_post_extract_process") | ||
| model_path = os.path.join( | ||
| self.parent_graph_extractor.config["output_dir"], self.modelname | ||
| ) | ||
| return self.post_extract_process(model_path) | ||
|
|
||
| def make_filter(self, config): | ||
| # print("make_filter") | ||
| if config["filter_path"] is None: | ||
| return None | ||
| module = imp_util.load_module(config["filter_path"]) | ||
| return module.GraphFilter(config["filter_config"]) | ||
|
|
||
| def make_post_extract_process(self, config): | ||
| # print("make post_extract_process") | ||
| if config["filter_path"] is None: | ||
|
||
| return None | ||
| module = imp_util.load_module(config["post_extract_process_path"]) | ||
| return module.PostExtractProcess(config["post_extract_process_config"]) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,5 +3,6 @@ def __init__(self, config): | |
| self.config = config | ||
|
|
||
| def __call__(self, gm, sample_inputs): | ||
| print(f"GraphFilter\n{gm.code}") | ||
| print("GraphFilter") | ||
| # print(f"GraphFilter\n{gm.code}") | ||
|
||
| return True | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,144 @@ | ||
| from graph_net.torch import utils | ||
| import argparse | ||
| import importlib.util | ||
| import inspect | ||
| import shutil | ||
| import torch | ||
| import logging | ||
| from pathlib import Path | ||
| from typing import Type, Any | ||
| import sys | ||
| import json | ||
| import base64 | ||
| from contextlib import contextmanager | ||
|
|
||
| from torch.profiler import profile, record_function, ProfilerActivity | ||
|
|
||
|
|
||
| class PostExtractProcess: | ||
| def __init__(self, config): | ||
| self.config = config | ||
|
|
||
| def __call__(self, model_path=None): | ||
| print("PostExtractProcess") | ||
| if model_path is None: | ||
| return False | ||
| import json | ||
| import base64 | ||
| import sys | ||
| import os | ||
|
|
||
| json_string = json.dumps(self.config) | ||
| json_bytes = json_string.encode("utf-8") | ||
| b64_encoded_bytes = base64.b64encode(json_bytes) | ||
| decorator_config = b64_encoded_bytes.decode("utf-8") | ||
|
|
||
| # args | ||
| parser = argparse.ArgumentParser(description="load and run model") | ||
| parser.add_argument( | ||
| "--model-path", | ||
| type=str, | ||
| required=True, | ||
| help="Path to folder e.g '../../samples/torch/resnet18'", | ||
| ) | ||
| parser.add_argument( | ||
| "--decorator-config", | ||
| type=str, | ||
| required=False, | ||
| default=None, | ||
| help="decorator configuration string", | ||
| ) | ||
| args = parser.parse_args() | ||
|
|
||
| # model | ||
| model_class = load_class_from_file( | ||
| f"{model_path}/model.py", class_name="GraphModule" | ||
| ) | ||
| assert model_class is not None | ||
| model = model_class() | ||
| print(f"{model_path=}") | ||
|
|
||
| model = _get_decorator(args)(model) | ||
|
|
||
| inputs_params = utils.load_converted_from_text(f"{model_path}") | ||
| params = inputs_params["weight_info"] | ||
| state_dict = {k: utils.replay_tensor(v) for k, v in params.items()} | ||
|
|
||
| compiled_num_of_kernels = compile_and_count_kernels(model, state_dict) | ||
| print("compiled: nums_of_kernels = ", compiled_num_of_kernels) | ||
| if compiled_num_of_kernels == 1: | ||
| print("Graph is fully fusionable") | ||
| return True | ||
| else: | ||
| print(f"Graph is not fully fusionable ({compiled_num_of_kernels} kernels)") | ||
| shutil.rmtree(model_path) | ||
| return False | ||
|
|
||
|
|
||
| def _convert_to_dict(config_str): | ||
| if config_str is None: | ||
| return {} | ||
| config_str = base64.b64decode(config_str).decode("utf-8") | ||
| config = json.loads(config_str) | ||
| assert isinstance(config, dict), f"config should be a dict. {config_str=}" | ||
| return config | ||
|
|
||
|
|
||
| def _get_decorator(args): | ||
| if args.decorator_config is None: | ||
| return lambda model: model | ||
| decorator_config = _convert_to_dict(args.decorator_config) | ||
| if "decorator_path" not in decorator_config: | ||
| return lambda model: model | ||
| class_name = decorator_config.get("decorator_class_name", "RunModelDecorator") | ||
| decorator_class = load_class_from_file( | ||
| decorator_config["decorator_path"], | ||
| class_name=class_name, | ||
| ) | ||
| return decorator_class(decorator_config.get("decorator_config", {})) | ||
|
|
||
|
|
||
| def load_class_from_file(file_path: str, class_name: str) -> Type[torch.nn.Module]: | ||
| spec = importlib.util.spec_from_file_location("unnamed", file_path) | ||
| unnamed = importlib.util.module_from_spec(spec) | ||
| spec.loader.exec_module(unnamed) | ||
| model_class = getattr(unnamed, class_name, None) | ||
| return model_class | ||
|
|
||
|
|
||
| def compile_and_count_kernels(gm, sample_inputs) -> int: | ||
|
||
| """ | ||
| Count the number of CUDA kernel launches performed during a model's forward pass. | ||
|
|
||
| Args: | ||
| gm(graph models) | ||
| sample_inputs(tensors) | ||
|
|
||
| Returns: | ||
| int: The number of kernels used. | ||
|
|
||
| Behavior: | ||
| - Runs the model once inside a PyTorch profiler context. | ||
| - Identifies the event with key = 'cudaLaunchKernel', which corresponds | ||
| to the number of CUDA kernel launches. | ||
| """ | ||
| gm.eval() | ||
| # Use PyTorch Profiler | ||
| compiled_gm = torch.compile(gm) | ||
| _ = compiled_gm(**sample_inputs) | ||
|
||
|
|
||
| with profile( | ||
| activities=[ProfilerActivity.CUDA, ProfilerActivity.CPU], | ||
| record_shapes=True, | ||
| ) as prof: | ||
| with record_function("model_inference"): | ||
| output = compiled_gm(**sample_inputs) | ||
| print(prof.key_averages().table()) # print a table of profiler result | ||
| events = prof.key_averages() | ||
| if_compile_work = any(e.key == "TorchDynamo Cache Lookup" for e in events) | ||
|
||
| if not if_compile_work: | ||
| print("Compile failed") | ||
| return -1 | ||
| for e in events: | ||
| if e.key == "cuLaunchKernel": | ||
|
||
| return e.count | ||
Uh oh!
There was an error while loading. Please reload this page.