Skip to content

Commit 8981ce8

Browse files
authored
[Cherry-Pick][RL] R3 Support RDMA Store (#5454)
* [RL] R3 support rdma store * refine notes * refine code * support preempted task and put cpu tensor
1 parent 196d624 commit 8981ce8

File tree

3 files changed

+100
-3
lines changed

3 files changed

+100
-3
lines changed

fastdeploy/config.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1167,14 +1167,17 @@ class RoutingReplayConfig:
11671167
"""Configuration for Routing Replay used in RL training"""
11681168

11691169
def __init__(self, args) -> None:
1170+
11701171
self.enable_routing_replay: bool = False
1172+
1173+
# Routing store type: local/rdma
11711174
self.routing_store_type: str = "local"
11721175

11731176
# Local routing store
11741177
self.local_store_dir: str = "./routing_replay_output"
11751178

11761179
# RDMA routing store
1177-
pass
1180+
self.rdma_store_server: str = ""
11781181

11791182
if args is not None:
11801183
for key, value in args.items():

fastdeploy/model_executor/layers/moe/routing_indices_cache.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@
1414
# limitations under the License.
1515
"""
1616

17+
import asyncio
1718
import copy
1819
import os
1920
import shutil
21+
import time
2022
from abc import ABC, abstractmethod
2123
from typing import Dict, List, Optional
2224

@@ -232,6 +234,11 @@ def split_request_id(self, request_id: str):
232234
rollout_id = reversed_tmp_str[-1][::-1]
233235
return rollout_id
234236

237+
def clear_request(self, batch_id: int):
238+
"""Clear the routing indices of the request"""
239+
self._clear_table_slot(batch_id)
240+
self.routing_batch_to_request.pop(batch_id, None)
241+
235242

236243
class RoutingStoreBase(ABC):
237244
"""Base class for routing store"""
@@ -268,6 +275,7 @@ class RoutingStoreLocal(RoutingStoreBase):
268275
def __init__(self, fd_config) -> None:
269276
super().__init__(fd_config=fd_config)
270277
self.local_store_dir = fd_config.routing_replay_config.local_store_dir
278+
self.clear_store()
271279

272280
def put(self, routing_indices: paddle.Tensor, rollout_id: str, layer_idx: int) -> None:
273281
"""Put the routing indices into store"""
@@ -315,8 +323,55 @@ def clear_store(self):
315323
class RoutingStoreRDMA(RoutingStoreBase):
316324
"""Routing Store using RDMA"""
317325

318-
def __init__(self) -> None:
319-
super().__init__()
326+
def __init__(self, fd_config) -> None:
327+
super().__init__(fd_config=fd_config)
328+
try:
329+
# Only used in RLHF
330+
from p2pstore import P2PClient, P2PConfig
331+
except ModuleNotFoundError:
332+
raise ModuleNotFoundError(" RoutingStoreRDMA and p2pstore only support in RLHF. ")
333+
334+
rdma_store_server = fd_config.routing_replay_config.rdma_store_server
335+
p2pConfig = P2PConfig(metadata_server=rdma_store_server)
336+
self.p2p_client = P2PClient(p2pConfig)
337+
self.clear_store()
338+
339+
def put(self, routing_indices: paddle.Tensor, rollout_id: str, layer_idx: int) -> None:
340+
"""Put the routing indices into store"""
341+
rdma_rollout_key = f"{rollout_id}_{layer_idx}"
342+
343+
# async put
344+
time_before_put = time.perf_counter()
345+
routing_indices_pin = routing_indices.pin_memory()
346+
routing_indices_np = routing_indices_pin.numpy()
347+
asyncio.run(self.p2p_client.put(rdma_rollout_key, routing_indices_np))
348+
print(f"Success put with key {rdma_rollout_key}, time cost is {time.perf_counter()-time_before_put} s")
349+
350+
def get(
351+
self,
352+
rollout_id: str,
353+
layer_idx: int = None,
354+
) -> paddle.Tensor:
355+
"""Get the routing indices from store"""
356+
rdma_rollout_key = f"{rollout_id}_{layer_idx}"
357+
# sync get
358+
tmp_routing = asyncio.run(self.p2p_client.get(rdma_rollout_key))
359+
return tmp_routing
360+
361+
def clear(
362+
self,
363+
rollout_id: str,
364+
layer_idx: int = None,
365+
) -> None:
366+
"""Clear the routing indices of the request"""
367+
rdma_rollout_key = f"{rollout_id}_{layer_idx}"
368+
# sync delete
369+
asyncio.run(self.p2p_client.delete(rdma_rollout_key))
370+
371+
def clear_store(self):
372+
"""Clear the routing indices store"""
373+
# sync clear routing store
374+
asyncio.run(self.p2p_client.clear())
320375

321376

322377
def get_routing_store(fd_config: FDConfig) -> RoutingStoreBase:

fastdeploy/worker/gpu_model_runner.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@
7575
from fastdeploy import envs
7676
from fastdeploy.input.ernie4_5_vl_processor import DataProcessor
7777
from fastdeploy.model_executor.forward_meta import ForwardMeta
78+
from fastdeploy.model_executor.layers.moe.routing_indices_cache import (
79+
RoutingReplayManager,
80+
)
7881
from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ScatterOp
7982
from fastdeploy.worker.model_runner_base import ModelRunnerBase
8083
from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput
@@ -163,6 +166,11 @@ def __init__(
163166
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(self.parallel_config.engine_worker_queue_port)
164167
logger.info(f"queue id is {str(self.parallel_config.engine_worker_queue_port)}")
165168

169+
# Rollout routing replay config
170+
self.routing_replay_manager = None
171+
if self.fd_config.routing_replay_config.enable_routing_replay:
172+
self.routing_replay_manager = RoutingReplayManager(fd_config=self.fd_config)
173+
166174
def exist_prefill(self):
167175
"""
168176
check whether prefill stage exist
@@ -313,11 +321,18 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int =
313321
self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = 0
314322
self.share_inputs["prompt_lens"][idx : idx + 1] = len(input_ids)
315323
self.share_inputs["is_block_step"][idx : idx + 1] = False
324+
self.share_inputs["is_chunk_step"][idx : idx + 1] = prefill_end_index < len(input_ids)
316325
self.share_inputs["step_idx"][idx : idx + 1] = (
317326
len(request.output_token_ids) if prefill_end_index >= len(input_ids) else 0
318327
)
319328
self.share_inputs["pre_ids"][idx : idx + 1] = -1
320329
has_prefill_task = True
330+
331+
# Routing Replay
332+
if self.fd_config.routing_replay_config.enable_routing_replay:
333+
if prefill_start_index == 0:
334+
self.routing_replay_manager.register_request(batch_id=idx, request_id=request.request_id)
335+
321336
elif request.task_type.value == RequestType.DECODE.value: # decode task
322337
logger.debug(f"Handle decode request {request} at idx {idx}")
323338
encoder_block_num = len(request.block_tables)
@@ -338,6 +353,11 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int =
338353
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = 0
339354
self.share_inputs["is_block_step"][idx : idx + 1] = False
340355
has_preempted_task = True
356+
357+
# Routing Replay
358+
if self.fd_config.routing_replay_config.enable_routing_replay:
359+
self.routing_replay_manager.clear_request(batch_id=idx)
360+
341361
continue
342362

343363
assert len(request.eos_token_ids) == self.model_config.eos_tokens_lens
@@ -716,6 +736,7 @@ def _init_share_inputs(self, max_num_seqs: int):
716736
self.share_inputs["bad_tokens_len"] = paddle.full([max_num_seqs], 1, dtype="int64")
717737
self.share_inputs["next_tokens"] = paddle.full([max_num_seqs, 1], -1, dtype="int64")
718738
self.share_inputs["is_block_step"] = paddle.full([max_num_seqs], False, dtype="bool")
739+
self.share_inputs["is_chunk_step"] = paddle.full([max_num_seqs], False, dtype="bool").cpu()
719740
self.share_inputs["encoder_block_lens"] = paddle.full([max_num_seqs], 0, dtype="int32")
720741
self.share_inputs["step_block_list"] = paddle.full([max_num_seqs], -1, dtype="int32")
721742
self.share_inputs["step_lens"] = paddle.full([1], 0, dtype="int32")
@@ -972,6 +993,9 @@ def initialize_forward_meta(self):
972993
Initialize forward meta and attention meta data
973994
"""
974995
# Initialize forward meta
996+
routing_replay_table = None
997+
if self.routing_replay_manager is not None:
998+
routing_replay_table = self.routing_replay_manager.get_routing_table()
975999
self.forward_meta = ForwardMeta(
9761000
input_ids=self.share_inputs["input_ids"],
9771001
ids_remove_padding=self.share_inputs["ids_remove_padding"],
@@ -989,6 +1013,7 @@ def initialize_forward_meta(self):
9891013
cu_seqlens_k=self.share_inputs["cu_seqlens_k"],
9901014
block_tables=self.share_inputs["block_tables"],
9911015
caches=self.share_inputs["caches"],
1016+
routing_replay_table=routing_replay_table,
9921017
)
9931018

9941019
# Update Batch type for cuda graph
@@ -1314,6 +1339,9 @@ def _dummy_run(
13141339
if int((self.share_inputs["seq_lens_this_time"] > 0).sum()) == 0:
13151340
break
13161341

1342+
if self.fd_config.routing_replay_config.enable_routing_replay:
1343+
self.routing_replay_manager.clear_routing_table()
1344+
13171345
def _update_chunked_prefill(self, tasks):
13181346
"""
13191347
Update chunked prefill related parameters
@@ -1694,6 +1722,17 @@ class at the server level, which is too granular for ModelRunner.
16941722
self.seq_lens_this_time_buffer[:num_running_requests].copy_(
16951723
self.share_inputs["seq_lens_this_time"][:num_running_requests], False
16961724
)
1725+
1726+
# Routing replay
1727+
if self.fd_config.routing_replay_config.enable_routing_replay:
1728+
if (
1729+
not self.exist_prefill()
1730+
and not self.exist_decode()
1731+
and self.share_inputs["is_block_step"].sum() == 0
1732+
and self.share_inputs["is_chunk_step"].sum() == 0
1733+
):
1734+
self.routing_replay_manager.put_table_to_store()
1735+
16971736
return None
16981737

16991738
def _add_cache(self, model_forward_batch) -> None:

0 commit comments

Comments
 (0)