77import argparse
88import subprocess
99import glob
10- from typing import List , Set , Dict , Any , Union
10+ from dataclasses import dataclass , field
11+ from typing import List , Dict , Union
1112from graph_net .analysis_util import get_incorrect_models
1213from graph_net import path_utils
1314
@@ -18,9 +19,16 @@ def convert_b64_string_to_json(b64str):
1819 return json .loads (base64 .b64decode (b64str ).decode ("utf-8" ))
1920
2021
22+ def convert_json_to_b64_string (config ):
23+ return base64 .b64encode (json .dumps (config ).encode ()).decode ()
24+
25+
26+ def get_pass_name (pass_id ):
27+ return f"pass_{ pass_id } "
28+
29+
2130def get_ranged_incorrect_models (tolerance_args : List [int ], log_path : str ) -> set :
22- if not os .path .exists (log_path ):
23- return set ()
31+ assert os .path .exists (log_path )
2432
2533 t_start = tolerance_args [0 ]
2634 models_start = set (get_incorrect_models (t_start , log_path ))
@@ -31,14 +39,10 @@ def get_ranged_incorrect_models(tolerance_args: List[int], log_path: str) -> set
3139 t_end = tolerance_args [1 ]
3240 models_end = set (get_incorrect_models (t_end , log_path ))
3341
34- print (f"[Filter] Tolerance Range: { t_start } -> { t_end } " )
3542 print (
36- f"[Filter] Fail( { t_start } ) : { len (models_start )} , Fail( { t_end } ): { len (models_end )} "
43+ f"[Init] number of incorrect models : { len (models_start )} (tolerance= { t_start } ) - { len (models_end )} (tolerance= { t_end } ) "
3744 )
38-
39- diff_set = models_start - models_end
40-
41- return diff_set
45+ return models_start - models_end
4246
4347
4448class TaskController :
@@ -105,6 +109,37 @@ def _print(self):
105109 print ()
106110
107111
112+ @dataclass
113+ class DecomposeConfig :
114+ max_subgraph_size : int = - 1
115+ incorrect_models : List [str ] = field (default_factory = list )
116+ tasks_map : Dict [str , Union [int , str , list , dict ]] = field (default_factory = dict )
117+ running_states : Dict [str , Union [int , str , list , dict ]] = field (default_factory = dict )
118+
119+ def save (self , work_dir ):
120+ """Save the current config to a JSON file."""
121+ config_path = self .get_config_path (work_dir )
122+
123+ with open (config_path , "w" ) as f :
124+ json .dump (self .__dict__ , f , indent = 4 )
125+ print (f"\n [INFO] Save state to: { config_path } " )
126+
127+ @classmethod
128+ def load (self , work_dir ):
129+ """Initialize a object from a JSON file."""
130+ config_path = self .get_config_path (work_dir )
131+ if not os .path .exists (config_path ):
132+ raise FileNotFoundError (f"Missing configuration file: { config_path } " )
133+
134+ with open (config_path , "r" ) as f :
135+ data = json .load (f )
136+ return self (** data )
137+
138+ @classmethod
139+ def get_config_path (self , work_dir ) -> str :
140+ return os .path .join (work_dir , "decompose_config.json" )
141+
142+
108143def get_rectfied_model_path (model_path ):
109144 graphnet_root = path_utils .get_graphnet_root ()
110145 return os .path .join (graphnet_root , model_path .split ("GraphNet/" )[- 1 ])
@@ -118,53 +153,10 @@ def count_samples(samples_dir):
118153 return num_samples
119154
120155
121- def get_decompose_config_path (output_dir : str ) -> str :
122- """Returns the full path to the decompose configuration file."""
123- return os .path .join (output_dir , "decompose_config.json" )
124-
125-
126156def get_decompose_workspace_path (output_dir , pass_id ):
127157 return os .path .join (output_dir , f"pass_{ pass_id } " )
128158
129159
130- def load_decompose_config (work_dir : str ) -> Dict [str , Any ]:
131- """Loads the configuration file from the previous pass."""
132- config_path = get_decompose_config_path (work_dir )
133-
134- if not os .path .exists (config_path ):
135- raise FileNotFoundError (f"Missing configuration file: { config_path } " )
136- with open (config_path , "r" ) as f :
137- return json .load (f )
138-
139-
140- def save_decompose_config (
141- workspace : str ,
142- max_subgraph_size : int ,
143- tasks_map : Dict [str , Union [int , str , list , dict ]],
144- incorrect_paths : Union [List [str ], Set [str ]],
145- failed_decomposition_models : Union [List [str ], Set [str ]],
146- ):
147- """Saves the current state to a JSON file."""
148-
149- tasks_map_copy = {}
150- for model_name , task_info in tasks_map .items ():
151- tasks_map_copy [model_name ] = {}
152- for key in ["original_path" , "split_positions" ]:
153- tasks_map_copy [model_name ][key ] = task_info [key ]
154-
155- config = {
156- "max_subgraph_size" : max_subgraph_size ,
157- "incorrect_models" : list (incorrect_paths ),
158- "tasks_map" : tasks_map_copy ,
159- "failed_decomposition_models" : list (failed_decomposition_models ),
160- }
161- config_path = get_decompose_config_path (workspace )
162-
163- with open (config_path , "w" ) as f :
164- json .dump (config , f , indent = 4 )
165- print (f"\n [INFO] Save state to: { config_path } " )
166-
167-
168160def get_model_name_with_subgraph_tag (model_path ):
169161 fields = model_path .rstrip ("/" ).split (os .sep )
170162 pattern = r"^subgraph(_\d+)?$"
@@ -195,9 +187,7 @@ def run_decomposer_for_single_model(
195187 },
196188 },
197189 }
198- decorator_config_b64 = base64 .b64encode (
199- json .dumps (decorator_config ).encode ()
200- ).decode ()
190+ decorator_config_b64 = convert_json_to_b64_string (decorator_config )
201191
202192 print (
203193 f"[Decomposition] model_path: { model_path } , split_positions: { split_positions } "
@@ -287,16 +277,16 @@ def reconstruct_subgraph_size(split_positions: List[int]) -> List[list]:
287277 return subgraph_size
288278
289279
290- def calculate_split_positions_for_subgraph (subgraph_size , max_subgraph_size ):
291- assert isinstance (subgraph_size , (list , tuple )) and len (subgraph_size ) == 2
280+ def calculate_split_positions_for_subgraph (subgraph_range , max_subgraph_size ):
281+ assert isinstance (subgraph_range , (list , tuple )) and len (subgraph_range ) == 2
292282
293283 # subgraph_size: the start and end position in original model.
294- start_pos , end_pos = subgraph_size
284+ start_pos , end_pos = subgraph_range
295285 end_pos = kMaxGraphSize if end_pos == float ("inf" ) else end_pos
296286
297- split_positions = list (range ( start_pos , end_pos + 1 , max_subgraph_size ))
298- if split_positions [ - 1 ] != end_pos :
299- split_positions . append ( end_pos )
287+ split_positions = list (
288+ range ( start_pos , end_pos + max_subgraph_size - 1 , max_subgraph_size )
289+ )
300290 return sorted (list (set (split_positions )))
301291
302292
@@ -317,38 +307,42 @@ def generate_initial_tasks(args):
317307 )
318308
319309 tasks_map [model_name ] = {
320- "subgraph_path" : model_path ,
321310 "original_path" : model_path ,
322311 "split_positions" : initial_splits ,
323312 }
324313
325- for task in tasks_map .values ():
326- task ["split_positions" ] = sorted (list (task ["split_positions" ]))
314+ running_states = {
315+ "pass_0" : {
316+ "num_incorrect_models" : len (initial_failures ),
317+ "incorrect_models" : list (sorted (initial_failures )),
318+ }
319+ }
320+ return tasks_map , max_subgraph_size , running_states
327321
328- return tasks_map , max_subgraph_size
322+
323+ def extract_model_name_and_subgraph_idx (subgraph_path ):
324+ # Parse model name and subgraph index
325+ model_name_with_subgraph_idx = subgraph_path .rstrip ("/" ).split (os .sep )[- 1 ]
326+ model_name = "_" .join (model_name_with_subgraph_idx .split ("_" )[:- 1 ])
327+ subgraph_idx = int (model_name_with_subgraph_idx .split ("_" )[- 1 ])
328+ return model_name , subgraph_idx
329329
330330
331331def generate_refined_tasks (base_output_dir , current_pass_id ):
332332 """Generates tasks for Pass > 0 based on previous pass results."""
333333 prev_pass_dir = get_decompose_workspace_path (base_output_dir , current_pass_id - 1 )
334334 print (f"[Init] Resuming from Pass_{ current_pass_id - 1 } (Dir: { prev_pass_dir } )..." )
335335
336- prev_config = load_decompose_config (prev_pass_dir )
337- prev_incorrect_subgraphs = prev_config .get ("incorrect_models" , [])
338- prev_tasks_map = prev_config .get ("tasks_map" , {})
339-
340- prev_max_subgraph_size = prev_config .get ("max_subgraph_size" )
341- max_subgraph_size = prev_max_subgraph_size // 2
342-
343- if not prev_incorrect_subgraphs :
344- return {}, max_subgraph_size
336+ prev_config = DecomposeConfig .load (prev_pass_dir )
337+ max_subgraph_size = prev_config .max_subgraph_size // 2
338+ if not prev_config .incorrect_models :
339+ return {}, max_subgraph_size , prev_config .running_states
345340
346341 tasks_map = {}
347- for subgraph_path in prev_incorrect_subgraphs :
348- # Parse model name and subgraph index
349- model_name_with_subgraph_idx = subgraph_path .rstrip ("/" ).split (os .sep )[- 1 ]
350- model_name = "_" .join (model_name_with_subgraph_idx .split ("_" )[:- 1 ])
351- subgraph_idx = int (model_name_with_subgraph_idx .split ("_" )[- 1 ])
342+ prev_tasks_map = prev_config .tasks_map
343+
344+ for subgraph_path in sorted (prev_config .incorrect_models ):
345+ model_name , subgraph_idx = extract_model_name_and_subgraph_idx (subgraph_path )
352346
353347 assert model_name in prev_tasks_map
354348 pre_task_for_model = prev_tasks_map [model_name ]
@@ -360,40 +354,38 @@ def generate_refined_tasks(base_output_dir, current_pass_id):
360354 subgraph_ranges
361355 ), f"subgraph_idx { subgraph_idx } is out of bounds for { model_name } (previous split_positions: { prev_split_positions } )"
362356
363- current_fail_range = subgraph_ranges [subgraph_idx ]
364-
365- new_splits = calculate_split_positions_for_subgraph (
366- current_fail_range , max_subgraph_size
357+ split_positions = calculate_split_positions_for_subgraph (
358+ subgraph_ranges [subgraph_idx ], max_subgraph_size
367359 )
368-
369360 if model_name not in tasks_map :
370361 tasks_map [model_name ] = {
371- "subgraph_path" : subgraph_path ,
372362 "original_path" : pre_task_for_model ["original_path" ],
373- "split_positions" : set ( new_splits ),
363+ "split_positions" : list ( sorted ( split_positions ) ),
374364 }
375365 else :
376- tasks_map [model_name ]["split_positions" ].update (new_splits )
377-
378- for task in tasks_map .values ():
379- task ["split_positions" ] = sorted (list (task ["split_positions" ]))
366+ merged_split_positions = (
367+ tasks_map [model_name ]["split_positions" ] + split_positions
368+ )
369+ tasks_map [model_name ]["split_positions" ] = list (
370+ sorted (set (merged_split_positions ))
371+ )
380372
381- return tasks_map , max_subgraph_size
373+ return tasks_map , max_subgraph_size , prev_config . running_states
382374
383375
384376def prepare_tasks_and_verify (args , current_pass_id , base_output_dir ):
385377 if current_pass_id == 0 :
386- tasks_map , max_subgraph_size = generate_initial_tasks (args )
378+ tasks_map , max_subgraph_size , running_states = generate_initial_tasks (args )
387379 else :
388- tasks_map , max_subgraph_size = generate_refined_tasks (
380+ tasks_map , max_subgraph_size , running_states = generate_refined_tasks (
389381 base_output_dir , current_pass_id
390382 )
391383
392- print (f"[INFO ] initial max_subgraph_size: { max_subgraph_size } " )
393- print (f"[INFO ] number of incorrect models: { len (tasks_map )} " )
394- for model_name , task_info in tasks_map .items ():
384+ print (f"[Init ] initial max_subgraph_size: { max_subgraph_size } " )
385+ print (f"[Init ] number of incorrect models: { len (tasks_map )} " )
386+ for idx , ( model_name , task_info ) in enumerate ( tasks_map .items () ):
395387 original_path = task_info ["original_path" ]
396- print (f"- { original_path } " )
388+ print (f"- [ { idx } ] { original_path } " )
397389
398390 if not tasks_map :
399391 print ("[FINISHED] No models need processing." )
@@ -405,7 +397,7 @@ def prepare_tasks_and_verify(args, current_pass_id, base_output_dir):
405397 )
406398 sys .exit (0 )
407399
408- return tasks_map , max_subgraph_size
400+ return tasks_map , max_subgraph_size , running_states
409401
410402
411403def execute_decomposition_phase (max_subgraph_size , tasks_map , framework , workspace ):
@@ -485,15 +477,14 @@ def main(args):
485477 print ("=" * 80 )
486478
487479 # --- Step 1: Prepare Tasks and Workspace ---
488- tasks_map , max_subgraph_size = prepare_tasks_and_verify (
480+ tasks_map , max_subgraph_size , running_states = prepare_tasks_and_verify (
489481 args , current_pass_id , base_output_dir
490482 )
491483 pass_work_dir = get_decompose_workspace_path (base_output_dir , current_pass_id )
492484 if not os .path .exists (pass_work_dir ):
493485 os .makedirs (pass_work_dir , exist_ok = True )
494486
495487 # --- Step 2: Decomposition ---
496- failed_decomposition = []
497488 if task_controller .task_scheduler ["run_decomposer" ]:
498489 print ("\n --- Phase 1: Decomposition ---" , flush = True )
499490 (
@@ -503,44 +494,61 @@ def main(args):
503494 ) = execute_decomposition_phase (
504495 max_subgraph_size , tasks_map , args .framework , pass_work_dir
505496 )
497+ running_states .get (f"pass_{ current_pass_id } " , {})[
498+ "failed_decomposition_models"
499+ ] = list (failed_decomposition )
506500 else :
507- config = load_decompose_config (pass_work_dir )
508- max_subgraph_size = config ["max_subgraph_size" ]
509- failed_decomposition = config ["failed_decomposition_models" ]
510- tasks_map = config .get ("tasks_map" , {})
501+ print ("\n --- Phase 1: Decomposition (skipped) ---" , flush = True )
502+ config = DecomposeConfig .load (pass_work_dir )
503+ max_subgraph_size = config .max_subgraph_size
504+ tasks_map = config .tasks_map
505+ running_states = config .running_states
511506
512507 # --- Step 3: Evaluation ---
513508 pass_log_path = os .path .join (pass_work_dir , "batch_test_result.log" )
514509 if task_controller .task_scheduler ["run_evaluation" ]:
515- print ("\n --- Phase 2: Evaluation ---" )
510+ print (f "\n --- Phase 2: Evaluation ( { task_controller . test_module_name } ) ---" )
516511 run_evaluation (args .framework , args .test_config , pass_work_dir , pass_log_path )
517512
518513 # --- Step 4: Analysis ---
519514 next_round_models = set ()
520515 if task_controller .task_scheduler ["post_analysis" ]:
521- print ("\n --- Phase 3: Analysis ---" )
522- analysis_tolerance = (
516+ tolerance = (
523517 args .tolerance [0 ] if isinstance (args .tolerance , list ) else args .tolerance
524518 )
525- next_round_models = get_incorrect_models (analysis_tolerance , pass_log_path )
519+ print (f"\n --- Phase 3: Analysis (torlance={ tolerance } ) ---" )
520+ next_round_models = sorted (get_incorrect_models (tolerance , pass_log_path ))
521+ original_model_paths = set (
522+ [
523+ model_name
524+ for subgraph_path in next_round_models
525+ for model_name , _ in [
526+ extract_model_name_and_subgraph_idx (subgraph_path )
527+ ]
528+ ]
529+ )
530+
531+ running_states [f"pass_{ current_pass_id + 1 } " ] = {
532+ "num_incorrect_models" : len (original_model_paths ),
533+ "incorrect_models" : list (next_round_models ),
534+ }
535+
536+ print (
537+ f"[Analysis] Found { len (next_round_models )} incorrect subgraphs ({ len (original_model_paths )} original models)."
538+ )
539+ for idx , model_path in enumerate (next_round_models ):
540+ print (f"- [{ idx } ] { model_path } " )
526541
527- print (f"[Analysis] Found { len (next_round_models )} incorrect subgraphs.\n " )
528- if len (next_round_models ) > 0 :
529- print ("[DEBUG] List of detected incorrect models:" )
530- for idx , model_path in enumerate (sorted (list (next_round_models ))):
531- print (f" [{ idx } ] { model_path } " )
532- else :
533- print ("[DEBUG] No incorrect models detected." )
534542 print_summary_and_suggestion (next_round_models , max_subgraph_size )
535543
536544 # --- Step 5: Save States ---
537- save_decompose_config (
538- pass_work_dir ,
539- max_subgraph_size ,
540- tasks_map ,
541- next_round_models ,
542- failed_decomposition ,
545+ config = DecomposeConfig (
546+ max_subgraph_size = max_subgraph_size ,
547+ incorrect_models = list (next_round_models ),
548+ tasks_map = tasks_map ,
549+ running_states = running_states ,
543550 )
551+ config .save (pass_work_dir )
544552
545553
546554if __name__ == "__main__" :
0 commit comments