Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions fastdeploy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1488,14 +1488,17 @@ class RoutingReplayConfig:
"""Configuration for Routing Replay used in RL training"""

def __init__(self, args) -> None:

self.enable_routing_replay: bool = False

# Routing store type: local/rdma
self.routing_store_type: str = "local"

# Local routing store
self.local_store_dir: str = "./routing_replay_output"

# RDMA routing store
# TODO: Add RDMA routing store configuration attributes here when the feature is implemented.
self.rdma_store_server: str = ""

if args is not None:
for key, value in args.items():
Expand Down Expand Up @@ -1688,7 +1691,9 @@ def postprocess(self):
self.cache_config.postprocess(self.scheduler_config.max_num_batched_tokens, self.scheduler_config.max_num_seqs)
if self.model_config is not None and self.model_config.enable_mm and not envs.ENABLE_V1_KVCACHE_SCHEDULER:
self.cache_config.enable_prefix_caching = False

if self.routing_replay_config is not None and self.routing_replay_config.enable_routing_replay:
# TODO(gongshaotian): R3 support prefix caching
self.cache_config.enable_prefix_caching = False
if (
self.structured_outputs_config is not None
and self.structured_outputs_config.guided_decoding_backend != "off"
Expand Down
46 changes: 44 additions & 2 deletions fastdeploy/model_executor/layers/moe/routing_indices_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
"""

import asyncio
import copy
import os
import shutil
Expand Down Expand Up @@ -330,8 +331,49 @@ def clear_store(self):
class RoutingStoreRDMA(RoutingStoreBase):
"""Routing Store using RDMA"""

def __init__(self) -> None:
super().__init__()
def __init__(self, fd_config) -> None:
super().__init__(fd_config=fd_config)
try:
# Only used in RLHF
from p2pstore import P2PClient, P2PConfig
except ModuleNotFoundError:
raise ModuleNotFoundError(" RoutingStoreRDMA and p2pstore only support in RLHF. ")

rdma_store_server = fd_config.routing_replay_config.rdma_store_server
p2pConfig = P2PConfig(metadata_server=rdma_store_server)
self.p2p_client = P2PClient(p2pConfig)

def put(self, routing_indices: paddle.Tensor, rollout_id: str, layer_idx: int) -> None:
"""Put the routing indices into store"""
rdma_rollout_key = f"{rollout_id}_{layer_idx}"
# sync put
asyncio.run(self.p2p_client.put(rdma_rollout_key, routing_indices))

def get(
self,
rollout_id: str,
layer_idx: int = None,
) -> paddle.Tensor:
"""Get the routing indices from store"""
rdma_rollout_key = f"{rollout_id}_{layer_idx}"
# sync get
tmp_routing = asyncio.run(self.p2p_client.get(rdma_rollout_key))
return tmp_routing

def clear(
self,
rollout_id: str,
layer_idx: int = None,
) -> None:
"""Clear the routing indices of the request"""
rdma_rollout_key = f"{rollout_id}_{layer_idx}"
# sync delete
asyncio.run(self.p2p_client.delete(rdma_rollout_key))

def clear_store(self):
"""Clear the routing indices store"""
# sync clear routing store
asyncio.run(self.p2p_client.clear())


def get_routing_store(fd_config: FDConfig) -> RoutingStoreBase:
Expand Down
Loading