|
4 | 4 | import argparse |
5 | 5 | import numpy as np |
6 | 6 | import random |
| 7 | +import pickle |
7 | 8 |
|
8 | 9 | os.environ["FLAGS_logging_pir_py_code_dir"] = "/tmp/dump" |
9 | 10 |
|
@@ -31,13 +32,27 @@ def get_input_dict(model_path): |
31 | 32 | params = inputs_params["weight_info"] |
32 | 33 | inputs = inputs_params["input_info"] |
33 | 34 |
|
34 | | - state_dict = {} |
35 | | - for k, v in params.items(): |
36 | | - name = v["original_name"] if v.get("original_name", None) else k |
37 | | - state_dict[k] = paddle.nn.parameter.Parameter(utils.replay_tensor(v), name=name) |
38 | | - for k, v in inputs.items(): |
39 | | - state_dict[k] = utils.replay_tensor(v) |
40 | | - return state_dict |
| 35 | + random_state_dict = {} |
| 36 | + input_dict = {} |
| 37 | + for name, meta in params.items(): |
| 38 | + original_name = ( |
| 39 | + meta["original_name"] if meta.get("original_name", None) else name |
| 40 | + ) |
| 41 | + random_state_dict[name] = np.random.get_state() |
| 42 | + input_dict[name] = paddle.nn.parameter.Parameter( |
| 43 | + utils.replay_tensor(meta), name=original_name |
| 44 | + ) |
| 45 | + for name, meta in inputs.items(): |
| 46 | + random_state_dict[name] = np.random.get_state() |
| 47 | + input_dict[name] = utils.replay_tensor(meta) |
| 48 | + return input_dict, random_state_dict |
| 49 | + |
| 50 | + |
| 51 | +def save_random_states(output_dir, random_state_dict): |
| 52 | + filepath = os.path.join(output_dir, "random_states.pkl") |
| 53 | + print(f"Write to {filepath}.") |
| 54 | + with open(filepath, "wb") as f: |
| 55 | + pickle.dump(random_state_dict, f) |
41 | 56 |
|
42 | 57 |
|
43 | 58 | def _convert_to_dict(config_str): |
@@ -73,7 +88,9 @@ def main(args): |
73 | 88 | initalize_seed = 123 |
74 | 89 | set_seed(random_seed=initalize_seed) |
75 | 90 |
|
76 | | - input_dict = get_input_dict(args.model_path) |
| 91 | + input_dict, random_state_dict = get_input_dict(args.model_path) |
| 92 | + output_dir = "/work/GraphNet/graph_net/test/outputs/pass_0" |
| 93 | + save_random_states(output_dir, random_state_dict) |
77 | 94 | model = _get_decorator(args)(model) |
78 | 95 | model(**input_dict) |
79 | 96 |
|
|
0 commit comments