Skip to content
Closed
3 changes: 3 additions & 0 deletions deepspeed/inference/v2/config_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,6 @@ class RaggedInferenceEngineConfig(DeepSpeedConfigModel):
"""

quantization: QuantizationConfig = {}

enable_prefix_cache: bool = False
""" Enable prefix cache for the model. """
45 changes: 43 additions & 2 deletions deepspeed/inference/v2/engine_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ def __init__(self, policy: InferenceV2Policy, engine_config: RaggedInferenceEngi
self._batch = RaggedBatchWrapper(self._config.state_manager)
self._state_manager = DSStateManager(self._config.state_manager,
self._model.kv_cache_config(),
base_mp_group=self._base_mp_group)
base_mp_group=self._base_mp_group,
enable_prefix_cache=self._config.enable_prefix_cache)
self._model.set_state_manager(self._state_manager)

def _initialize_tp_group(self):
Expand Down Expand Up @@ -129,7 +130,13 @@ def put(self,
for uid, tokens in zip(batch_uids, batch_tokens):

host_seq_desc = self._state_manager.get_or_create_sequence(uid)
self._model.maybe_allocate_kv(host_seq_desc, tokens.numel())
new_block_ids = self._model.maybe_allocate_kv(host_seq_desc, tokens.numel())
if self._config.enable_prefix_cache and new_block_ids is not None:
for block_id in new_block_ids:
assert self._state_manager._ref_counts[block_id.item()] == 0

self._state_manager.increment_ref_count(new_block_ids)
self._state_manager._block_map.delete(new_block_ids)
host_seq_desc.pre_forward(tokens.numel())

# We can disable checks since we already validated schedulability.
Expand Down Expand Up @@ -239,13 +246,47 @@ def get_remaining_block_capacity(self, uid: int) -> int:
return 0
return self._model.get_remaining_block_capacity(seq_desc)

def lookup_cache(self, tokens: torch.Tensor) -> Tuple[int, torch.Tensor]:
"""
Lookup the KV cache for a given sequence and allocate the necessary blocks.

Arguments:
block IDs (torch.Tensor): The tokens to allocate.
"""
if self._config.enable_prefix_cache:
return self._state_manager.lookup_cache(tokens)
return 0, torch.tensor([])

def setup_cached_sequence(self, uid: int, cached_length: int, block_ids: torch.Tensor) -> None:
if self._config.enable_prefix_cache:
seq = self._state_manager.get_or_create_sequence(uid)
seq.pre_forward(cached_length)
seq.post_forward()
seq.extend_kv_cache(block_ids)
seq.num_prefix_cache_blocks = len(block_ids)
self._state_manager.increment_ref_count(block_ids)
self._state_manager._kv_cache.allocate_blocks(block_ids)

def update_prefix_cache(self, uid: int, tokens: torch.Tensor) -> None:
if self._config.enable_prefix_cache:
self._state_manager.update_cache(uid, tokens)

def flush(self, uid: int) -> None:
"""
Remove all state associated with a sequence from the inference engine.

Arguments:
uid (int): The UID of the sequence to flush.
"""
seq = self._state_manager.get_sequence(uid)

if self._config.enable_prefix_cache:
self._state_manager.decrement_ref_count(seq.all_block_ids())
blocks_to_free = [b.item() for b in seq.all_block_ids() if self._state_manager._ref_counts[b.item()] == 0]
else:
blocks_to_free = seq.all_block_ids()
self._state_manager._kv_cache.free(blocks_to_free)

self._state_manager.flush_sequence(uid)

def serialize(self, save_path: str) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,14 +204,17 @@ def get_remaining_block_capacity(self, sequence: DSSequenceDescriptor) -> int:
raise NotImplementedError()

@abstractmethod
def maybe_allocate_kv(self, sequence: DSSequenceDescriptor, n_new_tokens: int) -> None:
def maybe_allocate_kv(self, sequence: DSSequenceDescriptor, n_new_tokens: int) -> Optional[torch.Tensor]:
"""
Given a sequence and the number of new tokens in the sequence, determine
whether or not additional KV-storage is needed and allocate it if so.

