Skip to content

Commit 400c46a

Browse files
committed
Temporally support to save the random states.
1 parent ade8bb9 commit 400c46a

File tree

4 files changed

+75
-42
lines changed

4 files changed

+75
-42
lines changed

graph_net/paddle/random_util.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import os
2+
import pickle
3+
import numpy as np
4+
import random
5+
import re
6+
import paddle
7+
8+
9+
def set_seed(random_seed):
10+
paddle.seed(random_seed)
11+
random.seed(random_seed)
12+
np.random.seed(random_seed)
13+
14+
15+
def _generate_random_state_filename(model_path):
16+
fields = model_path.rstrip("/").split(os.sep)
17+
pattern = r"^subgraph(_\d+)?$"
18+
return f"{fields[-2]}_{fields[-1]}" if re.match(pattern, fields[-1]) else fields[-1]
19+
20+
21+
def save_random_states(model_path, output_dir, random_state_dict):
22+
filepath = os.path.join(output_dir, _generate_random_state_filename(model_path))
23+
print(f"Write to {filepath}.", flush=True)
24+
try:
25+
with open(filepath, "wb") as f:
26+
pickle.dump(random_state_dict, f)
27+
except Exception:
28+
print(f"Fail to open {filepath}.")
29+
30+
31+
def load_random_states(model_path, output_dir):
32+
filepath = os.path.join(output_dir, _generate_random_state_filename(model_path))
33+
print(f"Read from {filepath}.", flush=True)
34+
random_states = None
35+
try:
36+
with open(filepath, "rb") as f:
37+
random_states = pickle.load(f)
38+
except Exception:
39+
print(f"Fail to open {filepath}.")
40+
return random_states

graph_net/paddle/run_model.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,12 @@
33
import base64
44
import argparse
55
import numpy as np
6-
import random
76

87
os.environ["FLAGS_logging_pir_py_code_dir"] = "/tmp/dump"
98

109
import paddle
1110
from graph_net import imp_util
12-
from graph_net.paddle import utils
13-
14-
15-
def set_seed(random_seed):
16-
paddle.seed(random_seed)
17-
random.seed(random_seed)
18-
np.random.seed(random_seed)
11+
from graph_net.paddle import utils, random_util
1912

2013

2114
def load_class_from_file(file_path: str, class_name: str):
@@ -31,13 +24,20 @@ def get_input_dict(model_path):
3124
params = inputs_params["weight_info"]
3225
inputs = inputs_params["input_info"]
3326

34-
state_dict = {}
35-
for k, v in params.items():
36-
name = v["original_name"] if v.get("original_name", None) else k
37-
state_dict[k] = paddle.nn.parameter.Parameter(utils.replay_tensor(v), name=name)
38-
for k, v in inputs.items():
39-
state_dict[k] = utils.replay_tensor(v)
40-
return state_dict
27+
random_state_dict = {}
28+
input_dict = {}
29+
for name, meta in params.items():
30+
original_name = (
31+
meta["original_name"] if meta.get("original_name", None) else name
32+
)
33+
random_state_dict[name] = np.random.get_state()
34+
input_dict[name] = paddle.nn.parameter.Parameter(
35+
utils.replay_tensor(meta), name=original_name
36+
)
37+
for name, meta in inputs.items():
38+
random_state_dict[name] = np.random.get_state()
39+
input_dict[name] = utils.replay_tensor(meta)
40+
return input_dict, random_state_dict
4141

4242

4343
def _convert_to_dict(config_str):
@@ -71,9 +71,11 @@ def main(args):
7171
print(f"{model_path=}")
7272

7373
initalize_seed = 123
74-
set_seed(random_seed=initalize_seed)
74+
random_util.set_seed(random_seed=initalize_seed)
7575

76-
input_dict = get_input_dict(args.model_path)
76+
input_dict, random_state_dict = get_input_dict(args.model_path)
77+
output_dir = "/work/GraphNet/graph_net/test/outputs/pass_0"
78+
random_util.save_random_states(model_path, output_dir, random_state_dict)
7779
model = _get_decorator(args)(model)
7880
model(**input_dict)
7981

