diff --git a/custom_ops/gpu_ops/swap_cache_layout.cu b/custom_ops/gpu_ops/swap_cache_layout.cu new file mode 100644 index 00000000000..6af6a198f54 --- /dev/null +++ b/custom_ops/gpu_ops/swap_cache_layout.cu @@ -0,0 +1,136 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "helper.h" +#include "paddle/extension.h" + +// #define SWAP_DEBUG + +template +void SwapCacheImpLayout( + const std::vector& cache_gpu_tensors, // gpu + const int64_t& cache_cpu_pointer, // cpu + const std::vector& cache_shape, + const std::vector& swap_block_ids_gpu, + int mode) { + typedef PDTraits traits_; + typedef typename traits_::DataType DataType_; + typedef typename traits_::data_t data_t; + const int64_t layer_number = cache_gpu_tensors.size(); +#ifdef SWAP_DEBUG + std::cout << "layer_number " << layer_number << std::endl; + std::cout << "cache_shape size: test" << std::endl; + std::cout << "cache_shape:" << cache_shape[0] << ", " << cache_shape[1] + << ", " << cache_shape[2] << ", " << cache_shape[3] << std::endl; +#endif + + const int64_t max_block_num_gpu = cache_shape[0]; + const int64_t num_heads = cache_shape[1]; + const int64_t block_size = cache_shape[2]; + const int64_t head_dim = cache_shape[3]; + const int64_t cache_stride = num_heads * block_size * head_dim; +#ifdef SWAP_DEBUG + std::cout << "cache_stride " << cache_stride << std::endl; +#endif + + auto stream = cache_gpu_tensors[0].stream(); + const cudaMemcpyKind copy_kind = + (mode == 0) ? cudaMemcpyDeviceToHost : cudaMemcpyHostToDevice; + // 【layer, block, block size,head num, head dim】 + // 【block, layer, block size,head num, head dim】 + auto* cache_cpu_ptr = reinterpret_cast(cache_cpu_pointer); +#ifdef SWAP_DEBUG + std::cout << "cache_cpu_ptr: " << cache_cpu_ptr << std::endl; +#endif + + for (int layer_idx = 0; layer_idx < cache_gpu_tensors.size(); layer_idx++) { + const paddle::Tensor& cache_gpu = cache_gpu_tensors[layer_idx]; + data_t* cache_gpu_ptr = const_cast(cache_gpu.data()); + auto stream = cache_gpu.stream(); + for (int id = 0; id < swap_block_ids_gpu.size(); id++) { +#ifdef SWAP_DEBUG + std::cout << "current block " << swap_block_ids_gpu[id] << std::endl; +#endif + + auto* cache_gpu_ptr_now = + cache_gpu_ptr + swap_block_ids_gpu[id] * cache_stride; + auto* cache_cpu_ptr_now = cache_cpu_ptr + + id * cache_stride * layer_number + + layer_idx * cache_stride; + +#ifdef SWAP_DEBUG + std::cout << "current data" << *cache_cpu_ptr_now << std::endl; +#endif + cudaError_t status = cudaMemcpyAsync( + (copy_kind == cudaMemcpyDeviceToHost) ? cache_cpu_ptr_now + : cache_gpu_ptr_now, + (copy_kind == cudaMemcpyDeviceToHost) ? cache_gpu_ptr_now + : cache_cpu_ptr_now, + cache_stride * sizeof(DataType_), + copy_kind, + stream); +#ifdef SWAP_DEBUG + std::cout << "current data11 " << *cache_cpu_ptr_now << std::endl; +#endif + } + } + cudaStreamSynchronize(stream); +#ifdef SWAP_DEBUG + std::cout << "finished " << std::endl; +#endif +} + +void SwapCacheLayout( + const std::vector& cache_gpu_tensors, // gpu + const int64_t& cache_cpu_ptrs, // cpu memory pointer + const std::vector& cache_shape, + const std::vector& swap_block_ids_gpu, + int rank, + int mode) { + cudaSetDevice(rank); // used for distributed launch + assert(cache_gpu_tensors.size() > 0); + switch (cache_gpu_tensors[0].dtype()) { + case paddle::DataType::BFLOAT16: + return SwapCacheImpLayout(cache_gpu_tensors, + cache_cpu_ptrs, + cache_shape, + swap_block_ids_gpu, + mode); + case paddle::DataType::FLOAT16: + return SwapCacheImpLayout(cache_gpu_tensors, + cache_cpu_ptrs, + cache_shape, + swap_block_ids_gpu, + mode); + case paddle::DataType::UINT8: + return SwapCacheImpLayout(cache_gpu_tensors, + cache_cpu_ptrs, + cache_shape, + swap_block_ids_gpu, + mode); + default: + PD_THROW("Unsupported data type."); + } +} + +PD_BUILD_STATIC_OP(swap_cache_layout) + .Inputs({paddle::Vec("cache_gpu_tensors")}) + .Attrs({ + "cache_cpu_ptrs: int64_t", + "cache_shape: std::vector", + "swap_block_ids_gpu: std::vector", + "rank: int", + "mode: int", + }) + .SetKernelFn(PD_KERNEL(SwapCacheLayout)); diff --git a/custom_ops/setup_ops.py b/custom_ops/setup_ops.py index 40900b18771..b3b5ba3d000 100644 --- a/custom_ops/setup_ops.py +++ b/custom_ops/setup_ops.py @@ -288,6 +288,7 @@ def find_end_files(directory, end_str): "gpu_ops/tune_cublaslt_gemm.cu", "gpu_ops/swap_cache_batch.cu", "gpu_ops/swap_cache.cu", + "gpu_ops/swap_cache_layout.cu", "gpu_ops/step_system_cache.cu", "gpu_ops/cpp_extensions.cc", "gpu_ops/share_external_data.cu", diff --git a/fastdeploy/cache_manager/cache_data.py b/fastdeploy/cache_manager/cache_data.py index 631f5efb05a..6dba01bdf7a 100644 --- a/fastdeploy/cache_manager/cache_data.py +++ b/fastdeploy/cache_manager/cache_data.py @@ -42,6 +42,8 @@ class CacheStatus(Enum): SWAP2CPU = 1 SWAP2GPU = 2 CPU = 3 + GPU2STORAGE = 4 + STORAGE2GPU = 5 class BlockNode: diff --git a/fastdeploy/cache_manager/cache_transfer_manager.py b/fastdeploy/cache_manager/cache_transfer_manager.py index b2b8218c805..826c0da904a 100644 --- a/fastdeploy/cache_manager/cache_transfer_manager.py +++ b/fastdeploy/cache_manager/cache_transfer_manager.py @@ -15,6 +15,7 @@ """ import argparse +import asyncio import concurrent.futures import gc import json @@ -36,8 +37,10 @@ set_device, share_external_data_, swap_cache_all_layers, + swap_cache_layout, unset_data_ipc, ) +from fastdeploy.cache_manager.transfer_factory import MooncakeStore from fastdeploy.config import SpeculativeConfig from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, KVCacheStatus from fastdeploy.platforms import current_platform @@ -57,6 +60,7 @@ def parse_args(): ) parser.add_argument("--rank", type=int, default=0, help="current rank") parser.add_argument("--device_id", type=int, default=0, help="device id") + parser.add_argument("--max_model_length", type=int, default=32768, help="max model length") parser.add_argument("--num_layers", type=int, default=1, help="model num layers") parser.add_argument("--mp_num", type=int, default=1, help="number of model parallel") parser.add_argument( @@ -94,11 +98,36 @@ def parse_args(): help="speculative config", ) parser.add_argument("--create_cache_tensor", action="store_true") + parser.add_argument( + "--kvcache_storage_backend", + type=str, + default="None", + choices=["mooncake", "None"], + help="The storage backend for kvcache storage.", + ) + parser.add_argument( + "--write_policy", + type=str, + choices=["write_through"], + default="write_through", + help="KVCache write policy", + ) args = parser.parse_args() return args +class TimeoutController: + def __init__(self): + self._stop_event = threading.Event() + + def stop(self): + self._stop_event.set() + + def should_stop(self): + return self._stop_event.is_set() + + class CacheTransferManager: """ 管理CPU和GPU之间缓存的交换传输 @@ -127,6 +156,9 @@ def __init__(self, args): self.swap_to_cpu_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) self.swap_to_gpu_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) + self.swap_to_storage_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) + self.write_to_storage_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) + self.timeout_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=2) self.transfer_task_queue = queue.Queue() # 用来接收传输任务 self.tansfer_done_queue = queue.Queue() # 用来告知任务执行完毕 self.n_ranks = args.mp_num @@ -186,8 +218,39 @@ def __init__(self, args): suffix=args.engine_worker_queue_port, create=False, ) + storage_backend = args.kvcache_storage_backend + + if storage_backend == "None": + self.storage_backend = None + elif storage_backend == "mooncake": + self.storage_backend = MooncakeStore() + self._init_storage_buffer() + else: + raise NotImplementedError(f"Unsupported storage backend: {storage_backend}") + + write_policy = args.write_policy + if write_policy not in [ + "write_through", + "write_back", + ]: + raise ValueError(f"Invalid write policy: {write_policy}") + self.write_policy = write_policy + threading.Thread(target=self.clear_or_update_caches, args=[args], daemon=True).start() + def _init_storage_buffer(self): + total_layers = args.num_layers + self.num_extra_layers + need_to_allocate_bytes = ( + args.max_model_length * self.key_cache_shape[1] * self.key_cache_shape[3] * total_layers * 2 + ) + self.cache_stride = self.key_cache_shape[1] * self.key_cache_shape[2] * self.key_cache_shape[3] * total_layers + logger.info( + f"[rank {self.rank}/{self.n_ranks}] ..creating cpu cache for alllayers {total_layers}: {2 * need_to_allocate_bytes / 1024 ** 3:.2f}GB" + ) + self.key_register_buffer = cuda_host_alloc(need_to_allocate_bytes * 2) + self.val_register_buffer = self.key_register_buffer + need_to_allocate_bytes + self.storage_backend.register_buffer(self.key_register_buffer, need_to_allocate_bytes * 2) + def _init_gpu_cache(self, args): try: @@ -359,6 +422,183 @@ def _init_cpu_cache(self, args): logger.info(f"[rank {self.rank}/{self.n_ranks}] ✅ swap space (cpu cache) is ready!") self.swap_space_ready_signal.value[self.rank] = 1 + async def _run_async_load(self, hash_keys, gpu_block_ids): + keys_k = [f"{key}_key_{self.rank}" for key in hash_keys] + keys_v = [f"{key}_value_{self.rank}" for key in hash_keys] + + target_location_k = [self.key_register_buffer + i * self.cache_stride for i in range(len(gpu_block_ids))] + target_location_v = [self.val_register_buffer + i * self.cache_stride for i in range(len(gpu_block_ids))] + + target_sizes = [self.cache_stride] * len(gpu_block_ids) * 2 + + keys = keys_k + keys_v + target_location = target_location_k + target_location_v + + self.storage_backend.get(keys, target_location=target_location, target_sizes=target_sizes) + + swap_cache_layout( + self.gpu_cache_k_tensors, + self.key_register_buffer, + self.key_cache_shape, + gpu_block_ids, + self.device, + 1, # cpu ==> gpu + ) + swap_cache_layout( + self.gpu_cache_v_tensors, self.val_register_buffer, self.value_cache_shape, gpu_block_ids, self.device, 1 + ) + + def load_storage_task(self, task_id, hash_keys, gpu_block_ids, timeout=0.1): + keys = [f"{key}_key_{self.rank}" for key in hash_keys] + results = self.storage_backend.exists(keys) + current_number = 0 + for _, exist in results.items(): + if exist: + current_number += 1 + else: + break + gpu_block_ids = gpu_block_ids[:current_number] + # TODO + # timeout 系数 自行调节 + # timeout = 0.100 * len(gpu_block_ids) + # 计算最小block 传输时间,应传尽传 + if current_number > 0: + try: + # Create new event loop if needed + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + # Run with timeout + loop.run_until_complete( + asyncio.wait_for(self._run_async_load(hash_keys, gpu_block_ids), timeout=timeout) + ) + except Exception as e: + logger.error(f"[rank {self.rank}/{self.n_ranks}] An error occurred: {task_id} {e}") + gpu_block_ids = [] + + result = (hash_keys, gpu_block_ids, [], CacheStatus.STORAGE2GPU, task_id) + self.cache_task_queue.swap_storage_to_gpu_barrier.wait() + if self.rank == 0: + if current_number > 0: + logger.info( + f"[rank {self.rank}/{self.n_ranks}] {current_number} data found in storage for task {task_id}, finish loading." + ) + self.cache_task_queue.swap_storage_to_gpu_barrier.reset() + self.cache_task_queue.put_transfer_done_signal(result) + + async def _run_async_write(self, uncached_keys_k, uncached_keys_v, uncached_block_ids): + try: + # logger.info(f"[rank {self.rank}/{self.n_ranks}] write cache to storage {uncached_keys_k} {uncached_block_ids}") + key_cache_size = [ + self.key_cache_shape[0], + self.key_cache_shape[1], + self.key_cache_shape[2], + self.key_cache_shape[3], + ] + swap_cache_layout( + self.gpu_cache_k_tensors, + self.key_register_buffer, + key_cache_size, + uncached_block_ids, + self.device, + 0, # gpu ==> cpu + ) + swap_cache_layout( + self.gpu_cache_v_tensors, + self.val_register_buffer, + key_cache_size, + uncached_block_ids, + self.device, + 0, # gpu ==> cpu + ) + + # Prepare locations + target_location_k = [ + self.key_register_buffer + i * self.cache_stride for i in range(len(uncached_block_ids)) + ] + target_location_v = [ + self.val_register_buffer + i * self.cache_stride for i in range(len(uncached_block_ids)) + ] + + target_sizes = [self.cache_stride] * len(uncached_block_ids) * 2 + target_location = target_location_k + target_location_v + + logger.info(f"write cache to storage {uncached_keys_k + uncached_keys_v} {target_location} {target_sizes}") + + # Execute storage set operation + self.storage_backend.set( + uncached_keys_k + uncached_keys_v, target_location=target_location, target_sizes=target_sizes + ) + except Exception as e: + logger.error(f"An error occurred during writing to storage: {e}") + + def write_back_storage_task(self, keys, gpu_block_ids, transfer_task_id, timeout=0.1): + """ + writeback kv cache to storage with coroutine-based timeout control + """ + logger.debug(f"write cache to storage {keys} {gpu_block_ids} {transfer_task_id}") + if gpu_block_ids is None: + raise ValueError("gpu_block_ids cannot be None") + + keys_k = [f"{key}_key_{self.rank}" for key in keys] + result = self.storage_backend.exists(keys_k) + uncached_keys_k = [] + uncached_keys_v = [] + uncached_block_ids = [] + current_id = 0 + for k, v in result.items(): + if v == 0: + uncached_keys_k.append(k) + uncached_keys_v.append(f"{keys[current_id]}_value_{self.rank}") + uncached_block_ids.append(gpu_block_ids[current_id]) + current_id += 1 + # Run the async version synchronously + result = ( + keys, + [], + [], + CacheStatus.GPU2STORAGE, + transfer_task_id, + ) + if len(uncached_keys_k) > 0: + try: + # Create new event loop if needed + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + # Run with timeout + loop.run_until_complete( + asyncio.wait_for( + self._run_async_write(uncached_keys_k, uncached_keys_v, uncached_block_ids), timeout=timeout + ) + ) + result = ( + keys, + uncached_block_ids, + [], + CacheStatus.GPU2STORAGE, + transfer_task_id, + ) + except asyncio.TimeoutError: + logger.error(f"Write back storage task timed out after {timeout} seconds") + except Exception as e: + logger.error(f"Error in write back storage task: {e}") + else: + logger.info(f"No uncached keys found for task {transfer_task_id}") + + self.cache_task_queue.swap_to_storage_barrier.wait() + if self.rank == 0: + self.cache_task_queue.swap_to_storage_barrier.reset() + self.cache_task_queue.put_transfer_done_signal(result) + logger.debug(f"_do_swap_to_storage: put_transfer_done_signal {result}") + logger.info(f"_do_swap_to_storage: put_transfer_done_signal for transfer_task_id {transfer_task_id}") + def _do_swap_to_cpu_task( self, swap_node_ids, @@ -457,6 +697,7 @@ def do_data_transfer(self): cpu_block_id, event_type, transfer_task_id, + timeout, ) = data if event_type.value == CacheStatus.SWAP2CPU.value: self.swap_to_cpu_thread_pool.submit( @@ -467,7 +708,7 @@ def do_data_transfer(self): event_type, transfer_task_id, ) - else: + elif event_type.value == CacheStatus.SWAP2GPU.value: self.swap_to_gpu_thread_pool.submit( self._do_swap_to_gpu_task, swap_node_ids, @@ -476,6 +717,23 @@ def do_data_transfer(self): event_type, transfer_task_id, ) + elif event_type.value == CacheStatus.STORAGE2GPU.value: + self.swap_to_storage_thread_pool.submit( + self.load_storage_task, + transfer_task_id, + swap_node_ids, + gpu_block_id, + timeout, + ) + elif event_type.value == CacheStatus.GPU2STORAGE.value: + # logger.info(f"GPU2STORAGE {swap_node_ids} {gpu_block_id} {transfer_task_id}") + self.write_to_storage_thread_pool.submit( + self.write_back_storage_task, + swap_node_ids, + gpu_block_id, + transfer_task_id, + timeout, + ) else: if self.n_ranks > 1: self.cache_task_queue.barrier2.wait() diff --git a/fastdeploy/cache_manager/ops.py b/fastdeploy/cache_manager/ops.py index 8e1ae6aa712..f78f5431e32 100644 --- a/fastdeploy/cache_manager/ops.py +++ b/fastdeploy/cache_manager/ops.py @@ -30,6 +30,7 @@ set_data_ipc, share_external_data, swap_cache_all_layers, + swap_cache_layout, unset_data_ipc, ) @@ -50,6 +51,7 @@ def get_peer_mem_addr(*args, **kwargs): ) unset_data_ipc = None + swap_cache_layout = None memory_allocated = paddle.device.xpu.memory_allocated def get_data_ptr_ipc(*args, **kwargs): @@ -102,6 +104,7 @@ def get_all_visible_devices(): ipc_sent_key_value_cache_by_remote_ptr_block_sync = None get_peer_mem_addr = None get_all_visible_devices = None + swap_cache_layout = None __all__ = [ @@ -119,4 +122,5 @@ def get_all_visible_devices(): "ipc_sent_key_value_cache_by_remote_ptr_block_sync", "get_peer_mem_addr", "get_all_visible_devices", + "swap_cache_layout", ] diff --git a/fastdeploy/cache_manager/prefix_cache_manager.py b/fastdeploy/cache_manager/prefix_cache_manager.py index a3c610965a5..7cc249ba042 100644 --- a/fastdeploy/cache_manager/prefix_cache_manager.py +++ b/fastdeploy/cache_manager/prefix_cache_manager.py @@ -14,10 +14,8 @@ # limitations under the License. """ -import hashlib import heapq import os -import pickle import subprocess import sys import threading @@ -34,6 +32,7 @@ from fastdeploy.cache_manager.cache_data import BlockNode, CacheStatus from fastdeploy.cache_manager.cache_metrics import CacheMetrics from fastdeploy.cache_manager.ops import get_all_visible_devices +from fastdeploy.cache_manager.transfer_factory import get_hash_str from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, PrefixTreeStatus from fastdeploy.metrics.metrics import main_process_metrics from fastdeploy.utils import get_logger @@ -89,6 +88,15 @@ def __init__( self.radix_tree_root = BlockNode(-1, [], 0, 0, -1, 0, None, None, None) + self.write_policy = self.cache_config.write_policy + if self.write_policy not in ["write_through"]: + raise ValueError(f"Invalid write policy: {self.write_policy}") + + self.storage_backend = self.cache_config.kvcache_storage_backend + self.task_write_back_event = {} + + self.cal_block_hash = get_hash_str + # gpu cache data structure self.gpu_lru_leaf_heap = [] self.gpu_lru_leaf_set = set() @@ -100,6 +108,8 @@ def __init__( # swap in/out data structure self.request_release_lock = Lock() self.task_swapping_event = {} + self.task_prefetch_event = {} + self.task_prefetch_blocks_ids = {} self.node_map = {} self.req_leaf_map = {} # {request_id: leaf node} @@ -280,7 +290,9 @@ def launch_cache_manager( + f" --rdma_port {cache_config.rdma_comm_ports[i] if cache_config.rdma_comm_ports is not None else '0'}" + f" --speculative_config '{self.speculative_config.to_json_string()}'" + (" --create_cache_tensor" if create_cache_tensor else "") - + f" >{log_dir}/launch_cache_transfer_manager_tprank{i}.log 2>&1" + + f" --kvcache_storage_backend {cache_config.kvcache_storage_backend}" + + f" --write_policy {cache_config.write_policy}" + + f" >{log_dir}/launch_cache_transfer_manager_{int(device_ids[i])}.log 2>&1" ) logger.info(f"Launch cache transfer manager, command:{launch_cmd}") cache_manager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid)) @@ -289,7 +301,7 @@ def launch_cache_manager( while np.sum(self.cache_ready_signal.value) != tensor_parallel_size: time.sleep(1) - if cache_config.enable_hierarchical_cache and self.num_cpu_blocks > 0: + if self.num_cpu_blocks > 0: while np.sum(self.swap_space_ready_signal.value) != tensor_parallel_size: time.sleep(1) @@ -302,7 +314,7 @@ def launch_cache_manager( ) # Start additional threads - if cache_config.enable_hierarchical_cache and self.num_cpu_blocks > 0: + if cache_config.enable_hierarchical_kvcache or self.num_cpu_blocks > 0: logger.info("Enable hierarchical cache.") threading.Thread(target=self.recv_data_transfer_result).start() if cache_config.enable_prefix_caching: @@ -504,13 +516,7 @@ def issue_swap_task( self.task_swapping_event[transfer_task_id] = Event() self.cache_task_queue.put_transfer_task( - ( - swap_node_ids, - gpu_block_ids, - cpu_block_ids, - event_type, - transfer_task_id, - ) + (swap_node_ids, gpu_block_ids, cpu_block_ids, event_type, transfer_task_id, 0) ) if is_sync: self.sync_swap_task(transfer_task_id) @@ -566,6 +572,33 @@ def _prepare_cpu_cache( True, ) + def request_match_storage_blocks(self, request, extra_gpu_block_ids): + storage_block_ids = [] + task_id = request.request_id + input_ids = request.prompt_token_ids + prefix_block_key = "" + num_cached_tokens = 0 + if task_id in self.cache_info: + last_node, num_cached_tokens = self.cache_info[task_id] + prefix_block_key = last_node.hash_value + + block_size = self.cache_config.block_size + if self.storage_backend is not None: + keys = [] + current_tokens = num_cached_tokens + + while current_tokens < len(input_ids): + keys.append( + self.cal_block_hash(input_ids[current_tokens : current_tokens + block_size], [prefix_block_key]) + ) + current_tokens += block_size + + self.prefetch_kv_cache(task_id, keys, extra_gpu_block_ids, is_sync=False) + + storage_block_ids = self.sync_prefetch_task(task_id) + + return storage_block_ids + def _prepare_cache( self, req_id, @@ -596,6 +629,25 @@ def _prepare_cache( if gpu_extra_block_num > 0: gpu_extra_block_ids = self.allocate_gpu_blocks(gpu_extra_block_num) + storage_block_ids = [] + do_prefetch = False + if self.storage_backend is not None: + keys = [] + prefix_block_key = [] + num_cached_tokens = 0 + if req_id in self.cache_info: + last_node, num_cached_tokens = self.cache_info[req_id] + prefix_block_key = [last_node.hash_value] + current_tokens = num_cached_tokens + while current_tokens < len(input_ids): + keys.append( + self.cal_block_hash(input_ids[current_tokens : current_tokens + block_size], prefix_block_key) + ) + current_tokens += block_size + + self.prefetch_kv_cache(req_id, keys, gpu_extra_block_ids, is_sync=False) + do_prefetch = True + if len(gpu_recv_block_ids) > 0: self._prepare_cpu_cache( req_id, @@ -604,8 +656,10 @@ def _prepare_cache( cpu_recv_block_ids, match_cpu_block_ids, ) + if do_prefetch: + storage_block_ids = self.sync_prefetch_task(req_id) - return gpu_recv_block_ids, gpu_extra_block_ids + return gpu_recv_block_ids, gpu_extra_block_ids, storage_block_ids def get_required_block_num(self, input_token_num, block_size): """ @@ -628,6 +682,7 @@ def update_cache_blocks(self, task, block_size, num_computed_tokens): can_cache_computed_tokens = num_computed_tokens - num_computed_tokens % block_size if req_id in self.leaf_req_map[last_node]: # delete old leaf record, update later self.leaf_req_map[last_node].remove(req_id) + logger.debug(f"update_cache_blocks: req_id {req_id} can_cache_computed_tokens {can_cache_computed_tokens}") with self.request_release_lock: leaf_node = self.mm_build_path( @@ -805,10 +860,7 @@ def request_block_ids(self, task, block_size, dec_token_num, *args): current_time = time.time() self._update_matched_node_info(req_id, match_block_node, current_time) # 2. prepare cache - ( - gpu_recv_block_ids, - gpu_extra_block_ids, - ) = self._prepare_cache( + (gpu_recv_block_ids, gpu_extra_block_ids, storage_cached_block_ids) = self._prepare_cache( req_id, input_ids, block_size, @@ -828,7 +880,7 @@ def request_block_ids(self, task, block_size, dec_token_num, *args): gpu_build_path_block_ids = [] gpu_build_path_block_ids = gpu_extra_block_ids - + logger.debug(f"request_block_ids: req_id {req_id} left_input_ids {len(left_input_ids)}") leaf_node = self.build_path( req_id, current_time, @@ -852,6 +904,7 @@ def request_block_ids(self, task, block_size, dec_token_num, *args): ) hit_info["gpu_cache_blocks"] = gpu_match_token_num // block_size hit_info["cpu_cache_blocks"] = cpu_match_token_num // block_size + hit_info["L3_cache_blocks"] = len(storage_cached_block_ids) self.metrics._update_history_hit_metrics() if self.metrics.req_count % 10000 == 0: self.metrics.reset_metrics() @@ -882,6 +935,7 @@ def release_block_ids(self, task): with self.request_release_lock: try: req_id = task.request_id + keys = [] leaf_node = self.req_leaf_map.pop(req_id) if leaf_node in self.leaf_req_map: self.leaf_req_map[leaf_node].remove(req_id) @@ -892,8 +946,20 @@ def release_block_ids(self, task): if req_id in node.req_id_set: node.req_id_set.remove(req_id) node.decrement_shared_count() + keys.append(node.hash_value) node = node.parent + # To-DO, 异步写入 + output 写入 + if ( + self.cache_config.enable_hierarchical_kvcache + and self.cache_config.kvcache_storage_backend is not None + ): + if self.write_policy == "write_through" and keys: + logger.info(f"write_through {req_id} {keys} {task.block_tables[:len(keys)]}") + self.write_back_storage( + task_id=req_id, hash_keys=keys, gpu_block_ids=task.block_tables[: len(keys)], is_sync=True + ) + if req_id in self.cache_info: del self.cache_info[req_id] @@ -913,11 +979,54 @@ def release_block_ids(self, task): f"release_block_ids: req_id {req_id} has been finished, " + f"current gpu_lru_leaf_heap length {len(self.gpu_lru_leaf_heap)}" ) + return except Exception as e: logger.error(f"release_block_ids: error: {type(e)} {e}, {str(traceback.format_exc())}") raise e + def write_back_storage( + self, task_id, hash_keys, gpu_block_ids=None, cpu_block_ids=None, is_sync=True, timeout=0.1 + ): + + self.task_write_back_event[task_id] = Event() + self.cache_task_queue.put_transfer_task( + (hash_keys, gpu_block_ids, cpu_block_ids, CacheStatus.GPU2STORAGE, task_id, timeout) + ) # 发起数据传输任务 + + if is_sync: + self.sync_write_back_task(task_id) + return + + def sync_write_back_task(self, task_id): + """ + 同步ssd任务 + 当issue_ssd_task中设置is_sync为False时需主动调用该函数同步结果 + """ + self.task_write_back_event[task_id].wait() + del self.task_write_back_event[task_id] + + def prefetch_kv_cache(self, task_id, hash_keys, gpu_block_ids, is_sync=True, timeout=0.1): + storage_block_ids = [] + self.task_prefetch_event[task_id] = Event() + self.cache_task_queue.put_transfer_task( + (hash_keys, gpu_block_ids, None, CacheStatus.STORAGE2GPU, task_id, timeout) + ) # 发起数据传输任务 + if is_sync: + storage_block_ids = self.sync_prefetch_task(task_id) + return storage_block_ids + + def sync_prefetch_task(self, task_id): + """ + 同步ssd任务 + 当issue_ssd_task中设置is_sync为False时需主动调用该函数同步结果 + """ + self.task_prefetch_event[task_id].wait() + storage_block_ids = self.task_prefetch_blocks_ids[task_id] + del self.task_prefetch_event[task_id] + del self.task_prefetch_blocks_ids[task_id] + return storage_block_ids + def free_nodes_directly(self, node): with self.request_release_lock: try: @@ -1068,10 +1177,7 @@ def free_block_ids_async(self, need_block_num): break node = heapq.heappop(self.gpu_lru_leaf_heap) self.gpu_lru_leaf_set.remove(node) - if ( - not self.cache_config.enable_hierarchical_cache - or self.cache_config.num_cpu_blocks < need_block_num - ): + if self.cache_config.num_cpu_blocks < need_block_num: if node.shared_count == 0 and node.is_gpu_leaf_node: # 直接回收 self._handle_free_gpu_node_without_cpu(node) total_gpu_free_count += 1 @@ -1194,12 +1300,6 @@ def free_cpu_block_ids(self, need_block_num): ) return total_cpu_free_count - def cal_block_hash(self, block): - """ - calculate hash value of a block - """ - return hash(tuple(block)) - def get_block_hash_extra_keys(self, request, start_idx, end_idx, mm_idx): """ Retrieves additional hash keys for block identification. @@ -1259,16 +1359,6 @@ def get_block_hash_extra_keys(self, request, start_idx, end_idx, mm_idx): hash_keys.append(mm_inputs["mm_hashes"][img_idx]) return len(mm_inputs["mm_positions"]) - 1, hash_keys - def hash_block_features(self, input_ids, extra_keys: list = []): - """ - calculate hash value of a block with additional keys - - Args: - input_ids: Input token IDs - extra_keys: Additional keys for block identification - """ - return hashlib.sha256(pickle.dumps((input_ids, extra_keys))).hexdigest() - def _revert_match_blocks( self, request, @@ -1362,6 +1452,7 @@ def mm_match_block(self, request, block_size): matche_nodes = [] has_modified_gpu_lru_leaf_heap = False has_modified_cpu_lru_leaf_heap = False + prefix_cache = [] with self.cache_status_lock: while match_token_num < total_token_num: @@ -1375,7 +1466,11 @@ def mm_match_block(self, request, block_size): end_idx=match_token_num + block_size, mm_idx=mm_idx, ) - hash_value = self.hash_block_features(token_block, extra_keys) + prefix_cache.extend(extra_keys) + hash_value = self.cal_block_hash(token_block, prefix_cache) + logger.debug(f"match_block: req_id {request.request_id} hash_value: {hash_value}") + prefix_cache = [hash_value] + if hash_value in current_match_node.children: child = current_match_node.children[hash_value] matche_nodes.append(child) @@ -1475,14 +1570,15 @@ def match_block(self, req_id, input_ids, block_size): matche_nodes = [] has_modified_gpu_lru_leaf_heap = False has_modified_cpu_lru_leaf_heap = False - + prefix_block_key = [] with self.cache_status_lock: while match_token_num < total_token_num: token_block = input_ids[match_token_num : match_token_num + block_size] token_num = len(token_block) if token_num != block_size: break - hash_value = self.cal_block_hash(token_block) + hash_value = self.cal_block_hash(token_block, prefix_block_key) + prefix_block_key = [hash_value] if hash_value in current_match_node.children: child = current_match_node.children[hash_value] matche_nodes.append(child) @@ -1576,8 +1672,12 @@ def mm_build_path(self, request, num_computed_tokens, block_size, last_node, num has_unfilled_block = False current_time = time.time() - input_hash_value = self.hash_block_features(input_ids) + input_hash_value = self.cal_block_hash(input_ids) gpu_block_ids = request.block_tables[num_cached_tokens // block_size :].copy() + if last_node.hash_value is None: + prefix_cache = [] + else: + prefix_cache = [last_node.hash_value] for i in range(num_cached_tokens, can_cache_computed_tokens, block_size): current_block = input_ids[i : i + block_size] current_block_size = len(current_block) # 最后一个block可能没填满 @@ -1590,7 +1690,9 @@ def mm_build_path(self, request, num_computed_tokens, block_size, last_node, num end_idx=i + block_size, mm_idx=mm_idx, ) - hash_value = self.hash_block_features(current_block, extra_keys) + prefix_cache.extend(extra_keys) + hash_value = self.cal_block_hash(current_block, prefix_cache) + prefix_cache = [hash_value] allocated_block_id = gpu_block_ids.pop(0) node_id = self.node_id_pool.pop() unique_node_ids.append(node_id) @@ -1663,13 +1765,16 @@ def build_path( new_last_node = last_node has_unfilled_block = False + prefix_block_key = last_node.hash_value + for i in range(0, token_num, block_size): current_block = left_input_ids[i : i + block_size] current_block_size = len(current_block) # 最后一个block可能没填满 if current_block_size != block_size: has_unfilled_block = True else: - hash_value = self.cal_block_hash(current_block) + hash_value = self.cal_block_hash(current_block, [prefix_block_key]) + prefix_block_key = [hash_value] allocated_block_id = gpu_block_ids.pop(0) node_id = self.node_id_pool.pop() unique_node_ids.append(node_id) @@ -1771,20 +1876,40 @@ def recv_data_transfer_result(self): transfer_task_id, ) = data length = len(task_gpu_block_id) - for i in range(length): - self._handle_swap_result( - swap_node_ids[i], - task_gpu_block_id[i], - task_cpu_block_id[i], - event_type, + + if event_type.value == CacheStatus.STORAGE2GPU.value: + logger.info(f"{data}") + self.task_prefetch_blocks_ids[transfer_task_id] = task_gpu_block_id + if transfer_task_id in self.task_prefetch_event: + self.task_prefetch_event[transfer_task_id].set() + logger.info( + f"recv_data_transfer_result: transfer_task_id {transfer_task_id}: " + + "from storage to GPU" + + f"task_gpu_block_id {task_gpu_block_id} done" + ) + elif event_type.value == CacheStatus.GPU2STORAGE.value: + if transfer_task_id in self.task_write_back_event: + self.task_write_back_event[transfer_task_id].set() + logger.info( + f"recv_data_transfer_result: transfer_task_id {transfer_task_id}: " + + "from GPU to storage" + + f"task_gpu_block_id {task_gpu_block_id} done" + ) + else: + for i in range(length): + self._handle_swap_result( + swap_node_ids[i], + task_gpu_block_id[i], + task_cpu_block_id[i], + event_type, + ) + if transfer_task_id in self.task_swapping_event: + self.task_swapping_event[transfer_task_id].set() + logger.info( + f"recv_data_transfer_result: transfer_task_id {transfer_task_id}: " + + f"task_node_ids {swap_node_ids} task_gpu_block_id {task_gpu_block_id} " + + f"task_cpu_block_id {task_cpu_block_id} event_type {event_type} done" ) - if transfer_task_id in self.task_swapping_event: - self.task_swapping_event[transfer_task_id].set() - logger.info( - f"recv_data_transfer_result: transfer_task_id {transfer_task_id}: " - + f"task_node_ids {swap_node_ids} task_gpu_block_id {task_gpu_block_id} " - + f"task_cpu_block_id {task_cpu_block_id} event_type {event_type} done" - ) except Exception as e: logger.warning(f"recv_data_transfer_result: error: {e}, {str(traceback.format_exc())}") raise e diff --git a/fastdeploy/cache_manager/transfer_factory/__init__.py b/fastdeploy/cache_manager/transfer_factory/__init__.py index 31298a918c1..94435316144 100644 --- a/fastdeploy/cache_manager/transfer_factory/__init__.py +++ b/fastdeploy/cache_manager/transfer_factory/__init__.py @@ -14,7 +14,23 @@ # limitations under the License. """ -from .ipc_cache_transfer import IPCCommManager +from fastdeploy.platforms import current_platform + +from .kvcache_storage import KVCacheStorage, get_hash_str +from .mooncake_store import MooncakeStore, get_hash_str_mooncake from .rdma_cache_transfer import RDMACommManager -__all__ = ["IPCCommManager", "RDMACommManager"] +if current_platform.is_cuda(): + from .ipc_cache_transfer import IPCCommManager +else: + IPCCommManager = None + + +__all__ = [ + "IPCCommManager", + "RDMACommManager", + "KVCacheStorage", + "get_hash_str", + "MooncakeStore", + "get_hash_str_mooncake", +] diff --git a/fastdeploy/cache_manager/transfer_factory/kvcache_storage.py b/fastdeploy/cache_manager/transfer_factory/kvcache_storage.py new file mode 100644 index 00000000000..602477fb3e7 --- /dev/null +++ b/fastdeploy/cache_manager/transfer_factory/kvcache_storage.py @@ -0,0 +1,112 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import hashlib +import pickle +from abc import ABC, abstractmethod +from typing import Any, List, Optional + +import paddle + +from fastdeploy.utils import get_logger + +logger = get_logger("kvcache_storage", "kvcache_storage.log") + + +def get_hash_str(token_ids: List[int], extra_keys: list = []) -> str: + """ + calculate hash value of a block with additional keys + + Args: + token_ids: Input token IDs + extra_keys: Additional keys for block identification + """ + return hashlib.sha256(pickle.dumps((token_ids, extra_keys))).hexdigest() + + +class KVCacheStorage(ABC): + """ + KVCacheStorage is a class that provides a generic key-value interface for storing and retrieving KV cache. + """ + + @abstractmethod + def get( + self, + key: str, + target_location: Optional[Any] = None, + target_sizes: Optional[Any] = None, + ) -> paddle.Tensor | None: + """ + Retrieve the value associated with the given key. + Returns None if the key does not exist. + """ + pass + + @abstractmethod + def batch_get( + self, + keys: List[str], + target_locations: Optional[Any] = None, + target_sizes: Optional[Any] = None, + ) -> List[paddle.Tensor | None]: + """ + Retrieve values for multiple keys. + Returns a list of tensors or None for each key. + """ + pass + + @abstractmethod + def set( + self, + key: str, + value: Optional[Any] = None, + target_location: Optional[Any] = None, + target_sizes: Optional[Any] = None, + ) -> bool: + """ + Store the value associated with the given key. + Returns True if the operation was successful, False otherwise. + """ + pass + + @abstractmethod + def batch_set( + self, + keys: List[str], + values: Optional[Any] = None, + target_locations: Optional[Any] = None, + target_sizes: Optional[Any] = None, + ) -> bool: + """ + Store multiple key-value pairs. + Returns True if all operations were successful, False otherwise. + """ + pass + + @abstractmethod + def exists(self, key: str) -> bool: + """ + Check if the key exists in the storage. + Returns True if the key exists, False otherwise. + """ + pass + + @abstractmethod + def clear(self) -> bool: + """ + Clear all keys in storage + """ + pass diff --git a/fastdeploy/cache_manager/transfer_factory/mooncake_store/README.md b/fastdeploy/cache_manager/transfer_factory/mooncake_store/README.md new file mode 100644 index 00000000000..5dcf52be585 --- /dev/null +++ b/fastdeploy/cache_manager/transfer_factory/mooncake_store/README.md @@ -0,0 +1,79 @@ +# MooncakeStore for FastDeploy + +This document describes how to use MooncakeStore as the backend of FastDeploy for L3 Cache. + +## Installation + +### Install MooncakeStore with pip + +```bash +pip install mooncake-transfer-engine +``` + +### Install MooncakeStore from source + +```bash +git clone https://github.com/kvcache-ai/Mooncake --recursive +cd Mooncake +``` + +Install dependencies + +```bash +cd Mooncake +bash dependencies.sh +``` + +Build the project. For additional build options, please refer to [the official guide](https://kvcache-ai.github.io/Mooncake/getting_started/build.html). + +```bash +mkdir build +cd build +cmake .. +make -j +sudo make install +``` + +## Use Mooncake + +Launch Mooncake master server: + +```bash +mooncake_master \ + --enable_http_metadata_server=true \ + --http_metadata_server_host=0.0.0.0 \ + --http_metadata_server_port=7882 \ + --metrics_port=7883 \ + --port=7721 +``` + +### Command line options +``` +-metrics_port (Port for HTTP metrics server to listen on) type: int32 + default: 9003 +-enable_http_metadata_server (Enable HTTP metadata server instead of etcd) + type: bool default: false +-http_metadata_server_host (Host for HTTP metadata server to bind to) + type: string default: "0.0.0.0" +-http_metadata_server_port (Port for HTTP metadata server to listen on) + type: int32 default: 8080 +-eviction_high_watermark_ratio (Ratio of high watermark trigger eviction) + type: double default: 0.94999999999999996 +``` + +more parameter can be found in the [official guide](https://github.com/kvcache-ai/Mooncake/blob/main/docs/source/python-api-reference/transfer-engine.md). + +Start the Fastdeploy with Mooncake enabled. Mooncake configuration can be provided via environment variables: + +```bash +MOONCAKE_CONFIG_PATH="./mooncake_config.json" \ +python -m fastdeploy.entrypoints.openai.api_server \ + --enable-hierarchical-kvcache \ + --kvcache-storage-backend mooncake \ + --model-path [model_path] +``` + +## Troubleshooting + +For more details, please refer to: +https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/troubleshooting.md diff --git a/fastdeploy/cache_manager/transfer_factory/mooncake_store/__init__.py b/fastdeploy/cache_manager/transfer_factory/mooncake_store/__init__.py new file mode 100644 index 00000000000..49220d207eb --- /dev/null +++ b/fastdeploy/cache_manager/transfer_factory/mooncake_store/__init__.py @@ -0,0 +1,19 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from .mooncake_store import MooncakeStore, get_hash_str_mooncake + +__all__ = ["MooncakeStore", "get_hash_str_mooncake"] diff --git a/fastdeploy/cache_manager/transfer_factory/mooncake_store/mooncake_config.json b/fastdeploy/cache_manager/transfer_factory/mooncake_store/mooncake_config.json new file mode 100644 index 00000000000..ec0f6375a07 --- /dev/null +++ b/fastdeploy/cache_manager/transfer_factory/mooncake_store/mooncake_config.json @@ -0,0 +1,7 @@ +{ + "local_hostname":"localhost", + "metadata_server":"http://127.0.0.1:7882/metadata", + "protocol":"rdma", + "device_name": "mlx5_2,mlx5_3", + "master_server_address":"127.0.0.1:7721" +} diff --git a/fastdeploy/cache_manager/transfer_factory/mooncake_store/mooncake_store.py b/fastdeploy/cache_manager/transfer_factory/mooncake_store/mooncake_store.py new file mode 100644 index 00000000000..2aed0684798 --- /dev/null +++ b/fastdeploy/cache_manager/transfer_factory/mooncake_store/mooncake_store.py @@ -0,0 +1,260 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import json +import os +import time +import uuid +from dataclasses import dataclass +from typing import Any, List, Optional + +import paddle + +from fastdeploy.cache_manager.transfer_factory.kvcache_storage import ( + KVCacheStorage, + logger, +) + +DEFAULT_GLOBAL_SEGMENT_SIZE = 4 * 1024 * 1024 * 1024 # 4 GiB +DEFAULT_LOCAL_BUFFER_SIZE = 128 * 1024 * 1024 # 128 MB + + +def get_hash_str_mooncake(token_ids: List, extra_keys: list = [], prefix_block_key: str = "") -> str: + # TODO page 存储 计算hash + pass + + +@dataclass +class MooncakeStoreConfig: + local_hostname: str + metadata_server: str + global_segment_size: int + local_buffer_size: int + protocol: str + device_name: str + master_server_address: str + + @staticmethod + def from_file() -> "MooncakeStoreConfig": + """Load the config from a JSON file.""" + file_path = os.getenv("MOONCAKE_CONFIG_PATH") + if file_path is None: + raise ValueError("The environment variable 'MOONCAKE_CONFIG_PATH' is not set.") + with open(file_path) as fin: + config = json.load(fin) + return MooncakeStoreConfig( + local_hostname=config.get("local_hostname"), + metadata_server=config.get("metadata_server"), + global_segment_size=config.get("global_segment_size", DEFAULT_GLOBAL_SEGMENT_SIZE), + local_buffer_size=config.get("local_buffer_size", DEFAULT_LOCAL_BUFFER_SIZE), + protocol=config.get("protocol", "rdma"), + device_name=config.get("device_name", "auto"), + master_server_address=config.get("master_server_address"), + ) + + @staticmethod + def load_from_env() -> "MooncakeStoreConfig": + """Load config from a file specified in the environment variable. + export MOONCAKE_MASTER=10.13.3.232:50051 + export MOONCAKE_PROTOCOL="rdma" + export MOONCAKE_DEVICE="auto" + export MOONCAKE_TE_META_DATA_SERVER="P2PHANDSHAKE" + """ + # other required environment variables... + if not os.getenv("MOONCAKE_MASTER"): + raise ValueError("The environment variable 'MOONCAKE_MASTER' is not set.") + return MooncakeStoreConfig( + local_hostname=os.getenv("LOCAL_HOSTNAME", "localhost"), + metadata_server=os.getenv("MOONCAKE_TE_META_DATA_SERVER", "P2PHANDSHAKE"), + global_segment_size=int(os.getenv("MOONCAKE_GLOBAL_SEGMENT_SIZE", DEFAULT_GLOBAL_SEGMENT_SIZE)), + local_buffer_size=int(os.getenv("MOONCAKE_LOCAL_BUFFER_SIZE", DEFAULT_LOCAL_BUFFER_SIZE)), + protocol=os.getenv("MOONCAKE_PROTOCOL", "tcp"), + device_name=os.getenv("MOONCAKE_DEVICE", "auto"), + master_server_address=os.getenv("MOONCAKE_MASTER"), + ) + + def __post_init__(self): + # TODO check nic + if self.device_name == "auto": + os.environ["MC_MS_AUTO_DISC"] = "1" + os.environ["MC_MS_FILTERS"] = "mlx5_bond_0, mlx5_bond_1, mlx5_bond_2, mlx5_bond_3" + + +class MooncakeStore(KVCacheStorage): + def __init__(self): + try: + from mooncake.store import MooncakeDistributedStore + except ImportError as e: + raise ImportError( + "Please install mooncake by following the instructions at " + "https://kvcache-ai.github.io/Mooncake/getting_started/build.html" + "to run Fastdeploy with MooncakeConnector." + ) from e + + try: + self.store = MooncakeDistributedStore() + self.config = MooncakeStoreConfig.from_file() + logger.info("Mooncake Configuration loaded from env successfully.") + + ret_code = self.store.setup( + self.config.local_hostname, + self.config.metadata_server, + self.config.global_segment_size, + self.config.local_buffer_size, + self.config.protocol, + self.config.device_name, + self.config.master_server_address, + ) + + if ret_code: + logger.error(f"failed to setup mooncake store, error code: {ret_code}") + + logger.info("Connect to Mooncake store successfully.") + self.warmup() + logger.info("Mooncake store warmup successfully.") + + except ValueError as e: + logger.error("Configuration loading failed: %s", e) + raise + except Exception as exc: + logger.error("An error occurred while loading the configuration: %s", exc) + raise + + def warmup(self): + warmup_key = "fastdeploy_mooncake_store_warmup_key" + str(uuid.uuid4()) + # 1 MB + warmup_value = bytes(1 * 1024 * 1024) + self.store.put(warmup_key, warmup_value) + assert self.store.is_exist(warmup_key) == 1 + self.store.get(warmup_key) + self.store.remove(warmup_key) + # assert self.store.is_exist(warmup_key) == 0 + + def register_buffer(self, buffer_ptr, buffer_size) -> None: + try: + ret_code = self.store.register_buffer(buffer_ptr, buffer_size) + if ret_code: + logger.error(f"failed to register buffer, error code: {ret_code}") + except TypeError as err: + logger.error("Failed to register buffer to Mooncake Store: %s", err) + raise TypeError("Mooncake Store Register Buffer Error.") from err + + def set( + self, + key, + value: Optional[Any] = None, + target_location: Optional[List[int]] = None, + target_sizes: Optional[List[int]] = None, + ) -> bool: + assert len(key) == len(target_location) == len(target_sizes) + if len(key) == 0: + return + + for i in range(len(key)): + if key[i] is None or target_location[i] is None or target_sizes[i] is None: + return + + self._put_batch_zero_copy_impl(key, target_location, target_sizes) + + def batch_set( + self, + keys: List[str], + value: Optional[Any] = None, + target_location: Optional[List[int]] = None, + target_sizes: Optional[List[int]] = None, + ) -> bool: + assert len(keys) == len(target_location) == len(target_sizes) + if len(keys) == 0: + return + + for i in range(len(keys)): + if keys[i] is None or target_location[i] is None or target_sizes[i] is None: + return + + self._put_batch_zero_copy_impl(keys, target_location, target_sizes) + + def get( + self, + key, + target_location: Optional[Any] = None, + target_sizes: Optional[Any] = None, + ) -> paddle.Tensor | None: + assert len(key) == len(target_location) == len(target_sizes) + if len(key) == 0: + return + + for i in range(len(key)): + if key[i] is None or target_location[i] is None or target_sizes[i] is None: + return + + return self._get_batch_zero_copy_impl(key, target_location, target_sizes) + + def batch_get( + self, + keys: List[str], + target_location: Optional[Any] = None, + target_sizes: Optional[Any] = None, + ) -> paddle.Tensor | None: + assert len(keys) == len(target_location) == len(target_sizes) + if len(keys) == 0: + return + + for i in range(len(keys)): + if keys[i] is None or target_location[i] is None or target_sizes[i] is None: + return + + return self._get_batch_zero_copy_impl(keys, target_location, target_sizes) + + def exists(self, keys): + result = {k: v for k, v in zip(keys, self.store.batch_is_exist(keys))} + return result + + def delete(self, key, timeout=5) -> None: + while timeout: + result = self.store.remove(key) + if result == 0: + logger.info("Successfully removed") + break + else: + time.sleep(1) + timeout -= 1 + return result + + def close(self): + # MooncakeDistributedStore will automatically call the destructor, so + # it is unnecessary to close it manually. + pass + + def clear(self) -> None: + """ + clear all the objects in the store + """ + count = self.store.remove_all() + logger.info(f"Removed {count} objects") + + def _put_batch_zero_copy_impl(self, key_strs: List[str], buffer_ptrs: List[int], buffer_sizes: List[int]) -> None: + try: + self.store.batch_put_from(key_strs, buffer_ptrs, buffer_sizes) + except TypeError as err: + logger.error("Failed to put value to Mooncake Store: %s", err) + raise TypeError("Mooncake Store Put Type Error.") from err + + def _get_batch_zero_copy_impl(self, key_strs: List[str], buffer_ptrs: List[int], buffer_sizes: List[int]) -> None: + try: + self.store.batch_get_into(key_strs, buffer_ptrs, buffer_sizes) + except TypeError as err: + logger.error("Failed to get value from Mooncake Store: %s", err) + raise TypeError("Mooncake Store Get Type Error.") from err diff --git a/fastdeploy/cache_manager/transfer_factory/mooncake_store/test_mooncake_transfer.py b/fastdeploy/cache_manager/transfer_factory/mooncake_store/test_mooncake_transfer.py new file mode 100644 index 00000000000..699832112e9 --- /dev/null +++ b/fastdeploy/cache_manager/transfer_factory/mooncake_store/test_mooncake_transfer.py @@ -0,0 +1,79 @@ +import os +import time + +import numpy as np +import paddle + +from fastdeploy.cache_manager.ops import cuda_host_alloc, cuda_host_free +from fastdeploy.cache_manager.transfer_factory import MooncakeStore +from fastdeploy.model_executor.ops.gpu import swap_cache_layout + +MOONCAKE_CONFIG_PATH = "./mooncake_config.json" + + +class TestMooncakeStore: + def __init__(self): + os.environ["MOONCAKE_CONFIG_PATH"] = os.getenv("MOONCAKE_CONFIG_PATH", MOONCAKE_CONFIG_PATH) + self.storage_backend = MooncakeStore() + self.cache_shape = [4, 64, 128] + self.lay_number = 80 + self.max_model_len = 64 * 1024 + self.block_number = self.max_model_len // self.cache_shape[1] + + def test_register_buffer(self): + need_to_allocate_bytes = self.max_model_len * self.cache_shape[0] * self.lay_number * self.cache_shape[2] * 2 + self.cache_stride = self.cache_shape[1] * self.cache_shape[2] * self.cache_shape[0] * self.lay_number + print(f"creating cpu cache for alllayers {self.lay_number}: {need_to_allocate_bytes / 1024 ** 3:.2f}GB") + self.key_register_buffer = cuda_host_alloc(need_to_allocate_bytes) + self.storage_backend.register_buffer(self.key_register_buffer, need_to_allocate_bytes) + + def _init_gpu_blocks(self): + self.key_gpu_cache = [] + for i in range(self.lay_number): + key_cache = paddle.full( + shape=[self.block_number, self.cache_shape[0], self.cache_shape[1], self.cache_shape[2]], + fill_value=0, + dtype=paddle.bfloat16, + ) + self.key_gpu_cache.append(key_cache) + + def test_write_storage(self, test_block_num=16): + start_time = time.time() + keys = [f"test_key_{i}" for i in range(test_block_num)] + gpu_block_ids = np.arange(test_block_num) + + swap_cache_layout(self.key_gpu_cache, self.key_register_buffer, gpu_block_ids, 0, 1) # gpu ==> cpu + print("swap_cache_layout done", time.time() - start_time) + # import pdb; pdb.set_trace() + target_location = [self.key_register_buffer + i * self.cache_stride for i in range(test_block_num)] + + target_sizes = [self.cache_stride] * test_block_num + + self.storage_backend.set(keys, target_location=target_location, target_sizes=target_sizes) + print("write storage time: ", time.time() - start_time) + + def test_read_storage(self, test_block_num=16): + start_time = time.time() + keys = [f"test_key_{i}" for i in range(test_block_num)] + gpu_block_ids = np.arange(test_block_num) + target_location = [self.key_register_buffer + i * self.cache_stride for i in range(test_block_num)] + + target_sizes = [self.cache_stride] * test_block_num + + self.storage_backend.get(keys, target_location=target_location, target_sizes=target_sizes) + + swap_cache_layout(self.key_gpu_cache, self.key_register_buffer, gpu_block_ids, 0, 0) # cpu ==> gpu + print("read storage time: ", time.time() - start_time) + + def free(self): + cuda_host_free(self.key_register_buffer) + + +if __name__ == "__main__": + + tester = TestMooncakeStore() + tester.test_register_buffer() + tester._init_gpu_blocks() + tester.test_write_storage(64) + tester.test_read_storage(64) + tester.free() diff --git a/fastdeploy/cache_manager/transfer_factory/mooncake_store/unit_test.py b/fastdeploy/cache_manager/transfer_factory/mooncake_store/unit_test.py new file mode 100644 index 00000000000..b87d5bcd2e2 --- /dev/null +++ b/fastdeploy/cache_manager/transfer_factory/mooncake_store/unit_test.py @@ -0,0 +1,54 @@ +import paddle + +from fastdeploy.cache_manager.transfer_factory import MooncakeStore + + +def test_init_and_warmup(): + store = MooncakeStore() + assert store.store is not None + + +def test_store_basic_function(): + store = MooncakeStore() + buffer = paddle.zeros([1024, 1024], dtype=paddle.float32).cpu() + store.register_buffer(buffer.data_ptr(), 1024 * 1024 * buffer.element_size()) + + key = ["test_key_" + str(i) for i in range(2)] + buffer[0, :] = 1 + buffer[1, :] = 2 + ptrs = [buffer.data_ptr(), buffer.data_ptr() + 1024 * 4] + sizes = [1024, 1024] + + store.set(key, target_location=ptrs, target_sizes=sizes) + buffer[0, :] = 3 + buffer[1, :] = 4 + print(buffer[0, 0], buffer[1, 0]) + + store.get(key, target_location=ptrs, target_sizes=sizes) + print("key: ", key) + print("buffer: ", buffer[0, 0], buffer[1, 0]) + assert buffer[0, 0] == 1 + assert buffer[1, 0] == 2 + keys = ["test_key_0", "non_existent_key"] + + result = store.exists(keys) + assert isinstance(result, dict) + assert "test_key_0" in result + print(result) + assert result["test_key_0"] == 1 + assert result["non_existent_key"] == 0 + + res = store.delete("test_key_0", timeout=10) + assert res == 0 + + new_result = store.exists(["test_key_0"]) + print(new_result) + assert new_result["test_key_0"] == 0 + + +if __name__ == "__main__": + import os + + os.environ["MOONCAKE_CONFIG_PATH"] = "./mooncake_config.json" + test_init_and_warmup() + test_store_basic_function() diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 24b26161a47..d4afd98fbf1 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -1267,6 +1267,9 @@ def __init__(self, args): self.enable_ssd_cache = False self.cache_queue_port = None self.swap_space = None + self.write_policy = None + self.enable_hierarchical_kvcache = False + self.kvcache_storage_backend = None self.max_encoder_cache = None self.max_processor_cache = None self.enable_output_caching = False @@ -1281,11 +1284,6 @@ def __init__(self, args): if self.pd_comm_port is not None and isinstance(self.pd_comm_port, str): self.pd_comm_port = [int(port) for port in self.pd_comm_port.split(",")] - if self.swap_space is None: - self.enable_hierarchical_cache = False - else: - self.enable_hierarchical_cache = True - if self.model_cfg is not None: if self.model_cfg.quantization is not None and isinstance(self.model_cfg.quantization, dict): self.cache_dtype = self.model_cfg.quantization.get("kv_cache_quant_type", self.cache_dtype) diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index d2d7c6f908a..452642629ed 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -228,6 +228,15 @@ class EngineArgs: """ Port for cache queue. """ + kvcache_storage_backend: str = None + """ + The storage backend for kvcache storage. + """ + write_policy: str = "write_through" + """ + KVCache write policy + """ + enable_hierarchical_kvcache: bool = False # System configuration parameters use_warmup: int = 0 @@ -951,6 +960,29 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help="Static decoding blocks num.", ) + cache_group.add_argument( + "--enable-hierarchical-kvcache", + action="store_true", + default=EngineArgs.enable_hierarchical_kvcache, + help="Enable hierarchical kvcache.", + ) + + cache_group.add_argument( + "--kvcache-storage-backend", + type=str, + choices=["mooncake"], + default=EngineArgs.kvcache_storage_backend, + help="The storage backend for kvcache storage.", + ) + + cache_group.add_argument( + "--write-policy", + type=str, + choices=["write_through"], + default=EngineArgs.write_policy, + help="KVCache write policy", + ) + # Cluster system parameters group system_group = parser.add_argument_group("System Configuration") system_group.add_argument( diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index 00f601c3121..322698ea198 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -628,17 +628,19 @@ def _allocate_decode_and_extend(): num_new_tokens = self._get_num_new_tokens(request, token_budget) num_new_block = self.get_new_block_nums(request, num_new_tokens) # Allocate blocks to prefill - if self.cache_manager.can_allocate_gpu_blocks(num_new_block): - request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(num_new_block)) - # Prepare prefill task - scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens)) - else: # Not enough blocks to allocate, trigger preemption + if not self.cache_manager.can_allocate_gpu_blocks(num_new_block): + # Not enough blocks to allocate, trigger preemption can_schedule = self._trigger_preempt(request, num_new_block, preempted_reqs, scheduled_reqs) if not can_schedule: break - request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(num_new_block)) - # Prepare prefill task - scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens)) + extra_gpu_block_ids = self.cache_manager.allocate_gpu_blocks(num_new_block) + if self.config.cache_config.enable_prefix_caching: + storage_block_ids = self.get_storage_cached_blocks(request, extra_gpu_block_ids) + num_new_tokens -= len(storage_block_ids) * self.config.cache_config.block_size + request.block_tables.extend(extra_gpu_block_ids) + # Prepare prefill task + scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens)) + token_budget -= num_new_tokens request.num_computed_tokens += num_new_tokens if self.config.cache_config.enable_prefix_caching: @@ -675,10 +677,7 @@ def _allocate_decode_and_extend(): self._update_mm_hashes(request) # Enable prefix caching if self.config.cache_config.enable_prefix_caching: - if ( - self.config.cache_config.enable_hierarchical_cache - and self.cache_manager.num_cpu_blocks > 0 - ): + if self.cache_manager.num_cpu_blocks > 0: if not self.cache_manager.can_allocate_gpu_blocks( (request.need_prefill_tokens + self.config.cache_config.block_size - 1) // self.config.cache_config.block_size @@ -694,7 +693,12 @@ def _allocate_decode_and_extend(): # Allocate blocks to prefill if self.cache_manager.can_allocate_gpu_blocks(num_new_block): if not request.get("skip_allocate", False): - request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(num_new_block)) + extra_gpu_block_ids = self.cache_manager.allocate_gpu_blocks(num_new_block) + if self.config.cache_config.enable_prefix_caching: + storage_block_ids = self.get_storage_cached_blocks(request, extra_gpu_block_ids) + num_new_tokens -= len(storage_block_ids) * self.config.cache_config.block_size + request.block_tables.extend(extra_gpu_block_ids) + self.waiting.popleft() self.running.append(request) scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens)) @@ -720,10 +724,7 @@ def _allocate_decode_and_extend(): request.num_total_tokens ) # Before preempted task rescheduled, preempted task has been sent to engine, no more tokens are output, here num_total_tokens should be static and correct if self.config.cache_config.enable_prefix_caching: - if ( - self.config.cache_config.enable_hierarchical_cache - and self.cache_manager.num_cpu_blocks > 0 - ): + if self.cache_manager.num_cpu_blocks > 0: if not self.cache_manager.can_allocate_gpu_blocks( (request.need_prefill_tokens + self.config.cache_config.block_size - 1) // self.config.cache_config.block_size @@ -738,7 +739,11 @@ def _allocate_decode_and_extend(): # Allocate blocks to prefill if self.cache_manager.can_allocate_gpu_blocks(num_new_block): if not request.get("skip_allocate", False): - request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(num_new_block)) + extra_gpu_block_ids = self.cache_manager.allocate_gpu_blocks(num_new_block) + if self.config.cache_config.enable_prefix_caching: + storage_block_ids = self.get_storage_cached_blocks(request, extra_gpu_block_ids) + num_new_tokens -= len(storage_block_ids) * self.config.cache_config.block_size + request.block_tables.extend(extra_gpu_block_ids) self.waiting.popleft() self.running.append(request) scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens)) @@ -875,6 +880,36 @@ def get_real_bsz(self) -> int: break return self.real_bsz + def get_storage_cached_blocks(self, request: Request, extra_gpu_block_ids: list = []): + """ + set prefix cached information for the given request + """ + try: + cache_prepare_time = time.time() + llm_logger.info(f"req {request.request_id}") + matched_block_ids = self.cache_manager.request_match_storage_blocks(request, extra_gpu_block_ids) + llm_logger.info( + f"storage backend: {self.config.cache_config.kvcache_storage_backend} matched block ids: {matched_block_ids}" + ) + + matched_token_num = len(matched_block_ids) * self.config.cache_config.block_size + + request.num_cached_tokens += matched_token_num + match_block_num, no_cache_block_num = request.cache_info + match_block_num += len(matched_block_ids) + no_cache_block_num -= len(matched_block_ids) + + request.cache_info = (match_block_num, no_cache_block_num) + + # Report the number of cached tokens to Prometheus metrics + main_process_metrics.prefix_cache_token_num.inc(matched_token_num) + request.num_computed_tokens += matched_token_num + request.cache_prepare_time += time.time() - cache_prepare_time + return matched_block_ids + except Exception as e: + llm_logger.error(f"prefix match blocks error: {e}, {str(traceback.format_exc())} waiting reschedule...") + return [] + def get_prefix_cached_blocks(self, request: Request): """ set prefix cached information for the given request @@ -956,7 +991,7 @@ def preallocate_resource_in_p(self, request: Request): ) // self.config.cache_config.block_size + self.config.cache_config.enc_dec_block_num # consider for mtp, plus enc_dec_block_num if self.config.cache_config.enable_prefix_caching: # Enable prefix caching - if self.config.cache_config.enable_hierarchical_cache and self.cache_manager.num_cpu_blocks > 0: + if self.cache_manager.num_cpu_blocks > 0: if not self.cache_manager.can_allocate_gpu_blocks( need_prealloc_prefill_blocks ): # to prevent block allocation for matching in hierarchical cache and cause dead lock @@ -968,7 +1003,10 @@ def preallocate_resource_in_p(self, request: Request): need_extra_prefill_blocks = need_prealloc_prefill_blocks - request.cache_info[0] if self.cache_manager.can_allocate_gpu_blocks(need_extra_prefill_blocks): - request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(need_extra_prefill_blocks)) + extra_gpu_block_ids = self.cache_manager.allocate_gpu_blocks(need_extra_prefill_blocks) + if self.config.cache_config.enable_prefix_caching: + self.get_storage_cached_blocks(request, extra_gpu_block_ids) + request.block_tables.extend(extra_gpu_block_ids) allocated_position = self.get_available_position() request.idx = allocated_position self.tasks_list[request.idx] = request diff --git a/fastdeploy/inter_communicator/engine_cache_queue.py b/fastdeploy/inter_communicator/engine_cache_queue.py index f46929755e1..76f36feec3b 100644 --- a/fastdeploy/inter_communicator/engine_cache_queue.py +++ b/fastdeploy/inter_communicator/engine_cache_queue.py @@ -102,6 +102,12 @@ class QueueManager(BaseManager): self.swap_to_gpu_barrier2_init = [ threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size) ] + self.swap_storage_to_gpu_barrier_init = [ + threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size) + ] + self.swap_to_storage_barrier_init = [ + threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size) + ] # Register shared objects with proxy types QueueManager.register( @@ -148,7 +154,14 @@ class QueueManager(BaseManager): "get_swap_to_gpu_barrier2", callable=lambda idx: self.swap_to_gpu_barrier2_init[idx], ) - + QueueManager.register( + "get_swap_storage_to_gpu_barrier", + callable=lambda idx: self.swap_storage_to_gpu_barrier_init[idx], + ) + QueueManager.register( + "get_swap_to_storage_barrier", + callable=lambda idx: self.swap_to_storage_barrier_init[idx], + ) self.manager: BaseManager = QueueManager(address=self.address, authkey=self.authkey) self.manager.start() @@ -175,6 +188,8 @@ class QueueManager(BaseManager): QueueManager.register("get_swap_to_cpu_barrier2") QueueManager.register("get_swap_to_gpu_barrier1") QueueManager.register("get_swap_to_gpu_barrier2") + QueueManager.register("get_swap_storage_to_gpu_barrier") + QueueManager.register("get_swap_to_storage_barrier") self.manager = QueueManager(address=self.address, authkey=self.authkey) self._connect_with_retry() @@ -194,6 +209,8 @@ class QueueManager(BaseManager): self.swap_to_cpu_barrier2 = self.manager.get_swap_to_cpu_barrier2(self.local_data_parallel_id) self.swap_to_gpu_barrier1 = self.manager.get_swap_to_gpu_barrier1(self.local_data_parallel_id) self.swap_to_gpu_barrier2 = self.manager.get_swap_to_gpu_barrier2(self.local_data_parallel_id) + self.swap_storage_to_gpu_barrier = self.manager.get_swap_storage_to_gpu_barrier(self.local_data_parallel_id) + self.swap_to_storage_barrier = self.manager.get_swap_to_storage_barrier(self.local_data_parallel_id) self.total_num: int = (1 << self.num_client) - 1 if not is_server: diff --git a/tests/cache_manager/test_cache_transfer_manager.py b/tests/cache_manager/test_cache_transfer_manager.py index f09fc603325..7ac2c66fa5f 100644 --- a/tests/cache_manager/test_cache_transfer_manager.py +++ b/tests/cache_manager/test_cache_transfer_manager.py @@ -26,6 +26,8 @@ class Args: value_cache_shape = "" create_cache_tensor = False cache_dtype = "bfloat16" + kvcache_storage_backend = "None" + write_policy = "write_through" # ========================== diff --git a/tests/operators/test_swap_layout.py b/tests/operators/test_swap_layout.py new file mode 100644 index 00000000000..1faf6d57135 --- /dev/null +++ b/tests/operators/test_swap_layout.py @@ -0,0 +1,59 @@ +import time +import unittest + +import numpy as np +import paddle + +from fastdeploy.cache_manager.ops import cuda_host_alloc, cuda_host_free +from fastdeploy.model_executor.ops.gpu import swap_cache_layout + + +class Test(unittest.TestCase): + def setUp(self): + + self.cache_shape = [8, 64, 4, 128] + self.layer_num = 10 + self.block_ids = np.arange(self.cache_shape[0]) + self.key_register_buffer = cuda_host_alloc(np.prod(self.cache_shape) * 2) + + def release_buffer(self): + cuda_host_free(self.key_register_buffer) + + def test_swap_cache_layout(self): + + gpu_key_register_buffer = [] + for i in range(self.layer_num): + gpu_key_register_buffer.append(paddle.full(self.cache_shape, fill_value=i, dtype=paddle.float16)) + + ss = time.time() + swap_cache_layout( + gpu_key_register_buffer, + self.key_register_buffer, + self.cache_shape, + self.block_ids, + 0, + 0, + ) + print("swap cache layout (host to device): ", time.time() - ss) + ss = time.time() + del gpu_key_register_buffer + gpu_key_register_buffer = [] + for i in range(self.layer_num): + gpu_key_register_buffer.append(paddle.zeros(self.cache_shape, dtype=paddle.float16)) + swap_cache_layout( + gpu_key_register_buffer, + self.key_register_buffer, + self.cache_shape, + self.block_ids, + 0, + 1, + ) + for i in range(self.layer_num): + assert gpu_key_register_buffer[i].numpy()[0, 0, 0, 0] == i + print("swap cache layout(device to host):", time.time() - ss) + + self.release_buffer() + + +if __name__ == "__main__": + unittest.main()