Args:
sequence (DSSequenceDescriptor): The sequence for which to allocate KV-storage.
n_new_tokens (int): The number of new tokens in the sequence.

Returns:
Optional[torch.Tensor]: The allocated KV block IDs. If no new blocks are needed, this returns None.
"""
raise NotImplementedError()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def get_kv_requirements(self, sequence: DSSequenceDescriptor, max_new_tokens: in
def get_remaining_block_capacity(self, sequence: DSSequenceDescriptor) -> int:
return sequence.seen_tokens % self.attn.kv_block_size

def maybe_allocate_kv(self, sequence: DSSequenceDescriptor, n_new_tokens: int) -> None:
def maybe_allocate_kv(self, sequence: DSSequenceDescriptor, n_new_tokens: int) -> Optional[torch.Tensor]:
"""
See ``DSInferenceModelBase.maybe_allocate_kv`` for documentation.

Expand All @@ -369,6 +369,9 @@ def maybe_allocate_kv(self, sequence: DSSequenceDescriptor, n_new_tokens: int) -
if n_needed_blocks > 0:
new_blocks = self.state_manager.allocate_blocks(n_needed_blocks)
sequence.extend_kv_cache(new_blocks)
return new_blocks

return None

def kv_cache_config(self) -> Tuple[KVCacheConfig, ...]:
"""
Expand Down
125 changes: 98 additions & 27 deletions deepspeed/inference/v2/ragged/blocked_allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,24 @@

# DeepSpeed Team

from typing import Iterable, Union
from typing import Iterable, Union, List
from abc import ABC, abstractmethod

import torch


class BlockedAllocator:
class BlockedAllocatorBase(ABC):
"""
Allocator class for managing which blocks are free/used in the
blocked KV-cache. This is a simple allocator that uses a linked list
to keep track of which blocks are free/used. The cost of allocation/deallocation
is O(blocks), where blocks is the number of blocks to allocate/deallocate.

TODO(cmikeh2): Evaluate performance of this allocator and migrate
to C++ if necessary.
blocked KV-cache.
"""

# Number of blocks in the KV-cache(s).
_num_blocks: int

# Array of blocks, where each element is the next block in the linked list.
_blocks: torch.Tensor

# Index of the head of the linked list.
_head: int

# Number of free blocks in the KV-cache.
_free_blocks: int

Expand All @@ -43,10 +37,9 @@ def __init__(self, num_blocks: int) -> None:
raise ValueError(f'Blocked KV-cache must have at least 1 block, provided {num_blocks}')

self._num_blocks = num_blocks
self._blocks = torch.arange(1, num_blocks + 1, dtype=torch.int32, device='cpu', pin_memory=True)
self._head = 0
self._free_blocks = num_blocks

@abstractmethod
def allocate(self, num_blocks: int) -> torch.Tensor:
"""
Allocate a list of blocks from the associated KV-caches. This will
Expand All @@ -57,7 +50,50 @@ def allocate(self, num_blocks: int) -> torch.Tensor:
num_blocks (int): The number of blocks to allocate.

Returns:
List[int]: The list of blocks allocated.
torch.Tensor: The list of blocks allocated.
"""
pass

@abstractmethod
def free(self, blocks: Union[Iterable[int], int]) -> None:
"""
Free a list of blocks in the associated KV-caches. This will return
the blocks to the free pool if they are valid and not already free.

Parameters:
blocks (Union[Iterable[int], int]): The list of blocks to free. If only one block
is to be freed, this can be alone as an integer.
"""
pass

@property
def free_blocks(self) -> int:
"""
Return the number of free blocks in the KV-cache.
"""
return self._free_blocks


class BlockedAllocator(BlockedAllocatorBase):
"""
This is a simple allocator that uses a linked list to keep track of which blocks are free/used. The cost of allocation/deallocation
is O(blocks), where blocks is the number of blocks to allocate/deallocate.

