Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion fastdeploy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1167,14 +1167,17 @@ class RoutingReplayConfig:
"""Configuration for Routing Replay used in RL training"""

def __init__(self, args) -> None:

self.enable_routing_replay: bool = False

# Routing store type: local/rdma
self.routing_store_type: str = "local"

# Local routing store
self.local_store_dir: str = "./routing_replay_output"

# RDMA routing store
pass
self.rdma_store_server: str = ""

if args is not None:
for key, value in args.items():
Expand Down
59 changes: 57 additions & 2 deletions fastdeploy/model_executor/layers/moe/routing_indices_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
# limitations under the License.
"""

import asyncio
import copy
import os
import shutil
import time
from abc import ABC, abstractmethod
from typing import Dict, List, Optional

Expand Down Expand Up @@ -232,6 +234,11 @@ def split_request_id(self, request_id: str):
rollout_id = reversed_tmp_str[-1][::-1]
return rollout_id

def clear_request(self, batch_id: int):
"""Clear the routing indices of the request"""
self._clear_table_slot(batch_id)
self.routing_batch_to_request.pop(batch_id, None)


class RoutingStoreBase(ABC):
"""Base class for routing store"""
Expand Down Expand Up @@ -268,6 +275,7 @@ class RoutingStoreLocal(RoutingStoreBase):
def __init__(self, fd_config) -> None:
super().__init__(fd_config=fd_config)
self.local_store_dir = fd_config.routing_replay_config.local_store_dir
self.clear_store()

def put(self, routing_indices: paddle.Tensor, rollout_id: str, layer_idx: int) -> None:
"""Put the routing indices into store"""
Expand Down Expand Up @@ -315,8 +323,55 @@ def clear_store(self):
class RoutingStoreRDMA(RoutingStoreBase):
"""Routing Store using RDMA"""

def __init__(self) -> None:
super().__init__()
def __init__(self, fd_config) -> None:
super().__init__(fd_config=fd_config)
try:
# Only used in RLHF
from p2pstore import P2PClient, P2PConfig
except ModuleNotFoundError:
raise ModuleNotFoundError(" RoutingStoreRDMA and p2pstore only support in RLHF. ")

rdma_store_server = fd_config.routing_replay_config.rdma_store_server
p2pConfig = P2PConfig(metadata_server=rdma_store_server)
self.p2p_client = P2PClient(p2pConfig)
self.clear_store()

def put(self, routing_indices: paddle.Tensor, rollout_id: str, layer_idx: int) -> None:
"""Put the routing indices into store"""
rdma_rollout_key = f"{rollout_id}_{layer_idx}"

# async put
time_before_put = time.perf_counter()
routing_indices_pin = routing_indices.pin_memory()
routing_indices_np = routing_indices_pin.numpy()
asyncio.run(self.p2p_client.put(rdma_rollout_key, routing_indices_np))
print(f"Success put with key {rdma_rollout_key}, time cost is {time.perf_counter()-time_before_put} s")

def get(
self,
rollout_id: str,
layer_idx: int = None,
) -> paddle.Tensor:
"""Get the routing indices from store"""
rdma_rollout_key = f"{rollout_id}_{layer_idx}"
# sync get
tmp_routing = asyncio.run(self.p2p_client.get(rdma_rollout_key))
return tmp_routing

def clear(
self,
rollout_id: str,
layer_idx: int = None,
) -> None:
"""Clear the routing indices of the request"""
rdma_rollout_key = f"{rollout_id}_{layer_idx}"
# sync delete
asyncio.run(self.p2p_client.delete(rdma_rollout_key))

def clear_store(self):
"""Clear the routing indices store"""
# sync clear routing store
asyncio.run(self.p2p_client.clear())


def get_routing_store(fd_config: FDConfig) -> RoutingStoreBase:
Expand Down
39 changes: 39 additions & 0 deletions fastdeploy/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@
from fastdeploy import envs
from fastdeploy.input.ernie4_5_vl_processor import DataProcessor
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.layers.moe.routing_indices_cache import (
RoutingReplayManager,
)
from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ScatterOp
from fastdeploy.worker.model_runner_base import ModelRunnerBase
from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput
Expand Down Expand Up @@ -163,6 +166,11 @@ def __init__(
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(self.parallel_config.engine_worker_queue_port)
logger.info(f"queue id is {str(self.parallel_config.engine_worker_queue_port)}")

# Rollout routing replay config
self.routing_replay_manager = None
if self.fd_config.routing_replay_config.enable_routing_replay:
self.routing_replay_manager = RoutingReplayManager(fd_config=self.fd_config)

def exist_prefill(self):
"""
check whether prefill stage exist
Expand Down Expand Up @@ -313,11 +321,18 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int =
self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = 0
self.share_inputs["prompt_lens"][idx : idx + 1] = len(input_ids)
self.share_inputs["is_block_step"][idx : idx + 1] = False
self.share_inputs["is_chunk_step"][idx : idx + 1] = prefill_end_index < len(input_ids)
self.share_inputs["step_idx"][idx : idx + 1] = (
len(request.output_token_ids) if prefill_end_index >= len(input_ids) else 0
)
self.share_inputs["pre_ids"][idx : idx + 1] = -1
has_prefill_task = True

