diff --git a/deepspeed/inference/v2/config_v2.py b/deepspeed/inference/v2/config_v2.py index 85e4b7a0e0a0..75742a6e8ec7 100644 --- a/deepspeed/inference/v2/config_v2.py +++ b/deepspeed/inference/v2/config_v2.py @@ -41,3 +41,6 @@ class RaggedInferenceEngineConfig(DeepSpeedConfigModel): """ quantization: QuantizationConfig = {} + + enable_prefix_cache: bool = False + """ Enable prefix cache for the model. """ diff --git a/deepspeed/inference/v2/engine_v2.py b/deepspeed/inference/v2/engine_v2.py index 4a358310377f..71fca3ee50c3 100644 --- a/deepspeed/inference/v2/engine_v2.py +++ b/deepspeed/inference/v2/engine_v2.py @@ -87,7 +87,8 @@ def __init__(self, policy: InferenceV2Policy, engine_config: RaggedInferenceEngi self._batch = RaggedBatchWrapper(self._config.state_manager) self._state_manager = DSStateManager(self._config.state_manager, self._model.kv_cache_config(), - base_mp_group=self._base_mp_group) + base_mp_group=self._base_mp_group, + enable_prefix_cache=self._config.enable_prefix_cache) self._model.set_state_manager(self._state_manager) def _initialize_tp_group(self): @@ -129,7 +130,13 @@ def put(self, for uid, tokens in zip(batch_uids, batch_tokens): host_seq_desc = self._state_manager.get_or_create_sequence(uid) - self._model.maybe_allocate_kv(host_seq_desc, tokens.numel()) + new_block_ids = self._model.maybe_allocate_kv(host_seq_desc, tokens.numel()) + if self._config.enable_prefix_cache and new_block_ids is not None: + for block_id in new_block_ids: + assert self._state_manager._ref_counts[block_id.item()] == 0 + + self._state_manager.increment_ref_count(new_block_ids) + self._state_manager._block_map.delete(new_block_ids) host_seq_desc.pre_forward(tokens.numel()) # We can disable checks since we already validated schedulability. @@ -239,6 +246,31 @@ def get_remaining_block_capacity(self, uid: int) -> int: return 0 return self._model.get_remaining_block_capacity(seq_desc) + def lookup_cache(self, tokens: torch.Tensor) -> Tuple[int, torch.Tensor]: + """ + Lookup the KV cache for a given sequence and allocate the necessary blocks. + + Arguments: + block IDs (torch.Tensor): The tokens to allocate. + """ + if self._config.enable_prefix_cache: + return self._state_manager.lookup_cache(tokens) + return 0, torch.tensor([]) + + def setup_cached_sequence(self, uid: int, cached_length: int, block_ids: torch.Tensor) -> None: + if self._config.enable_prefix_cache: + seq = self._state_manager.get_or_create_sequence(uid) + seq.pre_forward(cached_length) + seq.post_forward() + seq.extend_kv_cache(block_ids) + seq.num_prefix_cache_blocks = len(block_ids) + self._state_manager.increment_ref_count(block_ids) + self._state_manager._kv_cache.allocate_blocks(block_ids) + + def update_prefix_cache(self, uid: int, tokens: torch.Tensor) -> None: + if self._config.enable_prefix_cache: + self._state_manager.update_cache(uid, tokens) + def flush(self, uid: int) -> None: """ Remove all state associated with a sequence from the inference engine. @@ -246,6 +278,15 @@ def flush(self, uid: int) -> None: Arguments: uid (int): The UID of the sequence to flush. """ + seq = self._state_manager.get_sequence(uid) + + if self._config.enable_prefix_cache: + self._state_manager.decrement_ref_count(seq.all_block_ids()) + blocks_to_free = [b.item() for b in seq.all_block_ids() if self._state_manager._ref_counts[b.item()] == 0] + else: + blocks_to_free = seq.all_block_ids() + self._state_manager._kv_cache.free(blocks_to_free) + self._state_manager.flush_sequence(uid) def serialize(self, save_path: str) -> None: diff --git a/deepspeed/inference/v2/model_implementations/inference_model_base.py b/deepspeed/inference/v2/model_implementations/inference_model_base.py index 894a4137407e..d1bba390c70f 100644 --- a/deepspeed/inference/v2/model_implementations/inference_model_base.py +++ b/deepspeed/inference/v2/model_implementations/inference_model_base.py @@ -204,7 +204,7 @@ def get_remaining_block_capacity(self, sequence: DSSequenceDescriptor) -> int: raise NotImplementedError() @abstractmethod - def maybe_allocate_kv(self, sequence: DSSequenceDescriptor, n_new_tokens: int) -> None: + def maybe_allocate_kv(self, sequence: DSSequenceDescriptor, n_new_tokens: int) -> Optional[torch.Tensor]: """ Given a sequence and the number of new tokens in the sequence, determine whether or not additional KV-storage is needed and allocate it if so. @@ -212,6 +212,9 @@ def maybe_allocate_kv(self, sequence: DSSequenceDescriptor, n_new_tokens: int) - Args: sequence (DSSequenceDescriptor): The sequence for which to allocate KV-storage. n_new_tokens (int): The number of new tokens in the sequence. + + Returns: + Optional[torch.Tensor]: The allocated KV block IDs. If no new blocks are needed, this returns None. """ raise NotImplementedError() diff --git a/deepspeed/inference/v2/model_implementations/inference_transformer_base.py b/deepspeed/inference/v2/model_implementations/inference_transformer_base.py index fae67dc8fc2a..751ce0de1af1 100644 --- a/deepspeed/inference/v2/model_implementations/inference_transformer_base.py +++ b/deepspeed/inference/v2/model_implementations/inference_transformer_base.py @@ -356,7 +356,7 @@ def get_kv_requirements(self, sequence: DSSequenceDescriptor, max_new_tokens: in def get_remaining_block_capacity(self, sequence: DSSequenceDescriptor) -> int: return sequence.seen_tokens % self.attn.kv_block_size - def maybe_allocate_kv(self, sequence: DSSequenceDescriptor, n_new_tokens: int) -> None: + def maybe_allocate_kv(self, sequence: DSSequenceDescriptor, n_new_tokens: int) -> Optional[torch.Tensor]: """ See ``DSInferenceModelBase.maybe_allocate_kv`` for documentation. @@ -369,6 +369,9 @@ def maybe_allocate_kv(self, sequence: DSSequenceDescriptor, n_new_tokens: int) - if n_needed_blocks > 0: new_blocks = self.state_manager.allocate_blocks(n_needed_blocks) sequence.extend_kv_cache(new_blocks) + return new_blocks + + return None def kv_cache_config(self) -> Tuple[KVCacheConfig, ...]: """ diff --git a/deepspeed/inference/v2/ragged/blocked_allocator.py b/deepspeed/inference/v2/ragged/blocked_allocator.py index 7884d8cccb47..b54817360cd7 100644 --- a/deepspeed/inference/v2/ragged/blocked_allocator.py +++ b/deepspeed/inference/v2/ragged/blocked_allocator.py @@ -3,30 +3,24 @@ # DeepSpeed Team -from typing import Iterable, Union +from typing import Iterable, Union, List +from abc import ABC, abstractmethod import torch -class BlockedAllocator: +class BlockedAllocatorBase(ABC): """ Allocator class for managing which blocks are free/used in the - blocked KV-cache. This is a simple allocator that uses a linked list - to keep track of which blocks are free/used. The cost of allocation/deallocation - is O(blocks), where blocks is the number of blocks to allocate/deallocate. - - TODO(cmikeh2): Evaluate performance of this allocator and migrate - to C++ if necessary. + blocked KV-cache. """ + # Number of blocks in the KV-cache(s). _num_blocks: int # Array of blocks, where each element is the next block in the linked list. _blocks: torch.Tensor - # Index of the head of the linked list. - _head: int - # Number of free blocks in the KV-cache. _free_blocks: int @@ -43,10 +37,9 @@ def __init__(self, num_blocks: int) -> None: raise ValueError(f'Blocked KV-cache must have at least 1 block, provided {num_blocks}') self._num_blocks = num_blocks - self._blocks = torch.arange(1, num_blocks + 1, dtype=torch.int32, device='cpu', pin_memory=True) - self._head = 0 self._free_blocks = num_blocks + @abstractmethod def allocate(self, num_blocks: int) -> torch.Tensor: """ Allocate a list of blocks from the associated KV-caches. This will @@ -57,7 +50,50 @@ def allocate(self, num_blocks: int) -> torch.Tensor: num_blocks (int): The number of blocks to allocate. Returns: - List[int]: The list of blocks allocated. + torch.Tensor: The list of blocks allocated. + """ + pass + + @abstractmethod + def free(self, blocks: Union[Iterable[int], int]) -> None: + """ + Free a list of blocks in the associated KV-caches. This will return + the blocks to the free pool if they are valid and not already free. + + Parameters: + blocks (Union[Iterable[int], int]): The list of blocks to free. If only one block + is to be freed, this can be alone as an integer. + """ + pass + + @property + def free_blocks(self) -> int: + """ + Return the number of free blocks in the KV-cache. + """ + return self._free_blocks + + +class BlockedAllocator(BlockedAllocatorBase): + """ + This is a simple allocator that uses a linked list to keep track of which blocks are free/used. The cost of allocation/deallocation + is O(blocks), where blocks is the number of blocks to allocate/deallocate. + + TODO(cmikeh2): Evaluate performance of this allocator and migrate + to C++ if necessary. + """ + + # Index of the head of the linked list. + _head: int + + def __init__(self, num_blocks: int) -> None: + super().__init__(num_blocks) + self._blocks = torch.arange(1, num_blocks + 1, dtype=torch.int32, device='cpu', pin_memory=True) + self._head = 0 + + def allocate(self, num_blocks: int) -> torch.Tensor: + """ + Refer to the docstring of `BlockedAllocatorBase.allocate`. """ if num_blocks > self._free_blocks: raise ValueError(f'Not enough free blocks in the KV-cache to allocate {num_blocks} blocks') @@ -73,13 +109,7 @@ def allocate(self, num_blocks: int) -> torch.Tensor: def free(self, blocks: Union[Iterable[int], int]) -> None: """ - Return a list of blocks to the free pool. If a single invalid block is provided (i.e., - one that is out of range of the allocator or is already free), then an exception is raised - and no blocks are freed. - - Parameters: - blocks (Union[Iterable[int], int]): The list of blocks to free. If only one block - is to be freed, this can be alone as an integer. + Refer to the docstring of `BlockedAllocatorBase.free`. """ if isinstance(blocks, int): blocks = [blocks] @@ -97,9 +127,50 @@ def free(self, blocks: Union[Iterable[int], int]) -> None: self._head = block self._free_blocks += 1 - @property - def free_blocks(self) -> int: - """ - Return the number of free blocks in the KV-cache. - """ - return self._free_blocks + +class LinearScanBlockedAllocator(BlockedAllocatorBase): + + def __init__(self, num_blocks: int) -> None: + super().__init__(num_blocks) + self._blocks = torch.zeros(num_blocks + 1, dtype=torch.int32, device='cpu', pin_memory=True) + + def allocate(self, num_blocks: int) -> torch.Tensor: + if num_blocks > self._free_blocks: + raise ValueError(f'Not enough free blocks in the KV-cache to allocate {num_blocks} blocks') + + allocated_blocks = torch.zeros(num_blocks, dtype=torch.int32) + + alloc_idx = 0 + for i in range(self._num_blocks): + if self._blocks[i].item() == 0: + allocated_blocks[alloc_idx] = i + self._blocks[i] = -1 + self._free_blocks -= 1 + alloc_idx += 1 + if alloc_idx == num_blocks: + break + + return allocated_blocks + + def allocate_blocks(self, blocks: List[int]) -> None: + for block in blocks: + if self._blocks[block] != -1: + self._blocks[block] = -1 + self._free_blocks -= 1 + + def free(self, blocks: Union[Iterable[int], int]) -> None: + + if isinstance(blocks, int): + blocks = [blocks] + + for block in blocks: + # Parse all blocks for validity before mutating the list. + if block < 0 or block >= self._num_blocks: + raise ValueError(f'Invalid block {block} provided to free') + + if self._blocks[block] != -1: + raise ValueError(f'Block {block} is already free') + + for block in blocks: + self._blocks[block] = 0 + self._free_blocks += 1 diff --git a/deepspeed/inference/v2/ragged/kv_cache.py b/deepspeed/inference/v2/ragged/kv_cache.py index ceba3190b93c..5bcb766aa690 100644 --- a/deepspeed/inference/v2/ragged/kv_cache.py +++ b/deepspeed/inference/v2/ragged/kv_cache.py @@ -15,7 +15,7 @@ from deepspeed.accelerator import get_accelerator from ..inference_utils import elem_size from ..logging import inference_logger -from .blocked_allocator import BlockedAllocator +from .blocked_allocator import BlockedAllocator, LinearScanBlockedAllocator from .manager_configs import AllocationMode, KVCacheConfig, MemoryConfig @@ -61,7 +61,8 @@ def __init__(self, configs: Tuple[KVCacheConfig, ...], memory_config: MemoryConfig, mp_group: Optional[Any] = None, - offload: bool = False) -> None: + offload: bool = False, + enable_prefix_cache: bool = False) -> None: """ Create a container that will maintain the storage and allocations for a set of blocked KV-caches. @@ -136,7 +137,11 @@ def __init__(self, f"Allocating KV-cache {cache_group_id} with shape: {alloc_shape} consisting of {num_blocks} blocks.") caches.append(torch.empty(alloc_shape, dtype=config.cache_dtype, device=get_accelerator().current_device())) - allocators.append(BlockedAllocator(num_blocks)) + + if enable_prefix_cache: + allocators.append(LinearScanBlockedAllocator(num_blocks)) + else: + allocators.append(BlockedAllocator(num_blocks)) self._caches = tuple(caches) self._allocators = tuple(allocators) @@ -152,6 +157,9 @@ def reserve(self, num_blocks: int, cache_group: int = 0) -> torch.Tensor: """ return self._allocators[cache_group].allocate(num_blocks) + def allocate_blocks(self, blocks: Iterable[int], cache_group: int = 0) -> None: + return self._allocators[cache_group].allocate_blocks(blocks) + def free(self, blocks: Iterable[int], cache_group: int = 0) -> None: """ Free a set of blocks from the cache. This will mark the blocks as free in the diff --git a/deepspeed/inference/v2/ragged/prefix_block_map.py b/deepspeed/inference/v2/ragged/prefix_block_map.py new file mode 100644 index 000000000000..d59c750366b2 --- /dev/null +++ b/deepspeed/inference/v2/ragged/prefix_block_map.py @@ -0,0 +1,56 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Dict, List, Set +import hashlib + +import torch + + +def token_ids_to_hash(token_ids: torch.Tensor): + # Convert the tensor to bytes + tensor_bytes = token_ids.numpy().tobytes() + hash_obj = hashlib.sha256() + # Update the hash object with the bytes + hash_obj.update(tensor_bytes) + # Get the hexadecimal digest of the hash + return hash_obj.hexdigest() + + +class PrefixBlockMap(): + + def __init__(self, block_size: int): + self.tokens_to_blocks: Dict[str, List[int]] = {} + self.blocks_to_tokens: Dict[Set[int], str] = {} + self.block_size: int = block_size + + def lookup(self, tokens: torch.Tensor) -> torch.Tensor: + n_blocks = len(tokens) // self.block_size + cached_blocks = torch.tensor([], dtype=torch.int32) + for i in range(n_blocks): + chunk = tokens[:(i + 1) * self.block_size] + hash = token_ids_to_hash(chunk) + if hash in self.tokens_to_blocks: + cached_blocks = self.tokens_to_blocks[hash] + else: + break + return cached_blocks + + def extend(self, tokens: torch.Tensor, new_block_ids: List[int], num_already_cached_blocks: int) -> None: + n_blocks = len(tokens) // self.block_size + for i in range(num_already_cached_blocks, n_blocks): + chunk = tokens[:(i + 1) * self.block_size] + hash = token_ids_to_hash(chunk) + if hash not in self.tokens_to_blocks: + self.tokens_to_blocks[hash] = new_block_ids[:i + 1] + self.blocks_to_tokens[frozenset(new_block_ids[:i + 1])] = hash + + def delete(self, block_ids: List[int]) -> None: + blocks_set = frozenset(block_ids) + for used_blocks, hash in self.blocks_to_tokens.items(): + # check intersection + if blocks_set & used_blocks: + del self.tokens_to_blocks[hash] + del self.blocks_to_tokens[used_blocks] diff --git a/deepspeed/inference/v2/ragged/ragged_manager.py b/deepspeed/inference/v2/ragged/ragged_manager.py index ecc3c52a5834..01958e7e1fc5 100644 --- a/deepspeed/inference/v2/ragged/ragged_manager.py +++ b/deepspeed/inference/v2/ragged/ragged_manager.py @@ -5,6 +5,7 @@ import torch from typing import Any, Dict, Optional, Tuple +from collections import defaultdict from deepspeed.accelerator import get_accelerator from deepspeed.ops.op_builder import RaggedUtilsBuilder @@ -14,6 +15,7 @@ from .kv_cache import BlockedKVCache from .manager_configs import DSStateManagerConfig, KVCacheConfig from .sequence_descriptor import DSSequenceDescriptor +from .prefix_block_map import PrefixBlockMap class DSStateManager: @@ -55,7 +57,8 @@ class DSStateManager: def __init__(self, config: DSStateManagerConfig, kv_configs: Tuple[KVCacheConfig, ...], - base_mp_group: Optional[Any] = None) -> None: + base_mp_group: Optional[Any] = None, + enable_prefix_cache: bool = False) -> None: """ The key @@ -95,7 +98,12 @@ def __init__(self, self._kv_cache = BlockedKVCache(self._kv_configs, self._config.memory_config, mp_group=base_mp_group, - offload=self._config.offload) + offload=self._config.offload, + enable_prefix_cache=enable_prefix_cache) + + assert len(self._kv_configs) == 1, "Only one KV cache group is supported for now." + self._block_map = PrefixBlockMap(self._kv_configs[0].block_size) + self._ref_counts = defaultdict(int) def get_cache(self, cache_id: int, cache_group: int = 0) -> torch.Tensor: """ @@ -116,8 +124,6 @@ def flush_sequence(self, uid: int) -> None: return seq = self._seqs[uid] - for i in range(self.n_kv_cache_groups): - self._kv_cache.free(seq.all_block_ids(cache_group=i), cache_group=i) self._tracking_allocator.free(seq.tracking_id) del self._seqs[uid] @@ -167,6 +173,37 @@ def _create_sequence(self, uid: int) -> DSSequenceDescriptor: logger.debug(f"Created sequence {uid} with tracking slot {tracking_slot}.") return self._seqs[uid] + def lookup_cache(self, tokens: torch.Tensor) -> int: + + block_ids = self._block_map.lookup(tokens) + assert len(self._kv_configs) == 1, "Only one KV cache group is supported for now." + + cache_hit_length = len(block_ids) * self._kv_configs[0].block_size + + if cache_hit_length == tokens.numel(): + # we don't keep logits in the cache, so we need to recompute + block_ids = block_ids[:-1] + return len(block_ids) * self._kv_configs[0].block_size, block_ids + return cache_hit_length, block_ids + + def update_cache(self, uid: int, tokens: torch.Tensor) -> None: + """ + Update the KV cache for the given sequence id. + """ + seq = self.get_sequence(uid) + num_full_blocks = seq.seen_tokens // self._kv_configs[0].block_size + if num_full_blocks > seq.num_prefix_cache_blocks: + self._block_map.extend(tokens[:seq.seen_tokens], seq.all_block_ids(), seq.num_prefix_cache_blocks) + seq.num_prefix_cache_blocks = num_full_blocks + + def increment_ref_count(self, block_ids: torch.Tensor) -> None: + for block_id in block_ids: + self._ref_counts[block_id.item()] += 1 + + def decrement_ref_count(self, block_ids: torch.Tensor) -> None: + for block_id in block_ids: + self._ref_counts[block_id.item()] -= 1 + @property def tracked_sequences(self) -> Dict[int, DSSequenceDescriptor]: """ @@ -186,7 +223,7 @@ def kv_block_size(self) -> int: """ Return the block size of the KV cache. """ - return self._kv_config.block_size + return self._kv_configs[0].block_size @property def n_kv_cache_groups(self) -> int: diff --git a/deepspeed/inference/v2/ragged/sequence_descriptor.py b/deepspeed/inference/v2/ragged/sequence_descriptor.py index 6b9f65255eec..25c1072ff042 100644 --- a/deepspeed/inference/v2/ragged/sequence_descriptor.py +++ b/deepspeed/inference/v2/ragged/sequence_descriptor.py @@ -95,6 +95,8 @@ class DSSequenceDescriptor(BaseSequenceDescriptor): # are stored. Used on flush. _tracking_id: int + _num_prefix_cache_blocks: int + def __init__(self, tracking_id: int, kv_cache_ids: Tuple[torch.Tensor, ...], @@ -132,6 +134,8 @@ def __init__(self, assert self._num_allocation_groups[cache_group] == kv_cache_ids.shape[0] assert len(kv_cache_ids.shape) == 2 + self._num_prefix_cache_blocks = 0 + @property def seen_tokens(self) -> int: """ @@ -278,3 +282,20 @@ def free_kv_cache(self, free_ids: Union[List[torch.IntTensor], torch.IntTensor], to have the same shape. """ raise NotImplementedError("Partial KV-cache freeing is not yet supported.") + + @property + def num_prefix_cache_blocks(self) -> int: + """ + The number of prefix cache blocks for the sequence. + """ + return self._num_prefix_cache_blocks + + @num_prefix_cache_blocks.setter + def num_prefix_cache_blocks(self, num_blocks: int) -> None: + """ + Set the number of prefix cache blocks for the sequence. + + Arguments: + num_blocks (int): The number of prefix cache blocks for the sequence. + """ + self._num_prefix_cache_blocks = num_blocks