TODO(cmikeh2): Evaluate performance of this allocator and migrate
to C++ if necessary.
"""

# Index of the head of the linked list.
_head: int

def __init__(self, num_blocks: int) -> None:
super().__init__(num_blocks)
self._blocks = torch.arange(1, num_blocks + 1, dtype=torch.int32, device='cpu', pin_memory=True)
self._head = 0

def allocate(self, num_blocks: int) -> torch.Tensor:
"""
Refer to the docstring of `BlockedAllocatorBase.allocate`.
"""
if num_blocks > self._free_blocks:
raise ValueError(f'Not enough free blocks in the KV-cache to allocate {num_blocks} blocks')
Expand All @@ -73,13 +109,7 @@ def allocate(self, num_blocks: int) -> torch.Tensor:

def free(self, blocks: Union[Iterable[int], int]) -> None:
"""
Return a list of blocks to the free pool. If a single invalid block is provided (i.e.,
one that is out of range of the allocator or is already free), then an exception is raised
and no blocks are freed.

Parameters:
blocks (Union[Iterable[int], int]): The list of blocks to free. If only one block
is to be freed, this can be alone as an integer.
Refer to the docstring of `BlockedAllocatorBase.free`.
"""
if isinstance(blocks, int):
blocks = [blocks]
Expand All @@ -97,9 +127,50 @@ def free(self, blocks: Union[Iterable[int], int]) -> None:
self._head = block
self._free_blocks += 1

@property
def free_blocks(self) -> int:
"""
Return the number of free blocks in the KV-cache.
"""
return self._free_blocks

class LinearScanBlockedAllocator(BlockedAllocatorBase):

def __init__(self, num_blocks: int) -> None:
super().__init__(num_blocks)
self._blocks = torch.zeros(num_blocks + 1, dtype=torch.int32, device='cpu', pin_memory=True)

def allocate(self, num_blocks: int) -> torch.Tensor:
if num_blocks > self._free_blocks:
raise ValueError(f'Not enough free blocks in the KV-cache to allocate {num_blocks} blocks')

allocated_blocks = torch.zeros(num_blocks, dtype=torch.int32)

alloc_idx = 0
for i in range(self._num_blocks):
if self._blocks[i].item() == 0:
allocated_blocks[alloc_idx] = i
self._blocks[i] = -1
self._free_blocks -= 1
alloc_idx += 1
if alloc_idx == num_blocks:
break

return allocated_blocks

def allocate_blocks(self, blocks: List[int]) -> None:
for block in blocks:
if self._blocks[block] != -1:
self._blocks[block] = -1
self._free_blocks -= 1

def free(self, blocks: Union[Iterable[int], int]) -> None:

if isinstance(blocks, int):
blocks = [blocks]

for block in blocks:
# Parse all blocks for validity before mutating the list.
if block < 0 or block >= self._num_blocks:
raise ValueError(f'Invalid block {block} provided to free')

if self._blocks[block] != -1:
raise ValueError(f'Block {block} is already free')

for block in blocks:
self._blocks[block] = 0
self._free_blocks += 1
14 changes: 11 additions & 3 deletions deepspeed/inference/v2/ragged/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from deepspeed.accelerator import get_accelerator
from ..inference_utils import elem_size
from ..logging import inference_logger
from .blocked_allocator import BlockedAllocator
from .blocked_allocator import BlockedAllocator, LinearScanBlockedAllocator
from .manager_configs import AllocationMode, KVCacheConfig, MemoryConfig


Expand Down Expand Up @@ -61,7 +61,8 @@ def __init__(self,
configs: Tuple[KVCacheConfig, ...],
memory_config: MemoryConfig,
mp_group: Optional[Any] = None,
offload: bool = False) -> None:
offload: bool = False,
enable_prefix_cache: bool = False) -> None:
"""
Create a container that will maintain the storage and allocations for a set of
blocked KV-caches.
Expand Down Expand Up @@ -136,7 +137,11 @@ def __init__(self,
f"Allocating KV-cache {cache_group_id} with shape: {alloc_shape} consisting of {num_blocks} blocks.")
caches.append(torch.empty(alloc_shape, dtype=config.cache_dtype,
device=get_accelerator().current_device()))
allocators.append(BlockedAllocator(num_blocks))

