Skip to content

Commit 6d15fda

Browse files
committed
Merge branch 'opt_saved_results' into add_original_names
2 parents a7982d5 + 864e7b3 commit 6d15fda

File tree

2 files changed

+130
-124
lines changed

2 files changed

+130
-124
lines changed

graph_net/subgraph_decompose_and_evaluation_step.py

Lines changed: 129 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
import argparse
88
import subprocess
99
import glob
10-
from typing import List, Set, Dict, Any, Union
10+
from dataclasses import dataclass, field
11+
from typing import List, Dict, Union
1112
from graph_net.analysis_util import get_incorrect_models
1213
from graph_net import path_utils
1314

@@ -18,9 +19,16 @@ def convert_b64_string_to_json(b64str):
1819
return json.loads(base64.b64decode(b64str).decode("utf-8"))
1920

2021

22+
def convert_json_to_b64_string(config):
23+
return base64.b64encode(json.dumps(config).encode()).decode()
24+
25+
26+
def get_pass_name(pass_id):
27+
return f"pass_{pass_id}"
28+
29+
2130
def get_ranged_incorrect_models(tolerance_args: List[int], log_path: str) -> set:
22-
if not os.path.exists(log_path):
23-
return set()
31+
assert os.path.exists(log_path)
2432

2533
t_start = tolerance_args[0]
2634
models_start = set(get_incorrect_models(t_start, log_path))
@@ -31,14 +39,10 @@ def get_ranged_incorrect_models(tolerance_args: List[int], log_path: str) -> set
3139
t_end = tolerance_args[1]
3240
models_end = set(get_incorrect_models(t_end, log_path))
3341

34-
print(f"[Filter] Tolerance Range: {t_start} -> {t_end}")
3542
print(
36-
f"[Filter] Fail({t_start}): {len(models_start)}, Fail({t_end}): {len(models_end)}"
43+
f"[Init] number of incorrect models: {len(models_start)} (tolerance={t_start}) - {len(models_end)} (tolerance={t_end})"
3744
)
38-
39-
diff_set = models_start - models_end
40-
41-
return diff_set
45+
return models_start - models_end
4246

4347

4448
class TaskController:
@@ -105,6 +109,37 @@ def _print(self):
105109
print()
106110

107111

112+
@dataclass
113+
class DecomposeConfig:
114+
max_subgraph_size: int = -1
115+
incorrect_models: List[str] = field(default_factory=list)
116+
tasks_map: Dict[str, Union[int, str, list, dict]] = field(default_factory=dict)
117+
running_states: Dict[str, Union[int, str, list, dict]] = field(default_factory=dict)
118+
119+
def save(self, work_dir):
120+
"""Save the current config to a JSON file."""
121+
config_path = self.get_config_path(work_dir)
122+
123+
with open(config_path, "w") as f:
124+
json.dump(self.__dict__, f, indent=4)
125+
print(f"\n[INFO] Save state to: {config_path}")
126+
127+
@classmethod
128+
def load(self, work_dir):
129+
"""Initialize a object from a JSON file."""
130+
config_path = self.get_config_path(work_dir)
131+
if not os.path.exists(config_path):
132+
raise FileNotFoundError(f"Missing configuration file: {config_path}")
133+
134+
with open(config_path, "r") as f:
135+
data = json.load(f)
136+
return self(**data)
137+
138+
@classmethod
139+
def get_config_path(self, work_dir) -> str:
140+
return os.path.join(work_dir, "decompose_config.json")
141+
142+
108143
def get_rectfied_model_path(model_path):
109144
graphnet_root = path_utils.get_graphnet_root()
110145
return os.path.join(graphnet_root, model_path.split("GraphNet/")[-1])
@@ -118,53 +153,10 @@ def count_samples(samples_dir):
118153
return num_samples
119154

120155

121-
def get_decompose_config_path(output_dir: str) -> str:
122-
"""Returns the full path to the decompose configuration file."""
123-
return os.path.join(output_dir, "decompose_config.json")
124-
125-
126156
def get_decompose_workspace_path(output_dir, pass_id):
127157
return os.path.join(output_dir, f"pass_{pass_id}")
128158

129159

