44from pathlib import Path
55import sys
66import os
7- from dataclasses import dataclass
8- from contextlib import contextmanager
9- import time
10- import math
117import numpy as np
12- import random
138import platform
149import traceback
1510import subprocess
1611import re
1712
18- from graph_net .paddle import utils
1913from graph_net import path_utils
2014from graph_net import test_compiler_util
21-
15+ from graph_net . paddle import utils , random_util
2216from graph_net .paddle .backend .graph_compiler_backend import GraphCompilerBackend
2317from graph_net .paddle .backend .cinn_backend import CinnBackend
2418from graph_net .paddle .backend .nope_backend import NopeBackend
@@ -35,12 +29,6 @@ def get_compiler_backend(args) -> GraphCompilerBackend:
3529 return registry_backend [args .compiler ]
3630
3731
38- def set_seed (random_seed ):
39- paddle .seed (random_seed )
40- random .seed (random_seed )
41- np .random .seed (random_seed )
42-
43-
4432def init_env (args ):
4533 if test_compiler_util .is_gpu_device (args .device ):
4634 paddle .set_flags ({"FLAGS_cudnn_exhaustive_search" : 1 })
@@ -62,7 +50,7 @@ def get_hardward_name(args):
6250 )
6351 )
6452 )
65- except Exception as e :
53+ except Exception :
6654 pass
6755 elif args .device == "cpu" :
6856 hardware = platform .processor ()
@@ -100,14 +88,25 @@ def get_model(model_path):
10088 return model_class ()
10189
10290
103- def get_input_dict (model_path ):
91+ def get_input_dict (model_path , random_states_path = None ):
10492 inputs_params = utils .load_converted_from_text (f"{ model_path } " )
10593 params = inputs_params ["weight_info" ]
10694 inputs = inputs_params ["input_info" ]
107-
10895 params .update (inputs )
109- state_dict = {k : utils .replay_tensor (v ) for k , v in params .items ()}
110- return state_dict
96+
97+ random_states = (
98+ random_util .load_random_states (model_path , random_states_path )
99+ if random_states_path
100+ else None
101+ )
102+
103+ input_dict = {}
104+ for name , meta in params .items ():
105+ if random_states is not None and random_states .get (name , None ) is not None :
106+ np .random .set_state (random_states [name ])
107+ tensor = utils .replay_tensor (meta )
108+ input_dict [name ] = tensor
109+ return input_dict
111110
112111
113112def get_input_spec (model_path ):
@@ -476,7 +475,7 @@ def main(args):
476475 assert args .device in ["cuda" , "dcu" , "xpu" , "cpu" ]
477476
478477 initalize_seed = 123
479- set_seed (random_seed = initalize_seed )
478+ random_util . set_seed (random_seed = initalize_seed )
480479
481480 if path_utils .is_single_model_dir (args .model_path ):
482481 test_single_model (args )
0 commit comments