if enable_prefix_cache:
allocators.append(LinearScanBlockedAllocator(num_blocks))
else:
allocators.append(BlockedAllocator(num_blocks))

self._caches = tuple(caches)
self._allocators = tuple(allocators)
Expand All @@ -152,6 +157,9 @@ def reserve(self, num_blocks: int, cache_group: int = 0) -> torch.Tensor:
"""
return self._allocators[cache_group].allocate(num_blocks)

def allocate_blocks(self, blocks: Iterable[int], cache_group: int = 0) -> None:
return self._allocators[cache_group].allocate_blocks(blocks)

def free(self, blocks: Iterable[int], cache_group: int = 0) -> None:
"""
Free a set of blocks from the cache. This will mark the blocks as free in the
Expand Down
56 changes: 56 additions & 0 deletions deepspeed/inference/v2/ragged/prefix_block_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from typing import Dict, List, Set
import hashlib

import torch


def token_ids_to_hash(token_ids: torch.Tensor):
# Convert the tensor to bytes
tensor_bytes = token_ids.numpy().tobytes()
hash_obj = hashlib.sha256()
# Update the hash object with the bytes
hash_obj.update(tensor_bytes)
# Get the hexadecimal digest of the hash
return hash_obj.hexdigest()


class PrefixBlockMap():

def __init__(self, block_size: int):
self.tokens_to_blocks: Dict[str, List[int]] = {}
self.blocks_to_tokens: Dict[Set[int], str] = {}
self.block_size: int = block_size

def lookup(self, tokens: torch.Tensor) -> torch.Tensor:
n_blocks = len(tokens) // self.block_size
cached_blocks = torch.tensor([], dtype=torch.int32)
for i in range(n_blocks):
chunk = tokens[:(i + 1) * self.block_size]
hash = token_ids_to_hash(chunk)
if hash in self.tokens_to_blocks:
cached_blocks = self.tokens_to_blocks[hash]
else:
break
return cached_blocks

def extend(self, tokens: torch.Tensor, new_block_ids: List[int], num_already_cached_blocks: int) -> None:
n_blocks = len(tokens) // self.block_size
for i in range(num_already_cached_blocks, n_blocks):
chunk = tokens[:(i + 1) * self.block_size]
hash = token_ids_to_hash(chunk)
if hash not in self.tokens_to_blocks:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need this if statement if in the for loop you are already starting from num_already_cached_blocks.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if it is necessary but there might be a complicated case. Assume we are running two requests.
Request 1 is using cached blocks 0, 1, 2, and just generated the last token of the current block. Then it saves the generated sequence and a hash.
Request 2 is also using the cached block 0, 1, 2, but the generation is a few steps later than Request 1. They are not sharing the last block. But the request may generate exact same tokens for the last block. So, it will try to update the cache with the same hash.
In this case, I didn't want to overwrite the cache.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good, so this is to ensure we are sharing the recently generated blocks.

self.tokens_to_blocks[hash] = new_block_ids[:i + 1]
self.blocks_to_tokens[frozenset(new_block_ids[:i + 1])] = hash

def delete(self, block_ids: List[int]) -> None:
blocks_set = frozenset(block_ids)
for used_blocks, hash in self.blocks_to_tokens.items():
# check intersection
if blocks_set & used_blocks:
del self.tokens_to_blocks[hash]
del self.blocks_to_tokens[used_blocks]
Loading