@@ -62,17 +62,18 @@ def _capture_model_state(self):
6262 logger .info (f"Model param: { name } , shape={ param .shape } , dtype={ param .dtype } " )
6363 self .state_dict [name ] = param
6464
65- def update_parameters (self , pid : int = 0 ) -> None :
65+ def update_parameters (self , pid : int = 0 , restart_process_group = False ) -> None :
6666 """Core method to update model parameters based on strategy."""
6767 start_time = time .perf_counter ()
6868 paddle .device .cuda .empty_cache ()
6969
7070 # step1 : restart paddle process group
71- # if not self.first_load:
72- # paddle.distributed.restart_process_group()
73- # paddle.distributed.restart_process_group(self.parallel_config.tp_group)
74- # if self.parallel_config.enable_expert_parallel:
75- # paddle.distributed.restart_process_group(self.parallel_config.ep_group)
71+ if not self .first_load :
72+ if restart_process_group :
73+ paddle .distributed .restart_process_group ()
74+ paddle .distributed .restart_process_group (self .parallel_config .tp_group )
75+ if self .parallel_config .enable_expert_parallel :
76+ paddle .distributed .restart_process_group (self .parallel_config .ep_group )
7677
7778 # step2 : recreat deepep buffer when enable expert parallel
7879 if self .parallel_config .enable_expert_parallel and not self .first_load :
@@ -132,7 +133,7 @@ def _update_ipc(self):
132133 self ._update_model_from_state (state_dict , "raw" )
133134 logger .info (f"IPC update parameters completed from file: { self .ipc_path } " )
134135
135- def clear_parameters (self , pid : int = 0 ) -> None :
136+ def clear_parameters (self , pid : int = 0 , shutdown_process_group = False ) -> None :
136137 """Clear all model parameters and free memory."""
137138
138139 logger .info ("start clear paramaters" )
@@ -144,8 +145,9 @@ def clear_parameters(self, pid: int = 0) -> None:
144145 DeepEPBufferManager .clear_buffer ()
145146 # ep barrier
146147 paddle .distributed .barrier (self .parallel_config .ep_group )
147- # shutdown ep group
148- # paddle.distributed.shutdown_process_group(self.parallel_config.ep_group)
148+ if shutdown_process_group :
149+ # shutdown ep group
150+ paddle .distributed .shutdown_process_group (self .parallel_config .ep_group )
149151
150152 paddle .device .cuda .empty_cache ()
151153 # step2: release model weight
@@ -158,11 +160,14 @@ def clear_parameters(self, pid: int = 0) -> None:
158160 if self .parallel_config .tensor_parallel_size > 1 :
159161 # tp barrier
160162 paddle .distributed .barrier (self .parallel_config .tp_group )
161- # paddle.distributed.shutdown_process_group(self.parallel_config.tp_group)
163+ if shutdown_process_group :
164+ paddle .distributed .shutdown_process_group (self .parallel_config .tp_group )
162165 if self .parallel_config .enable_expert_parallel :
163166 paddle .distributed .barrier (self .parallel_config .ep_group )
164- # paddle.distributed.shutdown_process_group(self.parallel_config.ep_group)
165- # paddle.distributed.shutdown_process_group()
167+ if shutdown_process_group :
168+ paddle .distributed .shutdown_process_group (self .parallel_config .ep_group )
169+ if shutdown_process_group :
170+ paddle .distributed .shutdown_process_group ()
166171 self ._update_shared_status (pid , ModelWeightsStatus .CLEARED )
167172
168173 def _update_model_from_state (self , state_dict : Dict [str , paddle .Tensor ], src_type : str ):
0 commit comments