graph_net/paddle/test_compiler.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,15 @@
44
from pathlib import Path
55
import sys
66
import os
7-
from dataclasses import dataclass
8-
from contextlib import contextmanager
9-
import time
10-
import math
117
import numpy as np
12-
import random
138
import platform
149
import traceback
1510
import subprocess
1611
import re
1712

18-
from graph_net.paddle import utils
1913
from graph_net import path_utils
2014
from graph_net import test_compiler_util
21-
15+
from graph_net.paddle import utils, random_util
2216
from graph_net.paddle.backend.graph_compiler_backend import GraphCompilerBackend
2317
from graph_net.paddle.backend.cinn_backend import CinnBackend
2418
from graph_net.paddle.backend.nope_backend import NopeBackend
@@ -35,12 +29,6 @@ def get_compiler_backend(args) -> GraphCompilerBackend:
3529
return registry_backend[args.compiler]
3630

3731

38-
def set_seed(random_seed):
39-
paddle.seed(random_seed)
40-
random.seed(random_seed)
41-
np.random.seed(random_seed)
42-
43-
4432
def init_env(args):
4533
if test_compiler_util.is_gpu_device(args.device):
4634
paddle.set_flags({"FLAGS_cudnn_exhaustive_search": 1})
@@ -62,7 +50,7 @@ def get_hardward_name(args):
6250
)
6351
)
6452
)
65-
except Exception as e:
53+
except Exception:
6654
pass
6755
elif args.device == "cpu":
6856
hardware = platform.processor()
@@ -104,10 +92,18 @@ def get_input_dict(model_path):
10492
inputs_params = utils.load_converted_from_text(f"{model_path}")
10593
params = inputs_params["weight_info"]
10694
inputs = inputs_params["input_info"]
107-
10895
params.update(inputs)
109-
state_dict = {k: utils.replay_tensor(v) for k, v in params.items()}
110-
return state_dict
96+
97+
output_dir = "/work/GraphNet/graph_net/test/outputs/pass_0"
98+
random_states = random_util.load_random_states(model_path, output_dir)
99+
100+
input_dict = {}
101+
for name, meta in params.items():
102+
if random_states is not None and random_states.get(name, None) is not None:
103+
np.random.set_state(random_states[name])
104+
tensor = utils.replay_tensor(meta)
105+
input_dict[name] = tensor
106+
return input_dict
111107

112108

113109
def get_input_spec(model_path):
@@ -476,7 +472,7 @@ def main(args):
476472
assert args.device in ["cuda", "dcu", "xpu", "cpu"]
477473

478474
initalize_seed = 123
479-
set_seed(random_seed=initalize_seed)
475+
random_util.set_seed(random_seed=initalize_seed)
480476

481477
if path_utils.is_single_model_dir(args.model_path):
482478
test_single_model(args)

graph_net/paddle/test_reference_device.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,15 @@
11
import argparse
2-
import importlib.util
32
import paddle
4-
import time
5-
import numpy as np
6-
import random
73
import os
84
from pathlib import Path
95
from contextlib import redirect_stdout, redirect_stderr
106
import json
11-
import re
127
import sys
138
import traceback
149

1510
from graph_net import path_utils
1611
from graph_net import test_compiler_util
17-
from graph_net.paddle import test_compiler
12+
from graph_net.paddle import random_util, test_compiler
1813

1914

2015
def get_reference_log_path(reference_dir, model_path):
@@ -130,7 +125,7 @@ def main(args):
130125
assert args.compiler in {"cinn", "nope"}
131126
assert args.device in ["cuda"]
132127

133-
test_compiler.set_seed(random_seed=args.seed)
128+
random_util.set_seed(random_seed=args.seed)
134129
test_compiler.init_env(args)
135130

136131
ref_dump_dir = Path(args.reference_dir)

0 commit comments

Comments
 (0)