From 696e3b80e588ab522acab722c43b871ffb040bf3 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Tue, 23 Apr 2024 01:43:54 +0000 Subject: [PATCH 01/17] reuse prefix of kv cache --- deepspeed/inference/v2/engine_v2.py | 28 ++++- .../inference_transformer_base.py | 5 +- .../inference/v2/ragged/prefix_block_tree.py | 106 ++++++++++++++++++ .../inference/v2/ragged/ragged_manager.py | 29 +++++ 4 files changed, 165 insertions(+), 3 deletions(-) create mode 100644 deepspeed/inference/v2/ragged/prefix_block_tree.py diff --git a/deepspeed/inference/v2/engine_v2.py b/deepspeed/inference/v2/engine_v2.py index 4a358310377f..33e39105661f 100644 --- a/deepspeed/inference/v2/engine_v2.py +++ b/deepspeed/inference/v2/engine_v2.py @@ -129,7 +129,9 @@ def put(self, for uid, tokens in zip(batch_uids, batch_tokens): host_seq_desc = self._state_manager.get_or_create_sequence(uid) - self._model.maybe_allocate_kv(host_seq_desc, tokens.numel()) + new_block_ids = self._model.maybe_allocate_kv(host_seq_desc, tokens.numel()) + if new_block_ids is not None: + self._state_manager.increment_ref_count(new_block_ids) host_seq_desc.pre_forward(tokens.numel()) # We can disable checks since we already validated schedulability. @@ -148,9 +150,10 @@ def put(self, # We return one set of logits per sequence in the batch (saves cost on unembedding) assert logits.shape[0] == self._batch.current_sequences - for uid in batch_uids: + for uid, tokens in zip(batch_uids, batch_tokens): host_seq_desc = self._state_manager.get_sequence(uid) host_seq_desc.post_forward() # Updates sequence metadata. + self._state_manager.update_cache(uid, tokens) self._model.maybe_free_kv(host_seq_desc) return logits @@ -239,6 +242,24 @@ def get_remaining_block_capacity(self, uid: int) -> int: return 0 return self._model.get_remaining_block_capacity(seq_desc) + def lookup_cache(self, uid: int, tokens: torch.Tensor) -> None: + """ + Lookup the KV cache for a given sequence and allocate the necessary blocks. + + Arguments: + uid (int): The UID of the sequence. + tokens (torch.Tensor): The tokens to allocate. + """ + print(f"lookup_cache lookup_cache uid: {uid}") + return self._state_manager.lookup_cache(uid, tokens) + + def setup_cached_sequence(self, uid: int, cached_length:int, block_ids: torch.Tensor) -> None: + seq = self._state_manager.get_or_create_sequence(uid) + seq.pre_forward(cached_length) + seq.post_forward() + seq.extend_kv_cache(block_ids) + self._state_manager.increment_ref_count(block_ids) + def flush(self, uid: int) -> None: """ Remove all state associated with a sequence from the inference engine. @@ -246,6 +267,9 @@ def flush(self, uid: int) -> None: Arguments: uid (int): The UID of the sequence to flush. """ + seq = self._state_manager.get_sequence(uid) + self._state_manager.decrement_ref_count(seq.all_block_ids()) + self._state_manager.flush_sequence(uid) def serialize(self, save_path: str) -> None: diff --git a/deepspeed/inference/v2/model_implementations/inference_transformer_base.py b/deepspeed/inference/v2/model_implementations/inference_transformer_base.py index fae67dc8fc2a..8e697266dbeb 100644 --- a/deepspeed/inference/v2/model_implementations/inference_transformer_base.py +++ b/deepspeed/inference/v2/model_implementations/inference_transformer_base.py @@ -356,7 +356,7 @@ def get_kv_requirements(self, sequence: DSSequenceDescriptor, max_new_tokens: in def get_remaining_block_capacity(self, sequence: DSSequenceDescriptor) -> int: return sequence.seen_tokens % self.attn.kv_block_size - def maybe_allocate_kv(self, sequence: DSSequenceDescriptor, n_new_tokens: int) -> None: + def maybe_allocate_kv(self, sequence: DSSequenceDescriptor, n_new_tokens: int) -> Optional[torch.Tensor]: """ See ``DSInferenceModelBase.maybe_allocate_kv`` for documentation. @@ -369,6 +369,9 @@ def maybe_allocate_kv(self, sequence: DSSequenceDescriptor, n_new_tokens: int) - if n_needed_blocks > 0: new_blocks = self.state_manager.allocate_blocks(n_needed_blocks) sequence.extend_kv_cache(new_blocks) + return new_blocks + + return None def kv_cache_config(self) -> Tuple[KVCacheConfig, ...]: """ diff --git a/deepspeed/inference/v2/ragged/prefix_block_tree.py b/deepspeed/inference/v2/ragged/prefix_block_tree.py new file mode 100644 index 000000000000..e7fd48d156de --- /dev/null +++ b/deepspeed/inference/v2/ragged/prefix_block_tree.py @@ -0,0 +1,106 @@ +from typing import Any, Dict, Optional, List +import hashlib +from collections import defaultdict + +import torch + + +def token_ids_to_hash(token_ids: torch.Tensor): + # Convert the tensor to bytes + tensor_bytes = token_ids.numpy().tobytes() + hash_obj = hashlib.sha256() + # Update the hash object with the bytes + hash_obj.update(tensor_bytes) + # Get the hexadecimal digest of the hash + return hash_obj.hexdigest() + + +class PrefixBlockNode: + def __init__(self, token_ids: torch.Tensor, block_id: int, children: Dict): + self.token_ids = token_ids + self.block_id = block_id + self.children = children + self.hash = token_ids_to_hash(token_ids) + self.ref_count = 1 + + def add_child(self, hash, node): + self.children[hash] = node + + def get_child(self, prefix): + return self.children.get(prefix) + + def inc_ref_count(self): + self.ref_count += 1 + + def dec_ref_count(self): + self.ref_count -= 1 + + def __repr__(self): + return f"PrefixBlockNode(token_ids={self.token_ids.shape}, block_id={self.block_id}, children={self.children}, ref_count={self.ref_count})" + + +class PrefixBlockTree(): + + def __init__(self, block_size: int): + self.root: PrefixBlockNode = PrefixBlockNode(token_ids=torch.tensor([], dtype=torch.int32), block_id=-1, children={}) + + # Mapping from uid to token_ids. + self.prefix_nodes: Dict[int, List[PrefixBlockNode]] = defaultdict(list) + self.tokens: Dict[int, torch.Tensor] = defaultdict(lambda: torch.tensor([], dtype=torch.int32)) + self.block_size: int = block_size + + def lookup(self, tokens: torch.Tensor, increment_ref=False, decrement_ref=False) -> List[PrefixBlockNode]: + assert not (increment_ref and decrement_ref), 'increment_ref and decrement_ref cannot be set to True at the same time' + + chunks = torch.split(tokens, self.block_size) + current_node = self.root + path = [] + for chunk in chunks: + hash = token_ids_to_hash(chunk) + # print(f"lookup chunk={chunk.shape} hash={hash}") + if hash in current_node.children: + current_node = current_node.children[hash] + path.append(current_node) + if increment_ref: + current_node.inc_ref_count() + else: + break + return path + + def allocate(self, tokens: torch.Tensor) -> List[int]: + path = self.lookup(tokens) + if len(path) == 0: + return torch.tensor([], dtype=torch.int32) + return torch.cat([node.block_id.unsqueeze(0) for node in path]) + + def extend(self, uid: int, tokens: torch.Tensor, new_block_ids: List[int]) -> None: + path = self.prefix_nodes[uid] + self.tokens[uid] = torch.cat([self.tokens[uid], tokens]) + + n_full_blocks = len(self.tokens[uid]) // self.block_size + new_full_blocks = n_full_blocks - len(path) + + if new_full_blocks == 0: + return + + chunks = torch.split(tokens, self.block_size)[len(path):len(path) + new_full_blocks] + if len(path) == 0: + current_node = self.root + else: + current_node = path[-1] + + for chunk, block_id in zip(chunks, new_block_ids): + hash = token_ids_to_hash(chunk) + assert hash not in current_node.children, 'Chunk already exists in the tree' + + new_node = PrefixBlockNode(token_ids=chunk, block_id=block_id, children={}) + current_node.add_child(hash, new_node) + + path.append(current_node) + current_node = current_node.children[hash] + # current_node.inc_ref_count() + + # print(f"adding chunk to tree: hash={hash} current_node={current_node}") + + def delete(self, prefix_ids: torch.Tensor) -> None: + self.lookup(prefix_ids, decrement_ref=True) diff --git a/deepspeed/inference/v2/ragged/ragged_manager.py b/deepspeed/inference/v2/ragged/ragged_manager.py index ecc3c52a5834..7593e3c4952f 100644 --- a/deepspeed/inference/v2/ragged/ragged_manager.py +++ b/deepspeed/inference/v2/ragged/ragged_manager.py @@ -5,6 +5,7 @@ import torch from typing import Any, Dict, Optional, Tuple +from collections import defaultdict from deepspeed.accelerator import get_accelerator from deepspeed.ops.op_builder import RaggedUtilsBuilder @@ -14,6 +15,7 @@ from .kv_cache import BlockedKVCache from .manager_configs import DSStateManagerConfig, KVCacheConfig from .sequence_descriptor import DSSequenceDescriptor +from .prefix_block_tree import PrefixBlockTree class DSStateManager: @@ -97,6 +99,11 @@ def __init__(self, mp_group=base_mp_group, offload=self._config.offload) + assert len(self._kv_configs) == 1, "Only one KV cache group is supported for now." + self._block_tree = PrefixBlockTree(self._kv_configs[0].block_size) + self._ref_counts = defaultdict(int) + + def get_cache(self, cache_id: int, cache_group: int = 0) -> torch.Tensor: """ Return the Tensor associated with the given cache id in the specified cache group. @@ -167,6 +174,28 @@ def _create_sequence(self, uid: int) -> DSSequenceDescriptor: logger.debug(f"Created sequence {uid} with tracking slot {tracking_slot}.") return self._seqs[uid] + def lookup_cache(self, uid: int, tokens: torch.Tensor) -> int: + + block_ids = self._block_tree.allocate(tokens) + assert len(self._kv_configs) == 1, "Only one KV cache group is supported for now." + return len(block_ids) * self._kv_configs[0].block_size, block_ids + + def update_cache(self, uid: int, tokens: torch.Tensor) -> None: + """ + Update the KV cache for the given sequence id. + """ + print(f"update_cache tokens={tokens.shape} numel={tokens.numel()}") + seq = self.get_sequence(uid) + self._block_tree.extend(uid, tokens, seq.all_block_ids()) + + def increment_ref_count(self, block_ids: torch.Tensor) -> None: + for block_id in block_ids: + self._ref_counts[block_id.item()] += 1 + + def decrement_ref_count(self, block_ids: torch.Tensor) -> None: + for block_id in block_ids: + self._ref_counts[block_id.item()] -= 1 + @property def tracked_sequences(self) -> Dict[int, DSSequenceDescriptor]: """ From 2e8ac1cc7e9caef19292e9f29f21acc355af35f3 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Fri, 26 Apr 2024 03:02:35 +0000 Subject: [PATCH 02/17] free block with no ref --- deepspeed/inference/v2/engine_v2.py | 9 +- .../inference/v2/ragged/prefix_block_tree.py | 109 +++++------------- .../inference/v2/ragged/ragged_manager.py | 12 +- 3 files changed, 38 insertions(+), 92 deletions(-) diff --git a/deepspeed/inference/v2/engine_v2.py b/deepspeed/inference/v2/engine_v2.py index 33e39105661f..cf0ca2653d9e 100644 --- a/deepspeed/inference/v2/engine_v2.py +++ b/deepspeed/inference/v2/engine_v2.py @@ -132,6 +132,7 @@ def put(self, new_block_ids = self._model.maybe_allocate_kv(host_seq_desc, tokens.numel()) if new_block_ids is not None: self._state_manager.increment_ref_count(new_block_ids) + self._state_manager._block_tree.delete(new_block_ids) host_seq_desc.pre_forward(tokens.numel()) # We can disable checks since we already validated schedulability. @@ -242,16 +243,14 @@ def get_remaining_block_capacity(self, uid: int) -> int: return 0 return self._model.get_remaining_block_capacity(seq_desc) - def lookup_cache(self, uid: int, tokens: torch.Tensor) -> None: + def lookup_cache(self, tokens: torch.Tensor) -> None: """ Lookup the KV cache for a given sequence and allocate the necessary blocks. Arguments: - uid (int): The UID of the sequence. tokens (torch.Tensor): The tokens to allocate. """ - print(f"lookup_cache lookup_cache uid: {uid}") - return self._state_manager.lookup_cache(uid, tokens) + return self._state_manager.lookup_cache(tokens) def setup_cached_sequence(self, uid: int, cached_length:int, block_ids: torch.Tensor) -> None: seq = self._state_manager.get_or_create_sequence(uid) @@ -269,6 +268,8 @@ def flush(self, uid: int) -> None: """ seq = self._state_manager.get_sequence(uid) self._state_manager.decrement_ref_count(seq.all_block_ids()) + no_ref_blocks = [b.item() for b in seq.all_block_ids() if self._state_manager._ref_counts[b.item()] == 0] + self._state_manager._kv_cache.free(no_ref_blocks) self._state_manager.flush_sequence(uid) diff --git a/deepspeed/inference/v2/ragged/prefix_block_tree.py b/deepspeed/inference/v2/ragged/prefix_block_tree.py index e7fd48d156de..61d341a12992 100644 --- a/deepspeed/inference/v2/ragged/prefix_block_tree.py +++ b/deepspeed/inference/v2/ragged/prefix_block_tree.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, List +from typing import Dict, List, Set import hashlib from collections import defaultdict @@ -15,92 +15,39 @@ def token_ids_to_hash(token_ids: torch.Tensor): return hash_obj.hexdigest() -class PrefixBlockNode: - def __init__(self, token_ids: torch.Tensor, block_id: int, children: Dict): - self.token_ids = token_ids - self.block_id = block_id - self.children = children - self.hash = token_ids_to_hash(token_ids) - self.ref_count = 1 - - def add_child(self, hash, node): - self.children[hash] = node - - def get_child(self, prefix): - return self.children.get(prefix) - - def inc_ref_count(self): - self.ref_count += 1 - - def dec_ref_count(self): - self.ref_count -= 1 - - def __repr__(self): - return f"PrefixBlockNode(token_ids={self.token_ids.shape}, block_id={self.block_id}, children={self.children}, ref_count={self.ref_count})" - - -class PrefixBlockTree(): +class PrefixBlockMap(): def __init__(self, block_size: int): - self.root: PrefixBlockNode = PrefixBlockNode(token_ids=torch.tensor([], dtype=torch.int32), block_id=-1, children={}) - - # Mapping from uid to token_ids. - self.prefix_nodes: Dict[int, List[PrefixBlockNode]] = defaultdict(list) - self.tokens: Dict[int, torch.Tensor] = defaultdict(lambda: torch.tensor([], dtype=torch.int32)) + self.tokens_to_blocks: Dict[str, List[int]] = {} + self.blocks_to_tokens: Dict[Set[int], str] = {} self.block_size: int = block_size - - def lookup(self, tokens: torch.Tensor, increment_ref=False, decrement_ref=False) -> List[PrefixBlockNode]: - assert not (increment_ref and decrement_ref), 'increment_ref and decrement_ref cannot be set to True at the same time' - - chunks = torch.split(tokens, self.block_size) - current_node = self.root - path = [] - for chunk in chunks: + + def lookup(self, tokens: torch.Tensor) -> torch.Tensor: + n_blocks = len(tokens) // self.block_size + cached_blocks = torch.tensor([], dtype=torch.int32) + for i in range(n_blocks): + chunk = tokens[:(i+1)*self.block_size] hash = token_ids_to_hash(chunk) - # print(f"lookup chunk={chunk.shape} hash={hash}") - if hash in current_node.children: - current_node = current_node.children[hash] - path.append(current_node) - if increment_ref: - current_node.inc_ref_count() + if hash in self.tokens_to_blocks: + cached_blocks = self.tokens_to_blocks[hash] else: break - return path - - def allocate(self, tokens: torch.Tensor) -> List[int]: - path = self.lookup(tokens) - if len(path) == 0: - return torch.tensor([], dtype=torch.int32) - return torch.cat([node.block_id.unsqueeze(0) for node in path]) - - def extend(self, uid: int, tokens: torch.Tensor, new_block_ids: List[int]) -> None: - path = self.prefix_nodes[uid] - self.tokens[uid] = torch.cat([self.tokens[uid], tokens]) - - n_full_blocks = len(self.tokens[uid]) // self.block_size - new_full_blocks = n_full_blocks - len(path) - - if new_full_blocks == 0: - return + return cached_blocks - chunks = torch.split(tokens, self.block_size)[len(path):len(path) + new_full_blocks] - if len(path) == 0: - current_node = self.root - else: - current_node = path[-1] - - for chunk, block_id in zip(chunks, new_block_ids): + def extend(self, tokens: torch.Tensor, new_block_ids: List[int]) -> None: + n_blocks = len(tokens) // self.block_size + for i in range(n_blocks): + chunk = tokens[:(i+1)*self.block_size] hash = token_ids_to_hash(chunk) - assert hash not in current_node.children, 'Chunk already exists in the tree' - - new_node = PrefixBlockNode(token_ids=chunk, block_id=block_id, children={}) - current_node.add_child(hash, new_node) - - path.append(current_node) - current_node = current_node.children[hash] - # current_node.inc_ref_count() - - # print(f"adding chunk to tree: hash={hash} current_node={current_node}") + if hash not in self.tokens_to_blocks: + self.tokens_to_blocks[hash] = new_block_ids[:i+1] + self.blocks_to_tokens[frozenset(new_block_ids[:i+1])] = hash + + def delete(self, block_ids: List[int]) -> None: + blocks_set = frozenset(block_ids) + for used_blocks, hash in self.blocks_to_tokens.items(): + # check intersection + if blocks_set & used_blocks: + del self.tokens_to_blocks[hash] + del self.blocks_to_tokens[used_blocks] - def delete(self, prefix_ids: torch.Tensor) -> None: - self.lookup(prefix_ids, decrement_ref=True) diff --git a/deepspeed/inference/v2/ragged/ragged_manager.py b/deepspeed/inference/v2/ragged/ragged_manager.py index 7593e3c4952f..50296786d344 100644 --- a/deepspeed/inference/v2/ragged/ragged_manager.py +++ b/deepspeed/inference/v2/ragged/ragged_manager.py @@ -15,7 +15,7 @@ from .kv_cache import BlockedKVCache from .manager_configs import DSStateManagerConfig, KVCacheConfig from .sequence_descriptor import DSSequenceDescriptor -from .prefix_block_tree import PrefixBlockTree +from .prefix_block_tree import PrefixBlockMap class DSStateManager: @@ -100,7 +100,7 @@ def __init__(self, offload=self._config.offload) assert len(self._kv_configs) == 1, "Only one KV cache group is supported for now." - self._block_tree = PrefixBlockTree(self._kv_configs[0].block_size) + self._block_tree = PrefixBlockMap(self._kv_configs[0].block_size) self._ref_counts = defaultdict(int) @@ -123,8 +123,6 @@ def flush_sequence(self, uid: int) -> None: return seq = self._seqs[uid] - for i in range(self.n_kv_cache_groups): - self._kv_cache.free(seq.all_block_ids(cache_group=i), cache_group=i) self._tracking_allocator.free(seq.tracking_id) del self._seqs[uid] @@ -174,9 +172,9 @@ def _create_sequence(self, uid: int) -> DSSequenceDescriptor: logger.debug(f"Created sequence {uid} with tracking slot {tracking_slot}.") return self._seqs[uid] - def lookup_cache(self, uid: int, tokens: torch.Tensor) -> int: + def lookup_cache(self, tokens: torch.Tensor) -> int: - block_ids = self._block_tree.allocate(tokens) + block_ids = self._block_tree.lookup(tokens) assert len(self._kv_configs) == 1, "Only one KV cache group is supported for now." return len(block_ids) * self._kv_configs[0].block_size, block_ids @@ -186,7 +184,7 @@ def update_cache(self, uid: int, tokens: torch.Tensor) -> None: """ print(f"update_cache tokens={tokens.shape} numel={tokens.numel()}") seq = self.get_sequence(uid) - self._block_tree.extend(uid, tokens, seq.all_block_ids()) + self._block_tree.extend(tokens, seq.all_block_ids()) def increment_ref_count(self, block_ids: torch.Tensor) -> None: for block_id in block_ids: From 619a363787c9eeab42b37310de3fdf0c7b9f0c1c Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Sun, 28 Apr 2024 22:08:38 +0000 Subject: [PATCH 03/17] reversed block id list when freeing --- deepspeed/inference/v2/ragged/blocked_allocator.py | 2 +- deepspeed/inference/v2/ragged/ragged_manager.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/deepspeed/inference/v2/ragged/blocked_allocator.py b/deepspeed/inference/v2/ragged/blocked_allocator.py index 7884d8cccb47..04069ff4c918 100644 --- a/deepspeed/inference/v2/ragged/blocked_allocator.py +++ b/deepspeed/inference/v2/ragged/blocked_allocator.py @@ -92,7 +92,7 @@ def free(self, blocks: Union[Iterable[int], int]) -> None: if self._blocks[block] != -1: raise ValueError(f'Block {block} is already free') - for block in blocks: + for block in reversed(blocks): self._blocks[block] = self._head self._head = block self._free_blocks += 1 diff --git a/deepspeed/inference/v2/ragged/ragged_manager.py b/deepspeed/inference/v2/ragged/ragged_manager.py index 50296786d344..604597c7b29c 100644 --- a/deepspeed/inference/v2/ragged/ragged_manager.py +++ b/deepspeed/inference/v2/ragged/ragged_manager.py @@ -182,7 +182,6 @@ def update_cache(self, uid: int, tokens: torch.Tensor) -> None: """ Update the KV cache for the given sequence id. """ - print(f"update_cache tokens={tokens.shape} numel={tokens.numel()}") seq = self.get_sequence(uid) self._block_tree.extend(tokens, seq.all_block_ids()) From 73075073c6b17efc5178fc0d8cd5a64c5a000238 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Mon, 29 Apr 2024 20:23:50 +0000 Subject: [PATCH 04/17] add option to enable prefix cache --- deepspeed/inference/v2/config_v2.py | 3 +++ deepspeed/inference/v2/engine_v2.py | 34 ++++++++++++++++++----------- 2 files changed, 24 insertions(+), 13 deletions(-) diff --git a/deepspeed/inference/v2/config_v2.py b/deepspeed/inference/v2/config_v2.py index 85e4b7a0e0a0..d88d233a3632 100644 --- a/deepspeed/inference/v2/config_v2.py +++ b/deepspeed/inference/v2/config_v2.py @@ -41,3 +41,6 @@ class RaggedInferenceEngineConfig(DeepSpeedConfigModel): """ quantization: QuantizationConfig = {} + + enable_prefix_cache: bool = False + diff --git a/deepspeed/inference/v2/engine_v2.py b/deepspeed/inference/v2/engine_v2.py index cf0ca2653d9e..895eefc8289b 100644 --- a/deepspeed/inference/v2/engine_v2.py +++ b/deepspeed/inference/v2/engine_v2.py @@ -130,7 +130,7 @@ def put(self, host_seq_desc = self._state_manager.get_or_create_sequence(uid) new_block_ids = self._model.maybe_allocate_kv(host_seq_desc, tokens.numel()) - if new_block_ids is not None: + if self._config.enable_prefix_cache and new_block_ids is not None: self._state_manager.increment_ref_count(new_block_ids) self._state_manager._block_tree.delete(new_block_ids) host_seq_desc.pre_forward(tokens.numel()) @@ -154,7 +154,8 @@ def put(self, for uid, tokens in zip(batch_uids, batch_tokens): host_seq_desc = self._state_manager.get_sequence(uid) host_seq_desc.post_forward() # Updates sequence metadata. - self._state_manager.update_cache(uid, tokens) + if self._config.enable_prefix_cache: + self._state_manager.update_cache(uid, tokens) self._model.maybe_free_kv(host_seq_desc) return logits @@ -243,21 +244,24 @@ def get_remaining_block_capacity(self, uid: int) -> int: return 0 return self._model.get_remaining_block_capacity(seq_desc) - def lookup_cache(self, tokens: torch.Tensor) -> None: + def lookup_cache(self, tokens: torch.Tensor) -> Tuple[int, torch.Tensor]: """ Lookup the KV cache for a given sequence and allocate the necessary blocks. Arguments: - tokens (torch.Tensor): The tokens to allocate. + block IDs (torch.Tensor): The tokens to allocate. """ - return self._state_manager.lookup_cache(tokens) + if self._config.enable_prefix_cache: + return self._state_manager.lookup_cache(tokens) + return 0, torch.tensor([]) def setup_cached_sequence(self, uid: int, cached_length:int, block_ids: torch.Tensor) -> None: - seq = self._state_manager.get_or_create_sequence(uid) - seq.pre_forward(cached_length) - seq.post_forward() - seq.extend_kv_cache(block_ids) - self._state_manager.increment_ref_count(block_ids) + if self._config.enable_prefix_cache: + seq = self._state_manager.get_or_create_sequence(uid) + seq.pre_forward(cached_length) + seq.post_forward() + seq.extend_kv_cache(block_ids) + self._state_manager.increment_ref_count(block_ids) def flush(self, uid: int) -> None: """ @@ -267,9 +271,13 @@ def flush(self, uid: int) -> None: uid (int): The UID of the sequence to flush. """ seq = self._state_manager.get_sequence(uid) - self._state_manager.decrement_ref_count(seq.all_block_ids()) - no_ref_blocks = [b.item() for b in seq.all_block_ids() if self._state_manager._ref_counts[b.item()] == 0] - self._state_manager._kv_cache.free(no_ref_blocks) + + if self._config.enable_prefix_cache: + self._state_manager.decrement_ref_count(seq.all_block_ids()) + blocks_to_free = [b.item() for b in seq.all_block_ids() if self._state_manager._ref_counts[b.item()] == 0] + else: + blocks_to_free = seq.all_block_ids() + self._state_manager._kv_cache.free(blocks_to_free) self._state_manager.flush_sequence(uid) From a28b706bf0f9cb4f28e0b33617f177c4d3e8a9af Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Wed, 8 May 2024 01:31:11 +0000 Subject: [PATCH 05/17] fix use of prefix cache --- deepspeed/inference/v2/engine_v2.py | 7 +++++-- .../inference/v2/ragged/blocked_allocator.py | 20 ++++++++++++++++++- deepspeed/inference/v2/ragged/kv_cache.py | 3 +++ .../inference/v2/ragged/ragged_manager.py | 11 ++++++++-- 4 files changed, 36 insertions(+), 5 deletions(-) diff --git a/deepspeed/inference/v2/engine_v2.py b/deepspeed/inference/v2/engine_v2.py index 895eefc8289b..fb4faa70e9e4 100644 --- a/deepspeed/inference/v2/engine_v2.py +++ b/deepspeed/inference/v2/engine_v2.py @@ -154,8 +154,6 @@ def put(self, for uid, tokens in zip(batch_uids, batch_tokens): host_seq_desc = self._state_manager.get_sequence(uid) host_seq_desc.post_forward() # Updates sequence metadata. - if self._config.enable_prefix_cache: - self._state_manager.update_cache(uid, tokens) self._model.maybe_free_kv(host_seq_desc) return logits @@ -262,6 +260,11 @@ def setup_cached_sequence(self, uid: int, cached_length:int, block_ids: torch.Te seq.post_forward() seq.extend_kv_cache(block_ids) self._state_manager.increment_ref_count(block_ids) + self._state_manager._kv_cache.allocate_blocks(block_ids) + + def update_prefix_cache(self, uid: int, tokens: torch.Tensor) -> None: + if self._config.enable_prefix_cache: + self._state_manager.update_cache(uid, tokens) def flush(self, uid: int) -> None: """ diff --git a/deepspeed/inference/v2/ragged/blocked_allocator.py b/deepspeed/inference/v2/ragged/blocked_allocator.py index 04069ff4c918..e4a5ef4d39da 100644 --- a/deepspeed/inference/v2/ragged/blocked_allocator.py +++ b/deepspeed/inference/v2/ragged/blocked_allocator.py @@ -3,7 +3,7 @@ # DeepSpeed Team -from typing import Iterable, Union +from typing import Iterable, Union, List import torch @@ -71,6 +71,24 @@ def allocate(self, num_blocks: int) -> torch.Tensor: return allocated_blocks + def allocate_blocks(self, blocks: List[int]) -> None: + + for block in blocks: + if self._blocks[block] != -1: + self._blocks[block] = -1 + self._free_blocks -= 1 + + self._head = -1 + for i, b in enumerate(self._blocks): + next_available = b.item() + if next_available == -1: + while self._blocks[next_available].item() == -1: + next_available += 1 + self._blocks[i] = next_available + else: + if self._head == -1: + self._head = i + def free(self, blocks: Union[Iterable[int], int]) -> None: """ Return a list of blocks to the free pool. If a single invalid block is provided (i.e., diff --git a/deepspeed/inference/v2/ragged/kv_cache.py b/deepspeed/inference/v2/ragged/kv_cache.py index ceba3190b93c..88e5a794ca07 100644 --- a/deepspeed/inference/v2/ragged/kv_cache.py +++ b/deepspeed/inference/v2/ragged/kv_cache.py @@ -152,6 +152,9 @@ def reserve(self, num_blocks: int, cache_group: int = 0) -> torch.Tensor: """ return self._allocators[cache_group].allocate(num_blocks) + def allocate_blocks(self, blocks: Iterable[int], cache_group: int = 0) -> None: + return self._allocators[cache_group].allocate_blocks(blocks) + def free(self, blocks: Iterable[int], cache_group: int = 0) -> None: """ Free a set of blocks from the cache. This will mark the blocks as free in the diff --git a/deepspeed/inference/v2/ragged/ragged_manager.py b/deepspeed/inference/v2/ragged/ragged_manager.py index 604597c7b29c..0df76b99e2c5 100644 --- a/deepspeed/inference/v2/ragged/ragged_manager.py +++ b/deepspeed/inference/v2/ragged/ragged_manager.py @@ -176,7 +176,14 @@ def lookup_cache(self, tokens: torch.Tensor) -> int: block_ids = self._block_tree.lookup(tokens) assert len(self._kv_configs) == 1, "Only one KV cache group is supported for now." - return len(block_ids) * self._kv_configs[0].block_size, block_ids + + cache_hit_length = len(block_ids) * self._kv_configs[0].block_size + + if cache_hit_length == tokens.numel(): + # we don't keep logits in the cache, so we need to recompute + block_ids = block_ids[:-1] + return len(block_ids) * self._kv_configs[0].block_size, block_ids + return cache_hit_length, block_ids def update_cache(self, uid: int, tokens: torch.Tensor) -> None: """ @@ -212,7 +219,7 @@ def kv_block_size(self) -> int: """ Return the block size of the KV cache. """ - return self._kv_config.block_size + return self._kv_configs[0].block_size @property def n_kv_cache_groups(self) -> int: From d8e9d2858c8d19bad2c57e49e5ebed3d232466cb Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Sat, 25 May 2024 02:21:21 +0000 Subject: [PATCH 06/17] fix allocation bug --- .../inference/v2/ragged/blocked_allocator.py | 99 ++++++++++++++++--- deepspeed/inference/v2/ragged/kv_cache.py | 5 +- 2 files changed, 89 insertions(+), 15 deletions(-) diff --git a/deepspeed/inference/v2/ragged/blocked_allocator.py b/deepspeed/inference/v2/ragged/blocked_allocator.py index e4a5ef4d39da..7f93c22b9c3f 100644 --- a/deepspeed/inference/v2/ragged/blocked_allocator.py +++ b/deepspeed/inference/v2/ragged/blocked_allocator.py @@ -71,24 +71,97 @@ def allocate(self, num_blocks: int) -> torch.Tensor: return allocated_blocks + def free(self, blocks: Union[Iterable[int], int]) -> None: + """ + Return a list of blocks to the free pool. If a single invalid block is provided (i.e., + one that is out of range of the allocator or is already free), then an exception is raised + and no blocks are freed. + + Parameters: + blocks (Union[Iterable[int], int]): The list of blocks to free. If only one block + is to be freed, this can be alone as an integer. + """ + if isinstance(blocks, int): + blocks = [blocks] + + for block in blocks: + # Parse all blocks for validity before mutating the list. + if block < 0 or block >= self._num_blocks: + raise ValueError(f'Invalid block {block} provided to free') + + if self._blocks[block] != -1: + raise ValueError(f'Block {block} is already free') + + for block in reversed(blocks): + self._blocks[block] = self._head + self._head = block + self._free_blocks += 1 + + @property + def free_blocks(self) -> int: + """ + Return the number of free blocks in the KV-cache. + """ + return self._free_blocks + + +class LinearScanBlockedAllocator: + # Number of blocks in the KV-cache(s). + _num_blocks: int + + # Array of blocks, where each element is the next block in the linked list. + _blocks: torch.Tensor + + # Number of free blocks in the KV-cache. + _free_blocks: int + + def __init__(self, num_blocks: int) -> None: + + if num_blocks < 1: + raise ValueError(f'Blocked KV-cache must have at least 1 block, provided {num_blocks}') + + self._num_blocks = num_blocks + self._blocks = torch.zeros(num_blocks + 1, dtype=torch.int32, device='cpu', pin_memory=True) + self._free_blocks = num_blocks + + def allocate(self, num_blocks: int) -> torch.Tensor: + """ + Allocate a list of blocks from the associated KV-caches. This will + return `num_blocks` blocks from the KV-cache if they are available, + or raise an exception if there are not enough free blocks. + + Parameters: + num_blocks (int): The number of blocks to allocate. + + Returns: + List[int]: The list of blocks allocated. + """ + + if num_blocks > self._free_blocks: + raise ValueError(f'Not enough free blocks in the KV-cache to allocate {num_blocks} blocks') + + allocated_blocks = torch.zeros(num_blocks, dtype=torch.int32) + + alloc_idx = 0 + for i in range(self._num_blocks): + if self._blocks[i].item() == 0: + allocated_blocks[alloc_idx] = i + self._blocks[i] = -1 + self._free_blocks -= 1 + alloc_idx += 1 + if alloc_idx == num_blocks: + break + + return allocated_blocks + def allocate_blocks(self, blocks: List[int]) -> None: for block in blocks: + # assert self._blocks[block] == 0, f"Block {block} is already allocated" if self._blocks[block] != -1: self._blocks[block] = -1 self._free_blocks -= 1 - self._head = -1 - for i, b in enumerate(self._blocks): - next_available = b.item() - if next_available == -1: - while self._blocks[next_available].item() == -1: - next_available += 1 - self._blocks[i] = next_available - else: - if self._head == -1: - self._head = i - def free(self, blocks: Union[Iterable[int], int]) -> None: """ Return a list of blocks to the free pool. If a single invalid block is provided (i.e., @@ -99,6 +172,7 @@ def free(self, blocks: Union[Iterable[int], int]) -> None: blocks (Union[Iterable[int], int]): The list of blocks to free. If only one block is to be freed, this can be alone as an integer. """ + if isinstance(blocks, int): blocks = [blocks] @@ -111,8 +185,7 @@ def free(self, blocks: Union[Iterable[int], int]) -> None: raise ValueError(f'Block {block} is already free') for block in reversed(blocks): - self._blocks[block] = self._head - self._head = block + self._blocks[block] = 0 self._free_blocks += 1 @property diff --git a/deepspeed/inference/v2/ragged/kv_cache.py b/deepspeed/inference/v2/ragged/kv_cache.py index 88e5a794ca07..7cebdccab830 100644 --- a/deepspeed/inference/v2/ragged/kv_cache.py +++ b/deepspeed/inference/v2/ragged/kv_cache.py @@ -15,7 +15,7 @@ from deepspeed.accelerator import get_accelerator from ..inference_utils import elem_size from ..logging import inference_logger -from .blocked_allocator import BlockedAllocator +from .blocked_allocator import BlockedAllocator, LinearScanBlockedAllocator from .manager_configs import AllocationMode, KVCacheConfig, MemoryConfig @@ -136,7 +136,8 @@ def __init__(self, f"Allocating KV-cache {cache_group_id} with shape: {alloc_shape} consisting of {num_blocks} blocks.") caches.append(torch.empty(alloc_shape, dtype=config.cache_dtype, device=get_accelerator().current_device())) - allocators.append(BlockedAllocator(num_blocks)) + # allocators.append(BlockedAllocator(num_blocks)) + allocators.append(LinearScanBlockedAllocator(num_blocks)) self._caches = tuple(caches) self._allocators = tuple(allocators) From 9cffae127a2d67f9ecf70643482ab92c0e2be5ba Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Mon, 27 May 2024 09:31:58 +0000 Subject: [PATCH 07/17] use normal allocator if prefix sharing is disabled --- deepspeed/inference/v2/engine_v2.py | 3 ++- deepspeed/inference/v2/ragged/kv_cache.py | 10 +++++++--- deepspeed/inference/v2/ragged/ragged_manager.py | 6 ++++-- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/deepspeed/inference/v2/engine_v2.py b/deepspeed/inference/v2/engine_v2.py index fb4faa70e9e4..dcde7e5ecc44 100644 --- a/deepspeed/inference/v2/engine_v2.py +++ b/deepspeed/inference/v2/engine_v2.py @@ -87,7 +87,8 @@ def __init__(self, policy: InferenceV2Policy, engine_config: RaggedInferenceEngi self._batch = RaggedBatchWrapper(self._config.state_manager) self._state_manager = DSStateManager(self._config.state_manager, self._model.kv_cache_config(), - base_mp_group=self._base_mp_group) + base_mp_group=self._base_mp_group, + enable_prefix_cache=self._config.enable_prefix_cache) self._model.set_state_manager(self._state_manager) def _initialize_tp_group(self): diff --git a/deepspeed/inference/v2/ragged/kv_cache.py b/deepspeed/inference/v2/ragged/kv_cache.py index 7cebdccab830..fb45740ef25a 100644 --- a/deepspeed/inference/v2/ragged/kv_cache.py +++ b/deepspeed/inference/v2/ragged/kv_cache.py @@ -61,7 +61,8 @@ def __init__(self, configs: Tuple[KVCacheConfig, ...], memory_config: MemoryConfig, mp_group: Optional[Any] = None, - offload: bool = False) -> None: + offload: bool = False, + enable_prefix_cache: bool = False) -> None: """ Create a container that will maintain the storage and allocations for a set of blocked KV-caches. @@ -136,8 +137,11 @@ def __init__(self, f"Allocating KV-cache {cache_group_id} with shape: {alloc_shape} consisting of {num_blocks} blocks.") caches.append(torch.empty(alloc_shape, dtype=config.cache_dtype, device=get_accelerator().current_device())) - # allocators.append(BlockedAllocator(num_blocks)) - allocators.append(LinearScanBlockedAllocator(num_blocks)) + + if enable_prefix_cache: + allocators.append(LinearScanBlockedAllocator(num_blocks)) + else: + allocators.append(BlockedAllocator(num_blocks)) self._caches = tuple(caches) self._allocators = tuple(allocators) diff --git a/deepspeed/inference/v2/ragged/ragged_manager.py b/deepspeed/inference/v2/ragged/ragged_manager.py index 0df76b99e2c5..a10cdc812ffb 100644 --- a/deepspeed/inference/v2/ragged/ragged_manager.py +++ b/deepspeed/inference/v2/ragged/ragged_manager.py @@ -57,7 +57,8 @@ class DSStateManager: def __init__(self, config: DSStateManagerConfig, kv_configs: Tuple[KVCacheConfig, ...], - base_mp_group: Optional[Any] = None) -> None: + base_mp_group: Optional[Any] = None, + enable_prefix_cache: bool = False) -> None: """ The key @@ -97,7 +98,8 @@ def __init__(self, self._kv_cache = BlockedKVCache(self._kv_configs, self._config.memory_config, mp_group=base_mp_group, - offload=self._config.offload) + offload=self._config.offload, + enable_prefix_cache=enable_prefix_cache) assert len(self._kv_configs) == 1, "Only one KV cache group is supported for now." self._block_tree = PrefixBlockMap(self._kv_configs[0].block_size) From f10f6f44e242c25215db9301d17f99d8112f8711 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Mon, 27 May 2024 19:32:12 +0000 Subject: [PATCH 08/17] match type hint --- .../v2/model_implementations/inference_model_base.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/deepspeed/inference/v2/model_implementations/inference_model_base.py b/deepspeed/inference/v2/model_implementations/inference_model_base.py index 894a4137407e..d1bba390c70f 100644 --- a/deepspeed/inference/v2/model_implementations/inference_model_base.py +++ b/deepspeed/inference/v2/model_implementations/inference_model_base.py @@ -204,7 +204,7 @@ def get_remaining_block_capacity(self, sequence: DSSequenceDescriptor) -> int: raise NotImplementedError() @abstractmethod - def maybe_allocate_kv(self, sequence: DSSequenceDescriptor, n_new_tokens: int) -> None: + def maybe_allocate_kv(self, sequence: DSSequenceDescriptor, n_new_tokens: int) -> Optional[torch.Tensor]: """ Given a sequence and the number of new tokens in the sequence, determine whether or not additional KV-storage is needed and allocate it if so. @@ -212,6 +212,9 @@ def maybe_allocate_kv(self, sequence: DSSequenceDescriptor, n_new_tokens: int) - Args: sequence (DSSequenceDescriptor): The sequence for which to allocate KV-storage. n_new_tokens (int): The number of new tokens in the sequence. + + Returns: + Optional[torch.Tensor]: The allocated KV block IDs. If no new blocks are needed, this returns None. """ raise NotImplementedError() From 0bedc39b4e4a6b8db07549c7a8bfccf680b036f8 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Mon, 27 May 2024 20:05:37 +0000 Subject: [PATCH 09/17] refactor allocator --- .../inference/v2/ragged/blocked_allocator.py | 128 ++++++++---------- 1 file changed, 54 insertions(+), 74 deletions(-) diff --git a/deepspeed/inference/v2/ragged/blocked_allocator.py b/deepspeed/inference/v2/ragged/blocked_allocator.py index 7f93c22b9c3f..e9b2ec670caa 100644 --- a/deepspeed/inference/v2/ragged/blocked_allocator.py +++ b/deepspeed/inference/v2/ragged/blocked_allocator.py @@ -4,29 +4,23 @@ # DeepSpeed Team from typing import Iterable, Union, List +from abc import ABC, abstractmethod import torch -class BlockedAllocator: +class BlockedAllocatorBase(ABC): """ Allocator class for managing which blocks are free/used in the - blocked KV-cache. This is a simple allocator that uses a linked list - to keep track of which blocks are free/used. The cost of allocation/deallocation - is O(blocks), where blocks is the number of blocks to allocate/deallocate. - - TODO(cmikeh2): Evaluate performance of this allocator and migrate - to C++ if necessary. + blocked KV-cache. """ + # Number of blocks in the KV-cache(s). _num_blocks: int # Array of blocks, where each element is the next block in the linked list. _blocks: torch.Tensor - # Index of the head of the linked list. - _head: int - # Number of free blocks in the KV-cache. _free_blocks: int @@ -43,10 +37,9 @@ def __init__(self, num_blocks: int) -> None: raise ValueError(f'Blocked KV-cache must have at least 1 block, provided {num_blocks}') self._num_blocks = num_blocks - self._blocks = torch.arange(1, num_blocks + 1, dtype=torch.int32, device='cpu', pin_memory=True) - self._head = 0 self._free_blocks = num_blocks + @abstractmethod def allocate(self, num_blocks: int) -> torch.Tensor: """ Allocate a list of blocks from the associated KV-caches. This will @@ -57,7 +50,50 @@ def allocate(self, num_blocks: int) -> torch.Tensor: num_blocks (int): The number of blocks to allocate. Returns: - List[int]: The list of blocks allocated. + torch.Tensor: The list of blocks allocated. + """ + pass + + @abstractmethod + def free(self, blocks: Union[Iterable[int], int]) -> None: + """ + Free a list of blocks in the associated KV-caches. This will return + the blocks to the free pool if they are valid and not already free. + + Parameters: + blocks (Union[Iterable[int], int]): The list of blocks to free. If only one block + is to be freed, this can be alone as an integer. + """ + pass + + @property + def free_blocks(self) -> int: + """ + Return the number of free blocks in the KV-cache. + """ + return self._free_blocks + + +class BlockedAllocator(BlockedAllocatorBase): + """ + This is a simple allocator that uses a linked list to keep track of which blocks are free/used. The cost of allocation/deallocation + is O(blocks), where blocks is the number of blocks to allocate/deallocate. + + TODO(cmikeh2): Evaluate performance of this allocator and migrate + to C++ if necessary. + """ + + # Index of the head of the linked list. + _head: int + + def __init__(self, num_blocks: int) -> None: + super().__init__(num_blocks) + self._blocks = torch.arange(1, num_blocks + 1, dtype=torch.int32, device='cpu', pin_memory=True) + self._head = 0 + + def allocate(self, num_blocks: int) -> torch.Tensor: + """ + Refer to the docstring of `BlockedAllocatorBase.allocate`. """ if num_blocks > self._free_blocks: raise ValueError(f'Not enough free blocks in the KV-cache to allocate {num_blocks} blocks') @@ -73,13 +109,7 @@ def allocate(self, num_blocks: int) -> torch.Tensor: def free(self, blocks: Union[Iterable[int], int]) -> None: """ - Return a list of blocks to the free pool. If a single invalid block is provided (i.e., - one that is out of range of the allocator or is already free), then an exception is raised - and no blocks are freed. - - Parameters: - blocks (Union[Iterable[int], int]): The list of blocks to free. If only one block - is to be freed, this can be alone as an integer. + Refer to the docstring of `BlockedAllocatorBase.free`. """ if isinstance(blocks, int): blocks = [blocks] @@ -92,51 +122,19 @@ def free(self, blocks: Union[Iterable[int], int]) -> None: if self._blocks[block] != -1: raise ValueError(f'Block {block} is already free') - for block in reversed(blocks): + for block in blocks: self._blocks[block] = self._head self._head = block self._free_blocks += 1 - @property - def free_blocks(self) -> int: - """ - Return the number of free blocks in the KV-cache. - """ - return self._free_blocks - - -class LinearScanBlockedAllocator: - # Number of blocks in the KV-cache(s). - _num_blocks: int - - # Array of blocks, where each element is the next block in the linked list. - _blocks: torch.Tensor - - # Number of free blocks in the KV-cache. - _free_blocks: int +class LinearScanBlockedAllocator(BlockedAllocatorBase): def __init__(self, num_blocks: int) -> None: - - if num_blocks < 1: - raise ValueError(f'Blocked KV-cache must have at least 1 block, provided {num_blocks}') - - self._num_blocks = num_blocks + super().__init__(num_blocks) self._blocks = torch.zeros(num_blocks + 1, dtype=torch.int32, device='cpu', pin_memory=True) - self._free_blocks = num_blocks - - def allocate(self, num_blocks: int) -> torch.Tensor: - """ - Allocate a list of blocks from the associated KV-caches. This will - return `num_blocks` blocks from the KV-cache if they are available, - or raise an exception if there are not enough free blocks. - Parameters: - num_blocks (int): The number of blocks to allocate. - Returns: - List[int]: The list of blocks allocated. - """ - + def allocate(self, num_blocks: int) -> torch.Tensor: if num_blocks > self._free_blocks: raise ValueError(f'Not enough free blocks in the KV-cache to allocate {num_blocks} blocks') @@ -155,23 +153,12 @@ def allocate(self, num_blocks: int) -> torch.Tensor: return allocated_blocks def allocate_blocks(self, blocks: List[int]) -> None: - for block in blocks: - # assert self._blocks[block] == 0, f"Block {block} is already allocated" if self._blocks[block] != -1: self._blocks[block] = -1 self._free_blocks -= 1 def free(self, blocks: Union[Iterable[int], int]) -> None: - """ - Return a list of blocks to the free pool. If a single invalid block is provided (i.e., - one that is out of range of the allocator or is already free), then an exception is raised - and no blocks are freed. - - Parameters: - blocks (Union[Iterable[int], int]): The list of blocks to free. If only one block - is to be freed, this can be alone as an integer. - """ if isinstance(blocks, int): blocks = [blocks] @@ -187,10 +174,3 @@ def free(self, blocks: Union[Iterable[int], int]) -> None: for block in reversed(blocks): self._blocks[block] = 0 self._free_blocks += 1 - - @property - def free_blocks(self) -> int: - """ - Return the number of free blocks in the KV-cache. - """ - return self._free_blocks From bc92d2be3e7238805a117fb1304c1ff914ff6c71 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Mon, 27 May 2024 20:10:21 +0000 Subject: [PATCH 10/17] simplify loop --- deepspeed/inference/v2/config_v2.py | 1 + deepspeed/inference/v2/engine_v2.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/deepspeed/inference/v2/config_v2.py b/deepspeed/inference/v2/config_v2.py index d88d233a3632..7f863ee2b1ce 100644 --- a/deepspeed/inference/v2/config_v2.py +++ b/deepspeed/inference/v2/config_v2.py @@ -43,4 +43,5 @@ class RaggedInferenceEngineConfig(DeepSpeedConfigModel): quantization: QuantizationConfig = {} enable_prefix_cache: bool = False + """ Enable prefix cache for the model. """ diff --git a/deepspeed/inference/v2/engine_v2.py b/deepspeed/inference/v2/engine_v2.py index dcde7e5ecc44..856768c36e25 100644 --- a/deepspeed/inference/v2/engine_v2.py +++ b/deepspeed/inference/v2/engine_v2.py @@ -152,7 +152,7 @@ def put(self, # We return one set of logits per sequence in the batch (saves cost on unembedding) assert logits.shape[0] == self._batch.current_sequences - for uid, tokens in zip(batch_uids, batch_tokens): + for uid in batch_uids: host_seq_desc = self._state_manager.get_sequence(uid) host_seq_desc.post_forward() # Updates sequence metadata. self._model.maybe_free_kv(host_seq_desc) From c2028ec3d58349bbb8b196779562b340afe0c409 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Mon, 27 May 2024 20:13:59 +0000 Subject: [PATCH 11/17] refactor --- deepspeed/inference/v2/config_v2.py | 1 - deepspeed/inference/v2/engine_v2.py | 4 ++-- .../inference_transformer_base.py | 2 +- .../inference/v2/ragged/blocked_allocator.py | 4 ++-- deepspeed/inference/v2/ragged/kv_cache.py | 2 +- ...prefix_block_tree.py => prefix_block_map.py} | 17 ++++++++++------- deepspeed/inference/v2/ragged/ragged_manager.py | 9 ++++----- 7 files changed, 20 insertions(+), 19 deletions(-) rename deepspeed/inference/v2/ragged/{prefix_block_tree.py => prefix_block_map.py} (85%) diff --git a/deepspeed/inference/v2/config_v2.py b/deepspeed/inference/v2/config_v2.py index 7f863ee2b1ce..75742a6e8ec7 100644 --- a/deepspeed/inference/v2/config_v2.py +++ b/deepspeed/inference/v2/config_v2.py @@ -44,4 +44,3 @@ class RaggedInferenceEngineConfig(DeepSpeedConfigModel): enable_prefix_cache: bool = False """ Enable prefix cache for the model. """ - diff --git a/deepspeed/inference/v2/engine_v2.py b/deepspeed/inference/v2/engine_v2.py index 856768c36e25..8b79ddc803fe 100644 --- a/deepspeed/inference/v2/engine_v2.py +++ b/deepspeed/inference/v2/engine_v2.py @@ -253,8 +253,8 @@ def lookup_cache(self, tokens: torch.Tensor) -> Tuple[int, torch.Tensor]: if self._config.enable_prefix_cache: return self._state_manager.lookup_cache(tokens) return 0, torch.tensor([]) - - def setup_cached_sequence(self, uid: int, cached_length:int, block_ids: torch.Tensor) -> None: + + def setup_cached_sequence(self, uid: int, cached_length: int, block_ids: torch.Tensor) -> None: if self._config.enable_prefix_cache: seq = self._state_manager.get_or_create_sequence(uid) seq.pre_forward(cached_length) diff --git a/deepspeed/inference/v2/model_implementations/inference_transformer_base.py b/deepspeed/inference/v2/model_implementations/inference_transformer_base.py index 8e697266dbeb..751ce0de1af1 100644 --- a/deepspeed/inference/v2/model_implementations/inference_transformer_base.py +++ b/deepspeed/inference/v2/model_implementations/inference_transformer_base.py @@ -370,7 +370,7 @@ def maybe_allocate_kv(self, sequence: DSSequenceDescriptor, n_new_tokens: int) - new_blocks = self.state_manager.allocate_blocks(n_needed_blocks) sequence.extend_kv_cache(new_blocks) return new_blocks - + return None def kv_cache_config(self) -> Tuple[KVCacheConfig, ...]: diff --git a/deepspeed/inference/v2/ragged/blocked_allocator.py b/deepspeed/inference/v2/ragged/blocked_allocator.py index e9b2ec670caa..790eaf9abb7f 100644 --- a/deepspeed/inference/v2/ragged/blocked_allocator.py +++ b/deepspeed/inference/v2/ragged/blocked_allocator.py @@ -12,7 +12,7 @@ class BlockedAllocatorBase(ABC): """ Allocator class for managing which blocks are free/used in the - blocked KV-cache. + blocked KV-cache. """ # Number of blocks in the KV-cache(s). @@ -129,11 +129,11 @@ def free(self, blocks: Union[Iterable[int], int]) -> None: class LinearScanBlockedAllocator(BlockedAllocatorBase): + def __init__(self, num_blocks: int) -> None: super().__init__(num_blocks) self._blocks = torch.zeros(num_blocks + 1, dtype=torch.int32, device='cpu', pin_memory=True) - def allocate(self, num_blocks: int) -> torch.Tensor: if num_blocks > self._free_blocks: raise ValueError(f'Not enough free blocks in the KV-cache to allocate {num_blocks} blocks') diff --git a/deepspeed/inference/v2/ragged/kv_cache.py b/deepspeed/inference/v2/ragged/kv_cache.py index fb45740ef25a..5bcb766aa690 100644 --- a/deepspeed/inference/v2/ragged/kv_cache.py +++ b/deepspeed/inference/v2/ragged/kv_cache.py @@ -137,7 +137,7 @@ def __init__(self, f"Allocating KV-cache {cache_group_id} with shape: {alloc_shape} consisting of {num_blocks} blocks.") caches.append(torch.empty(alloc_shape, dtype=config.cache_dtype, device=get_accelerator().current_device())) - + if enable_prefix_cache: allocators.append(LinearScanBlockedAllocator(num_blocks)) else: diff --git a/deepspeed/inference/v2/ragged/prefix_block_tree.py b/deepspeed/inference/v2/ragged/prefix_block_map.py similarity index 85% rename from deepspeed/inference/v2/ragged/prefix_block_tree.py rename to deepspeed/inference/v2/ragged/prefix_block_map.py index 61d341a12992..85a2e83140f6 100644 --- a/deepspeed/inference/v2/ragged/prefix_block_tree.py +++ b/deepspeed/inference/v2/ragged/prefix_block_map.py @@ -1,6 +1,10 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + from typing import Dict, List, Set import hashlib -from collections import defaultdict import torch @@ -21,12 +25,12 @@ def __init__(self, block_size: int): self.tokens_to_blocks: Dict[str, List[int]] = {} self.blocks_to_tokens: Dict[Set[int], str] = {} self.block_size: int = block_size - + def lookup(self, tokens: torch.Tensor) -> torch.Tensor: n_blocks = len(tokens) // self.block_size cached_blocks = torch.tensor([], dtype=torch.int32) for i in range(n_blocks): - chunk = tokens[:(i+1)*self.block_size] + chunk = tokens[:(i + 1) * self.block_size] hash = token_ids_to_hash(chunk) if hash in self.tokens_to_blocks: cached_blocks = self.tokens_to_blocks[hash] @@ -37,11 +41,11 @@ def lookup(self, tokens: torch.Tensor) -> torch.Tensor: def extend(self, tokens: torch.Tensor, new_block_ids: List[int]) -> None: n_blocks = len(tokens) // self.block_size for i in range(n_blocks): - chunk = tokens[:(i+1)*self.block_size] + chunk = tokens[:(i + 1) * self.block_size] hash = token_ids_to_hash(chunk) if hash not in self.tokens_to_blocks: - self.tokens_to_blocks[hash] = new_block_ids[:i+1] - self.blocks_to_tokens[frozenset(new_block_ids[:i+1])] = hash + self.tokens_to_blocks[hash] = new_block_ids[:i + 1] + self.blocks_to_tokens[frozenset(new_block_ids[:i + 1])] = hash def delete(self, block_ids: List[int]) -> None: blocks_set = frozenset(block_ids) @@ -50,4 +54,3 @@ def delete(self, block_ids: List[int]) -> None: if blocks_set & used_blocks: del self.tokens_to_blocks[hash] del self.blocks_to_tokens[used_blocks] - diff --git a/deepspeed/inference/v2/ragged/ragged_manager.py b/deepspeed/inference/v2/ragged/ragged_manager.py index a10cdc812ffb..3de4059283eb 100644 --- a/deepspeed/inference/v2/ragged/ragged_manager.py +++ b/deepspeed/inference/v2/ragged/ragged_manager.py @@ -15,7 +15,7 @@ from .kv_cache import BlockedKVCache from .manager_configs import DSStateManagerConfig, KVCacheConfig from .sequence_descriptor import DSSequenceDescriptor -from .prefix_block_tree import PrefixBlockMap +from .prefix_block_map import PrefixBlockMap class DSStateManager: @@ -102,10 +102,9 @@ def __init__(self, enable_prefix_cache=enable_prefix_cache) assert len(self._kv_configs) == 1, "Only one KV cache group is supported for now." - self._block_tree = PrefixBlockMap(self._kv_configs[0].block_size) + self._block_map = PrefixBlockMap(self._kv_configs[0].block_size) self._ref_counts = defaultdict(int) - def get_cache(self, cache_id: int, cache_group: int = 0) -> torch.Tensor: """ Return the Tensor associated with the given cache id in the specified cache group. @@ -176,7 +175,7 @@ def _create_sequence(self, uid: int) -> DSSequenceDescriptor: def lookup_cache(self, tokens: torch.Tensor) -> int: - block_ids = self._block_tree.lookup(tokens) + block_ids = self._block_map.lookup(tokens) assert len(self._kv_configs) == 1, "Only one KV cache group is supported for now." cache_hit_length = len(block_ids) * self._kv_configs[0].block_size @@ -192,7 +191,7 @@ def update_cache(self, uid: int, tokens: torch.Tensor) -> None: Update the KV cache for the given sequence id. """ seq = self.get_sequence(uid) - self._block_tree.extend(tokens, seq.all_block_ids()) + self._block_map.extend(tokens, seq.all_block_ids()) def increment_ref_count(self, block_ids: torch.Tensor) -> None: for block_id in block_ids: From 959204c71614bb4b2c21ba2fa448454dae863525 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Mon, 27 May 2024 20:15:13 +0000 Subject: [PATCH 12/17] remove unnecessary reverse --- deepspeed/inference/v2/ragged/blocked_allocator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/inference/v2/ragged/blocked_allocator.py b/deepspeed/inference/v2/ragged/blocked_allocator.py index 790eaf9abb7f..b54817360cd7 100644 --- a/deepspeed/inference/v2/ragged/blocked_allocator.py +++ b/deepspeed/inference/v2/ragged/blocked_allocator.py @@ -171,6 +171,6 @@ def free(self, blocks: Union[Iterable[int], int]) -> None: if self._blocks[block] != -1: raise ValueError(f'Block {block} is already free') - for block in reversed(blocks): + for block in blocks: self._blocks[block] = 0 self._free_blocks += 1 From 9289ff76cf2d070f8e5e726d3ebc19e2ef59ee59 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Mon, 27 May 2024 20:30:17 +0000 Subject: [PATCH 13/17] fix attribute name --- deepspeed/inference/v2/engine_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/inference/v2/engine_v2.py b/deepspeed/inference/v2/engine_v2.py index 8b79ddc803fe..533d2ea4f34c 100644 --- a/deepspeed/inference/v2/engine_v2.py +++ b/deepspeed/inference/v2/engine_v2.py @@ -133,7 +133,7 @@ def put(self, new_block_ids = self._model.maybe_allocate_kv(host_seq_desc, tokens.numel()) if self._config.enable_prefix_cache and new_block_ids is not None: self._state_manager.increment_ref_count(new_block_ids) - self._state_manager._block_tree.delete(new_block_ids) + self._state_manager._block_map.delete(new_block_ids) host_seq_desc.pre_forward(tokens.numel()) # We can disable checks since we already validated schedulability. From 0c8e0e60c12682e13369ebf8fbe1bfd6a8b02fac Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Thu, 30 May 2024 02:44:01 +0000 Subject: [PATCH 14/17] update prefix cache at every iteration --- .../inference/v2/ragged/prefix_block_map.py | 4 ++-- .../inference/v2/ragged/ragged_manager.py | 5 ++++- .../v2/ragged/sequence_descriptor.py | 21 +++++++++++++++++++ 3 files changed, 27 insertions(+), 3 deletions(-) diff --git a/deepspeed/inference/v2/ragged/prefix_block_map.py b/deepspeed/inference/v2/ragged/prefix_block_map.py index 85a2e83140f6..d59c750366b2 100644 --- a/deepspeed/inference/v2/ragged/prefix_block_map.py +++ b/deepspeed/inference/v2/ragged/prefix_block_map.py @@ -38,9 +38,9 @@ def lookup(self, tokens: torch.Tensor) -> torch.Tensor: break return cached_blocks - def extend(self, tokens: torch.Tensor, new_block_ids: List[int]) -> None: + def extend(self, tokens: torch.Tensor, new_block_ids: List[int], num_already_cached_blocks: int) -> None: n_blocks = len(tokens) // self.block_size - for i in range(n_blocks): + for i in range(num_already_cached_blocks, n_blocks): chunk = tokens[:(i + 1) * self.block_size] hash = token_ids_to_hash(chunk) if hash not in self.tokens_to_blocks: diff --git a/deepspeed/inference/v2/ragged/ragged_manager.py b/deepspeed/inference/v2/ragged/ragged_manager.py index 3de4059283eb..fa07f663d54a 100644 --- a/deepspeed/inference/v2/ragged/ragged_manager.py +++ b/deepspeed/inference/v2/ragged/ragged_manager.py @@ -191,7 +191,10 @@ def update_cache(self, uid: int, tokens: torch.Tensor) -> None: Update the KV cache for the given sequence id. """ seq = self.get_sequence(uid) - self._block_map.extend(tokens, seq.all_block_ids()) + num_full_blocks = tokens.numel() // self._kv_configs[0].block_size + if num_full_blocks > seq.num_prefix_cache_blocks: + self._block_map.extend(tokens, seq.all_block_ids(), seq.num_prefix_cache_blocks) + seq.num_prefix_cache_blocks = num_full_blocks def increment_ref_count(self, block_ids: torch.Tensor) -> None: for block_id in block_ids: diff --git a/deepspeed/inference/v2/ragged/sequence_descriptor.py b/deepspeed/inference/v2/ragged/sequence_descriptor.py index 6b9f65255eec..25c1072ff042 100644 --- a/deepspeed/inference/v2/ragged/sequence_descriptor.py +++ b/deepspeed/inference/v2/ragged/sequence_descriptor.py @@ -95,6 +95,8 @@ class DSSequenceDescriptor(BaseSequenceDescriptor): # are stored. Used on flush. _tracking_id: int + _num_prefix_cache_blocks: int + def __init__(self, tracking_id: int, kv_cache_ids: Tuple[torch.Tensor, ...], @@ -132,6 +134,8 @@ def __init__(self, assert self._num_allocation_groups[cache_group] == kv_cache_ids.shape[0] assert len(kv_cache_ids.shape) == 2 + self._num_prefix_cache_blocks = 0 + @property def seen_tokens(self) -> int: """ @@ -278,3 +282,20 @@ def free_kv_cache(self, free_ids: Union[List[torch.IntTensor], torch.IntTensor], to have the same shape. """ raise NotImplementedError("Partial KV-cache freeing is not yet supported.") + + @property + def num_prefix_cache_blocks(self) -> int: + """ + The number of prefix cache blocks for the sequence. + """ + return self._num_prefix_cache_blocks + + @num_prefix_cache_blocks.setter + def num_prefix_cache_blocks(self, num_blocks: int) -> None: + """ + Set the number of prefix cache blocks for the sequence. + + Arguments: + num_blocks (int): The number of prefix cache blocks for the sequence. + """ + self._num_prefix_cache_blocks = num_blocks From ea50fb5d3f47d02025f61a108719b0ed507ba68e Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Fri, 31 May 2024 07:35:32 +0000 Subject: [PATCH 15/17] skip looking up cache --- deepspeed/inference/v2/engine_v2.py | 1 + 1 file changed, 1 insertion(+) diff --git a/deepspeed/inference/v2/engine_v2.py b/deepspeed/inference/v2/engine_v2.py index 533d2ea4f34c..f310dafe7984 100644 --- a/deepspeed/inference/v2/engine_v2.py +++ b/deepspeed/inference/v2/engine_v2.py @@ -260,6 +260,7 @@ def setup_cached_sequence(self, uid: int, cached_length: int, block_ids: torch.T seq.pre_forward(cached_length) seq.post_forward() seq.extend_kv_cache(block_ids) + seq.num_prefix_cache_blocks = len(block_ids) self._state_manager.increment_ref_count(block_ids) self._state_manager._kv_cache.allocate_blocks(block_ids) From 1493ab7e6e15b1ab0f7051d38e03aa62e3f0a647 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Mon, 3 Jun 2024 00:11:05 +0000 Subject: [PATCH 16/17] fix prefix tokens to cache --- deepspeed/inference/v2/ragged/ragged_manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deepspeed/inference/v2/ragged/ragged_manager.py b/deepspeed/inference/v2/ragged/ragged_manager.py index fa07f663d54a..01958e7e1fc5 100644 --- a/deepspeed/inference/v2/ragged/ragged_manager.py +++ b/deepspeed/inference/v2/ragged/ragged_manager.py @@ -191,9 +191,9 @@ def update_cache(self, uid: int, tokens: torch.Tensor) -> None: Update the KV cache for the given sequence id. """ seq = self.get_sequence(uid) - num_full_blocks = tokens.numel() // self._kv_configs[0].block_size + num_full_blocks = seq.seen_tokens // self._kv_configs[0].block_size if num_full_blocks > seq.num_prefix_cache_blocks: - self._block_map.extend(tokens, seq.all_block_ids(), seq.num_prefix_cache_blocks) + self._block_map.extend(tokens[:seq.seen_tokens], seq.all_block_ids(), seq.num_prefix_cache_blocks) seq.num_prefix_cache_blocks = num_full_blocks def increment_ref_count(self, block_ids: torch.Tensor) -> None: From 07a4c441bc1a38ff919c763a9d398e4c5786b97d Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Mon, 3 Jun 2024 16:40:46 +0000 Subject: [PATCH 17/17] add assertion --- deepspeed/inference/v2/engine_v2.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/deepspeed/inference/v2/engine_v2.py b/deepspeed/inference/v2/engine_v2.py index f310dafe7984..71fca3ee50c3 100644 --- a/deepspeed/inference/v2/engine_v2.py +++ b/deepspeed/inference/v2/engine_v2.py @@ -132,6 +132,9 @@ def put(self, host_seq_desc = self._state_manager.get_or_create_sequence(uid) new_block_ids = self._model.maybe_allocate_kv(host_seq_desc, tokens.numel()) if self._config.enable_prefix_cache and new_block_ids is not None: + for block_id in new_block_ids: + assert self._state_manager._ref_counts[block_id.item()] == 0 + self._state_manager.increment_ref_count(new_block_ids) self._state_manager._block_map.delete(new_block_ids) host_seq_desc.pre_forward(tokens.numel())