diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 24b26161a47..cc9eec47ca9 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -1488,14 +1488,17 @@ class RoutingReplayConfig: """Configuration for Routing Replay used in RL training""" def __init__(self, args) -> None: + self.enable_routing_replay: bool = False + + # Routing store type: local/rdma self.routing_store_type: str = "local" # Local routing store self.local_store_dir: str = "./routing_replay_output" # RDMA routing store - # TODO: Add RDMA routing store configuration attributes here when the feature is implemented. + self.rdma_store_server: str = "" if args is not None: for key, value in args.items(): @@ -1688,7 +1691,9 @@ def postprocess(self): self.cache_config.postprocess(self.scheduler_config.max_num_batched_tokens, self.scheduler_config.max_num_seqs) if self.model_config is not None and self.model_config.enable_mm and not envs.ENABLE_V1_KVCACHE_SCHEDULER: self.cache_config.enable_prefix_caching = False - + if self.routing_replay_config is not None and self.routing_replay_config.enable_routing_replay: + # TODO(gongshaotian): R3 support prefix caching + self.cache_config.enable_prefix_caching = False if ( self.structured_outputs_config is not None and self.structured_outputs_config.guided_decoding_backend != "off" diff --git a/fastdeploy/model_executor/layers/moe/routing_indices_cache.py b/fastdeploy/model_executor/layers/moe/routing_indices_cache.py index e95a3d8569f..ad633207fad 100644 --- a/fastdeploy/model_executor/layers/moe/routing_indices_cache.py +++ b/fastdeploy/model_executor/layers/moe/routing_indices_cache.py @@ -14,6 +14,7 @@ # limitations under the License. """ +import asyncio import copy import os import shutil @@ -330,8 +331,49 @@ def clear_store(self): class RoutingStoreRDMA(RoutingStoreBase): """Routing Store using RDMA""" - def __init__(self) -> None: - super().__init__() + def __init__(self, fd_config) -> None: + super().__init__(fd_config=fd_config) + try: + # Only used in RLHF + from p2pstore import P2PClient, P2PConfig + except ModuleNotFoundError: + raise ModuleNotFoundError(" RoutingStoreRDMA and p2pstore only support in RLHF. ") + + rdma_store_server = fd_config.routing_replay_config.rdma_store_server + p2pConfig = P2PConfig(metadata_server=rdma_store_server) + self.p2p_client = P2PClient(p2pConfig) + + def put(self, routing_indices: paddle.Tensor, rollout_id: str, layer_idx: int) -> None: + """Put the routing indices into store""" + rdma_rollout_key = f"{rollout_id}_{layer_idx}" + # sync put + asyncio.run(self.p2p_client.put(rdma_rollout_key, routing_indices)) + + def get( + self, + rollout_id: str, + layer_idx: int = None, + ) -> paddle.Tensor: + """Get the routing indices from store""" + rdma_rollout_key = f"{rollout_id}_{layer_idx}" + # sync get + tmp_routing = asyncio.run(self.p2p_client.get(rdma_rollout_key)) + return tmp_routing + + def clear( + self, + rollout_id: str, + layer_idx: int = None, + ) -> None: + """Clear the routing indices of the request""" + rdma_rollout_key = f"{rollout_id}_{layer_idx}" + # sync delete + asyncio.run(self.p2p_client.delete(rdma_rollout_key)) + + def clear_store(self): + """Clear the routing indices store""" + # sync clear routing store + asyncio.run(self.p2p_client.clear()) def get_routing_store(fd_config: FDConfig) -> RoutingStoreBase: