From e6265d1e25633741045c0e93fa2eaffe7a86ae2b Mon Sep 17 00:00:00 2001 From: jayfeather9 Date: Tue, 18 Mar 2025 17:10:21 +0800 Subject: [PATCH 01/13] initial hicache support not finished --- .../router/dynamic_prompt/cache_controller.py | 230 ++++++++++++++++++ .../router/dynamic_prompt/hiradix_cache.py | 64 +++++ 2 files changed, 294 insertions(+) create mode 100644 lightllm/server/router/dynamic_prompt/cache_controller.py create mode 100644 lightllm/server/router/dynamic_prompt/hiradix_cache.py diff --git a/lightllm/server/router/dynamic_prompt/cache_controller.py b/lightllm/server/router/dynamic_prompt/cache_controller.py new file mode 100644 index 000000000..44fb79af1 --- /dev/null +++ b/lightllm/server/router/dynamic_prompt/cache_controller.py @@ -0,0 +1,230 @@ +import torch +import threading +import time +import json +from typing import Dict, List, Tuple, Optional, Set, Any +from queue import Queue +from lightllm.common.mem_manager import MemoryManager + +BLOCK_SIZE = 16384 + +def get_torch_tensor_size(tensor: torch.Tensor): + return tensor.nelement() * tensor.element_size() + +class CacheNode: + def __init__(self, parent=None, split_token_idx=None): + self.parent = parent # 父节点 + self.split_token_idx = split_token_idx # 从父节点分裂的位置 + self.children = {} # (token_id, split_position) -> (child_node, split_position) + self.cache_indices = [] # 存储kv cache在mem_manager中的索引 + self.token_ids = [] # 当前节点存储的token ids + self.hash = None # 存储在磁盘上的唯一标识 + + def serialize(self): + """将节点数据序列化为JSON""" + data = { + "children": {f"{k[0]}_{k[1]}": [c.hash, p] for k, (c, p) in self.children.items()}, + "cache_indices": self.cache_indices, + "token_ids": self.token_ids, + "split_token_idx": self.split_token_idx + } + return json.dumps(data) + + @classmethod + def deserialize(cls, data_str, parent=None): + """从JSON反序列化节点数据""" + data = json.loads(data_str) + node = cls(parent=parent, split_token_idx=data["split_token_idx"]) + node.cache_indices = data["cache_indices"] + node.token_ids = data["token_ids"] + # 子节点需要单独加载 + return node, {(int(k.split('_')[0]), int(k.split('_')[1])): (v[0], v[1]) for k, v in data["children"].items()} + + +class HiCacheController: + def __init__(self, mem_manager: MemoryManager): + self.mem_manager = mem_manager + self.service = None # 将由外部代码初始化 + + self.root = CacheNode() + self.root.hash = "root" + + self.node_cache = {self.root.hash: self.root} # hash -> node + self.read_queue = Queue() + self.write_queue = Queue() + + self.token_kvcache_size = None # 每个token的kvcache大小 + + # 启动后台线程处理读写任务 + self.running = True + self.poll_thread = threading.Thread(target=self._poll_tasks) + self.poll_thread.daemon = True + self.poll_thread.start() + + def reset(self): + """重置缓存控制器""" + self.running = False + self.poll_thread.join(timeout=1) + + self.root = CacheNode() + self.root.hash = "root" + self.node_cache = {self.root.hash: self.root} + + self.read_queue = Queue() + self.write_queue = Queue() + + self.running = True + self.poll_thread = threading.Thread(target=self._poll_tasks) + self.poll_thread.daemon = True + self.poll_thread.start() + + def _poll_tasks(self): + """轮询读写任务,检查是否完成""" + while self.running: + # 处理读任务 + pending_reads = [] + while not self.read_queue.empty(): + task = self.read_queue.get() + if task.ready(): + # TODO: 将读到的内容存入 memory manager 中 + pass + else: + pending_reads.append(task) + + for task in pending_reads: + self.read_queue.put(task) + + # 处理写任务 + pending_writes = [] + while not self.write_queue.empty(): + task = self.write_queue.get() + if not task.ready(): + pending_writes.append(task) + + for task in pending_writes: + self.write_queue.put(task) + + time.sleep(0.01) # 避免CPU过度使用 + + def _ensure_node_loaded(self, node_hash): + """确保节点已加载到内存中""" + if node_hash not in self.node_cache and node_hash != "root": + task = self.service.create(hashs=[node_hash], mode="r") + self.service.commit(task) + self.read_queue.put(task) + # 需要等待节点加载完成 + while not task.ready() or node_hash not in self.node_cache: + time.sleep(0.01) + + def _persist_node(self, node): + """将节点持久化到磁盘""" + if not node.hash: + # 为新节点生成hash + node.hash = f"node_{id(node)}_{time.time()}" + + # TODO: 将对应的kvcache写入磁盘 + task = self.service.create(hashs=[node.hash], mode="w") + self.service.commit(task) + self.write_queue.put(task) + self.node_cache[node.hash] = node + + def write(self, key: torch.Tensor, value: torch.Tensor): + """ + 写入token序列及其对应的KV缓存索引 + key: token_ids序列 + value: 对应的KV缓存索引 + """ + token_ids = key.cpu().tolist() + indices = value.cpu().tolist() + + # 首次计算每个token的kvcache大小 + if self.token_kvcache_size is None: + kvcache = self.mem_manager.to_kvcache(indices[:1]) # 计算单个token的kvcache + self.token_kvcache_size = get_torch_tensor_size(kvcache) + + current = self.root + position = 0 + relative_position = 0 + + while position < len(token_ids): + token_id = token_ids[position] + child_key = (token_id, relative_position) + + if child_key in current.children: + child_info = current.children[child_key] + assert isinstance(child_info[0], CacheNode) + child_hash = child_info[0].hash + self._ensure_node_loaded(child_hash) + current = self.node_cache[child_hash] + position += 1 + relative_position = 0 # next time relative pos is 0 + else: + # 计算当前节点剩余空间 + remaining_space = BLOCK_SIZE - len(current.cache_indices) * self.token_kvcache_size + + if self.token_kvcache_size <= remaining_space: + # 当前节点有足够空间 + current.token_ids.append(token_ids[position]) + current.cache_indices.append(indices[position]) + position += 1 + relative_position += 1 + self._persist_node(current) + else: + # 当前节点已满,需要创建新节点 + new_node = CacheNode(parent=current, split_token_idx=len(current.token_ids)) + + # 将token添加到新节点 + new_node.token_ids.append(token_ids[position]) + new_node.cache_indices.append(indices[position]) + position += 1 + relative_position = 0 # next time relative pos is 0, not affecting child_key + + # 建立父子关系 + current.children[child_key] = (new_node, len(current.cache_indices)) + + # 持久化 + self._persist_node(new_node) + self._persist_node(current) + + current = new_node + + # 确保最后修改的节点被持久化 + self._persist_node(current) + + def read(self, key: torch.Tensor) -> torch.Tensor: + """ + 读取token序列对应的KV缓存索引 + key: token_ids序列 + 返回: 对应的KV缓存索引 + """ + token_ids = key.cpu().tolist() + result_indices = [] + + current = self.root + position = 0 + relative_position = 0 + + while position < len(token_ids): + token_id = token_ids[position] + + # 检查当前节点的token + if relative_position < len(current.token_ids) and current.token_ids[relative_position] == token_id: + # TODO: 将读到的东西存到 result_indices 中 + position += 1 + relative_position += 1 + continue + + # 查找子节点 + child_key = (token_id, relative_position) + if child_key in current.children: + child_info = current.children[child_key] + assert isinstance(child_info[0], CacheNode) + child_hash = child_info[0].hash + self._ensure_node_loaded(child_hash) + current = self.node_cache[child_hash] + relative_position = 0 + else: + # 未找到匹配的路径 + return torch.tensor(result_indices, dtype=torch.int64) + + return torch.tensor(result_indices, dtype=torch.int64) diff --git a/lightllm/server/router/dynamic_prompt/hiradix_cache.py b/lightllm/server/router/dynamic_prompt/hiradix_cache.py new file mode 100644 index 000000000..151bcd53a --- /dev/null +++ b/lightllm/server/router/dynamic_prompt/hiradix_cache.py @@ -0,0 +1,64 @@ +import torch +from .cache_controller import HiCacheController +from .radix_cache import RadixCache, TreeNode, match +from typing import Tuple, Dict, Set, List +from lightllm.common.mem_manager import MemoryManager + + +class HiRadixCache(RadixCache): + def __init__(self, cache_controller: HiCacheController, unique_name, total_token_num, rank_in_node, mem_manager: MemoryManager = None): + super().__init__(unique_name, total_token_num, rank_in_node, mem_manager) + self.cache_controller = cache_controller + + def _insert_helper(self, node: TreeNode, key, value): + if node.is_leaf(): + self.evict_tree_set.discard(node) + + try: + first_key_id = key[0].item() + if first_key_id in node.children.keys(): + child: TreeNode = node.children[first_key_id] + prefix_len = match(key, child.token_id_key) + if prefix_len == len(key): + if child.is_leaf(): + self.evict_tree_set.discard(child) + child.update_time() + if child.is_leaf(): + self.evict_tree_set.add(child) + return prefix_len + + elif prefix_len < len(key) and prefix_len < len(child.token_id_key): + if child.is_leaf(): + self.evict_tree_set.discard(child) + + key = key[prefix_len:] + value = value[prefix_len:] + split_parent_node = child.split_node(prefix_len) + new_node = split_parent_node.add_and_return_new_child(key, value) + # update total token num + self.tree_total_tokens_num.arr[0] += len(new_node.token_mem_index_value) + + if split_parent_node.is_leaf(): + self.evict_tree_set.add(split_parent_node) + if new_node.is_leaf(): + self.evict_tree_set.add(new_node) + + if child.is_leaf(): + self.evict_tree_set.add(child) + return prefix_len + elif prefix_len < len(key) and prefix_len == len(child.token_id_key): + return prefix_len + self._insert_helper(child, key[prefix_len:], value[prefix_len:]) + else: + assert False, "can not run to here" + + else: + new_node = node.add_and_return_new_child(key, value) + # update total token num + self.tree_total_tokens_num.arr[0] += len(new_node.token_mem_index_value) + if new_node.is_leaf(): + self.evict_tree_set.add(new_node) + return 0 + finally: + node.update_time() + if node.is_leaf(): + self.evict_tree_set.add(node) From 382ab2fb526eadc28d2919968775bbc0911bf17d Mon Sep 17 00:00:00 2001 From: jayfeather9 Date: Tue, 18 Mar 2025 22:02:09 +0800 Subject: [PATCH 02/13] add debug outputs --- .../router/dynamic_prompt/cache_controller.py | 14 +- test/server/test_hicache.py | 156 ++++++++++++++++++ 2 files changed, 168 insertions(+), 2 deletions(-) create mode 100644 test/server/test_hicache.py diff --git a/lightllm/server/router/dynamic_prompt/cache_controller.py b/lightllm/server/router/dynamic_prompt/cache_controller.py index 44fb79af1..e14106923 100644 --- a/lightllm/server/router/dynamic_prompt/cache_controller.py +++ b/lightllm/server/router/dynamic_prompt/cache_controller.py @@ -87,7 +87,11 @@ def _poll_tasks(self): task = self.read_queue.get() if task.ready(): # TODO: 将读到的内容存入 memory manager 中 - pass + node_hash = task.hashs[0] + if node_hash in self.node_cache: + node = self.node_cache[node_hash] + node.cache_indices = self.mem_manager.store(node.cache_indices, task.value) + print(f"Node {node_hash} loaded with {len(node.cache_indices)} cache indices") else: pending_reads.append(task) @@ -118,12 +122,13 @@ def _ensure_node_loaded(self, node_hash): def _persist_node(self, node): """将节点持久化到磁盘""" + print(f"Persisting node {node.hash} with {len(node.token_ids)} tokens") if not node.hash: # 为新节点生成hash node.hash = f"node_{id(node)}_{time.time()}" # TODO: 将对应的kvcache写入磁盘 - task = self.service.create(hashs=[node.hash], mode="w") + task = self.service.create(hashs=[node.hash], value=self.mem_manager.to_kvcache(node.cache_indices), mode="w") self.service.commit(task) self.write_queue.put(task) self.node_cache[node.hash] = node @@ -141,6 +146,7 @@ def write(self, key: torch.Tensor, value: torch.Tensor): if self.token_kvcache_size is None: kvcache = self.mem_manager.to_kvcache(indices[:1]) # 计算单个token的kvcache self.token_kvcache_size = get_torch_tensor_size(kvcache) + print(f"Single token KV cache size: {self.token_kvcache_size} bytes, Block size: {BLOCK_SIZE}") current = self.root position = 0 @@ -148,9 +154,11 @@ def write(self, key: torch.Tensor, value: torch.Tensor): while position < len(token_ids): token_id = token_ids[position] + print(f"Writing token {token_id} at position {position}, current node has {len(current.token_ids)} tokens") child_key = (token_id, relative_position) if child_key in current.children: + print(f"Child key {child_key} found in current.children") child_info = current.children[child_key] assert isinstance(child_info[0], CacheNode) child_hash = child_info[0].hash @@ -172,6 +180,7 @@ def write(self, key: torch.Tensor, value: torch.Tensor): else: # 当前节点已满,需要创建新节点 new_node = CacheNode(parent=current, split_token_idx=len(current.token_ids)) + print(f"Creating new node at split position {new_node.split_token_idx}, parent hash: {current.hash}") # 将token添加到新节点 new_node.token_ids.append(token_ids[position]) @@ -206,6 +215,7 @@ def read(self, key: torch.Tensor) -> torch.Tensor: while position < len(token_ids): token_id = token_ids[position] + print(f"Reading token {token_id} at position {position}, current node has {len(current.token_ids)} tokens") # 检查当前节点的token if relative_position < len(current.token_ids) and current.token_ids[relative_position] == token_id: diff --git a/test/server/test_hicache.py b/test/server/test_hicache.py new file mode 100644 index 000000000..b65b4f9c7 --- /dev/null +++ b/test/server/test_hicache.py @@ -0,0 +1,156 @@ +# test_hicache.py +import torch +import time +import random +from threading import Thread, Event +from queue import Queue +from lightllm.server.router.dynamic_prompt.cache_controller import HiCacheController, CacheNode, BLOCK_SIZE + +class MockMemoryManager: + """模拟内存管理器,仅返回连续的索引值""" + def __init__(self): + self.current_idx = 0 + self.kvcache_store = {} + + def alloc(self, size): + indices = list(range(self.current_idx, self.current_idx + size)) + self.current_idx += size + self.store(indices, torch.tensor([[0] * 512 for _ in range(size)])) + return indices + + def to_kvcache(self, indices): + return torch.tensor([self.kvcache_store[idx].tolist() for idx in indices]) + + def store(self, indices, value): + for idx, val in zip(indices, value): + self.kvcache_store[idx] = val + + def free(self, indices): + for idx in indices: + del self.kvcache_store[idx] + +class MockTask: + def __init__(self, hashs, mode, value=None): + self.hashs = hashs + self.mode = mode + self._ready = Event() + self.data = value + + def ready(self): + return self._ready.is_set() + + def set_ready(self): + self._ready.set() + +class MockService: + def __init__(self): + self.tasks = Queue() + self.running = True + self.worker = Thread(target=self.process_tasks) + self.worker.daemon = True + self.worker.start() + + def process_tasks(self): + while self.running: + if not self.tasks.empty(): + task = self.tasks.get() + # 模拟随机延迟后完成任务 + delay = random.uniform(0.01, 0.1) + time.sleep(delay) + task.set_ready() + print(f"Task for {task.hashs} completed after {delay:.2f}s") + else: + time.sleep(0.01) + + def create(self, hashs, mode, value=None): + task = MockTask(hashs, mode, value) + self.tasks.put(task) + return task + + def commit(self, task): + pass # 在Mock中不需要实现 + + def shutdown(self): + self.running = False + self.worker.join() + +def setup(): + mem_manager = MockMemoryManager() + service = MockService() + hicache = HiCacheController(mem_manager) + hicache.service = service # 注入模拟服务 + + # 预先计算单token大小 + dummy_indices = mem_manager.alloc(1) + kvcache = mem_manager.to_kvcache(dummy_indices[:1]) + token_size = kvcache.nelement() * kvcache.element_size() + print(f"[TEST] Single token KV cache size: {token_size} bytes, Block size: {BLOCK_SIZE}") + + return mem_manager, service, hicache, token_size + +def test_basic_write_read(mem_manager, hicache, token_size): + # 计算每个块可容纳的token数量 + tokens_per_block = BLOCK_SIZE // token_size + print(f"[TEST] Each block can hold {tokens_per_block} tokens") + + # 生成测试数据:刚好占满一个块 + token_ids = list(range(tokens_per_block)) + indices = mem_manager.alloc(len(token_ids)) + + # 写入缓存 + hicache.write(torch.tensor(token_ids), torch.tensor(indices)) + + # 等待任务完成 + time.sleep(0.5) # 确保后台线程处理完成 + + # 读取验证 + result = hicache.read(torch.tensor(token_ids)) + assert result.tolist() == indices, f"Retrieved indices: {result.tolist()}, Expected indices: {indices}" + print(f"[TEST] Basic test passed. Retrieved indices: {result.tolist()}") + +def test_node_splitting(mem_manager, hicache, token_size): + tokens_per_block = BLOCK_SIZE // token_size + # 生成超过一个块的数据 + token_ids = list(range(tokens_per_block + 1)) + indices = mem_manager.alloc(len(token_ids)) + + hicache.write(torch.tensor(token_ids), torch.tensor(indices)) + time.sleep(0.5) + + # 验证根节点应该有子节点 + root = hicache.root + assert len(root.children) > 0 + print(f"\nRoot node has {len(root.children)} children") + + # 读取完整序列 + result = hicache.read(torch.tensor(token_ids)) + assert result.tolist() == indices + print(f"[TEST] Node splitting test passed. Retrieved indices: {result.tolist()}") + +def test_partial_read(mem_manager, hicache): + token_ids = [1,2,3,4,5] + indices = mem_manager.alloc(len(token_ids)) + hicache.write(torch.tensor(token_ids), torch.tensor(indices)) + time.sleep(0.2) + + # 查询存在的部分前缀 + result = hicache.read(torch.tensor([1,2,3])) + assert result.tolist() == indices[:3] + print(f"[TEST] Partial read result: {result.tolist()}") + + # 查询不存在的前缀 + result = hicache.read(torch.tensor([1,2,9])) + assert len(result) == 0 + print(f"[TEST] Non-existent prefix returned: {result.tolist()}") + +def main(): + mem_manager, service, hicache, token_size = setup() + try: + test_basic_write_read(mem_manager, hicache, token_size) + test_node_splitting(mem_manager, hicache, token_size) + test_partial_read(mem_manager, hicache) + finally: + service.shutdown() + +if __name__ == "__main__": + main() \ No newline at end of file From f4bd76ee87d3f560033223b92c4411e631b940c5 Mon Sep 17 00:00:00 2001 From: jinbiaoyu Date: Tue, 8 Jul 2025 19:54:50 +0800 Subject: [PATCH 03/13] support hicache --- lightllm/server/api_cli.py | 3 +- lightllm/server/api_start.py | 4 + lightllm/server/core/objs/start_args_type.py | 1 + .../router/dynamic_prompt/cache_controller.py | 240 ------------------ .../router/dynamic_prompt/hiradix_cache.py | 166 ++++++++---- .../router/dynamic_prompt/radix_cache.py | 2 + lightllm/server/router/manager.py | 1 + .../server/router/model_infer/infer_batch.py | 1 + .../model_infer/mode_backend/base_backend.py | 13 +- test/server/test_hicache.py | 155 ++++++----- 10 files changed, 215 insertions(+), 371 deletions(-) delete mode 100644 lightllm/server/router/dynamic_prompt/cache_controller.py diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index d904c727f..af0473194 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -220,6 +220,7 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument("--disable_dynamic_prompt_cache", action="store_true", help="disable dynamic prompt cache") parser.add_argument("--chunked_prefill_size", type=int, default=4096, help="chunked prefill size") + parser.add_argument("--use_hi_dynamic_prompt_cache", action="store_true", help="enable hierachy prompt cache") parser.add_argument("--disable_chunked_prefill", action="store_true", help="whether to disable chunked prefill") parser.add_argument("--diverse_mode", action="store_true", help="diversity generation mode") parser.add_argument("--token_healing_mode", action="store_true", help="code model infer mode") @@ -326,7 +327,7 @@ def make_argument_parser() -> argparse.ArgumentParser: "--visual_infer_batch_size", type=int, default=1, help="number of images to process in each inference batch" ) parser.add_argument( - "--visual_gpu_ids", nargs="+", type=int, default=None, help="List of GPU IDs to use, e.g., 0 1 2" + "--visual_gpu_ids", nargs="+", type=int, default=[0, 1, 2, 3, 4, 5, 6, 7], help="List of GPU IDs to use, e.g., 0 1 2" ) parser.add_argument("--visual_tp", type=int, default=1, help="number of tensort parallel instances for ViT") parser.add_argument("--visual_dp", type=int, default=1, help="number of data parallel instances for ViT") diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 6e6c27b5e..099987308 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -173,6 +173,10 @@ def normal_or_p_d_start(args): args.batch_max_tokens >= args.chunked_prefill_size ), "chunked prefill mode, batch_max_tokens must >= chunked_prefill_size" + # if use_hi_dynamic_prompt_cache, then use_dynamic_prompt_cache must be True + if args.use_hi_dynamic_prompt_cache: + assert not args.disable_dynamic_prompt_cache, "use_hi_dynamic_prompt_cache must be used with use_dynamic_prompt_cache" + # help to manage data stored on Ceph if "s3://" in args.model_dir: from lightllm.utils.petrel_helper import s3_model_prepare diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index f76fbc8c8..0bb3d1912 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -46,6 +46,7 @@ class StartArgs: dp_prefill_wait_step: int = field(default=0) disable_aggressive_schedule: bool = field(default=False) disable_dynamic_prompt_cache: bool = field(default=False) + use_hi_dynamic_prompt_cache: bool = field(default=False) chunked_prefill_size: int = field(default=8192) disable_chunked_prefill: bool = field(default=False) diverse_mode: bool = field(default=False) diff --git a/lightllm/server/router/dynamic_prompt/cache_controller.py b/lightllm/server/router/dynamic_prompt/cache_controller.py deleted file mode 100644 index e14106923..000000000 --- a/lightllm/server/router/dynamic_prompt/cache_controller.py +++ /dev/null @@ -1,240 +0,0 @@ -import torch -import threading -import time -import json -from typing import Dict, List, Tuple, Optional, Set, Any -from queue import Queue -from lightllm.common.mem_manager import MemoryManager - -BLOCK_SIZE = 16384 - -def get_torch_tensor_size(tensor: torch.Tensor): - return tensor.nelement() * tensor.element_size() - -class CacheNode: - def __init__(self, parent=None, split_token_idx=None): - self.parent = parent # 父节点 - self.split_token_idx = split_token_idx # 从父节点分裂的位置 - self.children = {} # (token_id, split_position) -> (child_node, split_position) - self.cache_indices = [] # 存储kv cache在mem_manager中的索引 - self.token_ids = [] # 当前节点存储的token ids - self.hash = None # 存储在磁盘上的唯一标识 - - def serialize(self): - """将节点数据序列化为JSON""" - data = { - "children": {f"{k[0]}_{k[1]}": [c.hash, p] for k, (c, p) in self.children.items()}, - "cache_indices": self.cache_indices, - "token_ids": self.token_ids, - "split_token_idx": self.split_token_idx - } - return json.dumps(data) - - @classmethod - def deserialize(cls, data_str, parent=None): - """从JSON反序列化节点数据""" - data = json.loads(data_str) - node = cls(parent=parent, split_token_idx=data["split_token_idx"]) - node.cache_indices = data["cache_indices"] - node.token_ids = data["token_ids"] - # 子节点需要单独加载 - return node, {(int(k.split('_')[0]), int(k.split('_')[1])): (v[0], v[1]) for k, v in data["children"].items()} - - -class HiCacheController: - def __init__(self, mem_manager: MemoryManager): - self.mem_manager = mem_manager - self.service = None # 将由外部代码初始化 - - self.root = CacheNode() - self.root.hash = "root" - - self.node_cache = {self.root.hash: self.root} # hash -> node - self.read_queue = Queue() - self.write_queue = Queue() - - self.token_kvcache_size = None # 每个token的kvcache大小 - - # 启动后台线程处理读写任务 - self.running = True - self.poll_thread = threading.Thread(target=self._poll_tasks) - self.poll_thread.daemon = True - self.poll_thread.start() - - def reset(self): - """重置缓存控制器""" - self.running = False - self.poll_thread.join(timeout=1) - - self.root = CacheNode() - self.root.hash = "root" - self.node_cache = {self.root.hash: self.root} - - self.read_queue = Queue() - self.write_queue = Queue() - - self.running = True - self.poll_thread = threading.Thread(target=self._poll_tasks) - self.poll_thread.daemon = True - self.poll_thread.start() - - def _poll_tasks(self): - """轮询读写任务,检查是否完成""" - while self.running: - # 处理读任务 - pending_reads = [] - while not self.read_queue.empty(): - task = self.read_queue.get() - if task.ready(): - # TODO: 将读到的内容存入 memory manager 中 - node_hash = task.hashs[0] - if node_hash in self.node_cache: - node = self.node_cache[node_hash] - node.cache_indices = self.mem_manager.store(node.cache_indices, task.value) - print(f"Node {node_hash} loaded with {len(node.cache_indices)} cache indices") - else: - pending_reads.append(task) - - for task in pending_reads: - self.read_queue.put(task) - - # 处理写任务 - pending_writes = [] - while not self.write_queue.empty(): - task = self.write_queue.get() - if not task.ready(): - pending_writes.append(task) - - for task in pending_writes: - self.write_queue.put(task) - - time.sleep(0.01) # 避免CPU过度使用 - - def _ensure_node_loaded(self, node_hash): - """确保节点已加载到内存中""" - if node_hash not in self.node_cache and node_hash != "root": - task = self.service.create(hashs=[node_hash], mode="r") - self.service.commit(task) - self.read_queue.put(task) - # 需要等待节点加载完成 - while not task.ready() or node_hash not in self.node_cache: - time.sleep(0.01) - - def _persist_node(self, node): - """将节点持久化到磁盘""" - print(f"Persisting node {node.hash} with {len(node.token_ids)} tokens") - if not node.hash: - # 为新节点生成hash - node.hash = f"node_{id(node)}_{time.time()}" - - # TODO: 将对应的kvcache写入磁盘 - task = self.service.create(hashs=[node.hash], value=self.mem_manager.to_kvcache(node.cache_indices), mode="w") - self.service.commit(task) - self.write_queue.put(task) - self.node_cache[node.hash] = node - - def write(self, key: torch.Tensor, value: torch.Tensor): - """ - 写入token序列及其对应的KV缓存索引 - key: token_ids序列 - value: 对应的KV缓存索引 - """ - token_ids = key.cpu().tolist() - indices = value.cpu().tolist() - - # 首次计算每个token的kvcache大小 - if self.token_kvcache_size is None: - kvcache = self.mem_manager.to_kvcache(indices[:1]) # 计算单个token的kvcache - self.token_kvcache_size = get_torch_tensor_size(kvcache) - print(f"Single token KV cache size: {self.token_kvcache_size} bytes, Block size: {BLOCK_SIZE}") - - current = self.root - position = 0 - relative_position = 0 - - while position < len(token_ids): - token_id = token_ids[position] - print(f"Writing token {token_id} at position {position}, current node has {len(current.token_ids)} tokens") - child_key = (token_id, relative_position) - - if child_key in current.children: - print(f"Child key {child_key} found in current.children") - child_info = current.children[child_key] - assert isinstance(child_info[0], CacheNode) - child_hash = child_info[0].hash - self._ensure_node_loaded(child_hash) - current = self.node_cache[child_hash] - position += 1 - relative_position = 0 # next time relative pos is 0 - else: - # 计算当前节点剩余空间 - remaining_space = BLOCK_SIZE - len(current.cache_indices) * self.token_kvcache_size - - if self.token_kvcache_size <= remaining_space: - # 当前节点有足够空间 - current.token_ids.append(token_ids[position]) - current.cache_indices.append(indices[position]) - position += 1 - relative_position += 1 - self._persist_node(current) - else: - # 当前节点已满,需要创建新节点 - new_node = CacheNode(parent=current, split_token_idx=len(current.token_ids)) - print(f"Creating new node at split position {new_node.split_token_idx}, parent hash: {current.hash}") - - # 将token添加到新节点 - new_node.token_ids.append(token_ids[position]) - new_node.cache_indices.append(indices[position]) - position += 1 - relative_position = 0 # next time relative pos is 0, not affecting child_key - - # 建立父子关系 - current.children[child_key] = (new_node, len(current.cache_indices)) - - # 持久化 - self._persist_node(new_node) - self._persist_node(current) - - current = new_node - - # 确保最后修改的节点被持久化 - self._persist_node(current) - - def read(self, key: torch.Tensor) -> torch.Tensor: - """ - 读取token序列对应的KV缓存索引 - key: token_ids序列 - 返回: 对应的KV缓存索引 - """ - token_ids = key.cpu().tolist() - result_indices = [] - - current = self.root - position = 0 - relative_position = 0 - - while position < len(token_ids): - token_id = token_ids[position] - print(f"Reading token {token_id} at position {position}, current node has {len(current.token_ids)} tokens") - - # 检查当前节点的token - if relative_position < len(current.token_ids) and current.token_ids[relative_position] == token_id: - # TODO: 将读到的东西存到 result_indices 中 - position += 1 - relative_position += 1 - continue - - # 查找子节点 - child_key = (token_id, relative_position) - if child_key in current.children: - child_info = current.children[child_key] - assert isinstance(child_info[0], CacheNode) - child_hash = child_info[0].hash - self._ensure_node_loaded(child_hash) - current = self.node_cache[child_hash] - relative_position = 0 - else: - # 未找到匹配的路径 - return torch.tensor(result_indices, dtype=torch.int64) - - return torch.tensor(result_indices, dtype=torch.int64) diff --git a/lightllm/server/router/dynamic_prompt/hiradix_cache.py b/lightllm/server/router/dynamic_prompt/hiradix_cache.py index 151bcd53a..31a306f67 100644 --- a/lightllm/server/router/dynamic_prompt/hiradix_cache.py +++ b/lightllm/server/router/dynamic_prompt/hiradix_cache.py @@ -1,64 +1,128 @@ import torch -from .cache_controller import HiCacheController +import time +import tempfile +import numpy as np +import torch.distributed as dist +from os.path import join from .radix_cache import RadixCache, TreeNode, match from typing import Tuple, Dict, Set, List from lightllm.common.mem_manager import MemoryManager +from lightllm.utils.log_utils import init_logger +from threading import Lock +from enum import Enum +from .shared_arr import SharedArray +from kvcache.python.jit import PyLocalCacheService +logger = init_logger(__name__) -class HiRadixCache(RadixCache): - def __init__(self, cache_controller: HiCacheController, unique_name, total_token_num, rank_in_node, mem_manager: MemoryManager = None): - super().__init__(unique_name, total_token_num, rank_in_node, mem_manager) - self.cache_controller = cache_controller - - def _insert_helper(self, node: TreeNode, key, value): - if node.is_leaf(): - self.evict_tree_set.discard(node) +def wait_until_ready(task, timeout=10.0, check_interval=0.01): + start_time = time.time() + while not task.ready(): + time.sleep(check_interval) + if time.time() - start_time > timeout: + logger.error("Current kv cache task not ready in time") + return False + return True + +class LocalCacheManager: - try: - first_key_id = key[0].item() - if first_key_id in node.children.keys(): - child: TreeNode = node.children[first_key_id] - prefix_len = match(key, child.token_id_key) - if prefix_len == len(key): - if child.is_leaf(): - self.evict_tree_set.discard(child) - child.update_time() - if child.is_leaf(): - self.evict_tree_set.add(child) - return prefix_len + def __init__(self, unique_name: str, rank_in_node: int, mem_manager): + tmp_dir = tempfile.mkdtemp(prefix=f"cache_{unique_name}_{rank_in_node}") + self.cache_file = join(tmp_dir, "cache_file") + all_buffers = mem_manager.kv_buffer + all_buffers = all_buffers.view(all_buffers.shape[0], all_buffers.shape[1], -1) - elif prefix_len < len(key) and prefix_len < len(child.token_id_key): - if child.is_leaf(): - self.evict_tree_set.discard(child) + self.py_cache_service = PyLocalCacheService( + file=self.cache_file, + storage_size=128 * (1024 ** 3), # 128GB + num_shard=32, + kvcache_tensor=all_buffers, + num_worker=8 + ) - key = key[prefix_len:] - value = value[prefix_len:] - split_parent_node = child.split_node(prefix_len) - new_node = split_parent_node.add_and_return_new_child(key, value) - # update total token num - self.tree_total_tokens_num.arr[0] += len(new_node.token_mem_index_value) + def insert(self, tokens, kv_page_indexer, start_pos=0): + t = self.py_cache_service.create( + tokens=tokens, + kv_page_indexer=kv_page_indexer, + mode="w", + start_pos=start_pos) + res = wait_until_ready(t) + if not res: + self.py_cache_service.az5(t) - if split_parent_node.is_leaf(): - self.evict_tree_set.add(split_parent_node) - if new_node.is_leaf(): - self.evict_tree_set.add(new_node) + def read(self, tokens, kv_page_indexer, start_pos=0): + t = self.py_cache_service.create( + tokens=tokens, + kv_page_indexer=kv_page_indexer, + mode="r", + start_pos=start_pos) + res = wait_until_ready(t) + return res + + def query(self, tokens): + query_result = self.py_cache_service.query(tokens) + max_len = 0 + for result in query_result: + if result: + max_len += 1 + else: + break + return max_len * self.block_size + + @property + def block_size(self,): + return self.py_cache_service.tokens_per_block + +class HiRadixCache(RadixCache): + def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager): + super().__init__(unique_name, total_token_num, rank_in_node, mem_manager) + self.rank_in_node = rank_in_node + self.local_cache_manager = LocalCacheManager( + unique_name, + rank_in_node, + mem_manager, + ) + self.is_hi_radix_cache = True + self.disk_cache_match_count = SharedArray(f"{unique_name}_disk_cache_match_count_{rank_in_node}", (1,), dtype=np.int64) + self.disk_cache_match_count.arr[0] = 0 + self.total_match_count = SharedArray(f"{unique_name}_total_match_count_{rank_in_node}", (1,), dtype=np.int64) + self.total_match_count.arr[0] = 0 + self.disk_cache_match_ratio = SharedArray(f"{unique_name}_disk_cache_match_ratio_{rank_in_node}", (1,), dtype=np.float32) + self.disk_cache_match_ratio.arr[0] = 0.0 + logger.info(f"Initializing HiRadixCache {rank_in_node}") - if child.is_leaf(): - self.evict_tree_set.add(child) - return prefix_len - elif prefix_len < len(key) and prefix_len == len(child.token_id_key): - return prefix_len + self._insert_helper(child, key[prefix_len:], value[prefix_len:]) - else: - assert False, "can not run to here" + def insert(self, key, value=None): + share_len = super().insert(key, value) + if share_len == 0: + return 0 + self.local_cache_manager.insert(key, value) + return share_len + def match_prefix(self, key, update_refs=False): + assert len(key) != 0 + self.total_match_count.arr[0] += 1 + ans_value_list = [] + ans_value = None + tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=False) + if tree_node.node_prefix_total_len != 0: + ans_value = torch.concat(ans_value_list) + max_len = 0 + if tree_node.node_prefix_total_len < len(key): + max_len = self.local_cache_manager.query(key) + if max_len > tree_node.node_prefix_total_len: + pull_len = max_len - tree_node.node_prefix_total_len + self.disk_cache_match_count.arr[0] += 1 + self.disk_cache_match_ratio.arr[0] = self.disk_cache_match_count.arr[0] / self.total_match_count.arr[0] + self.free_radix_cache_to_get_enough_token(pull_len) + buffers = self.mem_manager.alloc(pull_len) + start_pos = 0 + if ans_value is not None: + buffers = torch.concat([ans_value, buffers]) + start_pos = (tree_node.node_prefix_total_len - 1) // self.local_cache_manager.block_size * self.local_cache_manager.block_size + logger.debug(f"HiCache current match ratio {self.disk_cache_match_ratio.arr[0]}, pulled cache len {pull_len} from disk") + res = self.local_cache_manager.read(tokens=key[:max_len], kv_page_indexer=buffers, start_pos=start_pos) + if res: + super().insert(key[:max_len], buffers) else: - new_node = node.add_and_return_new_child(key, value) - # update total token num - self.tree_total_tokens_num.arr[0] += len(new_node.token_mem_index_value) - if new_node.is_leaf(): - self.evict_tree_set.add(new_node) - return 0 - finally: - node.update_time() - if node.is_leaf(): - self.evict_tree_set.add(node) + self.mem_manager.free(buffers[tree_node.node_prefix_total_len:]) + return super().match_prefix(key, update_refs=update_refs) diff --git a/lightllm/server/router/dynamic_prompt/radix_cache.py b/lightllm/server/router/dynamic_prompt/radix_cache.py index 65ec4354b..e4c34bc85 100644 --- a/lightllm/server/router/dynamic_prompt/radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/radix_cache.py @@ -123,6 +123,8 @@ def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager: Memo ) self.tree_total_tokens_num.arr[0] = 0 + self.is_hi_radix_cache = False + def insert(self, key, value=None): if value is None: value = key diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index c10847e3f..f5fcbb4cb 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -158,6 +158,7 @@ async def wait_to_model_ready(self): "return_all_prompt_logprobs": self.args.return_all_prompt_logprobs, "use_reward_model": self.args.use_reward_model, "disable_dynamic_prompt_cache": self.args.disable_dynamic_prompt_cache, + "use_hi_dynamic_prompt_cache": self.args.use_hi_dynamic_prompt_cache, "data_type": self.args.data_type, "eos_id": self.eos_id, "diverse_mode": self.args.diverse_mode, diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 10b68245c..0774244b6 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -110,6 +110,7 @@ def free_a_req_mem(self, free_token_index: List, req: "InferReq", is_group_finis self.radix_cache.dec_node_ref_counter(req.shared_kv_node) req.shared_kv_node = None + def _save_promptcache_kvbuffer(self): """ save prompt cache kv buffer diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index dd1ea45fe..c34b2ad97 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -8,6 +8,8 @@ from lightllm.models import get_model from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache from lightllm.server.router.model_infer.infer_batch import InferReq +from lightllm.server.router.dynamic_prompt.hiradix_cache import HiRadixCache +from lightllm.server.router.model_infer.infer_batch import InferReq, InferSamplingParams from lightllm.server.router.token_load import TokenLoad from lightllm.common.basemodel.infer_lock import g_infer_state_lock, InferStateLock from lightllm.common.basemodel.basemodel import TpPartBaseModel @@ -54,6 +56,7 @@ def init_model(self, kvargs): self.chunked_prefill_size = self.args.chunked_prefill_size self.return_all_prompt_logprobs = self.args.return_all_prompt_logprobs self.use_dynamic_prompt_cache = not self.args.disable_dynamic_prompt_cache + self.use_hi_dynamic_prompt_cache = self.args.use_hi_dynamic_prompt_cache self.eos_id: List[int] = kvargs.get("eos_id", [2]) self.disable_cudagraph = self.args.disable_cudagraph @@ -118,7 +121,15 @@ def init_model(self, kvargs): self.model: TpPartBaseModel = self.model # for easy typing set_random_seed(2147483647) self.radix_cache = ( - RadixCache( + HiRadixCache( + get_unique_server_name(), + self.model.mem_manager.size, + self.rank_in_node, + mem_manager=self.model.mem_manager, + max_seq_length=kvargs.get("max_seq_length", 1024 * 5), + ) + if self.use_dynamic_prompt_cache and self.use_hi_dynamic_prompt_cache + else RadixCache( get_unique_server_name(), self.model.mem_manager.size, self.rank_in_node, diff --git a/test/server/test_hicache.py b/test/server/test_hicache.py index b65b4f9c7..bb82457c4 100644 --- a/test/server/test_hicache.py +++ b/test/server/test_hicache.py @@ -4,10 +4,18 @@ import random from threading import Thread, Event from queue import Queue -from lightllm.server.router.dynamic_prompt.cache_controller import HiCacheController, CacheNode, BLOCK_SIZE +from lightllm.server.router.dynamic_prompt.cache_controller import ( + HiCacheController, + CacheNode, + BLOCK_SIZE, + HiHostService, + HiHostTask, +) + class MockMemoryManager: """模拟内存管理器,仅返回连续的索引值""" + def __init__(self): self.current_idx = 0 self.kvcache_store = {} @@ -15,134 +23,124 @@ def __init__(self): def alloc(self, size): indices = list(range(self.current_idx, self.current_idx + size)) self.current_idx += size - self.store(indices, torch.tensor([[0] * 512 for _ in range(size)])) + self.store(indices, torch.tensor([[random.randint(0, 0xFFFF) for __ in range(512)] for _ in range(size)])) return indices - + + def load_index_kv_buffer(self, index, load_tensor_dict): + self.kvcache_store[index] = load_tensor_dict["kv_buffer"] + + def get_index_kv_buffer(self, index): + return {"kv_buffer": self.kvcache_store[index]} + def to_kvcache(self, indices): + assert all( + [idx in self.kvcache_store for idx in indices] + ), f"Not all of {indices} are not found in kvcache_store" return torch.tensor([self.kvcache_store[idx].tolist() for idx in indices]) - + def store(self, indices, value): - for idx, val in zip(indices, value): - self.kvcache_store[idx] = val - + print(f"[TEST:MemManager] Storing {value.shape} at {indices}") + for idx, value_dim in zip(indices, range(value.shape[0])): + self.kvcache_store[idx] = value[value_dim] + print(f"[TEST:MemManager] Stored {value[value_dim].shape} at {idx}") + return indices + def free(self, indices): + print(f"[TEST:MemManager] Freeing {indices}") for idx in indices: del self.kvcache_store[idx] -class MockTask: - def __init__(self, hashs, mode, value=None): - self.hashs = hashs - self.mode = mode - self._ready = Event() - self.data = value - - def ready(self): - return self._ready.is_set() - - def set_ready(self): - self._ready.set() - -class MockService: - def __init__(self): - self.tasks = Queue() - self.running = True - self.worker = Thread(target=self.process_tasks) - self.worker.daemon = True - self.worker.start() - - def process_tasks(self): - while self.running: - if not self.tasks.empty(): - task = self.tasks.get() - # 模拟随机延迟后完成任务 - delay = random.uniform(0.01, 0.1) - time.sleep(delay) - task.set_ready() - print(f"Task for {task.hashs} completed after {delay:.2f}s") - else: - time.sleep(0.01) - - def create(self, hashs, mode, value=None): - task = MockTask(hashs, mode, value) - self.tasks.put(task) - return task - - def commit(self, task): - pass # 在Mock中不需要实现 - - def shutdown(self): - self.running = False - self.worker.join() def setup(): mem_manager = MockMemoryManager() - service = MockService() + service = HiHostService() hicache = HiCacheController(mem_manager) hicache.service = service # 注入模拟服务 - + + indices = mem_manager.alloc(5) + print(mem_manager.to_kvcache(indices)) + # 预先计算单token大小 dummy_indices = mem_manager.alloc(1) kvcache = mem_manager.to_kvcache(dummy_indices[:1]) token_size = kvcache.nelement() * kvcache.element_size() print(f"[TEST] Single token KV cache size: {token_size} bytes, Block size: {BLOCK_SIZE}") - + return mem_manager, service, hicache, token_size + def test_basic_write_read(mem_manager, hicache, token_size): # 计算每个块可容纳的token数量 tokens_per_block = BLOCK_SIZE // token_size print(f"[TEST] Each block can hold {tokens_per_block} tokens") - + # 生成测试数据:刚好占满一个块 token_ids = list(range(tokens_per_block)) indices = mem_manager.alloc(len(token_ids)) - + kvcache = mem_manager.to_kvcache(indices) + print(f"[TEST] Generated KV cache with shape: {kvcache.shape}, type: {kvcache.dtype}") + # 写入缓存 hicache.write(torch.tensor(token_ids), torch.tensor(indices)) - + time.sleep(2) + # 等待任务完成 - time.sleep(0.5) # 确保后台线程处理完成 - + hicache.service.wait_till_all_finished() + + mem_manager.free(indices) + # 读取验证 result = hicache.read(torch.tensor(token_ids)) - assert result.tolist() == indices, f"Retrieved indices: {result.tolist()}, Expected indices: {indices}" - print(f"[TEST] Basic test passed. Retrieved indices: {result.tolist()}") + result = mem_manager.to_kvcache(result.tolist()) + assert result.eq(kvcache).all(), f"Retrieved kvcache: {result}, Expected kvcache: {kvcache}" + print("[TEST] Basic test passed. Retrieved kvcache\n\n") + def test_node_splitting(mem_manager, hicache, token_size): tokens_per_block = BLOCK_SIZE // token_size # 生成超过一个块的数据 - token_ids = list(range(tokens_per_block + 1)) + token_ids = list(range(12, 12 + tokens_per_block * 3 + 1)) indices = mem_manager.alloc(len(token_ids)) - + kvcache = mem_manager.to_kvcache(indices) + hicache.write(torch.tensor(token_ids), torch.tensor(indices)) - time.sleep(0.5) - + time.sleep(2) + hicache.service.wait_till_all_finished() + # 验证根节点应该有子节点 root = hicache.root assert len(root.children) > 0 print(f"\nRoot node has {len(root.children)} children") - + # 读取完整序列 result = hicache.read(torch.tensor(token_ids)) - assert result.tolist() == indices - print(f"[TEST] Node splitting test passed. Retrieved indices: {result.tolist()}") + result = mem_manager.to_kvcache(result.tolist()) + assert result.eq(kvcache).all(), f"Retrieved kvcache: {result}, Expected kvcache: {kvcache}" + print(f"[TEST] Node splitting test passed. Retrieved kvcache: {result.shape}\n\n") + def test_partial_read(mem_manager, hicache): - token_ids = [1,2,3,4,5] + token_ids = [97, 98, 99, 100, 101, 102] indices = mem_manager.alloc(len(token_ids)) + kvcache = mem_manager.to_kvcache(indices) hicache.write(torch.tensor(token_ids), torch.tensor(indices)) - time.sleep(0.2) - + time.sleep(2) + hicache.service.wait_till_all_finished() + # 查询存在的部分前缀 - result = hicache.read(torch.tensor([1,2,3])) - assert result.tolist() == indices[:3] - print(f"[TEST] Partial read result: {result.tolist()}") - + result = hicache.read(torch.tensor([97, 98, 99])) + result = mem_manager.to_kvcache(result.tolist()) + assert result.eq(kvcache[:3]).all() + print("[TEST] Partial read passed") + # 查询不存在的前缀 - result = hicache.read(torch.tensor([1,2,9])) - assert len(result) == 0 + result = hicache.read(torch.tensor([97, 98, 100])) + assert len(result) == 2 + result = mem_manager.to_kvcache(result.tolist()) + assert result.eq(kvcache[:2]).all() print(f"[TEST] Non-existent prefix returned: {result.tolist()}") + def main(): mem_manager, service, hicache, token_size = setup() try: @@ -152,5 +150,6 @@ def main(): finally: service.shutdown() + if __name__ == "__main__": - main() \ No newline at end of file + main() From a0ae71dddd5bc4ae606d7cc4232b3e9861c4dade Mon Sep 17 00:00:00 2001 From: jinbiaoyu Date: Mon, 14 Jul 2025 14:02:54 +0800 Subject: [PATCH 04/13] asynchronous hi radix cahce --- lightllm/common/basemodel/basemodel.py | 25 +- lightllm/common/radixmem_buffer.py | 192 ++++++++++++ lightllm/models/deepseek2/model.py | 24 +- lightllm/models/qwen2/model.py | 24 +- lightllm/server/api_start.py | 9 +- lightllm/server/core/objs/__init__.py | 2 +- lightllm/server/core/objs/req.py | 93 ++++++ lightllm/server/httpserver/manager.py | 20 +- lightllm/server/router/batch.py | 13 +- .../dynamic_prompt/disk_cache_server.py | 292 ++++++++++++++++++ .../router/dynamic_prompt/hiradix_cache.py | 133 ++++---- .../server/router/dynamic_prompt/io_objs.py | 39 +++ .../server/router/dynamic_prompt/manager.py | 154 +++++++++ lightllm/server/router/manager.py | 44 ++- .../server/router/model_infer/infer_batch.py | 26 +- .../model_infer/mode_backend/__init__.py | 1 + .../model_infer/mode_backend/base_backend.py | 18 +- .../chunked_prefill/impl_for_hiradix_cache.py | 26 ++ .../prefill_kv_move_manager.py | 1 - .../server/router/model_infer/model_rpc.py | 30 +- 20 files changed, 1074 insertions(+), 92 deletions(-) create mode 100644 lightllm/common/radixmem_buffer.py create mode 100644 lightllm/server/router/dynamic_prompt/disk_cache_server.py create mode 100644 lightllm/server/router/dynamic_prompt/io_objs.py create mode 100644 lightllm/server/router/dynamic_prompt/manager.py create mode 100644 lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_hiradix_cache.py diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index ff3290233..c0aae6b1d 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -74,6 +74,8 @@ def __init__(self, kvargs): self.quant_type = kvargs.get("quant_type", "none") self.quant_cfg_path = kvargs.get("quant_cfg", None) self.mem_fraction = kvargs.get("mem_fraction", 0.9) + self.enable_hiradix_cache = kvargs.get("use_hiradix_cache", False) + self.radix_lock = kvargs.get("radix_lock", None) self.tp_world_size_ = get_dp_world_size() self.enable_tpsp_mix_mode = get_env_start_args().enable_tpsp_mix_mode @@ -162,14 +164,35 @@ def _init_weights(self): def _init_mem_manager(self): assert self.config["num_attention_heads"] % self.tp_world_size_ == 0 + max_radix_token_num = 10000 self.mem_manager = MemoryManager( - self.max_total_token_num, + self.max_total_token_num - max_radix_token_num, dtype=self.data_type, head_num=self.config["num_attention_heads"] // self.tp_world_size_, head_dim=self.config["n_embed"] // self.config["num_attention_heads"], layer_num=self.config["n_layer"], mem_fraction=self.mem_fraction, ) + + if self.enable_hiradix_cache: + from lightllm.common.radixmem_buffer import RadixMemoryBuffer, init_shared_data, get_shared_data, MemPropties + mem_propties = MemPropties( + max_radix_token_num, + dtype=self.data_type, + head_num=self.config["num_attention_heads"] // self.tp_world_size_, + head_dim=self.config["n_embed"] // self.config["num_attention_heads"], + layer_num=self.config["n_layer"] + ) + init_shared_data( + mem_propties=mem_propties + ) + self.radix_mem_buffer = RadixMemoryBuffer( + mem_propties, + shared_data=get_shared_data(), + lock=self.radix_lock + ) + self.mem_propties = mem_propties + self.shared_mem_data = get_shared_data() return def _init_kv_move_buffer(self): diff --git a/lightllm/common/radixmem_buffer.py b/lightllm/common/radixmem_buffer.py new file mode 100644 index 000000000..8dcf220dd --- /dev/null +++ b/lightllm/common/radixmem_buffer.py @@ -0,0 +1,192 @@ + +import torch +from dataclasses import dataclass +import torch.multiprocessing as mp +from lightllm.utils.log_utils import init_logger +from typing import List, Union +from lightllm.utils.envs_utils import get_unique_server_name +from lightllm.utils.dist_utils import get_current_rank_in_node +from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt +from multiprocessing.managers import DictProxy +from multiprocessing import Manager + + +logger = init_logger(__name__) + +@dataclass +class SharedRadixMemoryData: + kv_buffer: torch.Tensor + mem_state: torch.Tensor + req_mem_index: DictProxy + + +@dataclass +class MemPropties: + size: int + dtype: torch.dtype + head_num: int + head_dim: int + layer_num: int + +shared_mem_data: SharedRadixMemoryData = None + +def init_shared_data(mem_propties: MemPropties, device="cuda"): + size, dtype, head_num, head_dim, layer_num = mem_propties.size, mem_propties.dtype, \ + mem_propties.head_num, mem_propties.head_dim, mem_propties.layer_num + global shared_mem_data + + if device == "cuda": + kv_buffer = torch.empty( + (layer_num, size, head_num, head_dim), + dtype=dtype, + device="cuda" + ) + else: + kv_buffer = torch.empty( + (layer_num, size, head_num, head_dim), + dtype=dtype, + device="cpu" + ).share_memory_() + + mem_state = torch.arange(size, dtype=torch.int32).share_memory_() + + manager = Manager() + req_mem_index = manager.dict() + shared_mem_data = SharedRadixMemoryData( + kv_buffer=kv_buffer, + mem_state=mem_state, + req_mem_index=req_mem_index + ) + +def get_shared_data() -> SharedRadixMemoryData: + """Get the shared memory data.""" + global shared_mem_data + if shared_mem_data is None: + raise RuntimeError("Shared memory data has not been initialized. Call init_shared_data first.") + return shared_mem_data + +class RadixMemoryBuffer: + def __init__(self, mem_propties: MemPropties, shared_data: SharedRadixMemoryData = None, lock: mp.Lock = None, device="cuda", + rank_in_node=None): + size, dtype, head_num, head_dim, layer_num = mem_propties.size, mem_propties.dtype, \ + mem_propties.head_num, mem_propties.head_dim, mem_propties.layer_num + if shared_data is not None: + self.kv_buffer = shared_data.kv_buffer + self.mem_state = shared_data.mem_state + self.req_mem_index = shared_data.req_mem_index + else: + # CPU 上分配 key 和 value(共 2 * head_num) + if device == "cuda": + self.kv_buffer = torch.empty( + (layer_num, size, 2 * head_num, head_dim), + dtype=dtype, + device="cuda" + ) + else: + self.kv_buffer = torch.empty( + (layer_num, size, 2 * head_num, head_dim), + dtype=dtype, + device="cpu" + ).share_memory_() + self.mem_state = torch.arange( + 0, size, dtype=torch.int32 + ).share_memory_() + self.req_mem_index = mp.Manager().dict() + self.lock = lock if lock is not None else mp.Lock() + #TODO profile size + self.size = size # token slot 个数 + self.head_num = head_num + self.head_dim = head_dim + self.layer_num = layer_num + self.dtype = dtype + + can_use_mem_size = self.size + mark_start = 0 + mark_end = self.size + rank_in_node = rank_in_node if rank_in_node is not None else get_current_rank_in_node() + self.can_use_mem_size = SharedInt( + f"{get_unique_server_name()}_radix_mem_manger_can_use_token_num_{rank_in_node}" + ) + self.can_use_mem_size.set_value(can_use_mem_size) + self.mark_start = SharedInt( + f"{get_unique_server_name()}_radix_mem_manger_mark_start_{rank_in_node}" + ) + self.mark_start.set_value(mark_start) + + self.mark_end = SharedInt( + f"{get_unique_server_name()}_radix_mem_manger_mark_end_{rank_in_node}" + ) + self.mark_end.set_value(mark_end) + logger.info(f"create {get_unique_server_name()}_radix_mem_manger_can_use_token_num_{rank_in_node}") + + def _free(self, free_index: Union[torch.Tensor, List[int]]): + """_summary_ + + Args: + free_index (torch.Tensor): _description_ + """ + end = self.mark_start.get_value() + start = end - len(free_index) + assert start >= 0, f"error free state start: {end} free len {len(free_index)}" + + if isinstance(free_index, list): + self.mem_state.numpy()[start:end] = free_index + else: + # 从 gpu 到 cpu 的拷贝操作是流内阻塞操作 + self.mem_state[start:end] = free_index + + self.mark_start.set_value(end - len(free_index)) + + self.can_use_mem_size.set_value(self.can_use_mem_size.get_value() + len(free_index)) + + if self.can_use_mem_size.get_value() == len(self.mem_state): + logger.debug(f"freed all gpu mem size {self.can_use_mem_size.get_value()}") + return + + def free_req_index(self, req_id: int): + """Free the memory index for a specific request ID.""" + with self.lock: + if req_id not in self.req_mem_index: + logger.warning(f"Request ID {req_id} not found in memory index.") + return + index = self.req_mem_index[req_id] + self._free(index) + logger.info(f"Freed memory index for request {req_id} size {len(index)}, left size {self.can_use_mem_size.get_value()}") + del self.req_mem_index[req_id] + + def alloc(self, need_size) -> torch.Tensor: + with self.lock: + if need_size > self.mark_end.get_value() - self.mark_start.get_value(): + logger.error(f"warn no enough cache need_size {need_size} left_size {self.can_use_mem_size.get_value()}") + raise RuntimeError(f"Not enough memory to allocate {need_size} tokens.") + + start = self.mark_start.get_value() + end = start + need_size + ans = self.mem_state[start:end] + self.mark_start.set_value(start + need_size) + + self.can_use_mem_size.set_value(self.can_use_mem_size.get_value() - need_size) + return ans + + def set_req_mem_index(self, req_id: int, index: List[int]): + """Set the memory index for a specific request ID.""" + with self.lock: + if req_id in self.req_mem_index: + logger.info(f"Request ID {req_id} already exists. Overwriting index {self.req_mem_index[req_id]} with {index}.") + self.req_mem_index[req_id] = index + logger.info(f"radix mem buffer insert req {req_id}, current disk work num {self._get_current_work_num()}") + + def get_req_mem_index(self, req_id: int) -> List[int]: + """Get the memory index for a specific request ID.""" + with self.lock: + if req_id not in self.req_mem_index: + logger.warning(f"Request ID {req_id} not found. Returning empty list.") + return [] + return self.req_mem_index[req_id] + + def get_kv_buffer(self, index) -> torch.Tensor: + with self.lock: + return self.kv_buffer[:, index, :, :] + + def _get_current_work_num(self) -> int: + return len(self.req_mem_index) diff --git a/lightllm/models/deepseek2/model.py b/lightllm/models/deepseek2/model.py index 9101cb963..60a97f11b 100644 --- a/lightllm/models/deepseek2/model.py +++ b/lightllm/models/deepseek2/model.py @@ -102,15 +102,35 @@ def _init_mem_manager(self): added_mtp_layer_num = 0 if get_env_start_args().mtp_mode == "deepseekv3": added_mtp_layer_num += get_env_start_args().mtp_step - + + max_radix_token_num = 300000 self.mem_manager = manager_class( - self.max_total_token_num, + self.max_total_token_num - max_radix_token_num, dtype=self.data_type, head_num=1, head_dim=self.config["kv_lora_rank"] + self.config["qk_rope_head_dim"], layer_num=self.config["num_hidden_layers"] + added_mtp_layer_num, mem_fraction=self.mem_fraction, ) + if self.enable_hiradix_cache: + from lightllm.common.radixmem_buffer import RadixMemoryBuffer, init_shared_data, get_shared_data, MemPropties + mem_propties = MemPropties( + max_radix_token_num, + dtype=self.data_type, + head_num=1, + head_dim=self.config["kv_lora_rank"] + self.config["qk_rope_head_dim"], + layer_num=self.config["num_hidden_layers"] + added_mtp_layer_num, + ) + init_shared_data( + mem_propties=mem_propties + ) + self.radix_mem_buffer = RadixMemoryBuffer( + mem_propties, + shared_data=get_shared_data(), + lock=self.radix_lock + ) + self.mem_propties = mem_propties + self.shared_mem_data = get_shared_data() return def _init_weights(self): diff --git a/lightllm/models/qwen2/model.py b/lightllm/models/qwen2/model.py index e3d8de461..252380294 100644 --- a/lightllm/models/qwen2/model.py +++ b/lightllm/models/qwen2/model.py @@ -41,12 +41,34 @@ def _init_mem_manager(self): head_dim_ = self.config["hidden_size"] // self.config["num_attention_heads"] head_dim_ = self.config.get("head_dim", head_dim_) tp_k_head_num_ = max(self.config["num_key_value_heads"] // self.tp_world_size_, 1) + + max_radix_token_num = 10000 self.mem_manager = select_mem_manager_class(self.mode)( - self.max_total_token_num, + self.max_total_token_num - max_radix_token_num, dtype=self.data_type, head_num=tp_k_head_num_, head_dim=head_dim_, layer_num=self.config["num_hidden_layers"], mem_fraction=self.mem_fraction, ) + + if self.enable_hiradix_cache: + from lightllm.common.radixmem_buffer import RadixMemoryBuffer, init_shared_data, get_shared_data, MemPropties + mem_propties = MemPropties( + max_radix_token_num, + dtype=self.data_type, + head_num=2 * tp_k_head_num_, + head_dim=head_dim_, + layer_num=self.config["num_hidden_layers"], + ) + init_shared_data( + mem_propties=mem_propties + ) + self.radix_mem_buffer = RadixMemoryBuffer( + mem_propties, + shared_data=get_shared_data(), + lock=self.radix_lock + ) + self.mem_propties = mem_propties + self.shared_mem_data = get_shared_data() return diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 099987308..d69e0176d 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -174,6 +174,7 @@ def normal_or_p_d_start(args): ), "chunked prefill mode, batch_max_tokens must >= chunked_prefill_size" # if use_hi_dynamic_prompt_cache, then use_dynamic_prompt_cache must be True + hiradix_cache_port_num = 0 if args.use_hi_dynamic_prompt_cache: assert not args.disable_dynamic_prompt_cache, "use_hi_dynamic_prompt_cache must be used with use_dynamic_prompt_cache" @@ -205,8 +206,11 @@ def normal_or_p_d_start(args): ports_locker.lock_port() node_world_size = args.tp // args.nnodes + + if args.use_hi_dynamic_prompt_cache: + hiradix_cache_port_num = node_world_size can_use_ports = alloc_can_use_network_port( - num=7 + node_world_size + args.visual_dp * args.visual_tp, used_nccl_ports=already_uesd_ports + num=7 + node_world_size + args.visual_dp * args.visual_tp + hiradix_cache_port_num, used_nccl_ports=already_uesd_ports ) logger.info(f"alloced ports: {can_use_ports}") ( @@ -234,6 +238,9 @@ def normal_or_p_d_start(args): args.audio_port = audio_port args.cache_port = cache_port args.metric_port = metric_port + if args.use_hi_dynamic_prompt_cache: + args.hiradix_cache_ports = can_use_ports[0:node_world_size] + can_use_ports = can_use_ports[node_world_size:] # 申请在 p d 分离模式下,会用的端口 args.pd_node_infer_rpyc_ports = can_use_ports[0:node_world_size] diff --git a/lightllm/server/core/objs/__init__.py b/lightllm/server/core/objs/__init__.py index 5594df6a0..d66e3bd86 100644 --- a/lightllm/server/core/objs/__init__.py +++ b/lightllm/server/core/objs/__init__.py @@ -1,5 +1,5 @@ from .sampling_params import SamplingParams -from .req import Req, FinishStatus +from .req import Req, FinishStatus, RadixStatus from .shm_req_manager import ShmReqManager from .rpc_shm import RpcShmParams, RpcShmResults, ShmSyncStatusArray from .start_args_type import StartArgs diff --git a/lightllm/server/core/objs/req.py b/lightllm/server/core/objs/req.py index f2ebadad1..bdf9a6d53 100644 --- a/lightllm/server/core/objs/req.py +++ b/lightllm/server/core/objs/req.py @@ -54,6 +54,96 @@ def set_token_ids(self, ids: List[int]): def get_token_ids(self): return list(self.data[: self.size]) +# class RadixStatus(ctypes.Structure): +# _pack_ = 4 +# _fields_ = [("status", ctypes.c_int * 32)] + +# NOCACHE = -2 +# NOT_READY = -1 +# READ_READY = 1 +# WRITE_READY = 2 + +# def __init__(self, init_state=NOT_READY): +# self.status = init_state + +# def set_status(self, new_status: int): +# assert new_status in (self.NOCACHE, self.NOT_READY, self.READ_READY, self.WRITE_READY) +# self.status = new_status + +# def get_status(self) -> int: +# return self.status + +# def no_need_cache(self) -> bool: +# return self.status == self.NOCACHE + +# def is_read_ready(self) -> bool: +# return self.status == self.READ_READY + +# def is_write_ready(self) -> bool: +# return self.status == self.WRITE_READY + +# def is_not_ready(self) -> bool: +# return self.status == self.NOT_READY +class RadixStatus(ctypes.Structure): + _pack_ = 4 + _fields_ = [("status", ctypes.c_int * 32)] + + NOCACHE = -2 + NOT_READY = -1 + READ_READY = 1 + WRITE_READY = 2 + + def __init__(self, init_state=NOT_READY): + for i in range(32): + self.status[i] = init_state + + def set_status(self, idx: int, new_status: int): + assert 0 <= idx < 32, f"Index out of range: {idx}" + assert new_status in (self.NOCACHE, self.NOT_READY, self.READ_READY, self.WRITE_READY) + self.status[idx] = new_status + + def get_status(self, idx: int) -> int: + assert 0 <= idx < 32, f"Index out of range: {idx}" + return self.status[idx] + + def is_no_need_cache(self, idx: int) -> bool: + return self.get_status(idx) == self.NOCACHE + + def is_read_ready(self, idx: int) -> bool: + return self.get_status(idx) == self.READ_READY + + def is_write_ready(self, idx: int) -> bool: + return self.get_status(idx) == self.WRITE_READY + + def is_not_ready(self, idx: int) -> bool: + return self.get_status(idx) == self.NOT_READY + + # def all_dp_read_ready_or_nocache(self, indexs: List[int]) -> bool: + # for i in indexs: + # if self.status[i] not in (self.READ_READY, self.NOCACHE): + # return False + # return True + def all_dp_read_ready_or_nocache(self, indexs: List[int]) -> bool: + # return np.all(self.status == self.READ_READY) + # for i in indexs: + # if self.status[i] not in (self.READ_READY, self.NOCACHE): + # return False + # return True + return np.all(np.array(self.status[indexs]) == self.READ_READY) or np.all(np.array(self.status[indexs]) == self.NOCACHE) + + # def all_read_ready_or_nocache(self) -> bool: + # for i in range(32): + # if self.status[i] not in (self.READ_READY, self.NOCACHE): + # return False + # return True + def all_read_ready_or_nocache(self) -> bool: + return np.all(np.array(self.status) == self.READ_READY) or np.all(np.array(self.status) == self.NOCACHE) + + # def all_read_ready(self) -> bool: + # return np.all(self.status == self.READ_READY) + + # def all_no_need_cache(self) -> bool: + # return np.all(self.status == self.NOCACHE) class Req(ctypes.Structure): _pack_ = 4 @@ -98,6 +188,8 @@ class Req(ctypes.Structure): ("mtp_accepted_token_num", ctypes.c_int), # mtp_step 保存一个mtp使用的常量参数,用于快速访问,不会被外部输入初始化 ("_mtp_step", ctypes.c_int), + # 用于标记当前请求的radix状态 + ("radix_status", RadixStatus), ] def get_str(self): @@ -151,6 +243,7 @@ def init( self.shm_prompt_ids.arr[0 : len(prompt_ids)] = prompt_ids self.mtp_accepted_token_num = 0 self._mtp_step = get_env_start_args().mtp_step + self.radix_status = RadixStatus(RadixStatus.NOT_READY) self.post_init() diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index fa455c225..61a2f75e9 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -51,7 +51,10 @@ def __init__( context = zmq.asyncio.Context(2) self.send_to_router = context.socket(zmq.PUSH) self.send_to_router.connect(f"{args.zmq_mode}127.0.0.1:{router_port}") - + if self.args.use_hi_dynamic_prompt_cache: + context_hiradix = zmq.asyncio.Context() + self.send_to_hiradix = context_hiradix.socket(zmq.PUSH) + self.send_to_hiradix.bind(f"{args.zmq_mode}127.0.0.1:55555") self.multinode_req_manager = None self.nnodes = args.nnodes self._shm_lock_pool = AtomicShmArrayLock(f"{get_unique_server_name()}_lightllm_resource_lock", 1) @@ -476,10 +479,17 @@ async def transfer_to_next_module( protocol=pickle.HIGHEST_PROTOCOL, ) else: - self.send_to_router.send_pyobj( - group_req_objs.to_group_req_index(), - protocol=pickle.HIGHEST_PROTOCOL, - ) + if self.args.use_hi_dynamic_prompt_cache: + logger.info(f"send_to_hiradix {group_req_objs.to_group_req_index()}") + self.send_to_hiradix.send_pyobj( + group_req_objs.to_group_req_index(), + protocol=pickle.HIGHEST_PROTOCOL + ) + else: + self.send_to_router.send_pyobj( + group_req_objs.to_group_req_index(), + protocol=pickle.HIGHEST_PROTOCOL, + ) return assert False, "dead code path" diff --git a/lightllm/server/router/batch.py b/lightllm/server/router/batch.py index 14a987f49..f356f4ef6 100644 --- a/lightllm/server/router/batch.py +++ b/lightllm/server/router/batch.py @@ -40,8 +40,14 @@ def get_req_list_for_dp(self, dp_index: int): req_list.append(req) return req_list - def filter_out_finished_req(self, shm_req_manager: ShmReqManager): + def release_reqs(self, reqs: List[Req], shm_req_manager: ShmReqManager): + for req in reqs: + shm_req_manager.put_back_req_obj(req) + + def filter_out_finished_req(self): unfinished_req_ids = [] + finished_reqs = [] + for req in self.reqs: # 更新aborted 标记,可以触发推理进程主动退出aborted的请求。 if req.is_aborted: @@ -49,14 +55,13 @@ def filter_out_finished_req(self, shm_req_manager: ShmReqManager): if req.shm_infer_released: logger.info(f"router release req id {req.request_id}") - shm_req_manager.put_back_req_obj(req) - req = None + finished_reqs.append(req) else: unfinished_req_ids.append(req.request_id) self.reqs = [self.id_to_reqs[req_id] for req_id in unfinished_req_ids] self.id_to_reqs = {req.request_id: req for req in self.reqs} - return + return finished_reqs def pop_req(self, req_id): self.reqs = [req for req in self.reqs if req.request_id != req_id] diff --git a/lightllm/server/router/dynamic_prompt/disk_cache_server.py b/lightllm/server/router/dynamic_prompt/disk_cache_server.py new file mode 100644 index 000000000..08a29ca56 --- /dev/null +++ b/lightllm/server/router/dynamic_prompt/disk_cache_server.py @@ -0,0 +1,292 @@ +import torch +import time +import tempfile +import rpyc +import zmq +import inspect +import asyncio +import threading +import numpy as np +import torch.multiprocessing as mp +from typing import List, Union +from rpyc.utils.server import ThreadedServer +from os.path import join +from typing import Tuple, Dict, Set, List +from lightllm.utils.log_utils import init_logger +from enum import Enum +from .shared_arr import SharedArray +from .io_objs import ShmReqInfo, GroupReqInfo +from lightllm.server.core.objs.io_objs import GroupReqIndexes +from lightllm.server.core.objs import ShmReqManager +from lightllm.utils.graceful_utils import graceful_registry +from lightllm.common.radixmem_buffer import RadixMemoryBuffer +from lightllm.server.core.objs import Req, RadixStatus + +logger = init_logger(__name__) + +def wait_until_ready(task, timeout=10.0, check_interval=0.01): + start_time = time.time() + while not task.ready(): + time.sleep(check_interval) + if time.time() - start_time > timeout: + logger.error("Current kv cache task not ready in time") + return False + return True + +class RemoteCacheManager: + def __init__(self, unique_name: str, rank_in_node: int, mem_manager): + tmp_dir = tempfile.mkdtemp(prefix=f"cache_{unique_name}_{rank_in_node}") + self.cache_file = join(tmp_dir, "cache_file") + all_buffers = mem_manager.kv_buffer + all_buffers = all_buffers.view(all_buffers.shape[0], all_buffers.shape[1], -1) + from kvcache.python.jit import PyLocalCacheService + + self.py_cache_service = PyLocalCacheService( + file=self.cache_file, + storage_size=128 * (1024 ** 3), # 128GB + num_shard=32, + kvcache_tensor=all_buffers, + num_worker=8 + ) + + def insert(self, tokens, kv_page_indexer, start_pos=0): + t = self.py_cache_service.create( + tokens=tokens, + kv_page_indexer=kv_page_indexer, + mode="w", + start_pos=start_pos) + res = wait_until_ready(t) + if not res: + self.py_cache_service.az5(t) + + def read(self, tokens, kv_page_indexer, start_pos=0): + t = self.py_cache_service.create( + tokens=tokens, + kv_page_indexer=kv_page_indexer, + mode="r", + start_pos=start_pos) + res = wait_until_ready(t) + return res + + def query(self, tokens): + query_result = self.py_cache_service.query(tokens) + max_len = 0 + for result in query_result: + if result: + max_len += 1 + else: + break + return max_len * self.block_size + + @property + def block_size(self,): + return self.py_cache_service.tokens_per_block + + +class DiskCacheService(rpyc.Service): + def __init__(self, mem_manager=None, remote_cache_manager=None, shm_req_manager=None, rank_in_node=None): + super().__init__() + self.mem_manager = mem_manager + self.remote_cache_manager = remote_cache_manager + self.shm_req_manager = shm_req_manager + self.rank_in_node = rank_in_node + + def exposed_push(self, req_info): + req_info: ShmReqInfo = ShmReqInfo.from_dict(req_info) + req: Req = self.shm_req_manager.get_req_obj_by_index(req_info.shm_req_index) + req.link_prompt_ids_shm_array() + assert req.radix_status.is_write_ready(self.rank_in_node), "radix cache is not ready" + input_token_ids = req.shm_prompt_ids.arr[0 : req.shm_cur_kv_len] + keys = torch.tensor(input_token_ids, dtype=torch.int64, device="cpu") + values = self.mem_manager.get_req_mem_index(req_info.request_id) + index = torch.tensor(values, device="cpu", dtype=torch.int32) + logger.info(f"_push_task_loop receive task keys {len(keys)} values {len(values)}") + self.remote_cache_manager.insert(keys, index) + self.mem_manager.free_req_index(req.request_id) + self.set_reqs_radix_status([req], RadixStatus.NOT_READY) + self.shm_req_manager.put_back_req_obj(req) + return {"status": "ok"} + + def set_reqs_radix_status(self, reqs: List[Req], status: int): + for req in reqs: + req.radix_status.set_status(self.rank_in_node, status) + logger.info(f"-->pull loop rank_in_node={self.rank_in_node} set req {req.group_req_id, req.request_id} radix status {req.radix_status.get_status(self.rank_in_node)}") + + def put_back_req_objs(self, reqs: List[Req]): + for req in reqs: + self.shm_req_manager.put_back_req_obj(req) + + def exposed_pull(self, group_req): + group_req: GroupReqInfo = GroupReqInfo.from_dict(group_req) + reqs: List[Req] = [] + for shm_req_index in group_req.shm_req_indexes: + req: Req = self.shm_req_manager.get_req_obj_by_index(shm_req_index) + reqs.append(req) + req = reqs[0] + req.link_prompt_ids_shm_array() + keys = req.get_prompt_ids() + query_len = self.remote_cache_manager.query(tokens=keys) + if query_len == 0: + self.set_reqs_radix_status(reqs, RadixStatus.NOCACHE) + return {"query_len": 0, "kv_indices": []} + index = self.mem_manager.alloc(query_len) + self.remote_cache_manager.read(tokens=keys[:query_len], kv_page_indexer=index) + self.mem_manager.set_req_mem_index( + group_req.group_req_id, index.tolist() + ) + self.set_reqs_radix_status(reqs, RadixStatus.READ_READY) + self.put_back_req_objs(reqs) + return {"query_len": query_len, "kv_indices": index.tolist()} + + +class DiskCacheClient: + def __init__(self, rank_in_node: int, service=None, use_rpc=True, proc=None): + self.rank_in_node = rank_in_node + self.use_rpc = use_rpc + self.service = service + self.proc=proc + if self.use_rpc: + self._push = self._async_wraper(self.service.push) + self._pull = self._async_wraper(self.service.pull) + else: + self._push = self.service.exposed_push + self._pull = self.service.exposed_pull + + def _async_wraper(self, func): + async_func = rpyc.async_(func) + + async def _wrapped(*args, **kwargs): + result = async_func(*args, **kwargs) + await asyncio.to_thread(result.wait) + return result.value + + return _wrapped + + async def push(self, req_info: ShmReqInfo): + if self.use_rpc: + return await self._insert(req_info) + else: + return self._insert(req_info) + + async def pull(self, group_req: GroupReqInfo): + if self.use_rpc: + return await self._pull(group_req) + else: + return self._pull(group_req) + + +def start_cache_server(mem_manager, remote_cache_manager, shm_req_manager, rank_in_node, port, init_event): + class CustomService(DiskCacheService): + def __init__(self): + super().__init__(mem_manager, remote_cache_manager, shm_req_manager, rank_in_node) + + def start(): + try: + server = ThreadedServer(CustomService(), + port=port, + protocol_config={"allow_public_attrs": True, "allow_pickle": True}) + init_event.set() + server.start() + except Exception as e: + logger.error(f"Failed to start ThreadedServer: {e}") + + t = threading.Thread(target=start, daemon=True) + t.start() + + logger.info(f"DiskCacheService started on port {port}") + return t + + +def _init_server( + device_id, + mem_queue, + radix_lock: List[mp.Lock], + init_event: mp.Event, + port:int=18861 +): + from lightllm.utils.envs_utils import get_unique_server_name + graceful_registry(inspect.currentframe().f_code.co_name) + torch.cuda.set_device(device_id) + mem_proties, shared_mem_data = mem_queue.get() + mem_manager = RadixMemoryBuffer( + mem_propties=mem_proties, + shared_data=shared_mem_data, + lock=radix_lock, + rank_in_node=device_id + ) + remote_cache_manager = RemoteCacheManager( + unique_name=get_unique_server_name(), + rank_in_node=device_id, + mem_manager=mem_manager, + ) + shm_req_manager = ShmReqManager() + + t = start_cache_server( + mem_manager=mem_manager, + remote_cache_manager=remote_cache_manager, + shm_req_manager=shm_req_manager, + rank_in_node=device_id, + port=port, + init_event=init_event + ) + t.join() + return + +async def start_disk_cache_server_process( + args, + device_id, + node_word_size, + mem_queue, + radix_lock, + port +): + """ + Start the DiskCacheManager in process. + """ + from lightllm.utils.envs_utils import get_unique_server_name + if node_word_size == 1: + mem_proties, shared_mem_data = mem_queue.get() + mem_manager = RadixMemoryBuffer( + mem_propties=mem_proties, + shared_data=shared_mem_data, + lock=radix_lock, + rank_in_node=device_id + ) + remote_cache_manager = RemoteCacheManager( + unique_name=get_unique_server_name(), + rank_in_node=device_id, + mem_manager=mem_manager, + ) + shm_req_manager = ShmReqManager() + service = DiskCacheService(mem_manager, remote_cache_manager, shm_req_manager) + client = DiskCacheClient( + service=service, + rank_in_node=0, + use_rpc=False + ) + return client + + init_event = mp.Event() + proc = mp.Process(target=_init_server, args=(device_id, mem_queue, radix_lock, init_event, port)) + proc.start() + + init_event.wait(timeout=60) + + max_wait_times = 20 + for i in range(max_wait_times): + try: + conn = rpyc.connect("localhost", port, config={"allow_pickle": True}) + break + except Exception as e: + asyncio.sleep(2) + + service = conn.root + client = DiskCacheClient( + rank_in_node=device_id, + service=service, + use_rpc=True, + proc=proc + ) + assert proc.is_alive() + logger.info(f"disk cache process for device {device_id} start!") + return client \ No newline at end of file diff --git a/lightllm/server/router/dynamic_prompt/hiradix_cache.py b/lightllm/server/router/dynamic_prompt/hiradix_cache.py index 31a306f67..68e217870 100644 --- a/lightllm/server/router/dynamic_prompt/hiradix_cache.py +++ b/lightllm/server/router/dynamic_prompt/hiradix_cache.py @@ -1,87 +1,94 @@ import torch import time import tempfile +import zmq +import inspect +import threading import numpy as np +import torch.multiprocessing as mp import torch.distributed as dist from os.path import join from .radix_cache import RadixCache, TreeNode, match from typing import Tuple, Dict, Set, List from lightllm.common.mem_manager import MemoryManager +from lightllm.common.radixmem_buffer import RadixMemoryBuffer from lightllm.utils.log_utils import init_logger from threading import Lock from enum import Enum from .shared_arr import SharedArray -from kvcache.python.jit import PyLocalCacheService +from .io_objs import ShmReqInfo +from lightllm.server.core.objs import Req, RadixStatus +from lightllm.server.core.objs.io_objs import GroupReqIndexes +from lightllm.server.core.objs import ShmReqManager +from lightllm.utils.graceful_utils import graceful_registry +from lightllm.server.core.objs import Req, SamplingParams, FinishStatus, ShmReqManager logger = init_logger(__name__) -def wait_until_ready(task, timeout=10.0, check_interval=0.01): - start_time = time.time() - while not task.ready(): - time.sleep(check_interval) - if time.time() - start_time > timeout: - logger.error("Current kv cache task not ready in time") - return False - return True class LocalCacheManager: - def __init__(self, unique_name: str, rank_in_node: int, mem_manager): - tmp_dir = tempfile.mkdtemp(prefix=f"cache_{unique_name}_{rank_in_node}") - self.cache_file = join(tmp_dir, "cache_file") - all_buffers = mem_manager.kv_buffer - all_buffers = all_buffers.view(all_buffers.shape[0], all_buffers.shape[1], -1) + def __init__(self, mem_buffer: RadixMemoryBuffer, mem_manager: MemoryManager, rank_in_node): + self.mem_buffer = mem_buffer + self.mem_manager = mem_manager + self.rank_in_node = rank_in_node - self.py_cache_service = PyLocalCacheService( - file=self.cache_file, - storage_size=128 * (1024 ** 3), # 128GB - num_shard=32, - kvcache_tensor=all_buffers, - num_worker=8 + def insert(self, req: Req, key: torch.Tensor, value=None): + index = self.mem_buffer.alloc(len(value)) + pre_index = self.mem_buffer.get_req_mem_index(req.request_id) + self.mem_buffer.set_req_mem_index( + req.request_id, index.tolist() ) + if len(pre_index) != 0: + logger.info(f"pre index {pre_index}, index {index}") + index = index[len(pre_index):] + value = value[len(pre_index):] + dst_kv_buffer = self.mem_buffer.get_kv_buffer(index) + src_kv_buffer = self.mem_manager.get_index_kv_buffer(value)["kv_buffer"] + logger.info(f"insert mem_buffer shape {dst_kv_buffer.shape}, manager buffer shape {src_kv_buffer.shape}") + assert len(src_kv_buffer) == len(dst_kv_buffer), f"src kv buffer len {len(src_kv_buffer)} != dst kv buffer len {len(dst_kv_buffer)}" + self.copy_kv_from_gpu_to_cpu(src_kv_buffer, dst_kv_buffer) + req.radix_status.set_status(self.rank_in_node, RadixStatus.WRITE_READY) - def insert(self, tokens, kv_page_indexer, start_pos=0): - t = self.py_cache_service.create( - tokens=tokens, - kv_page_indexer=kv_page_indexer, - mode="w", - start_pos=start_pos) - res = wait_until_ready(t) - if not res: - self.py_cache_service.az5(t) - - def read(self, tokens, kv_page_indexer, start_pos=0): - t = self.py_cache_service.create( - tokens=tokens, - kv_page_indexer=kv_page_indexer, - mode="r", - start_pos=start_pos) - res = wait_until_ready(t) - return res + def read(self, req: Req, dst_index): + try: + index = self.mem_buffer.get_req_mem_index(req.group_req_id) + src_kv_buffer = self.mem_buffer.get_kv_buffer(index[-len(dst_index)]) + dst_kv_buffer = self.mem_manager.get_index_kv_buffer(dst_index)["kv_buffer"] + logger.info(f"len mem src index and dst index {len(index), len(dst_index)} read mem_buffer shape {src_kv_buffer.shape}, manager buffer shape {dst_kv_buffer.shape}") + assert len(src_kv_buffer) == len(dst_kv_buffer), f"src kv buffer len {len(src_kv_buffer)} != dst kv buffer len {len(dst_kv_buffer)}" + self.copy_kv_from_cpu_to_gpu(src_kv_buffer, dst_kv_buffer) + self.mem_buffer.free_req_index(req.group_req_id) + except Exception as e: + logger.error(f"Local cache read from radix mem_buffer error {e}") + return False + return True - def query(self, tokens): - query_result = self.py_cache_service.query(tokens) - max_len = 0 - for result in query_result: - if result: - max_len += 1 - else: - break - return max_len * self.block_size + def query(self, req: Req): + if req.radix_status.is_no_need_cache(self.rank_in_node): + return 0 + if req.radix_status.is_read_ready(self.rank_in_node): + index = self.mem_buffer.get_req_mem_index(req.group_req_id) + return len(index) + return 0 + + def copy_kv_from_cpu_to_gpu(self, src_kv_tensor, dst_kv_tensor): + dst_kv_tensor.copy_(src_kv_tensor, non_blocking=True) + + def copy_kv_from_gpu_to_cpu(self, src_kv_tensor, dst_kv_tensor): + dst_kv_tensor.copy_(src_kv_tensor, non_blocking=True) - @property - def block_size(self,): - return self.py_cache_service.tokens_per_block class HiRadixCache(RadixCache): - def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager): + def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager, mem_buffer, radix_info_queue): super().__init__(unique_name, total_token_num, rank_in_node, mem_manager) self.rank_in_node = rank_in_node self.local_cache_manager = LocalCacheManager( - unique_name, - rank_in_node, - mem_manager, + mem_buffer=mem_buffer, + mem_manager=mem_manager, + rank_in_node=rank_in_node ) + self.radix_info_queue = radix_info_queue self.is_hi_radix_cache = True self.disk_cache_match_count = SharedArray(f"{unique_name}_disk_cache_match_count_{rank_in_node}", (1,), dtype=np.int64) self.disk_cache_match_count.arr[0] = 0 @@ -91,14 +98,16 @@ def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager): self.disk_cache_match_ratio.arr[0] = 0.0 logger.info(f"Initializing HiRadixCache {rank_in_node}") - def insert(self, key, value=None): - share_len = super().insert(key, value) - if share_len == 0: + def insert(self, key, value=None, req=None): + if len(key) == 0: return 0 - self.local_cache_manager.insert(key, value) + share_len = super().insert(key, value) + if req is None: + return + self.local_cache_manager.insert(req, key, value) return share_len - def match_prefix(self, key, update_refs=False): + def match_prefix(self, req, key, update_refs=False): assert len(key) != 0 self.total_match_count.arr[0] += 1 ans_value_list = [] @@ -108,7 +117,8 @@ def match_prefix(self, key, update_refs=False): ans_value = torch.concat(ans_value_list) max_len = 0 if tree_node.node_prefix_total_len < len(key): - max_len = self.local_cache_manager.query(key) + max_len = self.local_cache_manager.query(req) + logger.debug(f"HiCache rank_in_node={self.rank_in_node} current match radix len {tree_node.node_prefix_total_len}, max len {max_len}") if max_len > tree_node.node_prefix_total_len: pull_len = max_len - tree_node.node_prefix_total_len self.disk_cache_match_count.arr[0] += 1 @@ -120,7 +130,8 @@ def match_prefix(self, key, update_refs=False): buffers = torch.concat([ans_value, buffers]) start_pos = (tree_node.node_prefix_total_len - 1) // self.local_cache_manager.block_size * self.local_cache_manager.block_size logger.debug(f"HiCache current match ratio {self.disk_cache_match_ratio.arr[0]}, pulled cache len {pull_len} from disk") - res = self.local_cache_manager.read(tokens=key[:max_len], kv_page_indexer=buffers, start_pos=start_pos) + # res = self.local_cache_manager.read(tokens=key[:max_len], kv_page_indexer=buffers, start_pos=start_pos) + res = self.local_cache_manager.read(req, buffers) if res: super().insert(key[:max_len], buffers) else: diff --git a/lightllm/server/router/dynamic_prompt/io_objs.py b/lightllm/server/router/dynamic_prompt/io_objs.py new file mode 100644 index 000000000..7a1da21b2 --- /dev/null +++ b/lightllm/server/router/dynamic_prompt/io_objs.py @@ -0,0 +1,39 @@ +from dataclasses import dataclass +from lightllm.server.core.objs import Req +from typing import List + +@dataclass +class ShmReqInfo: + request_id: int + shm_req_index: int + + def to_dict(self): + return { + "request_id": self.request_id, + "shm_req_index": self.shm_req_index + } + + @staticmethod + def from_dict(d): + return GroupReqInfo( + request_id=d["request_id"], + shm_req_index=d["shm_req_index"] + ) + +@dataclass +class GroupReqInfo: + group_req_id: int + shm_req_indexes: List[int] + + def to_dict(self): + return { + "group_req_id": self.group_req_id, + "shm_req_indexes": self.shm_req_indexes + } + + @staticmethod + def from_dict(d): + return GroupReqInfo( + group_req_id=d["group_req_id"], + shm_req_indexes=d["shm_req_indexes"] + ) \ No newline at end of file diff --git a/lightllm/server/router/dynamic_prompt/manager.py b/lightllm/server/router/dynamic_prompt/manager.py new file mode 100644 index 000000000..53a5a1917 --- /dev/null +++ b/lightllm/server/router/dynamic_prompt/manager.py @@ -0,0 +1,154 @@ +import time +import zmq +import zmq.asyncio +import inspect +import pickle +import torch.multiprocessing as mp +import threading +import asyncio +from typing import List +from dataclasses import dataclass +from lightllm.server.core.objs.io_objs import GroupReqIndexes +from lightllm.utils.log_utils import init_logger, log_time_ready +from lightllm.utils.graceful_utils import graceful_registry +from .disk_cache_server import DiskCacheClient +from .io_objs import ShmReqInfo, GroupReqInfo + +logger = init_logger(__name__) + +class HiRadixCacheManagerServer: + def __init__( + self, args, mem_queues: List[mp.Queue], radix_locks: List[mp.Lock], router_port: int): + self.args = args + self.mem_queues = mem_queues + self.radix_locks = radix_locks + self.node_world_size = args.tp // args.nnodes + self.disk_cache_processes = [] + self.ports = args.hiradix_cache_ports + self.cache_server_client = [] + context = zmq.asyncio.Context(2) + self.recv_from_httpserver = context.socket(zmq.PULL) + self.recv_from_httpserver.connect(f"{args.zmq_mode}127.0.0.1:55555") + self.clients: List[DiskCacheClient] = [] + self.send_to_router = context.socket(zmq.PUSH) + self.send_to_router.connect(f"{args.zmq_mode}127.0.0.1:{router_port}") + logger.info(f"send_to_router {args.zmq_mode}127.0.0.1:{router_port} ") + self.recv_from_router = context.socket(zmq.PULL) + self.recv_from_router.connect(f"{args.zmq_mode}127.0.0.1:66666") + + async def asyn_init(self): + self.pull_queue = asyncio.Queue() + self.push_queue = asyncio.Queue() + + async def start_all(self): + from lightllm.server.router.dynamic_prompt.disk_cache_server import start_disk_cache_server_process + for rank_in_node in range(self.node_world_size): + client = await start_disk_cache_server_process( + self.args, + device_id=rank_in_node, + node_word_size=self.node_world_size, + mem_queue=self.mem_queues[rank_in_node], + radix_lock=self.radix_locks[rank_in_node], + port=self.ports[rank_in_node] + ) + self.clients.append(client) + + async def pull_cache(self, group_req): + tasks = [] + group_req_info = GroupReqInfo( + group_req_id=group_req.group_req_id, + shm_req_indexes=group_req.shm_req_indexes + ).to_dict() + for client in self.clients: + task = client.pull(group_req_info) + tasks.append(task) + all_results = await asyncio.gather(*tasks) + logger.info(f"pull cache results {all_results}") + await self.send_to_router.send_pyobj(group_req, protocol=pickle.HIGHEST_PROTOCOL) + logger.info(f"send to router pyobj {group_req}") + + async def push_cache(self, req_info): + tasks = [] + for client in self.clients: + task = client.push(req_info) + tasks.append(task) + all_results = await asyncio.gather(*tasks) + logger.info(f"push cache results {all_results}") + + async def pull_woker(self): + while True: + req: ShmReqInfo = await self.pull_queue.get() + await self.pull_cache(req) + await asyncio.sleep(0.01) + + async def push_woker(self): + while True: + req: GroupReqInfo = await self.push_queue.get() + await self.push_cache(req.to_dict()) + await asyncio.sleep(0.01) + + async def run(self): + await self.asyn_init() + await asyncio.gather( + self.loop_for_netio_req_to_pull(), + self.pull_woker(), + self.loop_for_netio_req_to_push(), + self.push_woker() + ) + + async def loop_for_netio_req_to_push(self): + while True: + recv_req: ShmReqInfo = await self.recv_from_router.recv_pyobj() + if isinstance(recv_req, ShmReqInfo): + await self.push_queue.put() + else: + raise ValueError(f"Invalid request: {recv_req}") + + async def loop_for_netio_req_to_pull(self): + while True: + recv_req: GroupReqIndexes = await self.recv_from_httpserver.recv_pyobj() + if isinstance(recv_req, GroupReqIndexes): + await self.pull_queue.put(recv_req) + else: + raise ValueError(f"Invalid request: {recv_req}") + +def _init_env_server( + args, + mem_queues, + radix_locks: List[mp.Lock], + init_event: mp.Event, + router_port: int +): + graceful_registry(inspect.currentframe().f_code.co_name) + hiradix_cache_manager = HiRadixCacheManagerServer( + args, + mem_queues=mem_queues, + radix_locks=radix_locks, + router_port=router_port + ) + asyncio.run(hiradix_cache_manager.start_all()) + try: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + init_event.set() + loop.run_until_complete(hiradix_cache_manager.run()) + except Exception as e: + logger.error(f"hiradix server error happend {e}") + return + +def start_hiradix_cache_manager_process_server( + args, + radix_mem_queues: List[mp.Queue], + radix_locks: List[mp.Lock], + router_port: int +): + """ + Start the HiRadix cache manager process. + """ + init_event = mp.Event() + proc = mp.Process(target=_init_env_server, args=(args, radix_mem_queues, radix_locks, init_event, router_port)) + proc.start() + init_event.wait() + logger.info(f"HiRadix cache manager process started") + assert proc.is_alive() + return \ No newline at end of file diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index f5fcbb4cb..28479ce0b 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -11,7 +11,7 @@ import torch.multiprocessing as mp import torch.distributed as dist import multiprocessing -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union from .batch import Batch, Req from .model_infer.model_rpc import start_model_process, ModelRpcClient from .req_queue import build_req_queue @@ -56,7 +56,7 @@ def __init__(self, args: StartArgs, router_port, detokenization_port, metric_por self.read_only_statics_mem_manager = ReadOnlyStaticsMemoryManager() # 初始化 radix_cache_client 用于读取 prompt cache 的管理信息 self.radix_cache_client = None - + self.use_hiradix_cache = args.use_hi_dynamic_prompt_cache and not args.disable_dynamic_prompt_cache self.mtp_step = args.mtp_step # 共享变量,用于存储router端调度分析得到的机器负载信息 @@ -75,11 +75,17 @@ def __init__(self, args: StartArgs, router_port, detokenization_port, metric_por self.max_wait_tokens = args.router_max_wait_tokens context = zmq.Context(2) self.recv_from_httpserver = context.socket(zmq.PULL) + logger.info(f"recv_from_httpserver {args.zmq_mode}127.0.0.1:{router_port} ") self.recv_from_httpserver.bind(f"{args.zmq_mode}127.0.0.1:{router_port}") self.send_to_detokenization = context.socket(zmq.PUSH) self.send_to_detokenization.connect(f"{args.zmq_mode}127.0.0.1:{detokenization_port}") + self.router_port = router_port + context_to_radix = zmq.asyncio.Context() + self.send_to_hiradix_server = context_to_radix.socket(zmq.PUSH) + self.send_to_hiradix_server.connect(f"{args.zmq_mode}127.0.0.1:{66666}") + if self.is_multinode_tp: self.mulitnode_group = dist.init_process_group( backend="gloo", @@ -114,6 +120,15 @@ async def wait_to_model_ready(self): self.mem_queues: List[torch.multiprocessing.Queue] = [ torch.multiprocessing.Queue() for _ in range(self.node_world_size) ] + self.radix_mem_queues: List[Union[torch.multiprocessing.Queue, None]] = [ + torch.multiprocessing.Queue() if self.use_hiradix_cache else None for _ in range(self.node_world_size) + ] + self.radix_info_queues: List[Union[torch.multiprocessing.Queue, None]] = [ + torch.multiprocessing.Queue() if self.use_hiradix_cache else None for _ in range(self.node_world_size) + ] + self.radix_locks: List[Union[torch.multiprocessing.Lock, None]] = [ + torch.multiprocessing.Lock() if self.use_hiradix_cache else None for _ in range(self.node_world_size) + ] self.rpc_event = multiprocessing.Event() self.rpc_finished_event = multiprocessing.Event() @@ -130,6 +145,9 @@ async def wait_to_model_ready(self): info_queue=self.info_queue, mem_queue=self.mem_queues[(rank_id % node_world_size)], router_lock=self.router_lock, + radix_mem_queue=self.radix_mem_queues[(rank_id % node_world_size)], + radix_info_queue=self.radix_info_queues[(rank_id % node_world_size)], + radix_lock=self.radix_locks[(rank_id % node_world_size)] ) self.model_rpc_servers.append(rpc_model) @@ -202,6 +220,11 @@ async def wait_to_model_ready(self): ) start_decode_kv_move_manager_process(self.args, self.info_queue, self.mem_queues) + + if self.use_hiradix_cache: + # 启动 hi radix cache 管理进程 + from lightllm.server.router.dynamic_prompt.manager import start_hiradix_cache_manager_process_server + start_hiradix_cache_manager_process_server(self.args, self.radix_mem_queues, self.radix_locks, self.router_port) return @@ -308,7 +331,7 @@ async def _step(self): self._add_new_batch_to_running_batch(new_batch=new_batch) await self._prefill_batch(new_batch) self.stats_tool.count_prompt_tokens(new_batch) - self._filter_reqs_from_running_batch() + await self._filter_reqs_from_running_batch() self.has_wait_tokens = 0 # Check if need pause some requests for decode. @@ -325,7 +348,7 @@ async def _step(self): # Decode self.stats_tool.count_output_tokens(self.running_batch) await self._decode_batch() - self._filter_reqs_from_running_batch() + await self._filter_reqs_from_running_batch() self.has_wait_tokens += 1 return @@ -355,9 +378,11 @@ def _add_new_batch_to_running_batch(self, new_batch: Batch): self.running_batch.merge(new_batch) return - def _filter_reqs_from_running_batch(self): + async def _filter_reqs_from_running_batch(self): if self.running_batch is not None: - self.running_batch.filter_out_finished_req(self.shm_req_manager) + finishs_reqs = self.running_batch.filter_out_finished_req() + await self._send_hiradix_manager(finishs_reqs) + self.running_batch.release_reqs(finishs_reqs, self.shm_req_manager) if self.running_batch.is_clear(): self.running_batch = None return @@ -369,6 +394,13 @@ def _can_decode(self, batch: Batch, dp_index: int): batch.get_batch_decode_need_tokens()[dp_index] + self.get_used_tokens(dp_index) <= self.max_total_token_num ) + async def _send_hiradix_manager(self, reqs): + if not self.use_hiradix_cache: + return + for req in reqs: + await self.send_to_hiradix_server.send_pyobj(req, protocol=pickle.HIGHEST_PROTOCOL) + return + def _send_detokenization_pack(self): # 发 mtp_step + 1 个 None 包触发一下 detokenization, 因为在开启 mtp feature 以后,每一步 # 生成的 token 数量最多为 mtp_step + 1 个,如果不及时触发 detokenization, 会带来一些性能 diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 0774244b6..3904dd628 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -33,9 +33,10 @@ class InferenceContext: vocab_size = None overlap_stream: torch.cuda.Stream = None # 一些情况下推理进程进行异步折叠操作的异步流对象。 + backend = None def register( - self, req_manager: ReqManager, radix_cache: RadixCache, shm_req_manager: ShmReqManager, vocab_size: int + self, req_manager: ReqManager, radix_cache: RadixCache, shm_req_manager: ShmReqManager, vocab_size: int, backend ): self.req_manager = req_manager self.radix_cache = radix_cache @@ -46,6 +47,7 @@ def register( self.infer_req_ids = [] self.vocab_size = vocab_size + self.backend = backend return def get_overlap_stream(self) -> torch.cuda.Stream: @@ -95,7 +97,13 @@ def free_a_req_mem(self, free_token_index: List, req: "InferReq", is_group_finis value = self.req_manager.req_to_token_indexs[req.req_idx][: req.cur_kv_len].detach().cpu() if is_group_finished: - prefix_len = self.radix_cache.insert(key, value) + if hasattr(self.radix_cache, "is_hi_radix_cache") and getattr(self.radix_cache, "is_hi_radix_cache"): + prefix_len = self.radix_cache.insert( + key, value, + req=req.shm_req + ) + else: + prefix_len = self.radix_cache.insert(key, value) old_prefix_len = 0 if req.shared_kv_node is None else req.shared_kv_node.node_prefix_total_len free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][old_prefix_len:prefix_len]) if req.shared_kv_node is not None: @@ -267,10 +275,12 @@ def __init__( # 当开启后,mtp_gen_token_ids 保存多生成的多余的token_id,但是在后面的 # 步骤中需要重新进行校验。 self.mtp_gen_token_ids: List[int] = [] + self.shm_req = None def init_all(self): if self.initialized is False: - self.shm_req = g_infer_context.shm_req_manager.get_req_obj_by_index(self.shm_index) + if self.shm_req is None: + self.shm_req = g_infer_context.shm_req_manager.get_req_obj_by_index(self.shm_index) self.shm_req.link_prompt_ids_shm_array() self.shm_req.link_logprobs_shm_array() self.sampling_param: InferSamplingParams = InferSamplingParams(self.shm_req, self.vocab_size) @@ -295,7 +305,10 @@ def init_all(self): input_token_ids = self.shm_req.shm_prompt_ids.arr[0 : self.get_cur_total_len()] key = torch.tensor(input_token_ids, dtype=torch.int64, device="cpu") key = key[0 : len(key) - 1] # 最后一个不需要,因为需要一个额外的token,让其在prefill的时候输出下一个token的值 - share_node, kv_len, value_tensor = g_infer_context.radix_cache.match_prefix(key, update_refs=True) + if hasattr(g_infer_context.radix_cache, "is_hi_radix_cache") and getattr(g_infer_context.radix_cache, "is_hi_radix_cache"): + share_node, kv_len, value_tensor = g_infer_context.radix_cache.match_prefix(self.shm_req, key, update_refs=True) + else: + share_node, kv_len, value_tensor = g_infer_context.radix_cache.match_prefix(key, update_refs=True) if share_node is not None: self.shared_kv_node = share_node ready_cache_len = share_node.node_prefix_total_len @@ -313,6 +326,11 @@ def init_all(self): def is_uninitialized(self): return not self.initialized or self.paused + def is_radix_ready(self): + if self.shm_req is None: + self.shm_req = g_infer_context.shm_req_manager.get_req_obj_by_index(self.shm_index) + return g_infer_context.backend.is_radix_ready(self.shm_req) + def get_output_len(self): return self.cur_output_len diff --git a/lightllm/server/router/model_infer/mode_backend/__init__.py b/lightllm/server/router/model_infer/mode_backend/__init__.py index 7ad15f00f..9e946a19d 100644 --- a/lightllm/server/router/model_infer/mode_backend/__init__.py +++ b/lightllm/server/router/model_infer/mode_backend/__init__.py @@ -2,6 +2,7 @@ from .continues_batch.impl_for_return_all_prompt_logprobs import ReturnPromptLogProbBackend from .continues_batch.impl_for_reward_model import RewardModelBackend from .chunked_prefill.impl import ChunkedPrefillBackend +from .chunked_prefill.impl_for_hiradix_cache import ChunkedPrefillBackendHiCache from .diverse_backend.impl import DiversehBackend from .chunked_prefill.impl_for_token_healing import TokenHealingBackend from .chunked_prefill.impl_for_outlines_constraint_mode import OutlinesConstraintBackend diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index c34b2ad97..97ffc5e09 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -59,6 +59,8 @@ def init_model(self, kvargs): self.use_hi_dynamic_prompt_cache = self.args.use_hi_dynamic_prompt_cache self.eos_id: List[int] = kvargs.get("eos_id", [2]) self.disable_cudagraph = self.args.disable_cudagraph + self.use_hiradix_cache = kvargs.get("use_hiradix_cache", False) + self.radix_lock = kvargs.get("radix_lock", None) self.logger = init_logger(__name__) @@ -116,6 +118,8 @@ def init_model(self, kvargs): "quant_type": kvargs.get("quant_type", None), "quant_cfg": kvargs.get("quant_cfg", None), "run_mode": self.run_mode, + "use_hiradix_cache": self.use_hiradix_cache, + "radix_lock": self.radix_lock } self.model, self.is_multimodal = get_model(model_cfg, model_kvargs) self.model: TpPartBaseModel = self.model # for easy typing @@ -126,14 +130,15 @@ def init_model(self, kvargs): self.model.mem_manager.size, self.rank_in_node, mem_manager=self.model.mem_manager, - max_seq_length=kvargs.get("max_seq_length", 1024 * 5), + mem_buffer=self.model.radix_mem_buffer, + radix_info_queue=kvargs.get("radix_info_queue", None) ) - if self.use_dynamic_prompt_cache and self.use_hi_dynamic_prompt_cache + if self.use_hiradix_cache else RadixCache( get_unique_server_name(), self.model.mem_manager.size, self.rank_in_node, - mem_manager=self.model.mem_manager, + mem_manager=self.model.mem_manager ) if self.use_dynamic_prompt_cache else None @@ -149,6 +154,7 @@ def init_model(self, kvargs): radix_cache=self.radix_cache, shm_req_manager=self.shm_req_manager, vocab_size=self.model.vocab_size, + backend=self ) # 初始化 dp 模式使用的通信 tensor, 对于非dp模式,不会使用到 @@ -161,6 +167,12 @@ def init_model(self, kvargs): self.init_custom() return + + def is_radix_ready(self, req): + dp_rank_list = range(self.dp_rank_in_node * self.dp_world_size, (self.dp_rank_in_node + 1) * self.dp_world_size) + if req.radix_status.is_no_need_cache(self.rank_in_node) or req.radix_status.is_read_ready(self.rank_in_node) : + return True + return False def init_custom(self): pass diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_hiradix_cache.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_hiradix_cache.py new file mode 100644 index 000000000..5a36132df --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_hiradix_cache.py @@ -0,0 +1,26 @@ +import torch +from typing import List, Tuple +from lightllm.server.router.model_infer.mode_backend.base_backend import ModeBackend +from lightllm.utils.infer_utils import calculate_time, mark_start, mark_end +from lightllm.utils.log_utils import init_logger +from lightllm.server.router.model_infer.infer_batch import g_infer_context +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample +from lightllm.server.router.model_infer.mode_backend.pre import ( + prepare_prefill_inputs, + prepare_decode_inputs, +) +import torch.multiprocessing as mp +from .impl import ChunkedPrefillBackend + + +logger = init_logger(__name__) + +class ChunkedPrefillBackendHiCache(ChunkedPrefillBackend): + + def __init__(self, radix_mem_queue: mp.Queue) -> None: + super().__init__() + self.radix_mem_queue = radix_mem_queue + + def init_custom(self): + self.radix_mem_queue.put((self.model.mem_propties, self.model.shared_mem_data)) \ No newline at end of file diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py index a54b54980..2ea2366cf 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py @@ -221,7 +221,6 @@ def __remove_dead_trans_obj(self): gc.collect() return - def _init_env(args, info_queue: mp.Queue, mem_queues: List[mp.Queue], event: mp.Event): import lightllm.utils.rpyc_fix_utils as _ diff --git a/lightllm/server/router/model_infer/model_rpc.py b/lightllm/server/router/model_infer/model_rpc.py index 311c2725f..80ce82de6 100644 --- a/lightllm/server/router/model_infer/model_rpc.py +++ b/lightllm/server/router/model_infer/model_rpc.py @@ -10,6 +10,7 @@ ContinuesBatchBackend, ReturnPromptLogProbBackend, ChunkedPrefillBackend, + ChunkedPrefillBackendHiCache, DiversehBackend, RewardModelBackend, TokenHealingBackend, @@ -49,12 +50,18 @@ def __init__( rpc_finished_event: multiprocessing.Event, info_queue: mp.Queue, mem_queue: mp.Queue, + radix_mem_queue: mp.Queue = None, + radix_info_queue: mp.Queue = None, + radix_lock: mp.Lock = None ): super().__init__() self.args: StartArgs = args self.node_world_size = node_world_size self.info_queue = info_queue self.mem_queue = mem_queue + self.radix_mem_queue = radix_mem_queue + self.radix_info_queue = radix_info_queue + self.radix_lock = radix_lock self.rpc_event = rpc_event self.rpc_finished_event = rpc_finished_event @@ -124,6 +131,12 @@ def init_model(self, kvargs): assert not (is_outlines_constraint_mode and is_xgrammar_constraint_mode), "only one constraint mode can be true" is_prefill_node = self.args.run_mode == "prefill" is_decode_node = self.args.run_mode == "decode" + use_hiradix_cache = self.args.use_hi_dynamic_prompt_cache and not self.args.disable_dynamic_prompt_cache + kvargs.update({ + "use_hiradix_cache": use_hiradix_cache, + "radix_info_queue": self.radix_info_queue, + "radix_lock": self.radix_lock + }) enable_mtp = self.args.mtp_mode is not None @@ -177,7 +190,10 @@ def init_model(self, kvargs): if enable_mtp: self.backend = ContinuesBatchWithMTPBackend() else: - self.backend = ChunkedPrefillBackend() + if use_hiradix_cache: + self.backend = ChunkedPrefillBackendHiCache(self.radix_mem_queue) + else: + self.backend = ChunkedPrefillBackend() logger.info(f"use {self.backend.__class__.__name__}") self.backend.init_model(kvargs) @@ -287,6 +303,9 @@ def _init_env( rpc_event: mp.Event, rpc_finished_event: mp.Event, success_event: mp.Event, + radix_mem_queue: mp.Queue = None, + radix_info_queue: mp.Queue = None, + radix_lock: mp.Lock = None ): import lightllm.utils.rpyc_fix_utils as _ @@ -300,7 +319,8 @@ def _init_env( g_router_lock.obj = router_lock model_rpc_server = ModelRpcServer( - args, rank, rank_in_node, node_world_size, rpc_event, rpc_finished_event, info_queue, mem_queue + args, rank, rank_in_node, node_world_size, rpc_event, rpc_finished_event, info_queue, mem_queue, + radix_mem_queue, radix_info_queue, radix_lock ) success_event.set() @@ -318,6 +338,9 @@ async def start_model_process( info_queue: mp.Queue, mem_queue: mp.Queue, router_lock: mp.Queue, + radix_mem_queue: mp.Queue = None, + radix_info_queue: mp.Queue = None, + radix_lock: mp.Lock = None ): import lightllm.utils.rpyc_fix_utils as _ @@ -335,6 +358,9 @@ async def start_model_process( rpc_event, rpc_finished_event, success_event, + radix_mem_queue, + radix_info_queue, + radix_lock ), ) proc.start() From 9141e0736523c93376e45c6b3407f9d818bf5abc Mon Sep 17 00:00:00 2001 From: jinbiaoyu Date: Tue, 15 Jul 2025 17:32:24 +0800 Subject: [PATCH 05/13] fix some --- lightllm/server/api_start.py | 5 +- lightllm/server/core/objs/req.py | 82 +++++++------------ lightllm/server/httpserver/manager.py | 24 ++++-- .../dynamic_prompt/disk_cache_server.py | 7 +- .../router/dynamic_prompt/hiradix_cache.py | 26 ++++-- .../server/router/dynamic_prompt/io_objs.py | 3 +- .../server/router/dynamic_prompt/manager.py | 23 ++++-- lightllm/server/router/manager.py | 18 ++-- .../server/router/model_infer/infer_batch.py | 10 +-- .../model_infer/mode_backend/base_backend.py | 9 +- 10 files changed, 107 insertions(+), 100 deletions(-) diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index d69e0176d..8d4e03d40 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -208,7 +208,7 @@ def normal_or_p_d_start(args): node_world_size = args.tp // args.nnodes if args.use_hi_dynamic_prompt_cache: - hiradix_cache_port_num = node_world_size + hiradix_cache_port_num = node_world_size + 2 can_use_ports = alloc_can_use_network_port( num=7 + node_world_size + args.visual_dp * args.visual_tp + hiradix_cache_port_num, used_nccl_ports=already_uesd_ports ) @@ -240,7 +240,8 @@ def normal_or_p_d_start(args): args.metric_port = metric_port if args.use_hi_dynamic_prompt_cache: args.hiradix_cache_ports = can_use_ports[0:node_world_size] - can_use_ports = can_use_ports[node_world_size:] + args.hiradix_server_ports = can_use_ports[node_world_size: node_world_size + 2] + can_use_ports = can_use_ports[node_world_size + 2:] # 申请在 p d 分离模式下,会用的端口 args.pd_node_infer_rpyc_ports = can_use_ports[0:node_world_size] diff --git a/lightllm/server/core/objs/req.py b/lightllm/server/core/objs/req.py index bdf9a6d53..a651e15f9 100644 --- a/lightllm/server/core/objs/req.py +++ b/lightllm/server/core/objs/req.py @@ -54,57 +54,54 @@ def set_token_ids(self, ids: List[int]): def get_token_ids(self): return list(self.data[: self.size]) -# class RadixStatus(ctypes.Structure): -# _pack_ = 4 -# _fields_ = [("status", ctypes.c_int * 32)] - -# NOCACHE = -2 -# NOT_READY = -1 -# READ_READY = 1 -# WRITE_READY = 2 - -# def __init__(self, init_state=NOT_READY): -# self.status = init_state - -# def set_status(self, new_status: int): -# assert new_status in (self.NOCACHE, self.NOT_READY, self.READ_READY, self.WRITE_READY) -# self.status = new_status - -# def get_status(self) -> int: -# return self.status - -# def no_need_cache(self) -> bool: -# return self.status == self.NOCACHE - -# def is_read_ready(self) -> bool: -# return self.status == self.READ_READY +class ReqRankStatus(ctypes.Structure): + _pack_ = 4 + _fields_ = [("dp_rank_in_node", ctypes.c_int), ("dp_world_size", ctypes.c_int)] -# def is_write_ready(self) -> bool: -# return self.status == self.WRITE_READY + def __init__(self): + self.dp_rank_in_node = 0 + self.dp_world_size = 8 + + def set_status(self, dp_rank_in_node: int, dp_world_size: int): + self.dp_rank_in_node = dp_rank_in_node + self.dp_world_size = dp_world_size -# def is_not_ready(self) -> bool: -# return self.status == self.NOT_READY class RadixStatus(ctypes.Structure): _pack_ = 4 - _fields_ = [("status", ctypes.c_int * 32)] + _fields_ = [("status", ctypes.c_int * 32), ("rank_status", ReqRankStatus), ("finished", ctypes.c_int)] NOCACHE = -2 NOT_READY = -1 READ_READY = 1 WRITE_READY = 2 + WRITE_DONE = 3 def __init__(self, init_state=NOT_READY): for i in range(32): self.status[i] = init_state + self.rank_status = ReqRankStatus() + self.finished = 0 def set_status(self, idx: int, new_status: int): assert 0 <= idx < 32, f"Index out of range: {idx}" - assert new_status in (self.NOCACHE, self.NOT_READY, self.READ_READY, self.WRITE_READY) + assert new_status in (self.NOCACHE, self.NOT_READY, self.READ_READY, self.WRITE_READY, self.WRITE_DONE) self.status[idx] = new_status + + def set_finished(self): + self.finished = 1 + + def is_finished(self): + self.finished == 1 def get_status(self, idx: int) -> int: assert 0 <= idx < 32, f"Index out of range: {idx}" return self.status[idx] + + def is_write_done(self): + dp_index = self.rank_status.dp_rank_in_node + dp_size = self.rank_status.dp_world_size + rank_list = range(dp_index * dp_size, (dp_index + 1) * dp_size) + return np.all(np.array(self.status)[rank_list] == self.WRITE_DONE) def is_no_need_cache(self, idx: int) -> bool: return self.get_status(idx) == self.NOCACHE @@ -118,32 +115,12 @@ def is_write_ready(self, idx: int) -> bool: def is_not_ready(self, idx: int) -> bool: return self.get_status(idx) == self.NOT_READY - # def all_dp_read_ready_or_nocache(self, indexs: List[int]) -> bool: - # for i in indexs: - # if self.status[i] not in (self.READ_READY, self.NOCACHE): - # return False - # return True def all_dp_read_ready_or_nocache(self, indexs: List[int]) -> bool: - # return np.all(self.status == self.READ_READY) - # for i in indexs: - # if self.status[i] not in (self.READ_READY, self.NOCACHE): - # return False - # return True - return np.all(np.array(self.status[indexs]) == self.READ_READY) or np.all(np.array(self.status[indexs]) == self.NOCACHE) - - # def all_read_ready_or_nocache(self) -> bool: - # for i in range(32): - # if self.status[i] not in (self.READ_READY, self.NOCACHE): - # return False - # return True + return np.all(np.array(self.status)[indexs] == self.READ_READY) or np.all(np.array(self.status)[indexs] == self.NOCACHE) + def all_read_ready_or_nocache(self) -> bool: return np.all(np.array(self.status) == self.READ_READY) or np.all(np.array(self.status) == self.NOCACHE) - # def all_read_ready(self) -> bool: - # return np.all(self.status == self.READ_READY) - - # def all_no_need_cache(self) -> bool: - # return np.all(self.status == self.NOCACHE) class Req(ctypes.Structure): _pack_ = 4 @@ -297,7 +274,6 @@ def can_release(self): # 只有管理节点有一个引用 ref_count_ok = self.ref_count == 1 can_released_mark = self.can_released_mark - if self.is_aborted and can_released_mark and ref_count_ok: return True diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 61a2f75e9..bab31f3a7 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -51,10 +51,11 @@ def __init__( context = zmq.asyncio.Context(2) self.send_to_router = context.socket(zmq.PUSH) self.send_to_router.connect(f"{args.zmq_mode}127.0.0.1:{router_port}") - if self.args.use_hi_dynamic_prompt_cache: + self.use_hi_dynamic_prompt_cache = args.use_hi_dynamic_prompt_cache + if self.use_hi_dynamic_prompt_cache: context_hiradix = zmq.asyncio.Context() self.send_to_hiradix = context_hiradix.socket(zmq.PUSH) - self.send_to_hiradix.bind(f"{args.zmq_mode}127.0.0.1:55555") + self.send_to_hiradix.connect(f"{args.zmq_mode}127.0.0.1:{self.args.hiradix_server_ports[0]}") self.multinode_req_manager = None self.nnodes = args.nnodes self._shm_lock_pool = AtomicShmArrayLock(f"{get_unique_server_name()}_lightllm_resource_lock", 1) @@ -305,7 +306,7 @@ async def generate( ) req_objs.append(req_obj) - req_status = ReqStatus(group_request_id, multimodal_params, req_objs, start_time) + req_status = ReqStatus(group_request_id, multimodal_params, req_objs, start_time, self.use_hi_dynamic_prompt_cache) self.req_id_to_out_inf[group_request_id] = req_status await self.transfer_to_next_module_or_node( @@ -479,8 +480,7 @@ async def transfer_to_next_module( protocol=pickle.HIGHEST_PROTOCOL, ) else: - if self.args.use_hi_dynamic_prompt_cache: - logger.info(f"send_to_hiradix {group_req_objs.to_group_req_index()}") + if self.use_hi_dynamic_prompt_cache: self.send_to_hiradix.send_pyobj( group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL @@ -705,7 +705,8 @@ async def handle_loop(self): class ReqStatus: - def __init__(self, group_request_id, multimodal_params, req_objs: List[Req], start_time) -> None: + def __init__(self, group_request_id, multimodal_params, req_objs: List[Req], start_time, use_hi_dynamic_prompt_cache) -> None: + self.use_hi_dynamic_prompt_cache = use_hi_dynamic_prompt_cache self.lock = asyncio.Lock() self.event = asyncio.Event() self.group_req_objs = GroupReqObjs( @@ -718,6 +719,11 @@ def __init__(self, group_request_id, multimodal_params, req_objs: List[Req], sta def can_release(self): for req in self.group_req_objs.shm_req_objs: - if not req.can_release(): - return False - return True + if self.use_hi_dynamic_prompt_cache: + if req.can_release() and req.radix_status.is_finished(): + + return True + else: + if req.can_release(): + return True + return False diff --git a/lightllm/server/router/dynamic_prompt/disk_cache_server.py b/lightllm/server/router/dynamic_prompt/disk_cache_server.py index 08a29ca56..8255f2cd3 100644 --- a/lightllm/server/router/dynamic_prompt/disk_cache_server.py +++ b/lightllm/server/router/dynamic_prompt/disk_cache_server.py @@ -103,7 +103,7 @@ def exposed_push(self, req_info): logger.info(f"_push_task_loop receive task keys {len(keys)} values {len(values)}") self.remote_cache_manager.insert(keys, index) self.mem_manager.free_req_index(req.request_id) - self.set_reqs_radix_status([req], RadixStatus.NOT_READY) + self.set_reqs_radix_status([req], RadixStatus.WRITE_DONE) self.shm_req_manager.put_back_req_obj(req) return {"status": "ok"} @@ -128,6 +128,7 @@ def exposed_pull(self, group_req): query_len = self.remote_cache_manager.query(tokens=keys) if query_len == 0: self.set_reqs_radix_status(reqs, RadixStatus.NOCACHE) + self.put_back_req_objs(reqs) return {"query_len": 0, "kv_indices": []} index = self.mem_manager.alloc(query_len) self.remote_cache_manager.read(tokens=keys[:query_len], kv_page_indexer=index) @@ -164,9 +165,9 @@ async def _wrapped(*args, **kwargs): async def push(self, req_info: ShmReqInfo): if self.use_rpc: - return await self._insert(req_info) + return await self._push(req_info) else: - return self._insert(req_info) + return self._push(req_info) async def pull(self, group_req: GroupReqInfo): if self.use_rpc: diff --git a/lightllm/server/router/dynamic_prompt/hiradix_cache.py b/lightllm/server/router/dynamic_prompt/hiradix_cache.py index 68e217870..c94cbc7d3 100644 --- a/lightllm/server/router/dynamic_prompt/hiradix_cache.py +++ b/lightllm/server/router/dynamic_prompt/hiradix_cache.py @@ -34,15 +34,21 @@ def __init__(self, mem_buffer: RadixMemoryBuffer, mem_manager: MemoryManager, ra self.rank_in_node = rank_in_node def insert(self, req: Req, key: torch.Tensor, value=None): - index = self.mem_buffer.alloc(len(value)) pre_index = self.mem_buffer.get_req_mem_index(req.request_id) - self.mem_buffer.set_req_mem_index( - req.request_id, index.tolist() - ) - if len(pre_index) != 0: - logger.info(f"pre index {pre_index}, index {index}") - index = index[len(pre_index):] + if len(pre_index) != 0 and len(value) > len(pre_index): + logger.info(f"pre index req {req.request_id} {pre_index}") + alloc_len = len(value) - len(pre_index) + index = self.mem_buffer.alloc(alloc_len) value = value[len(pre_index):] + self.mem_buffer.set_req_mem_index( + req.request_id, pre_index + index.tolist() + ) + logger.info(f"udpate index req {req.request_id} {pre_index + index.tolist()}") + else: + index = self.mem_buffer.alloc(len(value)) + self.mem_buffer.set_req_mem_index( + req.request_id, index.tolist() + ) dst_kv_buffer = self.mem_buffer.get_kv_buffer(index) src_kv_buffer = self.mem_manager.get_index_kv_buffer(value)["kv_buffer"] logger.info(f"insert mem_buffer shape {dst_kv_buffer.shape}, manager buffer shape {src_kv_buffer.shape}") @@ -58,6 +64,7 @@ def read(self, req: Req, dst_index): logger.info(f"len mem src index and dst index {len(index), len(dst_index)} read mem_buffer shape {src_kv_buffer.shape}, manager buffer shape {dst_kv_buffer.shape}") assert len(src_kv_buffer) == len(dst_kv_buffer), f"src kv buffer len {len(src_kv_buffer)} != dst kv buffer len {len(dst_kv_buffer)}" self.copy_kv_from_cpu_to_gpu(src_kv_buffer, dst_kv_buffer) + #TODO no free self.mem_buffer.free_req_index(req.group_req_id) except Exception as e: logger.error(f"Local cache read from radix mem_buffer error {e}") @@ -66,9 +73,11 @@ def read(self, req: Req, dst_index): def query(self, req: Req): if req.radix_status.is_no_need_cache(self.rank_in_node): + logger.info(f"query no need cache {self.rank_in_node} {req.radix_status.get_status(self.rank_in_node)}") return 0 if req.radix_status.is_read_ready(self.rank_in_node): index = self.mem_buffer.get_req_mem_index(req.group_req_id) + logger.info(f"query find cache {self.rank_in_node} {req.radix_status.get_status(self.rank_in_node)} len {len(index)}") return len(index) return 0 @@ -118,7 +127,7 @@ def match_prefix(self, req, key, update_refs=False): max_len = 0 if tree_node.node_prefix_total_len < len(key): max_len = self.local_cache_manager.query(req) - logger.debug(f"HiCache rank_in_node={self.rank_in_node} current match radix len {tree_node.node_prefix_total_len}, max len {max_len}") + logger.debug(f"HiCache rank_in_node={self.rank_in_node} current key len {len(key)} match radix len {tree_node.node_prefix_total_len}, max len {max_len}") if max_len > tree_node.node_prefix_total_len: pull_len = max_len - tree_node.node_prefix_total_len self.disk_cache_match_count.arr[0] += 1 @@ -136,4 +145,5 @@ def match_prefix(self, req, key, update_refs=False): super().insert(key[:max_len], buffers) else: self.mem_manager.free(buffers[tree_node.node_prefix_total_len:]) + return super().match_prefix(key, update_refs=update_refs) diff --git a/lightllm/server/router/dynamic_prompt/io_objs.py b/lightllm/server/router/dynamic_prompt/io_objs.py index 7a1da21b2..b95db6b3b 100644 --- a/lightllm/server/router/dynamic_prompt/io_objs.py +++ b/lightllm/server/router/dynamic_prompt/io_objs.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from lightllm.server.core.objs import Req from typing import List @dataclass @@ -15,7 +14,7 @@ def to_dict(self): @staticmethod def from_dict(d): - return GroupReqInfo( + return ShmReqInfo( request_id=d["request_id"], shm_req_index=d["shm_req_index"] ) diff --git a/lightllm/server/router/dynamic_prompt/manager.py b/lightllm/server/router/dynamic_prompt/manager.py index 53a5a1917..0d8dcb3b8 100644 --- a/lightllm/server/router/dynamic_prompt/manager.py +++ b/lightllm/server/router/dynamic_prompt/manager.py @@ -12,7 +12,9 @@ from lightllm.utils.log_utils import init_logger, log_time_ready from lightllm.utils.graceful_utils import graceful_registry from .disk_cache_server import DiskCacheClient +from lightllm.server.core.objs import ShmReqManager from .io_objs import ShmReqInfo, GroupReqInfo +from lightllm.server.core.objs import Req logger = init_logger(__name__) @@ -26,15 +28,17 @@ def __init__( self.disk_cache_processes = [] self.ports = args.hiradix_cache_ports self.cache_server_client = [] - context = zmq.asyncio.Context(2) + context = zmq.asyncio.Context(3) self.recv_from_httpserver = context.socket(zmq.PULL) - self.recv_from_httpserver.connect(f"{args.zmq_mode}127.0.0.1:55555") + recv_from_http_port, recv_from_router_port = self.args.hiradix_server_ports + self.recv_from_httpserver.bind(f"{args.zmq_mode}127.0.0.1:{recv_from_http_port}") self.clients: List[DiskCacheClient] = [] self.send_to_router = context.socket(zmq.PUSH) self.send_to_router.connect(f"{args.zmq_mode}127.0.0.1:{router_port}") - logger.info(f"send_to_router {args.zmq_mode}127.0.0.1:{router_port} ") self.recv_from_router = context.socket(zmq.PULL) - self.recv_from_router.connect(f"{args.zmq_mode}127.0.0.1:66666") + self.recv_from_router.bind(f"{args.zmq_mode}127.0.0.1:{recv_from_router_port}") + self.shm_req_manager = ShmReqManager() + async def asyn_init(self): self.pull_queue = asyncio.Queue() @@ -65,7 +69,6 @@ async def pull_cache(self, group_req): all_results = await asyncio.gather(*tasks) logger.info(f"pull cache results {all_results}") await self.send_to_router.send_pyobj(group_req, protocol=pickle.HIGHEST_PROTOCOL) - logger.info(f"send to router pyobj {group_req}") async def push_cache(self, req_info): tasks = [] @@ -73,17 +76,20 @@ async def push_cache(self, req_info): task = client.push(req_info) tasks.append(task) all_results = await asyncio.gather(*tasks) + req: Req = self.shm_req_manager.get_req_obj_by_index(req_info["shm_req_index"]) + assert req.radix_status.is_write_done() + req.radix_status.set_finished() logger.info(f"push cache results {all_results}") async def pull_woker(self): while True: - req: ShmReqInfo = await self.pull_queue.get() + req: GroupReqInfo = await self.pull_queue.get() await self.pull_cache(req) await asyncio.sleep(0.01) async def push_woker(self): while True: - req: GroupReqInfo = await self.push_queue.get() + req: ShmReqInfo = await self.push_queue.get() await self.push_cache(req.to_dict()) await asyncio.sleep(0.01) @@ -99,8 +105,9 @@ async def run(self): async def loop_for_netio_req_to_push(self): while True: recv_req: ShmReqInfo = await self.recv_from_router.recv_pyobj() + logger.info(f"loop_for_netio_req_to_push --> recv req {recv_req}") if isinstance(recv_req, ShmReqInfo): - await self.push_queue.put() + await self.push_queue.put(recv_req) else: raise ValueError(f"Invalid request: {recv_req}") diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 28479ce0b..03f0ebff8 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -15,6 +15,7 @@ from .batch import Batch, Req from .model_infer.model_rpc import start_model_process, ModelRpcClient from .req_queue import build_req_queue +from lightllm.server.router.dynamic_prompt.io_objs import ShmReqInfo from lightllm.server.core.objs.io_objs import GroupReqIndexes from lightllm.server.core.objs import ShmReqManager, StartArgs from .dynamic_prompt.radix_cache import RadixCacheReadOnlyClient @@ -75,16 +76,18 @@ def __init__(self, args: StartArgs, router_port, detokenization_port, metric_por self.max_wait_tokens = args.router_max_wait_tokens context = zmq.Context(2) self.recv_from_httpserver = context.socket(zmq.PULL) - logger.info(f"recv_from_httpserver {args.zmq_mode}127.0.0.1:{router_port} ") self.recv_from_httpserver.bind(f"{args.zmq_mode}127.0.0.1:{router_port}") self.send_to_detokenization = context.socket(zmq.PUSH) self.send_to_detokenization.connect(f"{args.zmq_mode}127.0.0.1:{detokenization_port}") self.router_port = router_port - context_to_radix = zmq.asyncio.Context() - self.send_to_hiradix_server = context_to_radix.socket(zmq.PUSH) - self.send_to_hiradix_server.connect(f"{args.zmq_mode}127.0.0.1:{66666}") + if self.use_hiradix_cache: + hiradix_port = self.args.hiradix_server_ports[1] + context_radix = zmq.asyncio.Context() + self.send_to_hiradix_server = context_radix.socket(zmq.PUSH) + self.send_to_hiradix_server.connect(f"{args.zmq_mode}127.0.0.1:{hiradix_port}") + logger.info(f"send_to_hiradix_server {args.zmq_mode}127.0.0.1:{hiradix_port}") if self.is_multinode_tp: self.mulitnode_group = dist.init_process_group( @@ -398,7 +401,12 @@ async def _send_hiradix_manager(self, reqs): if not self.use_hiradix_cache: return for req in reqs: - await self.send_to_hiradix_server.send_pyobj(req, protocol=pickle.HIGHEST_PROTOCOL) + req_info = ShmReqInfo( + req.request_id, + req.index_in_shm_mem + ) + await self.send_to_hiradix_server.send_pyobj(req_info, protocol=pickle.HIGHEST_PROTOCOL) + logger.info(f"_send_hiradix_manager {req_info}") return def _send_detokenization_pack(self): diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 3904dd628..11189b85c 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -279,8 +279,7 @@ def __init__( def init_all(self): if self.initialized is False: - if self.shm_req is None: - self.shm_req = g_infer_context.shm_req_manager.get_req_obj_by_index(self.shm_index) + self.shm_req = g_infer_context.shm_req_manager.get_req_obj_by_index(self.shm_index) self.shm_req.link_prompt_ids_shm_array() self.shm_req.link_logprobs_shm_array() self.sampling_param: InferSamplingParams = InferSamplingParams(self.shm_req, self.vocab_size) @@ -319,6 +318,7 @@ def init_all(self): self.shm_req.shm_cur_kv_len = self.cur_kv_len + self.init_radix_status() self.initialized = True self.paused = False return @@ -326,10 +326,8 @@ def init_all(self): def is_uninitialized(self): return not self.initialized or self.paused - def is_radix_ready(self): - if self.shm_req is None: - self.shm_req = g_infer_context.shm_req_manager.get_req_obj_by_index(self.shm_index) - return g_infer_context.backend.is_radix_ready(self.shm_req) + def init_radix_status(self): + return g_infer_context.backend.set_radix_status(self.shm_req) def get_output_len(self): return self.cur_output_len diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 97ffc5e09..964ffe42c 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -168,10 +168,11 @@ def init_model(self, kvargs): self.init_custom() return - def is_radix_ready(self, req): - dp_rank_list = range(self.dp_rank_in_node * self.dp_world_size, (self.dp_rank_in_node + 1) * self.dp_world_size) - if req.radix_status.is_no_need_cache(self.rank_in_node) or req.radix_status.is_read_ready(self.rank_in_node) : - return True + def set_radix_status(self, req): + if not self.use_hiradix_cache: + return + if self.is_master_in_dp: + req.radix_status.rank_status.set_status(self.dp_rank_in_node, self.dp_world_size) return False def init_custom(self): From e3f4955a2d3b1407c0459b174697e4efb33fa008 Mon Sep 17 00:00:00 2001 From: jinbiaoyu Date: Fri, 18 Jul 2025 16:07:31 +0800 Subject: [PATCH 06/13] update and add radix manager --- lightllm/common/basemodel/basemodel.py | 22 ++- lightllm/common/radixmem_buffer.py | 38 ++--- lightllm/common/radixmem_manager.py | 119 +++++++++++++++ lightllm/models/deepseek2/model.py | 18 ++- lightllm/models/qwen2/model.py | 18 ++- lightllm/server/api_cli.py | 4 +- lightllm/server/api_start.py | 10 +- lightllm/server/core/objs/start_args_type.py | 2 +- lightllm/server/httpserver/manager.py | 14 +- .../dynamic_prompt/disk_cache_server.py | 141 ++++++++++++------ .../router/dynamic_prompt/hiradix_cache.py | 113 ++++++++------ .../server/router/dynamic_prompt/io_objs.py | 47 ++++++ .../server/router/dynamic_prompt/manager.py | 1 - lightllm/server/router/manager.py | 4 +- .../model_infer/mode_backend/base_backend.py | 6 +- .../server/router/model_infer/model_rpc.py | 6 +- 16 files changed, 405 insertions(+), 158 deletions(-) create mode 100644 lightllm/common/radixmem_manager.py diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index c0aae6b1d..1665486e8 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -75,6 +75,8 @@ def __init__(self, kvargs): self.quant_cfg_path = kvargs.get("quant_cfg", None) self.mem_fraction = kvargs.get("mem_fraction", 0.9) self.enable_hiradix_cache = kvargs.get("use_hiradix_cache", False) + self.hiradix_cache_gpu = kvargs.get("hiradix_cache_gpu", False) + self.hiradix_cache_token_num = kvargs.get("hiradix_cache_token_num", None) self.radix_lock = kvargs.get("radix_lock", None) self.tp_world_size_ = get_dp_world_size() self.enable_tpsp_mix_mode = get_env_start_args().enable_tpsp_mix_mode @@ -164,9 +166,11 @@ def _init_weights(self): def _init_mem_manager(self): assert self.config["num_attention_heads"] % self.tp_world_size_ == 0 - max_radix_token_num = 10000 + + max_total_token_num = self.max_total_token_num - self.hiradix_cache_token_num if self.hiradix_cache_gpu else self.max_total_token_num + self.mem_manager = MemoryManager( - self.max_total_token_num - max_radix_token_num, + max_total_token_num, dtype=self.data_type, head_num=self.config["num_attention_heads"] // self.tp_world_size_, head_dim=self.config["n_embed"] // self.config["num_attention_heads"], @@ -176,21 +180,27 @@ def _init_mem_manager(self): if self.enable_hiradix_cache: from lightllm.common.radixmem_buffer import RadixMemoryBuffer, init_shared_data, get_shared_data, MemPropties + from lightllm.common.radixmem_manager import RadixBufferManager mem_propties = MemPropties( - max_radix_token_num, + self.hiradix_cache_token_num, dtype=self.data_type, head_num=self.config["num_attention_heads"] // self.tp_world_size_, head_dim=self.config["n_embed"] // self.config["num_attention_heads"], layer_num=self.config["n_layer"] ) init_shared_data( - mem_propties=mem_propties + mem_propties=mem_propties, + device="cpu" if not self.hiradix_cache_gpu else "cuda" ) - self.radix_mem_buffer = RadixMemoryBuffer( + radix_mem_buffer = RadixMemoryBuffer( mem_propties, shared_data=get_shared_data(), - lock=self.radix_lock + lock=self.radix_lock, + device="cpu" if not self.hiradix_cache_gpu else "cuda" ) + self.radix_manager = RadixBufferManager(radix_buffer=radix_mem_buffer, + radix_mem_data=get_shared_data(), + lock=self.radix_lock) self.mem_propties = mem_propties self.shared_mem_data = get_shared_data() return diff --git a/lightllm/common/radixmem_buffer.py b/lightllm/common/radixmem_buffer.py index 8dcf220dd..20181d9eb 100644 --- a/lightllm/common/radixmem_buffer.py +++ b/lightllm/common/radixmem_buffer.py @@ -7,7 +7,7 @@ from lightllm.utils.envs_utils import get_unique_server_name from lightllm.utils.dist_utils import get_current_rank_in_node from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt -from multiprocessing.managers import DictProxy +from multiprocessing.managers import DictProxy, ListProxy from multiprocessing import Manager @@ -18,7 +18,7 @@ class SharedRadixMemoryData: kv_buffer: torch.Tensor mem_state: torch.Tensor req_mem_index: DictProxy - + lru_queue: ListProxy @dataclass class MemPropties: @@ -30,6 +30,7 @@ class MemPropties: shared_mem_data: SharedRadixMemoryData = None + def init_shared_data(mem_propties: MemPropties, device="cuda"): size, dtype, head_num, head_dim, layer_num = mem_propties.size, mem_propties.dtype, \ mem_propties.head_num, mem_propties.head_dim, mem_propties.layer_num @@ -49,13 +50,15 @@ def init_shared_data(mem_propties: MemPropties, device="cuda"): ).share_memory_() mem_state = torch.arange(size, dtype=torch.int32).share_memory_() - manager = Manager() req_mem_index = manager.dict() + lru_queue = manager.list() + shared_mem_data = SharedRadixMemoryData( kv_buffer=kv_buffer, mem_state=mem_state, - req_mem_index=req_mem_index + req_mem_index=req_mem_index, + lru_queue=lru_queue ) def get_shared_data() -> SharedRadixMemoryData: @@ -70,29 +73,12 @@ def __init__(self, mem_propties: MemPropties, shared_data: SharedRadixMemoryData rank_in_node=None): size, dtype, head_num, head_dim, layer_num = mem_propties.size, mem_propties.dtype, \ mem_propties.head_num, mem_propties.head_dim, mem_propties.layer_num - if shared_data is not None: - self.kv_buffer = shared_data.kv_buffer - self.mem_state = shared_data.mem_state - self.req_mem_index = shared_data.req_mem_index - else: - # CPU 上分配 key 和 value(共 2 * head_num) - if device == "cuda": - self.kv_buffer = torch.empty( - (layer_num, size, 2 * head_num, head_dim), - dtype=dtype, - device="cuda" - ) - else: - self.kv_buffer = torch.empty( - (layer_num, size, 2 * head_num, head_dim), - dtype=dtype, - device="cpu" - ).share_memory_() - self.mem_state = torch.arange( - 0, size, dtype=torch.int32 - ).share_memory_() - self.req_mem_index = mp.Manager().dict() + + self.kv_buffer = shared_data.kv_buffer + self.mem_state = shared_data.mem_state + self.req_mem_index = shared_data.req_mem_index self.lock = lock if lock is not None else mp.Lock() + #TODO profile size self.size = size # token slot 个数 self.head_num = head_num diff --git a/lightllm/common/radixmem_manager.py b/lightllm/common/radixmem_manager.py new file mode 100644 index 000000000..41a393e90 --- /dev/null +++ b/lightllm/common/radixmem_manager.py @@ -0,0 +1,119 @@ +import torch +import time +import xxhash +import numpy as np +from typing import List, Dict, Tuple, Optional +import torch.multiprocessing as mp +from collections import OrderedDict + +from .radixmem_buffer import SharedRadixMemoryData, RadixMemoryBuffer + +from lightllm.utils.log_utils import init_logger +logger = init_logger(__name__) + +class RadixBufferManager: + + def __init__(self, + radix_buffer: RadixMemoryBuffer = None, + radix_mem_data: SharedRadixMemoryData = None, + lock: Optional[mp.Lock] = None, + max_entries: int = 100, + chunk_size: int = 64 + ): + self.chunk_size = chunk_size + self.max_entries = max_entries + self.radix_buffer = radix_buffer + self.lru_queue = radix_mem_data.lru_queue + + self.lock = lock if lock is not None else mp.Lock() + + def _compute_hash(self, tokens: List[int]) -> List[Tuple[int, List[int]]]: + chunks = [] + hsum = xxhash.xxh3_64() + cumulative_tokens = [] + + for i in range(0, len(tokens), self.chunk_size): + chunk = tokens[i:i + self.chunk_size] + cumulative_tokens.extend(chunk) + + chunk_np = np.array(chunk, dtype=np.uint32) + hsum.update(chunk_np.tobytes()) + + current_hash = hsum.intdigest() + chunks.append((current_hash, cumulative_tokens.copy())) + + return chunks + + def write(self, tokens: List[int], values: torch.Tensor, start_pos: int) -> None: + with self.lock: + index = start_pos // self.chunk_size + chunks = self._compute_hash(tokens) + + values = values[index * self.chunk_size:] + chunks = chunks[index:] + for i, (hash_val, _) in enumerate(chunks): + if hash not in self.radix_buffer.req_mem_index: + self.radix_buffer.req_mem_index[hash_val] = values[i * self.chunk_size : (i + 1) * self.chunk_size] + self._update_lru_state(hash_val) + + def _update_lru_state(self, hash_val: int): + if hash_val in self.lru_queue: + self.lru_queue.remove(hash_val) + self.lru_queue.append(hash_val) + + while len(self.lru_queue) > self.max_entries: + self.lru_queue.pop(0) + + def free_space(self, required_size: int) -> bool: + with self.lock: + current_free = self.radix_buffer.get_can_use_mem_size.get_value() + + if current_free >= required_size: + return True + + need_to_free = required_size - current_free + freed_size = 0 + + while freed_size < need_to_free and len(self.lru_queue) > 0: + evict_size = self._evict_lru() + freed_size += evict_size + + final_free = self.radix_buffer.can_use_mem_size.get_value() + return final_free >= required_size + + def _evict_lru(self): + if not self.lru_queue: + return + oldest_hash = self.lru_queue[0] + + evict_size = 0 + if oldest_hash in self.radix_buffer.req_mem_index: + indices = self.radix_buffer.req_mem_index[oldest_hash] + evict_size += len(indices) + self.radix_buffer._free(indices) + del self.radix_buffer.req_mem_index[oldest_hash] + + self.lru_queue.pop(0) + return evict_size + + def query_cache(self, tokens: List[int]) -> int: + with self.lock: + chunks = self._compute_hash(tokens) + if not chunks: + return 0, [] + + max_hit = 0 + mem_index = [] + for hash_val, _ in chunks: + if hash_val in self.radix_buffer.req_mem_index: + index_val = self.radix_buffer.req_mem_index[hash_val] + mem_index.extend(index_val) + max_hit += len(index_val) + else: + break + return max_hit, mem_index + + def clear(self): + with self.lock: + self.radix_buffer.req_mem_index.clear() + self.lru_queue[:] = [] \ No newline at end of file diff --git a/lightllm/models/deepseek2/model.py b/lightllm/models/deepseek2/model.py index 60a97f11b..38f49b96f 100644 --- a/lightllm/models/deepseek2/model.py +++ b/lightllm/models/deepseek2/model.py @@ -103,9 +103,9 @@ def _init_mem_manager(self): if get_env_start_args().mtp_mode == "deepseekv3": added_mtp_layer_num += get_env_start_args().mtp_step - max_radix_token_num = 300000 + max_total_token_num = self.max_total_token_num - self.hiradix_cache_token_num if self.hiradix_cache_gpu else self.max_total_token_num self.mem_manager = manager_class( - self.max_total_token_num - max_radix_token_num, + max_total_token_num, dtype=self.data_type, head_num=1, head_dim=self.config["kv_lora_rank"] + self.config["qk_rope_head_dim"], @@ -114,21 +114,27 @@ def _init_mem_manager(self): ) if self.enable_hiradix_cache: from lightllm.common.radixmem_buffer import RadixMemoryBuffer, init_shared_data, get_shared_data, MemPropties + from lightllm.common.radixmem_manager import RadixBufferManager mem_propties = MemPropties( - max_radix_token_num, + self.hiradix_cache_token_num, dtype=self.data_type, head_num=1, head_dim=self.config["kv_lora_rank"] + self.config["qk_rope_head_dim"], layer_num=self.config["num_hidden_layers"] + added_mtp_layer_num, ) init_shared_data( - mem_propties=mem_propties + mem_propties=mem_propties, + device="cpu" if not self.hiradix_cache_gpu else "cuda" ) - self.radix_mem_buffer = RadixMemoryBuffer( + radix_mem_buffer = RadixMemoryBuffer( mem_propties, shared_data=get_shared_data(), - lock=self.radix_lock + lock=self.radix_lock, + device="cpu" if not self.hiradix_cache_gpu else "cuda" ) + self.radix_manager = RadixBufferManager(radix_buffer=radix_mem_buffer, + radix_mem_data=get_shared_data(), + lock=self.radix_lock) self.mem_propties = mem_propties self.shared_mem_data = get_shared_data() return diff --git a/lightllm/models/qwen2/model.py b/lightllm/models/qwen2/model.py index 252380294..8c69eb727 100644 --- a/lightllm/models/qwen2/model.py +++ b/lightllm/models/qwen2/model.py @@ -42,9 +42,9 @@ def _init_mem_manager(self): head_dim_ = self.config.get("head_dim", head_dim_) tp_k_head_num_ = max(self.config["num_key_value_heads"] // self.tp_world_size_, 1) - max_radix_token_num = 10000 + max_total_token_num = self.max_total_token_num - self.hiradix_cache_token_num if self.hiradix_cache_gpu else self.max_total_token_num self.mem_manager = select_mem_manager_class(self.mode)( - self.max_total_token_num - max_radix_token_num, + max_total_token_num, dtype=self.data_type, head_num=tp_k_head_num_, head_dim=head_dim_, @@ -54,21 +54,27 @@ def _init_mem_manager(self): if self.enable_hiradix_cache: from lightllm.common.radixmem_buffer import RadixMemoryBuffer, init_shared_data, get_shared_data, MemPropties + from lightllm.common.radixmem_manager import RadixBufferManager mem_propties = MemPropties( - max_radix_token_num, + self.hiradix_cache_token_num, dtype=self.data_type, head_num=2 * tp_k_head_num_, head_dim=head_dim_, layer_num=self.config["num_hidden_layers"], ) init_shared_data( - mem_propties=mem_propties + mem_propties=mem_propties, + device="cpu" if not self.hiradix_cache_gpu else "cuda" ) - self.radix_mem_buffer = RadixMemoryBuffer( + radix_mem_buffer = RadixMemoryBuffer( mem_propties, shared_data=get_shared_data(), - lock=self.radix_lock + lock=self.radix_lock, + device="cpu" if not self.hiradix_cache_gpu else "cuda" ) + self.radix_manager = RadixBufferManager(radix_buffer=radix_mem_buffer, + radix_mem_data=get_shared_data(), + lock=self.radix_lock) self.mem_propties = mem_propties self.shared_mem_data = get_shared_data() return diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index af0473194..9122f76d8 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -220,7 +220,9 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument("--disable_dynamic_prompt_cache", action="store_true", help="disable dynamic prompt cache") parser.add_argument("--chunked_prefill_size", type=int, default=4096, help="chunked prefill size") - parser.add_argument("--use_hi_dynamic_prompt_cache", action="store_true", help="enable hierachy prompt cache") + parser.add_argument("--use_hiradix_cache", action="store_true", help="enable hierachy prompt cache") + parser.add_argument("--hiradix_cache_gpu", action="store_true", help="enable hierachy prompt cache gpu") + parser.add_argument("--hiradix_cache_token_num", type=int, default=None , help="set the number of tokens to use hierachy prompt cache") parser.add_argument("--disable_chunked_prefill", action="store_true", help="whether to disable chunked prefill") parser.add_argument("--diverse_mode", action="store_true", help="diversity generation mode") parser.add_argument("--token_healing_mode", action="store_true", help="code model infer mode") diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 8d4e03d40..15d1c6a6f 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -173,10 +173,10 @@ def normal_or_p_d_start(args): args.batch_max_tokens >= args.chunked_prefill_size ), "chunked prefill mode, batch_max_tokens must >= chunked_prefill_size" - # if use_hi_dynamic_prompt_cache, then use_dynamic_prompt_cache must be True + # if use_hiradix_cache, then use_dynamic_prompt_cache must be True hiradix_cache_port_num = 0 - if args.use_hi_dynamic_prompt_cache: - assert not args.disable_dynamic_prompt_cache, "use_hi_dynamic_prompt_cache must be used with use_dynamic_prompt_cache" + if args.use_hiradix_cache: + assert not args.disable_dynamic_prompt_cache, "use_hiradix_cache must be used with use_dynamic_prompt_cache" # help to manage data stored on Ceph if "s3://" in args.model_dir: @@ -207,7 +207,7 @@ def normal_or_p_d_start(args): node_world_size = args.tp // args.nnodes - if args.use_hi_dynamic_prompt_cache: + if args.use_hiradix_cache: hiradix_cache_port_num = node_world_size + 2 can_use_ports = alloc_can_use_network_port( num=7 + node_world_size + args.visual_dp * args.visual_tp + hiradix_cache_port_num, used_nccl_ports=already_uesd_ports @@ -238,7 +238,7 @@ def normal_or_p_d_start(args): args.audio_port = audio_port args.cache_port = cache_port args.metric_port = metric_port - if args.use_hi_dynamic_prompt_cache: + if args.use_hiradix_cache: args.hiradix_cache_ports = can_use_ports[0:node_world_size] args.hiradix_server_ports = can_use_ports[node_world_size: node_world_size + 2] can_use_ports = can_use_ports[node_world_size + 2:] diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 0bb3d1912..d68587480 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -46,7 +46,7 @@ class StartArgs: dp_prefill_wait_step: int = field(default=0) disable_aggressive_schedule: bool = field(default=False) disable_dynamic_prompt_cache: bool = field(default=False) - use_hi_dynamic_prompt_cache: bool = field(default=False) + use_hiradix_cache: bool = field(default=False) chunked_prefill_size: int = field(default=8192) disable_chunked_prefill: bool = field(default=False) diverse_mode: bool = field(default=False) diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index bab31f3a7..a2d7084ca 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -51,8 +51,8 @@ def __init__( context = zmq.asyncio.Context(2) self.send_to_router = context.socket(zmq.PUSH) self.send_to_router.connect(f"{args.zmq_mode}127.0.0.1:{router_port}") - self.use_hi_dynamic_prompt_cache = args.use_hi_dynamic_prompt_cache - if self.use_hi_dynamic_prompt_cache: + self.use_hiradix_cache = args.use_hiradix_cache + if self.use_hiradix_cache: context_hiradix = zmq.asyncio.Context() self.send_to_hiradix = context_hiradix.socket(zmq.PUSH) self.send_to_hiradix.connect(f"{args.zmq_mode}127.0.0.1:{self.args.hiradix_server_ports[0]}") @@ -306,7 +306,7 @@ async def generate( ) req_objs.append(req_obj) - req_status = ReqStatus(group_request_id, multimodal_params, req_objs, start_time, self.use_hi_dynamic_prompt_cache) + req_status = ReqStatus(group_request_id, multimodal_params, req_objs, start_time, self.use_hiradix_cache) self.req_id_to_out_inf[group_request_id] = req_status await self.transfer_to_next_module_or_node( @@ -480,7 +480,7 @@ async def transfer_to_next_module( protocol=pickle.HIGHEST_PROTOCOL, ) else: - if self.use_hi_dynamic_prompt_cache: + if self.use_hiradix_cache: self.send_to_hiradix.send_pyobj( group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL @@ -705,8 +705,8 @@ async def handle_loop(self): class ReqStatus: - def __init__(self, group_request_id, multimodal_params, req_objs: List[Req], start_time, use_hi_dynamic_prompt_cache) -> None: - self.use_hi_dynamic_prompt_cache = use_hi_dynamic_prompt_cache + def __init__(self, group_request_id, multimodal_params, req_objs: List[Req], start_time, use_hiradix_cache) -> None: + self.use_hiradix_cache = use_hiradix_cache self.lock = asyncio.Lock() self.event = asyncio.Event() self.group_req_objs = GroupReqObjs( @@ -719,7 +719,7 @@ def __init__(self, group_request_id, multimodal_params, req_objs: List[Req], sta def can_release(self): for req in self.group_req_objs.shm_req_objs: - if self.use_hi_dynamic_prompt_cache: + if self.use_hiradix_cache: if req.can_release() and req.radix_status.is_finished(): return True diff --git a/lightllm/server/router/dynamic_prompt/disk_cache_server.py b/lightllm/server/router/dynamic_prompt/disk_cache_server.py index 8255f2cd3..089212d7b 100644 --- a/lightllm/server/router/dynamic_prompt/disk_cache_server.py +++ b/lightllm/server/router/dynamic_prompt/disk_cache_server.py @@ -15,11 +15,12 @@ from lightllm.utils.log_utils import init_logger from enum import Enum from .shared_arr import SharedArray -from .io_objs import ShmReqInfo, GroupReqInfo +from .io_objs import ShmReqInfo, GroupReqInfo, HitSate, PullState, PushState, CacheTask from lightllm.server.core.objs.io_objs import GroupReqIndexes from lightllm.server.core.objs import ShmReqManager from lightllm.utils.graceful_utils import graceful_registry from lightllm.common.radixmem_buffer import RadixMemoryBuffer +from lightllm.common.radixmem_manager import RadixBufferManager from lightllm.server.core.objs import Req, RadixStatus logger = init_logger(__name__) @@ -39,8 +40,8 @@ def __init__(self, unique_name: str, rank_in_node: int, mem_manager): self.cache_file = join(tmp_dir, "cache_file") all_buffers = mem_manager.kv_buffer all_buffers = all_buffers.view(all_buffers.shape[0], all_buffers.shape[1], -1) - from kvcache.python.jit import PyLocalCacheService + from kvcache.python.jit import PyLocalCacheService self.py_cache_service = PyLocalCacheService( file=self.cache_file, storage_size=128 * (1024 ** 3), # 128GB @@ -49,33 +50,44 @@ def __init__(self, unique_name: str, rank_in_node: int, mem_manager): num_worker=8 ) - def insert(self, tokens, kv_page_indexer, start_pos=0): + def insert(self, cache_task: CacheTask): + assert cache_task.mode == 'w', "Cache task mode must be 'w' for insert" + t = self.py_cache_service.create( - tokens=tokens, - kv_page_indexer=kv_page_indexer, - mode="w", - start_pos=start_pos) + tokens=cache_task.tokens, + kv_page_indexer=cache_task.kv_page_indexer, + mode=cache_task.mode, + start_pos=cache_task.start_pos) res = wait_until_ready(t) + if not res: self.py_cache_service.az5(t) + return False + + return True + + def read(self, cache_task: CacheTask): + assert cache_task.mode == 'r', "Cache task mode must be 'r' for read" - def read(self, tokens, kv_page_indexer, start_pos=0): t = self.py_cache_service.create( - tokens=tokens, - kv_page_indexer=kv_page_indexer, - mode="r", - start_pos=start_pos) + tokens=cache_task.tokens, + kv_page_indexer=cache_task.kv_page_indexer, + mode=cache_task.mode, + start_pos=cache_task.start_pos) + res = wait_until_ready(t) return res - def query(self, tokens): - query_result = self.py_cache_service.query(tokens) + def query(self, cache_task: CacheTask): + query_result = self.py_cache_service.query(cache_task.tokens) + max_len = 0 for result in query_result: if result: max_len += 1 else: break + return max_len * self.block_size @property @@ -84,9 +96,9 @@ def block_size(self,): class DiskCacheService(rpyc.Service): - def __init__(self, mem_manager=None, remote_cache_manager=None, shm_req_manager=None, rank_in_node=None): + def __init__(self, radix_manager=None, remote_cache_manager=None, shm_req_manager=None, rank_in_node=None): super().__init__() - self.mem_manager = mem_manager + self.radix_manager = radix_manager self.remote_cache_manager = remote_cache_manager self.shm_req_manager = shm_req_manager self.rank_in_node = rank_in_node @@ -95,22 +107,32 @@ def exposed_push(self, req_info): req_info: ShmReqInfo = ShmReqInfo.from_dict(req_info) req: Req = self.shm_req_manager.get_req_obj_by_index(req_info.shm_req_index) req.link_prompt_ids_shm_array() - assert req.radix_status.is_write_ready(self.rank_in_node), "radix cache is not ready" - input_token_ids = req.shm_prompt_ids.arr[0 : req.shm_cur_kv_len] - keys = torch.tensor(input_token_ids, dtype=torch.int64, device="cpu") - values = self.mem_manager.get_req_mem_index(req_info.request_id) - index = torch.tensor(values, device="cpu", dtype=torch.int32) - logger.info(f"_push_task_loop receive task keys {len(keys)} values {len(values)}") - self.remote_cache_manager.insert(keys, index) - self.mem_manager.free_req_index(req.request_id) - self.set_reqs_radix_status([req], RadixStatus.WRITE_DONE) - self.shm_req_manager.put_back_req_obj(req) - return {"status": "ok"} + + if not req.radix_status.is_write_ready(self.rank_in_node): + raise RuntimeError("Radix cache is not ready for write.") + + token_ids = req.shm_prompt_ids.arr[0 : req.shm_cur_kv_len] + keys = torch.tensor(token_ids, dtype=torch.int64, device="cpu") + + _, index_list = self.radix_manager.query_cache(tokens=token_ids.tolist()) + index_tensor = torch.tensor(index_list, device="cpu", dtype=torch.int32) + assert len(keys) == len(index_tensor), f"keys length {len(keys)} != index length {len(index_list)}" + + if len(keys) != len(index_tensor): + raise ValueError(f"Mismatch in keys and index size: {len(keys)} != {len(index_tensor)}") + + insert_task = CacheTask(tokens=keys, kv_page_indexer=index_tensor, mode='w') + result = self.remote_cache_manager.insert(insert_task) + + reqs = [req] + self.set_reqs_radix_status(reqs, RadixStatus.WRITE_DONE) + self.put_back_req_objs(reqs) + + return PushState(state=result).to_dict() def set_reqs_radix_status(self, reqs: List[Req], status: int): for req in reqs: req.radix_status.set_status(self.rank_in_node, status) - logger.info(f"-->pull loop rank_in_node={self.rank_in_node} set req {req.group_req_id, req.request_id} radix status {req.radix_status.get_status(self.rank_in_node)}") def put_back_req_objs(self, reqs: List[Req]): for req in reqs: @@ -122,22 +144,41 @@ def exposed_pull(self, group_req): for shm_req_index in group_req.shm_req_indexes: req: Req = self.shm_req_manager.get_req_obj_by_index(shm_req_index) reqs.append(req) + req = reqs[0] req.link_prompt_ids_shm_array() keys = req.get_prompt_ids() - query_len = self.remote_cache_manager.query(tokens=keys) - if query_len == 0: - self.set_reqs_radix_status(reqs, RadixStatus.NOCACHE) - self.put_back_req_objs(reqs) - return {"query_len": 0, "kv_indices": []} - index = self.mem_manager.alloc(query_len) - self.remote_cache_manager.read(tokens=keys[:query_len], kv_page_indexer=index) - self.mem_manager.set_req_mem_index( - group_req.group_req_id, index.tolist() - ) - self.set_reqs_radix_status(reqs, RadixStatus.READ_READY) + + query_len, _ = self.radix_manager.query_cache(tokens=keys) + if query_len > 0: + radix_state = RadixStatus.READ_READY + cache_state = PullState(query_len, HitSate.MEM) + else: + query_task = CacheTask(tokens=keys) + query_len = self.remote_cache_manager.query(query_task) + + if query_len > 0: + self.radix_manager.free_space(query_len) + index = self.radix_manager.radix_buffer.alloc(query_len) + read_task = CacheTask( + tokens=keys[:query_len], + kv_page_indexer=index, + mode='r' + ) + self.remote_cache_manager.read(read_task) + + self.radix_manager.write(keys=keys[:query_len], values=index.tolist()) + + radix_state = RadixStatus.READ_READY + cache_state = PullState(query_len, HitSate.DISK) + else: + radix_state = RadixStatus.NOCACHE + cache_state = PullState(0, HitSate.NONE) + + self.set_reqs_radix_status(reqs, radix_state) self.put_back_req_objs(reqs) - return {"query_len": query_len, "kv_indices": index.tolist()} + + return cache_state.to_dict() class DiskCacheClient: @@ -176,10 +217,10 @@ async def pull(self, group_req: GroupReqInfo): return self._pull(group_req) -def start_cache_server(mem_manager, remote_cache_manager, shm_req_manager, rank_in_node, port, init_event): +def start_cache_server(radix_manager, remote_cache_manager, shm_req_manager, rank_in_node, port, init_event): class CustomService(DiskCacheService): def __init__(self): - super().__init__(mem_manager, remote_cache_manager, shm_req_manager, rank_in_node) + super().__init__(radix_manager, remote_cache_manager, shm_req_manager, rank_in_node) def start(): try: @@ -220,10 +261,14 @@ def _init_server( rank_in_node=device_id, mem_manager=mem_manager, ) + radix_manager = RadixBufferManager(radix_buffer=mem_manager, + radix_mem_data=shared_mem_data, + lock=radix_lock) + shm_req_manager = ShmReqManager() t = start_cache_server( - mem_manager=mem_manager, + radix_manager=radix_manager, remote_cache_manager=remote_cache_manager, shm_req_manager=shm_req_manager, rank_in_node=device_id, @@ -247,7 +292,7 @@ async def start_disk_cache_server_process( from lightllm.utils.envs_utils import get_unique_server_name if node_word_size == 1: mem_proties, shared_mem_data = mem_queue.get() - mem_manager = RadixMemoryBuffer( + mem_buffer = RadixMemoryBuffer( mem_propties=mem_proties, shared_data=shared_mem_data, lock=radix_lock, @@ -256,10 +301,14 @@ async def start_disk_cache_server_process( remote_cache_manager = RemoteCacheManager( unique_name=get_unique_server_name(), rank_in_node=device_id, - mem_manager=mem_manager, + mem_manager=mem_buffer, ) shm_req_manager = ShmReqManager() - service = DiskCacheService(mem_manager, remote_cache_manager, shm_req_manager) + + radix_manager = RadixBufferManager(radix_buffer=mem_buffer, + radix_mem_data=shared_mem_data, + lock=radix_lock) + service = DiskCacheService(radix_manager, remote_cache_manager, shm_req_manager) client = DiskCacheClient( service=service, rank_in_node=0, diff --git a/lightllm/server/router/dynamic_prompt/hiradix_cache.py b/lightllm/server/router/dynamic_prompt/hiradix_cache.py index c94cbc7d3..635452f5b 100644 --- a/lightllm/server/router/dynamic_prompt/hiradix_cache.py +++ b/lightllm/server/router/dynamic_prompt/hiradix_cache.py @@ -12,6 +12,7 @@ from typing import Tuple, Dict, Set, List from lightllm.common.mem_manager import MemoryManager from lightllm.common.radixmem_buffer import RadixMemoryBuffer +from lightllm.common.radixmem_manager import RadixBufferManager from lightllm.utils.log_utils import init_logger from threading import Lock from enum import Enum @@ -28,58 +29,77 @@ class LocalCacheManager: - def __init__(self, mem_buffer: RadixMemoryBuffer, mem_manager: MemoryManager, rank_in_node): - self.mem_buffer = mem_buffer + def __init__(self, radix_manager: RadixBufferManager, mem_manager: MemoryManager, rank_in_node): + self.radix_manager = radix_manager + self.radix_buffer: RadixMemoryBuffer = self.radix_manager.radix_buffer self.mem_manager = mem_manager self.rank_in_node = rank_in_node def insert(self, req: Req, key: torch.Tensor, value=None): - pre_index = self.mem_buffer.get_req_mem_index(req.request_id) - if len(pre_index) != 0 and len(value) > len(pre_index): - logger.info(f"pre index req {req.request_id} {pre_index}") - alloc_len = len(value) - len(pre_index) - index = self.mem_buffer.alloc(alloc_len) - value = value[len(pre_index):] - self.mem_buffer.set_req_mem_index( - req.request_id, pre_index + index.tolist() - ) - logger.info(f"udpate index req {req.request_id} {pre_index + index.tolist()}") - else: - index = self.mem_buffer.alloc(len(value)) - self.mem_buffer.set_req_mem_index( - req.request_id, index.tolist() - ) - dst_kv_buffer = self.mem_buffer.get_kv_buffer(index) - src_kv_buffer = self.mem_manager.get_index_kv_buffer(value)["kv_buffer"] - logger.info(f"insert mem_buffer shape {dst_kv_buffer.shape}, manager buffer shape {src_kv_buffer.shape}") - assert len(src_kv_buffer) == len(dst_kv_buffer), f"src kv buffer len {len(src_kv_buffer)} != dst kv buffer len {len(dst_kv_buffer)}" + query_len, query_index = self._query_cache(req, key) + + alloc_len = len(key) - query_len + if alloc_len == 0: + self._set_radix_staus(req, RadixStatus.WRITE_READY) + return + + new_index = self._alloc_and_copy_kv(alloc_len, value) + + start_pos = max(0, (query_len - 1) // self.chunk_size * self.chunk_size) + self.radix_manager.write( + tokens=key.tolist(), + values=query_index + new_index, + start_pos=start_pos + ) + + self._set_radix_staus(req, RadixStatus.WRITE_READY) + + def _query_cache(self, req, key): + if req.radix_status.is_no_need_cache(self.rank_in_node): + logger.info(f"query no need cache {self.rank_in_node} {req.radix_status.get_status(self.rank_in_node)}") + return 0, [] + + if req.radix_status.is_read_ready(self.rank_in_node): + query_len, mem_index = self.radix_manager.query_cache(key.tolist()) + return query_len, mem_index + return 0, [] + + def _alloc_and_copy_kv(self, alloc_len, value): + assert alloc_len > 0, "No allocation needed" + + new_index = self.radix_buffer.alloc(alloc_len) + dst_kv_buffer = self.radix_buffer.get_kv_buffer(new_index) + src_kv_buffer = self.mem_manager.get_index_kv_buffer(value[-alloc_len:])["kv_buffer"] + + assert len(src_kv_buffer) == len(dst_kv_buffer), f"Mis match buffer size src {len(src_kv_buffer)} != dst {len(dst_kv_buffer)}" + self.copy_kv_from_gpu_to_cpu(src_kv_buffer, dst_kv_buffer) - req.radix_status.set_status(self.rank_in_node, RadixStatus.WRITE_READY) + return new_index.tolist() - def read(self, req: Req, dst_index): + def _set_radix_staus(self, req, status): + req.radix_status.set_status(self.rank_in_node, status) + + def read(self, key, value, query_index, alloc_len): try: - index = self.mem_buffer.get_req_mem_index(req.group_req_id) - src_kv_buffer = self.mem_buffer.get_kv_buffer(index[-len(dst_index)]) - dst_kv_buffer = self.mem_manager.get_index_kv_buffer(dst_index)["kv_buffer"] - logger.info(f"len mem src index and dst index {len(index), len(dst_index)} read mem_buffer shape {src_kv_buffer.shape}, manager buffer shape {dst_kv_buffer.shape}") - assert len(src_kv_buffer) == len(dst_kv_buffer), f"src kv buffer len {len(src_kv_buffer)} != dst kv buffer len {len(dst_kv_buffer)}" + src_kv_buffer = self.radix_buffer.get_kv_buffer(index=query_index[-alloc_len:]) + dst_kv_buffer = self.mem_manager.get_index_kv_buffer(index=value[-alloc_len:])["kv_buffer"] + + assert len(src_kv_buffer) == len(dst_kv_buffer), f"Mis match buffer size src {len(src_kv_buffer)} != dst {len(dst_kv_buffer)}" + self.copy_kv_from_cpu_to_gpu(src_kv_buffer, dst_kv_buffer) - #TODO no free - self.mem_buffer.free_req_index(req.group_req_id) + except Exception as e: - logger.error(f"Local cache read from radix mem_buffer error {e}") + logger.error(f"LocalCache read from radix mem error {e}") return False + return True - def query(self, req: Req): - if req.radix_status.is_no_need_cache(self.rank_in_node): - logger.info(f"query no need cache {self.rank_in_node} {req.radix_status.get_status(self.rank_in_node)}") - return 0 - if req.radix_status.is_read_ready(self.rank_in_node): - index = self.mem_buffer.get_req_mem_index(req.group_req_id) - logger.info(f"query find cache {self.rank_in_node} {req.radix_status.get_status(self.rank_in_node)} len {len(index)}") - return len(index) - return 0 + def query(self, req: Req, key): + return self._query_cache(req, key) + + @property + def chunk_size(self): + return self.radix_manager.chunk_size def copy_kv_from_cpu_to_gpu(self, src_kv_tensor, dst_kv_tensor): dst_kv_tensor.copy_(src_kv_tensor, non_blocking=True) @@ -89,11 +109,12 @@ def copy_kv_from_gpu_to_cpu(self, src_kv_tensor, dst_kv_tensor): class HiRadixCache(RadixCache): - def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager, mem_buffer, radix_info_queue): + def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager, radix_manager, radix_info_queue): super().__init__(unique_name, total_token_num, rank_in_node, mem_manager) self.rank_in_node = rank_in_node + self.radix_manager: RadixBufferManager = radix_manager self.local_cache_manager = LocalCacheManager( - mem_buffer=mem_buffer, + radix_manager=self.radix_manager, mem_manager=mem_manager, rank_in_node=rank_in_node ) @@ -126,7 +147,8 @@ def match_prefix(self, req, key, update_refs=False): ans_value = torch.concat(ans_value_list) max_len = 0 if tree_node.node_prefix_total_len < len(key): - max_len = self.local_cache_manager.query(req) + max_len, query_index = self.local_cache_manager.query(req, key) + logger.debug(f"HiCache rank_in_node={self.rank_in_node} current key len {len(key)} match radix len {tree_node.node_prefix_total_len}, max len {max_len}") if max_len > tree_node.node_prefix_total_len: pull_len = max_len - tree_node.node_prefix_total_len @@ -134,13 +156,10 @@ def match_prefix(self, req, key, update_refs=False): self.disk_cache_match_ratio.arr[0] = self.disk_cache_match_count.arr[0] / self.total_match_count.arr[0] self.free_radix_cache_to_get_enough_token(pull_len) buffers = self.mem_manager.alloc(pull_len) - start_pos = 0 if ans_value is not None: buffers = torch.concat([ans_value, buffers]) - start_pos = (tree_node.node_prefix_total_len - 1) // self.local_cache_manager.block_size * self.local_cache_manager.block_size logger.debug(f"HiCache current match ratio {self.disk_cache_match_ratio.arr[0]}, pulled cache len {pull_len} from disk") - # res = self.local_cache_manager.read(tokens=key[:max_len], kv_page_indexer=buffers, start_pos=start_pos) - res = self.local_cache_manager.read(req, buffers) + res = self.local_cache_manager.read(key[:max_len], buffers, query_index, alloc_len=pull_len) if res: super().insert(key[:max_len], buffers) else: diff --git a/lightllm/server/router/dynamic_prompt/io_objs.py b/lightllm/server/router/dynamic_prompt/io_objs.py index b95db6b3b..93031a631 100644 --- a/lightllm/server/router/dynamic_prompt/io_objs.py +++ b/lightllm/server/router/dynamic_prompt/io_objs.py @@ -1,4 +1,6 @@ +import torch from dataclasses import dataclass +from enum import Enum from typing import List @dataclass @@ -35,4 +37,49 @@ def from_dict(d): return GroupReqInfo( group_req_id=d["group_req_id"], shm_req_indexes=d["shm_req_indexes"] + ) + +@dataclass +class CacheTask: + tokens: torch.Tensor + mode: str = None + kv_page_indexer: torch.Tensor = None + start_pos: torch.Tensor = 0 + +@dataclass +class PushState: + state: bool + + def to_dict(self): + return { + "state": self.state + } + + @staticmethod + def from_dict(d): + return PushState( + state=d["state"], + ) + +class HitSate(Enum): + NONE = -1 + MEM = 0 + DISK = 1 + +@dataclass +class PullState: + match_length: int + cache_source: HitSate + + def to_dict(self): + return { + "match_length": self.match_length, + "cache_source": self.cache_source.name + } + + @staticmethod + def from_dict(d): + return PullState( + match_length=d["match_length"], + cache_source=HitSate[d["cache_source"]] ) \ No newline at end of file diff --git a/lightllm/server/router/dynamic_prompt/manager.py b/lightllm/server/router/dynamic_prompt/manager.py index 0d8dcb3b8..bafc3d029 100644 --- a/lightllm/server/router/dynamic_prompt/manager.py +++ b/lightllm/server/router/dynamic_prompt/manager.py @@ -105,7 +105,6 @@ async def run(self): async def loop_for_netio_req_to_push(self): while True: recv_req: ShmReqInfo = await self.recv_from_router.recv_pyobj() - logger.info(f"loop_for_netio_req_to_push --> recv req {recv_req}") if isinstance(recv_req, ShmReqInfo): await self.push_queue.put(recv_req) else: diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 03f0ebff8..5e73ef83a 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -57,7 +57,7 @@ def __init__(self, args: StartArgs, router_port, detokenization_port, metric_por self.read_only_statics_mem_manager = ReadOnlyStaticsMemoryManager() # 初始化 radix_cache_client 用于读取 prompt cache 的管理信息 self.radix_cache_client = None - self.use_hiradix_cache = args.use_hi_dynamic_prompt_cache and not args.disable_dynamic_prompt_cache + self.use_hiradix_cache = args.use_hiradix_cache and not args.disable_dynamic_prompt_cache self.mtp_step = args.mtp_step # 共享变量,用于存储router端调度分析得到的机器负载信息 @@ -179,7 +179,7 @@ async def wait_to_model_ready(self): "return_all_prompt_logprobs": self.args.return_all_prompt_logprobs, "use_reward_model": self.args.use_reward_model, "disable_dynamic_prompt_cache": self.args.disable_dynamic_prompt_cache, - "use_hi_dynamic_prompt_cache": self.args.use_hi_dynamic_prompt_cache, + "use_hiradix_cache": self.args.use_hiradix_cache, "data_type": self.args.data_type, "eos_id": self.eos_id, "diverse_mode": self.args.diverse_mode, diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 964ffe42c..4f7645e28 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -56,7 +56,6 @@ def init_model(self, kvargs): self.chunked_prefill_size = self.args.chunked_prefill_size self.return_all_prompt_logprobs = self.args.return_all_prompt_logprobs self.use_dynamic_prompt_cache = not self.args.disable_dynamic_prompt_cache - self.use_hi_dynamic_prompt_cache = self.args.use_hi_dynamic_prompt_cache self.eos_id: List[int] = kvargs.get("eos_id", [2]) self.disable_cudagraph = self.args.disable_cudagraph self.use_hiradix_cache = kvargs.get("use_hiradix_cache", False) @@ -119,6 +118,9 @@ def init_model(self, kvargs): "quant_cfg": kvargs.get("quant_cfg", None), "run_mode": self.run_mode, "use_hiradix_cache": self.use_hiradix_cache, + "radix_lock": self.radix_lock, + "hiradix_cache_gpu": kvargs.get("hiradix_cache_gpu", False), + "hiradix_cache_token_num": kvargs.get("hiradix_cache_token_num", False), "radix_lock": self.radix_lock } self.model, self.is_multimodal = get_model(model_cfg, model_kvargs) @@ -130,7 +132,7 @@ def init_model(self, kvargs): self.model.mem_manager.size, self.rank_in_node, mem_manager=self.model.mem_manager, - mem_buffer=self.model.radix_mem_buffer, + radix_manager=self.model.radix_manager, radix_info_queue=kvargs.get("radix_info_queue", None) ) if self.use_hiradix_cache diff --git a/lightllm/server/router/model_infer/model_rpc.py b/lightllm/server/router/model_infer/model_rpc.py index 80ce82de6..1e0ea2091 100644 --- a/lightllm/server/router/model_infer/model_rpc.py +++ b/lightllm/server/router/model_infer/model_rpc.py @@ -131,11 +131,13 @@ def init_model(self, kvargs): assert not (is_outlines_constraint_mode and is_xgrammar_constraint_mode), "only one constraint mode can be true" is_prefill_node = self.args.run_mode == "prefill" is_decode_node = self.args.run_mode == "decode" - use_hiradix_cache = self.args.use_hi_dynamic_prompt_cache and not self.args.disable_dynamic_prompt_cache + use_hiradix_cache = self.args.use_hiradix_cache and not self.args.disable_dynamic_prompt_cache kvargs.update({ "use_hiradix_cache": use_hiradix_cache, + "hiradix_cache_gpu": self.args.hiradix_cache_gpu, "radix_info_queue": self.radix_info_queue, - "radix_lock": self.radix_lock + "radix_lock": self.radix_lock, + "hiradix_cache_token_num": self.args.hiradix_cache_token_num }) enable_mtp = self.args.mtp_mode is not None From 1a7e7d3da7d3c100e5cdb384fa074046bf375852 Mon Sep 17 00:00:00 2001 From: "jinbiaoyu@126.com" Date: Mon, 21 Jul 2025 11:48:02 +0800 Subject: [PATCH 07/13] fix error --- lightllm/common/radixmem_manager.py | 4 ++-- lightllm/server/router/dynamic_prompt/hiradix_cache.py | 3 ++- lightllm/server/router/manager.py | 2 -- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/lightllm/common/radixmem_manager.py b/lightllm/common/radixmem_manager.py index 41a393e90..4eaee1fbf 100644 --- a/lightllm/common/radixmem_manager.py +++ b/lightllm/common/radixmem_manager.py @@ -17,7 +17,7 @@ def __init__(self, radix_buffer: RadixMemoryBuffer = None, radix_mem_data: SharedRadixMemoryData = None, lock: Optional[mp.Lock] = None, - max_entries: int = 100, + max_entries: int = 10000, chunk_size: int = 64 ): self.chunk_size = chunk_size @@ -66,7 +66,7 @@ def _update_lru_state(self, hash_val: int): def free_space(self, required_size: int) -> bool: with self.lock: - current_free = self.radix_buffer.get_can_use_mem_size.get_value() + current_free = self.radix_buffer.can_use_mem_size.get_value() if current_free >= required_size: return True diff --git a/lightllm/server/router/dynamic_prompt/hiradix_cache.py b/lightllm/server/router/dynamic_prompt/hiradix_cache.py index 635452f5b..f399528ab 100644 --- a/lightllm/server/router/dynamic_prompt/hiradix_cache.py +++ b/lightllm/server/router/dynamic_prompt/hiradix_cache.py @@ -42,6 +42,8 @@ def insert(self, req: Req, key: torch.Tensor, value=None): if alloc_len == 0: self._set_radix_staus(req, RadixStatus.WRITE_READY) return + + self.radix_manager.free_space(alloc_len) new_index = self._alloc_and_copy_kv(alloc_len, value) @@ -56,7 +58,6 @@ def insert(self, req: Req, key: torch.Tensor, value=None): def _query_cache(self, req, key): if req.radix_status.is_no_need_cache(self.rank_in_node): - logger.info(f"query no need cache {self.rank_in_node} {req.radix_status.get_status(self.rank_in_node)}") return 0, [] if req.radix_status.is_read_ready(self.rank_in_node): diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 5e73ef83a..062426370 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -87,7 +87,6 @@ def __init__(self, args: StartArgs, router_port, detokenization_port, metric_por context_radix = zmq.asyncio.Context() self.send_to_hiradix_server = context_radix.socket(zmq.PUSH) self.send_to_hiradix_server.connect(f"{args.zmq_mode}127.0.0.1:{hiradix_port}") - logger.info(f"send_to_hiradix_server {args.zmq_mode}127.0.0.1:{hiradix_port}") if self.is_multinode_tp: self.mulitnode_group = dist.init_process_group( @@ -406,7 +405,6 @@ async def _send_hiradix_manager(self, reqs): req.index_in_shm_mem ) await self.send_to_hiradix_server.send_pyobj(req_info, protocol=pickle.HIGHEST_PROTOCOL) - logger.info(f"_send_hiradix_manager {req_info}") return def _send_detokenization_pack(self): From b55ca742f4c69743333f3d163f516c6cef85c545 Mon Sep 17 00:00:00 2001 From: "jinbiaoyu@126.com" Date: Mon, 21 Jul 2025 13:31:47 +0800 Subject: [PATCH 08/13] update package dir --- lightllm/server/router/dynamic_prompt/hiradix/__init__.py | 0 .../router/dynamic_prompt/{ => hiradix}/disk_cache_server.py | 2 +- .../router/dynamic_prompt/{ => hiradix}/hiradix_cache.py | 4 ++-- .../server/router/dynamic_prompt/{ => hiradix}/io_objs.py | 0 .../server/router/dynamic_prompt/{ => hiradix}/manager.py | 2 +- lightllm/server/router/manager.py | 4 ++-- .../server/router/model_infer/mode_backend/base_backend.py | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) create mode 100644 lightllm/server/router/dynamic_prompt/hiradix/__init__.py rename lightllm/server/router/dynamic_prompt/{ => hiradix}/disk_cache_server.py (99%) rename lightllm/server/router/dynamic_prompt/{ => hiradix}/hiradix_cache.py (98%) rename lightllm/server/router/dynamic_prompt/{ => hiradix}/io_objs.py (100%) rename lightllm/server/router/dynamic_prompt/{ => hiradix}/manager.py (98%) diff --git a/lightllm/server/router/dynamic_prompt/hiradix/__init__.py b/lightllm/server/router/dynamic_prompt/hiradix/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/server/router/dynamic_prompt/disk_cache_server.py b/lightllm/server/router/dynamic_prompt/hiradix/disk_cache_server.py similarity index 99% rename from lightllm/server/router/dynamic_prompt/disk_cache_server.py rename to lightllm/server/router/dynamic_prompt/hiradix/disk_cache_server.py index 089212d7b..008d7b992 100644 --- a/lightllm/server/router/dynamic_prompt/disk_cache_server.py +++ b/lightllm/server/router/dynamic_prompt/hiradix/disk_cache_server.py @@ -14,7 +14,7 @@ from typing import Tuple, Dict, Set, List from lightllm.utils.log_utils import init_logger from enum import Enum -from .shared_arr import SharedArray +from ..shared_arr import SharedArray from .io_objs import ShmReqInfo, GroupReqInfo, HitSate, PullState, PushState, CacheTask from lightllm.server.core.objs.io_objs import GroupReqIndexes from lightllm.server.core.objs import ShmReqManager diff --git a/lightllm/server/router/dynamic_prompt/hiradix_cache.py b/lightllm/server/router/dynamic_prompt/hiradix/hiradix_cache.py similarity index 98% rename from lightllm/server/router/dynamic_prompt/hiradix_cache.py rename to lightllm/server/router/dynamic_prompt/hiradix/hiradix_cache.py index f399528ab..ca8665690 100644 --- a/lightllm/server/router/dynamic_prompt/hiradix_cache.py +++ b/lightllm/server/router/dynamic_prompt/hiradix/hiradix_cache.py @@ -8,7 +8,7 @@ import torch.multiprocessing as mp import torch.distributed as dist from os.path import join -from .radix_cache import RadixCache, TreeNode, match +from ..radix_cache import RadixCache, TreeNode, match from typing import Tuple, Dict, Set, List from lightllm.common.mem_manager import MemoryManager from lightllm.common.radixmem_buffer import RadixMemoryBuffer @@ -16,7 +16,7 @@ from lightllm.utils.log_utils import init_logger from threading import Lock from enum import Enum -from .shared_arr import SharedArray +from ..shared_arr import SharedArray from .io_objs import ShmReqInfo from lightllm.server.core.objs import Req, RadixStatus from lightllm.server.core.objs.io_objs import GroupReqIndexes diff --git a/lightllm/server/router/dynamic_prompt/io_objs.py b/lightllm/server/router/dynamic_prompt/hiradix/io_objs.py similarity index 100% rename from lightllm/server/router/dynamic_prompt/io_objs.py rename to lightllm/server/router/dynamic_prompt/hiradix/io_objs.py diff --git a/lightllm/server/router/dynamic_prompt/manager.py b/lightllm/server/router/dynamic_prompt/hiradix/manager.py similarity index 98% rename from lightllm/server/router/dynamic_prompt/manager.py rename to lightllm/server/router/dynamic_prompt/hiradix/manager.py index bafc3d029..1d8fdeb85 100644 --- a/lightllm/server/router/dynamic_prompt/manager.py +++ b/lightllm/server/router/dynamic_prompt/hiradix/manager.py @@ -45,7 +45,7 @@ async def asyn_init(self): self.push_queue = asyncio.Queue() async def start_all(self): - from lightllm.server.router.dynamic_prompt.disk_cache_server import start_disk_cache_server_process + from lightllm.server.router.dynamic_prompt.hiradix.disk_cache_server import start_disk_cache_server_process for rank_in_node in range(self.node_world_size): client = await start_disk_cache_server_process( self.args, diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 062426370..d8a0acad3 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -15,7 +15,7 @@ from .batch import Batch, Req from .model_infer.model_rpc import start_model_process, ModelRpcClient from .req_queue import build_req_queue -from lightllm.server.router.dynamic_prompt.io_objs import ShmReqInfo +from lightllm.server.router.dynamic_prompt.hiradix.io_objs import ShmReqInfo from lightllm.server.core.objs.io_objs import GroupReqIndexes from lightllm.server.core.objs import ShmReqManager, StartArgs from .dynamic_prompt.radix_cache import RadixCacheReadOnlyClient @@ -225,7 +225,7 @@ async def wait_to_model_ready(self): if self.use_hiradix_cache: # 启动 hi radix cache 管理进程 - from lightllm.server.router.dynamic_prompt.manager import start_hiradix_cache_manager_process_server + from lightllm.server.router.dynamic_prompt.hiradix.manager import start_hiradix_cache_manager_process_server start_hiradix_cache_manager_process_server(self.args, self.radix_mem_queues, self.radix_locks, self.router_port) return diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 4f7645e28..6914f147f 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -8,7 +8,7 @@ from lightllm.models import get_model from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache from lightllm.server.router.model_infer.infer_batch import InferReq -from lightllm.server.router.dynamic_prompt.hiradix_cache import HiRadixCache +from lightllm.server.router.dynamic_prompt.hiradix.hiradix_cache import HiRadixCache from lightllm.server.router.model_infer.infer_batch import InferReq, InferSamplingParams from lightllm.server.router.token_load import TokenLoad from lightllm.common.basemodel.infer_lock import g_infer_state_lock, InferStateLock From 4f296725cbe8ca40613cc598031ecae968088c0d Mon Sep 17 00:00:00 2001 From: "jinbiaoyu@126.com" Date: Mon, 21 Jul 2025 14:58:02 +0800 Subject: [PATCH 09/13] fix format --- lightllm/common/radixmem_manager.py | 2 +- lightllm/server/core/objs/req.py | 2 +- .../hiradix/disk_cache_server.py | 120 +++++++----------- .../dynamic_prompt/hiradix/hiradix_cache.py | 70 +++++----- .../router/dynamic_prompt/hiradix/io_objs.py | 42 ++---- .../router/dynamic_prompt/hiradix/manager.py | 49 +++---- .../model_infer/mode_backend/base_backend.py | 1 - 7 files changed, 106 insertions(+), 180 deletions(-) diff --git a/lightllm/common/radixmem_manager.py b/lightllm/common/radixmem_manager.py index 4eaee1fbf..44564d2c6 100644 --- a/lightllm/common/radixmem_manager.py +++ b/lightllm/common/radixmem_manager.py @@ -52,7 +52,7 @@ def write(self, tokens: List[int], values: torch.Tensor, start_pos: int) -> None values = values[index * self.chunk_size:] chunks = chunks[index:] for i, (hash_val, _) in enumerate(chunks): - if hash not in self.radix_buffer.req_mem_index: + if hash_val not in self.radix_buffer.req_mem_index: self.radix_buffer.req_mem_index[hash_val] = values[i * self.chunk_size : (i + 1) * self.chunk_size] self._update_lru_state(hash_val) diff --git a/lightllm/server/core/objs/req.py b/lightllm/server/core/objs/req.py index a651e15f9..2ae2cbcca 100644 --- a/lightllm/server/core/objs/req.py +++ b/lightllm/server/core/objs/req.py @@ -91,7 +91,7 @@ def set_finished(self): self.finished = 1 def is_finished(self): - self.finished == 1 + return self.finished == 1 def get_status(self, idx: int) -> int: assert 0 <= idx < 32, f"Index out of range: {idx}" diff --git a/lightllm/server/router/dynamic_prompt/hiradix/disk_cache_server.py b/lightllm/server/router/dynamic_prompt/hiradix/disk_cache_server.py index 008d7b992..33aee7106 100644 --- a/lightllm/server/router/dynamic_prompt/hiradix/disk_cache_server.py +++ b/lightllm/server/router/dynamic_prompt/hiradix/disk_cache_server.py @@ -2,21 +2,15 @@ import time import tempfile import rpyc -import zmq import inspect import asyncio import threading -import numpy as np import torch.multiprocessing as mp -from typing import List, Union +from typing import List from rpyc.utils.server import ThreadedServer from os.path import join -from typing import Tuple, Dict, Set, List from lightllm.utils.log_utils import init_logger -from enum import Enum -from ..shared_arr import SharedArray from .io_objs import ShmReqInfo, GroupReqInfo, HitSate, PullState, PushState, CacheTask -from lightllm.server.core.objs.io_objs import GroupReqIndexes from lightllm.server.core.objs import ShmReqManager from lightllm.utils.graceful_utils import graceful_registry from lightllm.common.radixmem_buffer import RadixMemoryBuffer @@ -25,6 +19,7 @@ logger = init_logger(__name__) + def wait_until_ready(task, timeout=10.0, check_interval=0.01): start_time = time.time() while not task.ready(): @@ -34,6 +29,7 @@ def wait_until_ready(task, timeout=10.0, check_interval=0.01): return False return True + class RemoteCacheManager: def __init__(self, unique_name: str, rank_in_node: int, mem_manager): tmp_dir = tempfile.mkdtemp(prefix=f"cache_{unique_name}_{rank_in_node}") @@ -42,22 +38,24 @@ def __init__(self, unique_name: str, rank_in_node: int, mem_manager): all_buffers = all_buffers.view(all_buffers.shape[0], all_buffers.shape[1], -1) from kvcache.python.jit import PyLocalCacheService + self.py_cache_service = PyLocalCacheService( file=self.cache_file, - storage_size=128 * (1024 ** 3), # 128GB + storage_size=128 * (1024**3), # 128GB num_shard=32, kvcache_tensor=all_buffers, - num_worker=8 + num_worker=8, ) def insert(self, cache_task: CacheTask): - assert cache_task.mode == 'w', "Cache task mode must be 'w' for insert" + assert cache_task.mode == "w", "Cache task mode must be 'w' for insert" t = self.py_cache_service.create( - tokens=cache_task.tokens, - kv_page_indexer=cache_task.kv_page_indexer, - mode=cache_task.mode, - start_pos=cache_task.start_pos) + tokens=cache_task.tokens, + kv_page_indexer=cache_task.kv_page_indexer, + mode=cache_task.mode, + start_pos=cache_task.start_pos, + ) res = wait_until_ready(t) if not res: @@ -67,13 +65,14 @@ def insert(self, cache_task: CacheTask): return True def read(self, cache_task: CacheTask): - assert cache_task.mode == 'r', "Cache task mode must be 'r' for read" + assert cache_task.mode == "r", "Cache task mode must be 'r' for read" t = self.py_cache_service.create( - tokens=cache_task.tokens, - kv_page_indexer=cache_task.kv_page_indexer, - mode=cache_task.mode, - start_pos=cache_task.start_pos) + tokens=cache_task.tokens, + kv_page_indexer=cache_task.kv_page_indexer, + mode=cache_task.mode, + start_pos=cache_task.start_pos, + ) res = wait_until_ready(t) return res @@ -91,7 +90,9 @@ def query(self, cache_task: CacheTask): return max_len * self.block_size @property - def block_size(self,): + def block_size( + self, + ): return self.py_cache_service.tokens_per_block @@ -121,7 +122,7 @@ def exposed_push(self, req_info): if len(keys) != len(index_tensor): raise ValueError(f"Mismatch in keys and index size: {len(keys)} != {len(index_tensor)}") - insert_task = CacheTask(tokens=keys, kv_page_indexer=index_tensor, mode='w') + insert_task = CacheTask(tokens=keys, kv_page_indexer=index_tensor, mode="w") result = self.remote_cache_manager.insert(insert_task) reqs = [req] @@ -156,15 +157,11 @@ def exposed_pull(self, group_req): else: query_task = CacheTask(tokens=keys) query_len = self.remote_cache_manager.query(query_task) - + if query_len > 0: self.radix_manager.free_space(query_len) index = self.radix_manager.radix_buffer.alloc(query_len) - read_task = CacheTask( - tokens=keys[:query_len], - kv_page_indexer=index, - mode='r' - ) + read_task = CacheTask(tokens=keys[:query_len], kv_page_indexer=index, mode="r") self.remote_cache_manager.read(read_task) self.radix_manager.write(keys=keys[:query_len], values=index.tolist()) @@ -186,7 +183,7 @@ def __init__(self, rank_in_node: int, service=None, use_rpc=True, proc=None): self.rank_in_node = rank_in_node self.use_rpc = use_rpc self.service = service - self.proc=proc + self.proc = proc if self.use_rpc: self._push = self._async_wraper(self.service.push) self._pull = self._async_wraper(self.service.pull) @@ -224,9 +221,9 @@ def __init__(self): def start(): try: - server = ThreadedServer(CustomService(), - port=port, - protocol_config={"allow_public_attrs": True, "allow_pickle": True}) + server = ThreadedServer( + CustomService(), port=port, protocol_config={"allow_public_attrs": True, "allow_pickle": True} + ) init_event.set() server.start() except Exception as e: @@ -239,31 +236,21 @@ def start(): return t -def _init_server( - device_id, - mem_queue, - radix_lock: List[mp.Lock], - init_event: mp.Event, - port:int=18861 -): +def _init_server(device_id, mem_queue, radix_lock: List[mp.Lock], init_event: mp.Event, port: int = 18861): from lightllm.utils.envs_utils import get_unique_server_name + graceful_registry(inspect.currentframe().f_code.co_name) torch.cuda.set_device(device_id) mem_proties, shared_mem_data = mem_queue.get() mem_manager = RadixMemoryBuffer( - mem_propties=mem_proties, - shared_data=shared_mem_data, - lock=radix_lock, - rank_in_node=device_id + mem_propties=mem_proties, shared_data=shared_mem_data, lock=radix_lock, rank_in_node=device_id ) remote_cache_manager = RemoteCacheManager( unique_name=get_unique_server_name(), rank_in_node=device_id, mem_manager=mem_manager, ) - radix_manager = RadixBufferManager(radix_buffer=mem_manager, - radix_mem_data=shared_mem_data, - lock=radix_lock) + radix_manager = RadixBufferManager(radix_buffer=mem_manager, radix_mem_data=shared_mem_data, lock=radix_lock) shm_req_manager = ShmReqManager() @@ -273,30 +260,22 @@ def _init_server( shm_req_manager=shm_req_manager, rank_in_node=device_id, port=port, - init_event=init_event + init_event=init_event, ) - t.join() + t.join() return - -async def start_disk_cache_server_process( - args, - device_id, - node_word_size, - mem_queue, - radix_lock, - port -): + + +async def start_disk_cache_server_process(args, device_id, node_word_size, mem_queue, radix_lock, port): """ Start the DiskCacheManager in process. """ from lightllm.utils.envs_utils import get_unique_server_name + if node_word_size == 1: mem_proties, shared_mem_data = mem_queue.get() mem_buffer = RadixMemoryBuffer( - mem_propties=mem_proties, - shared_data=shared_mem_data, - lock=radix_lock, - rank_in_node=device_id + mem_propties=mem_proties, shared_data=shared_mem_data, lock=radix_lock, rank_in_node=device_id ) remote_cache_manager = RemoteCacheManager( unique_name=get_unique_server_name(), @@ -305,15 +284,9 @@ async def start_disk_cache_server_process( ) shm_req_manager = ShmReqManager() - radix_manager = RadixBufferManager(radix_buffer=mem_buffer, - radix_mem_data=shared_mem_data, - lock=radix_lock) + radix_manager = RadixBufferManager(radix_buffer=mem_buffer, radix_mem_data=shared_mem_data, lock=radix_lock) service = DiskCacheService(radix_manager, remote_cache_manager, shm_req_manager) - client = DiskCacheClient( - service=service, - rank_in_node=0, - use_rpc=False - ) + client = DiskCacheClient(service=service, rank_in_node=0, use_rpc=False) return client init_event = mp.Event() @@ -327,16 +300,11 @@ async def start_disk_cache_server_process( try: conn = rpyc.connect("localhost", port, config={"allow_pickle": True}) break - except Exception as e: + except Exception: asyncio.sleep(2) service = conn.root - client = DiskCacheClient( - rank_in_node=device_id, - service=service, - use_rpc=True, - proc=proc - ) + client = DiskCacheClient(rank_in_node=device_id, service=service, use_rpc=True, proc=proc) assert proc.is_alive() logger.info(f"disk cache process for device {device_id} start!") - return client \ No newline at end of file + return client diff --git a/lightllm/server/router/dynamic_prompt/hiradix/hiradix_cache.py b/lightllm/server/router/dynamic_prompt/hiradix/hiradix_cache.py index ca8665690..80110c8c0 100644 --- a/lightllm/server/router/dynamic_prompt/hiradix/hiradix_cache.py +++ b/lightllm/server/router/dynamic_prompt/hiradix/hiradix_cache.py @@ -1,28 +1,12 @@ import torch -import time -import tempfile -import zmq -import inspect -import threading import numpy as np -import torch.multiprocessing as mp -import torch.distributed as dist -from os.path import join -from ..radix_cache import RadixCache, TreeNode, match -from typing import Tuple, Dict, Set, List +from ..radix_cache import RadixCache from lightllm.common.mem_manager import MemoryManager from lightllm.common.radixmem_buffer import RadixMemoryBuffer from lightllm.common.radixmem_manager import RadixBufferManager from lightllm.utils.log_utils import init_logger -from threading import Lock -from enum import Enum from ..shared_arr import SharedArray -from .io_objs import ShmReqInfo from lightllm.server.core.objs import Req, RadixStatus -from lightllm.server.core.objs.io_objs import GroupReqIndexes -from lightllm.server.core.objs import ShmReqManager -from lightllm.utils.graceful_utils import graceful_registry -from lightllm.server.core.objs import Req, SamplingParams, FinishStatus, ShmReqManager logger = init_logger(__name__) @@ -42,24 +26,20 @@ def insert(self, req: Req, key: torch.Tensor, value=None): if alloc_len == 0: self._set_radix_staus(req, RadixStatus.WRITE_READY) return - + self.radix_manager.free_space(alloc_len) new_index = self._alloc_and_copy_kv(alloc_len, value) start_pos = max(0, (query_len - 1) // self.chunk_size * self.chunk_size) - self.radix_manager.write( - tokens=key.tolist(), - values=query_index + new_index, - start_pos=start_pos - ) + self.radix_manager.write(tokens=key.tolist(), values=query_index + new_index, start_pos=start_pos) self._set_radix_staus(req, RadixStatus.WRITE_READY) - + def _query_cache(self, req, key): if req.radix_status.is_no_need_cache(self.rank_in_node): return 0, [] - + if req.radix_status.is_read_ready(self.rank_in_node): query_len, mem_index = self.radix_manager.query_cache(key.tolist()) return query_len, mem_index @@ -72,7 +52,9 @@ def _alloc_and_copy_kv(self, alloc_len, value): dst_kv_buffer = self.radix_buffer.get_kv_buffer(new_index) src_kv_buffer = self.mem_manager.get_index_kv_buffer(value[-alloc_len:])["kv_buffer"] - assert len(src_kv_buffer) == len(dst_kv_buffer), f"Mis match buffer size src {len(src_kv_buffer)} != dst {len(dst_kv_buffer)}" + assert len(src_kv_buffer) == len( + dst_kv_buffer + ), f"Mis match buffer size src {len(src_kv_buffer)} != dst {len(dst_kv_buffer)}" self.copy_kv_from_gpu_to_cpu(src_kv_buffer, dst_kv_buffer) return new_index.tolist() @@ -85,7 +67,9 @@ def read(self, key, value, query_index, alloc_len): src_kv_buffer = self.radix_buffer.get_kv_buffer(index=query_index[-alloc_len:]) dst_kv_buffer = self.mem_manager.get_index_kv_buffer(index=value[-alloc_len:])["kv_buffer"] - assert len(src_kv_buffer) == len(dst_kv_buffer), f"Mis match buffer size src {len(src_kv_buffer)} != dst {len(dst_kv_buffer)}" + assert len(src_kv_buffer) == len( + dst_kv_buffer + ), f"Mis match buffer size src {len(src_kv_buffer)} != dst {len(dst_kv_buffer)}" self.copy_kv_from_cpu_to_gpu(src_kv_buffer, dst_kv_buffer) @@ -97,16 +81,16 @@ def read(self, key, value, query_index, alloc_len): def query(self, req: Req, key): return self._query_cache(req, key) - + @property def chunk_size(self): return self.radix_manager.chunk_size - + def copy_kv_from_cpu_to_gpu(self, src_kv_tensor, dst_kv_tensor): dst_kv_tensor.copy_(src_kv_tensor, non_blocking=True) - + def copy_kv_from_gpu_to_cpu(self, src_kv_tensor, dst_kv_tensor): - dst_kv_tensor.copy_(src_kv_tensor, non_blocking=True) + dst_kv_tensor.copy_(src_kv_tensor, non_blocking=True) class HiRadixCache(RadixCache): @@ -115,17 +99,19 @@ def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager, radi self.rank_in_node = rank_in_node self.radix_manager: RadixBufferManager = radix_manager self.local_cache_manager = LocalCacheManager( - radix_manager=self.radix_manager, - mem_manager=mem_manager, - rank_in_node=rank_in_node + radix_manager=self.radix_manager, mem_manager=mem_manager, rank_in_node=rank_in_node ) self.radix_info_queue = radix_info_queue self.is_hi_radix_cache = True - self.disk_cache_match_count = SharedArray(f"{unique_name}_disk_cache_match_count_{rank_in_node}", (1,), dtype=np.int64) + self.disk_cache_match_count = SharedArray( + f"{unique_name}_disk_cache_match_count_{rank_in_node}", (1,), dtype=np.int64 + ) self.disk_cache_match_count.arr[0] = 0 self.total_match_count = SharedArray(f"{unique_name}_total_match_count_{rank_in_node}", (1,), dtype=np.int64) self.total_match_count.arr[0] = 0 - self.disk_cache_match_ratio = SharedArray(f"{unique_name}_disk_cache_match_ratio_{rank_in_node}", (1,), dtype=np.float32) + self.disk_cache_match_ratio = SharedArray( + f"{unique_name}_disk_cache_match_ratio_{rank_in_node}", (1,), dtype=np.float32 + ) self.disk_cache_match_ratio.arr[0] = 0.0 logger.info(f"Initializing HiRadixCache {rank_in_node}") @@ -150,7 +136,9 @@ def match_prefix(self, req, key, update_refs=False): if tree_node.node_prefix_total_len < len(key): max_len, query_index = self.local_cache_manager.query(req, key) - logger.debug(f"HiCache rank_in_node={self.rank_in_node} current key len {len(key)} match radix len {tree_node.node_prefix_total_len}, max len {max_len}") + logger.debug( + f"HiCache rank_in_node={self.rank_in_node} current key len {len(key)} match radix len {tree_node.node_prefix_total_len}, max len {max_len}" + ) if max_len > tree_node.node_prefix_total_len: pull_len = max_len - tree_node.node_prefix_total_len self.disk_cache_match_count.arr[0] += 1 @@ -159,11 +147,13 @@ def match_prefix(self, req, key, update_refs=False): buffers = self.mem_manager.alloc(pull_len) if ans_value is not None: buffers = torch.concat([ans_value, buffers]) - logger.debug(f"HiCache current match ratio {self.disk_cache_match_ratio.arr[0]}, pulled cache len {pull_len} from disk") + logger.debug( + f"HiCache current match ratio {self.disk_cache_match_ratio.arr[0]}, pulled cache len {pull_len} from disk" + ) res = self.local_cache_manager.read(key[:max_len], buffers, query_index, alloc_len=pull_len) if res: super().insert(key[:max_len], buffers) else: - self.mem_manager.free(buffers[tree_node.node_prefix_total_len:]) - + self.mem_manager.free(buffers[tree_node.node_prefix_total_len :]) + return super().match_prefix(key, update_refs=update_refs) diff --git a/lightllm/server/router/dynamic_prompt/hiradix/io_objs.py b/lightllm/server/router/dynamic_prompt/hiradix/io_objs.py index 93031a631..35c6e7da6 100644 --- a/lightllm/server/router/dynamic_prompt/hiradix/io_objs.py +++ b/lightllm/server/router/dynamic_prompt/hiradix/io_objs.py @@ -3,23 +3,19 @@ from enum import Enum from typing import List + @dataclass class ShmReqInfo: request_id: int shm_req_index: int def to_dict(self): - return { - "request_id": self.request_id, - "shm_req_index": self.shm_req_index - } + return {"request_id": self.request_id, "shm_req_index": self.shm_req_index} @staticmethod def from_dict(d): - return ShmReqInfo( - request_id=d["request_id"], - shm_req_index=d["shm_req_index"] - ) + return ShmReqInfo(request_id=d["request_id"], shm_req_index=d["shm_req_index"]) + @dataclass class GroupReqInfo: @@ -27,17 +23,12 @@ class GroupReqInfo: shm_req_indexes: List[int] def to_dict(self): - return { - "group_req_id": self.group_req_id, - "shm_req_indexes": self.shm_req_indexes - } - + return {"group_req_id": self.group_req_id, "shm_req_indexes": self.shm_req_indexes} + @staticmethod def from_dict(d): - return GroupReqInfo( - group_req_id=d["group_req_id"], - shm_req_indexes=d["shm_req_indexes"] - ) + return GroupReqInfo(group_req_id=d["group_req_id"], shm_req_indexes=d["shm_req_indexes"]) + @dataclass class CacheTask: @@ -46,14 +37,13 @@ class CacheTask: kv_page_indexer: torch.Tensor = None start_pos: torch.Tensor = 0 + @dataclass class PushState: state: bool def to_dict(self): - return { - "state": self.state - } + return {"state": self.state} @staticmethod def from_dict(d): @@ -61,25 +51,21 @@ def from_dict(d): state=d["state"], ) + class HitSate(Enum): NONE = -1 MEM = 0 DISK = 1 + @dataclass class PullState: match_length: int cache_source: HitSate def to_dict(self): - return { - "match_length": self.match_length, - "cache_source": self.cache_source.name - } + return {"match_length": self.match_length, "cache_source": self.cache_source.name} @staticmethod def from_dict(d): - return PullState( - match_length=d["match_length"], - cache_source=HitSate[d["cache_source"]] - ) \ No newline at end of file + return PullState(match_length=d["match_length"], cache_source=HitSate[d["cache_source"]]) diff --git a/lightllm/server/router/dynamic_prompt/hiradix/manager.py b/lightllm/server/router/dynamic_prompt/hiradix/manager.py index 1d8fdeb85..1ded93199 100644 --- a/lightllm/server/router/dynamic_prompt/hiradix/manager.py +++ b/lightllm/server/router/dynamic_prompt/hiradix/manager.py @@ -1,15 +1,12 @@ -import time import zmq import zmq.asyncio import inspect import pickle import torch.multiprocessing as mp -import threading import asyncio from typing import List -from dataclasses import dataclass from lightllm.server.core.objs.io_objs import GroupReqIndexes -from lightllm.utils.log_utils import init_logger, log_time_ready +from lightllm.utils.log_utils import init_logger from lightllm.utils.graceful_utils import graceful_registry from .disk_cache_server import DiskCacheClient from lightllm.server.core.objs import ShmReqManager @@ -18,9 +15,9 @@ logger = init_logger(__name__) + class HiRadixCacheManagerServer: - def __init__( - self, args, mem_queues: List[mp.Queue], radix_locks: List[mp.Lock], router_port: int): + def __init__(self, args, mem_queues: List[mp.Queue], radix_locks: List[mp.Lock], router_port: int): self.args = args self.mem_queues = mem_queues self.radix_locks = radix_locks @@ -39,13 +36,13 @@ def __init__( self.recv_from_router.bind(f"{args.zmq_mode}127.0.0.1:{recv_from_router_port}") self.shm_req_manager = ShmReqManager() - async def asyn_init(self): self.pull_queue = asyncio.Queue() self.push_queue = asyncio.Queue() - + async def start_all(self): from lightllm.server.router.dynamic_prompt.hiradix.disk_cache_server import start_disk_cache_server_process + for rank_in_node in range(self.node_world_size): client = await start_disk_cache_server_process( self.args, @@ -53,15 +50,14 @@ async def start_all(self): node_word_size=self.node_world_size, mem_queue=self.mem_queues[rank_in_node], radix_lock=self.radix_locks[rank_in_node], - port=self.ports[rank_in_node] + port=self.ports[rank_in_node], ) self.clients.append(client) - + async def pull_cache(self, group_req): tasks = [] group_req_info = GroupReqInfo( - group_req_id=group_req.group_req_id, - shm_req_indexes=group_req.shm_req_indexes + group_req_id=group_req.group_req_id, shm_req_indexes=group_req.shm_req_indexes ).to_dict() for client in self.clients: task = client.pull(group_req_info) @@ -96,10 +92,7 @@ async def push_woker(self): async def run(self): await self.asyn_init() await asyncio.gather( - self.loop_for_netio_req_to_pull(), - self.pull_woker(), - self.loop_for_netio_req_to_push(), - self.push_woker() + self.loop_for_netio_req_to_pull(), self.pull_woker(), self.loop_for_netio_req_to_push(), self.push_woker() ) async def loop_for_netio_req_to_push(self): @@ -118,19 +111,11 @@ async def loop_for_netio_req_to_pull(self): else: raise ValueError(f"Invalid request: {recv_req}") -def _init_env_server( - args, - mem_queues, - radix_locks: List[mp.Lock], - init_event: mp.Event, - router_port: int -): + +def _init_env_server(args, mem_queues, radix_locks: List[mp.Lock], init_event: mp.Event, router_port: int): graceful_registry(inspect.currentframe().f_code.co_name) hiradix_cache_manager = HiRadixCacheManagerServer( - args, - mem_queues=mem_queues, - radix_locks=radix_locks, - router_port=router_port + args, mem_queues=mem_queues, radix_locks=radix_locks, router_port=router_port ) asyncio.run(hiradix_cache_manager.start_all()) try: @@ -142,11 +127,9 @@ def _init_env_server( logger.error(f"hiradix server error happend {e}") return + def start_hiradix_cache_manager_process_server( - args, - radix_mem_queues: List[mp.Queue], - radix_locks: List[mp.Lock], - router_port: int + args, radix_mem_queues: List[mp.Queue], radix_locks: List[mp.Lock], router_port: int ): """ Start the HiRadix cache manager process. @@ -155,6 +138,6 @@ def start_hiradix_cache_manager_process_server( proc = mp.Process(target=_init_env_server, args=(args, radix_mem_queues, radix_locks, init_event, router_port)) proc.start() init_event.wait() - logger.info(f"HiRadix cache manager process started") + logger.info("HiRadix cache manager process started") assert proc.is_alive() - return \ No newline at end of file + return diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 6914f147f..51fa4bfe4 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -118,7 +118,6 @@ def init_model(self, kvargs): "quant_cfg": kvargs.get("quant_cfg", None), "run_mode": self.run_mode, "use_hiradix_cache": self.use_hiradix_cache, - "radix_lock": self.radix_lock, "hiradix_cache_gpu": kvargs.get("hiradix_cache_gpu", False), "hiradix_cache_token_num": kvargs.get("hiradix_cache_token_num", False), "radix_lock": self.radix_lock From ab5d9333135501127086b705a608ef9463f42bf0 Mon Sep 17 00:00:00 2001 From: "jinbiaoyu@126.com" Date: Mon, 21 Jul 2025 15:05:55 +0800 Subject: [PATCH 10/13] fix too long --- lightllm/common/radixmem_buffer.py | 4 +++- .../server/router/dynamic_prompt/hiradix/hiradix_cache.py | 6 ++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/lightllm/common/radixmem_buffer.py b/lightllm/common/radixmem_buffer.py index 20181d9eb..ebe9a2e43 100644 --- a/lightllm/common/radixmem_buffer.py +++ b/lightllm/common/radixmem_buffer.py @@ -143,7 +143,9 @@ def free_req_index(self, req_id: int): def alloc(self, need_size) -> torch.Tensor: with self.lock: if need_size > self.mark_end.get_value() - self.mark_start.get_value(): - logger.error(f"warn no enough cache need_size {need_size} left_size {self.can_use_mem_size.get_value()}") + logger.error( + f"warn no enough cache need_size {need_size} left_size {self.can_use_mem_size.get_value()}" + ) raise RuntimeError(f"Not enough memory to allocate {need_size} tokens.") start = self.mark_start.get_value() diff --git a/lightllm/server/router/dynamic_prompt/hiradix/hiradix_cache.py b/lightllm/server/router/dynamic_prompt/hiradix/hiradix_cache.py index 80110c8c0..fe4b5a675 100644 --- a/lightllm/server/router/dynamic_prompt/hiradix/hiradix_cache.py +++ b/lightllm/server/router/dynamic_prompt/hiradix/hiradix_cache.py @@ -137,7 +137,8 @@ def match_prefix(self, req, key, update_refs=False): max_len, query_index = self.local_cache_manager.query(req, key) logger.debug( - f"HiCache rank_in_node={self.rank_in_node} current key len {len(key)} match radix len {tree_node.node_prefix_total_len}, max len {max_len}" + f"HiCache rank_in_node={self.rank_in_node} current key len {len(key)} match radix len " + f"{tree_node.node_prefix_total_len}, max len {max_len}" ) if max_len > tree_node.node_prefix_total_len: pull_len = max_len - tree_node.node_prefix_total_len @@ -148,7 +149,8 @@ def match_prefix(self, req, key, update_refs=False): if ans_value is not None: buffers = torch.concat([ans_value, buffers]) logger.debug( - f"HiCache current match ratio {self.disk_cache_match_ratio.arr[0]}, pulled cache len {pull_len} from disk" + f"HiCache current match ratio {self.disk_cache_match_ratio.arr[0]}," + f"pulled cache len {pull_len} from disk" ) res = self.local_cache_manager.read(key[:max_len], buffers, query_index, alloc_len=pull_len) if res: From caa2d6c75a5728c2a8efd076ff8a6af62016c28d Mon Sep 17 00:00:00 2001 From: "jinbiaoyu@126.com" Date: Mon, 21 Jul 2025 16:43:17 +0800 Subject: [PATCH 11/13] update & fix format --- lightllm/common/basemodel/basemodel.py | 18 ++---------- lightllm/common/radixmem_buffer.py | 9 ++++-- lightllm/common/radixmem_manager.py | 28 ++++++++++++++++++- lightllm/models/deepseek2/model.py | 18 ++---------- lightllm/models/qwen2/model.py | 18 ++---------- .../model_infer/mode_backend/base_backend.py | 1 - 6 files changed, 42 insertions(+), 50 deletions(-) diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 1665486e8..2bd5853a2 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -179,8 +179,8 @@ def _init_mem_manager(self): ) if self.enable_hiradix_cache: - from lightllm.common.radixmem_buffer import RadixMemoryBuffer, init_shared_data, get_shared_data, MemPropties - from lightllm.common.radixmem_manager import RadixBufferManager + from lightllm.common.radixmem_buffer import get_shared_data, MemPropties + from lightllm.common.radixmem_manager import build_radix_manager mem_propties = MemPropties( self.hiradix_cache_token_num, dtype=self.data_type, @@ -188,19 +188,7 @@ def _init_mem_manager(self): head_dim=self.config["n_embed"] // self.config["num_attention_heads"], layer_num=self.config["n_layer"] ) - init_shared_data( - mem_propties=mem_propties, - device="cpu" if not self.hiradix_cache_gpu else "cuda" - ) - radix_mem_buffer = RadixMemoryBuffer( - mem_propties, - shared_data=get_shared_data(), - lock=self.radix_lock, - device="cpu" if not self.hiradix_cache_gpu else "cuda" - ) - self.radix_manager = RadixBufferManager(radix_buffer=radix_mem_buffer, - radix_mem_data=get_shared_data(), - lock=self.radix_lock) + self.radix_manager = build_radix_manager(mem_propties, self.hiradix_cache_gpu, self.radix_lock) self.mem_propties = mem_propties self.shared_mem_data = get_shared_data() return diff --git a/lightllm/common/radixmem_buffer.py b/lightllm/common/radixmem_buffer.py index ebe9a2e43..87d781a32 100644 --- a/lightllm/common/radixmem_buffer.py +++ b/lightllm/common/radixmem_buffer.py @@ -137,14 +137,16 @@ def free_req_index(self, req_id: int): return index = self.req_mem_index[req_id] self._free(index) - logger.info(f"Freed memory index for request {req_id} size {len(index)}, left size {self.can_use_mem_size.get_value()}") + logger.info(f"Freed memory index for request {req_id} size {len(index)}, " + f"left size {self.can_use_mem_size.get_value()}") del self.req_mem_index[req_id] def alloc(self, need_size) -> torch.Tensor: with self.lock: if need_size > self.mark_end.get_value() - self.mark_start.get_value(): logger.error( - f"warn no enough cache need_size {need_size} left_size {self.can_use_mem_size.get_value()}" + f"warn no enough cache need_size {need_size} " + f"left_size {self.can_use_mem_size.get_value()}" ) raise RuntimeError(f"Not enough memory to allocate {need_size} tokens.") @@ -160,7 +162,8 @@ def set_req_mem_index(self, req_id: int, index: List[int]): """Set the memory index for a specific request ID.""" with self.lock: if req_id in self.req_mem_index: - logger.info(f"Request ID {req_id} already exists. Overwriting index {self.req_mem_index[req_id]} with {index}.") + logger.info(f"Request ID {req_id} already exists. " + f"Overwriting index {self.req_mem_index[req_id]} with {index}.") self.req_mem_index[req_id] = index logger.info(f"radix mem buffer insert req {req_id}, current disk work num {self._get_current_work_num()}") diff --git a/lightllm/common/radixmem_manager.py b/lightllm/common/radixmem_manager.py index 44564d2c6..398017980 100644 --- a/lightllm/common/radixmem_manager.py +++ b/lightllm/common/radixmem_manager.py @@ -6,6 +6,7 @@ import torch.multiprocessing as mp from collections import OrderedDict +from .radixmem_buffer import MemPropties, init_shared_data, get_shared_data from .radixmem_buffer import SharedRadixMemoryData, RadixMemoryBuffer from lightllm.utils.log_utils import init_logger @@ -116,4 +117,29 @@ def query_cache(self, tokens: List[int]) -> int: def clear(self): with self.lock: self.radix_buffer.req_mem_index.clear() - self.lru_queue[:] = [] \ No newline at end of file + self.lru_queue[:] = [] + +def build_radix_manager(mem_propties: MemPropties, + use_gpu: bool, + radix_lock) -> RadixBufferManager: + device = "cuda" if use_gpu else "cpu" + + init_shared_data( + mem_propties=mem_propties, + device=device, + ) + + radix_mem_buffer = RadixMemoryBuffer( + mem_propties=mem_propties, + shared_data=get_shared_data(), + lock=radix_lock, + device=device, + ) + + radix_manager = RadixBufferManager( + radix_buffer=radix_mem_buffer, + radix_mem_data=get_shared_data(), + lock=radix_lock, + ) + + return radix_manager \ No newline at end of file diff --git a/lightllm/models/deepseek2/model.py b/lightllm/models/deepseek2/model.py index 38f49b96f..69d673192 100644 --- a/lightllm/models/deepseek2/model.py +++ b/lightllm/models/deepseek2/model.py @@ -113,8 +113,8 @@ def _init_mem_manager(self): mem_fraction=self.mem_fraction, ) if self.enable_hiradix_cache: - from lightllm.common.radixmem_buffer import RadixMemoryBuffer, init_shared_data, get_shared_data, MemPropties - from lightllm.common.radixmem_manager import RadixBufferManager + from lightllm.common.radixmem_buffer import get_shared_data, MemPropties + from lightllm.common.radixmem_manager import build_radix_manager mem_propties = MemPropties( self.hiradix_cache_token_num, dtype=self.data_type, @@ -122,19 +122,7 @@ def _init_mem_manager(self): head_dim=self.config["kv_lora_rank"] + self.config["qk_rope_head_dim"], layer_num=self.config["num_hidden_layers"] + added_mtp_layer_num, ) - init_shared_data( - mem_propties=mem_propties, - device="cpu" if not self.hiradix_cache_gpu else "cuda" - ) - radix_mem_buffer = RadixMemoryBuffer( - mem_propties, - shared_data=get_shared_data(), - lock=self.radix_lock, - device="cpu" if not self.hiradix_cache_gpu else "cuda" - ) - self.radix_manager = RadixBufferManager(radix_buffer=radix_mem_buffer, - radix_mem_data=get_shared_data(), - lock=self.radix_lock) + self.radix_manager = build_radix_manager(mem_propties, self.hiradix_cache_gpu, self.radix_lock) self.mem_propties = mem_propties self.shared_mem_data = get_shared_data() return diff --git a/lightllm/models/qwen2/model.py b/lightllm/models/qwen2/model.py index 8c69eb727..2b86e2e9d 100644 --- a/lightllm/models/qwen2/model.py +++ b/lightllm/models/qwen2/model.py @@ -53,8 +53,8 @@ def _init_mem_manager(self): ) if self.enable_hiradix_cache: - from lightllm.common.radixmem_buffer import RadixMemoryBuffer, init_shared_data, get_shared_data, MemPropties - from lightllm.common.radixmem_manager import RadixBufferManager + from lightllm.common.radixmem_buffer import MemPropties, get_shared_data, MemPropties + from lightllm.common.radixmem_manager import build_radix_manager mem_propties = MemPropties( self.hiradix_cache_token_num, dtype=self.data_type, @@ -62,19 +62,7 @@ def _init_mem_manager(self): head_dim=head_dim_, layer_num=self.config["num_hidden_layers"], ) - init_shared_data( - mem_propties=mem_propties, - device="cpu" if not self.hiradix_cache_gpu else "cuda" - ) - radix_mem_buffer = RadixMemoryBuffer( - mem_propties, - shared_data=get_shared_data(), - lock=self.radix_lock, - device="cpu" if not self.hiradix_cache_gpu else "cuda" - ) - self.radix_manager = RadixBufferManager(radix_buffer=radix_mem_buffer, - radix_mem_data=get_shared_data(), - lock=self.radix_lock) + self.radix_manager = build_radix_manager(mem_propties, self.hiradix_cache_gpu, self.radix_lock) self.mem_propties = mem_propties self.shared_mem_data = get_shared_data() return diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 51fa4bfe4..25fba774e 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -7,7 +7,6 @@ from lightllm.utils.log_utils import init_logger from lightllm.models import get_model from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache -from lightllm.server.router.model_infer.infer_batch import InferReq from lightllm.server.router.dynamic_prompt.hiradix.hiradix_cache import HiRadixCache from lightllm.server.router.model_infer.infer_batch import InferReq, InferSamplingParams from lightllm.server.router.token_load import TokenLoad From 66fd1ab95a23a1d4f31b530f187a1c80acce03db Mon Sep 17 00:00:00 2001 From: "jinbiaoyu@126.com" Date: Tue, 22 Jul 2025 11:33:51 +0800 Subject: [PATCH 12/13] update --- lightllm/common/radixmem_manager.py | 2 +- lightllm/models/qwen2/model.py | 2 +- .../server/router/dynamic_prompt/hiradix/disk_cache_server.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lightllm/common/radixmem_manager.py b/lightllm/common/radixmem_manager.py index 398017980..79f48ae84 100644 --- a/lightllm/common/radixmem_manager.py +++ b/lightllm/common/radixmem_manager.py @@ -45,7 +45,7 @@ def _compute_hash(self, tokens: List[int]) -> List[Tuple[int, List[int]]]: return chunks - def write(self, tokens: List[int], values: torch.Tensor, start_pos: int) -> None: + def write(self, tokens: List[int], values: torch.Tensor, start_pos: int=0) -> None: with self.lock: index = start_pos // self.chunk_size chunks = self._compute_hash(tokens) diff --git a/lightllm/models/qwen2/model.py b/lightllm/models/qwen2/model.py index 2b86e2e9d..47eb87686 100644 --- a/lightllm/models/qwen2/model.py +++ b/lightllm/models/qwen2/model.py @@ -53,7 +53,7 @@ def _init_mem_manager(self): ) if self.enable_hiradix_cache: - from lightllm.common.radixmem_buffer import MemPropties, get_shared_data, MemPropties + from lightllm.common.radixmem_buffer import MemPropties, get_shared_data from lightllm.common.radixmem_manager import build_radix_manager mem_propties = MemPropties( self.hiradix_cache_token_num, diff --git a/lightllm/server/router/dynamic_prompt/hiradix/disk_cache_server.py b/lightllm/server/router/dynamic_prompt/hiradix/disk_cache_server.py index 33aee7106..e37ec2027 100644 --- a/lightllm/server/router/dynamic_prompt/hiradix/disk_cache_server.py +++ b/lightllm/server/router/dynamic_prompt/hiradix/disk_cache_server.py @@ -164,7 +164,7 @@ def exposed_pull(self, group_req): read_task = CacheTask(tokens=keys[:query_len], kv_page_indexer=index, mode="r") self.remote_cache_manager.read(read_task) - self.radix_manager.write(keys=keys[:query_len], values=index.tolist()) + self.radix_manager.write(tokens=keys[:query_len], values=index.tolist()) radix_state = RadixStatus.READ_READY cache_state = PullState(query_len, HitSate.DISK) From b58dc31045777971930ea3e2fdaa471cd199a838 Mon Sep 17 00:00:00 2001 From: "jinbiaoyu@126.com" Date: Fri, 25 Jul 2025 16:33:12 +0800 Subject: [PATCH 13/13] release --- lightllm/common/radixmem_buffer.py | 31 ++++++++-------- lightllm/common/radixmem_manager.py | 35 +++++++++++-------- lightllm/server/detokenization/decode_req.py | 3 ++ lightllm/server/detokenization/manager.py | 2 +- .../hiradix/disk_cache_server.py | 4 +-- .../dynamic_prompt/hiradix/hiradix_cache.py | 4 +-- .../router/dynamic_prompt/hiradix/manager.py | 1 + 7 files changed, 44 insertions(+), 36 deletions(-) diff --git a/lightllm/common/radixmem_buffer.py b/lightllm/common/radixmem_buffer.py index 87d781a32..740e02120 100644 --- a/lightllm/common/radixmem_buffer.py +++ b/lightllm/common/radixmem_buffer.py @@ -90,6 +90,7 @@ def __init__(self, mem_propties: MemPropties, shared_data: SharedRadixMemoryData mark_start = 0 mark_end = self.size rank_in_node = rank_in_node if rank_in_node is not None else get_current_rank_in_node() + self.rank_in_node = rank_in_node self.can_use_mem_size = SharedInt( f"{get_unique_server_name()}_radix_mem_manger_can_use_token_num_{rank_in_node}" ) @@ -127,6 +128,7 @@ def _free(self, free_index: Union[torch.Tensor, List[int]]): if self.can_use_mem_size.get_value() == len(self.mem_state): logger.debug(f"freed all gpu mem size {self.can_use_mem_size.get_value()}") + return def free_req_index(self, req_id: int): @@ -142,21 +144,20 @@ def free_req_index(self, req_id: int): del self.req_mem_index[req_id] def alloc(self, need_size) -> torch.Tensor: - with self.lock: - if need_size > self.mark_end.get_value() - self.mark_start.get_value(): - logger.error( - f"warn no enough cache need_size {need_size} " - f"left_size {self.can_use_mem_size.get_value()}" - ) - raise RuntimeError(f"Not enough memory to allocate {need_size} tokens.") - - start = self.mark_start.get_value() - end = start + need_size - ans = self.mem_state[start:end] - self.mark_start.set_value(start + need_size) - - self.can_use_mem_size.set_value(self.can_use_mem_size.get_value() - need_size) - return ans + if need_size > self.mark_end.get_value() - self.mark_start.get_value(): + logger.error( + f"warn no enough cache need_size {need_size} " + f"left_size {self.can_use_mem_size.get_value()}" + ) + raise RuntimeError(f"Not enough memory to allocate {need_size} tokens.") + + start = self.mark_start.get_value() + end = start + need_size + ans = self.mem_state[start:end] + self.mark_start.set_value(start + need_size) + + self.can_use_mem_size.set_value(self.can_use_mem_size.get_value() - need_size) + return ans def set_req_mem_index(self, req_id: int, index: List[int]): """Set the memory index for a specific request ID.""" diff --git a/lightllm/common/radixmem_manager.py b/lightllm/common/radixmem_manager.py index 79f48ae84..1c2011c62 100644 --- a/lightllm/common/radixmem_manager.py +++ b/lightllm/common/radixmem_manager.py @@ -65,22 +65,27 @@ def _update_lru_state(self, hash_val: int): while len(self.lru_queue) > self.max_entries: self.lru_queue.pop(0) - def free_space(self, required_size: int) -> bool: - with self.lock: - current_free = self.radix_buffer.can_use_mem_size.get_value() - - if current_free >= required_size: - return True - - need_to_free = required_size - current_free - freed_size = 0 - - while freed_size < need_to_free and len(self.lru_queue) > 0: - evict_size = self._evict_lru() - freed_size += evict_size + def _free_space(self, required_size: int) -> bool: + current_free = self.radix_buffer.can_use_mem_size.get_value() + + if current_free >= required_size: + return True - final_free = self.radix_buffer.can_use_mem_size.get_value() - return final_free >= required_size + need_to_free = required_size - current_free + freed_size = 0 + + while freed_size < need_to_free and len(self.lru_queue) > 0: + evict_size = self._evict_lru() + freed_size += evict_size + + final_free = self.radix_buffer.can_use_mem_size.get_value() + return final_free >= required_size + + def alloc(self, required_size: int) -> bool: + with self.lock: + self._free_space(required_size) + ans = self.radix_buffer.alloc(required_size) + return ans def _evict_lru(self): if not self.lru_queue: diff --git a/lightllm/server/detokenization/decode_req.py b/lightllm/server/detokenization/decode_req.py index 9a85ea089..4c17f3f98 100644 --- a/lightllm/server/detokenization/decode_req.py +++ b/lightllm/server/detokenization/decode_req.py @@ -8,9 +8,11 @@ class DecodeReq: def __init__( self, + args, req: Req, is_pd_decode_mode: bool, ) -> None: + self.args = args self.request_id = req.request_id self.group_req_id = req.group_req_id self.prompt_ids = req.shm_prompt_ids.arr[0 : req.input_len].tolist() @@ -59,6 +61,7 @@ def can_set_release_mark(self): self.req.finish_status.is_finished() and self.req.candetoken_out_len == len(self.output_ids) and self.req.finish_token_index == self.input_len + len(self.output_ids) - 1 + and (self.req.radix_status.is_finished() if self.args.use_hiradix_cache else True) ): return True return False diff --git a/lightllm/server/detokenization/manager.py b/lightllm/server/detokenization/manager.py index 8f32d0992..57e0422d5 100644 --- a/lightllm/server/detokenization/manager.py +++ b/lightllm/server/detokenization/manager.py @@ -71,7 +71,7 @@ async def handle_loop(self): ) # p d 分离模式,decode节点的解码需要做一些特殊的修复。 - decode_req = DecodeReq(req, self.is_pd_decode_mode) + decode_req = DecodeReq(self.args, req, self.is_pd_decode_mode) if self.is_pd_decode_mode: decode_req = decode_mode_fix(decode_req, self.tokenizer, self.eos_id) # token_healing mode 的特殊初始化 diff --git a/lightllm/server/router/dynamic_prompt/hiradix/disk_cache_server.py b/lightllm/server/router/dynamic_prompt/hiradix/disk_cache_server.py index e37ec2027..293e011ab 100644 --- a/lightllm/server/router/dynamic_prompt/hiradix/disk_cache_server.py +++ b/lightllm/server/router/dynamic_prompt/hiradix/disk_cache_server.py @@ -159,8 +159,8 @@ def exposed_pull(self, group_req): query_len = self.remote_cache_manager.query(query_task) if query_len > 0: - self.radix_manager.free_space(query_len) - index = self.radix_manager.radix_buffer.alloc(query_len) + + index = self.radix_manager.alloc(query_len) read_task = CacheTask(tokens=keys[:query_len], kv_page_indexer=index, mode="r") self.remote_cache_manager.read(read_task) diff --git a/lightllm/server/router/dynamic_prompt/hiradix/hiradix_cache.py b/lightllm/server/router/dynamic_prompt/hiradix/hiradix_cache.py index fe4b5a675..33682095a 100644 --- a/lightllm/server/router/dynamic_prompt/hiradix/hiradix_cache.py +++ b/lightllm/server/router/dynamic_prompt/hiradix/hiradix_cache.py @@ -27,8 +27,6 @@ def insert(self, req: Req, key: torch.Tensor, value=None): self._set_radix_staus(req, RadixStatus.WRITE_READY) return - self.radix_manager.free_space(alloc_len) - new_index = self._alloc_and_copy_kv(alloc_len, value) start_pos = max(0, (query_len - 1) // self.chunk_size * self.chunk_size) @@ -48,7 +46,7 @@ def _query_cache(self, req, key): def _alloc_and_copy_kv(self, alloc_len, value): assert alloc_len > 0, "No allocation needed" - new_index = self.radix_buffer.alloc(alloc_len) + new_index = self.radix_manager.alloc(alloc_len) dst_kv_buffer = self.radix_buffer.get_kv_buffer(new_index) src_kv_buffer = self.mem_manager.get_index_kv_buffer(value[-alloc_len:])["kv_buffer"] diff --git a/lightllm/server/router/dynamic_prompt/hiradix/manager.py b/lightllm/server/router/dynamic_prompt/hiradix/manager.py index 1ded93199..4b695071c 100644 --- a/lightllm/server/router/dynamic_prompt/hiradix/manager.py +++ b/lightllm/server/router/dynamic_prompt/hiradix/manager.py @@ -75,6 +75,7 @@ async def push_cache(self, req_info): req: Req = self.shm_req_manager.get_req_obj_by_index(req_info["shm_req_index"]) assert req.radix_status.is_write_done() req.radix_status.set_finished() + self.shm_req_manager.put_back_req_obj(req) logger.info(f"push cache results {all_results}") async def pull_woker(self):