diff --git a/examples/offline_inference_rerope.py b/examples/offline_inference_rerope.py new file mode 100644 index 000000000..34f33ed36 --- /dev/null +++ b/examples/offline_inference_rerope.py @@ -0,0 +1,179 @@ +import contextlib +import json +import os +import sys +import time +from dataclasses import asdict + +from transformers import AutoTokenizer + +# setting for rerope +os.environ["VLLM_USE_REROPE"] = "true" + +# Third Party +from vllm import LLM, SamplingParams +from vllm.config import KVTransferConfig +from vllm.engine.arg_utils import EngineArgs + +from ucm.logger import init_logger + +logger = init_logger(__name__) + + +def setup_environment_variables(): + os.environ["VLLM_USE_V1"] = "1" + os.environ["PYTHONHASHSEED"] = "123456" + os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,6,7" + os.environ["VLLM_ATTENTION_BACKEND"] = "TRITON_ATTN_VLLM_V1" + os.environ["REROPE_WINDOW"] = "32768" + os.environ["TRAINING_LENGTH"] = "32768" + + global data_dir + data_dir = os.getenv("DATA_DIR", "/home/externals/wangwenxin21/wx_data") + + if not os.path.isdir(data_dir): + create = input(f"Directory {data_dir} dose not exist. Create it? (Y/n): ") + if create.lower() == "y": + os.makedirs(data_dir, exist_ok=True) + else: + print("Exiting. Directory not created.") + sys.exit(1) + + +@contextlib.contextmanager +def build_llm_with_uc(module_path: str, name: str, model: str): + ktc = KVTransferConfig( + kv_connector=name, + kv_connector_module_path=module_path, + kv_role="kv_both", + kv_connector_extra_config={ + "ucm_connectors": [ + { + "ucm_connector_name": "UcmNfsStore", + "ucm_connector_config": { + "storage_backends": data_dir, + "use_direct": False, + }, + } + ], + }, + ) + + llm_args = EngineArgs( + model=model, + kv_transfer_config=ktc, + hf_overrides={ + "max_position_embeddings": 430080, + }, + gpu_memory_utilization=0.8, + max_num_batched_tokens=8192, + block_size=16, + enforce_eager=True, + tensor_parallel_size=4, + ) + + llm = LLM(**asdict(llm_args)) + try: + yield llm + finally: + logger.info("LLM engine is exiting.") + + +def print_output( + llm: LLM, + prompt: list[str], + sampling_params: SamplingParams, + req_str: str, +): + start = time.time() + outputs = llm.generate(prompt, sampling_params) + print("-" * 50) + for output in outputs: + generated_text = output.outputs[0].text + print(f"Generated text: {generated_text!r}") + print(f"Generation took {time.time() - start:.2f} seconds, {req_str} request done.") + print("-" * 50) + + +def main(): + module_path = "ucm.integration.vllm.ucm_connector" + name = "UCMConnector" + model = os.getenv("MODEL_PATH", "/home/wx/models/Qwen2.5-14B-Instruct") + + tokenizer = AutoTokenizer.from_pretrained(model, use_chat_template=True) + setup_environment_variables() + + with build_llm_with_uc(module_path, name, model) as llm: + + data_all = [] + with open( + "/home/wx/va_clean/data/multifieldqa_zh.jsonl", "r", encoding="utf-8" + ) as f: + for line in f: + data_all.append(json.loads(line)) + + materials = [] + questions = [] + references = [] + batch_size = 75 + num_batch = 2 + for idx in range(num_batch): + data = data_all[idx * batch_size : (idx + 1) * batch_size] + + materials.append( + "\n\n".join( + [ + f"【语料{i+1}】\n{item.get('context', '')}" + for i, item in enumerate(data) + ] + ) + ) + questions.append( + "\n".join( + [ + f"{i+1}. {item.get('input', '')}" + for i, item in enumerate(data[:15]) + ] + ) + ) + references.append( + [ + f"{i+1}. {item.get('answers', '')}" + for i, item in enumerate(data[:15]) + ] + ) + + system_prompt = "你是一个AI助手,请根据以下材料回答问题。" + tokenized_inputs = [] + for material, question in zip(materials, questions): + content = ( + "请根据以下文本内容回答后面的问题:\n\n" + "【文本内容开始】\n" + f"{material}\n" + "【文本内容结束】\n\n" + "请回答以下问题:\n" + f"{question}" + ) + + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": content}, + ] + inputs = tokenizer.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=False, + ) + tokenized_inputs.append(inputs) + + sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=2048) + + for req in range(num_batch): + print_output( + llm, tokenized_inputs[req], sampling_params, "request_" + str(req) + ) + + +if __name__ == "__main__": + main() + \ No newline at end of file diff --git a/ucm/integration/vllm/patch/apply_patch.py b/ucm/integration/vllm/patch/apply_patch.py index 39f5ccbb0..67587ef3d 100644 --- a/ucm/integration/vllm/patch/apply_patch.py +++ b/ucm/integration/vllm/patch/apply_patch.py @@ -37,6 +37,13 @@ PLATFORM = os.getenv("PLATFORM") +vllm_use_rerope = os.getenv("VLLM_USE_REROPE", "0").lower() in ( + "1", + "true", + "yes", + "on", +) + def _patch_ascend() -> bool: return PLATFORM == "ascend" @@ -94,6 +101,8 @@ def apply_all_patches() -> None: # Apply version-specific patches match version: + case "0.9.2" if vllm_use_rerope: + _apply_patches_rerope() case "0.9.2": _apply_patches_v092() case _: @@ -120,6 +129,13 @@ def _apply_patches_v092() -> None: _apply_ascend_patch() # apply vllm-ascend-adapt.patch +def _apply_patches_rerope() -> None: + """Apply patches for vLLM 0.9.2 for triton rerope""" + from .patch_funcs.v092.vllm_rerope_patch import _apply_rerope_adapt_patches + + _apply_rerope_adapt_patches() + + def install_import_hook() -> None: """Install an import hook to automatically apply patches when vLLM is imported.""" global _import_hook_installed, _vllm_import_hook diff --git a/ucm/integration/vllm/patch/patch_funcs/v092/vllm_rerope_patch.py b/ucm/integration/vllm/patch/patch_funcs/v092/vllm_rerope_patch.py new file mode 100644 index 000000000..aecda1f85 --- /dev/null +++ b/ucm/integration/vllm/patch/patch_funcs/v092/vllm_rerope_patch.py @@ -0,0 +1,707 @@ +# +# MIT License +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# + +from __future__ import annotations + +from torch.library import Library + +from ucm.logger import init_logger + +logger = init_logger(__name__) + +_UCM_UNIFIED_ATTENTION_WITH_OUTPUT_REGISTERED = False + + +def _apply_rerope_adapt_patches() -> None: + try: + _patch_attention_spec() + _patch_request_succeed_dumped_blocks() + _patch_qwen_model() + _patch_attention_layer() + _patch_triton_attn() + + except Exception as e: + logger.error(f"Failed to apply aggre patch: {e}", exc_info=True) + raise + + +# ==================== vllm/v1/kv_cache_interface.py ==================== +def _patch_attention_spec() -> None: + """Patch modify the kv cache spec""" + try: + from vllm.utils import cdiv, get_dtype_size + from vllm.v1.kv_cache_interface import AttentionSpec + + def _page_size_bytes_rerope(self: "AttentionSpec") -> int: + """ + Patched version of page_size_bytes property. + REROPE support with coefficient=3. + """ + + coef = 3 + + return ( + coef + * self.block_size + * self.num_kv_heads + * self.head_size + * get_dtype_size(self.dtype) + ) + + AttentionSpec.page_size_bytes = property(_page_size_bytes_rerope) + + except ImportError: + logger.warning( + "Could not patch AttentionSpec with _page_size_bytes_rerope - module not found" + ) + + +# ==================== vllm/v1/request.py ==================== +def _patch_request_succeed_dumped_blocks() -> None: + """Patch Request to add succeed_dumped_blocks field.""" + try: + from vllm.v1.request import Request + + original_init = Request.__init__ + + def __init__(self, *args, **kwargs): + original_init(self, *args, **kwargs) + self.succeed_dumped_blocks = [] + + Request.__init__ = __init__ + except ImportError: + logger.warning("Could not patch Request.__init__ - module not found") + + +# ==================== vllm/model_executor/models/qwen2.py ==================== +def _patch_qwen_model() -> None: + """Patch qwen to support rerope""" + try: + import math + + import torch + from vllm.forward_context import get_forward_context + from vllm.model_executor.models.qwen2 import Qwen2Attention + + from ucm.sparse.rerope.rerope_utils import default_config + + REROPE_WINDOW = default_config.rerope_window + TRAINING_LENGTH = default_config.training_length + + def Qwen2Attention_forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + attn_metadata = get_forward_context().attn_metadata + + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + if attn_metadata and next(iter(attn_metadata.values())).use_rerope: + q *= ( + ((positions + 1)[:, None].log() / math.log(TRAINING_LENGTH)) + .clip(1) + .to(q.dtype) + ) + q2 = q.clone() + k2 = k.clone() + k0 = k.clone() + + q, k = self.rotary_emb(positions, q, k) + q2, _ = self.rotary_emb(positions * 0 + REROPE_WINDOW, q2, k2) + del k2 + else: + q, k = self.rotary_emb(positions, q, k) + q2, k0 = None, None + + attn_output = self.attn(q, k, q2, k0, v) + output, _ = self.o_proj(attn_output) + return output + + Qwen2Attention.forward = Qwen2Attention_forward + + except ImportError: + logger.warning("Could not patch qwen2 modelr - module not found") + + +# ==================== vllm/attention/layer.py ==================== +def _patch_attention_layer() -> None: + """Patch attention layer""" + try: + from typing import Optional + + import torch + from vllm.attention.layer import ( + maybe_save_kv_layer_to_connector, + wait_for_kv_layer_from_connector, + ) + from vllm.forward_context import ForwardContext, get_forward_context + + def attn_forward( + self, + query: torch.Tensor, + key: torch.Tensor, + query2: Optional[torch.Tensor], + key2: Optional[torch.Tensor], + value: torch.Tensor, + # For some alternate attention backends like MLA the attention output + # shape does not match the query shape, so we optionally let the model + # definition specify the output tensor shape. + output_shape: Optional[torch.Size] = None, + ) -> torch.Tensor: + """ + The KV cache is stored inside this class and is accessed via + `self.kv_cache`. + Attention metadata (`attn_metadata`) is set using a context manager in + the model runner's `execute_model` method. It is accessed via forward + context using + `vllm.forward_context.get_forward_context().attn_metadata`. + """ + if self.calculate_kv_scales: + attn_metadata = get_forward_context().attn_metadata + if attn_metadata.enable_kv_scales_calculation: + self.calc_kv_scales(query, key, value) + if self.use_output: + output_shape = output_shape if output_shape is not None else query.shape + output = torch.zeros( + output_shape, dtype=query.dtype, device=query.device + ) + hidden_size = output_shape[-1] + # We skip reshaping query, key and value tensors for the MLA + # backend since these tensors have different semantics and are + # processed differently. + if not self.use_mla: + # Reshape the query, key, and value tensors. + # NOTE(woosuk): We do this outside the custom op to minimize the + # CPU overheads from the non-CUDA-graph regions. + query = query.view(-1, self.num_heads, self.head_size) + output = output.view(-1, self.num_heads, self.head_size) + if query2 is not None: + query2 = query2.view(-1, self.num_heads, self.head_size) + if key is not None: + key = key.view(-1, self.num_kv_heads, self.head_size) + if key2 is not None: + key2 = key2.view(-1, self.num_kv_heads, self.head_size) + if value is not None: + value = value.view(-1, self.num_kv_heads, self.head_size) + + if self.use_direct_call: + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[self.layer_name] + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + self.impl.forward( + self, + query, + key, + query2, + key2, + value, + self_kv_cache, + attn_metadata, + output=output, + ) + else: + torch.ops.vllm.unified_attention_with_output( + query, key, query2, key2, value, output, self.layer_name + ) + return output.view(-1, hidden_size) + else: + if self.use_direct_call: + forward_context = get_forward_context() + attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[self.layer_name] + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + return self.impl.forward( + self, + query, + key, + query2, + key2, + value, + self_kv_cache, + attn_metadata, + ) + else: + return torch.ops.vllm.unified_attention( + query, key, query2, key2, value, self.layer_name + ) + + vllm_ops = torch.ops.vllm + orig_unified_attention_with_output = vllm_ops.unified_attention_with_output + orig_unified_attention = vllm_ops.unified_attention + + def _wrap_op_overload(orig, impl): + class _Wrapper: + def __init__(self, orig): + self._orig = orig + + def __call__(self, *args, **kwargs): + return impl(*args, **kwargs) + + def __getattr__(self, name): + return getattr(self._orig, name) + + return _Wrapper(orig) + + def unified_attention_impl( + query: torch.Tensor, + key: torch.Tensor, + query2: Optional[torch.Tensor], + key2: Optional[torch.Tensor], + value: torch.Tensor, + layer_name: str, + ) -> torch.Tensor: + wait_for_kv_layer_from_connector(layer_name) + + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[layer_name] + self = forward_context.no_compile_layers[layer_name] + kv_cache = self.kv_cache[forward_context.virtual_engine] + + output = self.impl.forward( + self, query, key, query2, key2, value, kv_cache, attn_metadata + ) + + maybe_save_kv_layer_to_connector(layer_name, kv_cache) + return output + + def unified_attention_with_output_impl( + query: torch.Tensor, + key: torch.Tensor, + query2: Optional[torch.Tensor], + key2: Optional[torch.Tensor], + value: torch.Tensor, + output: torch.Tensor, + layer_name: str, + output_scale: Optional[torch.Tensor] = None, + ) -> None: + wait_for_kv_layer_from_connector(layer_name) + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[layer_name] + self = forward_context.no_compile_layers[layer_name] + kv_cache = self.kv_cache[forward_context.virtual_engine] + self.impl.forward( + self, + query, + key, + query2, + key2, + value, + kv_cache, + attn_metadata, + output=output, + output_scale=output_scale, + ) + maybe_save_kv_layer_to_connector(layer_name, kv_cache) + + vllm_ops.unified_attention_with_output = _wrap_op_overload( + orig_unified_attention_with_output, unified_attention_with_output_impl + ) + vllm_ops.unified_attention = _wrap_op_overload( + orig_unified_attention, unified_attention_impl + ) + from vllm.attention import layer + + layer.Attention.forward = attn_forward + layer.unified_attention = unified_attention_impl + layer.unified_attention_with_output = unified_attention_with_output_impl + + except ImportError: + logger.warning("Could not patch layer - module not found") + + +# ==================== vllm/v1/attention/backends/triton_attn.py ==================== +def _patch_triton_attn() -> None: + """Patch triton_attn to support rerope""" + try: + from dataclasses import dataclass + from typing import Optional + + import torch + from vllm import _custom_ops as ops + from vllm.attention.ops.triton_unified_attention import unified_attention + from vllm.platforms import current_platform + from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata + from vllm.v1.attention.backends.triton_attn import ( + TritonAttentionBackend, + TritonAttentionImpl, + TritonAttentionMetadata, + TritonAttentionMetadataBuilder, + ) + from vllm.v1.attention.backends.utils import ( + CommonAttentionMetadata, + make_local_attention_virtual_batches, + ) + + from ucm.sparse.rerope.rerope_utils import default_config + from ucm.sparse.rerope.triton_unified_attention_rerope import ( + unified_attention_rerope, + ) + + REROPE_WINDOW = default_config.rerope_window + + @dataclass + class TritonAttentionMetadata_add: + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + num_actual_tokens: int # Number of tokens excluding padding. + max_query_len: int + query_start_loc: torch.Tensor + max_seq_len: int + seq_lens: torch.Tensor + block_table: torch.Tensor + slot_mapping: torch.Tensor + + # For cascade attention. + use_cascade: bool + common_prefix_len: int + cu_prefix_query_lens: Optional[torch.Tensor] + prefix_kv_lens: Optional[torch.Tensor] + suffix_kv_lens: Optional[torch.Tensor] + + use_rerope: bool = False + + # Optional aot scheduling + scheduler_metadata: Optional[torch.Tensor] = None + prefix_scheduler_metadata: Optional[torch.Tensor] = None + + # for local attention + @dataclass + class LocalAttentionMetadata: + local_query_start_loc: torch.Tensor + local_seqused_k: torch.Tensor + local_block_table: torch.Tensor + local_max_query_len: int + local_max_seq_len: int + local_scheduler_metadata: Optional[torch.Tensor] + + local_attn_metadata: Optional[LocalAttentionMetadata] = None + + TritonAttentionMetadata = TritonAttentionMetadata_add + + def TritonAttentionMetadataBuilder_build( + self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata + ) -> TritonAttentionMetadata: + num_reqs = common_attn_metadata.num_reqs + num_actual_tokens = common_attn_metadata.num_actual_tokens + max_query_len = common_attn_metadata.max_query_len + + max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max()) + query_start_loc = common_attn_metadata.query_start_loc + seq_lens = common_attn_metadata.seq_lens + block_table = self.block_table + block_table_tensor = block_table.get_device_tensor()[:num_reqs] + + block_table.slot_mapping[:num_actual_tokens].copy_( + block_table.slot_mapping_cpu[:num_actual_tokens], non_blocking=True + ) + # Fill unused with -1. Needed for reshape_and_cache in full cuda graph + # mode. + block_table.slot_mapping[num_actual_tokens:].fill_(-1) + + slot_mapping = block_table.slot_mapping[:num_actual_tokens] + + # for local attention + local_attn_metadata = None + if self.runner.attention_chunk_size is not None: + ( + seqlens_q_local_np, + virt_q_cu_seqlens_np, + virt_k_seqlens_np, + virt_block_table_tensor, + ) = make_local_attention_virtual_batches( + self.runner.attention_chunk_size, + self.runner.query_start_loc_np[: num_reqs + 1], + self.runner.seq_lens_np[:num_reqs], + block_table_tensor, + self.block_size, + ) + local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to( + self.runner.device, non_blocking=True + ) + local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to( + self.runner.device, non_blocking=True + ) + local_max_query_len = seqlens_q_local_np.max() + local_max_seq_len = virt_k_seqlens_np.max() + + local_attn_metadata = TritonAttentionMetadata.LocalAttentionMetadata( + local_query_start_loc=local_query_start_loc, + local_seqused_k=local_seqused_k, + local_block_table=virt_block_table_tensor, + local_max_query_len=local_max_query_len, + local_max_seq_len=local_max_seq_len, + local_scheduler_metadata=None, + ) + + # saving for the max input tokens length + max_prompt_len = 0 + for req_id in self.runner.input_batch.req_id_to_index.keys(): + req_state = self.runner.requests.get(req_id) + if req_state: + prompt_len = len(req_state.prompt_token_ids) + max_prompt_len = max(max_prompt_len, prompt_len) + + use_rerope = max_prompt_len > REROPE_WINDOW + + use_cascade = common_prefix_len > 0 + + if use_cascade: + cu_prefix_query_lens = torch.tensor( + [0, num_actual_tokens], dtype=torch.int32, device=self.runner.device + ) + prefix_kv_lens = torch.tensor( + [common_prefix_len], dtype=torch.int32, device=self.runner.device + ) + suffix_kv_lens = self.runner.seq_lens_np[:num_reqs] - common_prefix_len + suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to(self.runner.device) + else: + cu_prefix_query_lens = None + prefix_kv_lens = None + suffix_kv_lens = None + prefix_scheduler_metadata = None + + attn_metadata = TritonAttentionMetadata( + num_actual_tokens=num_actual_tokens, + max_query_len=max_query_len, + query_start_loc=query_start_loc, + max_seq_len=max_seq_len, + seq_lens=seq_lens, + block_table=block_table_tensor, + slot_mapping=slot_mapping, + use_cascade=use_cascade, + common_prefix_len=common_prefix_len, + cu_prefix_query_lens=cu_prefix_query_lens, + prefix_kv_lens=prefix_kv_lens, + suffix_kv_lens=suffix_kv_lens, + local_attn_metadata=local_attn_metadata, + prefix_scheduler_metadata=prefix_scheduler_metadata, + use_rerope=use_rerope, + ) + return attn_metadata + + TritonAttentionMetadataBuilder.build = TritonAttentionMetadataBuilder_build + + def TritonAttentionBackend_get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> tuple[int, ...]: + if block_size % 16 != 0: + raise ValueError("Block size must be a multiple of 16.") + + return (3, num_blocks, block_size, num_kv_heads, head_size) + + TritonAttentionBackend.get_kv_cache_shape = staticmethod( + TritonAttentionBackend_get_kv_cache_shape + ) + + def TritonAttentionImpl_forwad( + self, + layer: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + query2: Optional[torch.Tensor], + key2: Optional[torch.Tensor], + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: TritonAttentionMetadata, + output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with FlashAttention. + + Args: + query: shape = [num_tokens, num_heads, head_size] + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size] + kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + assert output is not None, "Output tensor must be provided." + + if output_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported" + " for TritonAttentionImpl" + ) + + if attn_metadata is None: + # Profiling run. + return output + + assert attn_metadata.use_cascade is False + + # IMPORTANT! + # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in + # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead + # in this method. For example, `view` and `slice` (or `[:n]`) operations + # are surprisingly slow even in the case they do not invoke any GPU ops. + # Minimize the PyTorch ops in this method as much as possible. + # Whenever making a change in this method, please benchmark the + # performance to make sure it does not introduce any overhead. + + num_actual_tokens = attn_metadata.num_actual_tokens + + key_cache, value_cache, key_cache2 = kv_cache.unbind(0) + + if self.kv_sharing_target_layer_name is None: + # Reshape the input keys and values and store them in the cache. + # Skip this if sharing KV cache with an earlier attention layer. + torch.ops._C_cache_ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + if key2 is not None: + torch.ops._C_cache_ops.reshape_and_cache_flash( + key2, + value, + key_cache2, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + if self.kv_cache_dtype.startswith("fp8"): + key_cache = key_cache.view(self.fp8_dtype) + if key_cache2 is not None: + key_cache2 = key_cache2.view(self.fp8_dtype) + value_cache = value_cache.view(self.fp8_dtype) + num_tokens, num_heads, head_size = query.shape + assert ( + layer._q_scale == 1.0 + ), "A non 1.0 q_scale is not currently supported." + if not current_platform.is_rocm(): + # Skip Q quantization on ROCm, since dequantizing back to + # f32 in the attention kernel is not supported. + query, _ = ops.scaled_fp8_quant( + query.reshape((num_tokens, num_heads * head_size)).contiguous(), + layer._q_scale, + ) + query = query.reshape((num_tokens, num_heads, head_size)) + if query2 is not None: + query2, _ = ops.scaled_fp8_quant( + query2.reshape( + (num_tokens, num_heads * head_size) + ).contiguous(), + layer._q_scale, + ) + query2 = query2.reshape((num_tokens, num_heads, head_size)) + + use_local_attn = ( + self.use_irope and attn_metadata.local_attn_metadata is not None + ) + + if use_local_attn: + assert attn_metadata.local_attn_metadata is not None + local_metadata = attn_metadata.local_attn_metadata + cu_seqlens_q = local_metadata.local_query_start_loc + seqused_k = local_metadata.local_seqused_k + max_seqlen_q = local_metadata.local_max_query_len + max_seqlen_k = local_metadata.local_max_seq_len + block_table = local_metadata.local_block_table + else: + cu_seqlens_q = attn_metadata.query_start_loc + seqused_k = attn_metadata.seq_lens + max_seqlen_q = attn_metadata.max_query_len + max_seqlen_k = attn_metadata.max_seq_len + block_table = attn_metadata.block_table + + descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) + + if attn_metadata.use_rerope: + unified_attention_rerope( + q=query[:num_actual_tokens], + k=key_cache, + q2=query2[:num_actual_tokens], + k2=key_cache2, + v=value_cache, + out=output[:num_actual_tokens], + cu_seqlens_q=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + seqused_k=seqused_k, + max_seqlen_k=max_seqlen_k, + softmax_scale=self.scale, + causal=True, + rerope_window=REROPE_WINDOW, + alibi_slopes=self.alibi_slopes, + window_size=self.sliding_window, + block_table=block_table, + softcap=self.logits_soft_cap, + q_descale=None, + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), + ) + else: + unified_attention( + q=query[:num_actual_tokens], + k=key_cache, + v=value_cache, + out=output[:num_actual_tokens], + cu_seqlens_q=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + seqused_k=seqused_k, + max_seqlen_k=max_seqlen_k, + softmax_scale=self.scale, + causal=True, + alibi_slopes=self.alibi_slopes, + window_size=self.sliding_window, + block_table=block_table, + softcap=self.logits_soft_cap, + q_descale=None, + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), + ) + + return output + + TritonAttentionImpl.forward = TritonAttentionImpl_forwad + + except ImportError: + logger.warning("Could not patch triton attention - module not found") + \ No newline at end of file diff --git a/ucm/sparse/rerope/rerope_utils.py b/ucm/sparse/rerope/rerope_utils.py new file mode 100644 index 000000000..27418299b --- /dev/null +++ b/ucm/sparse/rerope/rerope_utils.py @@ -0,0 +1,15 @@ +import os + + +class VllmPatchConfig: + + @property + def rerope_window(self) -> int: + return int(os.getenv("REROPE_WINDOW", "32768")) + + @property + def training_length(self) -> int: + return int(os.getenv("TRAINING_LENGTH", "32768")) + + +default_config = VllmPatchConfig() diff --git a/ucm/sparse/rerope/triton_unified_attention_rerope.py b/ucm/sparse/rerope/triton_unified_attention_rerope.py new file mode 100644 index 000000000..6334b8643 --- /dev/null +++ b/ucm/sparse/rerope/triton_unified_attention_rerope.py @@ -0,0 +1,885 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Authors: +# - Burkhard Ringlein +# - Jan van Lunteren +# - Chih-Chieh Yang +# - Thomas Parnell + +import torch +import triton +import triton.language as tl +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +@triton.jit +def cdiv_fn(x, y): + return (x + y - 1) // y + + +@triton.jit +def apply_softcap(S, x): + Sdiv = S / x + p1 = tl.exp(Sdiv) + p2 = tl.exp(-Sdiv) + return x * (p1 - p2) / (p1 + p2) + + +@triton.jit +def find_seq_idx( + query_start_len_ptr, + target_idx, + num_seqs, + BLOCK_Q: tl.constexpr, + use_q_block_mode: tl.constexpr, +): + left: tl.int32 = 0 + right = num_seqs + while left < right: + mid = (left + right) // 2 + val = tl.load(query_start_len_ptr + mid) + mid_val = val // BLOCK_Q + mid if use_q_block_mode else val + + if mid_val <= target_idx: + left = mid + 1 + else: + right = mid + + return left - 1 + + +@triton.jit +def kernel_unified_attention_2d( + output_ptr, # [num_tokens, num_query_heads, head_size] + query_ptr, # [num_tokens, num_query_heads, head_size] + key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] + query2_ptr, # [num_tokens, num_query_heads, head_size] + key_cache2_ptr, # [num_blks, blk_size, num_kv_heads, head_size] + value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] + block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] + seq_lens_ptr, # [num_seqs] + alibi_slopes_ptr, # [num_query_heads] + scale, # float32 + k_scale, # float32 + v_scale, # float32 + softcap, # float32 + num_query_heads: tl.constexpr, # int + num_queries_per_kv: tl.constexpr, # int + block_table_stride: tl.int64, # int + query_stride_0: tl.int64, # int + query_stride_1: tl.int64, # int, should be equal to head_size + query2_stride_0: tl.int64, # int + query2_stride_1: tl.int64, # int, should be equal to head_size + output_stride_0: tl.int64, # int + output_stride_1: tl.int64, # int, should be equal to head_size + REROPE_WINDOW: tl.constexpr, # int + BLOCK_SIZE: tl.constexpr, # int + HEAD_SIZE: tl.constexpr, # int + HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 + USE_ALIBI_SLOPES: tl.constexpr, # bool + USE_SOFTCAP: tl.constexpr, # bool + SLIDING_WINDOW: tl.constexpr, # int + stride_k_cache_0: tl.int64, # int + stride_k_cache_1: tl.int64, # int + stride_k_cache_2: tl.int64, # int + stride_k_cache_3: tl.constexpr, # int + stride_k_cache2_0: tl.int64, # int + stride_k_cache2_1: tl.int64, # int + stride_k_cache2_2: tl.int64, # int + stride_k_cache2_3: tl.constexpr, # int + stride_v_cache_0: tl.int64, # int + stride_v_cache_1: tl.int64, # int + stride_v_cache_2: tl.int64, # int + stride_v_cache_3: tl.constexpr, # int + query_start_len_ptr, # [num_seqs+1] + BLOCK_Q: tl.constexpr, # int + num_seqs: tl.int32, + BLOCK_M: tl.constexpr, # int +): + q_block_global_idx = tl.program_id(0) + kv_head_idx = tl.program_id(1) + + seq_idx = find_seq_idx( + query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q, True + ) + + q_block_start_idx = tl.load(query_start_len_ptr + seq_idx) // BLOCK_Q + seq_idx + + q_block_local_idx = q_block_global_idx - q_block_start_idx + + cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx) + cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1) + + cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index + + if q_block_local_idx * BLOCK_Q >= cur_batch_query_len: + return + + offs_m = tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, HEAD_SIZE_PADDED) + query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv + + query_offset_0 = cur_batch_in_all_start_index + query_pos + query_offset_1 = kv_head_idx * num_queries_per_kv + offs_m % num_queries_per_kv + query_offset = ( + query_offset_0[:, None] * query_stride_0 + + query_offset_1[:, None] * query_stride_1 + + offs_d[None, :] + ) + query2_offset = ( + query_offset_0[:, None] * query2_stride_0 + + query_offset_1[:, None] * query2_stride_1 + + offs_d[None, :] + ) + + dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1) + query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1) + query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1) + + # Q : (BLOCK_M, HEAD_SIZE_PADDED) + Q = tl.load( + query_ptr + query_offset, + mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], + other=0.0, + ) + Q2 = tl.load( + query2_ptr + query2_offset, + mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], + other=0.0, + ) + + block_table_offset = seq_idx * block_table_stride + + M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + L = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32) + + # sequence len for this particular sequence + seq_len = tl.load(seq_lens_ptr + seq_idx) + + # context length for this particular sequences + context_len = seq_len - cur_batch_query_len + + # alibi slope for this head + if USE_ALIBI_SLOPES: + alibi_slope = tl.load( + alibi_slopes_ptr + query_offset_1, mask=query_mask_1, other=0.0 + ) + + num_blocks = cdiv_fn(seq_len, BLOCK_SIZE) + + # iterate through tiles + for j in range(0, num_blocks): + + physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j) + + offs_n = tl.arange(0, BLOCK_SIZE) + + v_offset = ( + physical_block_idx * stride_v_cache_0 + + kv_head_idx * stride_v_cache_2 + + offs_d[None, :] * stride_v_cache_3 + + offs_n[:, None] * stride_v_cache_1 + ) + + k_offset = ( + physical_block_idx * stride_k_cache_0 + + kv_head_idx * stride_k_cache_2 + + offs_d[:, None] * stride_k_cache_3 + + offs_n[None, :] * stride_k_cache_1 + ) + + k2_offset = ( + physical_block_idx * stride_k_cache2_0 + + kv_head_idx * stride_k_cache2_2 + + offs_d[:, None] * stride_k_cache2_3 + + offs_n[None, :] * stride_k_cache2_1 + ) + + # K : (HEAD_SIZE, BLOCK_SIZE) + K_load = tl.load(key_cache_ptr + k_offset, mask=dim_mask[:, None], other=0.0) + + K2_load = tl.load(key_cache2_ptr + k2_offset, mask=dim_mask[:, None], other=0.0) + + if K_load.dtype.is_fp8(): + if Q.dtype.is_fp8(): + K = K_load + else: + K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype) + else: + K = K_load + + if K2_load.dtype.is_fp8(): + if Q2.dtype.is_fp8(): + K2 = K2_load + else: + K2 = (K2_load.to(tl.float32) * tl.load(k_scale)).to(Q2.dtype) + else: + K2 = K2_load + + # V : (BLOCK_SIZE, HEAD_SIZE) + V_load = tl.load(value_cache_ptr + v_offset, mask=dim_mask[None, :], other=0.0) + + if V_load.dtype.is_fp8(): + if Q.dtype.is_fp8(): + V = V_load + else: + V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype) + else: + V = V_load + + seq_offset = j * BLOCK_SIZE + offs_n + + seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1 + + # S : (BLOCK_M, BLOCK_SIZE) + S1 = tl.zeros(shape=(BLOCK_M, BLOCK_SIZE), dtype=tl.float32) + S1 += scale * tl.dot(Q, K) + + # rerope mask + S2 = tl.zeros(shape=(BLOCK_M, BLOCK_SIZE), dtype=tl.float32) + S2 += scale * tl.dot(Q2, K2) + + query_pos_rerope = context_len + query_pos[:, None] + 1 + key_pos_rerope = seq_offset[None, :] + + valid_query_mask = query_pos[:, None] < cur_batch_query_len + pos_diff = tl.abs(query_pos_rerope - key_pos_rerope) + rerope_mask = pos_diff < REROPE_WINDOW + rerope_mask = rerope_mask & valid_query_mask + + if USE_SOFTCAP: + S1 = apply_softcap(S1, softcap) + S2 = apply_softcap(S2, softcap) + + S = tl.where(rerope_mask, S1, S2) + + # causal mask + S = tl.where( + query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, S, float("-inf") + ) + + if SLIDING_WINDOW > 0: + S = tl.where( + (context_len + query_pos[:, None] - seq_offset) < SLIDING_WINDOW, + S, + float("-inf"), + ) + + if USE_ALIBI_SLOPES: + S += alibi_slope[:, None] * (seq_offset - context_len) + + # compute running maximum + # m_j : (BLOCK_M,) + m_j = tl.maximum(M, tl.max(S, axis=1)) + # For sliding window there's a chance the max is -inf due to masking of + # the entire row. In this case we need to set m_j 0 to avoid NaN + m_j = tl.where(m_j > float("-inf"), m_j, 0.0) + + # P : (BLOCK_M, BLOCK_SIZE) + P = tl.exp(S - m_j[:, None]) + + # l_j : (BLOCK_M,) + l_j = tl.sum(P, axis=1) + + # alpha : (BLOCK_M, ) + alpha = tl.exp(M - m_j) + + # acc : (BLOCK_M, HEAD_SIZE_PADDED) + acc = acc * alpha[:, None] + + # update constants + L = L * alpha + l_j + M = m_j + + # acc : (BLOCK_M, HEAD_SIZE_PADDED) + acc += tl.dot(P.to(V.dtype), V) + + # epilogue + acc = acc / L[:, None] + + output_offset = ( + query_offset_0[:, None] * output_stride_0 + + query_offset_1[:, None] * output_stride_1 + + offs_d[None, :] + ) + + tl.store( + output_ptr + output_offset, + acc, + mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], + ) + + +@triton.jit +def kernel_unified_attention_3d( + segm_output_ptr, + # [num_tokens, num_query_heads, num_segments, head_size] + segm_max_ptr, # [num_tokens, num_query_heads, num_segments] + segm_expsum_ptr, # [num_tokens, num_query_heads, num_segments] + query_ptr, # [num_tokens, num_query_heads, head_size] + key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x] + query2_ptr, # [num_tokens, num_query_heads, head_size] + key_cache2_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x] + value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size] + block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] + seq_lens_ptr, # [num_seqs] + alibi_slopes_ptr, # [num_query_heads] + scale, # float32 + k_scale, # float32 + v_scale, # float32 + softcap, # float32 + num_query_heads: tl.constexpr, # int + num_queries_per_kv: tl.constexpr, # int + block_table_stride: tl.int64, # int + query_stride_0: tl.int64, # int + query_stride_1: tl.int64, # int, should be equal to head_size + query2_stride_0: tl.int64, # int + query2_stride_1: tl.int64, # int, should be equal to head_size + BLOCK_SIZE: tl.constexpr, # int + HEAD_SIZE: tl.constexpr, # int + HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 + USE_ALIBI_SLOPES: tl.constexpr, # bool + USE_SOFTCAP: tl.constexpr, # bool + SLIDING_WINDOW: tl.constexpr, # int + stride_k_cache_0: tl.int64, # int + stride_k_cache_1: tl.int64, # int + stride_k_cache_2: tl.int64, # int + stride_k_cache_3: tl.constexpr, # int + stride_k_cache2_0: tl.int64, # int + stride_k_cache2_1: tl.int64, # int + stride_k_cache2_2: tl.int64, # int + stride_k_cache2_3: tl.constexpr, # int + stride_v_cache_0: tl.int64, # int + stride_v_cache_1: tl.int64, # int + stride_v_cache_2: tl.int64, # int + stride_v_cache_3: tl.constexpr, # int + query_start_len_ptr, # [num_seqs+1] + REROPE_WINDOW: tl.constexpr, # int + BLOCK_Q: tl.constexpr, # int + num_seqs: tl.int32, + BLOCK_M: tl.constexpr, # int + NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int +): + q_block_global_idx = tl.program_id(0) + kv_head_idx = tl.program_id(1) + segm_idx = tl.program_id(2) + + seq_idx = find_seq_idx( + query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q, True + ) + + q_block_start_idx = tl.load(query_start_len_ptr + seq_idx) // BLOCK_Q + seq_idx + + q_block_local_idx = q_block_global_idx - q_block_start_idx + + cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx) + cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1) + + cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index + + if q_block_local_idx * BLOCK_Q >= cur_batch_query_len: + return + + # sequence len for this particular sequence + seq_len = tl.load(seq_lens_ptr + seq_idx) + + # number of segments for this particular sequence + num_segments = NUM_SEGMENTS_PER_SEQ + blocks_per_segment = cdiv_fn(seq_len, num_segments * BLOCK_SIZE) + + if segm_idx * blocks_per_segment * BLOCK_SIZE >= seq_len: + return + + offs_m = tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, HEAD_SIZE_PADDED) + + query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv + + query_offset_0 = cur_batch_in_all_start_index + query_pos + query_offset_1 = kv_head_idx * num_queries_per_kv + offs_m % num_queries_per_kv + + query_offset = ( + query_offset_0[:, None] * query_stride_0 + + query_offset_1[:, None] * query_stride_1 + + offs_d[None, :] + ) + query2_offset = ( + query_offset_0[:, None] * query2_stride_0 + + query_offset_1[:, None] * query2_stride_1 + + offs_d[None, :] + ) + + dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1) + query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1) + query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1) + + # Q : (BLOCK_M, HEAD_SIZE_PADDED) + Q = tl.load( + query_ptr + query_offset, + mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], + other=0.0, + ) + Q2 = tl.load( + query2_ptr + query2_offset, + mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], + other=0.0, + ) + + block_table_offset = seq_idx * block_table_stride + + M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + L = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32) + + # context length for this particular sequences + context_len = seq_len - cur_batch_query_len + + # alibi slope for this head + if USE_ALIBI_SLOPES: + alibi_slope = tl.load( + alibi_slopes_ptr + query_offset_1, mask=query_mask_1, other=0.0 + ) + + num_blocks = cdiv_fn(seq_len, BLOCK_SIZE) + + # iterate through tiles within current segment + for j in range( + segm_idx * blocks_per_segment, + min((segm_idx + 1) * blocks_per_segment, num_blocks), + ): + physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j) + + offs_n = tl.arange(0, BLOCK_SIZE) + + v_offset = ( + physical_block_idx * stride_v_cache_0 + + kv_head_idx * stride_v_cache_2 + + offs_d[None, :] * stride_v_cache_3 + + offs_n[:, None] * stride_v_cache_1 + ) + + k_offset = ( + physical_block_idx * stride_k_cache_0 + + kv_head_idx * stride_k_cache_2 + + offs_d[:, None] * stride_k_cache_3 + + offs_n[None, :] * stride_k_cache_1 + ) + + k2_offset = ( + physical_block_idx * stride_k_cache2_0 + + kv_head_idx * stride_k_cache2_2 + + offs_d[:, None] * stride_k_cache2_3 + + offs_n[None, :] * stride_k_cache2_1 + ) + + # K : (HEAD_SIZE, BLOCK_SIZE) + K_load = tl.load(key_cache_ptr + k_offset, mask=dim_mask[:, None], other=0.0) + + K2_load = tl.load(key_cache2_ptr + k2_offset, mask=dim_mask[:, None], other=0.0) + + if K_load.dtype.is_fp8(): + if Q.dtype.is_fp8(): + K = K_load + else: + K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype) + else: + K = K_load + + if K2_load.dtype.is_fp8(): + if Q2.dtype.is_fp8(): + K2 = K2_load + else: + K2 = (K2_load.to(tl.float32) * tl.load(k_scale)).to(Q2.dtype) + else: + K2 = K2_load + + # V : (BLOCK_SIZE, HEAD_SIZE) + V_load = tl.load(value_cache_ptr + v_offset, mask=dim_mask[None, :], other=0.0) + + if V_load.dtype.is_fp8(): + if Q.dtype.is_fp8(): + V = V_load + else: + V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype) + else: + V = V_load + + seq_offset = j * BLOCK_SIZE + offs_n + + seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1 + + # S : (BLOCK_M, BLOCK_SIZE) + S1 = tl.zeros(shape=(BLOCK_M, BLOCK_SIZE), dtype=tl.float32) + S1 += scale * tl.dot(Q, K) + + # rerope mask + S2 = tl.zeros(shape=(BLOCK_M, BLOCK_SIZE), dtype=tl.float32) + S2 += scale * tl.dot(Q2, K2) + + query_pos_rerope = context_len + query_pos[:, None] + 1 + key_pos_rerope = seq_offset[None, :] + + valid_query_mask = query_pos[:, None] < cur_batch_query_len + pos_diff = tl.abs(query_pos_rerope - key_pos_rerope) + rerope_mask = pos_diff < REROPE_WINDOW + rerope_mask = rerope_mask & valid_query_mask + + if USE_SOFTCAP: + S1 = apply_softcap(S1, softcap) + S2 = apply_softcap(S2, softcap) + + S = tl.where(rerope_mask, S1, S2) + + # causal mask + S = tl.where( + query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, S, float("-inf") + ) + + if SLIDING_WINDOW > 0: + S = tl.where( + (context_len + query_pos[:, None] - seq_offset) < SLIDING_WINDOW, + S, + float("-inf"), + ) + + if USE_ALIBI_SLOPES: + S += alibi_slope[:, None] * (seq_offset - context_len) + + # compute running maximum + # m_j : (BLOCK_M,) + m_j = tl.maximum(M, tl.max(S, axis=1)) + # For sliding window there's a chance the max is -inf due to masking of + # the entire row. In this case we need to set m_j 0 to avoid NaN + m_j = tl.where(m_j > float("-inf"), m_j, 0.0) + + # P : (BLOCK_M, BLOCK_SIZE,) + P = tl.exp(S - m_j[:, None]) + + # l_j : (BLOCK_M,) + l_j = tl.sum(P, axis=1) + + # alpha : (BLOCK_M, ) + alpha = tl.exp(M - m_j) + + # acc : (BLOCK_M, HEAD_SIZE_PADDED) + acc = acc * alpha[:, None] + + # update constants + L = L * alpha + l_j + M = m_j + + # acc : (BLOCK_M, HEAD_SIZE_PADDED) + acc += tl.dot(P.to(V.dtype), V) + + segm_output_offset = ( + query_offset_0[:, None].to(tl.int64) + * (num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + + query_offset_1[:, None] * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + + segm_idx * HEAD_SIZE_PADDED + + tl.arange(0, HEAD_SIZE_PADDED)[None, :] + ) + tl.store( + segm_output_ptr + segm_output_offset, + acc, + mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], + ) + segm_offset = ( + query_offset_0.to(tl.int64) * (num_query_heads * NUM_SEGMENTS_PER_SEQ) + + query_offset_1 * NUM_SEGMENTS_PER_SEQ + + segm_idx + ) + tl.store(segm_max_ptr + segm_offset, M, mask=query_mask_0 & query_mask_1) + tl.store(segm_expsum_ptr + segm_offset, L, mask=query_mask_0 & query_mask_1) + + +@triton.jit +def reduce_segments( + output_ptr, # [num_tokens, num_query_heads, head_size] + segm_output_ptr, + # [num_tokens, num_query_heads, max_num_segments, head_size] + segm_max_ptr, # [num_tokens, num_query_heads, max_num_segments] + segm_expsum_ptr, # [num_tokens, num_query_heads, max_num_segments] + seq_lens_ptr, # [num_seqs] + num_seqs, # int + num_query_heads: tl.constexpr, # int + output_stride_0: tl.int64, # int + output_stride_1: tl.int64, # int, should be equal to head_size + block_table_stride: tl.int64, # int + BLOCK_SIZE: tl.constexpr, # int + HEAD_SIZE: tl.constexpr, # int, must be power of 2 + HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 + query_start_len_ptr, # [num_seqs+1] + BLOCK_Q: tl.constexpr, # int + NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int +): + query_token_idx = tl.program_id(0) + query_head_idx = tl.program_id(1) + + seq_idx = find_seq_idx( + query_start_len_ptr, query_token_idx, num_seqs, BLOCK_Q, False + ) + + # sequence len for this particular sequence + seq_len = tl.load(seq_lens_ptr + seq_idx) + + # number of segments for this particular sequence + num_segments = NUM_SEGMENTS_PER_SEQ + blocks_per_segment = cdiv_fn(seq_len, num_segments * BLOCK_SIZE) + + # create masks for subsequent loads + act_num_segments = cdiv_fn(seq_len, blocks_per_segment * BLOCK_SIZE) + segm_mask = tl.arange(0, NUM_SEGMENTS_PER_SEQ) < tl.full( + [NUM_SEGMENTS_PER_SEQ], act_num_segments, dtype=tl.int32 + ) + dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1, 0).to(tl.int1) + + # load segment maxima + segm_offset = ( + query_token_idx.to(tl.int64) * (num_query_heads * NUM_SEGMENTS_PER_SEQ) + + query_head_idx * NUM_SEGMENTS_PER_SEQ + + tl.arange(0, NUM_SEGMENTS_PER_SEQ) + ) + segm_max = tl.load(segm_max_ptr + segm_offset, mask=segm_mask, other=float("-inf")) + overall_max = tl.max(segm_max) + + # load and rescale segment exp sums + segm_expsum = tl.load(segm_expsum_ptr + segm_offset, mask=segm_mask, other=0.0) + segm_expsum = segm_expsum * tl.exp(segm_max - overall_max) + overall_expsum = tl.sum(segm_expsum) + + # load, rescale, and add segment attention outputs + segm_output_offset = ( + query_token_idx.to(tl.int64) + * (num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + + query_head_idx * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + + tl.arange(0, NUM_SEGMENTS_PER_SEQ)[:, None] * HEAD_SIZE_PADDED + + tl.arange(0, HEAD_SIZE_PADDED)[None, :] + ) + segm_output = tl.load( + segm_output_ptr + segm_output_offset, + mask=segm_mask[:, None] & dim_mask[None, :], + other=0.0, + ) + segm_output *= tl.exp(segm_max - overall_max)[:, None] + acc_sum = tl.sum(segm_output, axis=0) + # safely divide by overall_expsum, returning 0.0 if overall_expsum is 0 + acc = tl.where(overall_expsum == 0.0, 0.0, acc_sum / overall_expsum) + + # write result + output_offset = ( + query_token_idx * output_stride_0 + + query_head_idx * output_stride_1 + + tl.arange(0, HEAD_SIZE_PADDED) + ) + tl.store(output_ptr + output_offset, acc, mask=dim_mask) + + +def unified_attention_rerope( + q, + k, + q2, + k2, + v, + out, + cu_seqlens_q, + max_seqlen_q, + seqused_k, + max_seqlen_k, + softmax_scale, + causal, + rerope_window, + window_size, + block_table, + softcap, + q_descale, + k_descale, + v_descale, + alibi_slopes=None, +): + assert causal, "Only causal attention is supported" + assert q_descale is None, "Q scales not supported" + + block_size = v.shape[1] + assert ( + q.element_size() >= 2 or block_size >= 32 + ), "Block size must be at least 32 for fp8" + + use_alibi_slopes = alibi_slopes is not None + + block_size = v.shape[1] + num_seqs = len(seqused_k) + num_query_heads = q.shape[1] + num_kv_heads = k.shape[2] + num_queries_per_kv = num_query_heads // num_kv_heads + head_size = q.shape[2] + + BLOCK_M = 16 + BLOCK_Q = BLOCK_M // num_queries_per_kv + + # Ideally we would launch with kernel with: + # \sum_i[ceil(query_len[i] / BLOCK_Q)] blocks. + # However, it is slow to realize the query_lens on cpu. + # Instead we use upper-bound: + # \sum_i[ceil(query_len[i] / BLOCK_Q)] + # <= \sum_i[floor(query_len[i] / BLOCK_Q) + 1] + # = \sum_i[floor(query_len[i] / BLOCK_Q)] + num_seqs + # <= floor(\sum_i(query_len[i]) / BLOCK_Q) + num_seqs + # = floor(q.shape[0] / BLOCK_Q) + num_seqs + total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs + + # if batch contains a prefill + if max_seqlen_q > 1 or total_num_q_blocks * num_kv_heads > 128: + with torch.cuda.nvtx.range("atten_2D"): + kernel_unified_attention_2d[ + ( + total_num_q_blocks, + num_kv_heads, + ) + ]( + output_ptr=out, + query_ptr=q, + key_cache_ptr=k, + query2_ptr=q2, + key_cache2_ptr=k2, + value_cache_ptr=v, + block_tables_ptr=block_table, + seq_lens_ptr=seqused_k, + alibi_slopes_ptr=alibi_slopes, + scale=softmax_scale, + k_scale=k_descale, + v_scale=v_descale, + softcap=softcap, + num_query_heads=num_query_heads, + num_queries_per_kv=num_queries_per_kv, + block_table_stride=block_table.stride(0), + query_stride_0=q.stride(0), + query_stride_1=q.stride(1), + query2_stride_0=q2.stride(0), + query2_stride_1=q2.stride(1), + output_stride_0=out.stride(0), + output_stride_1=out.stride(1), + REROPE_WINDOW=rerope_window, + BLOCK_SIZE=block_size, + HEAD_SIZE=head_size, + HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), + USE_ALIBI_SLOPES=use_alibi_slopes, + USE_SOFTCAP=(softcap > 0), + SLIDING_WINDOW=(1 + window_size[0]), + stride_k_cache_0=k.stride(0), + stride_k_cache_1=k.stride(1), + stride_k_cache_2=k.stride(2), + stride_k_cache_3=k.stride(3), + stride_k_cache2_0=k.stride(0), + stride_k_cache2_1=k.stride(1), + stride_k_cache2_2=k.stride(2), + stride_k_cache2_3=k.stride(3), + stride_v_cache_0=v.stride(0), + stride_v_cache_1=v.stride(1), + stride_v_cache_2=v.stride(2), + stride_v_cache_3=v.stride(3), + query_start_len_ptr=cu_seqlens_q, + BLOCK_Q=BLOCK_Q, + num_seqs=num_seqs, + BLOCK_M=BLOCK_M, + ) + torch.cuda.synchronize() + else: + # for initial version, NUM_SEGMENTS = 16 is chosen as a default + # value that showed good performance in tests + NUM_SEGMENTS = 16 + + segm_output = torch.empty( + q.shape[0], + num_query_heads, + NUM_SEGMENTS, + triton.next_power_of_2(head_size), + dtype=torch.float32, + device=q.device, + ) + segm_max = torch.empty( + q.shape[0], + num_query_heads, + NUM_SEGMENTS, + dtype=torch.float32, + device=q.device, + ) + segm_expsum = torch.empty( + q.shape[0], + num_query_heads, + NUM_SEGMENTS, + dtype=torch.float32, + device=q.device, + ) + + kernel_unified_attention_3d[(total_num_q_blocks, num_kv_heads, NUM_SEGMENTS)]( + segm_output_ptr=segm_output, + segm_max_ptr=segm_max, + segm_expsum_ptr=segm_expsum, + query_ptr=q, + key_cache_ptr=k, + query2_ptr=q2, + key_cache2_ptr=k2, + value_cache_ptr=v, + block_tables_ptr=block_table, + seq_lens_ptr=seqused_k, + alibi_slopes_ptr=alibi_slopes, + scale=softmax_scale, + k_scale=k_descale, + v_scale=v_descale, + softcap=softcap, + num_query_heads=num_query_heads, + num_queries_per_kv=num_queries_per_kv, + block_table_stride=block_table.stride(0), + query_stride_0=q.stride(0), + query_stride_1=q.stride(1), + query2_stride_0=q2.stride(0), + query2_stride_1=q2.stride(1), + BLOCK_SIZE=block_size, + HEAD_SIZE=head_size, + HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), + USE_ALIBI_SLOPES=use_alibi_slopes, + USE_SOFTCAP=(softcap > 0), + SLIDING_WINDOW=(1 + window_size[0]), + stride_k_cache_0=k.stride(0), + stride_k_cache_1=k.stride(1), + stride_k_cache_2=k.stride(2), + stride_k_cache_3=k.stride(3), + stride_k_cache2_0=k2.stride(0), + stride_k_cache2_1=k2.stride(1), + stride_k_cache2_2=k2.stride(2), + stride_k_cache2_3=k2.stride(3), + stride_v_cache_0=v.stride(0), + stride_v_cache_1=v.stride(1), + stride_v_cache_2=v.stride(2), + stride_v_cache_3=v.stride(3), + query_start_len_ptr=cu_seqlens_q, + REROPE_WINDOW=rerope_window, + BLOCK_Q=BLOCK_Q, + num_seqs=num_seqs, + BLOCK_M=BLOCK_M, + NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, + ) + + reduce_segments[(q.shape[0], num_query_heads)]( + output_ptr=out, + segm_output_ptr=segm_output, + segm_max_ptr=segm_max, + segm_expsum_ptr=segm_expsum, + seq_lens_ptr=seqused_k, + num_seqs=num_seqs, + num_query_heads=num_query_heads, + output_stride_0=out.stride(0), + output_stride_1=out.stride(1), + block_table_stride=block_table.stride(0), + BLOCK_SIZE=block_size, + HEAD_SIZE=head_size, + HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), + query_start_len_ptr=cu_seqlens_q, + BLOCK_Q=BLOCK_Q, + NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, + ) + \ No newline at end of file