Skip to content

Commit b5349f3

Browse files
committed
Temporally support to save the random states.
1 parent ade8bb9 commit b5349f3

File tree

2 files changed

+47
-16
lines changed

2 files changed

+47
-16
lines changed

graph_net/paddle/run_model.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import argparse
55
import numpy as np
66
import random
7+
import pickle
78

89
os.environ["FLAGS_logging_pir_py_code_dir"] = "/tmp/dump"
910

@@ -31,13 +32,27 @@ def get_input_dict(model_path):
3132
params = inputs_params["weight_info"]
3233
inputs = inputs_params["input_info"]
3334

34-
state_dict = {}
35-
for k, v in params.items():
36-
name = v["original_name"] if v.get("original_name", None) else k
37-
state_dict[k] = paddle.nn.parameter.Parameter(utils.replay_tensor(v), name=name)
38-
for k, v in inputs.items():
39-
state_dict[k] = utils.replay_tensor(v)
40-
return state_dict
35+
random_state_dict = {}
36+
input_dict = {}
37+
for name, meta in params.items():
38+
original_name = (
39+
meta["original_name"] if meta.get("original_name", None) else name
40+
)
41+
random_state_dict[name] = np.random.get_state()
42+
input_dict[name] = paddle.nn.parameter.Parameter(
43+
utils.replay_tensor(meta), name=original_name
44+
)
45+
for name, meta in inputs.items():
46+
random_state_dict[name] = np.random.get_state()
47+
input_dict[name] = utils.replay_tensor(meta)
48+
return input_dict, random_state_dict
49+
50+
51+
def save_random_states(output_dir, random_state_dict):
52+
filepath = os.path.join(output_dir, "random_states.pkl")
53+
print(f"Write to {filepath}.")
54+
with open(filepath, "wb") as f:
55+
pickle.dump(random_state_dict, f)
4156

4257

4358
def _convert_to_dict(config_str):
@@ -73,7 +88,9 @@ def main(args):
7388
initalize_seed = 123
7489
set_seed(random_seed=initalize_seed)
7590

76-
input_dict = get_input_dict(args.model_path)
91+
input_dict, random_state_dict = get_input_dict(args.model_path)
92+
output_dir = "/work/GraphNet/graph_net/test/outputs/pass_0"
93+
save_random_states(output_dir, random_state_dict)
7794
model = _get_decorator(args)(model)
7895
model(**input_dict)
7996

graph_net/paddle/test_compiler.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,9 @@
44
from pathlib import Path
55
import sys
66
import os
7-
from dataclasses import dataclass
8-
from contextlib import contextmanager
9-
import time
10-
import math
117
import numpy as np
128
import random
9+
import pickle
1310
import platform
1411
import traceback
1512
import subprocess
@@ -62,7 +59,7 @@ def get_hardward_name(args):
6259
)
6360
)
6461
)
65-
except Exception as e:
62+
except Exception:
6663
pass
6764
elif args.device == "cpu":
6865
hardware = platform.processor()
@@ -100,14 +97,31 @@ def get_model(model_path):
10097
return model_class()
10198

10299

100+
def load_random_states(output_dir):
101+
filepath = os.path.join(output_dir, "random_states.pkl")
102+
print(f"Read from {filepath}.")
103+
random_states = None
104+
with open(filepath, "rb") as f:
105+
random_states = pickle.load(f)
106+
return random_states
107+
108+
103109
def get_input_dict(model_path):
104110
inputs_params = utils.load_converted_from_text(f"{model_path}")
105111
params = inputs_params["weight_info"]
106112
inputs = inputs_params["input_info"]
107-
108113
params.update(inputs)
109-
state_dict = {k: utils.replay_tensor(v) for k, v in params.items()}
110-
return state_dict
114+
115+
output_dir = "/work/GraphNet/graph_net/test/outputs/pass_0"
116+
random_states = load_random_states(output_dir)
117+
118+
input_dict = {}
119+
for name, meta in params.items():
120+
if random_states is not None and random_states.get(name, None) is not None:
121+
np.random.set_state(random_states[name])
122+
tensor = utils.replay_tensor(meta)
123+
input_dict[name] = tensor
124+
return input_dict
111125

112126

113127
def get_input_spec(model_path):

0 commit comments

Comments
 (0)