130-
def load_decompose_config(work_dir: str) -> Dict[str, Any]:
131-
"""Loads the configuration file from the previous pass."""
132-
config_path = get_decompose_config_path(work_dir)
133-
134-
if not os.path.exists(config_path):
135-
raise FileNotFoundError(f"Missing configuration file: {config_path}")
136-
with open(config_path, "r") as f:
137-
return json.load(f)
138-
139-
140-
def save_decompose_config(
141-
workspace: str,
142-
max_subgraph_size: int,
143-
tasks_map: Dict[str, Union[int, str, list, dict]],
144-
incorrect_paths: Union[List[str], Set[str]],
145-
failed_decomposition_models: Union[List[str], Set[str]],
146-
):
147-
"""Saves the current state to a JSON file."""
148-
149-
tasks_map_copy = {}
150-
for model_name, task_info in tasks_map.items():
151-
tasks_map_copy[model_name] = {}
152-
for key in ["original_path", "split_positions"]:
153-
tasks_map_copy[model_name][key] = task_info[key]
154-
155-
config = {
156-
"max_subgraph_size": max_subgraph_size,
157-
"incorrect_models": list(incorrect_paths),
158-
"tasks_map": tasks_map_copy,
159-
"failed_decomposition_models": list(failed_decomposition_models),
160-
}
161-
config_path = get_decompose_config_path(workspace)
162-
163-
with open(config_path, "w") as f:
164-
json.dump(config, f, indent=4)
165-
print(f"\n[INFO] Save state to: {config_path}")
166-
167-
168160
def get_model_name_with_subgraph_tag(model_path):
169161
fields = model_path.rstrip("/").split(os.sep)
170162
pattern = r"^subgraph(_\d+)?$"
@@ -195,9 +187,7 @@ def run_decomposer_for_single_model(
195187
},
196188
},
197189
}
198-
decorator_config_b64 = base64.b64encode(
199-
json.dumps(decorator_config).encode()
200-
).decode()
190+
decorator_config_b64 = convert_json_to_b64_string(decorator_config)
201191

202192
print(
203193
f"[Decomposition] model_path: {model_path}, split_positions: {split_positions}"
@@ -287,16 +277,16 @@ def reconstruct_subgraph_size(split_positions: List[int]) -> List[list]:
287277
return subgraph_size
288278

289279

290-
def calculate_split_positions_for_subgraph(subgraph_size, max_subgraph_size):
291-
assert isinstance(subgraph_size, (list, tuple)) and len(subgraph_size) == 2
280+
def calculate_split_positions_for_subgraph(subgraph_range, max_subgraph_size):
281+
assert isinstance(subgraph_range, (list, tuple)) and len(subgraph_range) == 2
292282

293283
# subgraph_size: the start and end position in original model.
294-
start_pos, end_pos = subgraph_size
284+
start_pos, end_pos = subgraph_range
295285
end_pos = kMaxGraphSize if end_pos == float("inf") else end_pos
296286

297-
split_positions = list(range(start_pos, end_pos + 1, max_subgraph_size))
298-
if split_positions[-1] != end_pos:
299-
split_positions.append(end_pos)
287+
split_positions = list(
288+
range(start_pos, end_pos + max_subgraph_size - 1, max_subgraph_size)
289+
)
300290
return sorted(list(set(split_positions)))
301291

302292

@@ -317,38 +307,42 @@ def generate_initial_tasks(args):
317307
)
318308

319309
tasks_map[model_name] = {
320-
"subgraph_path": model_path,
321310
"original_path": model_path,
322311
"split_positions": initial_splits,
323312
}
324313

325-
for task in tasks_map.values():
326-
task["split_positions"] = sorted(list(task["split_positions"]))
314+
running_states = {
315+
"pass_0": {
316+
"num_incorrect_models": len(initial_failures),
317+
"incorrect_models": list(sorted(initial_failures)),
318+
}
319+
}
320+
return tasks_map, max_subgraph_size, running_states
327321

328-
return tasks_map, max_subgraph_size
322+
323+
def extract_model_name_and_subgraph_idx(subgraph_path):
324+
# Parse model name and subgraph index
325+
model_name_with_subgraph_idx = subgraph_path.rstrip("/").split(os.sep)[-1]
326+
model_name = "_".join(model_name_with_subgraph_idx.split("_")[:-1])
327+
subgraph_idx = int(model_name_with_subgraph_idx.split("_")[-1])
328+
return model_name, subgraph_idx
329329

330330

331331
def generate_refined_tasks(base_output_dir, current_pass_id):
332332
"""Generates tasks for Pass > 0 based on previous pass results."""
333333
prev_pass_dir = get_decompose_workspace_path(base_output_dir, current_pass_id - 1)
334334
print(f"[Init] Resuming from Pass_{current_pass_id - 1} (Dir: {prev_pass_dir})...")
335335