# Routing Replay
if self.fd_config.routing_replay_config.enable_routing_replay:
if prefill_start_index == 0:
self.routing_replay_manager.register_request(batch_id=idx, request_id=request.request_id)

elif request.task_type.value == RequestType.DECODE.value: # decode task
logger.debug(f"Handle decode request {request} at idx {idx}")
encoder_block_num = len(request.block_tables)
Expand All @@ -338,6 +353,11 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int =
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = 0
self.share_inputs["is_block_step"][idx : idx + 1] = False
has_preempted_task = True

# Routing Replay
if self.fd_config.routing_replay_config.enable_routing_replay:
self.routing_replay_manager.clear_request(batch_id=idx)

continue

assert len(request.eos_token_ids) == self.model_config.eos_tokens_lens
Expand Down Expand Up @@ -716,6 +736,7 @@ def _init_share_inputs(self, max_num_seqs: int):
self.share_inputs["bad_tokens_len"] = paddle.full([max_num_seqs], 1, dtype="int64")
self.share_inputs["next_tokens"] = paddle.full([max_num_seqs, 1], -1, dtype="int64")
self.share_inputs["is_block_step"] = paddle.full([max_num_seqs], False, dtype="bool")
self.share_inputs["is_chunk_step"] = paddle.full([max_num_seqs], False, dtype="bool").cpu()
self.share_inputs["encoder_block_lens"] = paddle.full([max_num_seqs], 0, dtype="int32")
self.share_inputs["step_block_list"] = paddle.full([max_num_seqs], -1, dtype="int32")
self.share_inputs["step_lens"] = paddle.full([1], 0, dtype="int32")
Expand Down Expand Up @@ -972,6 +993,9 @@ def initialize_forward_meta(self):
Initialize forward meta and attention meta data
"""
# Initialize forward meta
routing_replay_table = None
if self.routing_replay_manager is not None:
routing_replay_table = self.routing_replay_manager.get_routing_table()
self.forward_meta = ForwardMeta(
input_ids=self.share_inputs["input_ids"],
ids_remove_padding=self.share_inputs["ids_remove_padding"],
Expand All @@ -989,6 +1013,7 @@ def initialize_forward_meta(self):
cu_seqlens_k=self.share_inputs["cu_seqlens_k"],
block_tables=self.share_inputs["block_tables"],
caches=self.share_inputs["caches"],
routing_replay_table=routing_replay_table,
)

# Update Batch type for cuda graph
Expand Down Expand Up @@ -1314,6 +1339,9 @@ def _dummy_run(
if int((self.share_inputs["seq_lens_this_time"] > 0).sum()) == 0:
break

if self.fd_config.routing_replay_config.enable_routing_replay:
self.routing_replay_manager.clear_routing_table()

def _update_chunked_prefill(self, tasks):
"""
Update chunked prefill related parameters
Expand Down Expand Up @@ -1694,6 +1722,17 @@ class at the server level, which is too granular for ModelRunner.
self.seq_lens_this_time_buffer[:num_running_requests].copy_(
self.share_inputs["seq_lens_this_time"][:num_running_requests], False
)

# Routing replay
if self.fd_config.routing_replay_config.enable_routing_replay:
if (
not self.exist_prefill()
and not self.exist_decode()
and self.share_inputs["is_block_step"].sum() == 0
and self.share_inputs["is_chunk_step"].sum() == 0
):
self.routing_replay_manager.put_table_to_store()

return None

def _add_cache(self, model_forward_batch) -> None:
Expand Down
Loading