@@ -122,7 +122,7 @@ def save(self, work_dir):
122122
123123 with open (config_path , "w" ) as f :
124124 json .dump (self .__dict__ , f , indent = 4 )
125- print (f"\n [INFO] Save state to: { config_path } " )
125+ print (f"\n [INFO] Save state to: { config_path } \n " )
126126
127127 @classmethod
128128 def load (self , work_dir ):
@@ -139,6 +139,19 @@ def load(self, work_dir):
139139 def get_config_path (self , work_dir ) -> str :
140140 return os .path .join (work_dir , "decompose_config.json" )
141141
142+ def update_running_states (self , pass_id , ** kwargs ):
143+ pass_key = get_pass_name (pass_id )
144+ if self .running_states .get (pass_key , None ) is None :
145+ self .running_states [pass_key ] = {}
146+
147+ for key , value in kwargs .items ():
148+ assert key in [
149+ "num_incorrect_models" ,
150+ "incorrect_models" ,
151+ "failed_decomposition_models" ,
152+ ]
153+ self .running_states [pass_key ][key ] = value
154+
142155
143156def get_rectfied_model_path (model_path ):
144157 graphnet_root = path_utils .get_graphnet_root ()
@@ -281,11 +294,10 @@ def run_evaluation(
281294
282295def reconstruct_subgraph_size (split_positions : List [int ]) -> List [list ]:
283296 """Reconstructs subgraph size based on sorted split positions."""
284- deduplicated_splits = list ( dict . fromkeys (split_positions ))
297+ deduplicated_splits = sorted ( set (split_positions ))
285298
286299 subgraph_size = [
287- [deduplicated_splits [i ], deduplicated_splits [i + 1 ]]
288- for i in range (len (deduplicated_splits ) - 1 )
300+ deduplicated_splits [i : i + 2 ] for i in range (len (deduplicated_splits ) - 1 )
289301 ]
290302 return subgraph_size
291303
@@ -297,10 +309,10 @@ def calculate_split_positions_for_subgraph(subgraph_range, max_subgraph_size):
297309 start_pos , end_pos = subgraph_range
298310 end_pos = kMaxGraphSize if end_pos == float ("inf" ) else end_pos
299311
300- split_positions = list (
312+ split_positions = set (
301313 range (start_pos , end_pos + max_subgraph_size - 1 , max_subgraph_size )
302314 )
303- return sorted ( list (set (split_positions ) ))
315+ return list (sorted (split_positions ))
304316
305317
306318def generate_initial_tasks (args ):
@@ -341,7 +353,7 @@ def extract_model_name_and_subgraph_idx(subgraph_path):
341353 return model_name , subgraph_idx
342354
343355
344- def generate_refined_tasks (base_output_dir , current_pass_id ):
356+ def generate_successor_tasks (base_output_dir , current_pass_id ):
345357 """Generates tasks for Pass > 0 based on previous pass results."""
346358 prev_pass_dir = get_decompose_workspace_path (base_output_dir , current_pass_id - 1 )
347359 print (f"[Init] Resuming from Pass_{ current_pass_id - 1 } (Dir: { prev_pass_dir } )..." )
@@ -390,7 +402,7 @@ def prepare_tasks_and_verify(args, current_pass_id, base_output_dir):
390402 if current_pass_id == 0 :
391403 tasks_map , max_subgraph_size , running_states = generate_initial_tasks (args )
392404 else :
393- tasks_map , max_subgraph_size , running_states = generate_refined_tasks (
405+ tasks_map , max_subgraph_size , running_states = generate_successor_tasks (
394406 base_output_dir , current_pass_id
395407 )
396408
@@ -448,15 +460,13 @@ def execute_decomposition_phase(max_subgraph_size, tasks_map, framework, workspa
448460 os .makedirs (decomposed_samples_dir , exist_ok = True )
449461 max_subgraph_size = max (1 , max_subgraph_size // 2 )
450462 for model_name , task_info in tasks_map .items ():
451- splits = task_info ["split_positions" ]
452- if not splits or len (splits ) < 2 :
463+ split_positions = task_info ["split_positions" ]
464+ if not split_positions or len (split_positions ) < 2 :
453465 continue
454- start_pos = splits [0 ]
455- first_segment_end = splits [1 ]
456- new_splits = calculate_split_positions_for_subgraph (
457- [start_pos , first_segment_end ], max_subgraph_size
466+ new_split_positions = calculate_split_positions_for_subgraph (
467+ split_positions [0 :2 ], max_subgraph_size
458468 )
459- task_info ["split_positions" ] = new_splits
469+ task_info ["split_positions" ] = new_split_positions
460470 else :
461471 need_decompose = False
462472 print ()
@@ -467,6 +477,15 @@ def execute_decomposition_phase(max_subgraph_size, tasks_map, framework, workspa
467477 return tasks_map , failed_decomposition , max_subgraph_size
468478
469479
480+ def count_unique_original_models (incorrect_models ):
481+ original_model_paths = set (
482+ model_name
483+ for subgraph_path in incorrect_models
484+ for model_name , _ in [extract_model_name_and_subgraph_idx (subgraph_path )]
485+ )
486+ return len (original_model_paths )
487+
488+
470489def print_summary_and_suggestion (next_round_models , max_subgraph_size ):
471490 """Print suggestion/result."""
472491 print ("\n " + "=" * 80 )
@@ -493,9 +512,14 @@ def main(args):
493512 tasks_map , max_subgraph_size , running_states = prepare_tasks_and_verify (
494513 args , current_pass_id , base_output_dir
495514 )
496- pass_work_dir = get_decompose_workspace_path (base_output_dir , current_pass_id )
497- if not os .path .exists (pass_work_dir ):
498- os .makedirs (pass_work_dir , exist_ok = True )
515+ decompose_config = DecomposeConfig (
516+ max_subgraph_size = max_subgraph_size ,
517+ tasks_map = tasks_map ,
518+ running_states = running_states ,
519+ )
520+ work_dir = get_decompose_workspace_path (base_output_dir , current_pass_id )
521+ if not os .path .exists (work_dir ):
522+ os .makedirs (work_dir , exist_ok = True )
499523
500524 # --- Step 2: Decomposition ---
501525 if task_controller .task_scheduler ["run_decomposer" ]:
@@ -505,63 +529,45 @@ def main(args):
505529 failed_decomposition ,
506530 max_subgraph_size ,
507531 ) = execute_decomposition_phase (
508- max_subgraph_size , tasks_map , args .framework , pass_work_dir
532+ max_subgraph_size , tasks_map , args .framework , work_dir
533+ )
534+ decompose_config .update_running_states (
535+ current_pass_id , failed_decomposition_models = list (failed_decomposition )
509536 )
510- running_states .get (f"pass_{ current_pass_id } " , {})[
511- "failed_decomposition_models"
512- ] = list (failed_decomposition )
513537 else :
514538 print ("\n --- Phase 1: Decomposition (skipped) ---" , flush = True )
515- config = DecomposeConfig .load (pass_work_dir )
516- max_subgraph_size = config .max_subgraph_size
517- tasks_map = config .tasks_map
518- running_states = config .running_states
539+ decompose_config = DecomposeConfig .load (work_dir )
519540
520541 # --- Step 3: Evaluation ---
521- pass_log_path = os .path .join (pass_work_dir , "batch_test_result.log " )
542+ log_path = os .path .join (work_dir , f"log_ { task_controller . test_module_name } .txt " )
522543 if task_controller .task_scheduler ["run_evaluation" ]:
523544 print (f"\n --- Phase 2: Evaluation ({ task_controller .test_module_name } ) ---" )
524- run_evaluation (args .framework , args .test_config , pass_work_dir , pass_log_path )
545+ run_evaluation (args .framework , args .test_config , work_dir , log_path )
525546
526547 # --- Step 4: Analysis ---
527- next_round_models = set ()
548+ next_pass_incorrect_models = set ()
528549 if task_controller .task_scheduler ["post_analysis" ]:
529550 tolerance = (
530551 args .tolerance [0 ] if isinstance (args .tolerance , list ) else args .tolerance
531552 )
532553 print (f"\n --- Phase 3: Analysis (torlance={ tolerance } ) ---" )
533- next_round_models = sorted (get_incorrect_models (tolerance , pass_log_path ))
534- original_model_paths = set (
535- [
536- model_name
537- for subgraph_path in next_round_models
538- for model_name , _ in [
539- extract_model_name_and_subgraph_idx (subgraph_path )
540- ]
541- ]
554+ next_pass_incorrect_models = sorted (get_incorrect_models (tolerance , log_path ))
555+ num_original_models = count_unique_original_models (next_pass_incorrect_models )
556+ decompose_config .update_running_states (
557+ current_pass_id + 1 ,
558+ num_incorrect_models = num_original_models ,
559+ incorrect_models = list (next_pass_incorrect_models ),
542560 )
543-
544- running_states [f"pass_{ current_pass_id + 1 } " ] = {
545- "num_incorrect_models" : len (original_model_paths ),
546- "incorrect_models" : list (next_round_models ),
547- }
548-
549561 print (
550- f"[Analysis] Found { len (next_round_models )} incorrect subgraphs ({ len ( original_model_paths ) } original models)."
562+ f"[Analysis] Found { len (next_pass_incorrect_models )} incorrect subgraphs ({ num_original_models } original models)."
551563 )
552- for idx , model_path in enumerate (next_round_models ):
564+ for idx , model_path in enumerate (next_pass_incorrect_models ):
553565 print (f"- [{ idx } ] { model_path } " )
554-
555- print_summary_and_suggestion (next_round_models , max_subgraph_size )
566+ print_summary_and_suggestion (next_pass_incorrect_models , max_subgraph_size )
556567
557568 # --- Step 5: Save States ---
558- config = DecomposeConfig (
559- max_subgraph_size = max_subgraph_size ,
560- incorrect_models = list (next_round_models ),
561- tasks_map = tasks_map ,
562- running_states = running_states ,
563- )
564- config .save (pass_work_dir )
569+ decompose_config .incorrect_models = list (next_pass_incorrect_models )
570+ decompose_config .save (work_dir )
565571
566572
567573if __name__ == "__main__" :
0 commit comments