336-
prev_config = load_decompose_config(prev_pass_dir)
337-
prev_incorrect_subgraphs = prev_config.get("incorrect_models", [])
338-
prev_tasks_map = prev_config.get("tasks_map", {})
339-
340-
prev_max_subgraph_size = prev_config.get("max_subgraph_size")
341-
max_subgraph_size = prev_max_subgraph_size // 2
342-
343-
if not prev_incorrect_subgraphs:
344-
return {}, max_subgraph_size
336+
prev_config = DecomposeConfig.load(prev_pass_dir)
337+
max_subgraph_size = prev_config.max_subgraph_size // 2
338+
if not prev_config.incorrect_models:
339+
return {}, max_subgraph_size, prev_config.running_states
345340

346341
tasks_map = {}
347-
for subgraph_path in prev_incorrect_subgraphs:
348-
# Parse model name and subgraph index
349-
model_name_with_subgraph_idx = subgraph_path.rstrip("/").split(os.sep)[-1]
350-
model_name = "_".join(model_name_with_subgraph_idx.split("_")[:-1])
351-
subgraph_idx = int(model_name_with_subgraph_idx.split("_")[-1])
342+
prev_tasks_map = prev_config.tasks_map
343+
344+
for subgraph_path in sorted(prev_config.incorrect_models):
345+
model_name, subgraph_idx = extract_model_name_and_subgraph_idx(subgraph_path)
352346

353347
assert model_name in prev_tasks_map
354348
pre_task_for_model = prev_tasks_map[model_name]
@@ -360,40 +354,38 @@ def generate_refined_tasks(base_output_dir, current_pass_id):
360354
subgraph_ranges
361355
), f"subgraph_idx {subgraph_idx} is out of bounds for {model_name} (previous split_positions: {prev_split_positions})"
362356

