Skip to content

Commit 7019afb

Browse files
[BugFix] fix instability after clearing weight (#5487)
* [BugFix] fix instability after clearing weight * [chore] add todo
1 parent bcde798 commit 7019afb

File tree

2 files changed

+23
-15
lines changed

2 files changed

+23
-15
lines changed

fastdeploy/rl/dynamic_weight_manager.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -62,17 +62,18 @@ def _capture_model_state(self):
6262
logger.info(f"Model param: {name}, shape={param.shape}, dtype={param.dtype}")
6363
self.state_dict[name] = param
6464

65-
def update_parameters(self, pid: int = 0) -> None:
65+
def update_parameters(self, pid: int = 0, restart_process_group=False) -> None:
6666
"""Core method to update model parameters based on strategy."""
6767
start_time = time.perf_counter()
6868
paddle.device.cuda.empty_cache()
6969

7070
# step1 : restart paddle process group
71-
# if not self.first_load:
72-
# paddle.distributed.restart_process_group()
73-
# paddle.distributed.restart_process_group(self.parallel_config.tp_group)
74-
# if self.parallel_config.enable_expert_parallel:
75-
# paddle.distributed.restart_process_group(self.parallel_config.ep_group)
71+
if not self.first_load:
72+
if restart_process_group:
73+
paddle.distributed.restart_process_group()
74+
paddle.distributed.restart_process_group(self.parallel_config.tp_group)
75+
if self.parallel_config.enable_expert_parallel:
76+
paddle.distributed.restart_process_group(self.parallel_config.ep_group)
7677

7778
# step2 : recreat deepep buffer when enable expert parallel
7879
if self.parallel_config.enable_expert_parallel and not self.first_load:
@@ -132,7 +133,7 @@ def _update_ipc(self):
132133
self._update_model_from_state(state_dict, "raw")
133134
logger.info(f"IPC update parameters completed from file: {self.ipc_path}")
134135

135-
def clear_parameters(self, pid: int = 0) -> None:
136+
def clear_parameters(self, pid: int = 0, shutdown_process_group=False) -> None:
136137
"""Clear all model parameters and free memory."""
137138

138139
logger.info("start clear paramaters")
@@ -144,8 +145,9 @@ def clear_parameters(self, pid: int = 0) -> None:
144145
DeepEPBufferManager.clear_buffer()
145146
# ep barrier
146147
paddle.distributed.barrier(self.parallel_config.ep_group)
147-
# shutdown ep group
148-
# paddle.distributed.shutdown_process_group(self.parallel_config.ep_group)
148+
if shutdown_process_group:
149+
# shutdown ep group
150+
paddle.distributed.shutdown_process_group(self.parallel_config.ep_group)
149151

150152
paddle.device.cuda.empty_cache()
151153
# step2: release model weight
@@ -158,11 +160,14 @@ def clear_parameters(self, pid: int = 0) -> None:
158160
if self.parallel_config.tensor_parallel_size > 1:
159161
# tp barrier
160162
paddle.distributed.barrier(self.parallel_config.tp_group)
161-
# paddle.distributed.shutdown_process_group(self.parallel_config.tp_group)
163+
if shutdown_process_group:
164+
paddle.distributed.shutdown_process_group(self.parallel_config.tp_group)
162165
if self.parallel_config.enable_expert_parallel:
163166
paddle.distributed.barrier(self.parallel_config.ep_group)
164-
# paddle.distributed.shutdown_process_group(self.parallel_config.ep_group)
165-
# paddle.distributed.shutdown_process_group()
167+
if shutdown_process_group:
168+
paddle.distributed.shutdown_process_group(self.parallel_config.ep_group)
169+
if shutdown_process_group:
170+
paddle.distributed.shutdown_process_group()
166171
self._update_shared_status(pid, ModelWeightsStatus.CLEARED)
167172

168173
def _update_model_from_state(self, state_dict: Dict[str, paddle.Tensor], src_type: str):

fastdeploy/worker/worker_process.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,7 @@ def event_loop_normal(self) -> None:
418418
num_running_requests = 0
419419
tp_rank = self.local_rank % tp_size
420420

421+
# TODO: Unify status variables model_weights_status (shared memory) and model_weights_signal (numpy array) to one
421422
self.model_weights_signal = np.zeros([1], dtype=np.int32)
422423
while True:
423424
# run eplb
@@ -459,7 +460,7 @@ def event_loop_normal(self) -> None:
459460
else:
460461
paddle.distributed.barrier(self.parallel_config.tp_group)
461462
if self.model_weights_signal[0] != ModelWeightsStatus.NORMAL:
462-
logger.debug(
463+
logger.info(
463464
f"Rank: {self.local_rank} to update or clear parameters, signal is {self.model_weights_signal[0]}, [-1:clear, 1:update]"
464465
)
465466
from fastdeploy.rl.dynamic_weight_manager import (
@@ -473,10 +474,12 @@ def event_loop_normal(self) -> None:
473474
self.worker.model_runner,
474475
self.parallel_config.engine_worker_queue_port,
475476
)
476-
logger.debug(f"current task queue data: {self.task_queue.num_tasks()}")
477+
logger.info(f"current task queue data: {self.task_queue.num_tasks()}")
477478
self.task_queue.clear_data()
478479
self.model_weights_signal[0] = ModelWeightsStatus.NORMAL
479-
logger.debug(f"Rank: {self.local_rank} has updated or cleared parameters.")
480+
logger.info(f"Rank: {self.local_rank} has updated or cleared parameters.")
481+
while self.model_weights_status.value[0] == ModelWeightsStatus.CLEARED:
482+
time.sleep(0.01)
480483

481484
if self.exist_task_signal.value[0] == ExistTaskStatus.EXIST or self.task_queue.read_finish_flag.get() == 1:
482485
logger.info(f"Rank: {self.local_rank} Detected new requests.")

0 commit comments

Comments
 (0)