Skip to content

Commit ade8bb9

Browse files
committed
Merge branch 'opt_saved_results' into add_original_names
2 parents c8a9f68 + c067624 commit ade8bb9

File tree

1 file changed

+61
-55
lines changed

1 file changed

+61
-55
lines changed

graph_net/subgraph_decompose_and_evaluation_step.py

Lines changed: 61 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def save(self, work_dir):
122122

123123
with open(config_path, "w") as f:
124124
json.dump(self.__dict__, f, indent=4)
125-
print(f"\n[INFO] Save state to: {config_path}")
125+
print(f"\n[INFO] Save state to: {config_path}\n")
126126

127127
@classmethod
128128
def load(self, work_dir):
@@ -139,6 +139,19 @@ def load(self, work_dir):
139139
def get_config_path(self, work_dir) -> str:
140140
return os.path.join(work_dir, "decompose_config.json")
141141

142+
def update_running_states(self, pass_id, **kwargs):
143+
pass_key = get_pass_name(pass_id)
144+
if self.running_states.get(pass_key, None) is None:
145+
self.running_states[pass_key] = {}
146+
147+
for key, value in kwargs.items():
148+
assert key in [
149+
"num_incorrect_models",
150+
"incorrect_models",
151+
"failed_decomposition_models",
152+
]
153+
self.running_states[pass_key][key] = value
154+
142155