363-
current_fail_range = subgraph_ranges[subgraph_idx]
364-
365-
new_splits = calculate_split_positions_for_subgraph(
366-
current_fail_range, max_subgraph_size
357+
split_positions = calculate_split_positions_for_subgraph(
358+
subgraph_ranges[subgraph_idx], max_subgraph_size
367359
)
368-
369360
if model_name not in tasks_map:
370361
tasks_map[model_name] = {
371-
"subgraph_path": subgraph_path,
372362
"original_path": pre_task_for_model["original_path"],
373-
"split_positions": set(new_splits),
363+
"split_positions": list(sorted(split_positions)),
374364
}
375365
else:
376-
tasks_map[model_name]["split_positions"].update(new_splits)
377-
378-
for task in tasks_map.values():
379-
task["split_positions"] = sorted(list(task["split_positions"]))
366+
merged_split_positions = (
367+
tasks_map[model_name]["split_positions"] + split_positions
368+
)
369+
tasks_map[model_name]["split_positions"] = list(
370+
sorted(set(merged_split_positions))
371+
)
380372

381-
return tasks_map, max_subgraph_size
373+
return tasks_map, max_subgraph_size, prev_config.running_states
382374

383375

384376
def prepare_tasks_and_verify(args, current_pass_id, base_output_dir):
385377
if current_pass_id == 0:
386-
tasks_map, max_subgraph_size = generate_initial_tasks(args)
378+
tasks_map, max_subgraph_size, running_states = generate_initial_tasks(args)
387379
else:
388-
tasks_map, max_subgraph_size = generate_refined_tasks(
380+
tasks_map, max_subgraph_size, running_states = generate_refined_tasks(
389381
base_output_dir, current_pass_id
390382
)
391383

392-
print(f"[INFO] initial max_subgraph_size: {max_subgraph_size}")
393-
print(f"[INFO] number of incorrect models: {len(tasks_map)}")
394-
for model_name, task_info in tasks_map.items():
384+
print(f"[Init] initial max_subgraph_size: {max_subgraph_size}")
385+
print(f"[Init] number of incorrect models: {len(tasks_map)}")
386+
for idx, (model_name, task_info) in enumerate(tasks_map.items()):
395387
original_path = task_info["original_path"]
396-
print(f"- {original_path}")
388+
print(f"- [{idx}] {original_path}")
397389

398390
if not tasks_map:
399391
print("[FINISHED] No models need processing.")
@@ -405,7 +397,7 @@ def prepare_tasks_and_verify(args, current_pass_id, base_output_dir):
405397
)
406398
sys.exit(0)
407399

408-
return tasks_map, max_subgraph_size
400+
return tasks_map, max_subgraph_size, running_states
409401

410402

411403
def execute_decomposition_phase(max_subgraph_size, tasks_map, framework, workspace):
@@ -485,15 +477,14 @@ def main(args):
485477
print("=" * 80)
486478

487479
# --- Step 1: Prepare Tasks and Workspace ---
488-
tasks_map, max_subgraph_size = prepare_tasks_and_verify(
480+
tasks_map, max_subgraph_size, running_states = prepare_tasks_and_verify(
489481
args, current_pass_id, base_output_dir
490482
)
491483
pass_work_dir = get_decompose_workspace_path(base_output_dir, current_pass_id)
492484
if not os.path.exists(pass_work_dir):
493485
os.makedirs(pass_work_dir, exist_ok=True)
494486

495487
# --- Step 2: Decomposition ---
496-
failed_decomposition = []
497488
if task_controller.task_scheduler["run_decomposer"]:
498489
print("\n--- Phase 1: Decomposition ---", flush=True)
499490
(
@@ -503,44 +494,61 @@ def main(args):
503494
) = execute_decomposition_phase(
504495
max_subgraph_size, tasks_map, args.framework, pass_work_dir
505496
)
497+
running_states.get(f"pass_{current_pass_id}", {})[
498+
"failed_decomposition_models"
499+
] = list(failed_decomposition)
506500
else:
507-
config = load_decompose_config(pass_work_dir)
508-
max_subgraph_size = config["max_subgraph_size"]
509-
failed_decomposition = config["failed_decomposition_models"]
510-
tasks_map = config.get("tasks_map", {})
501+
print("\n--- Phase 1: Decomposition (skipped) ---", flush=True)
502+
config = DecomposeConfig.load(pass_work_dir)
503+
max_subgraph_size = config.max_subgraph_size
504+
tasks_map = config.tasks_map
505+
running_states = config.running_states
511506

512507
# --- Step 3: Evaluation ---
513508
pass_log_path = os.path.join(pass_work_dir, "batch_test_result.log")
514509
if task_controller.task_scheduler["run_evaluation"]:
515-
print("\n--- Phase 2: Evaluation ---")
510+
print(f"\n--- Phase 2: Evaluation ({task_controller.test_module_name}) ---")
516511
run_evaluation(args.framework, args.test_config, pass_work_dir, pass_log_path)
517512

518513
# --- Step 4: Analysis ---
519514
next_round_models = set()
520515
if task_controller.task_scheduler["post_analysis"]:
521-
print("\n--- Phase 3: Analysis ---")
522-
analysis_tolerance = (
516+
tolerance = (
523517
args.tolerance[0] if isinstance(args.tolerance, list) else args.tolerance
524518
)
525-
next_round_models = get_incorrect_models(analysis_tolerance, pass_log_path)
519+
print(f"\n--- Phase 3: Analysis (torlance={tolerance}) ---")
520+
next_round_models = sorted(get_incorrect_models(tolerance, pass_log_path))
521+
original_model_paths = set(
522+
[
523+
model_name
524+
for subgraph_path in next_round_models
525+
for model_name, _ in [
526+
extract_model_name_and_subgraph_idx(subgraph_path)
527+
]
528+
]
529+
)
530+
531+
running_states[f"pass_{current_pass_id + 1}"] = {
532+
"num_incorrect_models": len(original_model_paths),
533+
"incorrect_models": list(next_round_models),
534+
}
535+
536+
print(
537+
f"[Analysis] Found {len(next_round_models)} incorrect subgraphs ({len(original_model_paths)} original models)."
538+
)
539+
for idx, model_path in enumerate(next_round_models):
540+
print(f"- [{idx}] {model_path}")
526541

527-
print(f"[Analysis] Found {len(next_round_models)} incorrect subgraphs.\n")
528-
if len(next_round_models) > 0:
529-
print("[DEBUG] List of detected incorrect models:")
530-
for idx, model_path in enumerate(sorted(list(next_round_models))):
531-
print(f" [{idx}] {model_path}")
532-
else:
533-
print("[DEBUG] No incorrect models detected.")
534542
print_summary_and_suggestion(next_round_models, max_subgraph_size)
535543

536544
# --- Step 5: Save States ---
537-
save_decompose_config(
538-
pass_work_dir,
539-
max_subgraph_size,
540-
tasks_map,
541-
next_round_models,
542-
failed_decomposition,
545+
config = DecomposeConfig(
546+
max_subgraph_size=max_subgraph_size,
547+
incorrect_models=list(next_round_models),
548+
tasks_map=tasks_map,
549+
running_states=running_states,
543550
)
551+
config.save(pass_work_dir)
544552

545553

546554
if __name__ == "__main__":

0 commit comments

Comments
 (0)