7575from fastdeploy import envs
7676from fastdeploy .input .ernie4_5_vl_processor import DataProcessor
7777from fastdeploy .model_executor .forward_meta import ForwardMeta
78+ from fastdeploy .model_executor .layers .moe .routing_indices_cache import (
79+ RoutingReplayManager ,
80+ )
7881from fastdeploy .model_executor .models .ernie4_5_vl .modeling_resampler import ScatterOp
7982from fastdeploy .worker .model_runner_base import ModelRunnerBase
8083from fastdeploy .worker .output import ModelOutputData , ModelRunnerOutput
@@ -163,6 +166,11 @@ def __init__(
163166 os .environ ["INFERENCE_MSG_QUEUE_ID" ] = str (self .parallel_config .engine_worker_queue_port )
164167 logger .info (f"queue id is { str (self .parallel_config .engine_worker_queue_port )} " )
165168
169+ # Rollout routing replay config
170+ self .routing_replay_manager = None
171+ if self .fd_config .routing_replay_config .enable_routing_replay :
172+ self .routing_replay_manager = RoutingReplayManager (fd_config = self .fd_config )
173+
166174 def exist_prefill (self ):
167175 """
168176 check whether prefill stage exist
@@ -313,11 +321,18 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int =
313321 self .share_inputs ["step_seq_lens_decoder" ][idx : idx + 1 ] = 0
314322 self .share_inputs ["prompt_lens" ][idx : idx + 1 ] = len (input_ids )
315323 self .share_inputs ["is_block_step" ][idx : idx + 1 ] = False
324+ self .share_inputs ["is_chunk_step" ][idx : idx + 1 ] = prefill_end_index < len (input_ids )
316325 self .share_inputs ["step_idx" ][idx : idx + 1 ] = (
317326 len (request .output_token_ids ) if prefill_end_index >= len (input_ids ) else 0
318327 )
319328 self .share_inputs ["pre_ids" ][idx : idx + 1 ] = - 1
320329 has_prefill_task = True
330+
331+ # Routing Replay
332+ if self .fd_config .routing_replay_config .enable_routing_replay :
333+ if prefill_start_index == 0 :
334+ self .routing_replay_manager .register_request (batch_id = idx , request_id = request .request_id )
335+
321336 elif request .task_type .value == RequestType .DECODE .value : # decode task
322337 logger .debug (f"Handle decode request { request } at idx { idx } " )
323338 encoder_block_num = len (request .block_tables )
@@ -338,6 +353,11 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int =
338353 self .share_inputs ["seq_lens_encoder" ][idx : idx + 1 ] = 0
339354 self .share_inputs ["is_block_step" ][idx : idx + 1 ] = False
340355 has_preempted_task = True
356+
357+ # Routing Replay
358+ if self .fd_config .routing_replay_config .enable_routing_replay :
359+ self .routing_replay_manager .clear_request (batch_id = idx )
360+
341361 continue
342362
343363 assert len (request .eos_token_ids ) == self .model_config .eos_tokens_lens
@@ -716,6 +736,7 @@ def _init_share_inputs(self, max_num_seqs: int):
716736 self .share_inputs ["bad_tokens_len" ] = paddle .full ([max_num_seqs ], 1 , dtype = "int64" )
717737 self .share_inputs ["next_tokens" ] = paddle .full ([max_num_seqs , 1 ], - 1 , dtype = "int64" )
718738 self .share_inputs ["is_block_step" ] = paddle .full ([max_num_seqs ], False , dtype = "bool" )
739+ self .share_inputs ["is_chunk_step" ] = paddle .full ([max_num_seqs ], False , dtype = "bool" ).cpu ()
719740 self .share_inputs ["encoder_block_lens" ] = paddle .full ([max_num_seqs ], 0 , dtype = "int32" )
720741 self .share_inputs ["step_block_list" ] = paddle .full ([max_num_seqs ], - 1 , dtype = "int32" )
721742 self .share_inputs ["step_lens" ] = paddle .full ([1 ], 0 , dtype = "int32" )
@@ -972,6 +993,9 @@ def initialize_forward_meta(self):
972993 Initialize forward meta and attention meta data
973994 """
974995 # Initialize forward meta
996+ routing_replay_table = None
997+ if self .routing_replay_manager is not None :
998+ routing_replay_table = self .routing_replay_manager .get_routing_table ()
975999 self .forward_meta = ForwardMeta (
9761000 input_ids = self .share_inputs ["input_ids" ],
9771001 ids_remove_padding = self .share_inputs ["ids_remove_padding" ],
@@ -989,6 +1013,7 @@ def initialize_forward_meta(self):
9891013 cu_seqlens_k = self .share_inputs ["cu_seqlens_k" ],
9901014 block_tables = self .share_inputs ["block_tables" ],
9911015 caches = self .share_inputs ["caches" ],
1016+ routing_replay_table = routing_replay_table ,
9921017 )
9931018
9941019 # Update Batch type for cuda graph
@@ -1314,6 +1339,9 @@ def _dummy_run(
13141339 if int ((self .share_inputs ["seq_lens_this_time" ] > 0 ).sum ()) == 0 :
13151340 break
13161341
1342+ if self .fd_config .routing_replay_config .enable_routing_replay :
1343+ self .routing_replay_manager .clear_routing_table ()
1344+
13171345 def _update_chunked_prefill (self , tasks ):
13181346 """
13191347 Update chunked prefill related parameters
@@ -1694,6 +1722,17 @@ class at the server level, which is too granular for ModelRunner.
16941722 self .seq_lens_this_time_buffer [:num_running_requests ].copy_ (
16951723 self .share_inputs ["seq_lens_this_time" ][:num_running_requests ], False
16961724 )
1725+
1726+ # Routing replay
1727+ if self .fd_config .routing_replay_config .enable_routing_replay :
1728+ if (
1729+ not self .exist_prefill ()
1730+ and not self .exist_decode ()
1731+ and self .share_inputs ["is_block_step" ].sum () == 0
1732+ and self .share_inputs ["is_chunk_step" ].sum () == 0
1733+ ):
1734+ self .routing_replay_manager .put_table_to_store ()
1735+
16971736 return None
16981737
16991738 def _add_cache (self , model_forward_batch ) -> None :
0 commit comments