143156
def get_rectfied_model_path(model_path):
144157
graphnet_root = path_utils.get_graphnet_root()
@@ -281,11 +294,10 @@ def run_evaluation(
281294

282295
def reconstruct_subgraph_size(split_positions: List[int]) -> List[list]:
283296
"""Reconstructs subgraph size based on sorted split positions."""
284-
deduplicated_splits = list(dict.fromkeys(split_positions))
297+
deduplicated_splits = sorted(set(split_positions))
285298

286299
subgraph_size = [
287-
[deduplicated_splits[i], deduplicated_splits[i + 1]]
288-
for i in range(len(deduplicated_splits) - 1)
300+
deduplicated_splits[i : i + 2] for i in range(len(deduplicated_splits) - 1)
289301
]
290302
return subgraph_size
291303

@@ -297,10 +309,10 @@ def calculate_split_positions_for_subgraph(subgraph_range, max_subgraph_size):
297309
start_pos, end_pos = subgraph_range
298310
end_pos = kMaxGraphSize if end_pos == float("inf") else end_pos
299311

300-
split_positions = list(
312+
split_positions = set(
301313
range(start_pos, end_pos + max_subgraph_size - 1, max_subgraph_size)
302314
)
303-
return sorted(list(set(split_positions)))
315+
return list(sorted(split_positions))
304316

305317

306318
def generate_initial_tasks(args):
@@ -341,7 +353,7 @@ def extract_model_name_and_subgraph_idx(subgraph_path):
341353
return model_name, subgraph_idx
342354

343355

344-
def generate_refined_tasks(base_output_dir, current_pass_id):
356+
def generate_successor_tasks(base_output_dir, current_pass_id):
345357
"""Generates tasks for Pass > 0 based on previous pass results."""
346358
prev_pass_dir = get_decompose_workspace_path(base_output_dir, current_pass_id - 1)
347359
print(f"[Init] Resuming from Pass_{current_pass_id - 1} (Dir: {prev_pass_dir})...")
@@ -390,7 +402,7 @@ def prepare_tasks_and_verify(args, current_pass_id, base_output_dir):
390402
if current_pass_id == 0:
391403
tasks_map, max_subgraph_size, running_states = generate_initial_tasks(args)
392404
else:
393-
tasks_map, max_subgraph_size, running_states = generate_refined_tasks(
405+
tasks_map, max_subgraph_size, running_states = generate_successor_tasks(
394406
base_output_dir, current_pass_id
395407
)
396408

@@ -448,15 +460,13 @@ def execute_decomposition_phase(max_subgraph_size, tasks_map, framework, workspa
448460
os.makedirs(decomposed_samples_dir, exist_ok=True)
449461
max_subgraph_size = max(1, max_subgraph_size // 2)
450462
for model_name, task_info in tasks_map.items():
451-
splits = task_info["split_positions"]
452-
if not splits or len(splits) < 2:
463+
split_positions = task_info["split_positions"]
464+
if not split_positions or len(split_positions) < 2:
453465
continue
454-
start_pos = splits[0]
455-
first_segment_end = splits[1]
456-
new_splits = calculate_split_positions_for_subgraph(
457-
[start_pos, first_segment_end], max_subgraph_size
466+
new_split_positions = calculate_split_positions_for_subgraph(
467+
split_positions[0:2], max_subgraph_size
458468
)
459-
task_info["split_positions"] = new_splits
469+
task_info["split_positions"] = new_split_positions
460470
else:
461471
need_decompose = False
462472
print()
@@ -467,6 +477,15 @@ def execute_decomposition_phase(max_subgraph_size, tasks_map, framework, workspa
467477
return tasks_map, failed_decomposition, max_subgraph_size
468478

469479

480+
def count_unique_original_models(incorrect_models):
481+
original_model_paths = set(
482+
model_name
483+
for subgraph_path in incorrect_models
484+
for model_name, _ in [extract_model_name_and_subgraph_idx(subgraph_path)]
485+
)
486+
return len(original_model_paths)
487+
488+
470489
def print_summary_and_suggestion(next_round_models, max_subgraph_size):
471490
"""Print suggestion/result."""
472491
print("\n" + "=" * 80)
@@ -493,9 +512,14 @@ def main(args):
493512
tasks_map, max_subgraph_size, running_states = prepare_tasks_and_verify(
494513
args, current_pass_id, base_output_dir
495514
)
496-
pass_work_dir = get_decompose_workspace_path(base_output_dir, current_pass_id)
497-
if not os.path.exists(pass_work_dir):
498-
os.makedirs(pass_work_dir, exist_ok=True)
515+
decompose_config = DecomposeConfig(
516+
max_subgraph_size=max_subgraph_size,
517+
tasks_map=tasks_map,
518+
running_states=running_states,
519+
)
520+
work_dir = get_decompose_workspace_path(base_output_dir, current_pass_id)
521+
if not os.path.exists(work_dir):
522+
os.makedirs(work_dir, exist_ok=True)
499523

500524
# --- Step 2: Decomposition ---
501525
if task_controller.task_scheduler["run_decomposer"]:
@@ -505,63 +529,45 @@ def main(args):
505529
failed_decomposition,
506530
max_subgraph_size,
507531
) = execute_decomposition_phase(
508-
max_subgraph_size, tasks_map, args.framework, pass_work_dir
532+
max_subgraph_size, tasks_map, args.framework, work_dir
533+
)
534+
decompose_config.update_running_states(
535+
current_pass_id, failed_decomposition_models=list(failed_decomposition)
509536
)
510-
running_states.get(f"pass_{current_pass_id}", {})[
511-
"failed_decomposition_models"
512-
] = list(failed_decomposition)
513537
else:
514538
print("\n--- Phase 1: Decomposition (skipped) ---", flush=True)
515-
config = DecomposeConfig.load(pass_work_dir)
516-
max_subgraph_size = config.max_subgraph_size
517-
tasks_map = config.tasks_map
518-
running_states = config.running_states
539+
decompose_config = DecomposeConfig.load(work_dir)
519540

520541
# --- Step 3: Evaluation ---
521-
pass_log_path = os.path.join(pass_work_dir, "batch_test_result.log")
542+
log_path = os.path.join(work_dir, f"log_{task_controller.test_module_name}.txt")
522543
if task_controller.task_scheduler["run_evaluation"]:
523544
print(f"\n--- Phase 2: Evaluation ({task_controller.test_module_name}) ---")
524-
run_evaluation(args.framework, args.test_config, pass_work_dir, pass_log_path)
545+
run_evaluation(args.framework, args.test_config, work_dir, log_path)
525546

526547
# --- Step 4: Analysis ---
527-
next_round_models = set()
548+
next_pass_incorrect_models = set()
528549
if task_controller.task_scheduler["post_analysis"]:
529550
tolerance = (
530551
args.tolerance[0] if isinstance(args.tolerance, list) else args.tolerance
531552
)
532553
print(f"\n--- Phase 3: Analysis (torlance={tolerance}) ---")
533-
next_round_models = sorted(get_incorrect_models(tolerance, pass_log_path))
534-
original_model_paths = set(
535-
[
536-
model_name
537-
for subgraph_path in next_round_models
538-
for model_name, _ in [
539-
extract_model_name_and_subgraph_idx(subgraph_path)
540-
]
541-
]
554+
next_pass_incorrect_models = sorted(get_incorrect_models(tolerance, log_path))
555+
num_original_models = count_unique_original_models(next_pass_incorrect_models)
556+
decompose_config.update_running_states(
557+
current_pass_id + 1,
558+
num_incorrect_models=num_original_models,
559+
incorrect_models=list(next_pass_incorrect_models),
542560
)
543-
544-
running_states[f"pass_{current_pass_id + 1}"] = {
545-
"num_incorrect_models": len(original_model_paths),
546-
"incorrect_models": list(next_round_models),
547-
}
548-
549561
print(
550-
f"[Analysis] Found {len(next_round_models)} incorrect subgraphs ({len(original_model_paths)} original models)."
562+
f"[Analysis] Found {len(next_pass_incorrect_models)} incorrect subgraphs ({num_original_models} original models)."
551563
)
552-
for idx, model_path in enumerate(next_round_models):
564+
for idx, model_path in enumerate(next_pass_incorrect_models):
553565
print(f"- [{idx}] {model_path}")
554-
555-
print_summary_and_suggestion(next_round_models, max_subgraph_size)
566+
print_summary_and_suggestion(next_pass_incorrect_models, max_subgraph_size)
556567

557568
# --- Step 5: Save States ---
558-
config = DecomposeConfig(
559-
max_subgraph_size=max_subgraph_size,
560-
incorrect_models=list(next_round_models),
561-
tasks_map=tasks_map,
562-
running_states=running_states,
563-
)
564-
config.save(pass_work_dir)
569+
decompose_config.incorrect_models = list(next_pass_incorrect_models)
570+
decompose_config.save(work_dir)
565571

566572

567573
if __name__ == "__main__":

0 commit comments

Comments
 (0)