From bc281e2e8a812fc9a5b0f5f523710bccdaf5031c Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 1 Dec 2025 12:40:30 +0000 Subject: [PATCH 01/79] qwen3_moe mtp --- lightllm/models/qwen3_moe_mtp/__init__.py | 0 .../qwen3_moe_mtp/layer_infer/__init__.py | 0 .../layer_infer/pre_layer_infer.py | 60 +++++++++++++++++++ .../qwen3_moe_mtp/layer_weights/__init__.py | 0 .../pre_and_post_layer_weight.py | 29 +++++++++ lightllm/models/qwen3_moe_mtp/model.py | 47 +++++++++++++++ .../model_infer/mode_backend/base_backend.py | 10 +++- 7 files changed, 143 insertions(+), 3 deletions(-) create mode 100644 lightllm/models/qwen3_moe_mtp/__init__.py create mode 100644 lightllm/models/qwen3_moe_mtp/layer_infer/__init__.py create mode 100644 lightllm/models/qwen3_moe_mtp/layer_infer/pre_layer_infer.py create mode 100644 lightllm/models/qwen3_moe_mtp/layer_weights/__init__.py create mode 100644 lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py create mode 100644 lightllm/models/qwen3_moe_mtp/model.py diff --git a/lightllm/models/qwen3_moe_mtp/__init__.py b/lightllm/models/qwen3_moe_mtp/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/qwen3_moe_mtp/layer_infer/__init__.py b/lightllm/models/qwen3_moe_mtp/layer_infer/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/qwen3_moe_mtp/layer_infer/pre_layer_infer.py b/lightllm/models/qwen3_moe_mtp/layer_infer/pre_layer_infer.py new file mode 100644 index 000000000..66a41da73 --- /dev/null +++ b/lightllm/models/qwen3_moe_mtp/layer_infer/pre_layer_infer.py @@ -0,0 +1,60 @@ +import torch + +from lightllm.models.deepseek_mtp.layer_weights.pre_and_post_layer_weight import Deepseek3MTPPreAndPostLayerWeight +from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo +from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer +from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward + + +class Deepseek3MTPPreLayerInfer(LlamaPreLayerInfer): + """ """ + + def __init__(self, network_config, mode): + super().__init__(network_config, mode) + self.eps_ = network_config["rms_norm_eps"] + self.hidden_size = network_config["hidden_size"] + return + + def _mtp_context_forward( + self, input_embdings, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek3MTPPreAndPostLayerWeight + ): + tgt_embdings = infer_state.deepseekv3_mtp_draft_input_hiddens + assert input_embdings.shape[0] == tgt_embdings.shape[0] + rmsnorm_forward(input_embdings, weight=layer_weight.enorm_weight_, eps=self.eps_, out=input_embdings) + rmsnorm_forward(tgt_embdings, weight=layer_weight.hnorm_weight_, eps=self.eps_, out=tgt_embdings) + + cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1) + + ans_logics = self.alloc_tensor( + (cat_embdings.shape[0], layer_weight.eh_proj_weight_.shape[1]), dtype=input_embdings.dtype + ) + torch.mm(cat_embdings, layer_weight.eh_proj_weight_, out=ans_logics) + return ans_logics + + def _mtp_token_forward( + self, input_embdings, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek3MTPPreAndPostLayerWeight + ): + tgt_embdings = infer_state.deepseekv3_mtp_draft_input_hiddens + assert input_embdings.shape[0] == tgt_embdings.shape[0] + rmsnorm_forward(input_embdings, weight=layer_weight.enorm_weight_, eps=self.eps_, out=input_embdings) + rmsnorm_forward(tgt_embdings, weight=layer_weight.hnorm_weight_, eps=self.eps_, out=tgt_embdings) + + cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1) + + ans_logics = self.alloc_tensor( + (cat_embdings.shape[0], layer_weight.eh_proj_weight_.shape[1]), dtype=input_embdings.dtype + ) + torch.mm(cat_embdings, layer_weight.eh_proj_weight_, out=ans_logics) + return ans_logics + + def context_forward( + self, input_ids, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek3MTPPreAndPostLayerWeight + ): + input_embdings = super().context_forward(input_ids, infer_state, layer_weight) + return self._mtp_context_forward(input_embdings, infer_state, layer_weight) + + def token_forward( + self, input_ids, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek3MTPPreAndPostLayerWeight + ): + input_embdings = super().token_forward(input_ids, infer_state, layer_weight) + return self._mtp_token_forward(input_embdings, infer_state, layer_weight) diff --git a/lightllm/models/qwen3_moe_mtp/layer_weights/__init__.py b/lightllm/models/qwen3_moe_mtp/layer_weights/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py new file mode 100644 index 000000000..f5b805647 --- /dev/null +++ b/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py @@ -0,0 +1,29 @@ +import numpy as np +from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight + + +class Deepseek3MTPPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): + def __init__(self, data_type, network_config, mode): + super().__init__(data_type, network_config, mode) + # 与DeepseekV3模型共享 + self.wte_weight_ = None + self.lm_head_weight_ = None + return + + def load_hf_weights(self, weights): + if "model.layers.0.eh_proj.weight" in weights: + self.eh_proj_weight_ = self._cuda(weights["model.layers.0.eh_proj.weight"]).t() + if "model.layers.0.enorm.weight" in weights: + self.enorm_weight_ = self._cuda(weights["model.layers.0.enorm.weight"]) + if "model.layers.0.hnorm.weight" in weights: + self.hnorm_weight_ = self._cuda(weights["model.layers.0.hnorm.weight"]) + if "model.layers.0.shared_head.norm.weight" in weights: + self.final_norm_weight_ = self._cuda(weights["model.layers.0.shared_head.norm.weight"]) + return + + def verify_load(self): + errors = "weights load not ok" + weights = [self.eh_proj_weight_, self.enorm_weight_, self.hnorm_weight_, self.final_norm_weight_] + for i in range(len(weights)): + assert weights[i] is not None, "index:" + str(i) + " " + errors + return diff --git a/lightllm/models/qwen3_moe_mtp/model.py b/lightllm/models/qwen3_moe_mtp/model.py new file mode 100644 index 000000000..4586db4be --- /dev/null +++ b/lightllm/models/qwen3_moe_mtp/model.py @@ -0,0 +1,47 @@ +from lightllm.models.qwen3_moe.model import Qwen3MOEModel +from lightllm.models.deepseek_mtp.layer_infer.pre_layer_infer import Deepseek3MTPPreLayerInfer +from lightllm.models.deepseek_mtp.layer_weights.pre_and_post_layer_weight import Deepseek3MTPPreAndPostLayerWeight +from lightllm.common.basemodel import TpPartBaseModel + + +class Qwen3MOEMTPModel(Qwen3MOEModel): + + pre_and_post_weight_class = Deepseek3MTPPreAndPostLayerWeight + pre_layer_infer_class = Deepseek3MTPPreLayerInfer + + def __init__(self, kvargs: dict): + self._pre_init(kvargs) + super().__init__(kvargs) + return + + def _pre_init(self, kvargs: dict): + self.main_model: TpPartBaseModel = kvargs.pop("main_model") + self.mem_layer_start = kvargs.pop("mem_layer_start", 0) + return + + def _init_custom(self): + self._cos_cached = self.main_model._cos_cached + self._sin_cached = self.main_model._sin_cached + return + + def _init_req_manager(self): + self.req_manager = self.main_model.req_manager + return + + def _init_mem_manager(self): + self.mem_manager = self.main_model.mem_manager + return + + def _init_weights(self): + super()._init_weights() + self.pre_post_weight.wte_weight_ = self.main_model.pre_post_weight.wte_weight_ + self.pre_post_weight.lm_head_weight_ = self.main_model.pre_post_weight.lm_head_weight_ + self.pre_post_weight.final_norm_weight_ = self.main_model.pre_post_weight.final_norm_weight_ + return + + def _init_infer_layer(self): + super()._init_infer_layer() + # reset the layer_num_ of the self.layers_infer + for layer in self.layers_infer: + layer.layer_num_ = layer.layer_num_ + self.mem_layer_start + return diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index c37b8d74e..3c9b98b38 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -35,6 +35,7 @@ from lightllm.server.core.objs.shm_objs_io_buffer import ShmObjsIOBuffer from lightllm.server.router.model_infer.mode_backend.overlap_events import OverlapEventManager, OverlapEventPack from lightllm.models.deepseek_mtp.model import Deepseek3MTPModel +from lightllm.models.qwen3_moe_mtp.model import Qwen3MOEMTPModel from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token from lightllm.server.pd_io_struct import NIXLChunckedTransTaskRet @@ -321,9 +322,12 @@ def init_mtp_draft_model(self, main_kvargs: dict): } mtp_model_cfg, _ = PretrainedConfig.get_config_dict(self.args.mtp_draft_model_dir) - assert mtp_model_cfg["model_type"] == "deepseek_v3" - assert mtp_model_cfg["architectures"][0] == "DeepseekV3ForCausalLMNextN" - self.draft_models.append(Deepseek3MTPModel(mtp_model_kvargs)) + if mtp_model_cfg["model_type"] == "deepseekv3": + self.draft_models.append(Deepseek3MTPModel(mtp_model_kvargs)) + elif mtp_model_cfg["model_type"] == "qwen3_moe": + self.draft_models.append(Qwen3MOEMTPModel(mtp_model_kvargs)) + else: + assert False, f"error mtp mode {mtp_model_cfg['model_type']}" self.logger.info(f"loaded mtp model class {self.draft_models[i].__class__}") return From f4f841536d3e7d81d2d5e389d196f68417e326ba Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 1 Dec 2025 12:52:47 +0000 Subject: [PATCH 02/79] fix weight name --- lightllm/common/basemodel/basemodel.py | 4 +- .../qwen3_moe_mtp/layer_infer/__init__.py | 0 .../layer_infer/pre_layer_infer.py | 60 ------------------- .../pre_and_post_layer_weight.py | 18 +++--- lightllm/models/qwen3_moe_mtp/model.py | 4 +- 5 files changed, 13 insertions(+), 73 deletions(-) delete mode 100644 lightllm/models/qwen3_moe_mtp/layer_infer/__init__.py delete mode 100644 lightllm/models/qwen3_moe_mtp/layer_infer/pre_layer_infer.py diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 2d4209028..718eb9d19 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -993,7 +993,9 @@ def _init_padded_req(self): def _gen_special_model_input(self, token_num: int): special_model_input = {} - is_deepseekv3_mtp_draft_model = "Deepseek3MTPModel" in str(self.__class__) + is_deepseekv3_mtp_draft_model = "Deepseek3MTPModel" in str(self.__class__) or "Qwen3MOEMTPModel" in str( + self.__class__ + ) if is_deepseekv3_mtp_draft_model: special_model_input["deepseekv3_mtp_draft_input_hiddens"] = torch.randn( token_num, self.config["hidden_size"], dtype=self.data_type, device="cuda" diff --git a/lightllm/models/qwen3_moe_mtp/layer_infer/__init__.py b/lightllm/models/qwen3_moe_mtp/layer_infer/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/lightllm/models/qwen3_moe_mtp/layer_infer/pre_layer_infer.py b/lightllm/models/qwen3_moe_mtp/layer_infer/pre_layer_infer.py deleted file mode 100644 index 66a41da73..000000000 --- a/lightllm/models/qwen3_moe_mtp/layer_infer/pre_layer_infer.py +++ /dev/null @@ -1,60 +0,0 @@ -import torch - -from lightllm.models.deepseek_mtp.layer_weights.pre_and_post_layer_weight import Deepseek3MTPPreAndPostLayerWeight -from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo -from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer -from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward - - -class Deepseek3MTPPreLayerInfer(LlamaPreLayerInfer): - """ """ - - def __init__(self, network_config, mode): - super().__init__(network_config, mode) - self.eps_ = network_config["rms_norm_eps"] - self.hidden_size = network_config["hidden_size"] - return - - def _mtp_context_forward( - self, input_embdings, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek3MTPPreAndPostLayerWeight - ): - tgt_embdings = infer_state.deepseekv3_mtp_draft_input_hiddens - assert input_embdings.shape[0] == tgt_embdings.shape[0] - rmsnorm_forward(input_embdings, weight=layer_weight.enorm_weight_, eps=self.eps_, out=input_embdings) - rmsnorm_forward(tgt_embdings, weight=layer_weight.hnorm_weight_, eps=self.eps_, out=tgt_embdings) - - cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1) - - ans_logics = self.alloc_tensor( - (cat_embdings.shape[0], layer_weight.eh_proj_weight_.shape[1]), dtype=input_embdings.dtype - ) - torch.mm(cat_embdings, layer_weight.eh_proj_weight_, out=ans_logics) - return ans_logics - - def _mtp_token_forward( - self, input_embdings, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek3MTPPreAndPostLayerWeight - ): - tgt_embdings = infer_state.deepseekv3_mtp_draft_input_hiddens - assert input_embdings.shape[0] == tgt_embdings.shape[0] - rmsnorm_forward(input_embdings, weight=layer_weight.enorm_weight_, eps=self.eps_, out=input_embdings) - rmsnorm_forward(tgt_embdings, weight=layer_weight.hnorm_weight_, eps=self.eps_, out=tgt_embdings) - - cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1) - - ans_logics = self.alloc_tensor( - (cat_embdings.shape[0], layer_weight.eh_proj_weight_.shape[1]), dtype=input_embdings.dtype - ) - torch.mm(cat_embdings, layer_weight.eh_proj_weight_, out=ans_logics) - return ans_logics - - def context_forward( - self, input_ids, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek3MTPPreAndPostLayerWeight - ): - input_embdings = super().context_forward(input_ids, infer_state, layer_weight) - return self._mtp_context_forward(input_embdings, infer_state, layer_weight) - - def token_forward( - self, input_ids, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek3MTPPreAndPostLayerWeight - ): - input_embdings = super().token_forward(input_ids, infer_state, layer_weight) - return self._mtp_token_forward(input_embdings, infer_state, layer_weight) diff --git a/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py index f5b805647..408992178 100644 --- a/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py @@ -2,23 +2,21 @@ from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight -class Deepseek3MTPPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): +class Qwen3MOEMTPPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) - # 与DeepseekV3模型共享 + # 与Qwen3MOE模型共享 self.wte_weight_ = None self.lm_head_weight_ = None return def load_hf_weights(self, weights): - if "model.layers.0.eh_proj.weight" in weights: - self.eh_proj_weight_ = self._cuda(weights["model.layers.0.eh_proj.weight"]).t() - if "model.layers.0.enorm.weight" in weights: - self.enorm_weight_ = self._cuda(weights["model.layers.0.enorm.weight"]) - if "model.layers.0.hnorm.weight" in weights: - self.hnorm_weight_ = self._cuda(weights["model.layers.0.hnorm.weight"]) - if "model.layers.0.shared_head.norm.weight" in weights: - self.final_norm_weight_ = self._cuda(weights["model.layers.0.shared_head.norm.weight"]) + if "model.0.proj.weight" in weights: + self.eh_proj_weight_ = self._cuda(weights["model.0.proj.weight"]).t() + if "model.0.norm_after_embedding.weight" in weights: + self.enorm_weight_ = self._cuda(weights["model.0.norm_after_embedding.weight"]) + if "model.0.norm_before_output.weight" in weights: + self.hnorm_weight_ = self._cuda(weights["model.0.norm_before_output.weight"]) return def verify_load(self): diff --git a/lightllm/models/qwen3_moe_mtp/model.py b/lightllm/models/qwen3_moe_mtp/model.py index 4586db4be..ba6c82804 100644 --- a/lightllm/models/qwen3_moe_mtp/model.py +++ b/lightllm/models/qwen3_moe_mtp/model.py @@ -1,12 +1,12 @@ from lightllm.models.qwen3_moe.model import Qwen3MOEModel +from lightllm.models.qwen3_moe_mtp.layer_weights.pre_and_post_layer_weight import Qwen3MOEMTPPreAndPostLayerWeight from lightllm.models.deepseek_mtp.layer_infer.pre_layer_infer import Deepseek3MTPPreLayerInfer -from lightllm.models.deepseek_mtp.layer_weights.pre_and_post_layer_weight import Deepseek3MTPPreAndPostLayerWeight from lightllm.common.basemodel import TpPartBaseModel class Qwen3MOEMTPModel(Qwen3MOEModel): - pre_and_post_weight_class = Deepseek3MTPPreAndPostLayerWeight + pre_and_post_weight_class = Qwen3MOEMTPPreAndPostLayerWeight pre_layer_infer_class = Deepseek3MTPPreLayerInfer def __init__(self, kvargs: dict): From 652dd7d11d5de6ae440368c0abe6f638bf2c9708 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 1 Dec 2025 13:13:11 +0000 Subject: [PATCH 03/79] fix qwen3 fa3 mtp --- .../models/llama/flashattention_infer_struct.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/lightllm/models/llama/flashattention_infer_struct.py b/lightllm/models/llama/flashattention_infer_struct.py index c6e7aa560..4cfd72e81 100644 --- a/lightllm/models/llama/flashattention_infer_struct.py +++ b/lightllm/models/llama/flashattention_infer_struct.py @@ -38,22 +38,29 @@ def _init_flash_attention_state(self, model, input_ids: torch.Tensor): self.cu_seqlens_q = self.b1_cu_q_seq_len.int() self.cu_seqlens_k = self.b1_cu_kv_seq_len.int() max_seq_len_k = self.max_kv_seq_len + args_mtp_step = get_env_start_args().mtp_step + att_batch_size = self.batch_size // (args_mtp_step + 1) if self.batch_size <= model.graph_max_batch_size and self.max_len_in_batch <= model.graph_max_len_in_batch: page_buffer = FlashAttentionStateInfo.get_page_table_buffer( model.graph_max_batch_size, model.graph_max_len_in_batch ) self.page_table = page_buffer[self.microbatch_index][ - : self.batch_size * model.graph_max_len_in_batch - ].reshape(self.batch_size, model.graph_max_len_in_batch) + : att_batch_size * model.graph_max_len_in_batch + ].reshape(att_batch_size, model.graph_max_len_in_batch) else: self.page_table = torch.empty( - (self.batch_size, self.max_len_in_batch), dtype=torch.int32, device=input_ids.device + (att_batch_size, self.max_len_in_batch), dtype=torch.int32, device=input_ids.device ) + page_table_copy( page_table=self.page_table[:, :max_seq_len_k], req_to_token_indexs=model.req_manager.req_to_token_indexs, - b_req_idx=self.b_req_idx, + b_req_idx=self.b_req_idx[args_mtp_step :: (args_mtp_step + 1)], ) + if args_mtp_step > 0: + self.b_att_seq_len = self.b_seq_len[args_mtp_step :: (args_mtp_step + 1)].contiguous() + else: + self.b_att_seq_len = self.b_seq_len if "offline_calibration_fp8kv" in model.mode: if self.is_prefill: From 11dc305c6e994fd82a959db029ce3bbc7957492a Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 1 Dec 2025 13:17:57 +0000 Subject: [PATCH 04/79] fix --- lightllm/models/llama/layer_infer/transformer_layer_infer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py index 8c6015677..ea44fe2e5 100755 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -883,10 +883,10 @@ def _token_decode_attention_flashattention(self, q, infer_state: FlashAttentionS k_cache=cache_k, v_cache=cache_v, page_table=infer_state.page_table, - cache_seqlens=infer_state.b_seq_len, + cache_seqlens=infer_state.b_att_seq_len, cu_seqlens_q=infer_state.cu_seqlens_q, cu_seqlens_k_new=infer_state.cu_seqlens_k, - max_seqlen_q=1, + max_seqlen_q=infer_state.max_q_seq_len, softmax_scale=sm_scale, causal=True, window_size=(-1, -1), From 5c2ae2407d8f0d50b2e912d658ca53be69761501 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 1 Dec 2025 13:27:58 +0000 Subject: [PATCH 05/79] fix --- .../qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py index 408992178..bb18a4dda 100644 --- a/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py @@ -21,7 +21,7 @@ def load_hf_weights(self, weights): def verify_load(self): errors = "weights load not ok" - weights = [self.eh_proj_weight_, self.enorm_weight_, self.hnorm_weight_, self.final_norm_weight_] + weights = [self.eh_proj_weight_, self.enorm_weight_, self.hnorm_weight_] for i in range(len(weights)): assert weights[i] is not None, "index:" + str(i) + " " + errors return From f09c9bb400b39147b5c9d28aa7389721fd1d59c7 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 1 Dec 2025 13:51:25 +0000 Subject: [PATCH 06/79] fix --- .../layer_weights/pre_and_post_layer_weight.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py index bb18a4dda..57d98eec9 100644 --- a/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py @@ -11,12 +11,12 @@ def __init__(self, data_type, network_config, mode): return def load_hf_weights(self, weights): - if "model.0.proj.weight" in weights: - self.eh_proj_weight_ = self._cuda(weights["model.0.proj.weight"]).t() - if "model.0.norm_after_embedding.weight" in weights: - self.enorm_weight_ = self._cuda(weights["model.0.norm_after_embedding.weight"]) - if "model.0.norm_before_output.weight" in weights: - self.hnorm_weight_ = self._cuda(weights["model.0.norm_before_output.weight"]) + if "model.layers.0.proj.weight" in weights: + self.eh_proj_weight_ = self._cuda(weights["model.layers.0.proj.weight"]).t() + if "model.layers.0.norm_after_embedding.weight" in weights: + self.enorm_weight_ = self._cuda(weights["model.layers.0.norm_after_embedding.weight"]) + if "model.layers.0.norm_before_output.weight" in weights: + self.hnorm_weight_ = self._cuda(weights["model.layers.0.norm_before_output.weight"]) return def verify_load(self): From c9de6f65059d493d9f1ddfdea830dd99a6d6edea Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 30 Dec 2025 04:48:17 +0000 Subject: [PATCH 07/79] fix rebase --- lightllm/models/qwen2/model.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/lightllm/models/qwen2/model.py b/lightllm/models/qwen2/model.py index 5b756aadf..19d3c00a6 100644 --- a/lightllm/models/qwen2/model.py +++ b/lightllm/models/qwen2/model.py @@ -3,6 +3,7 @@ from lightllm.models.qwen2.layer_weights.transformer_layer_weight import Qwen2TransformerLayerWeight from lightllm.models.llama.model import LlamaTpPartModel from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class +from lightllm.utils.envs_utils import get_env_start_args @ModelRegistry("qwen2") @@ -41,12 +42,20 @@ def _init_mem_manager(self): head_dim_ = self.config["hidden_size"] // self.config["num_attention_heads"] head_dim_ = self.config.get("head_dim", head_dim_) tp_k_head_num_ = max(self.config["num_key_value_heads"] // self.tp_world_size_, 1) + + # mtp 模式下需要在mem manger上扩展draft model使用的layer + added_mtp_layer_num = 0 + if get_env_start_args().mtp_mode == "deepseekv3_eagle": + added_mtp_layer_num += 1 + elif get_env_start_args().mtp_mode == "deepseekv3_vanilla": + added_mtp_layer_num += get_env_start_args().mtp_step + self.mem_manager = select_mem_manager_class()( self.max_total_token_num, dtype=self.data_type, head_num=tp_k_head_num_, head_dim=head_dim_, - layer_num=self.config["num_hidden_layers"], + layer_num=self.config["num_hidden_layers"] + added_mtp_layer_num, mem_fraction=self.mem_fraction, ) return From 60b31139df185c2ad75ac6b98b9df2c890c68320 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Tue, 9 Dec 2025 09:16:36 +0000 Subject: [PATCH 08/79] mtp dense --- .../qwen3_moe_mtp/layer_infer/__init__.py | 0 .../layer_infer/transformer_layer_infer.py | 37 +++++++++++++++++++ 2 files changed, 37 insertions(+) create mode 100644 lightllm/models/qwen3_moe_mtp/layer_infer/__init__.py create mode 100644 lightllm/models/qwen3_moe_mtp/layer_infer/transformer_layer_infer.py diff --git a/lightllm/models/qwen3_moe_mtp/layer_infer/__init__.py b/lightllm/models/qwen3_moe_mtp/layer_infer/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/qwen3_moe_mtp/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_moe_mtp/layer_infer/transformer_layer_infer.py new file mode 100644 index 000000000..d21917340 --- /dev/null +++ b/lightllm/models/qwen3_moe_mtp/layer_infer/transformer_layer_infer.py @@ -0,0 +1,37 @@ +import os +import torch +import torch.functional as F +import torch.distributed as dist +import numpy as np +import triton +from typing import Tuple +from lightllm.common.basemodel.infer_struct import InferStateInfo +from lightllm.models.qwen3_moe.layer_infer.transformer_layer_infer import Qwen3MOETransformerLayerInfer +from lightllm.distributed.communication_op import all_reduce +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class Qwen3MOEMTPTransformerLayerInfer(Qwen3MOETransformerLayerInfer): + def __init__(self, layer_num, network_config, mode=[]): + super().__init__(layer_num, network_config, mode) + return + + def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): + input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) + ffn_out = self._ffn(input1, infer_state, layer_weight) + input1 = None + if self.tp_world_size_ > 1: + all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) + return input_embdings + + def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): + input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) + ffn_out = self._ffn(input1, infer_state, layer_weight) + input1 = None + if self.tp_world_size_ > 1: + all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) + return input_embdings From bfa1cfa1674c7faac87231601b6170c98d5c6d64 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Tue, 9 Dec 2025 09:30:23 +0000 Subject: [PATCH 09/79] mtp dense weight --- .../layer_weights/transformer_layer_weight.py | 11 +++++++++++ lightllm/models/qwen3_moe_mtp/model.py | 5 +++++ 2 files changed, 16 insertions(+) create mode 100644 lightllm/models/qwen3_moe_mtp/layer_weights/transformer_layer_weight.py diff --git a/lightllm/models/qwen3_moe_mtp/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_moe_mtp/layer_weights/transformer_layer_weight.py new file mode 100644 index 000000000..61df1d7e1 --- /dev/null +++ b/lightllm/models/qwen3_moe_mtp/layer_weights/transformer_layer_weight.py @@ -0,0 +1,11 @@ +import os +from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight + + +class Qwen3MOEMTPTransformerLayerWeight(Qwen3MOETransformerLayerWeight): + def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): + super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + return + + def _init_weight(self): + self._init_ffn() diff --git a/lightllm/models/qwen3_moe_mtp/model.py b/lightllm/models/qwen3_moe_mtp/model.py index ba6c82804..d73483fd6 100644 --- a/lightllm/models/qwen3_moe_mtp/model.py +++ b/lightllm/models/qwen3_moe_mtp/model.py @@ -1,6 +1,8 @@ from lightllm.models.qwen3_moe.model import Qwen3MOEModel from lightllm.models.qwen3_moe_mtp.layer_weights.pre_and_post_layer_weight import Qwen3MOEMTPPreAndPostLayerWeight from lightllm.models.deepseek_mtp.layer_infer.pre_layer_infer import Deepseek3MTPPreLayerInfer +from lightllm.models.qwen3_moe_mtp.layer_infer.transformer_layer_infer import Qwen3MOEMTPTransformerLayerInfer +from lightllm.models.qwen3_moe_mtp.layer_weights.transformer_layer_weight import Qwen3MOEMTPTransformerLayerWeight from lightllm.common.basemodel import TpPartBaseModel @@ -9,6 +11,9 @@ class Qwen3MOEMTPModel(Qwen3MOEModel): pre_and_post_weight_class = Qwen3MOEMTPPreAndPostLayerWeight pre_layer_infer_class = Deepseek3MTPPreLayerInfer + transformer_weight_class = Qwen3MOEMTPTransformerLayerWeight + transformer_layer_infer_class = Qwen3MOEMTPTransformerLayerInfer + def __init__(self, kvargs: dict): self._pre_init(kvargs) super().__init__(kvargs) From 323761e7e1c222de6e8d3707cd84c9261c9ed389 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Tue, 9 Dec 2025 09:43:18 +0000 Subject: [PATCH 10/79] fix --- .../qwen3_moe_mtp/layer_weights/transformer_layer_weight.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/lightllm/models/qwen3_moe_mtp/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_moe_mtp/layer_weights/transformer_layer_weight.py index 61df1d7e1..b14a5e785 100644 --- a/lightllm/models/qwen3_moe_mtp/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3_moe_mtp/layer_weights/transformer_layer_weight.py @@ -8,4 +8,7 @@ def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None return def _init_weight(self): - self._init_ffn() + if self.is_moe: + self._init_moe() + else: + self._init_ffn() From e973b5ea9aa40ad9de565980f24c506114708818 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Tue, 9 Dec 2025 10:29:52 +0000 Subject: [PATCH 11/79] fix --- .../qwen3_moe_mtp/layer_weights/transformer_layer_weight.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/lightllm/models/qwen3_moe_mtp/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_moe_mtp/layer_weights/transformer_layer_weight.py index b14a5e785..f0a33f723 100644 --- a/lightllm/models/qwen3_moe_mtp/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3_moe_mtp/layer_weights/transformer_layer_weight.py @@ -1,5 +1,6 @@ import os from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import NormWeight class Qwen3MOEMTPTransformerLayerWeight(Qwen3MOETransformerLayerWeight): @@ -12,3 +13,8 @@ def _init_weight(self): self._init_moe() else: self._init_ffn() + + def _init_norm(self): + self.ffn_norm_weight_ = NormWeight( + self._ffn_norm_weight_name, self.data_type_, bias_name=self._ffn_norm_bias_name + ) From aeb46094e5e148360cd11863fc33ef05554a8b4c Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Tue, 9 Dec 2025 10:45:02 +0000 Subject: [PATCH 12/79] fix --- .../qwen3_moe_mtp/layer_weights/transformer_layer_weight.py | 1 + 1 file changed, 1 insertion(+) diff --git a/lightllm/models/qwen3_moe_mtp/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_moe_mtp/layer_weights/transformer_layer_weight.py index f0a33f723..feb06c5d4 100644 --- a/lightllm/models/qwen3_moe_mtp/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3_moe_mtp/layer_weights/transformer_layer_weight.py @@ -9,6 +9,7 @@ def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None return def _init_weight(self): + self._init_norm() if self.is_moe: self._init_moe() else: From 224c398d089944ca26a9693346e66001258777ac Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Tue, 30 Dec 2025 03:22:21 +0000 Subject: [PATCH 13/79] remove mtp norm --- lightllm/common/basemodel/basemodel.py | 6 +- lightllm/models/mistral_mtp/__init__.py | 0 .../mistral_mtp/layer_infer/__init__.py | 0 .../layer_infer/post_layer_infer.py | 149 ++++++++++++++++++ .../layer_infer/pre_layer_infer.py | 113 +++++++++++++ .../layer_infer/transformer_layer_infer.py | 33 ++++ .../mistral_mtp/layer_weights/__init__.py | 0 .../layer_weights/hf_load_utils.py | 73 +++++++++ .../pre_and_post_layer_weight.py | 36 +++++ .../layer_weights/transformer_layer_weight.py | 17 ++ lightllm/models/mistral_mtp/model.py | 80 ++++++++++ .../model_infer/mode_backend/base_backend.py | 3 + 12 files changed, 508 insertions(+), 2 deletions(-) create mode 100644 lightllm/models/mistral_mtp/__init__.py create mode 100644 lightllm/models/mistral_mtp/layer_infer/__init__.py create mode 100644 lightllm/models/mistral_mtp/layer_infer/post_layer_infer.py create mode 100644 lightllm/models/mistral_mtp/layer_infer/pre_layer_infer.py create mode 100644 lightllm/models/mistral_mtp/layer_infer/transformer_layer_infer.py create mode 100644 lightllm/models/mistral_mtp/layer_weights/__init__.py create mode 100644 lightllm/models/mistral_mtp/layer_weights/hf_load_utils.py create mode 100644 lightllm/models/mistral_mtp/layer_weights/pre_and_post_layer_weight.py create mode 100644 lightllm/models/mistral_mtp/layer_weights/transformer_layer_weight.py create mode 100644 lightllm/models/mistral_mtp/model.py diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 718eb9d19..9065293ae 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -993,8 +993,10 @@ def _init_padded_req(self): def _gen_special_model_input(self, token_num: int): special_model_input = {} - is_deepseekv3_mtp_draft_model = "Deepseek3MTPModel" in str(self.__class__) or "Qwen3MOEMTPModel" in str( - self.__class__ + is_deepseekv3_mtp_draft_model = ( + "Deepseek3MTPModel" in str(self.__class__) + or "Qwen3MOEMTPModel" in str(self.__class__) + or "MistralMTPModel" in str(self.__class__) ) if is_deepseekv3_mtp_draft_model: special_model_input["deepseekv3_mtp_draft_input_hiddens"] = torch.randn( diff --git a/lightllm/models/mistral_mtp/__init__.py b/lightllm/models/mistral_mtp/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/mistral_mtp/layer_infer/__init__.py b/lightllm/models/mistral_mtp/layer_infer/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/mistral_mtp/layer_infer/post_layer_infer.py b/lightllm/models/mistral_mtp/layer_infer/post_layer_infer.py new file mode 100644 index 000000000..fa0d0b374 --- /dev/null +++ b/lightllm/models/mistral_mtp/layer_infer/post_layer_infer.py @@ -0,0 +1,149 @@ +import os +import torch +import torch.functional as F +import torch.distributed as dist +import numpy as np +from lightllm.common.basemodel.layer_weights.base_layer_weight import BaseLayerWeight + +from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight +from einops import rearrange +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward +from lightllm.common.basemodel import PostLayerInferTpl +from lightllm.utils.infer_utils import mark_cost_time +from lightllm.distributed.communication_op import all_gather + + +class MistralMTPPostLayerInfer(PostLayerInferTpl): + """ """ + + def __init__(self, network_config, mode): + super().__init__(network_config, mode) + self.eps_ = network_config["rms_norm_eps"] + self.vocab_size_ = network_config["vocab_size"] + self.embed_dim_ = network_config["n_embed"] + return + + def _norm(self, input, infer_state, layer_weight: LlamaPreAndPostLayerWeight) -> torch.Tensor: + return rmsnorm_forward(input, layer_weight.final_norm_weight_, eps=self.eps_) + + def _slice_get_last_input(self, input_embdings, infer_state: LlamaInferStateInfo): + + if infer_state.is_prefill and infer_state.is_token_healing: + batch_size = infer_state.batch_size + b_seq_len_numpy = (infer_state.b_seq_len - infer_state.b_ready_cache_len).detach().cpu().numpy() + select_index = [] + start_index = 0 + select_token_num = 0 + for cur_len in b_seq_len_numpy: + select_index.append(start_index + cur_len - 1) + start_index += cur_len + select_token_num += 1 + + last_index = torch.tensor(select_index, dtype=torch.long, device=input_embdings.device) + last_input = self.alloc_tensor((select_token_num, self.embed_dim_), dtype=input_embdings.dtype) + last_input[:, :] = input_embdings[last_index, :] + return last_input, select_token_num + + if infer_state.is_prefill and not infer_state.return_all_prompt_logics: + batch_size = infer_state.batch_size + last_input = self.alloc_tensor((batch_size, self.embed_dim_), dtype=input_embdings.dtype) + last_index = ( + torch.cumsum(infer_state.b_seq_len - infer_state.b_ready_cache_len, dim=0, dtype=torch.long) - 1 + ) + last_input[:, :] = input_embdings[last_index, :] + return last_input, batch_size + + if infer_state.is_prefill and infer_state.return_all_prompt_logics: + total_tokens = infer_state.total_token_num + return input_embdings, total_tokens + + if not infer_state.is_prefill: + batch_size = infer_state.batch_size + return input_embdings[-batch_size:, :], batch_size + + assert False, "Error State" + + def token_forward(self, input_embdings, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): + last_input, token_num = self._slice_get_last_input(input_embdings, infer_state) + input_embdings_dtype = input_embdings.dtype + input_embdings = None + last_input = last_input.permute(1, 0).view(-1, token_num) + logic_batch = self.alloc_tensor( + (layer_weight.lm_head_weight_.shape[0], last_input.shape[1]), dtype=last_input.dtype + ) + torch.mm(layer_weight.lm_head_weight_, last_input, out=logic_batch) + + last_input = None + if self.tp_world_size_ == 1: + gather_data = logic_batch + else: + gather_data = self.alloc_tensor((self.vocab_size_, token_num), dtype=input_embdings_dtype) + split_indexes = np.linspace(0, self.vocab_size_, self.tp_world_size_ + 1, dtype=np.int64) + all_gather( + [gather_data[split_indexes[i] : split_indexes[i + 1], :] for i in range(self.tp_world_size_)], + logic_batch, + group=infer_state.dist_group, + async_op=False, + ) + logic_batch = None + ans_logics = self.alloc_tensor( + (token_num, self.vocab_size_), + dtype=torch.float32, + is_graph_out=True, + microbatch_index=infer_state.microbatch_index, + ) + ans_logics[:, :] = gather_data.permute(1, 0) + gather_data = None + return ans_logics + + def tpsp_token_forward( + self, input_embdings: torch.Tensor, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight + ): + if self.tp_world_size_ > 1: + assert len(input_embdings.shape) == 2 + token_num, hidden_dim = input_embdings.shape + gather_data = torch.empty( + (self.tp_world_size_ * token_num, hidden_dim), device=input_embdings.device, dtype=input_embdings.dtype + ) + all_gather( + [gather_data[i * token_num : (i + 1) * token_num, :] for i in range(self.tp_world_size_)], + input_embdings, + group=infer_state.dist_group, + async_op=False, + ) + # len(infer_state.position_sin) 获取真实输入长度 + input_embdings = gather_data[0 : len(infer_state.position_sin)] + + if infer_state.need_dp_prefill_balance: + input_embdings = infer_state._all_to_all_unbalance_get(data=input_embdings) + + return self.token_forward(input_embdings=input_embdings, infer_state=infer_state, layer_weight=layer_weight) + + def overlap_tpsp_token_forward( + self, + input_embdings: torch.Tensor, + input_embdings1: torch.Tensor, + infer_state: LlamaInferStateInfo, + infer_state1: LlamaInferStateInfo, + layer_weight: BaseLayerWeight, + ): + if getattr(infer_state, "hook", None) is not None: + infer_state.hook() + infer_state.hook = None + + if infer_state.need_dp_prefill_balance: + input_embdings = infer_state._all_to_all_unbalance_get(data=input_embdings) + + logics = self.tpsp_token_forward(input_embdings, infer_state, layer_weight=layer_weight) + + if getattr(infer_state1, "hook", None) is not None: + infer_state1.hook() + infer_state1.hook = None + + if infer_state1.need_dp_prefill_balance: + input_embdings1 = infer_state1._all_to_all_unbalance_get(data=input_embdings1) + + logics1 = self.tpsp_token_forward(input_embdings1, infer_state1, layer_weight=layer_weight) + + return logics, logics1 diff --git a/lightllm/models/mistral_mtp/layer_infer/pre_layer_infer.py b/lightllm/models/mistral_mtp/layer_infer/pre_layer_infer.py new file mode 100644 index 000000000..b3ad86004 --- /dev/null +++ b/lightllm/models/mistral_mtp/layer_infer/pre_layer_infer.py @@ -0,0 +1,113 @@ +import os +import torch +import torch.distributed as dist +import numpy as np + +from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.common.basemodel import PreLayerInferTpl +from lightllm.utils.infer_utils import mark_cost_time +from lightllm.models.llama.triton_kernel.embedding import embedding +from lightllm.distributed.communication_op import all_reduce +from lightllm.utils.envs_utils import get_env_start_args + + +class MistralMTPPreLayerInfer(PreLayerInferTpl): + """ """ + + def __init__(self, network_config, mode): + super().__init__(network_config, mode) + tp_vob_ids = np.linspace(0, network_config["vocab_size"], self.tp_world_size_ + 1, dtype=np.int64) + self.vob_start_id_, self.vob_end_id_ = int(tp_vob_ids[self.tp_rank_]), int(tp_vob_ids[self.tp_rank_ + 1]) + + return + + def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): + input_embdings = self.alloc_tensor( + (input_ids.shape[0], layer_weight.wte_weight_.shape[1]), dtype=layer_weight.data_type_ + ) + embedding(input_ids, layer_weight.wte_weight_, self.vob_start_id_, self.vob_end_id_, input_embdings) + if self.tp_world_size_ > 1: + all_reduce(input_embdings, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + return input_embdings + + def token_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): + input_embdings = self.alloc_tensor( + (input_ids.shape[0], layer_weight.wte_weight_.shape[1]), dtype=layer_weight.data_type_ + ) + embedding(input_ids, layer_weight.wte_weight_, self.vob_start_id_, self.vob_end_id_, input_embdings) + if self.tp_world_size_ > 1: + all_reduce(input_embdings, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + return input_embdings + + def tpsp_context_forward( + self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight + ): + if get_env_start_args().enable_dp_prefill_balance: + input_ids = infer_state.prefill_dp_balance(input_ids=input_ids) + + input_embdings = self.context_forward(input_ids=input_ids, infer_state=infer_state, layer_weight=layer_weight) + from lightllm.common.basemodel.triton_kernel.sp_pad_copy import sp_pad_copy + + padded_input_embdings = sp_pad_copy(input_embdings, sp_rank_id=self.tp_rank_, sp_world_size=self.tp_world_size_) + return padded_input_embdings + + def tpsp_token_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): + input_embdings = self.token_forward(input_ids=input_ids, infer_state=infer_state, layer_weight=layer_weight) + from lightllm.common.basemodel.triton_kernel.sp_pad_copy import sp_pad_copy + + padded_input_embdings = sp_pad_copy(input_embdings, sp_rank_id=self.tp_rank_, sp_world_size=self.tp_world_size_) + return padded_input_embdings + + def overlap_tpsp_token_forward( + self, + input_ids: torch.Tensor, + input_ids1: torch.Tensor, + infer_state: LlamaInferStateInfo, + infer_state1: LlamaInferStateInfo, + layer_weight: LlamaPreAndPostLayerWeight, + ): + + input_embdings = self.token_forward(input_ids=input_ids, infer_state=infer_state, layer_weight=layer_weight) + from lightllm.common.basemodel.triton_kernel.sp_pad_copy import sp_pad_copy + + padded_input_embdings = sp_pad_copy(input_embdings, sp_rank_id=self.tp_rank_, sp_world_size=self.tp_world_size_) + + input_embdings1 = self.token_forward(input_ids=input_ids1, infer_state=infer_state1, layer_weight=layer_weight) + from lightllm.common.basemodel.triton_kernel.sp_pad_copy import sp_pad_copy + + padded_input_embdings1 = sp_pad_copy( + input_embdings1, sp_rank_id=self.tp_rank_, sp_world_size=self.tp_world_size_ + ) + + return padded_input_embdings, padded_input_embdings1 + + def overlap_tpsp_context_forward( + self, + input_ids: torch.Tensor, + input_ids1: torch.Tensor, + infer_state: LlamaInferStateInfo, + infer_state1: LlamaInferStateInfo, + layer_weight: LlamaPreAndPostLayerWeight, + ): + if get_env_start_args().enable_dp_prefill_balance: + input_ids = infer_state.prefill_dp_balance(input_ids=input_ids) + + input_embdings = self.context_forward(input_ids=input_ids, infer_state=infer_state, layer_weight=layer_weight) + from lightllm.common.basemodel.triton_kernel.sp_pad_copy import sp_pad_copy + + padded_input_embdings = sp_pad_copy(input_embdings, sp_rank_id=self.tp_rank_, sp_world_size=self.tp_world_size_) + + if get_env_start_args().enable_dp_prefill_balance: + input_ids1 = infer_state1.prefill_dp_balance(input_ids=input_ids1) + + input_embdings1 = self.context_forward( + input_ids=input_ids1, infer_state=infer_state1, layer_weight=layer_weight + ) + from lightllm.common.basemodel.triton_kernel.sp_pad_copy import sp_pad_copy + + padded_input_embdings1 = sp_pad_copy( + input_embdings1, sp_rank_id=self.tp_rank_, sp_world_size=self.tp_world_size_ + ) + + return padded_input_embdings, padded_input_embdings1 diff --git a/lightllm/models/mistral_mtp/layer_infer/transformer_layer_infer.py b/lightllm/models/mistral_mtp/layer_infer/transformer_layer_infer.py new file mode 100644 index 000000000..5724f32af --- /dev/null +++ b/lightllm/models/mistral_mtp/layer_infer/transformer_layer_infer.py @@ -0,0 +1,33 @@ +import torch.functional as F +import torch.distributed as dist +import numpy as np +from lightllm.common.basemodel.infer_struct import InferStateInfo +from lightllm.models.mistral.layer_infer.transformer_layer_infer import MistralTransformerLayerInfer +from lightllm.distributed.communication_op import all_reduce +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class MistralMTPTransformerLayerInfer(MistralTransformerLayerInfer): + def __init__(self, layer_num, network_config, mode=[]): + super().__init__(layer_num, network_config, mode) + return + + def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): + input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) + ffn_out = self._ffn(input1, infer_state, layer_weight) + input1 = None + if self.tp_world_size_ > 1: + all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) + return input_embdings + + def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): + input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) + ffn_out = self._ffn(input1, infer_state, layer_weight) + input1 = None + if self.tp_world_size_ > 1: + all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) + return input_embdings diff --git a/lightllm/models/mistral_mtp/layer_weights/__init__.py b/lightllm/models/mistral_mtp/layer_weights/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/mistral_mtp/layer_weights/hf_load_utils.py b/lightllm/models/mistral_mtp/layer_weights/hf_load_utils.py new file mode 100644 index 000000000..1819f6319 --- /dev/null +++ b/lightllm/models/mistral_mtp/layer_weights/hf_load_utils.py @@ -0,0 +1,73 @@ +import torch +import os +import gc +from safetensors import safe_open +from tqdm import tqdm +import lightllm.utils.petrel_helper as utils +from lightllm.utils.dist_utils import get_current_device_id + + +def load_func(file_, use_safetensors=False, pre_post_layer=None, transformer_layer_list=None, weight_dir=None): + # fix bug for 多线程加载的时候,每个线程内部的cuda device 会切回 0, 修改后来保证不会出现bug + import torch.distributed as dist + + torch.cuda.set_device(get_current_device_id()) + + if use_safetensors: + weights = safe_open(os.path.join(weight_dir, file_), "pt", "cpu") + weights = {k: weights.get_tensor(k) for k in weights.keys()} + else: + weights = utils.PetrelHelper.load(os.path.join(weight_dir, file_), map_location="cpu") + + weights = {k: v for k, v in weights.items() if k.startswith("mtp.")} + if pre_post_layer is not None: + pre_post_layer.load_hf_weights(weights) + if transformer_layer_list is not None: + for layer in transformer_layer_list: + layer.load_hf_weights(weights) + del weights + gc.collect() + + +def load_hf_weights(data_type, weight_dir, pre_post_layer=None, transformer_layer_list=None, weight_dict=None): + if isinstance(data_type, str): + data_type = torch.float16 if data_type == "fp16" else torch.float32 + if pre_post_layer is not None: + assert pre_post_layer.data_type_ == data_type, "type is not right" + if transformer_layer_list is not None: + assert transformer_layer_list[0].data_type_ == data_type, "type is not right" + if weight_dict: + if pre_post_layer is not None: + pre_post_layer.load_hf_weights(weight_dict) + if transformer_layer_list is not None: + for layer in transformer_layer_list: + layer.load_hf_weights(weight_dict) + del weight_dict + return + use_safetensors = True + files = utils.PetrelHelper.list(weight_dir, extension="all") + candidate_files = list(filter(lambda x: x.endswith(".safetensors"), files)) + if len(candidate_files) == 0: + use_safetensors = False + candidate_files = list(filter(lambda x: x.endswith(".bin"), files)) + assert len(candidate_files) != 0, "can only support pytorch tensor and safetensors format for weights." + from functools import partial + from multiprocessing.pool import ThreadPool as Pool + + partial_func = partial( + load_func, + use_safetensors=use_safetensors, + pre_post_layer=pre_post_layer, + transformer_layer_list=transformer_layer_list, + weight_dir=weight_dir, + ) # noqa + worker = int(os.environ.get("LOADWORKER", 1)) + with Pool(worker) as p: + iterator = p.imap_unordered(partial_func, candidate_files, chunksize=1) + desc_str = f"pid {os.getpid()} Loading model weights with {worker} workers" + iterator = tqdm(iterator, total=len(candidate_files), desc=desc_str) + + for _ in iterator: + pass + + return diff --git a/lightllm/models/mistral_mtp/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/mistral_mtp/layer_weights/pre_and_post_layer_weight.py new file mode 100644 index 000000000..988b9241f --- /dev/null +++ b/lightllm/models/mistral_mtp/layer_weights/pre_and_post_layer_weight.py @@ -0,0 +1,36 @@ +import numpy as np +from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight + + +class MistralMTPPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): + def __init__(self, data_type, network_config, mode): + super().__init__(data_type, network_config, mode) + self.wte_weight_ = None + self.lm_head_weight_ = None + self.final_norm_weight_ = None + return + + def load_hf_weights(self, weights): + rename_weights(weights) + if "model.eh_proj.weight" in weights: + self.eh_proj_weight_ = self._cuda(weights["model.eh_proj.weight"]).t() + if "model.enorm.weight" in weights: + self.enorm_weight_ = self._cuda(weights["model.enorm.weight"]) + if "model.hnorm.weight" in weights: + self.hnorm_weight_ = self._cuda(weights["model.hnorm.weight"]) + return + + def verify_load(self): + errors = "weights load not ok" + weights = [self.eh_proj_weight_, self.enorm_weight_, self.hnorm_weight_] + for i in range(len(weights)): + assert weights[i] is not None, "index:" + str(i) + " " + errors + return + + +def rename_weights(weights): + all_keys = list(weights.keys()) + for key in all_keys: + if key.startswith("mtp."): + weights[key.replace("mtp.", "model.")] = weights.pop(key) + return weights diff --git a/lightllm/models/mistral_mtp/layer_weights/transformer_layer_weight.py b/lightllm/models/mistral_mtp/layer_weights/transformer_layer_weight.py new file mode 100644 index 000000000..057322d89 --- /dev/null +++ b/lightllm/models/mistral_mtp/layer_weights/transformer_layer_weight.py @@ -0,0 +1,17 @@ +from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import NormWeight + + +class MistralMTPTransformerLayerWeight(LlamaTransformerLayerWeight): + def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): + super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + return + + def _init_weight(self): + self._init_norm() + self._init_ffn() + + def _init_norm(self): + self.ffn_norm_weight_ = NormWeight( + self._ffn_norm_weight_name, self.data_type_, bias_name=self._ffn_norm_bias_name + ) diff --git a/lightllm/models/mistral_mtp/model.py b/lightllm/models/mistral_mtp/model.py new file mode 100644 index 000000000..07cca282b --- /dev/null +++ b/lightllm/models/mistral_mtp/model.py @@ -0,0 +1,80 @@ +from lightllm.models.mistral.model import MistralTpPartModel +from lightllm.models.mistral_mtp.layer_weights.pre_and_post_layer_weight import MistralMTPPreAndPostLayerWeight +from lightllm.models.deepseek_mtp.layer_infer.pre_layer_infer import Deepseek3MTPPreLayerInfer +from lightllm.models.mistral_mtp.layer_infer.transformer_layer_infer import MistralMTPTransformerLayerInfer +from lightllm.models.mistral_mtp.layer_weights.transformer_layer_weight import MistralMTPTransformerLayerWeight +from lightllm.common.basemodel import TpPartBaseModel +from .layer_weights.hf_load_utils import load_hf_weights + + +class MistralMTPModel(MistralTpPartModel): + + pre_and_post_weight_class = MistralMTPPreAndPostLayerWeight + pre_layer_infer_class = Deepseek3MTPPreLayerInfer + + transformer_weight_class = MistralMTPTransformerLayerWeight + transformer_layer_infer_class = MistralMTPTransformerLayerInfer + + def __init__(self, kvargs: dict): + self._pre_init(kvargs) + super().__init__(kvargs) + return + + def _pre_init(self, kvargs: dict): + self.main_model: TpPartBaseModel = kvargs.pop("main_model") + self.mem_layer_start = kvargs.pop("mem_layer_start", 0) + return + + def _init_some_value(self): + super()._init_some_value() + self.layers_num = 1 + return + + def _init_custom(self): + self._cos_cached = self.main_model._cos_cached + self._sin_cached = self.main_model._sin_cached + return + + def _init_req_manager(self): + self.req_manager = self.main_model.req_manager + return + + def _init_mem_manager(self): + self.mem_manager = self.main_model.mem_manager + return + + def _init_weights(self): + self.pre_post_weight = self.pre_and_post_weight_class( + self.data_type, network_config=self.config, mode=self.mode + ) + num_layer = 1 + self.trans_layers_weight = [ + self.transformer_weight_class( + i, + self.data_type, + network_config=self.config, + mode=self.mode, + quant_cfg=self.quant_cfg, + ) + for i in range(num_layer) + ] + load_hf_weights( + self.data_type, + weight_dir=self.weight_dir_, + pre_post_layer=self.pre_post_weight, + transformer_layer_list=self.trans_layers_weight, + weight_dict=self.weight_dict, + ) + self.pre_post_weight.verify_load() + [weight.verify_load() for weight in self.trans_layers_weight] + self.pre_post_weight.wte_weight_ = self.main_model.pre_post_weight.wte_weight_ + self.pre_post_weight.lm_head_weight_ = self.main_model.pre_post_weight.lm_head_weight_ + self.pre_post_weight.final_norm_weight_ = self.main_model.pre_post_weight.final_norm_weight_ + return + + def _init_infer_layer(self): + super()._init_infer_layer() + # reset the layer_num_ of the self.layers_infer + for layer in self.layers_infer: + layer.layer_num_ = layer.layer_num_ + self.mem_layer_start + return diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 3c9b98b38..86d86fe7e 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -36,6 +36,7 @@ from lightllm.server.router.model_infer.mode_backend.overlap_events import OverlapEventManager, OverlapEventPack from lightllm.models.deepseek_mtp.model import Deepseek3MTPModel from lightllm.models.qwen3_moe_mtp.model import Qwen3MOEMTPModel +from lightllm.models.mistral_mtp.model import MistralMTPModel from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token from lightllm.server.pd_io_struct import NIXLChunckedTransTaskRet @@ -326,6 +327,8 @@ def init_mtp_draft_model(self, main_kvargs: dict): self.draft_models.append(Deepseek3MTPModel(mtp_model_kvargs)) elif mtp_model_cfg["model_type"] == "qwen3_moe": self.draft_models.append(Qwen3MOEMTPModel(mtp_model_kvargs)) + elif mtp_model_cfg["model_type"] == "mistral": + self.draft_models.append(MistralMTPModel(mtp_model_kvargs)) else: assert False, f"error mtp mode {mtp_model_cfg['model_type']}" From 71bcd72ed0b800f61daec1215fe00c2454cc7cbb Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Fri, 12 Dec 2025 09:39:57 +0000 Subject: [PATCH 14/79] mtp dense --- lightllm/models/qwen3_moe_mtp/model.py | 28 +++++++++++++++++-- lightllm/server/api_cli.py | 1 + .../model_infer/mode_backend/base_backend.py | 7 +++-- 3 files changed, 31 insertions(+), 5 deletions(-) diff --git a/lightllm/models/qwen3_moe_mtp/model.py b/lightllm/models/qwen3_moe_mtp/model.py index d73483fd6..a53fa8f0e 100644 --- a/lightllm/models/qwen3_moe_mtp/model.py +++ b/lightllm/models/qwen3_moe_mtp/model.py @@ -4,6 +4,7 @@ from lightllm.models.qwen3_moe_mtp.layer_infer.transformer_layer_infer import Qwen3MOEMTPTransformerLayerInfer from lightllm.models.qwen3_moe_mtp.layer_weights.transformer_layer_weight import Qwen3MOEMTPTransformerLayerWeight from lightllm.common.basemodel import TpPartBaseModel +from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights class Qwen3MOEMTPModel(Qwen3MOEModel): @@ -22,6 +23,7 @@ def __init__(self, kvargs: dict): def _pre_init(self, kvargs: dict): self.main_model: TpPartBaseModel = kvargs.pop("main_model") self.mem_layer_start = kvargs.pop("mem_layer_start", 0) + self.mtp_index = kvargs.pop("mtp_index") return def _init_custom(self): @@ -38,7 +40,29 @@ def _init_mem_manager(self): return def _init_weights(self): - super()._init_weights() + self.pre_post_weight = self.pre_and_post_weight_class( + self.data_type, network_config=self.config, mode=self.mode + ) + self.trans_layers_weight = [ + self.transformer_weight_class( + i, + self.data_type, + network_config=self.config, + mode=self.mode, + quant_cfg=self.quant_cfg, + ) + for i in range(self.mtp_index, self.mtp_index + self.config["n_layer"]) + ] + if self.load_way == "HF": + load_hf_weights( + self.data_type, + weight_dir=self.weight_dir_, + pre_post_layer=self.pre_post_weight, + transformer_layer_list=self.trans_layers_weight, + weight_dict=self.weight_dict, + ) + self.pre_post_weight.verify_load() + [weight.verify_load() for weight in self.trans_layers_weight] self.pre_post_weight.wte_weight_ = self.main_model.pre_post_weight.wte_weight_ self.pre_post_weight.lm_head_weight_ = self.main_model.pre_post_weight.lm_head_weight_ self.pre_post_weight.final_norm_weight_ = self.main_model.pre_post_weight.final_norm_weight_ @@ -48,5 +72,5 @@ def _init_infer_layer(self): super()._init_infer_layer() # reset the layer_num_ of the self.layers_infer for layer in self.layers_infer: - layer.layer_num_ = layer.layer_num_ + self.mem_layer_start + layer.layer_num_ = layer.layer_num_ + self.mem_layer_start - self.mtp_index return diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 68293fc92..165bd5109 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -549,6 +549,7 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--mtp_draft_model_dir", type=str, + nargs="+", default=None, help="""Path to the draft model for the MTP multi-prediction feature, used for loading the MTP multi-output token model.""", diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 86d86fe7e..5945a53c4 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -298,9 +298,9 @@ def init_mtp_draft_model(self, main_kvargs: dict): assert False, f"error mtp mode {self.args.mtp_mode}" for i in range(num_mtp_modules): - mtp_model_cfg, _ = PretrainedConfig.get_config_dict(self.args.mtp_draft_model_dir) + mtp_model_cfg, _ = PretrainedConfig.get_config_dict(self.args.mtp_draft_model_dir[i]) mtp_model_kvargs = { - "weight_dir": self.args.mtp_draft_model_dir, + "weight_dir": self.args.mtp_draft_model_dir[i], "max_total_token_num": self.model.mem_manager.size, "load_way": main_kvargs["load_way"], "mode": main_kvargs["mode"], @@ -320,9 +320,10 @@ def init_mtp_draft_model(self, main_kvargs: dict): "run_mode": "normal", "main_model": self.model, "mem_layer_start": self.model.config["num_hidden_layers"] + i * mtp_model_cfg["num_hidden_layers"], + "mtp_index": i, } - mtp_model_cfg, _ = PretrainedConfig.get_config_dict(self.args.mtp_draft_model_dir) + mtp_model_cfg, _ = PretrainedConfig.get_config_dict(self.args.mtp_draft_model_dir[i]) if mtp_model_cfg["model_type"] == "deepseekv3": self.draft_models.append(Deepseek3MTPModel(mtp_model_kvargs)) elif mtp_model_cfg["model_type"] == "qwen3_moe": From 5046c53b6dcd5bcb551de683d2f74765ffdc2705 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Thu, 25 Dec 2025 09:12:43 +0000 Subject: [PATCH 15/79] update --- .../layer_weights/pre_and_post_layer_weight.py | 8 ++++++++ lightllm/models/qwen3_moe_mtp/model.py | 4 ++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py index 57d98eec9..19f6958f6 100644 --- a/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py @@ -11,12 +11,20 @@ def __init__(self, data_type, network_config, mode): return def load_hf_weights(self, weights): + vob_size = self.network_config_["vocab_size"] + split_indexes = np.linspace(0, vob_size, self.tp_world_size_ + 1, dtype=np.int64) + split_start = split_indexes[self.tp_rank_] + split_end = split_indexes[self.tp_rank_ + 1] if "model.layers.0.proj.weight" in weights: self.eh_proj_weight_ = self._cuda(weights["model.layers.0.proj.weight"]).t() if "model.layers.0.norm_after_embedding.weight" in weights: self.enorm_weight_ = self._cuda(weights["model.layers.0.norm_after_embedding.weight"]) if "model.layers.0.norm_before_output.weight" in weights: self.hnorm_weight_ = self._cuda(weights["model.layers.0.norm_before_output.weight"]) + if "lm_head.weight" in weights: + self.lm_head_weight_ = self._cuda(weights["lm_head.weight"][split_start:split_end, :]) + if "model.norm.weight" in weights: + self.final_norm_weight_ = self._cuda(weights["model.norm.weight"]) return def verify_load(self): diff --git a/lightllm/models/qwen3_moe_mtp/model.py b/lightllm/models/qwen3_moe_mtp/model.py index a53fa8f0e..43f9cd473 100644 --- a/lightllm/models/qwen3_moe_mtp/model.py +++ b/lightllm/models/qwen3_moe_mtp/model.py @@ -64,8 +64,8 @@ def _init_weights(self): self.pre_post_weight.verify_load() [weight.verify_load() for weight in self.trans_layers_weight] self.pre_post_weight.wte_weight_ = self.main_model.pre_post_weight.wte_weight_ - self.pre_post_weight.lm_head_weight_ = self.main_model.pre_post_weight.lm_head_weight_ - self.pre_post_weight.final_norm_weight_ = self.main_model.pre_post_weight.final_norm_weight_ + # self.pre_post_weight.lm_head_weight_ = self.main_model.pre_post_weight.lm_head_weight_ + # self.pre_post_weight.final_norm_weight_ = self.main_model.pre_post_weight.final_norm_weight_ return def _init_infer_layer(self): From 043799b704b4c90e2c7e086c593c226ae4855d68 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 30 Dec 2025 06:23:08 +0000 Subject: [PATCH 16/79] fix --- lightllm/models/deepseek_mtp/model.py | 9 +- .../layer_infer/post_layer_infer.py | 144 +----------------- .../layer_infer/pre_layer_infer.py | 108 +------------ .../layer_weights/hf_load_utils.py | 73 --------- lightllm/models/mistral_mtp/model.py | 10 +- lightllm/models/qwen2/model.py | 8 +- lightllm/models/qwen3_moe_mtp/model.py | 14 +- .../model_infer/mode_backend/base_backend.py | 5 +- lightllm/utils/envs_utils.py | 12 ++ 9 files changed, 41 insertions(+), 342 deletions(-) delete mode 100644 lightllm/models/mistral_mtp/layer_weights/hf_load_utils.py diff --git a/lightllm/models/deepseek_mtp/model.py b/lightllm/models/deepseek_mtp/model.py index 2e2e95187..0325268e8 100644 --- a/lightllm/models/deepseek_mtp/model.py +++ b/lightllm/models/deepseek_mtp/model.py @@ -1,3 +1,4 @@ +from typing import List from lightllm.models.deepseek2.model import Deepseek2TpPartModel from lightllm.models.deepseek_mtp.layer_infer.pre_layer_infer import Deepseek3MTPPreLayerInfer from lightllm.models.deepseek_mtp.layer_weights.pre_and_post_layer_weight import Deepseek3MTPPreAndPostLayerWeight @@ -16,7 +17,7 @@ def __init__(self, kvargs: dict): def _pre_init(self, kvargs: dict): self.main_model: TpPartBaseModel = kvargs.pop("main_model") - self.mem_layer_start = kvargs.pop("mem_layer_start", 0) + self.mtp_previous_draft_models: List[TpPartBaseModel] = kvargs.pop("mtp_previous_draft_models") return def _init_custom(self): @@ -40,7 +41,11 @@ def _init_weights(self): def _init_infer_layer(self): super()._init_infer_layer() + total_pre_layers_num = len(self.main_model.layers_infer) + total_pre_layers_num += sum( + [len(previous_model.layers_infer) for previous_model in self.mtp_previous_draft_models] + ) # reset the layer_num_ of the self.layers_infer for layer in self.layers_infer: - layer.layer_num_ = layer.layer_num_ + self.mem_layer_start + layer.layer_num_ = layer.layer_num_ + total_pre_layers_num return diff --git a/lightllm/models/mistral_mtp/layer_infer/post_layer_infer.py b/lightllm/models/mistral_mtp/layer_infer/post_layer_infer.py index fa0d0b374..5eac249ba 100644 --- a/lightllm/models/mistral_mtp/layer_infer/post_layer_infer.py +++ b/lightllm/models/mistral_mtp/layer_infer/post_layer_infer.py @@ -1,149 +1,9 @@ -import os -import torch -import torch.functional as F -import torch.distributed as dist -import numpy as np -from lightllm.common.basemodel.layer_weights.base_layer_weight import BaseLayerWeight +from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer -from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight -from einops import rearrange -from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward -from lightllm.common.basemodel import PostLayerInferTpl -from lightllm.utils.infer_utils import mark_cost_time -from lightllm.distributed.communication_op import all_gather - -class MistralMTPPostLayerInfer(PostLayerInferTpl): +class MistralMTPPostLayerInfer(LlamaPostLayerInfer): """ """ def __init__(self, network_config, mode): super().__init__(network_config, mode) - self.eps_ = network_config["rms_norm_eps"] - self.vocab_size_ = network_config["vocab_size"] - self.embed_dim_ = network_config["n_embed"] return - - def _norm(self, input, infer_state, layer_weight: LlamaPreAndPostLayerWeight) -> torch.Tensor: - return rmsnorm_forward(input, layer_weight.final_norm_weight_, eps=self.eps_) - - def _slice_get_last_input(self, input_embdings, infer_state: LlamaInferStateInfo): - - if infer_state.is_prefill and infer_state.is_token_healing: - batch_size = infer_state.batch_size - b_seq_len_numpy = (infer_state.b_seq_len - infer_state.b_ready_cache_len).detach().cpu().numpy() - select_index = [] - start_index = 0 - select_token_num = 0 - for cur_len in b_seq_len_numpy: - select_index.append(start_index + cur_len - 1) - start_index += cur_len - select_token_num += 1 - - last_index = torch.tensor(select_index, dtype=torch.long, device=input_embdings.device) - last_input = self.alloc_tensor((select_token_num, self.embed_dim_), dtype=input_embdings.dtype) - last_input[:, :] = input_embdings[last_index, :] - return last_input, select_token_num - - if infer_state.is_prefill and not infer_state.return_all_prompt_logics: - batch_size = infer_state.batch_size - last_input = self.alloc_tensor((batch_size, self.embed_dim_), dtype=input_embdings.dtype) - last_index = ( - torch.cumsum(infer_state.b_seq_len - infer_state.b_ready_cache_len, dim=0, dtype=torch.long) - 1 - ) - last_input[:, :] = input_embdings[last_index, :] - return last_input, batch_size - - if infer_state.is_prefill and infer_state.return_all_prompt_logics: - total_tokens = infer_state.total_token_num - return input_embdings, total_tokens - - if not infer_state.is_prefill: - batch_size = infer_state.batch_size - return input_embdings[-batch_size:, :], batch_size - - assert False, "Error State" - - def token_forward(self, input_embdings, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): - last_input, token_num = self._slice_get_last_input(input_embdings, infer_state) - input_embdings_dtype = input_embdings.dtype - input_embdings = None - last_input = last_input.permute(1, 0).view(-1, token_num) - logic_batch = self.alloc_tensor( - (layer_weight.lm_head_weight_.shape[0], last_input.shape[1]), dtype=last_input.dtype - ) - torch.mm(layer_weight.lm_head_weight_, last_input, out=logic_batch) - - last_input = None - if self.tp_world_size_ == 1: - gather_data = logic_batch - else: - gather_data = self.alloc_tensor((self.vocab_size_, token_num), dtype=input_embdings_dtype) - split_indexes = np.linspace(0, self.vocab_size_, self.tp_world_size_ + 1, dtype=np.int64) - all_gather( - [gather_data[split_indexes[i] : split_indexes[i + 1], :] for i in range(self.tp_world_size_)], - logic_batch, - group=infer_state.dist_group, - async_op=False, - ) - logic_batch = None - ans_logics = self.alloc_tensor( - (token_num, self.vocab_size_), - dtype=torch.float32, - is_graph_out=True, - microbatch_index=infer_state.microbatch_index, - ) - ans_logics[:, :] = gather_data.permute(1, 0) - gather_data = None - return ans_logics - - def tpsp_token_forward( - self, input_embdings: torch.Tensor, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight - ): - if self.tp_world_size_ > 1: - assert len(input_embdings.shape) == 2 - token_num, hidden_dim = input_embdings.shape - gather_data = torch.empty( - (self.tp_world_size_ * token_num, hidden_dim), device=input_embdings.device, dtype=input_embdings.dtype - ) - all_gather( - [gather_data[i * token_num : (i + 1) * token_num, :] for i in range(self.tp_world_size_)], - input_embdings, - group=infer_state.dist_group, - async_op=False, - ) - # len(infer_state.position_sin) 获取真实输入长度 - input_embdings = gather_data[0 : len(infer_state.position_sin)] - - if infer_state.need_dp_prefill_balance: - input_embdings = infer_state._all_to_all_unbalance_get(data=input_embdings) - - return self.token_forward(input_embdings=input_embdings, infer_state=infer_state, layer_weight=layer_weight) - - def overlap_tpsp_token_forward( - self, - input_embdings: torch.Tensor, - input_embdings1: torch.Tensor, - infer_state: LlamaInferStateInfo, - infer_state1: LlamaInferStateInfo, - layer_weight: BaseLayerWeight, - ): - if getattr(infer_state, "hook", None) is not None: - infer_state.hook() - infer_state.hook = None - - if infer_state.need_dp_prefill_balance: - input_embdings = infer_state._all_to_all_unbalance_get(data=input_embdings) - - logics = self.tpsp_token_forward(input_embdings, infer_state, layer_weight=layer_weight) - - if getattr(infer_state1, "hook", None) is not None: - infer_state1.hook() - infer_state1.hook = None - - if infer_state1.need_dp_prefill_balance: - input_embdings1 = infer_state1._all_to_all_unbalance_get(data=input_embdings1) - - logics1 = self.tpsp_token_forward(input_embdings1, infer_state1, layer_weight=layer_weight) - - return logics, logics1 diff --git a/lightllm/models/mistral_mtp/layer_infer/pre_layer_infer.py b/lightllm/models/mistral_mtp/layer_infer/pre_layer_infer.py index b3ad86004..58d130d4c 100644 --- a/lightllm/models/mistral_mtp/layer_infer/pre_layer_infer.py +++ b/lightllm/models/mistral_mtp/layer_infer/pre_layer_infer.py @@ -1,113 +1,9 @@ -import os -import torch -import torch.distributed as dist -import numpy as np +from lightllm.models.deepseek_mtp.layer_infer.pre_layer_infer import Deepseek3MTPPreLayerInfer -from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight -from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.common.basemodel import PreLayerInferTpl -from lightllm.utils.infer_utils import mark_cost_time -from lightllm.models.llama.triton_kernel.embedding import embedding -from lightllm.distributed.communication_op import all_reduce -from lightllm.utils.envs_utils import get_env_start_args - -class MistralMTPPreLayerInfer(PreLayerInferTpl): +class MistralMTPPreLayerInfer(Deepseek3MTPPreLayerInfer): """ """ def __init__(self, network_config, mode): super().__init__(network_config, mode) - tp_vob_ids = np.linspace(0, network_config["vocab_size"], self.tp_world_size_ + 1, dtype=np.int64) - self.vob_start_id_, self.vob_end_id_ = int(tp_vob_ids[self.tp_rank_]), int(tp_vob_ids[self.tp_rank_ + 1]) - return - - def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): - input_embdings = self.alloc_tensor( - (input_ids.shape[0], layer_weight.wte_weight_.shape[1]), dtype=layer_weight.data_type_ - ) - embedding(input_ids, layer_weight.wte_weight_, self.vob_start_id_, self.vob_end_id_, input_embdings) - if self.tp_world_size_ > 1: - all_reduce(input_embdings, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) - return input_embdings - - def token_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): - input_embdings = self.alloc_tensor( - (input_ids.shape[0], layer_weight.wte_weight_.shape[1]), dtype=layer_weight.data_type_ - ) - embedding(input_ids, layer_weight.wte_weight_, self.vob_start_id_, self.vob_end_id_, input_embdings) - if self.tp_world_size_ > 1: - all_reduce(input_embdings, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) - return input_embdings - - def tpsp_context_forward( - self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight - ): - if get_env_start_args().enable_dp_prefill_balance: - input_ids = infer_state.prefill_dp_balance(input_ids=input_ids) - - input_embdings = self.context_forward(input_ids=input_ids, infer_state=infer_state, layer_weight=layer_weight) - from lightllm.common.basemodel.triton_kernel.sp_pad_copy import sp_pad_copy - - padded_input_embdings = sp_pad_copy(input_embdings, sp_rank_id=self.tp_rank_, sp_world_size=self.tp_world_size_) - return padded_input_embdings - - def tpsp_token_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): - input_embdings = self.token_forward(input_ids=input_ids, infer_state=infer_state, layer_weight=layer_weight) - from lightllm.common.basemodel.triton_kernel.sp_pad_copy import sp_pad_copy - - padded_input_embdings = sp_pad_copy(input_embdings, sp_rank_id=self.tp_rank_, sp_world_size=self.tp_world_size_) - return padded_input_embdings - - def overlap_tpsp_token_forward( - self, - input_ids: torch.Tensor, - input_ids1: torch.Tensor, - infer_state: LlamaInferStateInfo, - infer_state1: LlamaInferStateInfo, - layer_weight: LlamaPreAndPostLayerWeight, - ): - - input_embdings = self.token_forward(input_ids=input_ids, infer_state=infer_state, layer_weight=layer_weight) - from lightllm.common.basemodel.triton_kernel.sp_pad_copy import sp_pad_copy - - padded_input_embdings = sp_pad_copy(input_embdings, sp_rank_id=self.tp_rank_, sp_world_size=self.tp_world_size_) - - input_embdings1 = self.token_forward(input_ids=input_ids1, infer_state=infer_state1, layer_weight=layer_weight) - from lightllm.common.basemodel.triton_kernel.sp_pad_copy import sp_pad_copy - - padded_input_embdings1 = sp_pad_copy( - input_embdings1, sp_rank_id=self.tp_rank_, sp_world_size=self.tp_world_size_ - ) - - return padded_input_embdings, padded_input_embdings1 - - def overlap_tpsp_context_forward( - self, - input_ids: torch.Tensor, - input_ids1: torch.Tensor, - infer_state: LlamaInferStateInfo, - infer_state1: LlamaInferStateInfo, - layer_weight: LlamaPreAndPostLayerWeight, - ): - if get_env_start_args().enable_dp_prefill_balance: - input_ids = infer_state.prefill_dp_balance(input_ids=input_ids) - - input_embdings = self.context_forward(input_ids=input_ids, infer_state=infer_state, layer_weight=layer_weight) - from lightllm.common.basemodel.triton_kernel.sp_pad_copy import sp_pad_copy - - padded_input_embdings = sp_pad_copy(input_embdings, sp_rank_id=self.tp_rank_, sp_world_size=self.tp_world_size_) - - if get_env_start_args().enable_dp_prefill_balance: - input_ids1 = infer_state1.prefill_dp_balance(input_ids=input_ids1) - - input_embdings1 = self.context_forward( - input_ids=input_ids1, infer_state=infer_state1, layer_weight=layer_weight - ) - from lightllm.common.basemodel.triton_kernel.sp_pad_copy import sp_pad_copy - - padded_input_embdings1 = sp_pad_copy( - input_embdings1, sp_rank_id=self.tp_rank_, sp_world_size=self.tp_world_size_ - ) - - return padded_input_embdings, padded_input_embdings1 diff --git a/lightllm/models/mistral_mtp/layer_weights/hf_load_utils.py b/lightllm/models/mistral_mtp/layer_weights/hf_load_utils.py deleted file mode 100644 index 1819f6319..000000000 --- a/lightllm/models/mistral_mtp/layer_weights/hf_load_utils.py +++ /dev/null @@ -1,73 +0,0 @@ -import torch -import os -import gc -from safetensors import safe_open -from tqdm import tqdm -import lightllm.utils.petrel_helper as utils -from lightllm.utils.dist_utils import get_current_device_id - - -def load_func(file_, use_safetensors=False, pre_post_layer=None, transformer_layer_list=None, weight_dir=None): - # fix bug for 多线程加载的时候,每个线程内部的cuda device 会切回 0, 修改后来保证不会出现bug - import torch.distributed as dist - - torch.cuda.set_device(get_current_device_id()) - - if use_safetensors: - weights = safe_open(os.path.join(weight_dir, file_), "pt", "cpu") - weights = {k: weights.get_tensor(k) for k in weights.keys()} - else: - weights = utils.PetrelHelper.load(os.path.join(weight_dir, file_), map_location="cpu") - - weights = {k: v for k, v in weights.items() if k.startswith("mtp.")} - if pre_post_layer is not None: - pre_post_layer.load_hf_weights(weights) - if transformer_layer_list is not None: - for layer in transformer_layer_list: - layer.load_hf_weights(weights) - del weights - gc.collect() - - -def load_hf_weights(data_type, weight_dir, pre_post_layer=None, transformer_layer_list=None, weight_dict=None): - if isinstance(data_type, str): - data_type = torch.float16 if data_type == "fp16" else torch.float32 - if pre_post_layer is not None: - assert pre_post_layer.data_type_ == data_type, "type is not right" - if transformer_layer_list is not None: - assert transformer_layer_list[0].data_type_ == data_type, "type is not right" - if weight_dict: - if pre_post_layer is not None: - pre_post_layer.load_hf_weights(weight_dict) - if transformer_layer_list is not None: - for layer in transformer_layer_list: - layer.load_hf_weights(weight_dict) - del weight_dict - return - use_safetensors = True - files = utils.PetrelHelper.list(weight_dir, extension="all") - candidate_files = list(filter(lambda x: x.endswith(".safetensors"), files)) - if len(candidate_files) == 0: - use_safetensors = False - candidate_files = list(filter(lambda x: x.endswith(".bin"), files)) - assert len(candidate_files) != 0, "can only support pytorch tensor and safetensors format for weights." - from functools import partial - from multiprocessing.pool import ThreadPool as Pool - - partial_func = partial( - load_func, - use_safetensors=use_safetensors, - pre_post_layer=pre_post_layer, - transformer_layer_list=transformer_layer_list, - weight_dir=weight_dir, - ) # noqa - worker = int(os.environ.get("LOADWORKER", 1)) - with Pool(worker) as p: - iterator = p.imap_unordered(partial_func, candidate_files, chunksize=1) - desc_str = f"pid {os.getpid()} Loading model weights with {worker} workers" - iterator = tqdm(iterator, total=len(candidate_files), desc=desc_str) - - for _ in iterator: - pass - - return diff --git a/lightllm/models/mistral_mtp/model.py b/lightllm/models/mistral_mtp/model.py index 07cca282b..76793ba7e 100644 --- a/lightllm/models/mistral_mtp/model.py +++ b/lightllm/models/mistral_mtp/model.py @@ -1,6 +1,7 @@ +from typing import List from lightllm.models.mistral.model import MistralTpPartModel from lightllm.models.mistral_mtp.layer_weights.pre_and_post_layer_weight import MistralMTPPreAndPostLayerWeight -from lightllm.models.deepseek_mtp.layer_infer.pre_layer_infer import Deepseek3MTPPreLayerInfer +from lightllm.models.mistral_mtp.layer_infer.pre_layer_infer import MistralMTPPreLayerInfer from lightllm.models.mistral_mtp.layer_infer.transformer_layer_infer import MistralMTPTransformerLayerInfer from lightllm.models.mistral_mtp.layer_weights.transformer_layer_weight import MistralMTPTransformerLayerWeight from lightllm.common.basemodel import TpPartBaseModel @@ -10,7 +11,7 @@ class MistralMTPModel(MistralTpPartModel): pre_and_post_weight_class = MistralMTPPreAndPostLayerWeight - pre_layer_infer_class = Deepseek3MTPPreLayerInfer + pre_layer_infer_class = MistralMTPPreLayerInfer transformer_weight_class = MistralMTPTransformerLayerWeight transformer_layer_infer_class = MistralMTPTransformerLayerInfer @@ -22,7 +23,7 @@ def __init__(self, kvargs: dict): def _pre_init(self, kvargs: dict): self.main_model: TpPartBaseModel = kvargs.pop("main_model") - self.mem_layer_start = kvargs.pop("mem_layer_start", 0) + self.mtp_previous_draft_models: List[TpPartBaseModel] = kvargs.pop("mtp_previous_draft_models") return def _init_some_value(self): @@ -74,7 +75,4 @@ def _init_weights(self): def _init_infer_layer(self): super()._init_infer_layer() - # reset the layer_num_ of the self.layers_infer - for layer in self.layers_infer: - layer.layer_num_ = layer.layer_num_ + self.mem_layer_start return diff --git a/lightllm/models/qwen2/model.py b/lightllm/models/qwen2/model.py index 19d3c00a6..2bee50ee2 100644 --- a/lightllm/models/qwen2/model.py +++ b/lightllm/models/qwen2/model.py @@ -3,7 +3,7 @@ from lightllm.models.qwen2.layer_weights.transformer_layer_weight import Qwen2TransformerLayerWeight from lightllm.models.llama.model import LlamaTpPartModel from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class -from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.envs_utils import get_added_mtp_kv_layer_num @ModelRegistry("qwen2") @@ -44,11 +44,7 @@ def _init_mem_manager(self): tp_k_head_num_ = max(self.config["num_key_value_heads"] // self.tp_world_size_, 1) # mtp 模式下需要在mem manger上扩展draft model使用的layer - added_mtp_layer_num = 0 - if get_env_start_args().mtp_mode == "deepseekv3_eagle": - added_mtp_layer_num += 1 - elif get_env_start_args().mtp_mode == "deepseekv3_vanilla": - added_mtp_layer_num += get_env_start_args().mtp_step + added_mtp_layer_num = get_added_mtp_kv_layer_num() self.mem_manager = select_mem_manager_class()( self.max_total_token_num, diff --git a/lightllm/models/qwen3_moe_mtp/model.py b/lightllm/models/qwen3_moe_mtp/model.py index 43f9cd473..f57f633f3 100644 --- a/lightllm/models/qwen3_moe_mtp/model.py +++ b/lightllm/models/qwen3_moe_mtp/model.py @@ -1,3 +1,4 @@ +from typing import List from lightllm.models.qwen3_moe.model import Qwen3MOEModel from lightllm.models.qwen3_moe_mtp.layer_weights.pre_and_post_layer_weight import Qwen3MOEMTPPreAndPostLayerWeight from lightllm.models.deepseek_mtp.layer_infer.pre_layer_infer import Deepseek3MTPPreLayerInfer @@ -22,8 +23,7 @@ def __init__(self, kvargs: dict): def _pre_init(self, kvargs: dict): self.main_model: TpPartBaseModel = kvargs.pop("main_model") - self.mem_layer_start = kvargs.pop("mem_layer_start", 0) - self.mtp_index = kvargs.pop("mtp_index") + self.mtp_previous_draft_models: List[TpPartBaseModel] = kvargs.pop("mtp_previous_draft_models") return def _init_custom(self): @@ -43,6 +43,7 @@ def _init_weights(self): self.pre_post_weight = self.pre_and_post_weight_class( self.data_type, network_config=self.config, mode=self.mode ) + mtp_index = len(self.mtp_previous_draft_models) self.trans_layers_weight = [ self.transformer_weight_class( i, @@ -51,7 +52,7 @@ def _init_weights(self): mode=self.mode, quant_cfg=self.quant_cfg, ) - for i in range(self.mtp_index, self.mtp_index + self.config["n_layer"]) + for i in range(mtp_index, mtp_index + self.config["n_layer"]) ] if self.load_way == "HF": load_hf_weights( @@ -70,7 +71,12 @@ def _init_weights(self): def _init_infer_layer(self): super()._init_infer_layer() + total_pre_layers_num = len(self.main_model.layers_infer) + total_pre_layers_num += sum( + [len(previous_model.layers_infer) for previous_model in self.mtp_previous_draft_models] + ) + # reset the layer_num_ of the self.layers_infer for layer in self.layers_infer: - layer.layer_num_ = layer.layer_num_ + self.mem_layer_start - self.mtp_index + layer.layer_num_ = layer.layer_num_ + total_pre_layers_num return diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 5945a53c4..162e32ad1 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -319,12 +319,11 @@ def init_mtp_draft_model(self, main_kvargs: dict): "quant_cfg": main_kvargs.get("quant_cfg", None), "run_mode": "normal", "main_model": self.model, - "mem_layer_start": self.model.config["num_hidden_layers"] + i * mtp_model_cfg["num_hidden_layers"], - "mtp_index": i, + "mtp_previous_draft_models": self.draft_models.copy(), } mtp_model_cfg, _ = PretrainedConfig.get_config_dict(self.args.mtp_draft_model_dir[i]) - if mtp_model_cfg["model_type"] == "deepseekv3": + if mtp_model_cfg["model_type"] == "deepseek_v3": self.draft_models.append(Deepseek3MTPModel(mtp_model_kvargs)) elif mtp_model_cfg["model_type"] == "qwen3_moe": self.draft_models.append(Qwen3MOEMTPModel(mtp_model_kvargs)) diff --git a/lightllm/utils/envs_utils.py b/lightllm/utils/envs_utils.py index 8995afbc5..ba1702ed9 100644 --- a/lightllm/utils/envs_utils.py +++ b/lightllm/utils/envs_utils.py @@ -235,3 +235,15 @@ def enable_huge_page(): "sudo reboot" """ return enable_env_vars("LIGHTLLM_HUGE_PAGE_ENABLE") + + +@lru_cache(maxsize=None) +def get_added_mtp_kv_layer_num() -> int: + # mtp 模式下需要在mem manger上扩展draft model使用的layer + added_mtp_layer_num = 0 + if get_env_start_args().mtp_mode == "deepseekv3_eagle": + added_mtp_layer_num += 1 + elif get_env_start_args().mtp_mode == "deepseekv3_vanilla": + added_mtp_layer_num += get_env_start_args().mtp_step + + return added_mtp_layer_num From 72616efd483756485ca48f4b17c6dd2fca068e34 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 30 Dec 2025 06:31:35 +0000 Subject: [PATCH 17/79] fix --- lightllm/server/router/model_infer/mode_backend/base_backend.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 162e32ad1..c3c6447f2 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -41,7 +41,6 @@ from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token from lightllm.server.pd_io_struct import NIXLChunckedTransTaskRet from .multi_level_kv_cache import MultiLevelKvCacheModule -from lightllm.server.embed_cache.embed_cache_client import CpuEmbedCacheClient class ModeBackend: From 7955d7714acbd3056e93556f5e8650ea70ee1c7a Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 30 Dec 2025 06:40:57 +0000 Subject: [PATCH 18/79] fix mtp mistral model --- lightllm/models/mistral_mtp/model.py | 27 +++------------------------ 1 file changed, 3 insertions(+), 24 deletions(-) diff --git a/lightllm/models/mistral_mtp/model.py b/lightllm/models/mistral_mtp/model.py index 76793ba7e..e63a16f48 100644 --- a/lightllm/models/mistral_mtp/model.py +++ b/lightllm/models/mistral_mtp/model.py @@ -5,7 +5,6 @@ from lightllm.models.mistral_mtp.layer_infer.transformer_layer_infer import MistralMTPTransformerLayerInfer from lightllm.models.mistral_mtp.layer_weights.transformer_layer_weight import MistralMTPTransformerLayerWeight from lightllm.common.basemodel import TpPartBaseModel -from .layer_weights.hf_load_utils import load_hf_weights class MistralMTPModel(MistralTpPartModel): @@ -45,34 +44,14 @@ def _init_mem_manager(self): return def _init_weights(self): - self.pre_post_weight = self.pre_and_post_weight_class( - self.data_type, network_config=self.config, mode=self.mode - ) - num_layer = 1 - self.trans_layers_weight = [ - self.transformer_weight_class( - i, - self.data_type, - network_config=self.config, - mode=self.mode, - quant_cfg=self.quant_cfg, - ) - for i in range(num_layer) - ] - load_hf_weights( - self.data_type, - weight_dir=self.weight_dir_, - pre_post_layer=self.pre_post_weight, - transformer_layer_list=self.trans_layers_weight, - weight_dict=self.weight_dict, - ) - self.pre_post_weight.verify_load() - [weight.verify_load() for weight in self.trans_layers_weight] + self.config["n_layer"] = 1 + super()._init_weights() self.pre_post_weight.wte_weight_ = self.main_model.pre_post_weight.wte_weight_ self.pre_post_weight.lm_head_weight_ = self.main_model.pre_post_weight.lm_head_weight_ self.pre_post_weight.final_norm_weight_ = self.main_model.pre_post_weight.final_norm_weight_ return def _init_infer_layer(self): + self.config["n_layer"] = 1 super()._init_infer_layer() return From 47f768ad3d4b951663897d5737320a6d0df8f2a4 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 30 Dec 2025 07:38:57 +0000 Subject: [PATCH 19/79] mistral mtp pre layer infer --- .../layer_infer/pre_layer_infer.py | 52 +++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/lightllm/models/mistral_mtp/layer_infer/pre_layer_infer.py b/lightllm/models/mistral_mtp/layer_infer/pre_layer_infer.py index 58d130d4c..019a74bc7 100644 --- a/lightllm/models/mistral_mtp/layer_infer/pre_layer_infer.py +++ b/lightllm/models/mistral_mtp/layer_infer/pre_layer_infer.py @@ -1,4 +1,8 @@ +import torch from lightllm.models.deepseek_mtp.layer_infer.pre_layer_infer import Deepseek3MTPPreLayerInfer +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.models.mistral_mtp.layer_weights.pre_and_post_layer_weight import MistralMTPPreAndPostLayerWeight +from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward class MistralMTPPreLayerInfer(Deepseek3MTPPreLayerInfer): @@ -7,3 +11,51 @@ class MistralMTPPreLayerInfer(Deepseek3MTPPreLayerInfer): def __init__(self, network_config, mode): super().__init__(network_config, mode) return + + def _mtp_context_forward( + self, input_embdings, infer_state: LlamaInferStateInfo, layer_weight: MistralMTPPreAndPostLayerWeight + ): + tgt_embdings = infer_state.deepseekv3_mtp_draft_input_hiddens + assert ( + input_embdings.shape[0] == tgt_embdings.shape[0] + ), f"shape {input_embdings.shape} != shape {tgt_embdings.shape}" + rmsnorm_forward(input_embdings, weight=layer_weight.enorm_weight_, eps=self.eps_, out=input_embdings) + + tgt_embdings = rmsnorm_forward(tgt_embdings, weight=layer_weight.final_norm_weight_, eps=self.eps_) + rmsnorm_forward(tgt_embdings, weight=layer_weight.hnorm_weight_, eps=self.eps_, out=tgt_embdings) + + cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1) + + ans_logics = self.alloc_tensor( + (cat_embdings.shape[0], layer_weight.eh_proj_weight_.shape[1]), dtype=input_embdings.dtype + ) + torch.mm(cat_embdings, layer_weight.eh_proj_weight_, out=ans_logics) + return ans_logics + + def _mtp_token_forward( + self, input_embdings, infer_state: LlamaInferStateInfo, layer_weight: MistralMTPPreAndPostLayerWeight + ): + tgt_embdings = infer_state.deepseekv3_mtp_draft_input_hiddens + assert input_embdings.shape[0] == tgt_embdings.shape[0] + rmsnorm_forward(input_embdings, weight=layer_weight.enorm_weight_, eps=self.eps_, out=input_embdings) + + tgt_embdings = rmsnorm_forward(tgt_embdings, weight=layer_weight.final_norm_weight_, eps=self.eps_) + rmsnorm_forward(tgt_embdings, weight=layer_weight.hnorm_weight_, eps=self.eps_, out=tgt_embdings) + + cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1) + + ans_logics = self.alloc_tensor( + (cat_embdings.shape[0], layer_weight.eh_proj_weight_.shape[1]), dtype=input_embdings.dtype + ) + torch.mm(cat_embdings, layer_weight.eh_proj_weight_, out=ans_logics) + return ans_logics + + def context_forward( + self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: MistralMTPPreAndPostLayerWeight + ): + input_embdings = super().context_forward(input_ids, infer_state, layer_weight) + return self._mtp_context_forward(input_embdings, infer_state, layer_weight) + + def token_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: MistralMTPPreAndPostLayerWeight): + input_embdings = super().token_forward(input_ids, infer_state, layer_weight) + return self._mtp_token_forward(input_embdings, infer_state, layer_weight) From f63a725487baddb5f61b58fb434e927b45cb863c Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 30 Dec 2025 07:50:11 +0000 Subject: [PATCH 20/79] fix pre layer mtp --- .../pre_and_post_layer_weight.py | 22 +++++-------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/lightllm/models/mistral_mtp/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/mistral_mtp/layer_weights/pre_and_post_layer_weight.py index 988b9241f..e974ac2b0 100644 --- a/lightllm/models/mistral_mtp/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/mistral_mtp/layer_weights/pre_and_post_layer_weight.py @@ -1,4 +1,3 @@ -import numpy as np from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight @@ -11,13 +10,12 @@ def __init__(self, data_type, network_config, mode): return def load_hf_weights(self, weights): - rename_weights(weights) - if "model.eh_proj.weight" in weights: - self.eh_proj_weight_ = self._cuda(weights["model.eh_proj.weight"]).t() - if "model.enorm.weight" in weights: - self.enorm_weight_ = self._cuda(weights["model.enorm.weight"]) - if "model.hnorm.weight" in weights: - self.hnorm_weight_ = self._cuda(weights["model.hnorm.weight"]) + if "mtp.eh_proj.weight" in weights: + self.eh_proj_weight_ = self._cuda(weights["mtp.eh_proj.weight"]).t() + if "mtp.enorm.weight" in weights: + self.enorm_weight_ = self._cuda(weights["mtp.enorm.weight"]) + if "mtp.hnorm.weight" in weights: + self.hnorm_weight_ = self._cuda(weights["mtp.hnorm.weight"]) return def verify_load(self): @@ -26,11 +24,3 @@ def verify_load(self): for i in range(len(weights)): assert weights[i] is not None, "index:" + str(i) + " " + errors return - - -def rename_weights(weights): - all_keys = list(weights.keys()) - for key in all_keys: - if key.startswith("mtp."): - weights[key.replace("mtp.", "model.")] = weights.pop(key) - return weights From acdd94e73e35e6781b389e1949328a5b8ed84c83 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 30 Dec 2025 08:20:36 +0000 Subject: [PATCH 21/79] fix mistral mtp weight load --- .../mistral_mtp/layer_weights/transformer_layer_weight.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/lightllm/models/mistral_mtp/layer_weights/transformer_layer_weight.py b/lightllm/models/mistral_mtp/layer_weights/transformer_layer_weight.py index 057322d89..ef8c8a7a5 100644 --- a/lightllm/models/mistral_mtp/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/mistral_mtp/layer_weights/transformer_layer_weight.py @@ -5,6 +5,13 @@ class MistralMTPTransformerLayerWeight(LlamaTransformerLayerWeight): def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + + self._gate_weight_name = f"mtp.layers.{self.layer_num_}.mlp.gate_proj.weight" + self._up_weight_name = f"mtp.layers.{self.layer_num_}.mlp.up_proj.weight" + self._down_weight_name = f"mtp.layers.{self.layer_num_}.mlp.down_proj.weight" + + self._ffn_norm_weight_name = f"mtp.layers.{self.layer_num_}.post_attention_layernorm.weight" + self._ffn_norm_bias_name = None return def _init_weight(self): From 253a60c5516b96236d0eda94d5c3fd5a397c9ee9 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 30 Dec 2025 08:47:52 +0000 Subject: [PATCH 22/79] fix --- lightllm/models/mistral_mtp/model.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/lightllm/models/mistral_mtp/model.py b/lightllm/models/mistral_mtp/model.py index e63a16f48..1ce27c1d0 100644 --- a/lightllm/models/mistral_mtp/model.py +++ b/lightllm/models/mistral_mtp/model.py @@ -2,6 +2,7 @@ from lightllm.models.mistral.model import MistralTpPartModel from lightllm.models.mistral_mtp.layer_weights.pre_and_post_layer_weight import MistralMTPPreAndPostLayerWeight from lightllm.models.mistral_mtp.layer_infer.pre_layer_infer import MistralMTPPreLayerInfer +from lightllm.models.mistral_mtp.layer_infer.post_layer_infer import MistralMTPPostLayerInfer from lightllm.models.mistral_mtp.layer_infer.transformer_layer_infer import MistralMTPTransformerLayerInfer from lightllm.models.mistral_mtp.layer_weights.transformer_layer_weight import MistralMTPTransformerLayerWeight from lightllm.common.basemodel import TpPartBaseModel @@ -15,6 +16,8 @@ class MistralMTPModel(MistralTpPartModel): transformer_weight_class = MistralMTPTransformerLayerWeight transformer_layer_infer_class = MistralMTPTransformerLayerInfer + post_layer_infer_class = MistralMTPPostLayerInfer + def __init__(self, kvargs: dict): self._pre_init(kvargs) super().__init__(kvargs) From 979cd27ad2c7b27166c5b4ba18e6a35956723677 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 30 Dec 2025 10:11:19 +0000 Subject: [PATCH 23/79] fix --- lightllm/models/mistral_mtp/layer_infer/pre_layer_infer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lightllm/models/mistral_mtp/layer_infer/pre_layer_infer.py b/lightllm/models/mistral_mtp/layer_infer/pre_layer_infer.py index 019a74bc7..2a1b87fd9 100644 --- a/lightllm/models/mistral_mtp/layer_infer/pre_layer_infer.py +++ b/lightllm/models/mistral_mtp/layer_infer/pre_layer_infer.py @@ -1,11 +1,11 @@ import torch -from lightllm.models.deepseek_mtp.layer_infer.pre_layer_infer import Deepseek3MTPPreLayerInfer +from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.models.mistral_mtp.layer_weights.pre_and_post_layer_weight import MistralMTPPreAndPostLayerWeight from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward -class MistralMTPPreLayerInfer(Deepseek3MTPPreLayerInfer): +class MistralMTPPreLayerInfer(LlamaPreLayerInfer): """ """ def __init__(self, network_config, mode): From 1fd4c924dee5a9f14528f2d069d53ad3d6a3498d Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 31 Dec 2025 02:30:31 +0000 Subject: [PATCH 24/79] fix mistral support fa3 --- lightllm/models/mistral/model.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/lightllm/models/mistral/model.py b/lightllm/models/mistral/model.py index ef7e5d695..2318b9daa 100644 --- a/lightllm/models/mistral/model.py +++ b/lightllm/models/mistral/model.py @@ -8,8 +8,10 @@ from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.models.llama.flashattention_infer_struct import FlashAttentionStateInfo from lightllm.models.mistral.layer_infer.transformer_layer_infer import MistralTransformerLayerInfer from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class +from lightllm.utils.envs_utils import get_env_start_args @ModelRegistry("mistral") @@ -40,6 +42,10 @@ def _init_custom(self): self._init_to_get_rotary() return + def _init_inferstate_cls(self): + if get_env_start_args().enable_fa3: + self.infer_state_class = FlashAttentionStateInfo + def _init_mem_manager(self): # Dealing with head_dim_!=n_embed // num_attention_heads scenarios, such as mistral 13B head_dim = self.config["hidden_size"] // self.config["num_attention_heads"] From 7036a2747be81061a2b4d8b2c659b82e2de1eeef Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 31 Dec 2025 06:06:34 +0000 Subject: [PATCH 25/79] fix weight --- .../layer_weights/transformer_layer_weight.py | 30 ++++++++++++++++--- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/lightllm/models/mistral_mtp/layer_weights/transformer_layer_weight.py b/lightllm/models/mistral_mtp/layer_weights/transformer_layer_weight.py index ef8c8a7a5..034713ee4 100644 --- a/lightllm/models/mistral_mtp/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/mistral_mtp/layer_weights/transformer_layer_weight.py @@ -1,23 +1,45 @@ -from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight -from lightllm.common.basemodel.layer_weights.meta_weights import NormWeight +from lightllm.common.basemodel import TransformerLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, COLMMWeight, NormWeight -class MistralMTPTransformerLayerWeight(LlamaTransformerLayerWeight): +class MistralMTPTransformerLayerWeight(TransformerLayerWeight): def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + return + def _init_weight_names(self): self._gate_weight_name = f"mtp.layers.{self.layer_num_}.mlp.gate_proj.weight" + self._gate_bias_name = None self._up_weight_name = f"mtp.layers.{self.layer_num_}.mlp.up_proj.weight" + self._up_bias_name = None self._down_weight_name = f"mtp.layers.{self.layer_num_}.mlp.down_proj.weight" + self._down_bias_name = None self._ffn_norm_weight_name = f"mtp.layers.{self.layer_num_}.post_attention_layernorm.weight" self._ffn_norm_bias_name = None - return def _init_weight(self): self._init_norm() self._init_ffn() + def _init_ffn(self): + self.gate_up_proj = ROWMMWeight( + weight_names=[self._gate_weight_name, self._up_weight_name], + data_type=self.data_type_, + bias_names=[self._gate_bias_name, self._up_bias_name], + quant_cfg=self.quant_cfg, + layer_num=self.layer_num_, + name="gate_up_proj", + ) + self.down_proj = COLMMWeight( + weight_names=self._down_weight_name, + data_type=self.data_type_, + bias_names=self._down_bias_name, + quant_cfg=self.quant_cfg, + layer_num=self.layer_num_, + name="down_proj", + ) + def _init_norm(self): self.ffn_norm_weight_ = NormWeight( self._ffn_norm_weight_name, self.data_type_, bias_name=self._ffn_norm_bias_name From 4f8765b3bb7a9d71fca4d5a18af22a40e5ede15e Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 31 Dec 2025 06:42:42 +0000 Subject: [PATCH 26/79] fix mtp_avg_token_per_step calcu --- lightllm/server/httpserver/manager.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 4c40814b8..613722787 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -550,6 +550,7 @@ async def _wait_to_token_package( event = req_status.event unfinished_count = sampling_params.best_of out_token_counter = 0 + sub_req_id_to_mtp_accepted_token_num: Dict[int, int] = {} first_token_cost_ms = sys.float_info.max prompt_tokens = len(prompt_ids) is_first_token = True @@ -579,6 +580,8 @@ async def _wait_to_token_package( prompt_cache_len = metadata.pop("prompt_cache_len", 0) cpu_prompt_cache_len = metadata.pop("cpu_prompt_cache_len", 0) disk_prompt_cache_len = metadata.pop("disk_prompt_cache_len", 0) + sub_req_id_to_mtp_accepted_token_num[sub_req_id] = metadata.get("mtp_accepted_token_num", 0) + if is_first_token: first_token_cost_ms = (time.time() - start_time) * 1000 is_first_token = False @@ -605,7 +608,7 @@ async def _wait_to_token_package( disk_prompt_cache_ratio = disk_prompt_cache_len / prompt_tokens mtp_avg_token_per_step = out_token_counter / max( - (out_token_counter - metadata["mtp_accepted_token_num"]), 1 + (out_token_counter - sum(sub_req_id_to_mtp_accepted_token_num.values())), 1 ) format_start_time = datetime.datetime.fromtimestamp(start_time).strftime("%Y-%m-%d %H:%M:%S") logger.info( From f3364497ae683fa843e400a932db0e99ca39bfd3 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 31 Dec 2025 07:36:01 +0000 Subject: [PATCH 27/79] diverse_mode support mtp --- .../model_infer/mode_backend/diverse_backend/impl.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py index 5b0bf3335..6791e778a 100644 --- a/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py @@ -15,12 +15,19 @@ from lightllm.server.router.model_infer.pin_mem_manager import g_pin_mem_manager from ..chunked_prefill.impl import ChunkedPrefillBackend from lightllm.common.basemodel.infer_lock import g_infer_state_lock +from lightllm.utils.envs_utils import get_env_start_args class DiversehBackend(ChunkedPrefillBackend): def __init__(self) -> None: super().__init__() - self.prefill = self.beam_prefill + + if get_env_start_args().mtp_mode: + # 当前只有 mistral mtp 可以使用 diverse mode 的 mtp 功能。 + self.prefill = self.beam_prefill + else: + self.prefill = self.beam_prefill + self.classed_req_strict_prefill = True def beam_prefill(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq]): From de918c7fd41d0b8d92bd3afd21d1cd92e29776a6 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 31 Dec 2025 09:29:37 +0000 Subject: [PATCH 28/79] fix init weights and init layers --- lightllm/common/basemodel/basemodel.py | 8 +-- .../bloom/layer_weights/hf_load_utils.py | 58 ------------------- lightllm/models/bloom/model.py | 31 ---------- lightllm/models/deepseek2/model.py | 38 ------------ lightllm/models/deepseek_mtp/model.py | 11 ++-- .../llama/layer_weights/ds_load_utils.py | 49 ---------------- lightllm/models/llama/model.py | 39 ------------- 7 files changed, 8 insertions(+), 226 deletions(-) delete mode 100755 lightllm/models/bloom/layer_weights/hf_load_utils.py delete mode 100644 lightllm/models/llama/layer_weights/ds_load_utils.py diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 9065293ae..3a5af8811 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -156,7 +156,7 @@ def _init_quant(self): self.quant_cfg = Quantcfg(self.config, self.quant_type, self.quant_cfg_path) logger.info(f"Initial quantization. " f"The default quantization method is {self.quant_cfg.quant_type}") - def _init_weights(self): + def _init_weights(self, start_layer_index=0): self.pre_post_weight = self.pre_and_post_weight_class( self.data_type, network_config=self.config, mode=self.mode ) @@ -168,7 +168,7 @@ def _init_weights(self): mode=self.mode, quant_cfg=self.quant_cfg, ) - for i in range(self.config["n_layer"]) + for i in range(start_layer_index, start_layer_index + self.config["n_layer"]) ] load_hf_weights( self.data_type, @@ -214,12 +214,12 @@ def _init_req_manager(self): self.req_manager = ReqManager(self.max_req_num, create_max_seq_len, self.mem_manager) return - def _init_infer_layer(self): + def _init_infer_layer(self, start_layer_index=0): self.pre_infer = self.pre_layer_infer_class(network_config=self.config, mode=self.mode) self.post_infer = self.post_layer_infer_class(network_config=self.config, mode=self.mode) self.layers_infer = [ self.transformer_layer_infer_class(i, network_config=self.config, mode=self.mode) - for i in range(self.config["n_layer"]) + for i in range(start_layer_index, start_layer_index + self.config["n_layer"]) ] return diff --git a/lightllm/models/bloom/layer_weights/hf_load_utils.py b/lightllm/models/bloom/layer_weights/hf_load_utils.py deleted file mode 100755 index 01c4c5862..000000000 --- a/lightllm/models/bloom/layer_weights/hf_load_utils.py +++ /dev/null @@ -1,58 +0,0 @@ -import torch -import os -import gc -from safetensors import safe_open - - -def load_hf_weights(data_type, weight_dir, pre_post_layer=None, transformer_layer_list=None, weight_dict=None): - if isinstance(data_type, str): - data_type = torch.float16 if data_type == 'fp16' else torch.float32 - if pre_post_layer is not None: - assert pre_post_layer.data_type_ == data_type, "type is not right" - if transformer_layer_list is not None: - assert transformer_layer_list[0].data_type_ == data_type, "type is not right" - if weight_dict: - new_w = {} - for k,v in weight_dict.items(): - if "transformer." in k: - new_w[k[len("transformer."):]] = v - else: - new_w[k] = v - del weight_dict - weight_dict = new_w - if pre_post_layer is not None: - pre_post_layer.load_hf_weights(weight_dict) - if transformer_layer_list is not None: - for layer in transformer_layer_list: - layer.load_hf_weights(weight_dict) - del weight_dict - return - use_safetensors = True - files = os.listdir(weight_dir) - candidate_files = list(filter(lambda x : x.endswith('.safetensors'), files)) - if len(candidate_files) == 0: - use_safetensors = False - candidate_files = list(filter(lambda x : x.endswith('.bin'), files)) - assert len(candidate_files) != 0, "can only support pytorch tensor and safetensors format for weights." - for file_ in candidate_files: - if use_safetensors: - weights = safe_open(os.path.join(weight_dir, file_), 'pt', 'cpu') - weights = {k: weights.get_tensor(k) for k in weights.keys()} - else: - weights = torch.load(os.path.join(weight_dir, file_), 'cpu') - new_w = {} - for k,v in weights.items(): - if "transformer." in k: - new_w[k[len("transformer."):]] = v - else: - new_w[k] = v - del weights - weights = new_w - if pre_post_layer is not None: - pre_post_layer.load_hf_weights(weights) - if transformer_layer_list is not None: - for layer in transformer_layer_list: - layer.load_hf_weights(weights) - del weights - gc.collect() - return diff --git a/lightllm/models/bloom/model.py b/lightllm/models/bloom/model.py index 2c341a790..7e44ec2eb 100644 --- a/lightllm/models/bloom/model.py +++ b/lightllm/models/bloom/model.py @@ -1,17 +1,11 @@ -import os -import json -import torch from lightllm.models.registry import ModelRegistry from lightllm.models.bloom.layer_infer.pre_layer_infer import BloomPreLayerInfer from lightllm.models.bloom.layer_infer.post_layer_infer import BloomPostLayerInfer from lightllm.models.bloom.layer_infer.transformer_layer_infer import BloomTransformerLayerInfer from lightllm.models.bloom.layer_weights.pre_and_post_layer_weight import BloomPreAndPostLayerWeight from lightllm.models.bloom.layer_weights.transformer_layer_weight import BloomTransformerLayerWeight -from lightllm.models.bloom.layer_weights.hf_load_utils import load_hf_weights from lightllm.common.basemodel import InferStateInfo, TpPartBaseModel -from lightllm.common.build_utils import repair_config - @ModelRegistry("bloom") class BloomTpPartModel(TpPartBaseModel): @@ -41,28 +35,3 @@ def _init_config(self): def _reset_num_key_value_heads(self): self.config["num_key_value_heads"] = self.config["num_attention_heads"] return - - def _init_weights(self): - self.pre_post_weight = self.pre_and_post_weight_class( - self.data_type, network_config=self.config, mode=self.mode - ) - self.trans_layers_weight = [ - self.transformer_weight_class( - i, - self.data_type, - network_config=self.config, - mode=self.mode, - quant_cfg=self.quant_cfg, - ) - for i in range(self.config["n_layer"]) - ] - load_hf_weights( - self.data_type, - weight_dir=self.weight_dir_, - pre_post_layer=self.pre_post_weight, - transformer_layer_list=self.trans_layers_weight, - weight_dict=self.weight_dict, - ) - self.pre_post_weight.verify_load() - [weight.verify_load() for weight in self.trans_layers_weight] - return diff --git a/lightllm/models/deepseek2/model.py b/lightllm/models/deepseek2/model.py index 6dfd88970..b644fe477 100644 --- a/lightllm/models/deepseek2/model.py +++ b/lightllm/models/deepseek2/model.py @@ -112,44 +112,6 @@ def _init_mem_manager(self): ) return - def _init_weights(self): - self.pre_post_weight = self.pre_and_post_weight_class( - self.data_type, network_config=self.config, mode=self.mode - ) - self.trans_layers_weight = [ - self.transformer_weight_class( - i, - self.data_type, - network_config=self.config, - mode=self.mode, - quant_cfg=self.quant_cfg, - ) - for i in range(self.config["n_layer"]) - ] - load_hf_weights( - self.data_type, - weight_dir=self.weight_dir_, - pre_post_layer=self.pre_post_weight, - transformer_layer_list=self.trans_layers_weight, - weight_dict=self.weight_dict, - ) - self.pre_post_weight.verify_load() - [weight.verify_load() for weight in self.trans_layers_weight] - return - - def _init_infer_layer(self): - self.pre_infer = self.pre_layer_infer_class(network_config=self.config, mode=self.mode) - self.post_infer = self.post_layer_infer_class(network_config=self.config, mode=self.mode) - self.layers_infer = [ - self.transformer_layer_infer_class( - i, - network_config=self.config, - mode=self.mode, - ) - for i in range(self.config["n_layer"]) - ] - return - def _init_to_get_yarn_rotary(self): from lightllm.models.llama.yarn_rotary_utils import find_correction_range, linear_ramp_mask, get_deepseek_mscale diff --git a/lightllm/models/deepseek_mtp/model.py b/lightllm/models/deepseek_mtp/model.py index 0325268e8..8876e2630 100644 --- a/lightllm/models/deepseek_mtp/model.py +++ b/lightllm/models/deepseek_mtp/model.py @@ -33,19 +33,16 @@ def _init_mem_manager(self): self.mem_manager = self.main_model.mem_manager return - def _init_weights(self): - super()._init_weights() + def _init_weights(self, start_layer_index=None): + super()._init_weights(start_layer_index=0) self.pre_post_weight.wte_weight_ = self.main_model.pre_post_weight.wte_weight_ self.pre_post_weight.lm_head_weight_ = self.main_model.pre_post_weight.lm_head_weight_ return - def _init_infer_layer(self): - super()._init_infer_layer() + def _init_infer_layer(self, start_layer_index=None): total_pre_layers_num = len(self.main_model.layers_infer) total_pre_layers_num += sum( [len(previous_model.layers_infer) for previous_model in self.mtp_previous_draft_models] ) - # reset the layer_num_ of the self.layers_infer - for layer in self.layers_infer: - layer.layer_num_ = layer.layer_num_ + total_pre_layers_num + super()._init_infer_layer(start_layer_index=total_pre_layers_num) return diff --git a/lightllm/models/llama/layer_weights/ds_load_utils.py b/lightllm/models/llama/layer_weights/ds_load_utils.py deleted file mode 100644 index 091c056ca..000000000 --- a/lightllm/models/llama/layer_weights/ds_load_utils.py +++ /dev/null @@ -1,49 +0,0 @@ -import collections -import torch -import os -import gc - -def load_ds_weights(data_type, weight_dir, pre_post_layer=None, transformer_layer_list=None, weight_dict=None, prefix="", num_layer=0): - if weight_dict: - return weight_dict - files = os.listdir(weight_dir) - candidate_files = sorted(list(filter(lambda x : x.endswith('.pt') and x.startswith('layer'), files))) - assert len(candidate_files) != 0, "can only support pytorch tensor format for weights." - if weight_dict: - weights_all = weight_dict - else: - weights_all = {} - for file_ in candidate_files: - file_split = file_.split('-') - layer_num = int(file_split[0].split('_')[-1]) - rank_num = int(file_split[0].split('_')[-1]) - weights = torch.load(os.path.join(weight_dir, file_), 'cpu') - for k,v in weights.items(): - if layer_num >=3 and layer_num < 3 + num_layer: - k = prefix + str(layer_num - 3) + '.' + k - if layer_num == num_layer + 5: - k = 'lm_head.weight' - if layer_num == num_layer + 4: - k = 'model.norm.weight' - if layer_num == 1: - k = 'model.embed_tokens.weight' - if k not in weights_all: - weights_all[k] = v - else: - if 'q_proj' in k or 'k_proj' in k or 'v_proj' in k or 'gate_proj' in k or 'up_proj' in k: - weights_all[k] = torch.cat([weights_all[k], v], dim=0) - elif 'o_proj' in k or 'down_proj' in k: - weights_all[k] = torch.cat([weights_all[k], v], dim=1) - else: - weights_all[k] = v - if pre_post_layer is not None: - pre_post_layer.load_hf_weights(weights_all) - if transformer_layer_list is not None: - for layer in transformer_layer_list: - layer.load_hf_weights(weights_all) - del weights_all - gc.collect() - return - -if __name__ == '__main__': - load_ds_weight('fp16', '/nvme/baishihao/llama7b', prefix='model.layers.', num_layer=32) \ No newline at end of file diff --git a/lightllm/models/llama/model.py b/lightllm/models/llama/model.py index a228e0025..73de5b4ad 100644 --- a/lightllm/models/llama/model.py +++ b/lightllm/models/llama/model.py @@ -8,9 +8,6 @@ from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight -from lightllm.models.llama.layer_weights.ds_load_utils import load_ds_weights -from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights - from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.models.llama.flashattention_infer_struct import FlashAttentionStateInfo from lightllm.models.llama.flashinfer_struct import LlamaFlashInferStateInfo @@ -134,42 +131,6 @@ def _init_custom(self): raise ValueError(f"Unknown RoPE scaling type {scaling_type}") return - def _init_weights(self): - self.pre_post_weight = self.pre_and_post_weight_class( - self.data_type, network_config=self.config, mode=self.mode - ) - self.trans_layers_weight = [ - self.transformer_weight_class( - i, - self.data_type, - network_config=self.config, - mode=self.mode, - quant_cfg=self.quant_cfg, - ) - for i in range(self.config["n_layer"]) - ] - if self.load_way == "HF": - load_hf_weights( - self.data_type, - weight_dir=self.weight_dir_, - pre_post_layer=self.pre_post_weight, - transformer_layer_list=self.trans_layers_weight, - weight_dict=self.weight_dict, - ) - else: - load_ds_weights( - self.data_type, - weight_dir=self.weight_dir_, - pre_post_layer=self.pre_post_weight, - transformer_layer_list=self.trans_layers_weight, - weight_dict=self.weight_dict, - prefix="model.layers.", - num_layer=self.config["n_layer"], - ) - self.pre_post_weight.verify_load() - [weight.verify_load() for weight in self.trans_layers_weight] - return - def _init_to_get_rotary(self, default_base=10000): partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_) if self.config.get("rope_scaling", {}) is None: From 2013d6fb7d34e922ef1d1cf70df78b42694c360f Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 31 Dec 2025 09:38:04 +0000 Subject: [PATCH 29/79] fix init weights and init layers --- lightllm/models/deepseek_mtp/model.py | 2 ++ lightllm/models/mistral_mtp/model.py | 11 +++++--- lightllm/models/qwen3_moe_mtp/model.py | 37 +++++--------------------- 3 files changed, 15 insertions(+), 35 deletions(-) diff --git a/lightllm/models/deepseek_mtp/model.py b/lightllm/models/deepseek_mtp/model.py index 8876e2630..0204e292a 100644 --- a/lightllm/models/deepseek_mtp/model.py +++ b/lightllm/models/deepseek_mtp/model.py @@ -34,12 +34,14 @@ def _init_mem_manager(self): return def _init_weights(self, start_layer_index=None): + assert start_layer_index is None super()._init_weights(start_layer_index=0) self.pre_post_weight.wte_weight_ = self.main_model.pre_post_weight.wte_weight_ self.pre_post_weight.lm_head_weight_ = self.main_model.pre_post_weight.lm_head_weight_ return def _init_infer_layer(self, start_layer_index=None): + assert start_layer_index is None total_pre_layers_num = len(self.main_model.layers_infer) total_pre_layers_num += sum( [len(previous_model.layers_infer) for previous_model in self.mtp_previous_draft_models] diff --git a/lightllm/models/mistral_mtp/model.py b/lightllm/models/mistral_mtp/model.py index 1ce27c1d0..0132db80f 100644 --- a/lightllm/models/mistral_mtp/model.py +++ b/lightllm/models/mistral_mtp/model.py @@ -46,15 +46,18 @@ def _init_mem_manager(self): self.mem_manager = self.main_model.mem_manager return - def _init_weights(self): + def _init_weights(self, start_layer_index=None): + assert start_layer_index is None + self.config["n_layer"] = 1 - super()._init_weights() + super()._init_weights(start_layer_index=0) self.pre_post_weight.wte_weight_ = self.main_model.pre_post_weight.wte_weight_ self.pre_post_weight.lm_head_weight_ = self.main_model.pre_post_weight.lm_head_weight_ self.pre_post_weight.final_norm_weight_ = self.main_model.pre_post_weight.final_norm_weight_ return - def _init_infer_layer(self): + def _init_infer_layer(self, start_layer_index=None): + assert start_layer_index is None self.config["n_layer"] = 1 - super()._init_infer_layer() + super()._init_infer_layer(start_layer_index=0) return diff --git a/lightllm/models/qwen3_moe_mtp/model.py b/lightllm/models/qwen3_moe_mtp/model.py index f57f633f3..1552695cd 100644 --- a/lightllm/models/qwen3_moe_mtp/model.py +++ b/lightllm/models/qwen3_moe_mtp/model.py @@ -5,7 +5,6 @@ from lightllm.models.qwen3_moe_mtp.layer_infer.transformer_layer_infer import Qwen3MOEMTPTransformerLayerInfer from lightllm.models.qwen3_moe_mtp.layer_weights.transformer_layer_weight import Qwen3MOEMTPTransformerLayerWeight from lightllm.common.basemodel import TpPartBaseModel -from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights class Qwen3MOEMTPModel(Qwen3MOEModel): @@ -39,44 +38,20 @@ def _init_mem_manager(self): self.mem_manager = self.main_model.mem_manager return - def _init_weights(self): - self.pre_post_weight = self.pre_and_post_weight_class( - self.data_type, network_config=self.config, mode=self.mode - ) + def _init_weights(self, start_layer_index=None): + assert start_layer_index is None mtp_index = len(self.mtp_previous_draft_models) - self.trans_layers_weight = [ - self.transformer_weight_class( - i, - self.data_type, - network_config=self.config, - mode=self.mode, - quant_cfg=self.quant_cfg, - ) - for i in range(mtp_index, mtp_index + self.config["n_layer"]) - ] - if self.load_way == "HF": - load_hf_weights( - self.data_type, - weight_dir=self.weight_dir_, - pre_post_layer=self.pre_post_weight, - transformer_layer_list=self.trans_layers_weight, - weight_dict=self.weight_dict, - ) - self.pre_post_weight.verify_load() - [weight.verify_load() for weight in self.trans_layers_weight] + super()._init_weights(start_layer_index=mtp_index) self.pre_post_weight.wte_weight_ = self.main_model.pre_post_weight.wte_weight_ # self.pre_post_weight.lm_head_weight_ = self.main_model.pre_post_weight.lm_head_weight_ # self.pre_post_weight.final_norm_weight_ = self.main_model.pre_post_weight.final_norm_weight_ return - def _init_infer_layer(self): - super()._init_infer_layer() + def _init_infer_layer(self, start_layer_index=None): + assert start_layer_index is None total_pre_layers_num = len(self.main_model.layers_infer) total_pre_layers_num += sum( [len(previous_model.layers_infer) for previous_model in self.mtp_previous_draft_models] ) - - # reset the layer_num_ of the self.layers_infer - for layer in self.layers_infer: - layer.layer_num_ = layer.layer_num_ + total_pre_layers_num + super()._init_infer_layer(start_layer_index=total_pre_layers_num) return From b4872c0d80fd884278ee806fc0e141dca2a7ae4f Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 31 Dec 2025 10:18:22 +0000 Subject: [PATCH 30/79] rename mtp --- .../CN/source/tutorial/api_server_args_zh.rst | 8 ++- .../EN/source/tutorial/api_server_args_zh.rst | 8 ++- lightllm/common/basemodel/basemodel.py | 65 ++++++++++--------- lightllm/common/basemodel/batch_objs.py | 14 ++-- lightllm/common/basemodel/infer_struct.py | 6 +- lightllm/models/deepseek2/model.py | 8 +-- .../layer_infer/pre_layer_infer.py | 4 +- .../layer_infer/pre_layer_infer.py | 4 +- lightllm/server/api_cli.py | 7 +- lightllm/server/core/objs/req.py | 4 +- lightllm/server/core/objs/start_args_type.py | 4 +- .../model_infer/mode_backend/base_backend.py | 4 +- .../mode_backend/chunked_prefill/impl.py | 6 +- .../mode_backend/dp_backend/impl.py | 26 +++----- .../mode_backend/mtp_pre_process.py | 4 +- lightllm/utils/envs_utils.py | 4 +- .../static_inference/model_infer_mtp.py | 10 +-- 17 files changed, 94 insertions(+), 92 deletions(-) diff --git a/docs/CN/source/tutorial/api_server_args_zh.rst b/docs/CN/source/tutorial/api_server_args_zh.rst index 5c53dca84..ce7a79ab9 100755 --- a/docs/CN/source/tutorial/api_server_args_zh.rst +++ b/docs/CN/source/tutorial/api_server_args_zh.rst @@ -447,10 +447,12 @@ MTP 多预测参数 .. option:: --mtp_mode - 支持的 mtp 模式,建议使用 deepseekv3_eagle获得更好的性能体验,可选值: + 支持的 mtp 模式,建议使用 eagle_with_att获得更好的性能体验,可选值: - * ``deepseekv3_vanilla`` - * ``deepseekv3_eagle`` + * ``vanilla_with_att`` + * ``eagle_with_att`` + * ``vanilla_no_att`` + * ``eagle_no_att`` * ``None``: 不启用 mtp(默认) .. option:: --mtp_draft_model_dir diff --git a/docs/EN/source/tutorial/api_server_args_zh.rst b/docs/EN/source/tutorial/api_server_args_zh.rst index 1b72287b6..1644bbab5 100755 --- a/docs/EN/source/tutorial/api_server_args_zh.rst +++ b/docs/EN/source/tutorial/api_server_args_zh.rst @@ -444,10 +444,12 @@ MTP Multi-Prediction Parameters .. option:: --mtp_mode - Supported mtp modes, it is recommended to use deepseekv3_eagle for better performance, optional values: + Supported mtp modes, it is recommended to use eagle_with_att for better performance, optional values: - * ``deepseekv3_vanilla`` - * ``deepseekv3_eagle`` + * ``vanilla_with_att`` + * ``eagle_with_att`` + * ``vanilla_no_att`` + * ``eagle_no_att`` * ``None``: Do not enable mtp (default) .. option:: --mtp_draft_model_dir diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 3a5af8811..72d85c40c 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -89,7 +89,12 @@ def __init__(self, kvargs): self.tp_world_size_ = get_dp_world_size() self.enable_tpsp_mix_mode = get_env_start_args().enable_tpsp_mix_mode - self.is_deepseekv3_mtp_mode = self.args.mtp_mode in ["deepseekv3_vanilla", "deepseekv3_eagle"] + self.is_mtp_mode = self.args.mtp_mode in [ + "vanilla_with_att", + "eagle_with_att", + "vanilla_no_att", + "eagle_no_att", + ] self.prefill_graph: PrefillCudaGraph = None self._init_config() @@ -303,7 +308,7 @@ def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0) infer_state.dist_group = dist_group_manager.get_group(microbatch_index) # 特殊模型,特殊模式的特定变量初始化操作。 - infer_state.deepseekv3_mtp_draft_input_hiddens = model_input.deepseekv3_mtp_draft_input_hiddens + infer_state.mtp_draft_input_hiddens = model_input.mtp_draft_input_hiddens return infer_state @@ -343,9 +348,9 @@ def _create_padded_decode_model_input(self, model_input: ModelInput, new_batch_s ) # 特殊模型,特殊模式的特殊变量的特殊 padding - if new_model_input.deepseekv3_mtp_draft_input_hiddens is not None: - new_model_input.deepseekv3_mtp_draft_input_hiddens = pad2dim_tensor_to_new_batch( - input=new_model_input.deepseekv3_mtp_draft_input_hiddens, + if new_model_input.mtp_draft_input_hiddens is not None: + new_model_input.mtp_draft_input_hiddens = pad2dim_tensor_to_new_batch( + input=new_model_input.mtp_draft_input_hiddens, new_batch_size=new_batch_size, ) @@ -388,9 +393,9 @@ def _create_padded_prefill_model_input(self, model_input: ModelInput, new_handle ] # 特殊模型,特殊模式的特殊变量的特殊 padding - if new_model_input.deepseekv3_mtp_draft_input_hiddens is not None: - new_model_input.deepseekv3_mtp_draft_input_hiddens = pad2dim_tensor_to_new_batch( - input=new_model_input.deepseekv3_mtp_draft_input_hiddens, + if new_model_input.mtp_draft_input_hiddens is not None: + new_model_input.mtp_draft_input_hiddens = pad2dim_tensor_to_new_batch( + input=new_model_input.mtp_draft_input_hiddens, new_batch_size=new_handle_token_num, ) @@ -405,9 +410,9 @@ def _create_unpad_decode_model_output(self, model_output: ModelOutput, origin_ba new_model_output.logits = new_model_output.logits[0:origin_batch_size] # 特殊模型,特殊模式的特殊变量的特殊 unpad - if new_model_output.deepseekv3_mtp_main_output_hiddens is not None: - _hidden_states = new_model_output.deepseekv3_mtp_main_output_hiddens - new_model_output.deepseekv3_mtp_main_output_hiddens = _hidden_states[0:origin_batch_size] + if new_model_output.mtp_main_output_hiddens is not None: + _hidden_states = new_model_output.mtp_main_output_hiddens + new_model_output.mtp_main_output_hiddens = _hidden_states[0:origin_batch_size] return new_model_output @@ -421,9 +426,9 @@ def _create_unpad_prefill_model_output(self, padded_model_output: ModelOutput, o new_model_output.logits = new_model_output.logits[0:-1] # 特殊模型,特殊模式的特殊变量的特殊 unpad - if new_model_output.deepseekv3_mtp_main_output_hiddens is not None: - _hidden_states = new_model_output.deepseekv3_mtp_main_output_hiddens - new_model_output.deepseekv3_mtp_main_output_hiddens = _hidden_states[0:origin_handle_token_num] + if new_model_output.mtp_main_output_hiddens is not None: + _hidden_states = new_model_output.mtp_main_output_hiddens + new_model_output.mtp_main_output_hiddens = _hidden_states[0:origin_handle_token_num] return new_model_output @@ -559,8 +564,8 @@ def prefill_func(input_tensors, infer_state): model_output = ModelOutput(logits=predict_logits) # 特殊模型特殊模式的额外输出 - if self.is_deepseekv3_mtp_mode: - model_output.deepseekv3_mtp_main_output_hiddens = input_embs + if self.is_mtp_mode: + model_output.mtp_main_output_hiddens = input_embs # 在开启使用deepep的时候,需要调用clear_deepep_buffer做资源清理,没有启用的时候 # 该调用没有实际意义 @@ -581,14 +586,14 @@ def _token_forward(self, input_ids, infer_state: InferStateInfo): post_method = (self.post_infer.token_forward, self.post_infer.tpsp_token_forward)[run_mode_index] predict_logits: torch.Tensor = post_method(input_embs, infer_state, self.pre_post_weight) - if self.is_deepseekv3_mtp_mode: + if self.is_mtp_mode: graph_out_hiddens = input_embs.contiguous() model_output = ModelOutput(logits=predict_logits.contiguous()) # 特殊模型特殊模式的额外输出 - if self.is_deepseekv3_mtp_mode: - model_output.deepseekv3_mtp_main_output_hiddens = graph_out_hiddens + if self.is_mtp_mode: + model_output.mtp_main_output_hiddens = graph_out_hiddens # 在 cuda graph 模式下,输出需要转为 no ref tensor, 加强mem pool 的复用,降低显存的使用。 if infer_state.is_cuda_graph: @@ -756,9 +761,9 @@ def _overlap_tpsp_context_forward( model_output = ModelOutput(logits=predict_logits.contiguous()) model_output1 = ModelOutput(logits=predict_logits1.contiguous()) - if self.is_deepseekv3_mtp_mode: - model_output.deepseekv3_mtp_main_output_hiddens = input_embs.contiguous() - model_output1.deepseekv3_mtp_main_output_hiddens = input_embs1.contiguous() + if self.is_mtp_mode: + model_output.mtp_main_output_hiddens = input_embs.contiguous() + model_output1.mtp_main_output_hiddens = input_embs1.contiguous() return model_output, model_output1 @@ -779,16 +784,16 @@ def _overlap_tpsp_token_forward( input_embs, input_embs1, infer_state, infer_state1, self.pre_post_weight ) - if self.is_deepseekv3_mtp_mode: + if self.is_mtp_mode: graph_out_hiddens = input_embs.contiguous() graph_out_hiddens1 = input_embs1.contiguous() model_output = ModelOutput(logits=predict_logits.contiguous()) model_output1 = ModelOutput(logits=predict_logits1.contiguous()) - if self.is_deepseekv3_mtp_mode: - model_output.deepseekv3_mtp_main_output_hiddens = graph_out_hiddens - model_output1.deepseekv3_mtp_main_output_hiddens = graph_out_hiddens1 + if self.is_mtp_mode: + model_output.mtp_main_output_hiddens = graph_out_hiddens + model_output1.mtp_main_output_hiddens = graph_out_hiddens1 if infer_state.is_cuda_graph: model_output.to_no_ref_tensor() @@ -993,16 +998,16 @@ def _init_padded_req(self): def _gen_special_model_input(self, token_num: int): special_model_input = {} - is_deepseekv3_mtp_draft_model = ( + is_mtp_draft_model = ( "Deepseek3MTPModel" in str(self.__class__) or "Qwen3MOEMTPModel" in str(self.__class__) or "MistralMTPModel" in str(self.__class__) ) - if is_deepseekv3_mtp_draft_model: - special_model_input["deepseekv3_mtp_draft_input_hiddens"] = torch.randn( + if is_mtp_draft_model: + special_model_input["mtp_draft_input_hiddens"] = torch.randn( token_num, self.config["hidden_size"], dtype=self.data_type, device="cuda" ) else: - special_model_input["deepseekv3_mtp_draft_input_hiddens"] = None + special_model_input["mtp_draft_input_hiddens"] = None return special_model_input diff --git a/lightllm/common/basemodel/batch_objs.py b/lightllm/common/basemodel/batch_objs.py index 5a98a13df..138f08427 100644 --- a/lightllm/common/basemodel/batch_objs.py +++ b/lightllm/common/basemodel/batch_objs.py @@ -46,9 +46,9 @@ class ModelInput: # 专有变量,用于一些特殊的模型,特殊的模式下, 传递一些特殊 # 的输入变量。只在特殊的模型模式下才会具体使用和生效。 - # deepseekv3_mtp_draft_input_hiddens 用于 deepseekv3 模型 mtp 模式下 + # mtp_draft_input_hiddens 用于模型 mtp 模式下 # 的 draft 模型的输入 - deepseekv3_mtp_draft_input_hiddens: Optional[torch.Tensor] = None + mtp_draft_input_hiddens: Optional[torch.Tensor] = None def to_cuda(self): if self.input_ids is not None: @@ -90,12 +90,12 @@ class ModelOutput: # 专有变量,用于一些特殊的模型,特殊的模式下, 传递一些特殊 # 的输出变量。只在特殊的模型模式下才会具体使用和生效。 - # deepseekv3_mtp_main_output_hiddens 用于在mtp模式下,llm main model - # 输出最后一层的hidden state 状态用于 draft 模型的 deepseekv3_mtp_draft_input_hiddens + # mtp_main_output_hiddens 用于在mtp模式下,llm main model + # 输出最后一层的hidden state 状态用于 draft 模型的 mtp_draft_input_hiddens # 输入 - deepseekv3_mtp_main_output_hiddens: Optional[torch.Tensor] = None + mtp_main_output_hiddens: Optional[torch.Tensor] = None def to_no_ref_tensor(self): self.logits = tensor_to_no_ref_tensor(self.logits) - if self.deepseekv3_mtp_main_output_hiddens is not None: - self.deepseekv3_mtp_main_output_hiddens = tensor_to_no_ref_tensor(self.deepseekv3_mtp_main_output_hiddens) + if self.mtp_main_output_hiddens is not None: + self.mtp_main_output_hiddens = tensor_to_no_ref_tensor(self.mtp_main_output_hiddens) diff --git a/lightllm/common/basemodel/infer_struct.py b/lightllm/common/basemodel/infer_struct.py index 3c28a47b8..a014bad64 100755 --- a/lightllm/common/basemodel/infer_struct.py +++ b/lightllm/common/basemodel/infer_struct.py @@ -71,10 +71,10 @@ def __init__(self): # inferstate的基类中,但是为了代码的简洁和方便,都放在基类中 # 进行管理。注意这些成员变量只会在特定的模型和模式下才会生效。 - # deepseekv3 mtp draft model 使用的额外输入参数, - # 在开启 mtp_mode == deepseekv3 时,mtp draft model + # mtp draft model 使用的额外输入参数, + # 在开启 mtp_mode 时,mtp draft model # 的输入会用到,其他模型和场景都不会用到 - self.deepseekv3_mtp_draft_input_hiddens: Optional[torch.Tensor] = None + self.mtp_draft_input_hiddens: Optional[torch.Tensor] = None # 在单节点多dp的运行模式下,在进行prefill的阶段,如果出现了dp之间数据不平衡的现象, # 可以将推理的数据,进行重新分配到各个dp,在做 att 之前,重新 all to all 到各自的 diff --git a/lightllm/models/deepseek2/model.py b/lightllm/models/deepseek2/model.py index b644fe477..218b2e130 100644 --- a/lightllm/models/deepseek2/model.py +++ b/lightllm/models/deepseek2/model.py @@ -12,7 +12,7 @@ from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class from lightllm.utils.log_utils import init_logger from lightllm.models.llama.yarn_rotary_utils import get_deepseek_mscale -from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args +from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args, get_added_mtp_kv_layer_num from lightllm.distributed.communication_op import dist_group_manager from lightllm.utils.dist_utils import get_dp_world_size, get_current_device_id @@ -96,11 +96,7 @@ def _init_mem_manager(self): manager_class = select_mem_manager_class() # mtp 模式下需要在mem manger上扩展draft model使用的layer - added_mtp_layer_num = 0 - if get_env_start_args().mtp_mode == "deepseekv3_eagle": - added_mtp_layer_num += 1 - elif get_env_start_args().mtp_mode == "deepseekv3_vanilla": - added_mtp_layer_num += get_env_start_args().mtp_step + added_mtp_layer_num = get_added_mtp_kv_layer_num() self.mem_manager = manager_class( self.max_total_token_num, diff --git a/lightllm/models/deepseek_mtp/layer_infer/pre_layer_infer.py b/lightllm/models/deepseek_mtp/layer_infer/pre_layer_infer.py index 0991103e1..b81fbc510 100644 --- a/lightllm/models/deepseek_mtp/layer_infer/pre_layer_infer.py +++ b/lightllm/models/deepseek_mtp/layer_infer/pre_layer_infer.py @@ -18,7 +18,7 @@ def __init__(self, network_config, mode): def _mtp_context_forward( self, input_embdings, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek3MTPPreAndPostLayerWeight ): - tgt_embdings = infer_state.deepseekv3_mtp_draft_input_hiddens + tgt_embdings = infer_state.mtp_draft_input_hiddens assert ( input_embdings.shape[0] == tgt_embdings.shape[0] ), f"shape {input_embdings.shape} != shape {tgt_embdings.shape}" @@ -36,7 +36,7 @@ def _mtp_context_forward( def _mtp_token_forward( self, input_embdings, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek3MTPPreAndPostLayerWeight ): - tgt_embdings = infer_state.deepseekv3_mtp_draft_input_hiddens + tgt_embdings = infer_state.mtp_draft_input_hiddens assert input_embdings.shape[0] == tgt_embdings.shape[0] rmsnorm_forward(input_embdings, weight=layer_weight.enorm_weight_, eps=self.eps_, out=input_embdings) rmsnorm_forward(tgt_embdings, weight=layer_weight.hnorm_weight_, eps=self.eps_, out=tgt_embdings) diff --git a/lightllm/models/mistral_mtp/layer_infer/pre_layer_infer.py b/lightllm/models/mistral_mtp/layer_infer/pre_layer_infer.py index 2a1b87fd9..f43b85d63 100644 --- a/lightllm/models/mistral_mtp/layer_infer/pre_layer_infer.py +++ b/lightllm/models/mistral_mtp/layer_infer/pre_layer_infer.py @@ -15,7 +15,7 @@ def __init__(self, network_config, mode): def _mtp_context_forward( self, input_embdings, infer_state: LlamaInferStateInfo, layer_weight: MistralMTPPreAndPostLayerWeight ): - tgt_embdings = infer_state.deepseekv3_mtp_draft_input_hiddens + tgt_embdings = infer_state.mtp_draft_input_hiddens assert ( input_embdings.shape[0] == tgt_embdings.shape[0] ), f"shape {input_embdings.shape} != shape {tgt_embdings.shape}" @@ -35,7 +35,7 @@ def _mtp_context_forward( def _mtp_token_forward( self, input_embdings, infer_state: LlamaInferStateInfo, layer_weight: MistralMTPPreAndPostLayerWeight ): - tgt_embdings = infer_state.deepseekv3_mtp_draft_input_hiddens + tgt_embdings = infer_state.mtp_draft_input_hiddens assert input_embdings.shape[0] == tgt_embdings.shape[0] rmsnorm_forward(input_embdings, weight=layer_weight.enorm_weight_, eps=self.eps_, out=input_embdings) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 165bd5109..d193bab41 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -542,9 +542,12 @@ def make_argument_parser() -> argparse.ArgumentParser: ) parser.add_argument( "--mtp_mode", - choices=["deepseekv3_vanilla", "deepseekv3_eagle", None], + choices=["vanilla_with_att", "eagle_with_att", "vanilla_no_att", "eagle_no_att", None], default=None, - help="""supported mtp mode, None is not enable mtp, """, + help="""Supported MTP modes. + None: Disables MTP. + *_with_att: Uses the MTP model with an attention mechanism to predict the next draft token. + *_no_att: Uses the MTP model without an attention module to predict the next draft token.""", ) parser.add_argument( "--mtp_draft_model_dir", diff --git a/lightllm/server/core/objs/req.py b/lightllm/server/core/objs/req.py index 1e1796335..f489aac9c 100644 --- a/lightllm/server/core/objs/req.py +++ b/lightllm/server/core/objs/req.py @@ -352,8 +352,8 @@ def get_decode_need_tokens(self): need_tokens = min(self.input_len + self.shm_cur_output_len - self.shm_cur_kv_len, self.chunked_prefill_size) if need_tokens == 1 and self._mtp_step > 0: # self._mtp_step > 0 时,说明开启了mtp 模式,每次decode需要额外的mem token 资源 - # "deepseekv3_vanilla" 模式需要的 mem 用量为 self._mtp_step + 1 - # "deepseekv3_eagle" 模式需要的 mem 用量为 (self._mtp_step + 1)* 2 + # "vanilla_with_att" 模式需要的 mem 用量为 self._mtp_step + 1 + # "eagle_with_att" 模式需要的 mem 用量为 (self._mtp_step + 1)* 2 # 为了简化统一 返回 (self._mtp_step + 1)* 2 need_tokens = (self._mtp_step + 1) * 2 diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 611ff772b..5ebadaf16 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -124,7 +124,9 @@ class StartArgs: ) ep_redundancy_expert_config_path: Optional[str] = field(default=None) auto_update_redundancy_expert: bool = field(default=False) - mtp_mode: Optional[str] = field(default=None) + mtp_mode: Optional[str] = field( + default=None, metadata={"choices": ["vanilla_with_att", "eagle_with_att", "vanilla_no_att", "eagle_no_att"]} + ) mtp_draft_model_dir: Optional[str] = field(default=None) mtp_step: int = field(default=0) kv_quant_calibration_config_path: Optional[str] = field(default=None) diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index c3c6447f2..129404af8 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -289,9 +289,9 @@ def init_mtp_draft_model(self, main_kvargs: dict): os.environ["DISABLE_CHECK_MAX_LEN_INFER"] = "1" - if self.args.mtp_mode == "deepseekv3_vanilla": + if self.args.mtp_mode in ["vanilla_with_att", "vanilla_no_att"]: num_mtp_modules = self.args.mtp_step - elif self.args.mtp_mode == "deepseekv3_eagle": + elif self.args.mtp_mode in ["eagle_with_att", "eagle_no_att"]: num_mtp_modules = 1 else: assert False, f"error mtp mode {self.args.mtp_mode}" diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index 858f4713a..fca60fed6 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -325,7 +325,7 @@ def _draft_prefill_forward(self, model_input: ModelInput, model_output: ModelOut draft_model_input = prepare_mtp_prefill_inputs( model_input=draft_model_input, b_next_token_ids=draft_next_token_ids_gpu, - deepseekv3_mtp_draft_input_hiddens=draft_model_output.deepseekv3_mtp_main_output_hiddens, + mtp_draft_input_hiddens=draft_model_output.mtp_main_output_hiddens, ) draft_model_output = self.draft_models[draft_model_idx].forward(draft_model_input) draft_next_token_ids_gpu = self._gen_argmax_token_ids(draft_model_output) @@ -349,7 +349,7 @@ def _draft_decode_vanilla( for draft_model_idx in range(self.mtp_step): draft_model_input.input_ids = draft_next_token_ids - draft_model_input.deepseekv3_mtp_draft_input_hiddens = draft_model_output.deepseekv3_mtp_main_output_hiddens + draft_model_input.mtp_draft_input_hiddens = draft_model_output.mtp_main_output_hiddens # spec decode: MTP draft_model_output: ModelOutput = self.draft_models[draft_model_idx].forward(draft_model_input) draft_next_token_ids = self._gen_argmax_token_ids(draft_model_output) @@ -393,7 +393,7 @@ def _draft_decode_eagle( for _step in range(self.mtp_step): draft_model_input.input_ids = draft_next_token_ids - draft_model_input.deepseekv3_mtp_draft_input_hiddens = draft_model_output.deepseekv3_mtp_main_output_hiddens + draft_model_input.mtp_draft_input_hiddens = draft_model_output.mtp_main_output_hiddens # spec decode: MTP draft_model_idx = _step % self.num_mtp_models draft_model_output: ModelOutput = self.draft_models[draft_model_idx].forward(draft_model_input) diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index a1414b8b2..f4a3729e4 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -534,7 +534,7 @@ def _draft_decode_vanilla( for draft_model_idx in range(self.mtp_step): draft_model_input.input_ids = draft_next_token_ids_gpu - draft_model_input.deepseekv3_mtp_draft_input_hiddens = draft_model_output.deepseekv3_mtp_main_output_hiddens + draft_model_input.mtp_draft_input_hiddens = draft_model_output.mtp_main_output_hiddens # spec decode: MTP draft_model_output: ModelOutput = self.draft_models[draft_model_idx].forward(draft_model_input) draft_next_token_ids_gpu = self._gen_argmax_token_ids(draft_model_output) @@ -585,7 +585,7 @@ def _draft_decode_eagle( for _step in range(self.mtp_step): draft_model_input.input_ids = draft_next_token_ids_gpu - draft_model_input.deepseekv3_mtp_draft_input_hiddens = draft_model_output.deepseekv3_mtp_main_output_hiddens + draft_model_input.mtp_draft_input_hiddens = draft_model_output.mtp_main_output_hiddens # spec decode: MTP draft_model_idx = _step % self.num_mtp_models draft_model_output: ModelOutput = self.draft_models[draft_model_idx].forward(draft_model_input) @@ -672,13 +672,13 @@ def prefill_overlap_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[I draft_model_input0 = prepare_mtp_prefill_inputs( model_input=draft_model_input0, b_next_token_ids=draft_next_token_ids_gpu0, - deepseekv3_mtp_draft_input_hiddens=draft_model_output0.deepseekv3_mtp_main_output_hiddens, + mtp_draft_input_hiddens=draft_model_output0.mtp_main_output_hiddens, ) draft_model_input1 = prepare_mtp_prefill_inputs( model_input=draft_model_input1, b_next_token_ids=draft_next_token_ids_gpu1, - deepseekv3_mtp_draft_input_hiddens=draft_model_output1.deepseekv3_mtp_main_output_hiddens, + mtp_draft_input_hiddens=draft_model_output1.mtp_main_output_hiddens, ) draft_model_output0, draft_model_output1 = self.draft_models[ @@ -836,7 +836,7 @@ def _draft_prefill_forward(self, model_input: ModelInput, model_output: ModelOut draft_model_input = prepare_mtp_prefill_inputs( model_input=draft_model_input, b_next_token_ids=draft_next_token_ids_gpu, - deepseekv3_mtp_draft_input_hiddens=draft_model_output.deepseekv3_mtp_main_output_hiddens, + mtp_draft_input_hiddens=draft_model_output.mtp_main_output_hiddens, ) draft_model_output = self.draft_models[draft_model_idx].forward(draft_model_input) draft_next_token_ids_gpu = self._gen_argmax_token_ids(draft_model_output) @@ -874,13 +874,9 @@ def _draft_decode_vanilla_overlap( for draft_model_idx in range(self.mtp_step): draft_model_input0.input_ids = draft_next_token_ids_gpu0 - draft_model_input0.deepseekv3_mtp_draft_input_hiddens = ( - draft_model_output0.deepseekv3_mtp_main_output_hiddens - ) + draft_model_input0.mtp_draft_input_hiddens = draft_model_output0.mtp_main_output_hiddens draft_model_input1.input_ids = draft_next_token_ids_gpu1 - draft_model_input1.deepseekv3_mtp_draft_input_hiddens = ( - draft_model_output1.deepseekv3_mtp_main_output_hiddens - ) + draft_model_input1.mtp_draft_input_hiddens = draft_model_output1.mtp_main_output_hiddens draft_model_output0, draft_model_output1 = self.draft_models[draft_model_idx].microbatch_overlap_decode( draft_model_input0, draft_model_input1 @@ -949,13 +945,9 @@ def _draft_decode_eagle_overlap( for _step in range(self.mtp_step): draft_model_input0.input_ids = draft_next_token_ids_gpu0 - draft_model_input0.deepseekv3_mtp_draft_input_hiddens = ( - draft_model_output0.deepseekv3_mtp_main_output_hiddens - ) + draft_model_input0.mtp_draft_input_hiddens = draft_model_output0.mtp_main_output_hiddens draft_model_input1.input_ids = draft_next_token_ids_gpu1 - draft_model_input1.deepseekv3_mtp_draft_input_hiddens = ( - draft_model_output1.deepseekv3_mtp_main_output_hiddens - ) + draft_model_input1.mtp_draft_input_hiddens = draft_model_output1.mtp_main_output_hiddens draft_model_idx = _step % self.num_mtp_models draft_model_output0, draft_model_output1 = self.draft_models[draft_model_idx].microbatch_overlap_decode( diff --git a/lightllm/server/router/model_infer/mode_backend/mtp_pre_process.py b/lightllm/server/router/model_infer/mode_backend/mtp_pre_process.py index 3991e42de..dbce73a94 100644 --- a/lightllm/server/router/model_infer/mode_backend/mtp_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/mtp_pre_process.py @@ -5,7 +5,7 @@ def prepare_mtp_prefill_inputs( - model_input: ModelInput, b_next_token_ids: torch.Tensor, deepseekv3_mtp_draft_input_hiddens: torch.Tensor + model_input: ModelInput, b_next_token_ids: torch.Tensor, mtp_draft_input_hiddens: torch.Tensor ): new_model_input = copy.copy(model_input) new_input_ids = gen_mtp_new_input_ids( @@ -15,5 +15,5 @@ def prepare_mtp_prefill_inputs( b_ready_cache_len=model_input.b_ready_cache_len, ) new_model_input.input_ids = new_input_ids - new_model_input.deepseekv3_mtp_draft_input_hiddens = deepseekv3_mtp_draft_input_hiddens + new_model_input.mtp_draft_input_hiddens = mtp_draft_input_hiddens return new_model_input diff --git a/lightllm/utils/envs_utils.py b/lightllm/utils/envs_utils.py index ba1702ed9..06f53b307 100644 --- a/lightllm/utils/envs_utils.py +++ b/lightllm/utils/envs_utils.py @@ -241,9 +241,9 @@ def enable_huge_page(): def get_added_mtp_kv_layer_num() -> int: # mtp 模式下需要在mem manger上扩展draft model使用的layer added_mtp_layer_num = 0 - if get_env_start_args().mtp_mode == "deepseekv3_eagle": + if get_env_start_args().mtp_mode == "eagle_with_att": added_mtp_layer_num += 1 - elif get_env_start_args().mtp_mode == "deepseekv3_vanilla": + elif get_env_start_args().mtp_mode == "vanilla_with_att": added_mtp_layer_num += get_env_start_args().mtp_step return added_mtp_layer_num diff --git a/test/benchmark/static_inference/model_infer_mtp.py b/test/benchmark/static_inference/model_infer_mtp.py index ef2ada64c..942af0f88 100644 --- a/test/benchmark/static_inference/model_infer_mtp.py +++ b/test/benchmark/static_inference/model_infer_mtp.py @@ -148,7 +148,7 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_ # Draft model Prefill # For simplicity, we'll just take the input of main_model to draft model. - model_input.deepseekv3_mtp_draft_input_hiddens = model_output.deepseekv3_mtp_main_output_hiddens + model_input.mtp_draft_input_hiddens = model_output.mtp_main_output_hiddens for draft_model_id in range(len(draft_models)): draft_model = draft_models[draft_model_id] model_output = draft_model.forward(model_input) @@ -156,7 +156,7 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_ predict_ids = torch.argmax(prob_out, dim=1, keepdim=True) predict_ids = predict_ids.detach().cpu().numpy() draft_ids.append(predict_ids) - model_input.deepseekv3_mtp_draft_input_hiddens = model_output.deepseekv3_mtp_main_output_hiddens + model_input.mtp_draft_input_hiddens = model_output.mtp_main_output_hiddens torch.cuda.synchronize() prefill_end_time = time.time() @@ -218,7 +218,7 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_ # draft decode model_input.input_ids = predict_ids.reshape(-1) - model_input.deepseekv3_mtp_draft_input_hiddens = model_output.deepseekv3_mtp_main_output_hiddens + model_input.mtp_draft_input_hiddens = model_output.mtp_main_output_hiddens for draft_model_id in range(len(draft_models)): draft_model = draft_models[draft_model_id] @@ -228,11 +228,11 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_ prob_out = torch.softmax(model_output.logits, dim=-1) predict_ids = torch.argmax(prob_out, dim=1, keepdim=True) model_input.input_ids = predict_ids.reshape(-1) - model_input.deepseekv3_mtp_draft_input_hiddens = model_output.deepseekv3_mtp_main_output_hiddens + model_input.mtp_draft_input_hiddens = model_output.mtp_main_output_hiddens # accept all draft ids by default. model_input.input_ids = predict_ids.reshape(-1) - model_input.deepseekv3_mtp_draft_input_hiddens = model_output.deepseekv3_mtp_main_output_hiddens + model_input.mtp_draft_input_hiddens = model_output.mtp_main_output_hiddens torch.cuda.synchronize() if i % 100 == 0 or i == output_len - 1: step_end_time = time.time() From 37c04a66a5865b76f54bd02e0f9e4ab211ef0fc4 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 31 Dec 2025 10:24:08 +0000 Subject: [PATCH 31/79] fix cpu cache kv layer num --- lightllm/utils/kv_cache_utils.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/lightllm/utils/kv_cache_utils.py b/lightllm/utils/kv_cache_utils.py index ed183e393..4875b4eee 100644 --- a/lightllm/utils/kv_cache_utils.py +++ b/lightllm/utils/kv_cache_utils.py @@ -8,7 +8,12 @@ import numpy as np import triton from functools import lru_cache -from lightllm.utils.envs_utils import get_env_start_args, enable_huge_page, get_llm_data_type +from lightllm.utils.envs_utils import ( + get_env_start_args, + enable_huge_page, + get_llm_data_type, + get_added_mtp_kv_layer_num, +) from lightllm.utils.log_utils import init_logger from lightllm.utils.config_utils import get_num_key_value_heads, get_head_dim, get_layer_num from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class @@ -111,7 +116,7 @@ def calcu_cpu_cache_meta() -> "CpuKVCacheMeta": if args.mtp_mode is not None: # TODO 可能会存在不同mtp模式的精度问题 - cpu_cache_meta.layer_num += 1 + cpu_cache_meta.layer_num += get_added_mtp_kv_layer_num() cpu_cache_page_num = int( (args.cpu_cache_storage_size * 1024 * 1024 * 1024) / (cpu_cache_meta.calcu_one_page_size()) From a4416b31c61c77f1a9f257bbc14d252342c958e2 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 31 Dec 2025 10:41:41 +0000 Subject: [PATCH 32/79] fix mem layer --- lightllm/common/basemodel/basemodel.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 72d85c40c..4f7893633 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -24,7 +24,7 @@ from lightllm.common.basemodel.triton_kernel.gather_token_id import gather_token from lightllm.utils.log_utils import init_logger from lightllm.utils.dist_utils import get_dp_world_size -from lightllm.utils.envs_utils import get_env_start_args, get_llm_data_type +from lightllm.utils.envs_utils import get_env_start_args, get_llm_data_type, get_added_mtp_kv_layer_num from lightllm.distributed.communication_op import dist_group_manager from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput from lightllm.common.triton_utils.autotuner import AutotuneLevel @@ -188,12 +188,13 @@ def _init_weights(self, start_layer_index=0): def _init_mem_manager(self): assert self.config["num_attention_heads"] % self.tp_world_size_ == 0 + added_mtp_layer_num = get_added_mtp_kv_layer_num() self.mem_manager: MemoryManager = select_mem_manager_class()( self.max_total_token_num, dtype=self.data_type, head_num=self.config["num_attention_heads"] // self.tp_world_size_, head_dim=self.config["n_embed"] // self.config["num_attention_heads"], - layer_num=self.config["n_layer"], + layer_num=self.config["n_layer"] + added_mtp_layer_num, mem_fraction=self.mem_fraction, ) return From 085f9db6a0143afbaef006d8cc8830c118e32291 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 31 Dec 2025 10:53:32 +0000 Subject: [PATCH 33/79] fix bloom --- lightllm/models/bloom/layer_infer/post_layer_infer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/models/bloom/layer_infer/post_layer_infer.py b/lightllm/models/bloom/layer_infer/post_layer_infer.py index 0cf8f8e99..ee533d86d 100644 --- a/lightllm/models/bloom/layer_infer/post_layer_infer.py +++ b/lightllm/models/bloom/layer_infer/post_layer_infer.py @@ -57,6 +57,6 @@ def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weigh ) logic_batch = None - ans_logics = gather_data.permute(1, 0).float() + ans_logics = gather_data.permute(1, 0).float().contiguous() gather_data = None return ans_logics From 1a1dd87fb36d93c7b9ef9134f64737899f375667 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 31 Dec 2025 11:01:14 +0000 Subject: [PATCH 34/79] fix added_mtp_kv_layer_num --- lightllm/common/basemodel/basemodel.py | 3 +-- lightllm/models/deepseek2/model.py | 5 +---- lightllm/models/gemma3/model.py | 3 ++- lightllm/models/gemma_2b/model.py | 3 ++- lightllm/models/llama/model.py | 3 ++- lightllm/models/mistral/model.py | 3 ++- lightllm/models/qwen2/model.py | 5 +---- lightllm/models/starcoder/model.py | 3 ++- lightllm/models/starcoder2/model.py | 3 ++- 9 files changed, 15 insertions(+), 16 deletions(-) diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 4f7893633..66785b0ec 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -188,13 +188,12 @@ def _init_weights(self, start_layer_index=0): def _init_mem_manager(self): assert self.config["num_attention_heads"] % self.tp_world_size_ == 0 - added_mtp_layer_num = get_added_mtp_kv_layer_num() self.mem_manager: MemoryManager = select_mem_manager_class()( self.max_total_token_num, dtype=self.data_type, head_num=self.config["num_attention_heads"] // self.tp_world_size_, head_dim=self.config["n_embed"] // self.config["num_attention_heads"], - layer_num=self.config["n_layer"] + added_mtp_layer_num, + layer_num=self.config["n_layer"] + get_added_mtp_kv_layer_num(), mem_fraction=self.mem_fraction, ) return diff --git a/lightllm/models/deepseek2/model.py b/lightllm/models/deepseek2/model.py index 218b2e130..cbbd6b1a5 100644 --- a/lightllm/models/deepseek2/model.py +++ b/lightllm/models/deepseek2/model.py @@ -95,15 +95,12 @@ def _verify_params(self): def _init_mem_manager(self): manager_class = select_mem_manager_class() - # mtp 模式下需要在mem manger上扩展draft model使用的layer - added_mtp_layer_num = get_added_mtp_kv_layer_num() - self.mem_manager = manager_class( self.max_total_token_num, dtype=self.data_type, head_num=1, head_dim=self.config["kv_lora_rank"] + self.config["qk_rope_head_dim"], - layer_num=self.config["num_hidden_layers"] + added_mtp_layer_num, + layer_num=self.config["num_hidden_layers"] + get_added_mtp_kv_layer_num(), mem_fraction=self.mem_fraction, ) return diff --git a/lightllm/models/gemma3/model.py b/lightllm/models/gemma3/model.py index 42326169a..dc4f03b7e 100644 --- a/lightllm/models/gemma3/model.py +++ b/lightllm/models/gemma3/model.py @@ -6,6 +6,7 @@ from lightllm.models.registry import ModelRegistry from lightllm.common.basemodel.multimodal_tokenizer import BaseMultiModalTokenizer from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class +from lightllm.utils.envs_utils import get_added_mtp_kv_layer_num from lightllm.models.gemma3.infer_struct import Gemma3InferStateInfo from lightllm.models.gemma3.layer_infer.post_layer_infer import Gemma3PostLayerInfer from lightllm.models.gemma3.layer_infer.pre_layer_infer import Gemma3PreLayerInfer @@ -148,7 +149,7 @@ def _init_mem_manager(self): dtype=torch.bfloat16, head_num=self.config["num_key_value_heads"] // self.tp_world_size_, head_dim=256, - layer_num=self.config["num_hidden_layers"], + layer_num=self.config["num_hidden_layers"] + get_added_mtp_kv_layer_num(), mem_fraction=self.mem_fraction, ) return diff --git a/lightllm/models/gemma_2b/model.py b/lightllm/models/gemma_2b/model.py index 4b425c9ce..2563c7e79 100644 --- a/lightllm/models/gemma_2b/model.py +++ b/lightllm/models/gemma_2b/model.py @@ -7,6 +7,7 @@ from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.models.llama.model import LlamaTpPartModel from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class +from lightllm.utils.envs_utils import get_added_mtp_kv_layer_num @ModelRegistry("gemma") @@ -43,7 +44,7 @@ def _init_mem_manager(self): dtype=self.data_type, head_num=self.config["num_key_value_heads"], head_dim=self.config["hidden_size"] // self.config["num_attention_heads"], - layer_num=self.config["num_hidden_layers"], + layer_num=self.config["num_hidden_layers"] + get_added_mtp_kv_layer_num(), mem_fraction=self.mem_fraction, ) return diff --git a/lightllm/models/llama/model.py b/lightllm/models/llama/model.py index 73de5b4ad..95465a9e6 100644 --- a/lightllm/models/llama/model.py +++ b/lightllm/models/llama/model.py @@ -13,6 +13,7 @@ from lightllm.models.llama.flashinfer_struct import LlamaFlashInferStateInfo from lightllm.common.basemodel import TpPartBaseModel from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class +from lightllm.utils.envs_utils import get_added_mtp_kv_layer_num from lightllm.utils.log_utils import init_logger from lightllm.utils.envs_utils import get_env_start_args from lightllm.utils.dist_utils import get_dp_world_size, get_current_device_id @@ -88,7 +89,7 @@ def _init_mem_manager(self): dtype=self.data_type, head_num=self.config["num_key_value_heads"] // self.tp_world_size_, head_dim=head_dim_, - layer_num=self.config["num_hidden_layers"], + layer_num=self.config["num_hidden_layers"] + get_added_mtp_kv_layer_num(), mem_fraction=self.mem_fraction, ) return diff --git a/lightllm/models/mistral/model.py b/lightllm/models/mistral/model.py index 2318b9daa..d32f51ae7 100644 --- a/lightllm/models/mistral/model.py +++ b/lightllm/models/mistral/model.py @@ -11,6 +11,7 @@ from lightllm.models.llama.flashattention_infer_struct import FlashAttentionStateInfo from lightllm.models.mistral.layer_infer.transformer_layer_infer import MistralTransformerLayerInfer from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class +from lightllm.utils.envs_utils import get_added_mtp_kv_layer_num from lightllm.utils.envs_utils import get_env_start_args @@ -55,7 +56,7 @@ def _init_mem_manager(self): dtype=self.data_type, head_num=self.config["num_key_value_heads"] // self.tp_world_size_, head_dim=head_dim, - layer_num=self.config["num_hidden_layers"], + layer_num=self.config["num_hidden_layers"] + get_added_mtp_kv_layer_num(), mem_fraction=self.mem_fraction, ) return diff --git a/lightllm/models/qwen2/model.py b/lightllm/models/qwen2/model.py index 2bee50ee2..d2f067c42 100644 --- a/lightllm/models/qwen2/model.py +++ b/lightllm/models/qwen2/model.py @@ -43,15 +43,12 @@ def _init_mem_manager(self): head_dim_ = self.config.get("head_dim", head_dim_) tp_k_head_num_ = max(self.config["num_key_value_heads"] // self.tp_world_size_, 1) - # mtp 模式下需要在mem manger上扩展draft model使用的layer - added_mtp_layer_num = get_added_mtp_kv_layer_num() - self.mem_manager = select_mem_manager_class()( self.max_total_token_num, dtype=self.data_type, head_num=tp_k_head_num_, head_dim=head_dim_, - layer_num=self.config["num_hidden_layers"] + added_mtp_layer_num, + layer_num=self.config["num_hidden_layers"] + get_added_mtp_kv_layer_num(), mem_fraction=self.mem_fraction, ) return diff --git a/lightllm/models/starcoder/model.py b/lightllm/models/starcoder/model.py index ea2aeabbc..3fcd14208 100644 --- a/lightllm/models/starcoder/model.py +++ b/lightllm/models/starcoder/model.py @@ -6,6 +6,7 @@ from lightllm.models.bloom.layer_infer.post_layer_infer import BloomPostLayerInfer from lightllm.common.build_utils import repair_config from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class +from lightllm.utils.envs_utils import get_added_mtp_kv_layer_num from lightllm.common.basemodel import TpPartBaseModel from lightllm.common.basemodel import InferStateInfo @@ -46,7 +47,7 @@ def _init_mem_manager(self): dtype=self.data_type, head_num=self.config["num_key_value_heads"], head_dim=self.config["hidden_size"] // self.config["num_attention_heads"], - layer_num=self.config["num_hidden_layers"], + layer_num=self.config["num_hidden_layers"] + get_added_mtp_kv_layer_num(), mem_fraction=self.mem_fraction, ) return diff --git a/lightllm/models/starcoder2/model.py b/lightllm/models/starcoder2/model.py index a299c08be..2b7545914 100644 --- a/lightllm/models/starcoder2/model.py +++ b/lightllm/models/starcoder2/model.py @@ -9,6 +9,7 @@ from lightllm.common.build_utils import repair_config from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class +from lightllm.utils.envs_utils import get_added_mtp_kv_layer_num from lightllm.common.basemodel import TpPartBaseModel @@ -52,7 +53,7 @@ def _init_mem_manager(self): dtype=self.data_type, head_num=self.config["num_key_value_heads"] // self.tp_world_size_, head_dim=self.config["hidden_size"] // self.config["num_attention_heads"], - layer_num=self.config["num_hidden_layers"], + layer_num=self.config["num_hidden_layers"] + get_added_mtp_kv_layer_num(), mem_fraction=self.mem_fraction, ) return From 5da25946aa05fbd077f5867f68303120f2d34f97 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 31 Dec 2025 11:17:02 +0000 Subject: [PATCH 35/79] fix token decode kernel for int32 overflow --- .../token_attention_nopad_att1.py | 65 ++++++++++++++----- .../token_attention_nopad_reduceV.py | 3 + 2 files changed, 53 insertions(+), 15 deletions(-) diff --git a/lightllm/models/bloom/triton_kernel/token_attention_nopad_att1.py b/lightllm/models/bloom/triton_kernel/token_attention_nopad_att1.py index 923d6d83b..174aea6c1 100644 --- a/lightllm/models/bloom/triton_kernel/token_attention_nopad_att1.py +++ b/lightllm/models/bloom/triton_kernel/token_attention_nopad_att1.py @@ -7,15 +7,27 @@ @triton.jit def _fwd_kernel_token_att1( - Q, K, sm_scale, Alibi, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, # B_Start_Loc 保存的是如果连续存储时候的累加输入和 + Q, + K, + sm_scale, + Alibi, + Req_to_tokens, + B_req_idx, + B_Start_Loc, + B_Seqlen, # B_Start_Loc 保存的是如果连续存储时候的累加输入和 Att_Out, - stride_req_to_tokens_b, stride_req_to_tokens_s, - stride_qbs, stride_qh, stride_qd, - stride_kbs, stride_kh, stride_kd, - att_stride_h, att_stride_bs, - + stride_req_to_tokens_b, + stride_req_to_tokens_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + att_stride_h, + att_stride_bs, BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr + BLOCK_N: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) @@ -36,11 +48,19 @@ def _fwd_kernel_token_att1( block_stard_index = start_n * BLOCK_N block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0) + stride_req_to_tokens_s = tl.cast(stride_req_to_tokens_s, tl.int64) + cur_batch_req_id = tl.cast(cur_batch_req_id, tl.int64) + stride_kbs = tl.cast(stride_kbs, tl.int64) + for start_mark in range(0, block_mask, 1): # 用来判断当前 mask 是否需要计算 alibi_m = tl.load(Alibi + cur_head) q = tl.load(Q + off_q + start_mark) offs_n_new = cur_batch_start_index + offs_n - k_loc = tl.load(Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_id + stride_req_to_tokens_s * offs_n_new, mask=offs_n_new < cur_batch_end_index, other=0) + k_loc = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_id + stride_req_to_tokens_s * offs_n_new, + mask=offs_n_new < cur_batch_end_index, + other=0, + ) off_k = k_loc[:, None] * stride_kbs + cur_head * stride_kh + offs_d[None, :] * stride_kd k = tl.load(K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0) att_value = tl.sum(q[None, :] * k, 1) @@ -68,12 +88,25 @@ def token_att_fwd(q, k, att_out, alibi, Req_to_tokens, B_req_idx, B_Start_Loc, B num_warps = 2 _fwd_kernel_token_att1[grid]( - q, k, sm_scale, alibi, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, + q, + k, + sm_scale, + alibi, + Req_to_tokens, + B_req_idx, + B_Start_Loc, + B_Seqlen, att_out, - Req_to_tokens.stride(0), Req_to_tokens.stride(1), - q.stride(0), q.stride(1), q.stride(2), - k.stride(0), k.stride(1), k.stride(2), - att_out.stride(0), att_out.stride(1), + Req_to_tokens.stride(0), + Req_to_tokens.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + att_out.stride(0), + att_out.stride(1), BLOCK_DMODEL=Lk, BLOCK_N=BLOCK, num_warps=num_warps, @@ -88,7 +121,9 @@ def torch_att(xq, xk, bs, seqlen, num_head, head_dim): keys = xk xq = xq.transpose(1, 2) keys = keys.transpose(1, 2) - scores = (torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(head_dim)).squeeze().transpose(0, 1).reshape(num_head, -1) + scores = ( + (torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(head_dim)).squeeze().transpose(0, 1).reshape(num_head, -1) + ) print("s ", scores.shape) return scores @@ -99,4 +134,4 @@ def torch_att1(xq, xk, seqlen, num_head, head_dim): logics = torch.sum(xq * xk, dim=-1, keepdim=False) logics = logics.transpose(0, 1) / math.sqrt(head_dim) - return logics \ No newline at end of file + return logics diff --git a/lightllm/models/bloom/triton_kernel/token_attention_nopad_reduceV.py b/lightllm/models/bloom/triton_kernel/token_attention_nopad_reduceV.py index 3960083fa..8ee01c340 100644 --- a/lightllm/models/bloom/triton_kernel/token_attention_nopad_reduceV.py +++ b/lightllm/models/bloom/triton_kernel/token_attention_nopad_reduceV.py @@ -29,6 +29,9 @@ def _fwd_kernel_token_att2( cur_batch = tl.program_id(0) cur_head = tl.program_id(1) + stride_vbs = tl.cast(stride_vbs, tl.int64) + stride_pbs = tl.cast(stride_pbs, tl.int64) + offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_DMODEL) cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) From 74e0f7546fb2319f0f7bb16045cb6b6730fbe86d Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 31 Dec 2025 11:25:02 +0000 Subject: [PATCH 36/79] fix mtp mode support --- .../server/router/model_infer/mode_backend/base_backend.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 129404af8..92653bc0c 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -323,10 +323,13 @@ def init_mtp_draft_model(self, main_kvargs: dict): mtp_model_cfg, _ = PretrainedConfig.get_config_dict(self.args.mtp_draft_model_dir[i]) if mtp_model_cfg["model_type"] == "deepseek_v3": + assert self.args.mtp_mode in ["vanilla_with_att", "eagle_with_att"] self.draft_models.append(Deepseek3MTPModel(mtp_model_kvargs)) elif mtp_model_cfg["model_type"] == "qwen3_moe": + assert self.args.mtp_mode in ["vanilla_no_att", "eagle_no_att"] self.draft_models.append(Qwen3MOEMTPModel(mtp_model_kvargs)) elif mtp_model_cfg["model_type"] == "mistral": + assert self.args.mtp_mode in ["vanilla_no_att", "eagle_no_att"] self.draft_models.append(MistralMTPModel(mtp_model_kvargs)) else: assert False, f"error mtp mode {mtp_model_cfg['model_type']}" From 594fa949b64bab4fceebd5bc2f06999aa7e5b24a Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 1 Jan 2026 14:03:52 +0800 Subject: [PATCH 37/79] fix --- lightllm/models/deepseek2/model.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/lightllm/models/deepseek2/model.py b/lightllm/models/deepseek2/model.py index cbbd6b1a5..e4ce7c826 100644 --- a/lightllm/models/deepseek2/model.py +++ b/lightllm/models/deepseek2/model.py @@ -146,8 +146,3 @@ def _init_to_get_yarn_rotary(self): self._sin_cached = (freqs.sin() * _mscale).to(self.data_type).cuda() return - - @final - def _context_forward(self, input_ids, infer_state): - predict_logics = super()._context_forward(input_ids, infer_state) - return predict_logics From 5488e4a4f1564c18c158dcc08feae1336b86d16f Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 1 Jan 2026 14:21:43 +0800 Subject: [PATCH 38/79] fix is_egale_mtp --- .../router/model_infer/mode_backend/chunked_prefill/impl.py | 2 +- .../server/router/model_infer/mode_backend/dp_backend/impl.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index fca60fed6..f3450261b 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -40,7 +40,7 @@ def __init__(self) -> None: if get_env_start_args().mtp_mode: self.prefill = self.prefill_mtp self.decode = self.decode_mtp - self.is_mtp_eagle = get_env_start_args().mtp_mode == "deepseekv3_eagle" + self.is_mtp_eagle = get_env_start_args().mtp_mode in ["eagle_with_att", "eagle_no_att"] self.num_mtp_models = 1 if self.is_mtp_eagle else get_env_start_args().mtp_step self._draft_decode_func = self._draft_decode_eagle if self.is_mtp_eagle else self._draft_decode_vanilla else: diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index f4a3729e4..df10a6d4e 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -34,7 +34,7 @@ def __init__(self) -> None: # 在 mtp 模式下切换绑定的prefill 和 decode 函数 if get_env_start_args().mtp_mode: - self.is_mtp_eagle = get_env_start_args().mtp_mode == "deepseekv3_eagle" + self.is_mtp_eagle = get_env_start_args().mtp_mode in ["eagle_with_att", "eagle_no_att"] self.num_mtp_models = 1 if self.is_mtp_eagle else get_env_start_args().mtp_step if self.enable_prefill_microbatch_overlap: self.prefill = self.prefill_overlap_mtp From 616c37fbcb9870bfc42341edce651f37c1ad7c79 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 1 Jan 2026 19:55:25 +0800 Subject: [PATCH 39/79] fix inferstate input_ids --- lightllm/common/basemodel/basemodel.py | 47 ++++++------- lightllm/common/basemodel/cuda_graph.py | 51 ++++++-------- lightllm/common/basemodel/infer_struct.py | 6 +- .../template/pre_layer_infer_template.py | 2 - .../layer_weights/meta_weights/__init__.py | 3 +- .../meta_weights/embedding_weight.py | 67 +++++++++++++++++++ .../layer_weights/meta_weights/norm_weight.py | 17 ++++- .../pre_and_post_layer_weight.py | 13 ++++ .../basemodel}/triton_kernel/embedding.py | 0 .../basemodel}/triton_kernel/rmsnorm.py | 0 .../deepseek2/flashattention_infer_struct.py | 6 +- .../models/deepseek2/flashinfer_struct.py | 8 +-- lightllm/models/deepseek2/infer_struct.py | 4 +- lightllm/models/gemma3/infer_struct.py | 4 +- .../llama/flashattention_infer_struct.py | 18 ++--- lightllm/models/llama/flashinfer_struct.py | 12 ++-- lightllm/models/llama/infer_struct.py | 4 +- .../llama/layer_infer/post_layer_infer.py | 35 ++++------ .../llama/layer_infer/pre_layer_infer.py | 18 +---- .../pre_and_post_layer_weight.py | 44 ++++++------ lightllm/models/qwen/infer_struct.py | 6 +- lightllm/models/qwen2_vl/infer_struct.py | 6 +- 22 files changed, 212 insertions(+), 159 deletions(-) create mode 100644 lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py rename lightllm/{models/llama => common/basemodel}/triton_kernel/embedding.py (100%) rename lightllm/{models/llama => common/basemodel}/triton_kernel/rmsnorm.py (100%) diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 66785b0ec..b6635fbab 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -275,6 +275,7 @@ def forward(self, model_input: ModelInput): def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0): infer_state = self.infer_state_class() + infer_state.input_ids = model_input.input_ids infer_state.is_prefill = model_input.is_prefill infer_state.is_token_healing = self.is_token_healing infer_state.return_all_prompt_logics = self.return_all_prompt_logics @@ -462,8 +463,8 @@ def _prefill( prefill_mem_indexes_ready_event = torch.cuda.Event() prefill_mem_indexes_ready_event.record() - infer_state.init_some_extra_state(self, model_input.input_ids) - model_output = self._context_forward(model_input.input_ids, infer_state) + infer_state.init_some_extra_state(self) + model_output = self._context_forward(infer_state) if is_padded_model_input: model_output = self._create_unpad_prefill_model_output( model_output, origin_handle_token_num=origin_handle_token_num @@ -493,7 +494,7 @@ def _decode( infer_state.b_seq_len, infer_state.mem_index, ) - infer_state.init_some_extra_state(self, padded_model_input.input_ids) + infer_state.init_some_extra_state(self) if self.graph.need_capture(find_graph_batch_size): infer_state.is_cuda_graph = True @@ -514,14 +515,15 @@ def _decode( infer_state.b_seq_len, infer_state.mem_index, ) - infer_state.init_some_extra_state(self, model_input.input_ids) + infer_state.init_some_extra_state(self) model_output = self._token_forward(model_input.input_ids, infer_state) return model_output @final - def _context_forward(self, input_ids, infer_state: InferStateInfo): + def _context_forward(self, infer_state: InferStateInfo): run_mode_index = 1 if self.enable_tpsp_mix_mode else 0 + input_ids = infer_state.input_ids cuda_input_ids = input_ids pre_method = (self.pre_infer.context_forward, self.pre_infer.tpsp_context_forward)[run_mode_index] @@ -573,8 +575,9 @@ def prefill_func(input_tensors, infer_state): return model_output @final - def _token_forward(self, input_ids, infer_state: InferStateInfo): + def _token_forward(self, infer_state: InferStateInfo): run_mode_index = 1 if self.enable_tpsp_mix_mode else 0 + input_ids = infer_state.input_ids cuda_input_ids = input_ids pre_method = (self.pre_infer.token_forward, self.pre_infer.tpsp_token_forward)[run_mode_index] input_embs = pre_method(cuda_input_ids, infer_state, self.pre_post_weight) @@ -620,7 +623,7 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod alloc_mem_index=infer_state0.mem_index, max_q_seq_len=infer_state0.max_q_seq_len, ) - infer_state0.init_some_extra_state(self, input_ids0) + infer_state0.init_some_extra_state(self) infer_state1 = self._create_inferstate(model_input1, 1) init_req_to_token_indexes( @@ -632,7 +635,7 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod alloc_mem_index=infer_state1.mem_index, max_q_seq_len=infer_state1.max_q_seq_len, ) - infer_state1.init_some_extra_state(self, input_ids1) + infer_state1.init_some_extra_state(self) prefill_mem_indexes_ready_event = torch.cuda.Event() prefill_mem_indexes_ready_event.record() @@ -686,7 +689,7 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode infer_state0.b_seq_len, infer_state0.mem_index, ) - infer_state0.init_some_extra_state(self, padded_model_input0.input_ids) + infer_state0.init_some_extra_state(self) infer_state1 = self._create_inferstate(padded_model_input1, 1) copy_kv_index_to_req( self.req_manager.req_to_token_indexs, @@ -694,7 +697,7 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode infer_state1.b_seq_len, infer_state1.mem_index, ) - infer_state1.init_some_extra_state(self, padded_model_input1.input_ids) + infer_state1.init_some_extra_state(self) if self.graph.need_capture(find_graph_batch_size): infer_state0.is_cuda_graph = True @@ -702,16 +705,12 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode model_output0, model_output1 = self.graph.capture_decode( self._overlap_tpsp_token_forward, - padded_model_input0.input_ids, infer_state0, - input_ids1=padded_model_input1.input_ids, infer_state1=infer_state1, ) else: model_output0, model_output1 = self.graph.replay( - padded_model_input0.input_ids, infer_state0, - input_ids1=padded_model_input1.input_ids, infer_state1=infer_state1, ) @@ -726,7 +725,7 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode infer_state0.b_seq_len, infer_state0.mem_index, ) - infer_state0.init_some_extra_state(self, model_input0.input_ids) + infer_state0.init_some_extra_state(self) infer_state1 = self._create_inferstate(model_input1, 1) copy_kv_index_to_req( self.req_manager.req_to_token_indexs, @@ -734,20 +733,16 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode infer_state1.b_seq_len, infer_state1.mem_index, ) - infer_state1.init_some_extra_state(self, model_input1.input_ids) + infer_state1.init_some_extra_state(self) - model_output0, model_output1 = self._overlap_tpsp_token_forward( - model_input0.input_ids, infer_state0, input_ids1=model_input1.input_ids, infer_state1=infer_state1 - ) + model_output0, model_output1 = self._overlap_tpsp_token_forward(infer_state0, infer_state1=infer_state1) return model_output0, model_output1 @final - def _overlap_tpsp_context_forward( - self, input_ids, infer_state: InferStateInfo, input_ids1, infer_state1: InferStateInfo - ): + def _overlap_tpsp_context_forward(self, infer_state: InferStateInfo, infer_state1: InferStateInfo): g_cache_manager.cache_env_in() input_embs, input_embs1 = self.pre_infer.overlap_tpsp_context_forward( - input_ids, input_ids1, infer_state, infer_state1, self.pre_post_weight + infer_state.input_ids, infer_state1.input_ids, infer_state, infer_state1, self.pre_post_weight ) for i in range(self.layers_num): input_embs, input_embs1 = self.layers_infer[i].overlap_tpsp_context_forward( @@ -768,11 +763,9 @@ def _overlap_tpsp_context_forward( return model_output, model_output1 @final - def _overlap_tpsp_token_forward( - self, input_ids, infer_state: InferStateInfo, input_ids1, infer_state1: InferStateInfo - ): + def _overlap_tpsp_token_forward(self, infer_state: InferStateInfo, infer_state1: InferStateInfo): input_embs, input_embs1 = self.pre_infer.overlap_tpsp_token_forward( - input_ids, input_ids1, infer_state, infer_state1, self.pre_post_weight + infer_state.input_ids, infer_state1.input_ids, infer_state, infer_state1, self.pre_post_weight ) for i in range(self.layers_num): diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py index 15c55e91c..9eeab7270 100644 --- a/lightllm/common/basemodel/cuda_graph.py +++ b/lightllm/common/basemodel/cuda_graph.py @@ -62,9 +62,10 @@ def find_closest_graph_batch_size(self, batch_size): else: return None - def _capture_decode(self, decode_func, input_ids: torch.Tensor, infer_state: InferStateInfo): + def _capture_decode(self, decode_func, infer_state: InferStateInfo): dist_group: CustomProcessGroup = infer_state.dist_group graph_obj = torch.cuda.CUDAGraph() + input_ids = infer_state.input_ids batch_size = input_ids.shape[0] infer_state.max_len_in_batch = self.graph_max_len_in_batch infer_state.total_token_num = self.graph_max_len_in_batch * batch_size @@ -78,27 +79,26 @@ def _capture_decode(self, decode_func, input_ids: torch.Tensor, infer_state: Inf # 中的 tensor。 for _ in range(1): torch.cuda.synchronize() - decode_func(input_ids, copy.copy(infer_state)) + decode_func(copy.copy(infer_state)) torch.cuda.synchronize() with lightllm_capture_graph(dist_group): with torch.cuda.graph(graph_obj, pool=self.mempool): - model_output = decode_func(input_ids, infer_state) - self.graph[batch_size] = (graph_obj, input_ids, infer_state, model_output) + model_output = decode_func(infer_state) + self.graph[batch_size] = (graph_obj, infer_state, model_output) graph_obj.replay() return model_output def _capture_decode_overlap( self, decode_func, - input_ids: torch.Tensor, infer_state: InferStateInfo, - input_ids1: torch.Tensor, infer_state1: InferStateInfo, ): dist_group: CustomProcessGroup = infer_state.dist_group dist_group1 = infer_state1.dist_group graph_obj = torch.cuda.CUDAGraph() + input_ids = infer_state.input_ids batch_size = input_ids.shape[0] infer_state.max_len_in_batch = self.graph_max_len_in_batch infer_state.total_token_num = self.graph_max_len_in_batch * batch_size @@ -107,17 +107,15 @@ def _capture_decode_overlap( # warmup for _ in range(1): torch.cuda.synchronize() - decode_func(input_ids, copy.copy(infer_state), input_ids1, copy.copy(infer_state1)) + decode_func(copy.copy(infer_state), copy.copy(infer_state1)) torch.cuda.synchronize() with lightllm_capture_graph(dist_group1): with lightllm_capture_graph(dist_group): with torch.cuda.graph(graph_obj, pool=self.mempool): - model_output, model_output1 = decode_func(input_ids, infer_state, input_ids1, infer_state1) + model_output, model_output1 = decode_func(infer_state, infer_state1) self.graph[batch_size] = ( graph_obj, - input_ids, infer_state, - input_ids1, infer_state1, model_output, model_output1, @@ -128,59 +126,50 @@ def _capture_decode_overlap( def capture_decode( self, decode_func, - input_ids: torch.Tensor, infer_state: InferStateInfo, - input_ids1: Optional[torch.Tensor] = None, - infer_state1: Optional[torch.Tensor] = None, + infer_state1: Optional[InferStateInfo] = None, ): """ Capture the cuda graph for the decoding stage. input_ids1 and infer_state1 is used for the overlap. """ if self.enable_decode_microbatch_overlap: - return self._capture_decode_overlap(decode_func, input_ids, infer_state, input_ids1, infer_state1) + return self._capture_decode_overlap(decode_func, infer_state, infer_state1) else: - assert input_ids1 is None and infer_state1 is None - return self._capture_decode(decode_func, input_ids, infer_state) + assert infer_state1 is None + return self._capture_decode(decode_func, infer_state) - def _replay(self, input_ids: torch.Tensor, infer_state: InferStateInfo): - batch_size = input_ids.shape[0] - graph_obj, graph_input_ids, graph_infer_state, graph_output = self.graph[batch_size] - graph_input_ids.copy_(input_ids) + def _replay(self, infer_state: InferStateInfo): + batch_size = infer_state.input_ids.shape[0] + graph_obj, graph_infer_state, graph_output = self.graph[batch_size] graph_infer_state.copy_for_cuda_graph(infer_state) graph_obj.replay() return graph_output def _replay_overlap( self, - input_ids: torch.Tensor, infer_state: InferStateInfo, - input_ids1: torch.Tensor, infer_state1: InferStateInfo, ): - batch_size = input_ids.shape[0] + batch_size = infer_state.input_ids.shape[0] ( graph_obj, - graph_input_ids, graph_infer_state, - graph_input_ids1, graph_infer_state1, graph_model_output, graph_model_output1, ) = self.graph[batch_size] - graph_input_ids.copy_(input_ids) graph_infer_state.copy_for_cuda_graph(infer_state) - graph_input_ids1.copy_(input_ids1) graph_infer_state1.copy_for_cuda_graph(infer_state1) graph_obj.replay() return graph_model_output, graph_model_output1 - def replay(self, input_ids, infer_state, input_ids1=None, infer_state1=None): + def replay(self, infer_state, infer_state1=None): if self.enable_decode_microbatch_overlap: - return self._replay_overlap(input_ids, infer_state, input_ids1, infer_state1) + return self._replay_overlap(infer_state, infer_state1) else: - assert input_ids1 is None and infer_state1 is None - return self._replay(input_ids, infer_state) + assert infer_state1 is None + return self._replay(infer_state) @torch.no_grad() def warmup(self, model): diff --git a/lightllm/common/basemodel/infer_struct.py b/lightllm/common/basemodel/infer_struct.py index a014bad64..305e24d83 100755 --- a/lightllm/common/basemodel/infer_struct.py +++ b/lightllm/common/basemodel/infer_struct.py @@ -19,6 +19,7 @@ class InferStateInfo: """ def __init__(self): + self.input_ids: torch.Tensor = None self.batch_size: int = None self.total_token_num: int = None self.b_req_idx: torch.Tensor = None @@ -88,7 +89,8 @@ def __init__(self): self.dp_output_split_sizes: List[List[int]] = None self.dp_input_split_sizes: List[List[int]] = None - def init_some_extra_state(self, model, input_ids: torch.Tensor): + def init_some_extra_state(self, model): + if self.is_prefill: ( self.b_q_seq_len, @@ -97,7 +99,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): self.b1_cu_kv_seq_len, self.position_ids, ) = gen_prefill_params( - input_token_num=input_ids.shape[0], + input_token_num=self.input_ids.shape[0], b_ready_cache_len=self.b_ready_cache_len, b_seq_len=self.b_seq_len, ) diff --git a/lightllm/common/basemodel/layer_infer/template/pre_layer_infer_template.py b/lightllm/common/basemodel/layer_infer/template/pre_layer_infer_template.py index 6dc1ebb50..e7a084079 100644 --- a/lightllm/common/basemodel/layer_infer/template/pre_layer_infer_template.py +++ b/lightllm/common/basemodel/layer_infer/template/pre_layer_infer_template.py @@ -8,8 +8,6 @@ class PreLayerInferTpl(PreLayerInfer): def __init__(self, network_config, mode): super().__init__(network_config, mode) self.eps_ = 1e-5 - self.vob_start_id_ = -1 - self.vob_end_id_ = -1 return def _norm(self, input, infer_state, layer_weight) -> torch.Tensor: diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py index b3dab0614..c8d356394 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py @@ -6,6 +6,7 @@ COLMMWeight, ROWBMMWeight, ) -from .norm_weight import NormWeight, GEMMANormWeight, TpNormWeight +from .norm_weight import NormWeight, GEMMANormWeight, TpNormWeight, NoTpNormWeight from .fused_moe_weight_tp import create_tp_moe_wegiht_obj from .fused_moe_weight_ep import FusedMoeWeightEP +from .embedding_weight import EmbeddingWeight, LMHeadWeight diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py new file mode 100644 index 000000000..75d197138 --- /dev/null +++ b/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py @@ -0,0 +1,67 @@ +import torch +import numpy as np +from typing import Dict, Optional +from .base_weight import BaseWeightTpl +from lightllm.utils.dist_utils import get_current_device_id +from lightllm.common.basemodel.triton_kernel.embedding import embedding as embedding_kernel + + +class EmbeddingWeight(BaseWeightTpl): + def __init__(self, weight_name, data_type, vocab_size: int): + super().__init__() + self.weight_name: str = weight_name + self.data_type_ = data_type + self.weight: torch.Tensor = None + self.vocab_size = vocab_size + split_indexes = np.linspace(0, self.vocab_size, self.tp_world_size_ + 1, dtype=np.int64) + self.tp_vocab_start_id = split_indexes[self.tp_rank_] + self.tp_vocab_end_id = split_indexes[self.tp_rank_ + 1] + + def load_hf_weights(self, weights: Dict[str, torch.Tensor]): + if self.weight_name in weights and self.weight is None: + t_weight = weights[self.weight_name] + assert len(t_weight) == self.vocab_size + self.weight = ( + t_weight[self.tp_vocab_start_id : self.tp_vocab_end_id, :] + .to(self.data_type_) + .cuda(get_current_device_id()) + ) + + def verify_load(self): + load_ok = True + load_ok = load_ok and self.weight is not None + + return load_ok + + def embedding(self, input_ids: torch.Tensor, out: Optional[torch.Tensor] = None, alloc_func=torch.empty): + if out is None: + out = alloc_func( + (input_ids.shape[0], self.weight.shape[1]), dtype=self.weight.dtype, device=self.weight.device + ) + + embedding_kernel( + input_ids=input_ids, + weight=self.weight, + vob_start_id=self.tp_vocab_start_id, + vob_end_id=self.tp_vocab_end_id, + out=out, + ) + + return out + + def lm_head(self, input: torch.Tensor, out: Optional[torch.Tensor] = None, alloc_func=torch.empty): + assert input.ndim == 2 + if out is None: + out = alloc_func( + (self.weight.shape[0], input.shape[1]), + dtype=input.dtype, + device=input.device, + ) + + torch.mm(self.weight, input, out=out) + return out + + +class LMHeadWeight(EmbeddingWeight): + def __init__(self, weight_name, data_type, vocab_size): + super().__init__(weight_name, data_type, vocab_size) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py index 7ec672ab8..a3f10bf39 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py @@ -1,6 +1,8 @@ import torch +from typing import Optional from .base_weight import BaseWeightTpl from lightllm.utils.dist_utils import get_current_device_id +from lightllm.common.basemodel.triton_kernel.rmsnorm import rmsnorm_forward class NormWeight(BaseWeightTpl): @@ -9,8 +11,8 @@ def __init__(self, weight_name, data_type, bias_name=None): self.weight_name = weight_name self.bias_name = bias_name self.data_type_ = data_type - self.weight = None - self.bias = None + self.weight: torch.Tensor = None + self.bias: Optional[torch.Tensor] = None def load_hf_weights(self, weights): if self.weight_name in weights: @@ -28,6 +30,17 @@ def verify_load(self): return load_ok +class NoTpNormWeight(NormWeight): + def __init__(self, weight_name, data_type, bias_name=None): + super().__init__(weight_name=weight_name, data_type=data_type, bias_name=bias_name) + self.tp_world_size_ = 1 + self.tp_rank_ = 0 + + def rmsnorm_forward(self, input: torch.Tensor, eps: float) -> torch.Tensor: + assert self.bias is None + return rmsnorm_forward(x=input, weight=self.weight, eps=eps) + + class GEMMANormWeight(NormWeight): def __init__(self, weight_name, data_type, bias_name=None): super().__init__(weight_name, data_type, bias_name) diff --git a/lightllm/common/basemodel/layer_weights/pre_and_post_layer_weight.py b/lightllm/common/basemodel/layer_weights/pre_and_post_layer_weight.py index bb670b289..19eb67017 100644 --- a/lightllm/common/basemodel/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/common/basemodel/layer_weights/pre_and_post_layer_weight.py @@ -1,4 +1,5 @@ from .base_layer_weight import BaseLayerWeight +from .meta_weights import BaseWeight, MMWeightTpl class PreAndPostLayerWeight(BaseLayerWeight): @@ -9,3 +10,15 @@ def __init__(self, data_type, network_config, mode): self.mode = mode self.init_static_params() return + + def load_hf_weights(self, weights): + """ + load weights + """ + for attr_name in dir(self): + attr = getattr(self, attr_name, None) + if isinstance(attr, MMWeightTpl) and len(attr.weight_names) >= 2: + with self.lock: + attr.load_hf_weights(weights) + elif isinstance(attr, BaseWeight): + attr.load_hf_weights(weights) diff --git a/lightllm/models/llama/triton_kernel/embedding.py b/lightllm/common/basemodel/triton_kernel/embedding.py similarity index 100% rename from lightllm/models/llama/triton_kernel/embedding.py rename to lightllm/common/basemodel/triton_kernel/embedding.py diff --git a/lightllm/models/llama/triton_kernel/rmsnorm.py b/lightllm/common/basemodel/triton_kernel/rmsnorm.py similarity index 100% rename from lightllm/models/llama/triton_kernel/rmsnorm.py rename to lightllm/common/basemodel/triton_kernel/rmsnorm.py diff --git a/lightllm/models/deepseek2/flashattention_infer_struct.py b/lightllm/models/deepseek2/flashattention_infer_struct.py index 0725fa337..72ba8a43b 100644 --- a/lightllm/models/deepseek2/flashattention_infer_struct.py +++ b/lightllm/models/deepseek2/flashattention_infer_struct.py @@ -23,8 +23,8 @@ def get_page_table_buffer(cls, graph_max_batch_size: int, max_seq_len: int): ] return cls._shared_page_table_buffer - def init_some_extra_state(self, model, input_ids: torch.Tensor): - super().init_some_extra_state(model, input_ids) + def init_some_extra_state(self, model): + super().init_some_extra_state(model) args_mtp_step = get_env_start_args().mtp_step if self.is_prefill: self.cu_seqlens_q = self.b1_cu_q_seq_len @@ -51,7 +51,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): ].view(att_batch_size, model.graph_max_len_in_batch) else: self.page_table = torch.empty((att_batch_size, self.max_len_in_batch), dtype=torch.int32).to( - input_ids.device + self.input_ids.device ) page_table_copy( page_table=self.page_table[:, :max_seq_len_k], diff --git a/lightllm/models/deepseek2/flashinfer_struct.py b/lightllm/models/deepseek2/flashinfer_struct.py index a00c45601..db6386f79 100644 --- a/lightllm/models/deepseek2/flashinfer_struct.py +++ b/lightllm/models/deepseek2/flashinfer_struct.py @@ -14,15 +14,15 @@ def __init__(self): self.decode_wrapper = None self.flashinfer_extra_state = None - def init_some_extra_state(self, model, input_ids: torch.Tensor): - super().init_some_extra_state(model, input_ids) + def init_some_extra_state(self, model): + super().init_some_extra_state(model) self.flashinfer_extra_state = model.flashinfer_extra_state import flashinfer if not self.is_prefill: if get_env_start_args().enable_flashinfer_decode: - self.q_indptr = torch.arange(self.batch_size + 1, dtype=torch.int32).to(input_ids.device) + self.q_indptr = torch.arange(self.batch_size + 1, dtype=torch.int32).to(self.input_ids.device) if self.batch_size <= model.graph_max_batch_size: self.kv_indices = self.flashinfer_extra_state.kv_indices_buffer[self.microbatch_index][ : self.batch_size * self.flashinfer_extra_state.max_seq_length @@ -31,7 +31,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): self.kv_indices = torch.empty( self.batch_size * self.flashinfer_extra_state.max_seq_length, dtype=torch.int32, - device=input_ids.device, + device=self.input_ids.device, ) repack_kv_index( self.req_manager.req_to_token_indexs, diff --git a/lightllm/models/deepseek2/infer_struct.py b/lightllm/models/deepseek2/infer_struct.py index f05f52f2f..0c2ef3048 100644 --- a/lightllm/models/deepseek2/infer_struct.py +++ b/lightllm/models/deepseek2/infer_struct.py @@ -10,8 +10,8 @@ def __init__(self): super().__init__() self.kv_starts = None - def init_some_extra_state(self, model, input_ids: torch.Tensor): - super().init_some_extra_state(model, input_ids) + def init_some_extra_state(self, model): + super().init_some_extra_state(model) if not self.is_prefill: self.kv_starts = self.b1_cu_kv_seq_len diff --git a/lightllm/models/gemma3/infer_struct.py b/lightllm/models/gemma3/infer_struct.py index 4145124af..33bd44815 100644 --- a/lightllm/models/gemma3/infer_struct.py +++ b/lightllm/models/gemma3/infer_struct.py @@ -12,8 +12,8 @@ def __init__(self): self.position_sin_local = None self.position_cos_local = None - def init_some_extra_state(self, model, input_ids: torch.Tensor): - super().init_some_extra_state(model, input_ids) + def init_some_extra_state(self, model): + super().init_some_extra_state(model) if self.is_prefill: self.max_seq_len = self.max_kv_seq_len position_ids = self.position_ids diff --git a/lightllm/models/llama/flashattention_infer_struct.py b/lightllm/models/llama/flashattention_infer_struct.py index 4cfd72e81..9f71cbbc5 100644 --- a/lightllm/models/llama/flashattention_infer_struct.py +++ b/lightllm/models/llama/flashattention_infer_struct.py @@ -25,12 +25,12 @@ def get_page_table_buffer(cls, graph_max_batch_size: int, max_seq_len: int): ] return cls._shared_page_table_buffer - def _init_flash_attention_state(self, model, input_ids: torch.Tensor): + def _init_flash_attention_state(self, model): if self.is_prefill: self.cu_seqlens_q = self.b1_cu_q_seq_len.int() self.cu_seqlens_k = self.b1_cu_kv_seq_len.int() self.page_table = torch.empty( - (self.batch_size, self.max_seq_len), dtype=torch.int32, device=input_ids.device + (self.batch_size, self.max_seq_len), dtype=torch.int32, device=self.input_ids.device ) self.page_table.copy_(model.req_manager.req_to_token_indexs[self.b_req_idx, : self.max_seq_len]) else: @@ -49,7 +49,7 @@ def _init_flash_attention_state(self, model, input_ids: torch.Tensor): ].reshape(att_batch_size, model.graph_max_len_in_batch) else: self.page_table = torch.empty( - (att_batch_size, self.max_len_in_batch), dtype=torch.int32, device=input_ids.device + (att_batch_size, self.max_len_in_batch), dtype=torch.int32, device=self.input_ids.device ) page_table_copy( @@ -64,7 +64,7 @@ def _init_flash_attention_state(self, model, input_ids: torch.Tensor): if "offline_calibration_fp8kv" in model.mode: if self.is_prefill: - device = input_ids.device + device = self.input_ids.device # q_scale和token_batch_ids在对q做per head量化使用,为了节省资源在推理外部初始化 self.q_scale = torch.empty( (self.batch_size, self.mem_manager.head_num), dtype=torch.float32, device=device @@ -84,7 +84,7 @@ def _init_flash_attention_state(self, model, input_ids: torch.Tensor): else torch.ones( (self.mem_manager.layer_num, self.batch_size, head_num), dtype=torch.float32, - device=input_ids.device, + device=self.input_ids.device, ) ) self.v_descale = ( @@ -95,12 +95,12 @@ def _init_flash_attention_state(self, model, input_ids: torch.Tensor): else torch.ones( (self.mem_manager.layer_num, self.batch_size, head_num), dtype=torch.float32, - device=input_ids.device, + device=self.input_ids.device, ) ) return - def init_some_extra_state(self, model, input_ids: torch.Tensor): - super().init_some_extra_state(model, input_ids) - self._init_flash_attention_state(model, input_ids) + def init_some_extra_state(self, model): + super().init_some_extra_state(model) + self._init_flash_attention_state(model) return diff --git a/lightllm/models/llama/flashinfer_struct.py b/lightllm/models/llama/flashinfer_struct.py index a0c40b57a..7f9beac1d 100644 --- a/lightllm/models/llama/flashinfer_struct.py +++ b/lightllm/models/llama/flashinfer_struct.py @@ -14,8 +14,8 @@ def __init__(self): self.decode_wrapper = None self.flashinfer_extra_state = None - def init_some_extra_state(self, model, input_ids: torch.Tensor): - super().init_some_extra_state(model, input_ids) + def init_some_extra_state(self, model): + super().init_some_extra_state(model) self.flashinfer_extra_state = model.flashinfer_extra_state import flashinfer @@ -23,7 +23,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): if not self.is_prefill: if get_env_start_args().enable_flashinfer_decode: self.kv_last_page_len_buffer = torch.full( - (self.batch_size,), 1, dtype=torch.int32, device=input_ids.device + (self.batch_size,), 1, dtype=torch.int32, device=self.input_ids.device ) if self.batch_size <= model.graph_max_batch_size: self.kv_indices = self.flashinfer_extra_state.kv_indices_buffer[self.microbatch_index][ @@ -33,7 +33,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): self.kv_indices = torch.empty( self.batch_size * self.flashinfer_extra_state.max_seq_length, dtype=torch.int32, - device=input_ids.device, + device=self.input_ids.device, ) repack_kv_index( @@ -71,11 +71,11 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): if get_env_start_args().enable_flashinfer_prefill: q_starts = self.b1_cu_q_seq_len.int() kv_starts = self.b1_cu_kv_seq_len.int() - kv_last_page_len = torch.full((self.batch_size,), 1, dtype=torch.int32, device=input_ids.device) + kv_last_page_len = torch.full((self.batch_size,), 1, dtype=torch.int32, device=self.input_ids.device) kv_indices = torch.empty( self.batch_size * self.flashinfer_extra_state.max_seq_length, dtype=torch.int32, - device=input_ids.device, + device=self.input_ids.device, ) repack_kv_index( self.req_manager.req_to_token_indexs, diff --git a/lightllm/models/llama/infer_struct.py b/lightllm/models/llama/infer_struct.py index 6373b3782..3bba43976 100644 --- a/lightllm/models/llama/infer_struct.py +++ b/lightllm/models/llama/infer_struct.py @@ -10,8 +10,8 @@ def __init__(self): self.position_cos = None self.position_sin = None - def init_some_extra_state(self, model, input_ids: torch.Tensor): - super().init_some_extra_state(model, input_ids) + def init_some_extra_state(self, model): + super().init_some_extra_state(model) if self.is_prefill: self.max_seq_len = self.max_kv_seq_len self.q_max_seq_len = self.max_q_seq_len diff --git a/lightllm/models/llama/layer_infer/post_layer_infer.py b/lightllm/models/llama/layer_infer/post_layer_infer.py index 28e60a952..ce8f00681 100644 --- a/lightllm/models/llama/layer_infer/post_layer_infer.py +++ b/lightllm/models/llama/layer_infer/post_layer_infer.py @@ -4,13 +4,9 @@ import torch.distributed as dist import numpy as np from lightllm.common.basemodel.layer_weights.base_layer_weight import BaseLayerWeight - from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight -from einops import rearrange from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward from lightllm.common.basemodel import PostLayerInferTpl -from lightllm.utils.infer_utils import mark_cost_time from lightllm.distributed.communication_op import all_gather @@ -20,15 +16,13 @@ class LlamaPostLayerInfer(PostLayerInferTpl): def __init__(self, network_config, mode): super().__init__(network_config, mode) self.eps_ = network_config["rms_norm_eps"] - self.vocab_size_ = network_config["vocab_size"] - self.embed_dim_ = network_config["n_embed"] return def _norm(self, input, infer_state, layer_weight: LlamaPreAndPostLayerWeight) -> torch.Tensor: - return rmsnorm_forward(input, layer_weight.final_norm_weight_, eps=self.eps_) - - def _slice_get_last_input(self, input_embdings, infer_state: LlamaInferStateInfo): + return layer_weight.final_norm_weight_.rmsnorm_forward(input=input, eps=self.eps_) + def _slice_get_last_input(self, input_embdings: torch.Tensor, infer_state: LlamaInferStateInfo): + embed_dim_ = input_embdings.shape[1] if infer_state.is_prefill and infer_state.is_token_healing: batch_size = infer_state.batch_size b_seq_len_numpy = (infer_state.b_seq_len - infer_state.b_ready_cache_len).detach().cpu().numpy() @@ -41,13 +35,13 @@ def _slice_get_last_input(self, input_embdings, infer_state: LlamaInferStateInfo select_token_num += 1 last_index = torch.tensor(select_index, dtype=torch.long, device=input_embdings.device) - last_input = self.alloc_tensor((select_token_num, self.embed_dim_), dtype=input_embdings.dtype) + last_input = self.alloc_tensor((select_token_num, embed_dim_), dtype=input_embdings.dtype) last_input[:, :] = input_embdings[last_index, :] return last_input, select_token_num if infer_state.is_prefill and not infer_state.return_all_prompt_logics: batch_size = infer_state.batch_size - last_input = self.alloc_tensor((batch_size, self.embed_dim_), dtype=input_embdings.dtype) + last_input = self.alloc_tensor((batch_size, embed_dim_), dtype=input_embdings.dtype) last_index = ( torch.cumsum(infer_state.b_seq_len - infer_state.b_ready_cache_len, dim=0, dtype=torch.long) - 1 ) @@ -64,23 +58,22 @@ def _slice_get_last_input(self, input_embdings, infer_state: LlamaInferStateInfo assert False, "Error State" - def token_forward(self, input_embdings, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): + def token_forward( + self, input_embdings: torch.Tensor, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight + ): last_input, token_num = self._slice_get_last_input(input_embdings, infer_state) input_embdings_dtype = input_embdings.dtype input_embdings = None last_input = self._norm(last_input, infer_state, layer_weight) last_input = last_input.permute(1, 0).view(-1, token_num) - logic_batch = self.alloc_tensor( - (layer_weight.lm_head_weight_.shape[0], last_input.shape[1]), dtype=last_input.dtype - ) - torch.mm(layer_weight.lm_head_weight_, last_input, out=logic_batch) - + logic_batch = layer_weight.lm_head_weight_.lm_head(input=last_input, alloc_func=self.alloc_tensor) last_input = None + vocab_size = layer_weight.lm_head_weight_.vocab_size if self.tp_world_size_ == 1: gather_data = logic_batch else: - gather_data = self.alloc_tensor((self.vocab_size_, token_num), dtype=input_embdings_dtype) - split_indexes = np.linspace(0, self.vocab_size_, self.tp_world_size_ + 1, dtype=np.int64) + gather_data = self.alloc_tensor((vocab_size, token_num), dtype=input_embdings_dtype) + split_indexes = np.linspace(0, vocab_size, self.tp_world_size_ + 1, dtype=np.int64) all_gather( [gather_data[split_indexes[i] : split_indexes[i + 1], :] for i in range(self.tp_world_size_)], logic_batch, @@ -89,7 +82,7 @@ def token_forward(self, input_embdings, infer_state: LlamaInferStateInfo, layer_ ) logic_batch = None ans_logics = self.alloc_tensor( - (token_num, self.vocab_size_), + (token_num, vocab_size), dtype=torch.float32, ) ans_logics[:, :] = gather_data.permute(1, 0) @@ -112,7 +105,7 @@ def tpsp_token_forward( async_op=False, ) # len(infer_state.position_sin) 获取真实输入长度 - input_embdings = gather_data[0 : len(infer_state.position_sin)] + input_embdings = gather_data[0 : infer_state.handle_token_num] if infer_state.need_dp_prefill_balance: input_embdings = infer_state._all_to_all_unbalance_get(data=input_embdings) diff --git a/lightllm/models/llama/layer_infer/pre_layer_infer.py b/lightllm/models/llama/layer_infer/pre_layer_infer.py index 99b7db5bf..e8fa70626 100644 --- a/lightllm/models/llama/layer_infer/pre_layer_infer.py +++ b/lightllm/models/llama/layer_infer/pre_layer_infer.py @@ -1,13 +1,8 @@ -import os import torch import torch.distributed as dist -import numpy as np - from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.common.basemodel import PreLayerInferTpl -from lightllm.utils.infer_utils import mark_cost_time -from lightllm.models.llama.triton_kernel.embedding import embedding from lightllm.distributed.communication_op import all_reduce from lightllm.utils.envs_utils import get_env_start_args @@ -17,25 +12,16 @@ class LlamaPreLayerInfer(PreLayerInferTpl): def __init__(self, network_config, mode): super().__init__(network_config, mode) - tp_vob_ids = np.linspace(0, network_config["vocab_size"], self.tp_world_size_ + 1, dtype=np.int64) - self.vob_start_id_, self.vob_end_id_ = int(tp_vob_ids[self.tp_rank_]), int(tp_vob_ids[self.tp_rank_ + 1]) - return def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): - input_embdings = self.alloc_tensor( - (input_ids.shape[0], layer_weight.wte_weight_.shape[1]), dtype=layer_weight.data_type_ - ) - embedding(input_ids, layer_weight.wte_weight_, self.vob_start_id_, self.vob_end_id_, input_embdings) + input_embdings = layer_weight.embdding_weight_.embedding(input_ids=input_ids, alloc_func=self.alloc_tensor) if self.tp_world_size_ > 1: all_reduce(input_embdings, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) return input_embdings def token_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): - input_embdings = self.alloc_tensor( - (input_ids.shape[0], layer_weight.wte_weight_.shape[1]), dtype=layer_weight.data_type_ - ) - embedding(input_ids, layer_weight.wte_weight_, self.vob_start_id_, self.vob_end_id_, input_embdings) + input_embdings = layer_weight.embdding_weight_.embedding(input_ids=input_ids, alloc_func=self.alloc_tensor) if self.tp_world_size_ > 1: all_reduce(input_embdings, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) return input_embdings diff --git a/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py index 711406e3f..9754094ad 100644 --- a/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py @@ -2,33 +2,31 @@ import torch import numpy as np from lightllm.common.basemodel import PreAndPostLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, LMHeadWeight, NoTpNormWeight class LlamaPreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) - return - - def load_hf_weights(self, weights): - vob_size = self.network_config_["vocab_size"] - split_indexes = np.linspace(0, vob_size, self.tp_world_size_ + 1, dtype=np.int64) - split_start = split_indexes[self.tp_rank_] - split_end = split_indexes[self.tp_rank_ + 1] - if "model.embed_tokens.weight" in weights: - self.wte_weight_ = self._cuda(weights["model.embed_tokens.weight"][split_start:split_end, :]) - tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False) - if tie_word_embeddings: - self.lm_head_weight_ = self.wte_weight_ - if "lm_head.weight" in weights: - self.lm_head_weight_ = self._cuda(weights["lm_head.weight"][split_start:split_end, :]) - if "model.norm.weight" in weights: - self.final_norm_weight_ = self._cuda(weights["model.norm.weight"]) - - return - def verify_load(self): - errors = "weights load not ok" - weights = [self.wte_weight_, self.lm_head_weight_, self.final_norm_weight_] - for i in range(len(weights)): - assert weights[i] is not None, "index:" + str(i) + " " + errors + vocab_size = self.network_config_["vocab_size"] + self.embdding_weight_ = EmbeddingWeight( + weight_name="model.embed_tokens.weight", + data_type=self.data_type_, + vocab_size=vocab_size, + ) + tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False) + if tie_word_embeddings: + self.lm_head_weight_ = self.embdding_weight_ + else: + self.lm_head_weight_ = LMHeadWeight( + weight_name="lm_head.weight", + data_type=self.data_type_, + vocab_size=vocab_size, + ) + self.final_norm_weight_ = NoTpNormWeight( + weight_name="model.norm.weight", + data_type=self.data_type_, + bias_name=None, + ) return diff --git a/lightllm/models/qwen/infer_struct.py b/lightllm/models/qwen/infer_struct.py index deff17ce2..d575f6d95 100644 --- a/lightllm/models/qwen/infer_struct.py +++ b/lightllm/models/qwen/infer_struct.py @@ -11,13 +11,13 @@ def __init__(self): self.position_sin = None self.logn_values = None - def init_some_extra_state(self, model, input_ids: torch.Tensor): + def init_some_extra_state(self, model): use_dynamic_ntk = model.config.get("use_dynamic_ntk", False) if not use_dynamic_ntk: - super().init_some_extra_state(model, input_ids) + super().init_some_extra_state(model) return - InferStateInfo.init_some_extra_state(self, model, input_ids) + InferStateInfo.init_some_extra_state(self, model) if self.is_prefill: position_ids = self.position_ids self.position_sin = [] diff --git a/lightllm/models/qwen2_vl/infer_struct.py b/lightllm/models/qwen2_vl/infer_struct.py index ce7938b6a..838590325 100644 --- a/lightllm/models/qwen2_vl/infer_struct.py +++ b/lightllm/models/qwen2_vl/infer_struct.py @@ -16,10 +16,10 @@ def __init__(self): self.position_cos = None self.position_sin = None - def init_some_extra_state(self, model, input_ids: torch.Tensor): + def init_some_extra_state(self, model): rope_scaling = model.config.get("rope_scaling", {}) self.rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) - InferStateInfo.init_some_extra_state(self, model, input_ids) + InferStateInfo.init_some_extra_state(self, model) if self.is_prefill: self.position_ids = self.get_mrope_position(self.multimodal_params) else: @@ -38,7 +38,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): if get_env_start_args().enable_fa3: self.max_seq_len = self.max_kv_seq_len self.q_max_seq_len = self.max_q_seq_len - self.init_flash_attention_state_func(model, input_ids) + self.init_flash_attention_state_func(model) return def get_mrope_position(self, multimodal_params: List[dict]) -> torch.Tensor: From 419e0fe1dfb5fcf42253c0be868581e091259e88 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 1 Jan 2026 19:58:45 +0800 Subject: [PATCH 40/79] fix get input len --- .../deepseek2/layer_infer/transformer_layer_infer.py | 8 ++++---- lightllm/models/llama/layer_infer/post_layer_infer.py | 4 ++-- .../models/llama/layer_infer/transformer_layer_infer.py | 6 +++--- .../qwen3_moe/layer_infer/transformer_layer_infer.py | 2 +- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index 30d37d1df..76dd1d568 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -190,7 +190,7 @@ def _tpsp_get_qkv( (sp_token_num * self.tp_world_size_, hidden_dim), dtype=input.dtype, device=input.device ) all_gather_into_tensor(gather_input, input, group=infer_state.dist_group, async_op=False) - input = gather_input[0 : len(infer_state.position_cos), :] + input = gather_input[0 : len(infer_state.input_ids), :] input = input.view(-1, self.embed_dim_) q = layer_weight.q_weight_.mm(input) @@ -223,7 +223,7 @@ def _tpsp_get_qkv( (sp_token_num * self.tp_world_size_, qkv_dim), dtype=qkv.dtype, device=qkv.device ) all_gather_into_tensor(gather_qkv, qkv, group=infer_state.dist_group, async_op=False) - qkv = gather_qkv[0 : len(infer_state.position_cos), :] + qkv = gather_qkv[0 : len(infer_state.input_ids), :] if infer_state.need_dp_prefill_balance: qkv = infer_state._all_to_all_unbalance_get(data=qkv) @@ -273,8 +273,8 @@ def _tpsp_get_o( input = input.reshape(-1, self.tp_q_head_num_ * self.qk_nope_head_dim) dest_size = triton.cdiv(input.shape[0], self.tp_world_size_) * self.tp_world_size_ o_tensor = self.alloc_tensor((dest_size, self.embed_dim_), dtype=input.dtype, device=input.device) - layer_weight.o_weight_.mm(input, out=o_tensor[0 : len(infer_state.position_cos), :]) - e_o_tensor = o_tensor[len(infer_state.position_cos) :, :] + layer_weight.o_weight_.mm(input, out=o_tensor[0 : len(infer_state.input_ids), :]) + e_o_tensor = o_tensor[len(infer_state.input_ids) :, :] if e_o_tensor.shape[0] > 0: e_o_tensor.fill_(0) diff --git a/lightllm/models/llama/layer_infer/post_layer_infer.py b/lightllm/models/llama/layer_infer/post_layer_infer.py index ce8f00681..ce224bc39 100644 --- a/lightllm/models/llama/layer_infer/post_layer_infer.py +++ b/lightllm/models/llama/layer_infer/post_layer_infer.py @@ -104,8 +104,8 @@ def tpsp_token_forward( group=infer_state.dist_group, async_op=False, ) - # len(infer_state.position_sin) 获取真实输入长度 - input_embdings = gather_data[0 : infer_state.handle_token_num] + # len(infer_state.input_ids) 获取真实输入长度 + input_embdings = gather_data[0 : len(infer_state.input_ids)] if infer_state.need_dp_prefill_balance: input_embdings = infer_state._all_to_all_unbalance_get(data=input_embdings) diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py index ea44fe2e5..83c090495 100755 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -224,7 +224,7 @@ def _tpsp_get_qkv( (sp_token_num * self.tp_world_size_, hidden_dim), dtype=input.dtype, device=input.device ) all_gather_into_tensor(gather_input, input, group=infer_state.dist_group, async_op=False) - input = gather_input[0 : len(infer_state.position_cos), :] + input = gather_input[0 : len(infer_state.input_ids), :] q = layer_weight.q_proj.mm(input) cache_kv = layer_weight.kv_proj.mm(input).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) @@ -415,8 +415,8 @@ def _tpsp_get_o( input = input.view(-1, self.tp_o_head_num_ * self.head_dim_) dest_size = triton.cdiv(input.shape[0], self.tp_world_size_) * self.tp_world_size_ o_tensor = self.alloc_tensor((dest_size, self.embed_dim_), dtype=input.dtype, device=input.device) - layer_weight.o_proj.mm(input, out=o_tensor[0 : len(infer_state.position_cos), :]) - e_o_tensor = o_tensor[len(infer_state.position_cos) :, :] + layer_weight.o_proj.mm(input, out=o_tensor[0 : len(infer_state.input_ids), :]) + e_o_tensor = o_tensor[len(infer_state.input_ids) :, :] if e_o_tensor.shape[0] > 0: e_o_tensor.fill_(0) diff --git a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py index 45f1f59d7..bd7bca39e 100644 --- a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py @@ -95,7 +95,7 @@ def _tpsp_get_qkv( (sp_token_num * self.tp_world_size_, hidden_dim), dtype=input.dtype, device=input.device ) all_gather_into_tensor(gather_input, input, group=infer_state.dist_group, async_op=False) - input = gather_input[0 : len(infer_state.position_cos), :] + input = gather_input[0 : len(infer_state.input_ids), :] input = input.view(-1, self.embed_dim_) q = layer_weight.q_proj.mm(input) From ea06919f820399241b1bcc06d4ad69f14207179f Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 2 Jan 2026 09:00:48 +0800 Subject: [PATCH 41/79] fix norm --- .../basemodel/layer_weights/meta_weights/norm_weight.py | 8 ++++++-- lightllm/common/basemodel/triton_kernel/rmsnorm.py | 3 ++- lightllm/models/llama/layer_infer/post_layer_infer.py | 2 +- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py index a3f10bf39..da272d8c8 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py @@ -36,9 +36,13 @@ def __init__(self, weight_name, data_type, bias_name=None): self.tp_world_size_ = 1 self.tp_rank_ = 0 - def rmsnorm_forward(self, input: torch.Tensor, eps: float) -> torch.Tensor: + def rmsnorm_forward( + self, input: torch.Tensor, eps: float, out: Optional[torch.Tensor] = None, alloc_func=torch.empty + ) -> torch.Tensor: assert self.bias is None - return rmsnorm_forward(x=input, weight=self.weight, eps=eps) + if out is None: + out = alloc_func(input.shape, dtype=input.dtype, device=input.device) + return rmsnorm_forward(x=input, weight=self.weight, eps=eps, out=out) class GEMMANormWeight(NormWeight): diff --git a/lightllm/common/basemodel/triton_kernel/rmsnorm.py b/lightllm/common/basemodel/triton_kernel/rmsnorm.py index 0140847af..19f7c84d3 100644 --- a/lightllm/common/basemodel/triton_kernel/rmsnorm.py +++ b/lightllm/common/basemodel/triton_kernel/rmsnorm.py @@ -44,12 +44,13 @@ def _rms_norm_fwd_fused( tl.store(Y + cols * y_stride1, y.to(Y.dtype.element_ty), mask=mask) -def rmsnorm_forward(x: torch.Tensor, weight, eps, out=None): +def rmsnorm_forward(x: torch.Tensor, weight: torch.Tensor, eps: float, out=None): # allocate output y = torch.empty_like(x) if out is None else out # reshape input data into 2D tensor x_arg = x.view(-1, x.shape[-1]) y_arg = y.view(-1, x.shape[-1]) + assert x_arg.shape[-1] == weight.shape[0] assert y.data_ptr() == y_arg.data_ptr() M, N = x_arg.shape # Less than 64KB per feature: enqueue fused kernel diff --git a/lightllm/models/llama/layer_infer/post_layer_infer.py b/lightllm/models/llama/layer_infer/post_layer_infer.py index ce224bc39..d9b82cae7 100644 --- a/lightllm/models/llama/layer_infer/post_layer_infer.py +++ b/lightllm/models/llama/layer_infer/post_layer_infer.py @@ -19,7 +19,7 @@ def __init__(self, network_config, mode): return def _norm(self, input, infer_state, layer_weight: LlamaPreAndPostLayerWeight) -> torch.Tensor: - return layer_weight.final_norm_weight_.rmsnorm_forward(input=input, eps=self.eps_) + return layer_weight.final_norm_weight_.rmsnorm_forward(input=input, eps=self.eps_, alloc_func=self.alloc_tensor) def _slice_get_last_input(self, input_embdings: torch.Tensor, infer_state: LlamaInferStateInfo): embed_dim_ = input_embdings.shape[1] From 35ab3bb3c70e2cacfbdd2e2e20e02100d78e30d9 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 2 Jan 2026 09:34:17 +0800 Subject: [PATCH 42/79] fix norm head tp --- .../layer_weights/meta_weights/__init__.py | 3 +- .../meta_weights/att_sink_weight.py | 25 +++++++++++++ .../layer_weights/meta_weights/norm_weight.py | 36 +++++++++++++++---- .../layer_weights/transformer_layer_weight.py | 11 +++--- .../layer_infer/transformer_layer_infer.py | 2 +- .../layer_weights/transformer_layer_weight.py | 9 +++-- 6 files changed, 67 insertions(+), 19 deletions(-) create mode 100644 lightllm/common/basemodel/layer_weights/meta_weights/att_sink_weight.py diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py index c8d356394..d6804bb44 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py @@ -6,7 +6,8 @@ COLMMWeight, ROWBMMWeight, ) -from .norm_weight import NormWeight, GEMMANormWeight, TpNormWeight, NoTpNormWeight +from .norm_weight import NormWeight, GEMMANormWeight, TpNormWeight, NoTpNormWeight, TpHeadNormWeight from .fused_moe_weight_tp import create_tp_moe_wegiht_obj from .fused_moe_weight_ep import FusedMoeWeightEP from .embedding_weight import EmbeddingWeight, LMHeadWeight +from .att_sink_weight import TpAttSinkWeight diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/att_sink_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/att_sink_weight.py new file mode 100644 index 000000000..7d785ea82 --- /dev/null +++ b/lightllm/common/basemodel/layer_weights/meta_weights/att_sink_weight.py @@ -0,0 +1,25 @@ +import torch +from typing import Dict +from .base_weight import BaseWeightTpl +from lightllm.utils.dist_utils import get_current_device_id + + +class TpAttSinkWeight(BaseWeightTpl): + def __init__(self, weight_name: str, data_type, tp_head_num: int): + super().__init__() + self.weight_name = weight_name + self.data_type_ = data_type + self.weight: torch.Tensor = None + self.tp_head_num = tp_head_num + assert self.tp_head_num > 0 + + def load_hf_weights(self, weights: Dict[str, torch.Tensor]): + start_head_index = self.tp_head_num * self.tp_rank_ + end_head_index = self.tp_head_num * (self.tp_rank_ + 1) + + if self.weight_name in weights: + self.weight = ( + weights[self.weight_name][start_head_index:end_head_index] + .to(self.data_type_) + .cuda(get_current_device_id()) + ) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py index da272d8c8..9ff00cc49 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py @@ -29,13 +29,6 @@ def verify_load(self): load_ok = load_ok and self.bias is not None return load_ok - -class NoTpNormWeight(NormWeight): - def __init__(self, weight_name, data_type, bias_name=None): - super().__init__(weight_name=weight_name, data_type=data_type, bias_name=bias_name) - self.tp_world_size_ = 1 - self.tp_rank_ = 0 - def rmsnorm_forward( self, input: torch.Tensor, eps: float, out: Optional[torch.Tensor] = None, alloc_func=torch.empty ) -> torch.Tensor: @@ -45,6 +38,13 @@ def rmsnorm_forward( return rmsnorm_forward(x=input, weight=self.weight, eps=eps, out=out) +class NoTpNormWeight(NormWeight): + def __init__(self, weight_name, data_type, bias_name=None): + super().__init__(weight_name=weight_name, data_type=data_type, bias_name=bias_name) + self.tp_world_size_ = 1 + self.tp_rank_ = 0 + + class GEMMANormWeight(NormWeight): def __init__(self, weight_name, data_type, bias_name=None): super().__init__(weight_name, data_type, bias_name) @@ -67,3 +67,25 @@ def load_hf_weights(self, weights): self.weight = weights[self.weight_name][start:end].to(self.data_type_).cuda(get_current_device_id()) if self.bias_name in weights: self.bias = weights[self.bias_name][start:end].to(self.data_type_).cuda(get_current_device_id()) + + +class TpHeadNormWeight(NormWeight): + def __init__(self, weight_name, data_type, tp_head_num, bias_name=None): + super().__init__(weight_name, data_type, bias_name) + self.tp_head_num = tp_head_num + assert self.tp_head_num > 0 + + def load_hf_weights(self, weights): + start = self.tp_head_num * self.tp_rank_ + end = self.tp_head_num * (self.tp_rank_ + 1) + + if self.weight_name in weights: + self.weight: torch.Tensor = ( + weights[self.weight_name][start:end].to(self.data_type_).cuda(get_current_device_id()) + ) + assert self.weight.ndim == 2 + if self.bias_name in weights: + self.bias: torch.Tensor = ( + weights[self.bias_name][start:end].to(self.data_type_).cuda(get_current_device_id()) + ) + assert self.bias.ndim == 2 diff --git a/lightllm/models/cohere/layer_weights/transformer_layer_weight.py b/lightllm/models/cohere/layer_weights/transformer_layer_weight.py index fff92abf5..bfe1f534c 100644 --- a/lightllm/models/cohere/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/cohere/layer_weights/transformer_layer_weight.py @@ -1,8 +1,5 @@ from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight -from lightllm.common.basemodel.layer_weights.meta_weights import ( - NormWeight, - TpNormWeight, -) +from lightllm.common.basemodel.layer_weights.meta_weights import NormWeight, TpHeadNormWeight class CohereTransformerLayerWeight(LlamaTransformerLayerWeight): @@ -14,17 +11,17 @@ def _parse_config(self): super()._parse_config() self.use_qk_norm = self.network_config_.get("use_qk_norm", False) - def _init_norm(self, weights): + def _init_norm(self): q_split_head = self.network_config_["num_attention_heads"] // self.tp_world_size_ k_split_head = self.network_config_["num_key_value_heads"] // self.tp_world_size_ self.att_norm_weight_ = NormWeight(self._att_norm_weight_name, self.data_type_) if self.use_qk_norm: - self.q_norm_weight_ = TpNormWeight( + self.q_norm_weight_ = TpHeadNormWeight( f"model.layers.{self.layer_num_}.self_attn.q_norm.weight", self.data_type_, q_split_head ) - self.k_norm_weight_ = TpNormWeight( + self.k_norm_weight_ = TpHeadNormWeight( f"model.layers.{self.layer_num_}.self_attn.k_norm.weight", self.data_type_, k_split_head ) diff --git a/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py b/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py index 1246af090..8aacb457d 100644 --- a/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py @@ -69,7 +69,7 @@ def _ffn( return hidden_states.view(num_tokens, hidden_dim) def _context_sliding_attention_flashattention( - self, q, kv, infer_state: FlashAttentionStateInfo, layer_weight, out=None + self, q, kv, infer_state: FlashAttentionStateInfo, layer_weight: GptOssTransformerLayerWeight, out=None ): if self.network_config_["layer_types"][self.layer_num_] == "sliding_attention": window_size = (self.sliding_window - 1, self.sliding_window - 1) diff --git a/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py b/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py index 7e6035dc5..94f3efc15 100644 --- a/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py @@ -5,6 +5,7 @@ from lightllm.common.basemodel.layer_weights.meta_weights.gpt_oss_fused_moe_weight_tp import GPTOSSFusedMoeWeightTP from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight import ROWMMWeight from lightllm.common.basemodel.layer_weights.meta_weights.norm_weight import NormWeight, TpNormWeight +from lightllm.common.basemodel.layer_weights.meta_weights import TpAttSinkWeight from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight from lightllm.utils.log_utils import init_logger @@ -57,8 +58,6 @@ def _init_moe(self): def _init_weight_names(self): super()._init_weight_names() - self._attn_sink_name = f"model.layers.{self.layer_num_}.self_attn.sinks" - self._q_bias_name = f"model.layers.{self.layer_num_}.self_attn.q_proj.bias" self._k_bias_name = f"model.layers.{self.layer_num_}.self_attn.k_proj.bias" self._v_bias_name = f"model.layers.{self.layer_num_}.self_attn.v_proj.bias" @@ -71,7 +70,11 @@ def _init_weight(self): super()._init_weight() n_split_head = self.network_config_["num_attention_heads"] // self.tp_world_size_ - self.attn_sinks = TpNormWeight(self._attn_sink_name, torch.bfloat16, n_split_head) + self.attn_sinks = TpAttSinkWeight( + weight_name=f"model.layers.{self.layer_num_}.self_attn.sinks", + data_type=torch.bfloat16, + tp_head_num=n_split_head, + ) def _init_ffn(self): self._init_moe() From aedb85dcb42d4a3ea044fcbdc6dd17c32e67e7fe Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 2 Jan 2026 09:52:39 +0800 Subject: [PATCH 43/79] fix norm --- .../layer_weights/meta_weights/__init__.py | 2 +- .../layer_weights/meta_weights/norm_weight.py | 26 +++++++++++-------- .../layer_weights/transformer_layer_weight.py | 4 +-- .../layer_weights/transformer_layer_weight.py | 12 +++++---- .../layer_weights/transformer_layer_weight.py | 10 +++---- .../layer_weights/transformer_layer_weight.py | 6 ++--- .../layer_weights/transformer_layer_weight.py | 6 ++--- .../layer_weights/transformer_layer_weight.py | 4 +-- .../pre_and_post_layer_weight.py | 2 +- .../layer_weights/transformer_layer_weight.py | 6 ++--- .../layer_weights/transformer_layer_weight.py | 4 +-- .../layer_weights/transformer_layer_weight.py | 1 - .../layer_weights/transformer_layer_weight.py | 6 ++--- 13 files changed, 47 insertions(+), 42 deletions(-) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py index d6804bb44..a0b54f3d6 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py @@ -6,7 +6,7 @@ COLMMWeight, ROWBMMWeight, ) -from .norm_weight import NormWeight, GEMMANormWeight, TpNormWeight, NoTpNormWeight, TpHeadNormWeight +from .norm_weight import NoTpGEMMANormWeight, TpNormWeight, NoTpNormWeight, TpHeadNormWeight from .fused_moe_weight_tp import create_tp_moe_wegiht_obj from .fused_moe_weight_ep import FusedMoeWeightEP from .embedding_weight import EmbeddingWeight, LMHeadWeight diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py index 9ff00cc49..c159451ae 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py @@ -5,7 +5,7 @@ from lightllm.common.basemodel.triton_kernel.rmsnorm import rmsnorm_forward -class NormWeight(BaseWeightTpl): +class _NormWeight(BaseWeightTpl): def __init__(self, weight_name, data_type, bias_name=None): super().__init__() self.weight_name = weight_name @@ -14,12 +14,6 @@ def __init__(self, weight_name, data_type, bias_name=None): self.weight: torch.Tensor = None self.bias: Optional[torch.Tensor] = None - def load_hf_weights(self, weights): - if self.weight_name in weights: - self.weight = weights[self.weight_name].to(self.data_type_).cuda(get_current_device_id()) - if self.bias_name in weights: - self.bias = weights[self.bias_name].to(self.data_type_).cuda(get_current_device_id()) - def verify_load(self): load_ok = True # Verify weight. The weight must be not None. @@ -32,29 +26,39 @@ def verify_load(self): def rmsnorm_forward( self, input: torch.Tensor, eps: float, out: Optional[torch.Tensor] = None, alloc_func=torch.empty ) -> torch.Tensor: + assert input.ndim == 2 and self.weight.ndim == 1 assert self.bias is None if out is None: out = alloc_func(input.shape, dtype=input.dtype, device=input.device) return rmsnorm_forward(x=input, weight=self.weight, eps=eps, out=out) -class NoTpNormWeight(NormWeight): +class NoTpNormWeight(_NormWeight): def __init__(self, weight_name, data_type, bias_name=None): super().__init__(weight_name=weight_name, data_type=data_type, bias_name=bias_name) self.tp_world_size_ = 1 self.tp_rank_ = 0 + def load_hf_weights(self, weights): + if self.weight_name in weights: + self.weight = weights[self.weight_name].to(self.data_type_).cuda(get_current_device_id()) + if self.bias_name in weights: + self.bias = weights[self.bias_name].to(self.data_type_).cuda(get_current_device_id()) + -class GEMMANormWeight(NormWeight): +class NoTpGEMMANormWeight(_NormWeight): def __init__(self, weight_name, data_type, bias_name=None): super().__init__(weight_name, data_type, bias_name) + assert self.bias_name is None + self.tp_world_size_ = 1 + self.tp_rank_ = 0 def load_hf_weights(self, weights): if self.weight_name in weights: self.weight = (weights[self.weight_name] + 1).to(self.data_type_).cuda(get_current_device_id()) -class TpNormWeight(NormWeight): +class TpNormWeight(_NormWeight): def __init__(self, weight_name, data_type, split_n_embed, bias_name=None): super().__init__(weight_name, data_type, bias_name) self.split_n_embed = split_n_embed @@ -69,7 +73,7 @@ def load_hf_weights(self, weights): self.bias = weights[self.bias_name][start:end].to(self.data_type_).cuda(get_current_device_id()) -class TpHeadNormWeight(NormWeight): +class TpHeadNormWeight(_NormWeight): def __init__(self, weight_name, data_type, tp_head_num, bias_name=None): super().__init__(weight_name, data_type, bias_name) self.tp_head_num = tp_head_num diff --git a/lightllm/models/cohere/layer_weights/transformer_layer_weight.py b/lightllm/models/cohere/layer_weights/transformer_layer_weight.py index bfe1f534c..4af0ab8f0 100644 --- a/lightllm/models/cohere/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/cohere/layer_weights/transformer_layer_weight.py @@ -1,5 +1,5 @@ from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight -from lightllm.common.basemodel.layer_weights.meta_weights import NormWeight, TpHeadNormWeight +from lightllm.common.basemodel.layer_weights.meta_weights import NoTpNormWeight, TpHeadNormWeight class CohereTransformerLayerWeight(LlamaTransformerLayerWeight): @@ -15,7 +15,7 @@ def _init_norm(self): q_split_head = self.network_config_["num_attention_heads"] // self.tp_world_size_ k_split_head = self.network_config_["num_key_value_heads"] // self.tp_world_size_ - self.att_norm_weight_ = NormWeight(self._att_norm_weight_name, self.data_type_) + self.att_norm_weight_ = NoTpNormWeight(self._att_norm_weight_name, self.data_type_) if self.use_qk_norm: self.q_norm_weight_ = TpHeadNormWeight( diff --git a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py index c899751eb..611878f9e 100644 --- a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py @@ -7,7 +7,7 @@ from lightllm.common.basemodel.layer_weights.meta_weights import ( ROWMMWeight, COLMMWeight, - NormWeight, + NoTpNormWeight, FusedMoeWeightEP, ROWBMMWeight, create_tp_moe_wegiht_obj, @@ -299,14 +299,16 @@ def _init_ffn(self): self._load_mlp(f"model.layers.{self.layer_num_}.mlp") def _init_norm(self): - self.att_norm_weight_ = NormWeight(f"model.layers.{self.layer_num_}.input_layernorm.weight", self.data_type_) - self.ffn_norm_weight_ = NormWeight( + self.att_norm_weight_ = NoTpNormWeight( + f"model.layers.{self.layer_num_}.input_layernorm.weight", self.data_type_ + ) + self.ffn_norm_weight_ = NoTpNormWeight( f"model.layers.{self.layer_num_}.post_attention_layernorm.weight", self.data_type_ ) - self.kv_a_layernorm_ = NormWeight( + self.kv_a_layernorm_ = NoTpNormWeight( f"model.layers.{self.layer_num_}.self_attn.kv_a_layernorm.weight", self.data_type_ ) if self.q_lora_rank is not None: - self.q_a_layernorm_ = NormWeight( + self.q_a_layernorm_ = NoTpNormWeight( f"model.layers.{self.layer_num_}.self_attn.q_a_layernorm.weight", self.data_type_ ) diff --git a/lightllm/models/gemma3/layer_weights/transformer_layer_weight.py b/lightllm/models/gemma3/layer_weights/transformer_layer_weight.py index 6f5530461..71227bd9b 100644 --- a/lightllm/models/gemma3/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/gemma3/layer_weights/transformer_layer_weight.py @@ -1,7 +1,7 @@ import torch import numpy as np from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight import ROWMMWeight -from lightllm.common.basemodel.layer_weights.meta_weights.norm_weight import NormWeight +from lightllm.common.basemodel.layer_weights.meta_weights import NoTpNormWeight from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight @@ -66,12 +66,12 @@ def _init_qkv(self): def _init_norm(self): super()._init_norm() - self.k_norm_weight_ = NormWeight(self._k_norm_weight_name, self.data_type_, bias_name=None) - self.q_norm_weight_ = NormWeight(self._q_norm_weight_name, self.data_type_, bias_name=None) - self.pre_feedforward_layernorm_weight_ = NormWeight( + self.k_norm_weight_ = NoTpNormWeight(self._k_norm_weight_name, self.data_type_, bias_name=None) + self.q_norm_weight_ = NoTpNormWeight(self._q_norm_weight_name, self.data_type_, bias_name=None) + self.pre_feedforward_layernorm_weight_ = NoTpNormWeight( self._pre_feedforward_layernorm_name, self.data_type_, bias_name=None ) - self.post_feedforward_layernorm_weight_ = NormWeight( + self.post_feedforward_layernorm_weight_ = NoTpNormWeight( self._post_feedforward_layernorm_name, self.data_type_, bias_name=None ) diff --git a/lightllm/models/gemma_2b/layer_weights/transformer_layer_weight.py b/lightllm/models/gemma_2b/layer_weights/transformer_layer_weight.py index 32248e6dd..1916bd095 100644 --- a/lightllm/models/gemma_2b/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/gemma_2b/layer_weights/transformer_layer_weight.py @@ -2,7 +2,7 @@ import math import numpy as np from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight -from lightllm.common.basemodel.layer_weights.meta_weights import GEMMANormWeight, ROWMMWeight +from lightllm.common.basemodel.layer_weights.meta_weights import NoTpGEMMANormWeight, ROWMMWeight class Gemma_2bTransformerLayerWeight(LlamaTransformerLayerWeight): @@ -29,5 +29,5 @@ def _init_qkv(self): ) def _init_norm(self): - self.att_norm_weight_ = GEMMANormWeight(self._att_norm_weight_name, self.data_type_) - self.ffn_norm_weight_ = GEMMANormWeight(self._ffn_norm_weight_name, self.data_type_) + self.att_norm_weight_ = NoTpGEMMANormWeight(self._att_norm_weight_name, self.data_type_) + self.ffn_norm_weight_ = NoTpGEMMANormWeight(self._ffn_norm_weight_name, self.data_type_) diff --git a/lightllm/models/llama/layer_weights/transformer_layer_weight.py b/lightllm/models/llama/layer_weights/transformer_layer_weight.py index 624717007..6b92272ee 100644 --- a/lightllm/models/llama/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/llama/layer_weights/transformer_layer_weight.py @@ -2,7 +2,7 @@ import math import numpy as np from lightllm.common.basemodel import TransformerLayerWeight -from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, COLMMWeight, NormWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, COLMMWeight, NoTpNormWeight class LlamaTransformerLayerWeight(TransformerLayerWeight): @@ -103,9 +103,9 @@ def _init_ffn(self): ) def _init_norm(self): - self.att_norm_weight_ = NormWeight( + self.att_norm_weight_ = NoTpNormWeight( self._att_norm_weight_name, self.data_type_, bias_name=self._att_norm_bias_name ) - self.ffn_norm_weight_ = NormWeight( + self.ffn_norm_weight_ = NoTpNormWeight( self._ffn_norm_weight_name, self.data_type_, bias_name=self._ffn_norm_bias_name ) diff --git a/lightllm/models/mistral_mtp/layer_weights/transformer_layer_weight.py b/lightllm/models/mistral_mtp/layer_weights/transformer_layer_weight.py index 034713ee4..6607dbb70 100644 --- a/lightllm/models/mistral_mtp/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/mistral_mtp/layer_weights/transformer_layer_weight.py @@ -1,5 +1,5 @@ from lightllm.common.basemodel import TransformerLayerWeight -from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, COLMMWeight, NormWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, COLMMWeight, NoTpNormWeight class MistralMTPTransformerLayerWeight(TransformerLayerWeight): @@ -41,6 +41,6 @@ def _init_ffn(self): ) def _init_norm(self): - self.ffn_norm_weight_ = NormWeight( + self.ffn_norm_weight_ = NoTpNormWeight( self._ffn_norm_weight_name, self.data_type_, bias_name=self._ffn_norm_bias_name ) diff --git a/lightllm/models/qwen2_reward/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen2_reward/layer_weights/pre_and_post_layer_weight.py index a56c5d6cb..8c1e288a4 100644 --- a/lightllm/models/qwen2_reward/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/qwen2_reward/layer_weights/pre_and_post_layer_weight.py @@ -1,7 +1,7 @@ import torch import numpy as np from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight -from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, COLMMWeight, NormWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, COLMMWeight class Qwen2RewardPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): diff --git a/lightllm/models/qwen3/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3/layer_weights/transformer_layer_weight.py index 4c0ef586f..86b9e172a 100644 --- a/lightllm/models/qwen3/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3/layer_weights/transformer_layer_weight.py @@ -1,6 +1,6 @@ from lightllm.models.qwen2.layer_weights.transformer_layer_weight import Qwen2TransformerLayerWeight from lightllm.common.basemodel.layer_weights.meta_weights import ( - NormWeight, + NoTpNormWeight, ) @@ -20,5 +20,5 @@ def _init_weight_names(self): def _init_norm(self): super()._init_norm() - self.q_norm_weight_ = NormWeight(weight_name=self._q_norm_name, data_type=self.data_type_) - self.k_norm_weight_ = NormWeight(weight_name=self._k_norm_name, data_type=self.data_type_) + self.q_norm_weight_ = NoTpNormWeight(weight_name=self._q_norm_name, data_type=self.data_type_) + self.k_norm_weight_ = NoTpNormWeight(weight_name=self._k_norm_name, data_type=self.data_type_) diff --git a/lightllm/models/qwen3_moe_mtp/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_moe_mtp/layer_weights/transformer_layer_weight.py index feb06c5d4..22d4d1950 100644 --- a/lightllm/models/qwen3_moe_mtp/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3_moe_mtp/layer_weights/transformer_layer_weight.py @@ -1,6 +1,6 @@ import os from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight -from lightllm.common.basemodel.layer_weights.meta_weights import NormWeight +from lightllm.common.basemodel.layer_weights.meta_weights import NoTpNormWeight class Qwen3MOEMTPTransformerLayerWeight(Qwen3MOETransformerLayerWeight): @@ -16,6 +16,6 @@ def _init_weight(self): self._init_ffn() def _init_norm(self): - self.ffn_norm_weight_ = NormWeight( + self.ffn_norm_weight_ = NoTpNormWeight( self._ffn_norm_weight_name, self.data_type_, bias_name=self._ffn_norm_bias_name ) diff --git a/lightllm/models/stablelm/layer_weights/transformer_layer_weight.py b/lightllm/models/stablelm/layer_weights/transformer_layer_weight.py index dc7b00862..a1a73f674 100755 --- a/lightllm/models/stablelm/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/stablelm/layer_weights/transformer_layer_weight.py @@ -1,5 +1,4 @@ from lightllm.models.qwen2.layer_weights.transformer_layer_weight import Qwen2TransformerLayerWeight -from lightllm.common.basemodel.layer_weights.meta_weights import NormWeight class StablelmTransformerLayerWeight(Qwen2TransformerLayerWeight): diff --git a/lightllm/models/vit/layer_weights/transformer_layer_weight.py b/lightllm/models/vit/layer_weights/transformer_layer_weight.py index f1de0bdc1..881a378e2 100644 --- a/lightllm/models/vit/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/vit/layer_weights/transformer_layer_weight.py @@ -7,7 +7,7 @@ from lightllm.common.basemodel.layer_weights.meta_weights import ( ROWMMWeight, COLMMWeight, - NormWeight, + NoTpNormWeight, TpNormWeight, ) from lightllm.utils.dist_utils import get_current_device_id @@ -119,10 +119,10 @@ def _init_ffn(self): ) def _init_norm(self): - self.att_norm_weight_ = NormWeight( + self.att_norm_weight_ = NoTpNormWeight( self._att_norm_weight_name, self.data_type_, bias_name=self._att_norm_bias_name ) - self.ffn_norm_weight_ = NormWeight( + self.ffn_norm_weight_ = NoTpNormWeight( self._ffn_norm_weight_name, self.data_type_, bias_name=self._ffn_norm_bias_name ) if self.qk_norm: From 82112f8ec7f1e360642ae29d5f09d92dfabf7ca8 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 2 Jan 2026 10:28:00 +0800 Subject: [PATCH 44/79] fix bloom --- .../layer_weights/meta_weights/norm_weight.py | 14 +++++ .../basemodel}/triton_kernel/layernorm.py | 19 +++---- .../bloom/layer_infer/post_layer_infer.py | 53 +++---------------- .../bloom/layer_infer/pre_layer_infer.py | 21 ++------ .../layer_infer/transformer_layer_infer.py | 28 ++++------ .../pre_and_post_layer_weight.py | 53 +++++++------------ 6 files changed, 61 insertions(+), 127 deletions(-) rename lightllm/{models/bloom => common/basemodel}/triton_kernel/layernorm.py (84%) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py index c159451ae..6c1fbc736 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py @@ -3,6 +3,7 @@ from .base_weight import BaseWeightTpl from lightllm.utils.dist_utils import get_current_device_id from lightllm.common.basemodel.triton_kernel.rmsnorm import rmsnorm_forward +from lightllm.common.basemodel.triton_kernel.layernorm import layernorm_forward class _NormWeight(BaseWeightTpl): @@ -32,6 +33,19 @@ def rmsnorm_forward( out = alloc_func(input.shape, dtype=input.dtype, device=input.device) return rmsnorm_forward(x=input, weight=self.weight, eps=eps, out=out) + def layernorm_forward( + self, input: torch.Tensor, eps: float, out: Optional[torch.Tensor] = None, alloc_func=torch.empty + ) -> torch.Tensor: + assert input.ndim == 2 and self.weight.ndim == 1 + assert self.bias is not None + + _tout = layernorm_forward(x=input, weight=self.weight, bias=self.bias, eps=eps) + if out is None: + return _tout + else: + out.copy_(_tout) + return out + class NoTpNormWeight(_NormWeight): def __init__(self, weight_name, data_type, bias_name=None): diff --git a/lightllm/models/bloom/triton_kernel/layernorm.py b/lightllm/common/basemodel/triton_kernel/layernorm.py similarity index 84% rename from lightllm/models/bloom/triton_kernel/layernorm.py rename to lightllm/common/basemodel/triton_kernel/layernorm.py index 6911d707b..538bb2b13 100644 --- a/lightllm/models/bloom/triton_kernel/layernorm.py +++ b/lightllm/common/basemodel/triton_kernel/layernorm.py @@ -24,15 +24,15 @@ def _layer_norm_fwd_fused( _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32) for off in range(0, N, BLOCK_SIZE): cols = off + tl.arange(0, BLOCK_SIZE) - a = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) + a = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) _mean += a mean = tl.sum(_mean, axis=0) / N # Compute variance _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) for off in range(0, N, BLOCK_SIZE): cols = off + tl.arange(0, BLOCK_SIZE) - x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) - x = tl.where(cols < N, x - mean, 0.) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + x = tl.where(cols < N, x - mean, 0.0) _var += x * x var = tl.sum(_var, axis=0) / N rstd = 1 / tl.sqrt(var + eps) @@ -42,7 +42,7 @@ def _layer_norm_fwd_fused( mask = cols < N w = tl.load(W + cols, mask=mask).to(tl.float32) b = tl.load(B + cols, mask=mask).to(tl.float32) - x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32) + x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) x_hat = (x - mean) * rstd y = x_hat * w + b # Write output @@ -72,17 +72,18 @@ def _layer_norm_fwd_fused( # BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps) # return y + def layernorm_forward(x, weight, bias, eps): return torch.layer_norm(x, (x.shape[-1],), weight, bias, eps) -def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'): +def test_layer_norm(M, N, dtype, eps=1e-5, device="cuda"): # create data x_shape = (M, N) - w_shape = (x_shape[-1], ) - weight = torch.rand(w_shape, dtype=dtype, device='cuda') - bias = torch.rand(w_shape, dtype=dtype, device='cuda') - x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda') + w_shape = (x_shape[-1],) + weight = torch.rand(w_shape, dtype=dtype, device="cuda") + bias = torch.rand(w_shape, dtype=dtype, device="cuda") + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") # forward pass y_tri = layernorm_forward(x, weight, bias, eps) y_ref = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype) diff --git a/lightllm/models/bloom/layer_infer/post_layer_infer.py b/lightllm/models/bloom/layer_infer/post_layer_infer.py index ee533d86d..7938869f5 100644 --- a/lightllm/models/bloom/layer_infer/post_layer_infer.py +++ b/lightllm/models/bloom/layer_infer/post_layer_infer.py @@ -1,62 +1,21 @@ import torch import torch.functional as F import numpy as np - from lightllm.models.bloom.layer_weights.pre_and_post_layer_weight import BloomPreAndPostLayerWeight -from einops import rearrange -from lightllm.common.basemodel import InferStateInfo, PostLayerInferTpl -from lightllm.models.bloom.triton_kernel.layernorm import layernorm_forward from lightllm.distributed.communication_op import all_gather +from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer +from lightllm.common.build_utils import repair_config -class BloomPostLayerInfer(PostLayerInferTpl): +class BloomPostLayerInfer(LlamaPostLayerInfer): """ """ def __init__(self, network_config, mode): + repair_config(config=network_config, same_names=["layer_norm_epsilon", "rms_norm_eps"]) super().__init__(network_config, mode) - assert network_config["vocab_size"] % self.tp_world_size_ == 0 - self.eps_ = network_config["layer_norm_epsilon"] - self.vocab_size_ = network_config["vocab_size"] - self.embed_dim_ = network_config["n_embed"] return def _norm(self, input, infer_state, layer_weight: BloomPreAndPostLayerWeight) -> torch.Tensor: - return layernorm_forward(input, layer_weight.final_norm_weight_, layer_weight.final_norm_bias_, eps=self.eps_) - - def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight: BloomPreAndPostLayerWeight): - batch_size = infer_state.batch_size - last_input = self.alloc_tensor( - (batch_size, self.embed_dim_), device=input_embdings.device, dtype=input_embdings.dtype + return layer_weight.final_norm_weight_.layernorm_forward( + input=input, eps=self.eps_, alloc_func=self.alloc_tensor ) - if infer_state.is_prefill: - last_index = ( - torch.cumsum(infer_state.b_seq_len - infer_state.b_ready_cache_len, dim=0, dtype=torch.long) - 1 - ) - last_input[:, :] = input_embdings[last_index, :] - else: - last_input[:, :] = input_embdings[-batch_size:, :] - - input_embdings_dtype = input_embdings.dtype - input_embdings = None - last_input = self._norm(last_input, infer_state, layer_weight) - last_input = rearrange(last_input, "batch embed_dim -> embed_dim batch").contiguous().reshape(-1, batch_size) - logic_batch = torch.mm(layer_weight.lm_head_weight_, last_input) - last_input = None - if self.tp_world_size_ == 1: - gather_data = logic_batch - else: - gather_data = self.alloc_tensor( - (self.vocab_size_, batch_size), device=logic_batch.device, dtype=input_embdings_dtype - ) - split_size = self.vocab_size_ // self.tp_world_size_ - all_gather( - [gather_data[i * split_size : (i + 1) * split_size, :] for i in range(self.tp_world_size_)], - logic_batch, - group=infer_state.dist_group, - async_op=False, - ) - logic_batch = None - - ans_logics = gather_data.permute(1, 0).float().contiguous() - gather_data = None - return ans_logics diff --git a/lightllm/models/bloom/layer_infer/pre_layer_infer.py b/lightllm/models/bloom/layer_infer/pre_layer_infer.py index 1733c20b0..baf1d3084 100644 --- a/lightllm/models/bloom/layer_infer/pre_layer_infer.py +++ b/lightllm/models/bloom/layer_infer/pre_layer_infer.py @@ -3,9 +3,6 @@ from lightllm.common.basemodel import PreLayerInferTpl from lightllm.common.basemodel import InferStateInfo from lightllm.models.bloom.layer_weights.pre_and_post_layer_weight import BloomPreAndPostLayerWeight -from lightllm.utils.infer_utils import mark_cost_time -from lightllm.models.bloom.triton_kernel.layernorm import layernorm_forward -from lightllm.models.llama.triton_kernel.embedding import embedding from lightllm.distributed.communication_op import all_reduce @@ -15,32 +12,20 @@ class BloomPreLayerInfer(PreLayerInferTpl): def __init__(self, network_config, mode): super().__init__(network_config, mode) self.eps_ = network_config["layer_norm_epsilon"] - tp_vocab_size_ = network_config["vocab_size"] // self.tp_world_size_ - self.vob_start_id_ = tp_vocab_size_ * self.tp_rank_ - self.vob_end_id_ = tp_vocab_size_ * (self.tp_rank_ + 1) return def _norm(self, input, infer_state, layer_weight: BloomPreAndPostLayerWeight) -> torch.Tensor: - return layernorm_forward(input, layer_weight.pre_norm_weight_, layer_weight.pre_norm_bias_, eps=self.eps_) + return layer_weight.pre_norm_weight_.layernorm_forward(input=input, eps=self.eps_, alloc_func=self.alloc_tensor) def context_forward(self, input_ids, infer_state: InferStateInfo, layer_weight: BloomPreAndPostLayerWeight): - total_token_num = infer_state.total_token_num - input_ids = input_ids[0:total_token_num] - - input_embdings = self.alloc_tensor( - (input_ids.shape[0], layer_weight.wte_weight_.shape[1]), dtype=layer_weight.data_type_ - ) - embedding(input_ids, layer_weight.wte_weight_, self.vob_start_id_, self.vob_end_id_, input_embdings) + input_embdings = layer_weight.wte_weight_.embedding(input_ids=input_ids, alloc_func=self.alloc_tensor) if self.tp_world_size_ > 1: all_reduce(input_embdings, group=infer_state.dist_group, op=dist.ReduceOp.SUM, async_op=False) input_embdings = self._norm(input_embdings, infer_state, layer_weight) return input_embdings def token_forward(self, input_ids, infer_state: InferStateInfo, layer_weight: BloomPreAndPostLayerWeight): - input_embdings = self.alloc_tensor( - (input_ids.shape[0], layer_weight.wte_weight_.shape[1]), dtype=layer_weight.data_type_ - ) - embedding(input_ids, layer_weight.wte_weight_, self.vob_start_id_, self.vob_end_id_, input_embdings) + input_embdings = layer_weight.wte_weight_.embedding(input_ids=input_ids, alloc_func=self.alloc_tensor) if self.tp_world_size_ > 1: all_reduce(input_embdings, group=infer_state.dist_group, op=dist.ReduceOp.SUM, async_op=False) input_embdings = self._norm(input_embdings, infer_state, layer_weight) diff --git a/lightllm/models/bloom/layer_infer/transformer_layer_infer.py b/lightllm/models/bloom/layer_infer/transformer_layer_infer.py index 8299697f3..d82a23d03 100755 --- a/lightllm/models/bloom/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/bloom/layer_infer/transformer_layer_infer.py @@ -1,16 +1,10 @@ -import time import torch -import torch.functional as F -import torch.distributed as dist -import numpy as np from typing import Tuple from lightllm.common.basemodel import TransformerLayerInferTpl from lightllm.models.bloom.layer_weights.transformer_layer_weight import BloomTransformerLayerWeight from lightllm.models.bloom.triton_kernel.context_flashattention_nopad import context_attention_fwd from lightllm.models.bloom.triton_kernel.token_flashattention_nopad import token_attention_fwd -from lightllm.models.bloom.triton_kernel.layernorm import layernorm_forward from lightllm.common.basemodel import InferStateInfo -from lightllm.utils.infer_utils import mark_cost_time class BloomTransformerLayerInfer(TransformerLayerInferTpl): @@ -27,20 +21,18 @@ def __init__(self, layer_num, network_config, mode): self.embed_dim_ = network_config["n_embed"] return - def _att_norm(self, input, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight) -> torch.Tensor: - return layernorm_forward( - input.view(-1, self.embed_dim_), - weight=layer_weight.att_norm_weight_.weight, - bias=layer_weight.att_norm_weight_.bias, - eps=self.eps_, + def _att_norm( + self, input: torch.Tensor, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight + ) -> torch.Tensor: + return layer_weight.att_norm_weight_.layernorm_forward( + input=input.view(-1, self.embed_dim_), eps=self.eps_, alloc_func=self.alloc_tensor ) - def _ffn_norm(self, input, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight) -> torch.Tensor: - return layernorm_forward( - input.view(-1, self.embed_dim_), - weight=layer_weight.ffn_norm_weight_.weight, - bias=layer_weight.ffn_norm_weight_.bias, - eps=self.eps_, + def _ffn_norm( + self, input: torch.Tensor, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight + ) -> torch.Tensor: + return layer_weight.ffn_norm_weight_.layernorm_forward( + input=input.view(-1, self.embed_dim_), eps=self.eps_, alloc_func=self.alloc_tensor ) def _get_qkv( diff --git a/lightllm/models/bloom/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/bloom/layer_weights/pre_and_post_layer_weight.py index b740bb62f..23d443dc8 100644 --- a/lightllm/models/bloom/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/bloom/layer_weights/pre_and_post_layer_weight.py @@ -1,43 +1,26 @@ import torch import numpy as np from lightllm.common.basemodel import PreAndPostLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, NoTpNormWeight class BloomPreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) - - def load_hf_weights(self, weights): - - if "word_embeddings_layernorm.weight" in weights: - self.pre_norm_weight_ = self._cuda(weights["word_embeddings_layernorm.weight"]) - if "word_embeddings_layernorm.bias" in weights: - self.pre_norm_bias_ = self._cuda(weights["word_embeddings_layernorm.bias"]) - if "ln_f.weight" in weights: - self.final_norm_weight_ = self._cuda(weights["ln_f.weight"]) - if "ln_f.bias" in weights: - self.final_norm_bias_ = self._cuda(weights["ln_f.bias"]) - if "word_embeddings.weight" in weights: - vob_size = self.network_config_["vocab_size"] - split_vob_size = vob_size // self.tp_world_size_ - self.wte_weight_ = self._cuda( - weights["word_embeddings.weight"][ - split_vob_size * self.tp_rank_ : split_vob_size * (self.tp_rank_ + 1), : - ] - ) - self.lm_head_weight_ = self.wte_weight_ - return - - def verify_load(self): - errors = "weights load not ok" - weights = [ - self.pre_norm_weight_, - self.pre_norm_bias_, - self.final_norm_weight_, - self.final_norm_bias_, - self.wte_weight_, - self.lm_head_weight_, - ] - for i in range(len(weights)): - assert weights[i] is not None, "index:" + str(i) + " " + errors - return + self.pre_norm_weight_ = NoTpNormWeight( + weight_name="word_embeddings_layernorm.weight", + data_type=self.data_type_, + bias_name="word_embeddings_layernorm.bias", + ) + self.final_norm_weight_ = NoTpNormWeight( + weight_name="ln_f.weight", + data_type=self.data_type_, + bias_name="ln_f.bias", + ) + vocab_size = self.network_config_["vocab_size"] + self.wte_weight_ = EmbeddingWeight( + weight_name="word_embeddings.weight", + data_type=self.data_type_, + vocab_size=vocab_size, + ) + self.lm_head_weight_ = self.wte_weight_ From 89f8fda1ab04fd26f8b6497f78e95a515890e61f Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 2 Jan 2026 10:39:16 +0800 Subject: [PATCH 45/79] fix rmsnorm llama call --- .../llama/layer_infer/transformer_layer_infer.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py index 83c090495..b08b2aa1f 100755 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -14,7 +14,6 @@ from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd, token_att_fwd_int8k from lightllm.models.llama.triton_kernel.token_attention_nopad_softmax import token_softmax_fwd from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import token_att_fwd2, token_att_fwd2_int8v -from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.common.fused_moe.moe_silu_and_mul import silu_and_mul_fwd @@ -190,16 +189,16 @@ def _bind_attention(self): def _att_norm( self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight ) -> torch.Tensor: - out = self.alloc_tensor(input.shape, input.dtype) - rmsnorm_forward(input, weight=layer_weight.att_norm_weight_.weight, eps=self.eps_, out=out) - return out + return layer_weight.att_norm_weight_.rmsnorm_forward(input=input, eps=self.eps_, alloc_func=self.alloc_tensor) def _ffn_norm( self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight ) -> torch.Tensor: - out = self.alloc_tensor(input.shape, input.dtype) - rmsnorm_forward(input, weight=layer_weight.ffn_norm_weight_.weight, eps=self.eps_, out=out) - return out + return layer_weight.ffn_norm_weight_.rmsnorm_forward( + input=input, + eps=self.eps_, + alloc_func=self.alloc_tensor, + ) def _get_qkv( self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight From 0aae4f671ae3d7c2dace7dafda556a839e49d8b3 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 2 Jan 2026 11:04:52 +0800 Subject: [PATCH 46/79] fix pos embdiing --- .../layer_weights/meta_weights/__init__.py | 2 +- .../meta_weights/embedding_weight.py | 61 +++++++++++++++--- .../pre_and_post_layer_weight.py | 3 +- .../pre_and_post_layer_weight.py | 4 +- .../starcoder/layer_infer/pre_layer_infer.py | 47 ++++---------- .../pre_and_post_layer_weight.py | 63 +++++++------------ 6 files changed, 90 insertions(+), 90 deletions(-) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py index a0b54f3d6..1441547aa 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py @@ -9,5 +9,5 @@ from .norm_weight import NoTpGEMMANormWeight, TpNormWeight, NoTpNormWeight, TpHeadNormWeight from .fused_moe_weight_tp import create_tp_moe_wegiht_obj from .fused_moe_weight_ep import FusedMoeWeightEP -from .embedding_weight import EmbeddingWeight, LMHeadWeight +from .embedding_weight import EmbeddingWeight, LMHeadWeight, NoTpPosEmbeddingWeight from .att_sink_weight import TpAttSinkWeight diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py index 75d197138..ba3e45c8f 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py @@ -4,23 +4,29 @@ from .base_weight import BaseWeightTpl from lightllm.utils.dist_utils import get_current_device_id from lightllm.common.basemodel.triton_kernel.embedding import embedding as embedding_kernel +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) class EmbeddingWeight(BaseWeightTpl): - def __init__(self, weight_name, data_type, vocab_size: int): + def __init__(self, weight_name, data_type): super().__init__() self.weight_name: str = weight_name self.data_type_ = data_type self.weight: torch.Tensor = None - self.vocab_size = vocab_size - split_indexes = np.linspace(0, self.vocab_size, self.tp_world_size_ + 1, dtype=np.int64) - self.tp_vocab_start_id = split_indexes[self.tp_rank_] - self.tp_vocab_end_id = split_indexes[self.tp_rank_ + 1] def load_hf_weights(self, weights: Dict[str, torch.Tensor]): if self.weight_name in weights and self.weight is None: t_weight = weights[self.weight_name] - assert len(t_weight) == self.vocab_size + # init some params + self.vocab_size = len(t_weight) + split_indexes = np.linspace(0, self.vocab_size, self.tp_world_size_ + 1, dtype=np.int64) + self.tp_vocab_start_id = split_indexes[self.tp_rank_] + self.tp_vocab_end_id = split_indexes[self.tp_rank_ + 1] + + logger.info(f"loaded weight vocab_size: {self.vocab_size}") + self.weight = ( t_weight[self.tp_vocab_start_id : self.tp_vocab_end_id, :] .to(self.data_type_) @@ -63,5 +69,44 @@ def lm_head(self, input: torch.Tensor, out: Optional[torch.Tensor] = None, alloc class LMHeadWeight(EmbeddingWeight): - def __init__(self, weight_name, data_type, vocab_size): - super().__init__(weight_name, data_type, vocab_size) + def __init__(self, weight_name, data_type): + super().__init__(weight_name, data_type) + + +class NoTpPosEmbeddingWeight(BaseWeightTpl): + def __init__(self, weight_name, data_type): + super().__init__() + self.weight_name: str = weight_name + self.data_type_ = data_type + self.weight: torch.Tensor = None + self.tp_world_size_ = 1 + self.tp_rank_ = 0 + + def load_hf_weights(self, weights: Dict[str, torch.Tensor]): + if self.weight_name in weights and self.weight is None: + t_weight = weights[self.weight_name] + self.weight = t_weight.to(self.data_type_).cuda(get_current_device_id()) + self.end_position_id: int = t_weight.shape[0] + logger.info(f"loaded weight end_position_id: {self.end_position_id}") + + def verify_load(self): + load_ok = True + load_ok = load_ok and self.weight is not None + + return load_ok + + def embedding(self, input_ids: torch.Tensor, out: Optional[torch.Tensor] = None, alloc_func=torch.empty): + if out is None: + out = alloc_func( + (input_ids.shape[0], self.weight.shape[1]), dtype=self.weight.dtype, device=self.weight.device + ) + + embedding_kernel( + input_ids=input_ids, + weight=self.weight, + vob_start_id=0, + vob_end_id=self.end_position_id, + out=out, + ) + + return out diff --git a/lightllm/models/bloom/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/bloom/layer_weights/pre_and_post_layer_weight.py index 23d443dc8..afc8c9308 100644 --- a/lightllm/models/bloom/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/bloom/layer_weights/pre_and_post_layer_weight.py @@ -17,10 +17,9 @@ def __init__(self, data_type, network_config, mode): data_type=self.data_type_, bias_name="ln_f.bias", ) - vocab_size = self.network_config_["vocab_size"] + self.wte_weight_ = EmbeddingWeight( weight_name="word_embeddings.weight", data_type=self.data_type_, - vocab_size=vocab_size, ) self.lm_head_weight_ = self.wte_weight_ diff --git a/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py index 9754094ad..9dd24da1e 100644 --- a/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py @@ -9,11 +9,9 @@ class LlamaPreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) - vocab_size = self.network_config_["vocab_size"] self.embdding_weight_ = EmbeddingWeight( weight_name="model.embed_tokens.weight", data_type=self.data_type_, - vocab_size=vocab_size, ) tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False) if tie_word_embeddings: @@ -22,8 +20,8 @@ def __init__(self, data_type, network_config, mode): self.lm_head_weight_ = LMHeadWeight( weight_name="lm_head.weight", data_type=self.data_type_, - vocab_size=vocab_size, ) + self.final_norm_weight_ = NoTpNormWeight( weight_name="model.norm.weight", data_type=self.data_type_, diff --git a/lightllm/models/starcoder/layer_infer/pre_layer_infer.py b/lightllm/models/starcoder/layer_infer/pre_layer_infer.py index a98323ab5..52072a348 100644 --- a/lightllm/models/starcoder/layer_infer/pre_layer_infer.py +++ b/lightllm/models/starcoder/layer_infer/pre_layer_infer.py @@ -1,13 +1,8 @@ -import torch -import torch.functional as F import torch.distributed as dist -import numpy as np -from lightllm.models.starcoder.layer_weights.pre_and_post_layer_weight import PreAndPostLayerWeight +from lightllm.models.starcoder.layer_weights.pre_and_post_layer_weight import StarcoderPreAndPostLayerWeight from lightllm.common.basemodel.infer_struct import InferStateInfo -from lightllm.utils.infer_utils import mark_cost_time from lightllm.common.basemodel import PreLayerInfer -from lightllm.models.llama.triton_kernel.embedding import embedding from lightllm.distributed.communication_op import all_reduce @@ -16,47 +11,27 @@ class StarcoderPreLayerInfer(PreLayerInfer): def __init__(self, network_config, mode): super().__init__(network_config, mode) - assert network_config["vocab_size"] % self.tp_world_size_ == 0 - self.tp_vocab_size_ = network_config["vocab_size"] // self.tp_world_size_ - self.embed_dim_ = network_config["hidden_size"] self.layer_norm_eps_ = network_config["layer_norm_epsilon"] - self.vob_start_id_ = self.tp_vocab_size_ * self.tp_rank_ - self.vob_end_id_ = self.tp_vocab_size_ * (self.tp_rank_ + 1) - def context_forward(self, input_ids, infer_state: InferStateInfo, layer_weight: PreAndPostLayerWeight): - total_token_num = infer_state.total_token_num - input_ids = input_ids[0:total_token_num] - - input_embdings = self.alloc_tensor( - (input_ids.shape[0], layer_weight.wte_weight_.shape[1]), dtype=layer_weight.data_type_ - ) - embedding(input_ids, layer_weight.wte_weight_, self.vob_start_id_, self.vob_end_id_, input_embdings) + def context_forward(self, input_ids, infer_state: InferStateInfo, layer_weight: StarcoderPreAndPostLayerWeight): + input_embdings = layer_weight.wte_weight_.embedding(input_ids=input_ids, alloc_func=self.alloc_tensor) if self.tp_world_size_ > 1: all_reduce(input_embdings, group=infer_state.dist_group, op=dist.ReduceOp.SUM, async_op=False) - position_embeds = self.alloc_tensor( - (infer_state.position_ids.shape[0], layer_weight.wpe_weight_.shape[1]), dtype=layer_weight.data_type_ - ) - embedding( - infer_state.position_ids, layer_weight.wpe_weight_, 0, layer_weight.wpe_weight_.shape[0], position_embeds + position_embeds = layer_weight.wpe_weight_.embedding( + input_ids=infer_state.position_ids, + alloc_func=self.alloc_tensor, ) return input_embdings.add_(position_embeds) - def token_forward(self, input_ids, infer_state: InferStateInfo, layer_weight: PreAndPostLayerWeight): - # import ipdb;ipdb.set_trace() - input_embdings = self.alloc_tensor( - (input_ids.shape[0], layer_weight.wte_weight_.shape[1]), dtype=layer_weight.data_type_ - ) - embedding(input_ids, layer_weight.wte_weight_, self.vob_start_id_, self.vob_end_id_, input_embdings) + def token_forward(self, input_ids, infer_state: InferStateInfo, layer_weight: StarcoderPreAndPostLayerWeight): + input_embdings = layer_weight.wte_weight_.embedding(input_ids=input_ids, alloc_func=self.alloc_tensor) if self.tp_world_size_ > 1: all_reduce(input_embdings, group=infer_state.dist_group, op=dist.ReduceOp.SUM, async_op=False) - position_embeds = self.alloc_tensor( - (infer_state.position_ids.shape[0], layer_weight.wpe_weight_.shape[1]), dtype=layer_weight.data_type_ - ) - embedding( - infer_state.position_ids, layer_weight.wpe_weight_, 0, layer_weight.wpe_weight_.shape[0], position_embeds + position_embeds = layer_weight.wpe_weight_.embedding( + input_ids=infer_state.position_ids, + alloc_func=self.alloc_tensor, ) - return input_embdings.add_(position_embeds) diff --git a/lightllm/models/starcoder/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/starcoder/layer_weights/pre_and_post_layer_weight.py index 8d87c1163..c44725382 100644 --- a/lightllm/models/starcoder/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/starcoder/layer_weights/pre_and_post_layer_weight.py @@ -1,51 +1,34 @@ import torch import numpy as np from lightllm.common.basemodel import PreAndPostLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ( + EmbeddingWeight, + NoTpNormWeight, + NoTpPosEmbeddingWeight, + LMHeadWeight, +) class StarcoderPreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) - def load_hf_weights(self, weights): + self.wte_weight_ = EmbeddingWeight( + weight_name="transformer.wte.weight", + data_type=self.data_type_, + ) + self.wpe_weight_ = NoTpPosEmbeddingWeight( + weight_name="transformer.wpe.weight", + data_type=self.data_type_, + ) - vob_size = self.network_config_["vocab_size"] - split_vob_size = vob_size // self.tp_world_size_ - if "transformer.wte.weight" in weights: - # print(weights['transformer.wte.weight'].shape) - self.wte_weight_ = ( - weights["transformer.wte.weight"][ - split_vob_size * self.tp_rank_ : split_vob_size * (self.tp_rank_ + 1), : - ] - .contiguous() - .to(self.data_type_) - .cuda() - ) - if "transformer.wpe.weight" in weights: - # print(weights['transformer.wpe.weight'].shape) - self.wpe_weight_ = weights["transformer.wpe.weight"].to(self.data_type_).cuda() - if "lm_head.weight" in weights: - self.lm_head_weight_ = ( - weights["lm_head.weight"][split_vob_size * self.tp_rank_ : split_vob_size * (self.tp_rank_ + 1), :] - .contiguous() - .to(self.data_type_) - .cuda() - ) - if "transformer.ln_f.weight" in weights: - self.final_norm_weight_ = weights["transformer.ln_f.weight"].contiguous().to(self.data_type_).cuda() - if "transformer.ln_f.bias" in weights: - self.final_norm_bias_ = weights["transformer.ln_f.bias"].contiguous().to(self.data_type_).cuda() - return - - def verify_load(self): - errors = "weights load not ok" - weights = [ - self.final_norm_weight_, - self.final_norm_bias_, - self.wte_weight_, - self.wpe_weight_, - self.lm_head_weight_, - ] - for i in range(len(weights)): - assert weights[i] is not None, "index:" + str(i) + " " + errors + self.final_norm_weight_ = NoTpNormWeight( + weight_name="transformer.ln_f.weight", + bias_name="transformer.ln_f.bias", + data_type=self.data_type_, + ) + self.lm_head_weight_ = LMHeadWeight( + weight_name="lm_head.weight", + data_type=self.data_type_, + ) return From edd3f35e2ebbcd593892c8122122d17071fa4e4a Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 2 Jan 2026 11:07:19 +0800 Subject: [PATCH 47/79] fix starcoder --- .../layer_infer/transformer_layer_infer.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/lightllm/models/starcoder2/layer_infer/transformer_layer_infer.py b/lightllm/models/starcoder2/layer_infer/transformer_layer_infer.py index ca04bee0c..796a96bc4 100644 --- a/lightllm/models/starcoder2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/starcoder2/layer_infer/transformer_layer_infer.py @@ -1,5 +1,4 @@ import torch -from lightllm.models.bloom.triton_kernel.layernorm import layernorm_forward from lightllm.models.starcoder2.layer_weights.transformer_layer_weight import Starcoder2TransformerLayerWeight from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.models.llama.infer_struct import LlamaInferStateInfo @@ -12,21 +11,19 @@ def __init__(self, layer_num, network_config, mode=[]): def _att_norm( self, input, infer_state: LlamaInferStateInfo, layer_weight: Starcoder2TransformerLayerWeight ) -> torch.Tensor: - return layernorm_forward( - input.view(-1, self.embed_dim_), - weight=layer_weight.att_norm_weight_.weight, - bias=layer_weight.att_norm_weight_.bias, + return layer_weight.att_norm_weight_.layernorm_forward( + input=input.view(-1, self.embed_dim_), eps=self.eps_, + alloc_func=self.alloc_tensor, ) def _ffn_norm( self, input, infer_state: LlamaInferStateInfo, layer_weight: Starcoder2TransformerLayerWeight ) -> torch.Tensor: - return layernorm_forward( - input.view(-1, self.embed_dim_), - weight=layer_weight.ffn_norm_weight_.weight, - bias=layer_weight.ffn_norm_weight_.bias, + return layer_weight.ffn_norm_weight_.layernorm_forward( + input=input.view(-1, self.embed_dim_), eps=self.eps_, + alloc_func=self.alloc_tensor, ) def _ffn( From d97cc4001526cae508a5b80213c67c886ac55151 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 2 Jan 2026 11:12:40 +0800 Subject: [PATCH 48/79] fix rmsnorm call --- .../layer_infer/transformer_layer_infer.py | 1 - .../layer_infer/transformer_layer_infer.py | 18 ++++++++---------- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/lightllm/models/qwen3/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3/layer_infer/transformer_layer_infer.py index c69c7f4fb..30ae2242d 100644 --- a/lightllm/models/qwen3/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3/layer_infer/transformer_layer_infer.py @@ -8,7 +8,6 @@ from lightllm.models.qwen3.layer_weights.transformer_layer_weight import Qwen3TransformerLayerWeight from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd from lightllm.models.qwen3.triton_kernel.qk_norm import qk_rmsnorm_forward diff --git a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py index bd7bca39e..10a734e5c 100644 --- a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py @@ -8,7 +8,6 @@ from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd from functools import partial @@ -62,17 +61,17 @@ def _get_qkv( input = input.view(-1, self.embed_dim_) q = layer_weight.q_proj.mm(input) cache_kv = layer_weight.kv_proj.mm(input).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) - rmsnorm_forward( + + layer_weight.q_norm_weight_.rmsnorm_forward( q.view(-1, self.head_dim_), - weight=layer_weight.q_norm_weight_.weight, eps=self.eps_, out=q.view(-1, self.head_dim_), ) - cache_kv[:, : self.tp_k_head_num_, :] = rmsnorm_forward( - cache_kv[:, : self.tp_k_head_num_, :].reshape(-1, cache_kv.shape[-1]), - weight=layer_weight.k_norm_weight_.weight, + cache_kv[:, : self.tp_k_head_num_, :] = layer_weight.k_norm_weight_.rmsnorm_forward( + input=cache_kv[:, : self.tp_k_head_num_, :].reshape(-1, cache_kv.shape[-1]), eps=self.eps_, + alloc_func=self.alloc_tensor, ).view(-1, self.tp_k_head_num_, cache_kv.shape[-1]) rotary_emb_fwd( @@ -101,17 +100,16 @@ def _tpsp_get_qkv( q = layer_weight.q_proj.mm(input) cache_kv = layer_weight.kv_proj.mm(input).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) - rmsnorm_forward( + layer_weight.q_norm_weight_.rmsnorm_forward( q.view(-1, self.head_dim_), - weight=layer_weight.q_norm_weight_.weight, eps=self.eps_, out=q.view(-1, self.head_dim_), ) - cache_kv[:, : self.tp_k_head_num_, :] = rmsnorm_forward( + cache_kv[:, : self.tp_k_head_num_, :] = layer_weight.k_norm_weight_.rmsnorm_forward( cache_kv[:, : self.tp_k_head_num_, :].reshape(-1, cache_kv.shape[-1]), - weight=layer_weight.k_norm_weight_.weight, eps=self.eps_, + alloc_func=self.alloc_tensor, ).view(-1, self.tp_k_head_num_, cache_kv.shape[-1]) rotary_emb_fwd( From 41e82e0747147a6b7868514498fe45863f91ac65 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 2 Jan 2026 11:13:54 +0800 Subject: [PATCH 49/79] fix chatglm --- .../chatglm2/layer_infer/transformer_layer_infer.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/lightllm/models/chatglm2/layer_infer/transformer_layer_infer.py b/lightllm/models/chatglm2/layer_infer/transformer_layer_infer.py index 89061dae3..07ffc4bea 100755 --- a/lightllm/models/chatglm2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/chatglm2/layer_infer/transformer_layer_infer.py @@ -1,17 +1,8 @@ import torch -import torch.functional as F -import torch.distributed as dist -import numpy as np - -from lightllm.utils.infer_utils import mark_cost_time from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.models.chatglm2.layer_weights.transformer_layer_weight import ChatGLM2TransformerLayerWeight -from lightllm.models.chatglm2.triton_kernel.rotary_emb import rotary_emb_fwd -from lightllm.common.basemodel.triton_kernel.destindex_copy_kv import destindex_copy_kv, destindex_copy_quantize_kv -from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward - class ChatGLM2TransformerLayerInfer(LlamaTransformerLayerInfer): """ """ From 55666de692920daa011b51accc1d9088bfe0b139 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 2 Jan 2026 11:16:19 +0800 Subject: [PATCH 50/79] fix stablelm --- .../layer_infer/transformer_layer_infer.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/lightllm/models/stablelm/layer_infer/transformer_layer_infer.py b/lightllm/models/stablelm/layer_infer/transformer_layer_infer.py index 53171ce53..df64f2bcc 100755 --- a/lightllm/models/stablelm/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/stablelm/layer_infer/transformer_layer_infer.py @@ -5,7 +5,6 @@ from functools import partial from typing import Tuple from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd -from lightllm.models.bloom.triton_kernel.layernorm import layernorm_forward from lightllm.models.stablelm.layer_weights.transformer_layer_weight import StablelmTransformerLayerWeight from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.models.llama.infer_struct import LlamaInferStateInfo @@ -57,19 +56,17 @@ def _tpsp_get_o(self, input, infer_state, layer_weight) -> Tuple[torch.Tensor, t def _att_norm( self, input, infer_state: LlamaInferStateInfo, layer_weight: StablelmTransformerLayerWeight ) -> torch.Tensor: - return layernorm_forward( - input.view(-1, self.embed_dim_), - weight=layer_weight.att_norm_weight_.weight, - bias=layer_weight.att_norm_weight_.bias, + return layer_weight.att_norm_weight_.layernorm_forward( + input=input.view(-1, self.embed_dim_), eps=self.eps_, + alloc_func=self.alloc_tensor, ) def _ffn_norm( self, input, infer_state: LlamaInferStateInfo, layer_weight: StablelmTransformerLayerWeight ) -> torch.Tensor: - return layernorm_forward( - input.view(-1, self.embed_dim_), - weight=layer_weight.ffn_norm_weight_.weight, - bias=layer_weight.ffn_norm_weight_.bias, + return layer_weight.ffn_norm_weight_.layernorm_forward( + input=input.view(-1, self.embed_dim_), eps=self.eps_, + alloc_func=self.alloc_tensor, ) From 435543765799a7320247c96fd536b0df43b8dbef Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 2 Jan 2026 11:24:14 +0800 Subject: [PATCH 51/79] fix wte name --- .../gemma_2b/layer_infer/pre_layer_infer.py | 14 ++++----- .../pre_and_post_layer_weight.py | 30 ++++++++----------- .../llama/layer_infer/pre_layer_infer.py | 4 +-- .../pre_and_post_layer_weight.py | 4 +-- 4 files changed, 23 insertions(+), 29 deletions(-) diff --git a/lightllm/models/gemma_2b/layer_infer/pre_layer_infer.py b/lightllm/models/gemma_2b/layer_infer/pre_layer_infer.py index c63432a92..ce9737820 100644 --- a/lightllm/models/gemma_2b/layer_infer/pre_layer_infer.py +++ b/lightllm/models/gemma_2b/layer_infer/pre_layer_infer.py @@ -5,8 +5,6 @@ from lightllm.models.gemma_2b.layer_weights.pre_and_post_layer_weight import Gemma_2bPreAndPostLayerWeight from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.common.basemodel import PreLayerInferTpl -from lightllm.utils.infer_utils import mark_cost_time -from lightllm.models.llama.triton_kernel.embedding import embedding from lightllm.distributed.communication_op import all_reduce @@ -24,20 +22,20 @@ def _norm(self, input, infer_state, layer_weight: Gemma_2bPreAndPostLayerWeight) return input * self.normfactor def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: Gemma_2bPreAndPostLayerWeight): - input_embdings = self.alloc_tensor( - (input_ids.shape[0], layer_weight.wte_weight_.shape[1]), dtype=layer_weight.data_type_ + input_embdings = layer_weight.wte_weight_.embedding( + input_ids=input_ids, + alloc_func=self.alloc_tensor, ) - embedding(input_ids, layer_weight.wte_weight_, self.vob_start_id_, self.vob_end_id_, input_embdings) if self.tp_world_size_ > 1: all_reduce(input_embdings, group=infer_state.dist_group, op=dist.ReduceOp.SUM, async_op=False) input_embdings = self._norm(input_embdings, infer_state, layer_weight) return input_embdings def token_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: Gemma_2bPreAndPostLayerWeight): - input_embdings = self.alloc_tensor( - (input_ids.shape[0], layer_weight.wte_weight_.shape[1]), dtype=layer_weight.data_type_ + input_embdings = layer_weight.wte_weight_.embedding( + input_ids=input_ids, + alloc_func=self.alloc_tensor, ) - embedding(input_ids, layer_weight.wte_weight_, self.vob_start_id_, self.vob_end_id_, input_embdings) if self.tp_world_size_ > 1: all_reduce(input_embdings, group=infer_state.dist_group, op=dist.ReduceOp.SUM, async_op=False) input_embdings = self._norm(input_embdings, infer_state, layer_weight) diff --git a/lightllm/models/gemma_2b/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/gemma_2b/layer_weights/pre_and_post_layer_weight.py index c119960c5..d5d0438fa 100644 --- a/lightllm/models/gemma_2b/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/gemma_2b/layer_weights/pre_and_post_layer_weight.py @@ -1,25 +1,21 @@ -import torch -import numpy as np -from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight +from lightllm.common.basemodel import PreAndPostLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, NoTpGEMMANormWeight -class Gemma_2bPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): +class Gemma_2bPreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) - return - def load_hf_weights(self, weights): - vob_size = self.network_config_["vocab_size"] - split_indexes = np.linspace(0, vob_size, self.tp_world_size_ + 1, dtype=np.int64) - split_start = split_indexes[self.tp_rank_] - split_end = split_indexes[self.tp_rank_ + 1] - if "model.embed_tokens.weight" in weights: - # print(weights['model.embed_tokens.weight'].shape) - self.wte_weight_ = self._cuda(weights["model.embed_tokens.weight"][split_start:split_end, :]) - self.lm_head_weight_ = self.wte_weight_ + self.wte_weight_ = EmbeddingWeight( + weight_name="model.embed_tokens.weight", + data_type=self.data_type_, + ) + self.lm_head_weight_ = self.wte_weight_ - if "model.norm.weight" in weights: - self.final_norm_weight_ = self._cuda(weights["model.norm.weight"]) - self.final_norm_weight_ = self.final_norm_weight_ + 1 + self.final_norm_weight_ = NoTpGEMMANormWeight( + weight_name="model.norm.weight", + data_type=self.data_type_, + bias_name=None, + ) return diff --git a/lightllm/models/llama/layer_infer/pre_layer_infer.py b/lightllm/models/llama/layer_infer/pre_layer_infer.py index e8fa70626..ddb99e262 100644 --- a/lightllm/models/llama/layer_infer/pre_layer_infer.py +++ b/lightllm/models/llama/layer_infer/pre_layer_infer.py @@ -15,13 +15,13 @@ def __init__(self, network_config, mode): return def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): - input_embdings = layer_weight.embdding_weight_.embedding(input_ids=input_ids, alloc_func=self.alloc_tensor) + input_embdings = layer_weight.wte_weight_.embedding(input_ids=input_ids, alloc_func=self.alloc_tensor) if self.tp_world_size_ > 1: all_reduce(input_embdings, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) return input_embdings def token_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): - input_embdings = layer_weight.embdding_weight_.embedding(input_ids=input_ids, alloc_func=self.alloc_tensor) + input_embdings = layer_weight.wte_weight_.embedding(input_ids=input_ids, alloc_func=self.alloc_tensor) if self.tp_world_size_ > 1: all_reduce(input_embdings, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) return input_embdings diff --git a/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py index 9dd24da1e..5dbd0eb63 100644 --- a/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py @@ -9,13 +9,13 @@ class LlamaPreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) - self.embdding_weight_ = EmbeddingWeight( + self.wte_weight_ = EmbeddingWeight( weight_name="model.embed_tokens.weight", data_type=self.data_type_, ) tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False) if tie_word_embeddings: - self.lm_head_weight_ = self.embdding_weight_ + self.lm_head_weight_ = self.wte_weight_ else: self.lm_head_weight_ = LMHeadWeight( weight_name="lm_head.weight", From 0fad6c6f5d27f2013308c5153dfbbcd1b1960725 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 2 Jan 2026 11:33:08 +0800 Subject: [PATCH 52/79] fix rmsnorm --- .../layer_infer/transformer_layer_infer.py | 23 +++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index 76dd1d568..ff20bc6ee 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -18,7 +18,6 @@ from lightllm.models.deepseek2.triton_kernel.gqa_flash_decoding import gqa_token_decode_attention_flash_decoding from lightllm.models.deepseek2.triton_kernel.gqa_flash_decoding_fp8 import gqa_token_decode_attention_flash_decoding_fp8 from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer -from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo from lightllm.models.deepseek2.flashinfer_struct import Deepseek2FlashInferStateInfo @@ -158,14 +157,18 @@ def _get_qkv( q, cache_kv = layer_weight.qkv_a_proj_with_mqa_.mm(input).split( [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1 ) - q = rmsnorm_forward(q, weight=layer_weight.q_a_layernorm_.weight, eps=self.eps_) + q = layer_weight.q_a_layernorm_.rmsnorm_forward( + input=q, + eps=self.eps_, + alloc_func=self.alloc_tensor, + ) q = layer_weight.q_b_proj_.mm(q) cache_kv = cache_kv.view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim) q = q.view(-1, self.tp_q_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim) q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - rmsnorm_forward( + + layer_weight.kv_a_layernorm_.rmsnorm_forward( cache_kv[:, :, : self.kv_lora_rank], - weight=layer_weight.kv_a_layernorm_.weight, eps=self.eps_, out=cache_kv[:, :, : self.kv_lora_rank], ) @@ -197,9 +200,8 @@ def _tpsp_get_qkv( cache_kv = layer_weight.kv_a_proj_with_mqa_.mm(input).view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim) q = q.view(-1, self.tp_q_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim) q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - rmsnorm_forward( + layer_weight.kv_a_layernorm_.rmsnorm_forward( cache_kv[:, :, : self.kv_lora_rank], - weight=layer_weight.kv_a_layernorm_.weight, eps=self.eps_, out=cache_kv[:, :, : self.kv_lora_rank], ) @@ -234,14 +236,17 @@ def _tpsp_get_qkv( position_sin = infer_state.position_sin q, cache_kv = qkv.split([self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1) - q = rmsnorm_forward(q, weight=layer_weight.q_a_layernorm_.weight, eps=self.eps_) + q = layer_weight.q_a_layernorm_.rmsnorm_forward( + q, + eps=self.eps_, + alloc_func=self.alloc_tensor, + ) q = layer_weight.q_b_proj_.mm(q) cache_kv = cache_kv.view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim) q = q.view(-1, self.tp_q_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim) q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - rmsnorm_forward( + layer_weight.kv_a_layernorm_.rmsnorm_forward( cache_kv[:, :, : self.kv_lora_rank], - weight=layer_weight.kv_a_layernorm_.weight, eps=self.eps_, out=cache_kv[:, :, : self.kv_lora_rank], ) From e4be2167bb085926f119f516d6d7e536a6402db1 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 2 Jan 2026 11:34:20 +0800 Subject: [PATCH 53/79] fix --- lightllm/common/basemodel/triton_kernel/rmsnorm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/common/basemodel/triton_kernel/rmsnorm.py b/lightllm/common/basemodel/triton_kernel/rmsnorm.py index 19f7c84d3..ca8f9a1c8 100644 --- a/lightllm/common/basemodel/triton_kernel/rmsnorm.py +++ b/lightllm/common/basemodel/triton_kernel/rmsnorm.py @@ -50,7 +50,7 @@ def rmsnorm_forward(x: torch.Tensor, weight: torch.Tensor, eps: float, out=None) # reshape input data into 2D tensor x_arg = x.view(-1, x.shape[-1]) y_arg = y.view(-1, x.shape[-1]) - assert x_arg.shape[-1] == weight.shape[0] + assert x_arg.shape[-1] == weight.shape[0] and x_arg.shape == y_arg.shape assert y.data_ptr() == y_arg.data_ptr() M, N = x_arg.shape # Less than 64KB per feature: enqueue fused kernel From 741532e3722140a9b9aa3c31efe89c52d3a54ef5 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 2 Jan 2026 11:35:10 +0800 Subject: [PATCH 54/79] fix --- lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py index a9a6954d6..175340a77 100644 --- a/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py @@ -11,7 +11,6 @@ from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo -from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd from lightllm.distributed import all_reduce From 4e9941a37966ca65541dd1acb5e16c7557ed018a Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 2 Jan 2026 11:35:52 +0800 Subject: [PATCH 55/79] fix --- .../models/gpt_oss/layer_weights/transformer_layer_weight.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py b/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py index 94f3efc15..fb3c66e22 100644 --- a/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py @@ -4,7 +4,6 @@ from lightllm.common.basemodel.layer_weights.meta_weights.gpt_oss_fused_moe_weight_tp import GPTOSSFusedMoeWeightTP from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight import ROWMMWeight -from lightllm.common.basemodel.layer_weights.meta_weights.norm_weight import NormWeight, TpNormWeight from lightllm.common.basemodel.layer_weights.meta_weights import TpAttSinkWeight from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight from lightllm.utils.log_utils import init_logger From c53cdf9748791fa5497372ef220fb3cc5251862c Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 2 Jan 2026 11:56:48 +0800 Subject: [PATCH 56/79] fix mtp deepseek --- .../layer_infer/pre_layer_infer.py | 25 +++++++-- .../pre_and_post_layer_weight.py | 56 +++++++++++-------- 2 files changed, 54 insertions(+), 27 deletions(-) diff --git a/lightllm/models/deepseek_mtp/layer_infer/pre_layer_infer.py b/lightllm/models/deepseek_mtp/layer_infer/pre_layer_infer.py index b81fbc510..c030e001a 100644 --- a/lightllm/models/deepseek_mtp/layer_infer/pre_layer_infer.py +++ b/lightllm/models/deepseek_mtp/layer_infer/pre_layer_infer.py @@ -3,7 +3,6 @@ from lightllm.models.deepseek_mtp.layer_weights.pre_and_post_layer_weight import Deepseek3MTPPreAndPostLayerWeight from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer -from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward class Deepseek3MTPPreLayerInfer(LlamaPreLayerInfer): @@ -22,9 +21,17 @@ def _mtp_context_forward( assert ( input_embdings.shape[0] == tgt_embdings.shape[0] ), f"shape {input_embdings.shape} != shape {tgt_embdings.shape}" - rmsnorm_forward(input_embdings, weight=layer_weight.enorm_weight_, eps=self.eps_, out=input_embdings) - rmsnorm_forward(tgt_embdings, weight=layer_weight.hnorm_weight_, eps=self.eps_, out=tgt_embdings) + layer_weight.enorm_weight_.rmsnorm_forward( + input=input_embdings, + eps=self.eps_, + out=input_embdings, + ) + layer_weight.hnorm_weight_.rmsnorm_forward( + input=tgt_embdings, + eps=self.eps_, + out=tgt_embdings, + ) cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1) ans_logics = self.alloc_tensor( @@ -38,9 +45,17 @@ def _mtp_token_forward( ): tgt_embdings = infer_state.mtp_draft_input_hiddens assert input_embdings.shape[0] == tgt_embdings.shape[0] - rmsnorm_forward(input_embdings, weight=layer_weight.enorm_weight_, eps=self.eps_, out=input_embdings) - rmsnorm_forward(tgt_embdings, weight=layer_weight.hnorm_weight_, eps=self.eps_, out=tgt_embdings) + layer_weight.enorm_weight_.rmsnorm_forward( + input=input_embdings, + eps=self.eps_, + out=input_embdings, + ) + layer_weight.hnorm_weight_.rmsnorm_forward( + input=tgt_embdings, + eps=self.eps_, + out=tgt_embdings, + ) cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1) ans_logics = self.alloc_tensor( diff --git a/lightllm/models/deepseek_mtp/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/deepseek_mtp/layer_weights/pre_and_post_layer_weight.py index f5b805647..582afb560 100644 --- a/lightllm/models/deepseek_mtp/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/deepseek_mtp/layer_weights/pre_and_post_layer_weight.py @@ -1,29 +1,41 @@ -import numpy as np -from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight +from lightllm.common.basemodel import PreAndPostLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ( + EmbeddingWeight, + LMHeadWeight, + NoTpNormWeight, + ROWMMWeight, +) -class Deepseek3MTPPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): +class Deepseek3MTPPreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) - # 与DeepseekV3模型共享 - self.wte_weight_ = None - self.lm_head_weight_ = None - return - def load_hf_weights(self, weights): - if "model.layers.0.eh_proj.weight" in weights: - self.eh_proj_weight_ = self._cuda(weights["model.layers.0.eh_proj.weight"]).t() - if "model.layers.0.enorm.weight" in weights: - self.enorm_weight_ = self._cuda(weights["model.layers.0.enorm.weight"]) - if "model.layers.0.hnorm.weight" in weights: - self.hnorm_weight_ = self._cuda(weights["model.layers.0.hnorm.weight"]) - if "model.layers.0.shared_head.norm.weight" in weights: - self.final_norm_weight_ = self._cuda(weights["model.layers.0.shared_head.norm.weight"]) - return + self.eh_proj_weight_ = ROWMMWeight( + weight_names="model.layers.0.eh_proj.weight", + data_type=self.data_type_, + layer_num=0, + name="eh_proj", + tp_rank=0, + tp_world_size=1, + ) + self.enorm_weight_ = NoTpNormWeight( + weight_name="model.layers.0.enorm.weight", + data_type=self.data_type_, + bias_name=None, + ) + self.hnorm_weight_ = NoTpNormWeight( + weight_name="model.layers.0.hnorm.weight", + data_type=self.data_type_, + bias_name=None, + ) + self.final_norm_weight_ = NoTpNormWeight( + weight_name="model.layers.0.shared_head.norm.weight", + data_type=self.data_type_, + bias_name=None, + ) - def verify_load(self): - errors = "weights load not ok" - weights = [self.eh_proj_weight_, self.enorm_weight_, self.hnorm_weight_, self.final_norm_weight_] - for i in range(len(weights)): - assert weights[i] is not None, "index:" + str(i) + " " + errors + # 与DeepseekV3模型共享, 不通过 load 加载 + self.wte_weight_: EmbeddingWeight = None + self.lm_head_weight_: LMHeadWeight = None return From 0b74a3cb211da1ac983d6690d5bc3c436abc2ebd Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 2 Jan 2026 12:03:16 +0800 Subject: [PATCH 57/79] fix mtp mistral --- .../layer_infer/pre_layer_infer.py | 39 ++++++++++++--- .../pre_and_post_layer_weight.py | 48 +++++++++++-------- 2 files changed, 61 insertions(+), 26 deletions(-) diff --git a/lightllm/models/mistral_mtp/layer_infer/pre_layer_infer.py b/lightllm/models/mistral_mtp/layer_infer/pre_layer_infer.py index f43b85d63..ac0f39487 100644 --- a/lightllm/models/mistral_mtp/layer_infer/pre_layer_infer.py +++ b/lightllm/models/mistral_mtp/layer_infer/pre_layer_infer.py @@ -2,7 +2,6 @@ from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.models.mistral_mtp.layer_weights.pre_and_post_layer_weight import MistralMTPPreAndPostLayerWeight -from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward class MistralMTPPreLayerInfer(LlamaPreLayerInfer): @@ -19,10 +18,23 @@ def _mtp_context_forward( assert ( input_embdings.shape[0] == tgt_embdings.shape[0] ), f"shape {input_embdings.shape} != shape {tgt_embdings.shape}" - rmsnorm_forward(input_embdings, weight=layer_weight.enorm_weight_, eps=self.eps_, out=input_embdings) - tgt_embdings = rmsnorm_forward(tgt_embdings, weight=layer_weight.final_norm_weight_, eps=self.eps_) - rmsnorm_forward(tgt_embdings, weight=layer_weight.hnorm_weight_, eps=self.eps_, out=tgt_embdings) + layer_weight.enorm_weight_.rmsnorm_forward( + input=input_embdings, + eps=self.eps_, + out=input_embdings, + ) + + tgt_embdings = layer_weight.final_norm_weight_.rmsnorm_forward( + input=tgt_embdings, + eps=self.eps_, + alloc_func=self.alloc_tensor, + ) + layer_weight.hnorm_weight_.rmsnorm_forward( + input=tgt_embdings, + eps=self.eps_, + out=tgt_embdings, + ) cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1) @@ -37,10 +49,23 @@ def _mtp_token_forward( ): tgt_embdings = infer_state.mtp_draft_input_hiddens assert input_embdings.shape[0] == tgt_embdings.shape[0] - rmsnorm_forward(input_embdings, weight=layer_weight.enorm_weight_, eps=self.eps_, out=input_embdings) - tgt_embdings = rmsnorm_forward(tgt_embdings, weight=layer_weight.final_norm_weight_, eps=self.eps_) - rmsnorm_forward(tgt_embdings, weight=layer_weight.hnorm_weight_, eps=self.eps_, out=tgt_embdings) + layer_weight.enorm_weight_.rmsnorm_forward( + input=input_embdings, + eps=self.eps_, + out=input_embdings, + ) + + tgt_embdings = layer_weight.final_norm_weight_.rmsnorm_forward( + input=tgt_embdings, + eps=self.eps_, + alloc_func=self.alloc_tensor, + ) + layer_weight.hnorm_weight_.rmsnorm_forward( + input=tgt_embdings, + eps=self.eps_, + out=tgt_embdings, + ) cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1) diff --git a/lightllm/models/mistral_mtp/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/mistral_mtp/layer_weights/pre_and_post_layer_weight.py index e974ac2b0..2fbc89cfd 100644 --- a/lightllm/models/mistral_mtp/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/mistral_mtp/layer_weights/pre_and_post_layer_weight.py @@ -1,26 +1,36 @@ -from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight +from lightllm.common.basemodel import PreAndPostLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ( + EmbeddingWeight, + LMHeadWeight, + NoTpNormWeight, + ROWMMWeight, +) -class MistralMTPPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): +class MistralMTPPreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) - self.wte_weight_ = None - self.lm_head_weight_ = None - self.final_norm_weight_ = None - return - def load_hf_weights(self, weights): - if "mtp.eh_proj.weight" in weights: - self.eh_proj_weight_ = self._cuda(weights["mtp.eh_proj.weight"]).t() - if "mtp.enorm.weight" in weights: - self.enorm_weight_ = self._cuda(weights["mtp.enorm.weight"]) - if "mtp.hnorm.weight" in weights: - self.hnorm_weight_ = self._cuda(weights["mtp.hnorm.weight"]) - return + self.eh_proj_weight_ = ROWMMWeight( + weight_names="mtp.eh_proj.weight", + data_type=self.data_type_, + layer_num=0, + name="eh_proj", + tp_rank=0, + tp_world_size=1, + ) + self.enorm_weight_ = NoTpNormWeight( + weight_name="mtp.enorm.weight", + data_type=self.data_type_, + bias_name=None, + ) + self.hnorm_weight_ = NoTpNormWeight( + weight_name="mtp.hnorm.weight", + data_type=self.data_type_, + bias_name=None, + ) - def verify_load(self): - errors = "weights load not ok" - weights = [self.eh_proj_weight_, self.enorm_weight_, self.hnorm_weight_] - for i in range(len(weights)): - assert weights[i] is not None, "index:" + str(i) + " " + errors + self.wte_weight_: EmbeddingWeight = None + self.lm_head_weight_: LMHeadWeight = None + self.final_norm_weight_: NoTpNormWeight = None return From bfe3f98024a5eb894bf1d8f7028c420d25b50754 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 2 Jan 2026 12:05:27 +0800 Subject: [PATCH 58/79] fix --- .../models/mistral_mtp/layer_infer/pre_layer_infer.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/lightllm/models/mistral_mtp/layer_infer/pre_layer_infer.py b/lightllm/models/mistral_mtp/layer_infer/pre_layer_infer.py index ac0f39487..25bea1aa6 100644 --- a/lightllm/models/mistral_mtp/layer_infer/pre_layer_infer.py +++ b/lightllm/models/mistral_mtp/layer_infer/pre_layer_infer.py @@ -38,10 +38,7 @@ def _mtp_context_forward( cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1) - ans_logics = self.alloc_tensor( - (cat_embdings.shape[0], layer_weight.eh_proj_weight_.shape[1]), dtype=input_embdings.dtype - ) - torch.mm(cat_embdings, layer_weight.eh_proj_weight_, out=ans_logics) + ans_logics = layer_weight.eh_proj_weight_.mm(cat_embdings) return ans_logics def _mtp_token_forward( @@ -69,10 +66,7 @@ def _mtp_token_forward( cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1) - ans_logics = self.alloc_tensor( - (cat_embdings.shape[0], layer_weight.eh_proj_weight_.shape[1]), dtype=input_embdings.dtype - ) - torch.mm(cat_embdings, layer_weight.eh_proj_weight_, out=ans_logics) + ans_logics = layer_weight.eh_proj_weight_.mm(cat_embdings) return ans_logics def context_forward( From a5f006995a87dabb119ab8d5ec7f336814b755b0 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 2 Jan 2026 12:06:36 +0800 Subject: [PATCH 59/79] fix mtp deepseek --- .../models/deepseek_mtp/layer_infer/pre_layer_infer.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/lightllm/models/deepseek_mtp/layer_infer/pre_layer_infer.py b/lightllm/models/deepseek_mtp/layer_infer/pre_layer_infer.py index c030e001a..26bfc865e 100644 --- a/lightllm/models/deepseek_mtp/layer_infer/pre_layer_infer.py +++ b/lightllm/models/deepseek_mtp/layer_infer/pre_layer_infer.py @@ -34,10 +34,7 @@ def _mtp_context_forward( ) cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1) - ans_logics = self.alloc_tensor( - (cat_embdings.shape[0], layer_weight.eh_proj_weight_.shape[1]), dtype=input_embdings.dtype - ) - torch.mm(cat_embdings, layer_weight.eh_proj_weight_, out=ans_logics) + ans_logics = layer_weight.eh_proj_weight_.mm(cat_embdings) return ans_logics def _mtp_token_forward( @@ -58,10 +55,7 @@ def _mtp_token_forward( ) cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1) - ans_logics = self.alloc_tensor( - (cat_embdings.shape[0], layer_weight.eh_proj_weight_.shape[1]), dtype=input_embdings.dtype - ) - torch.mm(cat_embdings, layer_weight.eh_proj_weight_, out=ans_logics) + ans_logics = layer_weight.eh_proj_weight_.mm(cat_embdings) return ans_logics def context_forward( From af9f50db49762c39b2ebf15bf299f502e97f690f Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 2 Jan 2026 12:27:01 +0800 Subject: [PATCH 60/79] fix --- lightllm/common/basemodel/basemodel.py | 6 ++---- .../layer_weights/meta_weights/embedding_weight.py | 4 ++-- lightllm/server/api_start.py | 3 ++- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index b6635fbab..3b8d6bd5f 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -498,11 +498,9 @@ def _decode( if self.graph.need_capture(find_graph_batch_size): infer_state.is_cuda_graph = True - model_output: ModelOutput = self.graph.capture_decode( - self._token_forward, padded_model_input.input_ids, infer_state - ) + model_output: ModelOutput = self.graph.capture_decode(self._token_forward, infer_state) else: - model_output: ModelOutput = self.graph.replay(padded_model_input.input_ids, infer_state) + model_output: ModelOutput = self.graph.replay(infer_state) model_output = self._create_unpad_decode_model_output( model_output, origin_batch_size=model_input.batch_size diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py index ba3e45c8f..1ea2d1331 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py @@ -22,8 +22,8 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]): # init some params self.vocab_size = len(t_weight) split_indexes = np.linspace(0, self.vocab_size, self.tp_world_size_ + 1, dtype=np.int64) - self.tp_vocab_start_id = split_indexes[self.tp_rank_] - self.tp_vocab_end_id = split_indexes[self.tp_rank_ + 1] + self.tp_vocab_start_id = int(split_indexes[self.tp_rank_]) + self.tp_vocab_end_id = int(split_indexes[self.tp_rank_ + 1]) logger.info(f"loaded weight vocab_size: {self.vocab_size}") diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 014d98766..4ead3cbbf 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -197,7 +197,8 @@ def normal_or_p_d_start(args): assert ( args.batch_max_tokens >= args.chunked_prefill_size - ), "chunked prefill mode, batch_max_tokens must >= chunked_prefill_size" + ), "chunked prefill mode, batch_max_tokens must >= chunked_prefill_size, " + f"but got {args.batch_max_tokens}, {args.chunked_prefill_size}" # help to manage data stored on Ceph if "s3://" in args.model_dir: From 977a038b33a7116d983c10c778a47eef7ae82f03 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 2 Jan 2026 16:23:02 +0800 Subject: [PATCH 61/79] fix all weights --- .../pre_and_post_layer_weight.py | 41 +++----- .../cohere/layer_infer/post_layer_infer.py | 96 ++----------------- .../pre_and_post_layer_weight.py | 49 ++++------ .../gemma3/layer_infer/post_layer_infer.py | 48 ---------- .../gemma3/layer_infer/pre_layer_infer.py | 10 +- .../layer_infer/transformer_layer_infer.py | 94 +++++++----------- .../pre_and_post_layer_weight.py | 33 +++---- .../layer_weights/transformer_layer_weight.py | 12 +-- .../layer_infer/transformer_layer_infer.py | 4 +- .../pre_and_post_layer_weight.py | 26 ++--- .../layer_infer/post_layer_infer.py | 12 +-- .../pre_and_post_layer_weight.py | 37 +++---- lightllm/models/internlm2_reward/model.py | 3 - .../pre_and_post_layer_weight.py | 3 - .../pre_and_post_layer_weight.py | 5 +- .../pre_and_post_layer_weight.py | 23 +---- .../layer_infer/transformer_layer_infer.py | 4 - .../pre_and_post_layer_weight.py | 44 +++------ .../layer_weights/transformer_layer_weight.py | 2 - .../pre_and_post_layer_weight.py | 34 +------ .../layer_weights/transformer_layer_weight.py | 3 - .../layer_infer/post_layer_infer.py | 5 +- .../pre_and_post_layer_weight.py | 59 ++++-------- .../layer_infer/transformer_layer_infer.py | 5 - .../layer_infer/transformer_layer_infer.py | 5 - .../pre_and_post_layer_weight.py | 58 +++++------ .../qwen3_vl/layer_infer/pre_layer_infer.py | 15 +-- .../layer_infer/transformer_layer_infer.py | 6 -- .../pre_and_post_layer_weight.py | 47 ++++----- .../qwen_vl/layer_infer/pre_layer_infer.py | 8 +- .../layer_infer/transformer_layer_infer.py | 3 - .../pre_and_post_layer_weight.py | 34 ++----- .../pre_and_post_layer_weight.py | 2 - .../pre_and_post_layer_weight.py | 47 ++++----- 34 files changed, 263 insertions(+), 614 deletions(-) diff --git a/lightllm/models/chatglm2/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/chatglm2/layer_weights/pre_and_post_layer_weight.py index 0b125bea3..0139eb883 100644 --- a/lightllm/models/chatglm2/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/chatglm2/layer_weights/pre_and_post_layer_weight.py @@ -1,31 +1,20 @@ -import torch -import numpy as np -from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight +from lightllm.common.basemodel import PreAndPostLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, LMHeadWeight, NoTpNormWeight -class ChatGLM2PreAndPostLayerWeight(LlamaPreAndPostLayerWeight): +class ChatGLM2PreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) - def load_hf_weights(self, weights): - # input layernorm params - - vob_size = self.network_config_["padded_vocab_size"] - split_vob_size = vob_size // self.tp_world_size_ - if "transformer.embedding.word_embeddings.weight" in weights: - self.wte_weight_ = weights["transformer.embedding.word_embeddings.weight"] - self.wte_weight_ = self.wte_weight_[ - split_vob_size * self.tp_rank_ : split_vob_size * (self.tp_rank_ + 1), : - ] - self.wte_weight_ = self._cuda(self.wte_weight_) - if "transformer.output_layer.weight" in weights: - self.lm_head_weight_ = weights["transformer.output_layer.weight"] - self.lm_head_weight_ = self.lm_head_weight_[ - split_vob_size * self.tp_rank_ : split_vob_size * (self.tp_rank_ + 1), : - ] - self.lm_head_weight_ = self._cuda(self.lm_head_weight_) - if "transformer.encoder.final_layernorm.weight" in weights: - self.final_norm_weight_ = weights["transformer.encoder.final_layernorm.weight"] - self.final_norm_weight_ = self._cuda(self.final_norm_weight_) - - return + self.wte_weight_ = EmbeddingWeight( + weight_name="transformer.embedding.word_embeddings.weight", data_type=self.data_type_ + ) + self.lm_head_weight_ = LMHeadWeight( + weight_name="transformer.output_layer.weight", + data_type=self.data_type_, + ) + self.final_norm_weight_ = NoTpNormWeight( + weight_name="transformer.encoder.final_layernorm.weight", + data_type=self.data_type_, + bias_name=None, + ) diff --git a/lightllm/models/cohere/layer_infer/post_layer_infer.py b/lightllm/models/cohere/layer_infer/post_layer_infer.py index 8b9d4b268..e2fc47581 100644 --- a/lightllm/models/cohere/layer_infer/post_layer_infer.py +++ b/lightllm/models/cohere/layer_infer/post_layer_infer.py @@ -1,101 +1,21 @@ import torch -import torch.distributed as dist -import numpy as np - from lightllm.models.cohere.infer_struct import CohereInferStateInfo from lightllm.models.cohere.layer_weights.pre_and_post_layer_weight import CoherePreAndPostLayerWeight from lightllm.models.cohere.triton_kernels.layernorm import layernorm_forward -from lightllm.common.basemodel.layer_weights.base_layer_weight import BaseLayerWeight - -from einops import rearrange -from lightllm.common.basemodel import PostLayerInferTpl -from lightllm.distributed.communication_op import all_gather +from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer +from lightllm.common.build_utils import repair_config -class CoherePostLayerInfer(PostLayerInferTpl): +class CoherePostLayerInfer(LlamaPostLayerInfer): def __init__(self, network_config, mode): + repair_config(config=network_config, same_names=["layer_norm_eps", "rms_norm_eps"]) super().__init__(network_config, mode) self.eps_ = network_config["layer_norm_eps"] - self.vocab_size_ = network_config["vocab_size"] - self.embed_dim_ = network_config["n_embed"] - self.logits_scale = network_config["logit_scale"] return - def _norm(self, input, infer_state, layer_weight: CoherePreAndPostLayerWeight) -> torch.Tensor: + def _norm( + self, input: torch.Tensor, infer_state: CohereInferStateInfo, layer_weight: CoherePreAndPostLayerWeight + ) -> torch.Tensor: return layernorm_forward( - input.unsqueeze(1), layer_weight.final_norm_weight_.unsqueeze(0), eps=self.eps_ + input.unsqueeze(1), layer_weight.final_norm_weight_.weight.unsqueeze(0), eps=self.eps_ ).squeeze(1) - - def _slice_get_last_input(self, input_embdings, infer_state: CohereInferStateInfo): - - if infer_state.is_prefill and infer_state.is_token_healing: - batch_size = infer_state.batch_size - b_seq_len_numpy = (infer_state.b_seq_len - infer_state.b_ready_cache_len).detach().cpu().numpy() - select_index = [] - start_index = 0 - select_token_num = 0 - for cur_len in b_seq_len_numpy: - - select_index.append(start_index + cur_len - 1) - start_index += cur_len - select_token_num += 1 - - last_index = torch.tensor(select_index, dtype=torch.long, device=input_embdings.device) - last_input = self.alloc_tensor( - (select_token_num, self.embed_dim_), device=input_embdings.device, dtype=input_embdings.dtype - ) - - last_input[:, :] = input_embdings[last_index, :] - return last_input, select_token_num - - if infer_state.is_prefill and not infer_state.return_all_prompt_logics: - batch_size = infer_state.batch_size - last_input = self.alloc_tensor( - (batch_size, self.embed_dim_), device=input_embdings.device, dtype=input_embdings.dtype - ) - last_index = ( - torch.cumsum(infer_state.b_seq_len - infer_state.b_ready_cache_len, dim=0, dtype=torch.long) - 1 - ) - last_input[:, :] = input_embdings[last_index, :] - return last_input, batch_size - - if infer_state.is_prefill and infer_state.return_all_prompt_logics: - total_tokens = infer_state.total_token_num - return input_embdings, total_tokens - - if not infer_state.is_prefill: - batch_size = infer_state.batch_size - return input_embdings[-batch_size:, :], batch_size - - assert False, "Error State" - - def token_forward( - self, input_embdings, infer_state: CohereInferStateInfo, layer_weight: CoherePreAndPostLayerWeight - ): - last_input, token_num = self._slice_get_last_input(input_embdings, infer_state) - input_embdings_dtype = input_embdings.dtype - input_embdings = None - last_input = self._norm(last_input, infer_state, layer_weight) - last_input = rearrange(last_input, "batch embed_dim -> embed_dim batch").contiguous().reshape(-1, token_num) - logic_batch = torch.mm(layer_weight.lm_head_weight_, last_input) - - last_input = None - if self.tp_world_size_ == 1: - gather_data = logic_batch - else: - gather_data = self.alloc_tensor( - (self.vocab_size_, token_num), device=logic_batch.device, dtype=input_embdings_dtype - ) - split_indexes = np.linspace(0, self.vocab_size_, self.tp_world_size_ + 1, dtype=np.int64) - all_gather( - [gather_data[split_indexes[i] : split_indexes[i + 1], :] for i in range(self.tp_world_size_)], - logic_batch, - group=infer_state.dist_group, - async_op=False, - ) - gather_data = gather_data * self.logits_scale - logic_batch = None - - ans_logics = gather_data.permute(1, 0).float() - gather_data = None - return ans_logics diff --git a/lightllm/models/cohere/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/cohere/layer_weights/pre_and_post_layer_weight.py index 993acd64d..f2e5f8547 100644 --- a/lightllm/models/cohere/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/cohere/layer_weights/pre_and_post_layer_weight.py @@ -1,36 +1,25 @@ -import torch -import numpy as np +from lightllm.common.basemodel import PreAndPostLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, LMHeadWeight, NoTpNormWeight -from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight - -class CoherePreAndPostLayerWeight(LlamaPreAndPostLayerWeight): - def load_hf_weights(self, weights): - vob_size = self.network_config_["vocab_size"] +class CoherePreAndPostLayerWeight(PreAndPostLayerWeight): + def __init__(self, data_type, network_config, mode): + super().__init__(data_type, network_config, mode) tie_weight = self.network_config_.get("tie_word_embeddings", True) - split_indexes = np.linspace(0, vob_size, self.tp_world_size_ + 1, dtype=np.int64) - split_start = split_indexes[self.tp_rank_] - split_end = split_indexes[self.tp_rank_ + 1] - if "model.embed_tokens.weight" in weights: - # print(weights['model.embed_tokens.weight'].shape) - self.wte_weight_ = self._cuda(weights["model.embed_tokens.weight"][split_start:split_end, :]) - if tie_weight: - self.lm_head_weight_ = self.wte_weight_ - if "model.norm.weight" in weights: - self.final_norm_weight_ = self._cuda(weights["model.norm.weight"]) - if "model.lm_head.weight" in weights: - self.lm_head_weight_ = self._cuda(weights["model.lm_head.weight"]) - return - - def verify_load(self): - super().verify_load() - errors = "tie weights load not ok" - tie_weight = self.network_config_.get("tie_word_embeddings", True) + self.wte_weight_ = EmbeddingWeight( + weight_name="model.embed_tokens.weight", + data_type=self.data_type_, + ) if tie_weight: - assert self.lm_head_weight_ is not None, errors - assert self.wte_weight_ is self.lm_head_weight_, errors + self.lm_head_weight_ = self.wte_weight_ else: - assert self.lm_head_weight_ is not None, errors - assert self.wte_weight_ is not None, errors - assert self.wte_weight_ is not self.lm_head_weight_, errors + self.lm_head_weight_ = LMHeadWeight( + weight_name="model.lm_head.weight", + data_type=self.data_type_, + ) + self.final_norm_weight_ = NoTpNormWeight( + weight_name="model.norm.weight", + data_type=self.data_type_, + bias_name=None, + ) diff --git a/lightllm/models/gemma3/layer_infer/post_layer_infer.py b/lightllm/models/gemma3/layer_infer/post_layer_infer.py index 8004309ed..22dc59505 100644 --- a/lightllm/models/gemma3/layer_infer/post_layer_infer.py +++ b/lightllm/models/gemma3/layer_infer/post_layer_infer.py @@ -1,9 +1,4 @@ -import numpy as np -import torch - -from lightllm.distributed.communication_op import all_gather from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer -from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight class Gemma3PostLayerInfer(LlamaPostLayerInfer): @@ -13,46 +8,3 @@ def __init__(self, network_config, mode): super().__init__(network_config, mode) self.eps_ = 1e-6 return - - def gemma3_rmsnorm(self, input, weight, eps: float = 1e-6, out=None): - def _inner_norm(x): - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) - - output = _inner_norm(input.float()) - output = output * (1.0 + weight.float()) - if out is not None: - out = output.to(out.dtype) - return output - - def _norm(self, input, infer_state, layer_weight) -> torch.Tensor: - return self.gemma3_rmsnorm(input, layer_weight.final_norm_weight_, eps=self.eps_) - - def token_forward(self, input_embdings, infer_state, layer_weight): - last_input, token_num = self._slice_get_last_input(input_embdings, infer_state) - input_embdings_dtype = input_embdings.dtype - last_input = self._norm(last_input.float(), infer_state, layer_weight).to(torch.bfloat16) - last_input = last_input.permute(1, 0).view(-1, token_num) - logic_batch = self.alloc_tensor( - (layer_weight.lm_head_weight_.shape[0], last_input.shape[1]), dtype=last_input.dtype - ) - torch.mm(layer_weight.lm_head_weight_.to(last_input.dtype), last_input, out=logic_batch) - last_input = None - if self.tp_world_size_ == 1: - gather_data = logic_batch - else: - gather_data = self.alloc_tensor((self.vocab_size_, token_num), dtype=input_embdings_dtype) - split_indexes = np.linspace(0, self.vocab_size_, self.tp_world_size_ + 1, dtype=np.int64) - all_gather( - [gather_data[split_indexes[i] : split_indexes[i + 1], :] for i in range(self.tp_world_size_)], - logic_batch, - group=infer_state.dist_group, - async_op=False, - ) - logic_batch = None - ans_logics = self.alloc_tensor( - (token_num, self.vocab_size_), - dtype=torch.float32, - ) - ans_logics[:, :] = gather_data.permute(1, 0) - gather_data = None - return ans_logics diff --git a/lightllm/models/gemma3/layer_infer/pre_layer_infer.py b/lightllm/models/gemma3/layer_infer/pre_layer_infer.py index 0df5c0f06..20fc9aa9e 100644 --- a/lightllm/models/gemma3/layer_infer/pre_layer_infer.py +++ b/lightllm/models/gemma3/layer_infer/pre_layer_infer.py @@ -16,8 +16,8 @@ def context_forward(self, input_ids, infer_state, layer_weight): img_start_token_ids = [] img_token_lens = [] img_start_locs_in_cache = [] - device = layer_weight.wte_weight_.device - dtype = layer_weight.wte_weight_.dtype + device = layer_weight.wte_weight_.weight.device + dtype = layer_weight.wte_weight_.weight.dtype hidden_size = layer_weight.wte_weight_.shape[1] weight_mask = torch.zeros((len(input_ids)), dtype=torch.float32, device=device) @@ -70,13 +70,13 @@ def context_forward(self, input_ids, infer_state, layer_weight): img_token_lens=img_token_lens, img_start_token_ids=img_start_token_ids, img_start_locs_in_cache=img_start_locs_in_cache, - tp_text_start_token_id=self.vob_start_id_, - tp_text_end_token_id=self.vob_end_id_, + tp_text_start_token_id=layer_weight.wte_weight_.tp_vocab_start_id, + tp_text_end_token_id=layer_weight.wte_weight_.tp_vocab_end_id, tp_world_size=self.tp_world_size_, ) input_dtype = out.dtype if self.tp_world_size_ > 1: - all_reduce(out, group=infer_state.dist_group, op=torch.dist.ReduceOp.SUM, async_op=False) + all_reduce(out, group=infer_state.dist_group, op=torch.distributed.ReduceOp.SUM, async_op=False) return (out.float() * weight_mask.unsqueeze(1).float()).to(input_dtype) def token_forward(self, input_ids, infer_state, layer_weight): diff --git a/lightllm/models/gemma3/layer_infer/transformer_layer_infer.py b/lightllm/models/gemma3/layer_infer/transformer_layer_infer.py index 09efe9a36..d4bd8c3fa 100644 --- a/lightllm/models/gemma3/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/gemma3/layer_infer/transformer_layer_infer.py @@ -9,11 +9,7 @@ from lightllm.common.basemodel.infer_struct import InferStateInfo from lightllm.distributed import all_reduce -from lightllm.models.cohere.layer_infer.transformer_layer_infer import CohereTransformerLayerInfer -from lightllm.models.cohere.triton_kernels.layernorm import layernorm_forward from lightllm.models.gemma3.layer_weights.transformer_layer_weight import Gemma3TransformerLayerWeight -from lightllm.models.gemma_2b.triton_kernel.gelu_and_mul import gelu_and_mul_fwd - from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd @@ -31,54 +27,6 @@ def __init__(self, layer_num, network_config, mode=[]): self.sliding_window_pattern = 6 return - def gemma3_rmsnorm(self, input, weight, eps: float = 1e-6, out=None): - def _norm(x): - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) - - output = _norm(input.float()) - output = output * (1.0 + weight.float()) - if out is not None: - out = output.to(out.dtype) - return output - - def _pre_feedforward_layernorm(self, input, infer_state, layer_weight: Gemma3TransformerLayerWeight): - out = self.alloc_tensor(input.shape, input.dtype) - out = self.gemma3_rmsnorm(input, layer_weight.pre_feedforward_layernorm_weight_.weight, self.eps_, out=out) - return out - - def _post_feedforward_layernorm(self, input, infer_state, layer_weight: Gemma3TransformerLayerWeight): - out = self.alloc_tensor(input.shape, input.dtype) - out = self.gemma3_rmsnorm(input, layer_weight.post_feedforward_layernorm_weight_.weight, self.eps_, out=out) - return out - - def _k_norm(self, input, infer_state, layer_weight: Gemma3TransformerLayerWeight): - out = self.alloc_tensor(input.shape, input.dtype) - out = self.gemma3_rmsnorm(input, layer_weight.k_norm_weight_.weight, self.eps_, out=out) - return out - - def _q_norm(self, input, infer_state, layer_weight: Gemma3TransformerLayerWeight): - out = self.alloc_tensor(input.shape, input.dtype) - out = self.gemma3_rmsnorm(input, layer_weight.q_norm_weight_.weight, self.eps_, out=out) - return out - - def _att_norm(self, input, infer_state, layer_weight): - out = self.alloc_tensor(input.shape, input.dtype) - out = self.gemma3_rmsnorm(input, layer_weight.att_norm_weight_.weight, self.eps_, out=out) - return out - - def _ffn_norm(self, input, infer_state, layer_weight): - out = self.alloc_tensor(input.shape, input.dtype) - out = self.gemma3_rmsnorm(input, layer_weight.ffn_norm_weight_.weight, self.eps_, out=out) - return out - - def _bind_norm(self): - self._att_norm = partial(Gemma3TransformerLayerInfer._att_norm, self) - self._ffn_norm = partial(Gemma3TransformerLayerInfer._ffn_norm, self) - self._q_norm = partial(Gemma3TransformerLayerInfer._q_norm, self) - self._k_norm = partial(Gemma3TransformerLayerInfer._k_norm, self) - self._pre_feedforward_layernorm = partial(Gemma3TransformerLayerInfer._pre_feedforward_layernorm, self) - self._post_feedforward_layernorm = partial(Gemma3TransformerLayerInfer._post_feedforward_layernorm, self) - def _get_qkv( self, input, infer_state: LlamaInferStateInfo, layer_weight: Gemma3TransformerLayerWeight ) -> torch.Tensor: @@ -94,8 +42,16 @@ def _get_qkv( # gemma3 use qk norm q = q.view(-1, self.tp_q_head_num_, self.head_dim_) k = cache_kv[:, 0 : self.tp_k_head_num_, :] - q = self._q_norm(q.float(), infer_state, layer_weight).to(cache_kv.dtype) - cache_kv[:, 0 : self.tp_k_head_num_, :] = self._k_norm(k.float(), infer_state, layer_weight).to(cache_kv.dtype) + + q = layer_weight.q_norm_weight_.rmsnorm_forward( + input=q.float(), eps=self.eps_, alloc_func=self.alloc_tensor + ).to(cache_kv.dtype) + + cache_kv[:, 0 : self.tp_k_head_num_, :] = layer_weight.k_norm_weight_.rmsnorm_forward( + input=k.float(), + eps=self.eps_, + alloc_func=self.alloc_tensor, + ).to(cache_kv.dtype) is_sliding = bool((self.layer_num_ + 1) % self.sliding_window_pattern) if is_sliding: @@ -125,7 +81,7 @@ def _ffn(self, input, infer_state: LlamaInferStateInfo, layer_weight: Gemma3Tran ffn1_out = None return ffn2_out - def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): + def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight: Gemma3TransformerLayerWeight): input_embdings = input_embdings.to(torch.bfloat16) input1 = self._att_norm(input_embdings.view(-1, self.embed_dim_).float(), infer_state, layer_weight).to( torch.bfloat16 @@ -142,16 +98,25 @@ def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_wei input_embdings.add_(o.view(-1, self.embed_dim_)) o = None - input1 = self._pre_feedforward_layernorm(input_embdings.float(), infer_state, layer_weight).to(torch.bfloat16) + input1 = layer_weight.pre_feedforward_layernorm_weight_.rmsnorm_forward( + input=input_embdings.float(), eps=self.eps_, alloc_func=self.alloc_tensor + ).to(torch.bfloat16) + ffn_out = self._ffn(input1, infer_state, layer_weight) input1 = None if self.tp_world_size_ > 1: all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) - ffn_out = self._post_feedforward_layernorm(ffn_out.float(), infer_state, layer_weight).to(torch.bfloat16) + + ffn_out = layer_weight.post_feedforward_layernorm_weight_.rmsnorm_forward( + input=ffn_out.float(), + eps=self.eps_, + alloc_func=self.alloc_tensor, + ).to(torch.bfloat16) + input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) return input_embdings - def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): + def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight: Gemma3TransformerLayerWeight): input_embdings = input_embdings.to(torch.bfloat16) input1 = self._att_norm(input_embdings.view(-1, self.embed_dim_).float(), infer_state, layer_weight).to( torch.bfloat16 @@ -168,11 +133,20 @@ def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weigh input_embdings.add_(o.view(-1, self.embed_dim_)) o = None - input1 = self._pre_feedforward_layernorm(input_embdings.float(), infer_state, layer_weight).to(torch.bfloat16) + input1 = layer_weight.pre_feedforward_layernorm_weight_.rmsnorm_forward( + input=input_embdings.float(), eps=self.eps_, alloc_func=self.alloc_tensor + ).to(torch.bfloat16) + ffn_out = self._ffn(input1, infer_state, layer_weight) input1 = None if self.tp_world_size_ > 1: all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) - ffn_out = self._post_feedforward_layernorm(ffn_out.float(), infer_state, layer_weight).to(torch.bfloat16) + + ffn_out = layer_weight.post_feedforward_layernorm_weight_.rmsnorm_forward( + input=ffn_out.float(), + eps=self.eps_, + alloc_func=self.alloc_tensor, + ).to(torch.bfloat16) + input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) return input_embdings diff --git a/lightllm/models/gemma3/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/gemma3/layer_weights/pre_and_post_layer_weight.py index 24ea91d89..17e65268c 100644 --- a/lightllm/models/gemma3/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/gemma3/layer_weights/pre_and_post_layer_weight.py @@ -1,25 +1,20 @@ -import torch -import numpy as np -from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight +from lightllm.common.basemodel import PreAndPostLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, NoTpGEMMANormWeight -# add key: language_model.xxx -> xxx -# only change keys at PreAndPostLayerWeight load, TransformLayerWeight is correct now -def rename_weight_keys(weights): - prefix = "language_model." - keys = list(weights.keys()) - for k in keys: - if prefix in k: - weights[k[len(prefix) :]] = weights[k] - - -class Gemma3PreAndPostLayerWeight(LlamaPreAndPostLayerWeight): +class Gemma3PreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): - network_config["tie_word_embeddingse"] = True super().__init__(data_type, network_config, mode) - return - def load_hf_weights(self, weights): - rename_weight_keys(weights) - super().load_hf_weights(weights) + self.wte_weight_ = EmbeddingWeight( + weight_name="language_model.model.embed_tokens.weight", + data_type=self.data_type_, + ) + self.lm_head_weight_ = self.wte_weight_ + + self.final_norm_weight_ = NoTpGEMMANormWeight( + weight_name="language_model.model.norm.weight", + data_type=self.data_type_, + bias_name=None, + ) return diff --git a/lightllm/models/gemma3/layer_weights/transformer_layer_weight.py b/lightllm/models/gemma3/layer_weights/transformer_layer_weight.py index 71227bd9b..1e7ceeb42 100644 --- a/lightllm/models/gemma3/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/gemma3/layer_weights/transformer_layer_weight.py @@ -1,8 +1,6 @@ -import torch -import numpy as np from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight import ROWMMWeight -from lightllm.common.basemodel.layer_weights.meta_weights import NoTpNormWeight from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import NoTpGEMMANormWeight class Gemma3TransformerLayerWeight(LlamaTransformerLayerWeight): @@ -66,12 +64,12 @@ def _init_qkv(self): def _init_norm(self): super()._init_norm() - self.k_norm_weight_ = NoTpNormWeight(self._k_norm_weight_name, self.data_type_, bias_name=None) - self.q_norm_weight_ = NoTpNormWeight(self._q_norm_weight_name, self.data_type_, bias_name=None) - self.pre_feedforward_layernorm_weight_ = NoTpNormWeight( + self.k_norm_weight_ = NoTpGEMMANormWeight(self._k_norm_weight_name, self.data_type_, bias_name=None) + self.q_norm_weight_ = NoTpGEMMANormWeight(self._q_norm_weight_name, self.data_type_, bias_name=None) + self.pre_feedforward_layernorm_weight_ = NoTpGEMMANormWeight( self._pre_feedforward_layernorm_name, self.data_type_, bias_name=None ) - self.post_feedforward_layernorm_weight_ = NoTpNormWeight( + self.post_feedforward_layernorm_weight_ = NoTpGEMMANormWeight( self._post_feedforward_layernorm_name, self.data_type_, bias_name=None ) diff --git a/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py b/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py index 8aacb457d..93cd7413b 100644 --- a/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py @@ -106,7 +106,9 @@ def _context_sliding_attention_flashattention( ) return o - def _token_sliding_attention_flashattention(self, q, infer_state: FlashAttentionStateInfo, layer_weight, out=None): + def _token_sliding_attention_flashattention( + self, q, infer_state: FlashAttentionStateInfo, layer_weight: GptOssTransformerLayerWeight, out=None + ): if self.network_config_["layer_types"][self.layer_num_] == "sliding_attention": window_size = (self.sliding_window - 1, self.sliding_window - 1) else: diff --git a/lightllm/models/internlm2/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/internlm2/layer_weights/pre_and_post_layer_weight.py index dd8c64915..b40330aa3 100644 --- a/lightllm/models/internlm2/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/internlm2/layer_weights/pre_and_post_layer_weight.py @@ -1,23 +1,15 @@ -import torch -import numpy as np -from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight +from lightllm.common.basemodel import PreAndPostLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, LMHeadWeight, NoTpNormWeight -class Internlm2PreAndPostLayerWeight(LlamaPreAndPostLayerWeight): +class Internlm2PreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) - return - - def load_hf_weights(self, weights): - vob_size = self.network_config_["vocab_size"] - split_indexes = np.linspace(0, vob_size, self.tp_world_size_ + 1, dtype=np.int64) - split_start = split_indexes[self.tp_rank_] - split_end = split_indexes[self.tp_rank_ + 1] - if "model.tok_embeddings.weight" in weights: - self.wte_weight_ = self._cuda(weights["model.tok_embeddings.weight"][split_start:split_end, :]) - if "output.weight" in weights: - self.lm_head_weight_ = self._cuda(weights["output.weight"][split_start:split_end, :]) - if "model.norm.weight" in weights: - self.final_norm_weight_ = self._cuda(weights["model.norm.weight"]) + self.wte_weight_ = EmbeddingWeight(weight_name="model.tok_embeddings.weight", data_type=self.data_type_) + self.lm_head_weight_ = LMHeadWeight(weight_name="output.weight", data_type=self.data_type_) + self.final_norm_weight_ = NoTpNormWeight( + weight_name="model.norm.weight", + data_type=self.data_type_, + ) return diff --git a/lightllm/models/internlm2_reward/layer_infer/post_layer_infer.py b/lightllm/models/internlm2_reward/layer_infer/post_layer_infer.py index 42061c784..d5eaa1654 100644 --- a/lightllm/models/internlm2_reward/layer_infer/post_layer_infer.py +++ b/lightllm/models/internlm2_reward/layer_infer/post_layer_infer.py @@ -1,19 +1,17 @@ import torch -import torch.distributed as dist -import numpy as np - from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer -from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight -from einops import rearrange +from ..layer_weights.pre_and_post_layer_weight import Internlm2RewardPreAndPostLayerWeight class Internlm2RewardPostLayerInfer(LlamaPostLayerInfer): - def token_forward(self, input_embdings, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): + def token_forward( + self, input_embdings, infer_state: LlamaInferStateInfo, layer_weight: Internlm2RewardPreAndPostLayerWeight + ): last_input, token_num = self._slice_get_last_input(input_embdings, infer_state) input_embdings = None last_input = self._norm(last_input, infer_state, layer_weight) - score = torch.mm(last_input, layer_weight.lm_head_weight_) + score = layer_weight.score_head_.mm(last_input) return score diff --git a/lightllm/models/internlm2_reward/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/internlm2_reward/layer_weights/pre_and_post_layer_weight.py index 78fb0c5d7..f64fb23e0 100644 --- a/lightllm/models/internlm2_reward/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/internlm2_reward/layer_weights/pre_and_post_layer_weight.py @@ -1,23 +1,26 @@ -import torch import numpy as np -from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight +from lightllm.common.basemodel import PreAndPostLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, NoTpNormWeight, ROWMMWeight -class Internlm2RewardPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): +class Internlm2RewardPreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) - return - - def load_hf_weights(self, weights): - vob_size = self.network_config_["vocab_size"] - split_indexes = np.linspace(0, vob_size, self.tp_world_size_ + 1, dtype=np.int64) - split_start = split_indexes[self.tp_rank_] - split_end = split_indexes[self.tp_rank_ + 1] - if "model.tok_embeddings.weight" in weights: - self.wte_weight_ = self._cuda(weights["model.tok_embeddings.weight"][split_start:split_end, :]) - if "v_head.weight" in weights: - self.lm_head_weight_ = self._cuda(weights["v_head.weight"]).transpose(0, 1) - if "model.norm.weight" in weights: - self.final_norm_weight_ = self._cuda(weights["model.norm.weight"]) - + self.wte_weight_ = EmbeddingWeight( + weight_name="model.tok_embeddings.weight", + data_type=self.data_type_, + ) + self.score_head_ = ROWMMWeight( + weight_names="v_head.weight", + data_type=self.data_type_, + quant_cfg=None, + layer_num=0, + name="kv_a_proj_with_mqa", + tp_rank=0, + tp_world_size=1, + ) + self.final_norm_weight_ = NoTpNormWeight( + weight_name="model.norm.weight", + data_type=self.data_type_, + ) return diff --git a/lightllm/models/internlm2_reward/model.py b/lightllm/models/internlm2_reward/model.py index 881a607a5..b9ea002a6 100644 --- a/lightllm/models/internlm2_reward/model.py +++ b/lightllm/models/internlm2_reward/model.py @@ -1,6 +1,3 @@ -import os -import json -import torch from lightllm.models.registry import ModelRegistry, is_reward_model from lightllm.models.internlm2_reward.layer_infer.post_layer_infer import Internlm2RewardPostLayerInfer from lightllm.models.internlm2_reward.layer_weights.pre_and_post_layer_weight import ( diff --git a/lightllm/models/internvl/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/internvl/layer_weights/pre_and_post_layer_weight.py index f19563932..7d76d202a 100644 --- a/lightllm/models/internvl/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/internvl/layer_weights/pre_and_post_layer_weight.py @@ -1,7 +1,4 @@ -import torch -import numpy as np from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight - from lightllm.models.internlm2.layer_weights.pre_and_post_layer_weight import Internlm2PreAndPostLayerWeight diff --git a/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py index 5dbd0eb63..ea59d24df 100644 --- a/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py @@ -1,6 +1,3 @@ -import os -import torch -import numpy as np from lightllm.common.basemodel import PreAndPostLayerWeight from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, LMHeadWeight, NoTpNormWeight @@ -15,7 +12,7 @@ def __init__(self, data_type, network_config, mode): ) tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False) if tie_word_embeddings: - self.lm_head_weight_ = self.wte_weight_ + self.lm_head_weight_: LMHeadWeight = self.wte_weight_ else: self.lm_head_weight_ = LMHeadWeight( weight_name="lm_head.weight", diff --git a/lightllm/models/minicpm/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/minicpm/layer_weights/pre_and_post_layer_weight.py index 94e1a27e0..71fae2b7b 100644 --- a/lightllm/models/minicpm/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/minicpm/layer_weights/pre_and_post_layer_weight.py @@ -1,5 +1,4 @@ -import torch -import numpy as np +import copy from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight @@ -12,23 +11,11 @@ def __init__(self, data_type, network_config, mode): self.scale_emb = self.network_config_.get("scale_emb", 1) return - def load_hf_weights(self, weights): - vob_size = self.network_config_["vocab_size"] - split_indexes = np.linspace(0, vob_size, self.tp_world_size_ + 1, dtype=np.int64) - split_start = split_indexes[self.tp_rank_] - split_end = split_indexes[self.tp_rank_ + 1] - if "model.embed_tokens.weight" in weights: - self.wte_weight_ = self._cuda(weights["model.embed_tokens.weight"][split_start:split_end, :]) - if "lm_head.weight" in weights: - self.lm_head_weight_ = self._cuda(weights["lm_head.weight"][split_start:split_end, :]) / self.lm_head_scale - if "model.norm.weight" in weights: - self.final_norm_weight_ = self._cuda(weights["model.norm.weight"]) - - return - def verify_load(self): - if not hasattr(self, "lm_head_weight_"): - self.lm_head_weight_ = self.wte_weight_ / self.lm_head_scale + if self.lm_head_weight_ == self.wte_weight_: + self.lm_head_weight_ = copy.copy(self.lm_head_weight_) + + self.lm_head_weight_.weight = self.lm_head_weight_.weight / self.lm_head_scale self.wte_weight_ = self.wte_weight_ * self.scale_emb errors = "weights load not ok" weights = [self.wte_weight_, self.lm_head_weight_, self.final_norm_weight_] diff --git a/lightllm/models/phi3/layer_infer/transformer_layer_infer.py b/lightllm/models/phi3/layer_infer/transformer_layer_infer.py index ce27e3ee5..806c59365 100755 --- a/lightllm/models/phi3/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/phi3/layer_infer/transformer_layer_infer.py @@ -1,9 +1,5 @@ import torch -import torch.functional as F -import torch.distributed as dist -import numpy as np from functools import partial - from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.models.phi3.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.models.phi3.triton_kernel.context_flashattention_nopad import ( diff --git a/lightllm/models/qwen/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen/layer_weights/pre_and_post_layer_weight.py index 95af6ecd3..00f68eee6 100644 --- a/lightllm/models/qwen/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/qwen/layer_weights/pre_and_post_layer_weight.py @@ -1,40 +1,22 @@ import torch import numpy as np from lightllm.common.basemodel import PreAndPostLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, LMHeadWeight, NoTpNormWeight class QwenPreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) - return - - def load_hf_weights(self, weights): - - vob_size = self.network_config_["vocab_size"] - split_vob_size = vob_size // self.tp_world_size_ - - if "transformer.wte.weight" in weights: - self.wte_weight_ = self._cuda( - weights["transformer.wte.weight"][ - split_vob_size * self.tp_rank_ : split_vob_size * (self.tp_rank_ + 1), : - ] - ) - if "lm_head.weight" in weights: - self.lm_head_weight_ = self._cuda( - weights["lm_head.weight"][split_vob_size * self.tp_rank_ : split_vob_size * (self.tp_rank_ + 1), :] - ) - if "transformer.ln_f.weight" in weights: - self.final_norm_weight_ = self._cuda(weights["transformer.ln_f.weight"]) - - return - - def verify_load(self): - errors = "weights load not ok" - weights = [ - self.wte_weight_, - self.lm_head_weight_, - self.final_norm_weight_, - ] - for i in range(len(weights)): - assert weights[i] is not None, "index:" + str(i) + " " + errors + self.wte_weight_ = EmbeddingWeight( + weight_name="transformer.wte.weight", + data_type=self.data_type_, + ) + self.lm_head_weight_ = LMHeadWeight( + weight_name="lm_head.weight", + data_type=self.data_type_, + ) + self.final_norm_weight_ = NoTpNormWeight( + weight_name="transformer.ln_f.weight", + data_type=self.data_type_, + ) return diff --git a/lightllm/models/qwen/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen/layer_weights/transformer_layer_weight.py index 7c710af2d..9afb964ad 100755 --- a/lightllm/models/qwen/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen/layer_weights/transformer_layer_weight.py @@ -1,6 +1,4 @@ import torch -import math -import numpy as np from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight diff --git a/lightllm/models/qwen2/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen2/layer_weights/pre_and_post_layer_weight.py index 5735b0339..a8a57c02e 100644 --- a/lightllm/models/qwen2/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/qwen2/layer_weights/pre_and_post_layer_weight.py @@ -1,37 +1,7 @@ -import torch -import numpy as np -from lightllm.common.basemodel import PreAndPostLayerWeight +from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight -class Qwen2PreAndPostLayerWeight(PreAndPostLayerWeight): +class Qwen2PreAndPostLayerWeight(LlamaPreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) return - - def load_hf_weights(self, weights): - vob_size = self.network_config_["vocab_size"] - split_indexes = np.linspace(0, vob_size, self.tp_world_size_ + 1, dtype=np.int64) - split_start = split_indexes[self.tp_rank_] - split_end = split_indexes[self.tp_rank_ + 1] - if "model.embed_tokens.weight" in weights: - self.wte_weight_ = self._cuda(weights["model.embed_tokens.weight"][split_start:split_end, :]) - tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False) - if tie_word_embeddings: - self.lm_head_weight_ = self.wte_weight_ - if "lm_head.weight" in weights: - self.lm_head_weight_ = self._cuda(weights["lm_head.weight"][split_start:split_end, :]) - if "model.norm.weight" in weights: - self.final_norm_weight_ = self._cuda(weights["model.norm.weight"]) - - return - - def verify_load(self): - errors = "weights load not ok" - weights = [ - self.wte_weight_, - self.lm_head_weight_, - self.final_norm_weight_, - ] - for i in range(len(weights)): - assert weights[i] is not None, "index:" + str(i) + " " + errors - return diff --git a/lightllm/models/qwen2/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen2/layer_weights/transformer_layer_weight.py index 2e2c0d3bb..6962818c4 100644 --- a/lightllm/models/qwen2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen2/layer_weights/transformer_layer_weight.py @@ -1,6 +1,3 @@ -import torch -import math -import numpy as np from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight diff --git a/lightllm/models/qwen2_reward/layer_infer/post_layer_infer.py b/lightllm/models/qwen2_reward/layer_infer/post_layer_infer.py index 22ec8fd43..e9b41d7ab 100644 --- a/lightllm/models/qwen2_reward/layer_infer/post_layer_infer.py +++ b/lightllm/models/qwen2_reward/layer_infer/post_layer_infer.py @@ -3,7 +3,6 @@ from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer from lightllm.models.qwen2_reward.layer_weights.pre_and_post_layer_weight import Qwen2RewardPreAndPostLayerWeight -from einops import rearrange class Qwen2RewardPostLayerInfer(LlamaPostLayerInfer): @@ -15,8 +14,8 @@ def token_forward( input_embdings = None last_input = self._norm(last_input, infer_state, layer_weight) - last_input = torch.addmm(layer_weight.score_up_bias, last_input, layer_weight.score_up_weight) + last_input = layer_weight.score_up_weight_.mm(last_input) last_input = torch.nn.functional.relu(last_input) - score = torch.addmm(layer_weight.score_down_bias, last_input, layer_weight.score_down_weight) + score = layer_weight.score_down_weight_.mm(last_input) return score diff --git a/lightllm/models/qwen2_reward/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen2_reward/layer_weights/pre_and_post_layer_weight.py index 8c1e288a4..0ea4e5820 100644 --- a/lightllm/models/qwen2_reward/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/qwen2_reward/layer_weights/pre_and_post_layer_weight.py @@ -7,44 +7,23 @@ class Qwen2RewardPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) - return - - def load_hf_weights(self, weights): - vob_size = self.network_config_["vocab_size"] - split_indexes = np.linspace(0, vob_size, self.tp_world_size_ + 1, dtype=np.int64) - split_start = split_indexes[self.tp_rank_] - split_end = split_indexes[self.tp_rank_ + 1] - if "model.embed_tokens.weight" in weights: - self.wte_weight_ = self._cuda(weights["model.embed_tokens.weight"][split_start:split_end, :]) - tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False) - if tie_word_embeddings: - self.lm_head_weight_ = self.wte_weight_ - - if "model.norm.weight" in weights: - self.final_norm_weight_ = self._cuda(weights["model.norm.weight"]) - - if "score.0.weight" in weights: - self.score_up_weight = self._cuda(weights["score.0.weight"]).transpose(0, 1) - if "score.0.bias" in weights: - self.score_up_bias = self._cuda(weights["score.0.bias"]) - - if "score.2.weight" in weights: - self.score_down_weight = self._cuda(weights["score.2.weight"]).transpose(0, 1) - if "score.2.bias" in weights: - self.score_down_bias = self._cuda(weights["score.2.bias"]) - - return - - def verify_load(self): - errors = "weights load not ok" - weights = [ - self.wte_weight_, - self.final_norm_weight_, - self.score_up_weight, - self.score_up_bias, - self.score_down_weight, - self.score_down_bias, - ] - for i in range(len(weights)): - assert weights[i] is not None, "index:" + str(i) + " " + errors + del self.lm_head_weight_ + self.score_up_weight_ = ROWMMWeight( + weight_names="score.0.weight", + bias_names="score.0.bias", + data_type=self.data_type_, + layer_num=0, + name="score_up_weight", + tp_rank=0, + tp_world_size=1, + ) + self.score_down_weight_ = ROWMMWeight( + weight_names="score.2.weight", + bias_names="score.2.bias", + data_type=self.data_type_, + layer_num=0, + name="score_down_weight", + tp_rank=0, + tp_world_size=1, + ) return diff --git a/lightllm/models/qwen2_vl/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen2_vl/layer_infer/transformer_layer_infer.py index d05f8d3b5..19e17c36e 100755 --- a/lightllm/models/qwen2_vl/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen2_vl/layer_infer/transformer_layer_infer.py @@ -1,10 +1,5 @@ import torch -import torch.functional as F -import torch.distributed as dist -import numpy as np from typing import Tuple -from functools import partial - from lightllm.models.qwen2_vl.triton_kernel.mrope import mrope_triton_fused from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer diff --git a/lightllm/models/qwen3/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3/layer_infer/transformer_layer_infer.py index 30ae2242d..20f135e76 100644 --- a/lightllm/models/qwen3/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3/layer_infer/transformer_layer_infer.py @@ -1,9 +1,4 @@ -import os import torch -import torch.functional as F -import torch.distributed as dist -import numpy as np -import triton from typing import Tuple from lightllm.models.qwen3.layer_weights.transformer_layer_weight import Qwen3TransformerLayerWeight from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer diff --git a/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py index 19f6958f6..5431314aa 100644 --- a/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py @@ -1,35 +1,37 @@ import numpy as np -from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight +from lightllm.common.basemodel import PreAndPostLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ( + EmbeddingWeight, + ROWMMWeight, + LMHeadWeight, + NoTpNormWeight, +) -class Qwen3MOEMTPPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): +class Qwen3MOEMTPPreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) - # 与Qwen3MOE模型共享 - self.wte_weight_ = None - self.lm_head_weight_ = None - return - def load_hf_weights(self, weights): - vob_size = self.network_config_["vocab_size"] - split_indexes = np.linspace(0, vob_size, self.tp_world_size_ + 1, dtype=np.int64) - split_start = split_indexes[self.tp_rank_] - split_end = split_indexes[self.tp_rank_ + 1] - if "model.layers.0.proj.weight" in weights: - self.eh_proj_weight_ = self._cuda(weights["model.layers.0.proj.weight"]).t() - if "model.layers.0.norm_after_embedding.weight" in weights: - self.enorm_weight_ = self._cuda(weights["model.layers.0.norm_after_embedding.weight"]) - if "model.layers.0.norm_before_output.weight" in weights: - self.hnorm_weight_ = self._cuda(weights["model.layers.0.norm_before_output.weight"]) - if "lm_head.weight" in weights: - self.lm_head_weight_ = self._cuda(weights["lm_head.weight"][split_start:split_end, :]) - if "model.norm.weight" in weights: - self.final_norm_weight_ = self._cuda(weights["model.norm.weight"]) - return - - def verify_load(self): - errors = "weights load not ok" - weights = [self.eh_proj_weight_, self.enorm_weight_, self.hnorm_weight_] - for i in range(len(weights)): - assert weights[i] is not None, "index:" + str(i) + " " + errors + self.eh_proj_weight_ = ROWMMWeight( + weight_names="model.layers.0.proj.weight", + data_type=self.data_type_, + layer_num=0, + name="eh_proj", + tp_rank=0, + tp_world_size=1, + ) + self.enorm_weight_ = NoTpNormWeight( + weight_name="model.layers.0.norm_after_embedding.weight", + data_type=self.data_type_, + bias_name=None, + ) + self.hnorm_weight_ = NoTpNormWeight( + weight_name="model.layers.0.norm_before_output.weight", + data_type=self.data_type_, + bias_name=None, + ) + # 与Qwen3MOE模型共享 + self.wte_weight_: EmbeddingWeight = None + self.lm_head_weight_: LMHeadWeight = None + self.final_norm_weight_: NoTpNormWeight = None return diff --git a/lightllm/models/qwen3_vl/layer_infer/pre_layer_infer.py b/lightllm/models/qwen3_vl/layer_infer/pre_layer_infer.py index c79bb7665..f2f3f5e3d 100644 --- a/lightllm/models/qwen3_vl/layer_infer/pre_layer_infer.py +++ b/lightllm/models/qwen3_vl/layer_infer/pre_layer_infer.py @@ -1,11 +1,10 @@ import torch import torch.distributed as dist - -from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo from lightllm.common.basemodel.triton_kernel.multimodal_emb import multimodal_emb from lightllm.distributed.communication_op import all_reduce from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer +from ..layer_weights.pre_and_post_layer_weight import Qwen3VLPreAndPostLayerWeight class Qwen3VLMultimodalPreLayerInfer(LlamaMultimodalPreLayerInfer): @@ -13,12 +12,14 @@ def __init__(self, network_config, mode): super().__init__(network_config, mode) return - def context_forward(self, input_ids, infer_state: Qwen3VLInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): + def context_forward( + self, input_ids, infer_state: Qwen3VLInferStateInfo, layer_weight: Qwen3VLPreAndPostLayerWeight + ): img_start_token_ids = [] img_token_lens = [] img_start_locs_in_cache = [] - device = layer_weight.wte_weight_.device - dtype = layer_weight.wte_weight_.dtype + device = layer_weight.wte_weight_.weight.device + dtype = layer_weight.wte_weight_.weight.dtype hidden_size = layer_weight.wte_weight_.shape[1] for batch_id, p in enumerate(infer_state.multimodal_params): @@ -60,8 +61,8 @@ def context_forward(self, input_ids, infer_state: Qwen3VLInferStateInfo, layer_w img_token_lens=infer_state.img_token_lens, img_start_token_ids=infer_state.img_start_token_ids, img_start_locs_in_cache=infer_state.img_start_locs_in_cache, - tp_text_start_token_id=self.vob_start_id_, - tp_text_end_token_id=self.vob_end_id_, + tp_text_start_token_id=layer_weight.wte_weight_.tp_vocab_start_id, + tp_text_end_token_id=layer_weight.wte_weight_.tp_vocab_end_id, tp_world_size=self.tp_world_size_, ) if self.tp_world_size_ > 1: diff --git a/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py index 0bc878d96..b155f8b90 100644 --- a/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py @@ -1,15 +1,9 @@ import torch -import torch.functional as F import torch.distributed as dist -import numpy as np -from functools import partial from typing import Tuple -from lightllm.common.basemodel.infer_struct import InferStateInfo from lightllm.models.qwen2_vl.triton_kernel.mrope import mrope_triton_fused from lightllm.models.qwen3_moe.layer_infer.transformer_layer_infer import Qwen3MOETransformerLayerInfer from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight -from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer -from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo from lightllm.models.qwen3.triton_kernel.qk_norm import qk_rmsnorm_forward from lightllm.distributed import all_reduce diff --git a/lightllm/models/qwen3_vl_moe/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen3_vl_moe/layer_weights/pre_and_post_layer_weight.py index 0a7f82a93..b1f5ee660 100644 --- a/lightllm/models/qwen3_vl_moe/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/qwen3_vl_moe/layer_weights/pre_and_post_layer_weight.py @@ -1,37 +1,24 @@ -import torch import numpy as np from lightllm.common.basemodel import PreAndPostLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, LMHeadWeight, NoTpNormWeight class Qwen3VLMOEPreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) - return - - def load_hf_weights(self, weights): - vob_size = self.network_config_["vocab_size"] - split_indexes = np.linspace(0, vob_size, self.tp_world_size_ + 1, dtype=np.int64) - split_start = split_indexes[self.tp_rank_] - split_end = split_indexes[self.tp_rank_ + 1] - if "model.language_model.embed_tokens.weight" in weights: - self.wte_weight_ = self._cuda(weights["model.language_model.embed_tokens.weight"][split_start:split_end, :]) - tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False) - if tie_word_embeddings: - self.lm_head_weight_ = self.wte_weight_ - if "lm_head.weight" in weights: - self.lm_head_weight_ = self._cuda(weights["lm_head.weight"][split_start:split_end, :]) - if "model.language_model.norm.weight" in weights: - self.final_norm_weight_ = self._cuda(weights["model.language_model.norm.weight"]) - - return - - def verify_load(self): - errors = "weights load not ok" - weights = [ - self.wte_weight_, - self.lm_head_weight_, - self.final_norm_weight_, - ] - for i in range(len(weights)): - assert weights[i] is not None, "index:" + str(i) + " " + errors - return + self.wte_weight_ = EmbeddingWeight( + weight_name="model.language_model.embed_tokens.weight", + data_type=self.data_type_, + ) + tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False) + if tie_word_embeddings: + self.lm_head_weight_: LMHeadWeight = self.wte_weight_ + else: + self.lm_head_weight_ = LMHeadWeight( + weight_name="lm_head.weight", + data_type=self.data_type_, + ) + self.final_norm_weight_ = NoTpNormWeight( + weight_name="model.language_model.norm.weight", + data_type=self.data_type_, + ) diff --git a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py index b8500980a..c311e746e 100644 --- a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py +++ b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py @@ -32,8 +32,8 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei img_start_token_ids = [] img_token_lens = [] img_start_locs_in_cache = [] - device = layer_weight.wte_weight_.device - dtype = layer_weight.wte_weight_.dtype + device = layer_weight.wte_weight_.weight.device + dtype = layer_weight.wte_weight_.weight.dtype hidden_size = layer_weight.wte_weight_.shape[1] for batch_id, p in enumerate(infer_state.multimodal_params): @@ -73,8 +73,8 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei img_token_lens=img_token_lens, img_start_token_ids=img_start_token_ids, img_start_locs_in_cache=img_start_locs_in_cache, - tp_text_start_token_id=self.vob_start_id_, - tp_text_end_token_id=self.vob_end_id_, + tp_text_start_token_id=layer_weight.wte_weight_.tp_vocab_start_id, + tp_text_end_token_id=layer_weight.wte_weight_.tp_vocab_end_id, tp_world_size=self.tp_world_size_, ) if self.tp_world_size_ > 1: diff --git a/lightllm/models/stablelm/layer_infer/transformer_layer_infer.py b/lightllm/models/stablelm/layer_infer/transformer_layer_infer.py index df64f2bcc..395ed4ba1 100755 --- a/lightllm/models/stablelm/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/stablelm/layer_infer/transformer_layer_infer.py @@ -1,7 +1,4 @@ import torch -import torch.functional as F -import torch.distributed as dist -import numpy as np from functools import partial from typing import Tuple from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd diff --git a/lightllm/models/stablelm/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/stablelm/layer_weights/pre_and_post_layer_weight.py index 80966c7b4..0ad3e07df 100755 --- a/lightllm/models/stablelm/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/stablelm/layer_weights/pre_and_post_layer_weight.py @@ -1,34 +1,12 @@ -import torch -import numpy as np -from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight +from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight, NoTpNormWeight class StableLMPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) - return - - def load_hf_weights(self, weights): - vob_size = self.network_config_["vocab_size"] - split_indexes = np.linspace(0, vob_size, self.tp_world_size_ + 1, dtype=np.int64) - split_start = split_indexes[self.tp_rank_] - split_end = split_indexes[self.tp_rank_ + 1] - if "model.embed_tokens.weight" in weights: - # print(weights['model.embed_tokens.weight'].shape) - self.wte_weight_ = self._cuda(weights["model.embed_tokens.weight"][split_start:split_end, :]) - if "lm_head.weight" in weights: - # print(weights['lm_head.weight'].shape) - self.lm_head_weight_ = self._cuda(weights["lm_head.weight"][split_start:split_end, :]) - if "model.norm.weight" in weights: - self.final_norm_weight_ = self._cuda(weights["model.norm.weight"]) - if "model.norm.bias" in weights: - self.final_norm_bias_ = self._cuda(weights["model.norm.bias"]) - - return - - def verify_load(self): - errors = "weights load not ok" - weights = [self.wte_weight_, self.lm_head_weight_, self.final_norm_weight_, self.final_norm_bias_] - for i in range(len(weights)): - assert weights[i] is not None, "index:" + str(i) + " " + errors + self.final_norm_weight_ = NoTpNormWeight( + weight_name="model.norm.weight", + data_type=self.data_type_, + bias_name="model.norm.bias", + ) return diff --git a/lightllm/models/starcoder/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/starcoder/layer_weights/pre_and_post_layer_weight.py index c44725382..d5bdd79a7 100644 --- a/lightllm/models/starcoder/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/starcoder/layer_weights/pre_and_post_layer_weight.py @@ -1,5 +1,3 @@ -import torch -import numpy as np from lightllm.common.basemodel import PreAndPostLayerWeight from lightllm.common.basemodel.layer_weights.meta_weights import ( EmbeddingWeight, diff --git a/lightllm/models/starcoder2/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/starcoder2/layer_weights/pre_and_post_layer_weight.py index fd2d47575..28a26cb4b 100644 --- a/lightllm/models/starcoder2/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/starcoder2/layer_weights/pre_and_post_layer_weight.py @@ -1,37 +1,28 @@ import numpy as np from lightllm.common.basemodel import PreAndPostLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, LMHeadWeight, NoTpNormWeight class Starcoder2PreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) - return - - def load_hf_weights(self, weights): - vob_size = self.network_config_["vocab_size"] - split_indexes = np.linspace(0, vob_size, self.tp_world_size_ + 1, dtype=np.int64) - split_start = split_indexes[self.tp_rank_] - split_end = split_indexes[self.tp_rank_ + 1] - if "model.embed_tokens.weight" in weights: - self.wte_weight_ = self._cuda(weights["model.embed_tokens.weight"][split_start:split_end, :]) - - # for starcoder2-3b and 7b which didn't use lm_head.weight (tie_word_embeddings) - self.lm_head_weight_ = self.wte_weight_ - - if "lm_head.weight" in weights: - self.lm_head_weight_ = self._cuda(weights["lm_head.weight"][split_start:split_end, :]) - - if "model.norm.weight" in weights: - self.final_norm_weight_ = self._cuda(weights["model.norm.weight"]) - - if "model.norm.bias" in weights: - self.final_norm_bias_ = self._cuda(weights["model.norm.bias"]) - - return - def verify_load(self): - errors = "weights load not ok" - weights = [self.wte_weight_, self.lm_head_weight_, self.final_norm_weight_, self.final_norm_bias_] - for i in range(len(weights)): - assert weights[i] is not None, "index:" + str(i) + " " + errors + self.wte_weight_ = EmbeddingWeight( + weight_name="model.embed_tokens.weight", + data_type=self.data_type_, + ) + tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False) + if tie_word_embeddings: + self.lm_head_weight_: LMHeadWeight = self.wte_weight_ + else: + self.lm_head_weight_ = LMHeadWeight( + weight_name="lm_head.weight", + data_type=self.data_type_, + ) + + self.final_norm_weight_ = NoTpNormWeight( + weight_name="model.norm.weight", + data_type=self.data_type_, + bias_name="model.norm.bias", + ) return From 87a214faba34066774f3cf55f94938d818c820da Mon Sep 17 00:00:00 2001 From: root Date: Sat, 3 Jan 2026 10:04:34 +0000 Subject: [PATCH 62/79] fix att sink weight --- .../meta_weights/att_sink_weight.py | 31 ++++++++++++------- .../layer_weights/transformer_layer_weight.py | 2 -- 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/att_sink_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/att_sink_weight.py index 7d785ea82..d426dcfbc 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/att_sink_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/att_sink_weight.py @@ -5,21 +5,30 @@ class TpAttSinkWeight(BaseWeightTpl): - def __init__(self, weight_name: str, data_type, tp_head_num: int): + def __init__(self, weight_name: str, data_type): super().__init__() self.weight_name = weight_name self.data_type_ = data_type self.weight: torch.Tensor = None - self.tp_head_num = tp_head_num - assert self.tp_head_num > 0 def load_hf_weights(self, weights: Dict[str, torch.Tensor]): - start_head_index = self.tp_head_num * self.tp_rank_ - end_head_index = self.tp_head_num * (self.tp_rank_ + 1) + if self.weight_name not in weights or self.weight is not None: + return - if self.weight_name in weights: - self.weight = ( - weights[self.weight_name][start_head_index:end_head_index] - .to(self.data_type_) - .cuda(get_current_device_id()) - ) + t_weight = weights[self.weight_name] + all_head_num = t_weight.shape[0] + tp_head_num = all_head_num // self.tp_world_size_ + + if tp_head_num > 0: + start_head_index = self.tp_rank_ * tp_head_num + end_head_index = (self.tp_rank_ + 1) * tp_head_num + else: + # 当 tp_world_size 大于 all_head_num 时的特殊处理 + scale_size = self.tp_world_size_ // all_head_num + assert self.tp_world_size_ % all_head_num == 0 + start_head_index = self.tp_rank_ // scale_size + end_head_index = start_head_index + 1 + + self.weight = ( + weights[self.weight_name][start_head_index:end_head_index].to(self.data_type_).cuda(get_current_device_id()) + ) diff --git a/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py b/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py index fb3c66e22..f6a841b1a 100644 --- a/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py @@ -68,11 +68,9 @@ def _init_weight_names(self): def _init_weight(self): super()._init_weight() - n_split_head = self.network_config_["num_attention_heads"] // self.tp_world_size_ self.attn_sinks = TpAttSinkWeight( weight_name=f"model.layers.{self.layer_num_}.self_attn.sinks", data_type=torch.bfloat16, - tp_head_num=n_split_head, ) def _init_ffn(self): From f36d861ff4f0f779f8d171949a01bc44327a98e0 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Sat, 3 Jan 2026 10:17:50 +0000 Subject: [PATCH 63/79] fix embeding weight --- .../meta_weights/att_sink_weight.py | 3 ++ .../meta_weights/embedding_weight.py | 52 +++++++++---------- 2 files changed, 27 insertions(+), 28 deletions(-) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/att_sink_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/att_sink_weight.py index d426dcfbc..161c362b9 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/att_sink_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/att_sink_weight.py @@ -32,3 +32,6 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]): self.weight = ( weights[self.weight_name][start_head_index:end_head_index].to(self.data_type_).cuda(get_current_device_id()) ) + + def verify_load(self): + return self.weight is not None diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py index 1ea2d1331..fc018267f 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py @@ -17,27 +17,24 @@ def __init__(self, weight_name, data_type): self.weight: torch.Tensor = None def load_hf_weights(self, weights: Dict[str, torch.Tensor]): - if self.weight_name in weights and self.weight is None: - t_weight = weights[self.weight_name] - # init some params - self.vocab_size = len(t_weight) - split_indexes = np.linspace(0, self.vocab_size, self.tp_world_size_ + 1, dtype=np.int64) - self.tp_vocab_start_id = int(split_indexes[self.tp_rank_]) - self.tp_vocab_end_id = int(split_indexes[self.tp_rank_ + 1]) - - logger.info(f"loaded weight vocab_size: {self.vocab_size}") - - self.weight = ( - t_weight[self.tp_vocab_start_id : self.tp_vocab_end_id, :] - .to(self.data_type_) - .cuda(get_current_device_id()) - ) + if self.weight_name not in weights or self.weight is not None: + return - def verify_load(self): - load_ok = True - load_ok = load_ok and self.weight is not None + t_weight = weights[self.weight_name] + # init some params + self.vocab_size = len(t_weight) + split_indexes = np.linspace(0, self.vocab_size, self.tp_world_size_ + 1, dtype=np.int64) + self.tp_vocab_start_id = int(split_indexes[self.tp_rank_]) + self.tp_vocab_end_id = int(split_indexes[self.tp_rank_ + 1]) + + logger.info(f"loaded weight vocab_size: {self.vocab_size}") + + self.weight = ( + t_weight[self.tp_vocab_start_id : self.tp_vocab_end_id, :].to(self.data_type_).cuda(get_current_device_id()) + ) - return load_ok + def verify_load(self): + return self.weight is not None def embedding(self, input_ids: torch.Tensor, out: Optional[torch.Tensor] = None, alloc_func=torch.empty): if out is None: @@ -83,17 +80,16 @@ def __init__(self, weight_name, data_type): self.tp_rank_ = 0 def load_hf_weights(self, weights: Dict[str, torch.Tensor]): - if self.weight_name in weights and self.weight is None: - t_weight = weights[self.weight_name] - self.weight = t_weight.to(self.data_type_).cuda(get_current_device_id()) - self.end_position_id: int = t_weight.shape[0] - logger.info(f"loaded weight end_position_id: {self.end_position_id}") + if self.weight_name not in weights or self.weight is not None: + return - def verify_load(self): - load_ok = True - load_ok = load_ok and self.weight is not None + t_weight = weights[self.weight_name] + self.weight = t_weight.to(self.data_type_).cuda(get_current_device_id()) + self.end_position_id: int = t_weight.shape[0] + logger.info(f"loaded weight end_position_id: {self.end_position_id}") - return load_ok + def verify_load(self): + return self.weight is not None def embedding(self, input_ids: torch.Tensor, out: Optional[torch.Tensor] = None, alloc_func=torch.empty): if out is None: From d43ced90ca8de3036ae05aa4ec0581ab878a9b28 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Sat, 3 Jan 2026 10:47:27 +0000 Subject: [PATCH 64/79] fix tpnorm params --- .../meta_weights/att_sink_weight.py | 18 ++--------- .../layer_weights/meta_weights/base_weight.py | 29 ++++++++++++++++- .../layer_weights/meta_weights/norm_weight.py | 32 +++++++++---------- .../layer_weights/transformer_layer_weight.py | 7 ++-- 4 files changed, 48 insertions(+), 38 deletions(-) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/att_sink_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/att_sink_weight.py index 161c362b9..3f8e1f50a 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/att_sink_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/att_sink_weight.py @@ -16,22 +16,8 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]): return t_weight = weights[self.weight_name] - all_head_num = t_weight.shape[0] - tp_head_num = all_head_num // self.tp_world_size_ - - if tp_head_num > 0: - start_head_index = self.tp_rank_ * tp_head_num - end_head_index = (self.tp_rank_ + 1) * tp_head_num - else: - # 当 tp_world_size 大于 all_head_num 时的特殊处理 - scale_size = self.tp_world_size_ // all_head_num - assert self.tp_world_size_ % all_head_num == 0 - start_head_index = self.tp_rank_ // scale_size - end_head_index = start_head_index + 1 - - self.weight = ( - weights[self.weight_name][start_head_index:end_head_index].to(self.data_type_).cuda(get_current_device_id()) - ) + start_head_index, end_head_index = self._get_head_tp_split_params(weight=t_weight) + self.weight = t_weight[start_head_index:end_head_index].to(self.data_type_).cuda(get_current_device_id()) def verify_load(self): return self.weight is not None diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py index 0d1a1e475..2cd8ea6ae 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py @@ -1,6 +1,6 @@ import torch from abc import ABC, abstractmethod -from typing import Dict +from typing import Dict, Tuple from lightllm.utils.dist_utils import get_dp_world_size, get_current_rank_in_dp, get_current_device_id @@ -29,3 +29,30 @@ def load_hf_weights(self, weights): def verify_load(self) -> bool: raise NotImplementedError("verify_load must implement this method") + + def _get_head_tp_split_params(self, weight: torch.Tensor) -> Tuple[int, int]: + """ + Docstring for _get_head_tp_split_params, + 一个常用的tp 划分head获取head_index 范围的功能函数, 一些继承类可能会使用。 + :param self: Description + :param weight: Description + :type weight: torch.Tensor + :return: Description + :rtype: Tuple[int, int] + """ + assert weight.ndim == 2 + + all_head_num = weight.shape[0] + tp_head_num = all_head_num // self.tp_world_size_ + + if tp_head_num > 0: + start_head_index = self.tp_rank_ * tp_head_num + end_head_index = (self.tp_rank_ + 1) * tp_head_num + else: + # 当 tp_world_size 大于 all_head_num 时的特殊处理 + scale_size = self.tp_world_size_ // all_head_num + assert self.tp_world_size_ % all_head_num == 0 + start_head_index = self.tp_rank_ // scale_size + end_head_index = start_head_index + 1 + + return start_head_index, end_head_index diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py index 6c1fbc736..6458c7c2b 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py @@ -27,7 +27,7 @@ def verify_load(self): def rmsnorm_forward( self, input: torch.Tensor, eps: float, out: Optional[torch.Tensor] = None, alloc_func=torch.empty ) -> torch.Tensor: - assert input.ndim == 2 and self.weight.ndim == 1 + assert input.ndim in [2, 3] and self.weight.ndim == 1 assert self.bias is None if out is None: out = alloc_func(input.shape, dtype=input.dtype, device=input.device) @@ -54,9 +54,9 @@ def __init__(self, weight_name, data_type, bias_name=None): self.tp_rank_ = 0 def load_hf_weights(self, weights): - if self.weight_name in weights: + if self.weight_name in weights and self.weight is None: self.weight = weights[self.weight_name].to(self.data_type_).cuda(get_current_device_id()) - if self.bias_name in weights: + if self.bias_name in weights and self.bias is None: self.bias = weights[self.bias_name].to(self.data_type_).cuda(get_current_device_id()) @@ -68,7 +68,7 @@ def __init__(self, weight_name, data_type, bias_name=None): self.tp_rank_ = 0 def load_hf_weights(self, weights): - if self.weight_name in weights: + if self.weight_name in weights and self.weight is None: self.weight = (weights[self.weight_name] + 1).to(self.data_type_).cuda(get_current_device_id()) @@ -81,29 +81,29 @@ def load_hf_weights(self, weights): start = self.split_n_embed * self.tp_rank_ end = self.split_n_embed * (self.tp_rank_ + 1) - if self.weight_name in weights: + if self.weight_name in weights and self.weight is None: self.weight = weights[self.weight_name][start:end].to(self.data_type_).cuda(get_current_device_id()) - if self.bias_name in weights: + if self.bias_name in weights and self.bias is None: self.bias = weights[self.bias_name][start:end].to(self.data_type_).cuda(get_current_device_id()) class TpHeadNormWeight(_NormWeight): - def __init__(self, weight_name, data_type, tp_head_num, bias_name=None): + def __init__(self, weight_name, data_type, bias_name=None): super().__init__(weight_name, data_type, bias_name) - self.tp_head_num = tp_head_num - assert self.tp_head_num > 0 def load_hf_weights(self, weights): - start = self.tp_head_num * self.tp_rank_ - end = self.tp_head_num * (self.tp_rank_ + 1) - - if self.weight_name in weights: + if self.weight_name in weights and self.weight is None: + t_weight = weights[self.weight_name] + start_head_index, end_head_index = self._get_head_tp_split_params(weight=t_weight) self.weight: torch.Tensor = ( - weights[self.weight_name][start:end].to(self.data_type_).cuda(get_current_device_id()) + t_weight[start_head_index:end_head_index].to(self.data_type_).cuda(get_current_device_id()) ) assert self.weight.ndim == 2 - if self.bias_name in weights: + + if self.bias_name in weights and self.bias is None: + t_bias = weights[self.bias_name] + start_head_index, end_head_index = self._get_head_tp_split_params(weight=t_bias) self.bias: torch.Tensor = ( - weights[self.bias_name][start:end].to(self.data_type_).cuda(get_current_device_id()) + t_bias[start_head_index:end_head_index].to(self.data_type_).cuda(get_current_device_id()) ) assert self.bias.ndim == 2 diff --git a/lightllm/models/cohere/layer_weights/transformer_layer_weight.py b/lightllm/models/cohere/layer_weights/transformer_layer_weight.py index 4af0ab8f0..9c446b49e 100644 --- a/lightllm/models/cohere/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/cohere/layer_weights/transformer_layer_weight.py @@ -12,17 +12,14 @@ def _parse_config(self): self.use_qk_norm = self.network_config_.get("use_qk_norm", False) def _init_norm(self): - q_split_head = self.network_config_["num_attention_heads"] // self.tp_world_size_ - k_split_head = self.network_config_["num_key_value_heads"] // self.tp_world_size_ - self.att_norm_weight_ = NoTpNormWeight(self._att_norm_weight_name, self.data_type_) if self.use_qk_norm: self.q_norm_weight_ = TpHeadNormWeight( - f"model.layers.{self.layer_num_}.self_attn.q_norm.weight", self.data_type_, q_split_head + f"model.layers.{self.layer_num_}.self_attn.q_norm.weight", self.data_type_ ) self.k_norm_weight_ = TpHeadNormWeight( - f"model.layers.{self.layer_num_}.self_attn.k_norm.weight", self.data_type_, k_split_head + f"model.layers.{self.layer_num_}.self_attn.k_norm.weight", self.data_type_ ) return From 59085b650bc02ca6d9d31028b5b8ed4b48161f93 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Sat, 3 Jan 2026 11:05:46 +0000 Subject: [PATCH 65/79] fix --- lightllm/common/basemodel/basemodel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 3b8d6bd5f..011f998fc 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -514,7 +514,7 @@ def _decode( infer_state.mem_index, ) infer_state.init_some_extra_state(self) - model_output = self._token_forward(model_input.input_ids, infer_state) + model_output = self._token_forward(infer_state) return model_output From 12813015282feeef60b496bc69b39a6e4231bd56 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Sat, 3 Jan 2026 11:16:54 +0000 Subject: [PATCH 66/79] fix diverset mtp only support no att mtp mode --- .../router/model_infer/mode_backend/diverse_backend/impl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py index 6791e778a..5a179cb62 100644 --- a/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py @@ -25,6 +25,7 @@ def __init__(self) -> None: if get_env_start_args().mtp_mode: # 当前只有 mistral mtp 可以使用 diverse mode 的 mtp 功能。 self.prefill = self.beam_prefill + assert get_env_start_args().mtp_mode in ["vanilla_no_att", "eagle_no_att"] else: self.prefill = self.beam_prefill From 5986d3323de1fc1920fb75bf3dc8013f729eacf3 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Sat, 3 Jan 2026 12:26:40 +0000 Subject: [PATCH 67/79] fix cohere --- .../cohere/layer_infer/post_layer_infer.py | 50 +++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/lightllm/models/cohere/layer_infer/post_layer_infer.py b/lightllm/models/cohere/layer_infer/post_layer_infer.py index e2fc47581..546394008 100644 --- a/lightllm/models/cohere/layer_infer/post_layer_infer.py +++ b/lightllm/models/cohere/layer_infer/post_layer_infer.py @@ -1,9 +1,11 @@ import torch +import numpy as np from lightllm.models.cohere.infer_struct import CohereInferStateInfo from lightllm.models.cohere.layer_weights.pre_and_post_layer_weight import CoherePreAndPostLayerWeight from lightllm.models.cohere.triton_kernels.layernorm import layernorm_forward from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer from lightllm.common.build_utils import repair_config +from lightllm.distributed.communication_op import all_gather class CoherePostLayerInfer(LlamaPostLayerInfer): @@ -11,6 +13,7 @@ def __init__(self, network_config, mode): repair_config(config=network_config, same_names=["layer_norm_eps", "rms_norm_eps"]) super().__init__(network_config, mode) self.eps_ = network_config["layer_norm_eps"] + self.logits_scale = network_config["logit_scale"] return def _norm( @@ -19,3 +22,50 @@ def _norm( return layernorm_forward( input.unsqueeze(1), layer_weight.final_norm_weight_.weight.unsqueeze(0), eps=self.eps_ ).squeeze(1) + + def token_forward( + self, input_embdings: torch.Tensor, infer_state: CohereInferStateInfo, layer_weight: CoherePreAndPostLayerWeight + ): + last_input, token_num = self._slice_get_last_input(input_embdings, infer_state) + input_embdings_dtype = input_embdings.dtype + input_embdings = None + last_input = self._norm(last_input, infer_state, layer_weight) + last_input = last_input.permute(1, 0).view(-1, token_num) + logic_batch = layer_weight.lm_head_weight_.lm_head(input=last_input, alloc_func=self.alloc_tensor) + last_input = None + vocab_size = layer_weight.lm_head_weight_.vocab_size + if self.tp_world_size_ == 1: + gather_data = logic_batch + else: + gather_data = self.alloc_tensor((vocab_size, token_num), dtype=input_embdings_dtype) + split_indexes = np.linspace(0, vocab_size, self.tp_world_size_ + 1, dtype=np.int64) + all_gather( + [gather_data[split_indexes[i] : split_indexes[i + 1], :] for i in range(self.tp_world_size_)], + logic_batch, + group=infer_state.dist_group, + async_op=False, + ) + gather_data *= self.logits_scale + logic_batch = None + ans_logics = self.alloc_tensor( + (token_num, vocab_size), + dtype=torch.float32, + ) + ans_logics[:, :] = gather_data.permute(1, 0) + gather_data = None + return ans_logics + + def tpsp_token_forward( + self, input_embdings: torch.Tensor, infer_state: CohereInferStateInfo, layer_weight: CoherePreAndPostLayerWeight + ): + raise NotImplementedError("not impl") + + def overlap_tpsp_token_forward( + self, + input_embdings: torch.Tensor, + input_embdings1: torch.Tensor, + infer_state: CohereInferStateInfo, + infer_state1: CohereInferStateInfo, + layer_weight: CoherePreAndPostLayerWeight, + ): + raise NotImplementedError("not impl") From 68724db836d862012a82e03c991a8d51142ff328 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Sat, 3 Jan 2026 12:30:04 +0000 Subject: [PATCH 68/79] fix --- lightllm/models/cohere/layer_infer/post_layer_infer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/models/cohere/layer_infer/post_layer_infer.py b/lightllm/models/cohere/layer_infer/post_layer_infer.py index 546394008..67987a8d3 100644 --- a/lightllm/models/cohere/layer_infer/post_layer_infer.py +++ b/lightllm/models/cohere/layer_infer/post_layer_infer.py @@ -45,7 +45,7 @@ def token_forward( group=infer_state.dist_group, async_op=False, ) - gather_data *= self.logits_scale + gather_data = gather_data * self.logits_scale logic_batch = None ans_logics = self.alloc_tensor( (token_num, vocab_size), From e7496bfe8a22c6d111ccdf7f81183f2f97e8030a Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Sat, 3 Jan 2026 14:24:56 +0000 Subject: [PATCH 69/79] review fix all --- .../deepseek_mtp/layer_weights/pre_and_post_layer_weight.py | 1 - .../layer_weights/pre_and_post_layer_weight.py | 4 +--- .../models/minicpm/layer_weights/pre_and_post_layer_weight.py | 2 +- .../qwen2_reward/layer_weights/pre_and_post_layer_weight.py | 2 -- .../qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py | 1 - lightllm/models/qwen3_moe_mtp/model.py | 4 ++-- 6 files changed, 4 insertions(+), 10 deletions(-) diff --git a/lightllm/models/deepseek_mtp/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/deepseek_mtp/layer_weights/pre_and_post_layer_weight.py index 582afb560..4a5bf2e96 100644 --- a/lightllm/models/deepseek_mtp/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/deepseek_mtp/layer_weights/pre_and_post_layer_weight.py @@ -14,7 +14,6 @@ def __init__(self, data_type, network_config, mode): self.eh_proj_weight_ = ROWMMWeight( weight_names="model.layers.0.eh_proj.weight", data_type=self.data_type_, - layer_num=0, name="eh_proj", tp_rank=0, tp_world_size=1, diff --git a/lightllm/models/internlm2_reward/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/internlm2_reward/layer_weights/pre_and_post_layer_weight.py index f64fb23e0..b20b9c495 100644 --- a/lightllm/models/internlm2_reward/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/internlm2_reward/layer_weights/pre_and_post_layer_weight.py @@ -13,9 +13,7 @@ def __init__(self, data_type, network_config, mode): self.score_head_ = ROWMMWeight( weight_names="v_head.weight", data_type=self.data_type_, - quant_cfg=None, - layer_num=0, - name="kv_a_proj_with_mqa", + name="score_head", tp_rank=0, tp_world_size=1, ) diff --git a/lightllm/models/minicpm/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/minicpm/layer_weights/pre_and_post_layer_weight.py index 71fae2b7b..0952468d0 100644 --- a/lightllm/models/minicpm/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/minicpm/layer_weights/pre_and_post_layer_weight.py @@ -16,7 +16,7 @@ def verify_load(self): self.lm_head_weight_ = copy.copy(self.lm_head_weight_) self.lm_head_weight_.weight = self.lm_head_weight_.weight / self.lm_head_scale - self.wte_weight_ = self.wte_weight_ * self.scale_emb + self.wte_weight_.weight = self.wte_weight_.weight * self.scale_emb errors = "weights load not ok" weights = [self.wte_weight_, self.lm_head_weight_, self.final_norm_weight_] for i in range(len(weights)): diff --git a/lightllm/models/qwen2_reward/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen2_reward/layer_weights/pre_and_post_layer_weight.py index 0ea4e5820..7cf636622 100644 --- a/lightllm/models/qwen2_reward/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/qwen2_reward/layer_weights/pre_and_post_layer_weight.py @@ -12,7 +12,6 @@ def __init__(self, data_type, network_config, mode): weight_names="score.0.weight", bias_names="score.0.bias", data_type=self.data_type_, - layer_num=0, name="score_up_weight", tp_rank=0, tp_world_size=1, @@ -21,7 +20,6 @@ def __init__(self, data_type, network_config, mode): weight_names="score.2.weight", bias_names="score.2.bias", data_type=self.data_type_, - layer_num=0, name="score_down_weight", tp_rank=0, tp_world_size=1, diff --git a/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py index 5431314aa..6cc447a59 100644 --- a/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py @@ -15,7 +15,6 @@ def __init__(self, data_type, network_config, mode): self.eh_proj_weight_ = ROWMMWeight( weight_names="model.layers.0.proj.weight", data_type=self.data_type_, - layer_num=0, name="eh_proj", tp_rank=0, tp_world_size=1, diff --git a/lightllm/models/qwen3_moe_mtp/model.py b/lightllm/models/qwen3_moe_mtp/model.py index 1552695cd..72aadbda8 100644 --- a/lightllm/models/qwen3_moe_mtp/model.py +++ b/lightllm/models/qwen3_moe_mtp/model.py @@ -43,8 +43,8 @@ def _init_weights(self, start_layer_index=None): mtp_index = len(self.mtp_previous_draft_models) super()._init_weights(start_layer_index=mtp_index) self.pre_post_weight.wte_weight_ = self.main_model.pre_post_weight.wte_weight_ - # self.pre_post_weight.lm_head_weight_ = self.main_model.pre_post_weight.lm_head_weight_ - # self.pre_post_weight.final_norm_weight_ = self.main_model.pre_post_weight.final_norm_weight_ + self.pre_post_weight.lm_head_weight_ = self.main_model.pre_post_weight.lm_head_weight_ + self.pre_post_weight.final_norm_weight_ = self.main_model.pre_post_weight.final_norm_weight_ return def _init_infer_layer(self, start_layer_index=None): From 9766c386fcb2ba4fe1ebd5672ad71844ad6f62d6 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Sun, 4 Jan 2026 04:26:46 +0000 Subject: [PATCH 70/79] fix prefill dp banlance feature --- lightllm/common/basemodel/infer_struct.py | 3 +++ lightllm/models/llama/layer_infer/post_layer_infer.py | 6 ------ 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/lightllm/common/basemodel/infer_struct.py b/lightllm/common/basemodel/infer_struct.py index 305e24d83..8e7174bb3 100755 --- a/lightllm/common/basemodel/infer_struct.py +++ b/lightllm/common/basemodel/infer_struct.py @@ -213,6 +213,9 @@ def prefill_dp_balance(self, input_ids: torch.Tensor): self.position_sin = self._all_to_all_balance_get(self.position_sin) + self._unbalance_input_ids = self.input_ids + self.input_ids = new_input_ids + return new_input_ids def _all_to_all_balance_get(self, data: torch.Tensor): diff --git a/lightllm/models/llama/layer_infer/post_layer_infer.py b/lightllm/models/llama/layer_infer/post_layer_infer.py index d9b82cae7..7c7b0ea39 100644 --- a/lightllm/models/llama/layer_infer/post_layer_infer.py +++ b/lightllm/models/llama/layer_infer/post_layer_infer.py @@ -124,18 +124,12 @@ def overlap_tpsp_token_forward( infer_state.hook() infer_state.hook = None - if infer_state.need_dp_prefill_balance: - input_embdings = infer_state._all_to_all_unbalance_get(data=input_embdings) - logics = self.tpsp_token_forward(input_embdings, infer_state, layer_weight=layer_weight) if getattr(infer_state1, "hook", None) is not None: infer_state1.hook() infer_state1.hook = None - if infer_state1.need_dp_prefill_balance: - input_embdings1 = infer_state1._all_to_all_unbalance_get(data=input_embdings1) - logics1 = self.tpsp_token_forward(input_embdings1, infer_state1, layer_weight=layer_weight) return logics, logics1 From e12635426ae2c4b007ef052014b29bf59e1ce1be Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Sun, 4 Jan 2026 04:55:49 +0000 Subject: [PATCH 71/79] add test model acc sh --- test/acc/test_deepseekr1.sh | 5 +++++ test/acc/test_deepseekr1_mtp.sh | 3 +++ test/acc/test_deepseekr1_mtp_ep.sh | 3 +++ test/acc/test_qwen3.sh | 5 +++++ 4 files changed, 16 insertions(+) create mode 100644 test/acc/test_deepseekr1.sh create mode 100644 test/acc/test_deepseekr1_mtp.sh create mode 100644 test/acc/test_deepseekr1_mtp_ep.sh create mode 100644 test/acc/test_qwen3.sh diff --git a/test/acc/test_deepseekr1.sh b/test/acc/test_deepseekr1.sh new file mode 100644 index 000000000..e167303a3 --- /dev/null +++ b/test/acc/test_deepseekr1.sh @@ -0,0 +1,5 @@ +LOADWORKER=18 python -m lightllm.server.api_server --model_dir /mtc/models/DeepSeek-R1 --tp 8 --port 8089 --enable_fa3 + + + +HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"deepseek-ai/DeepSeek-R1", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 500 --confirm_run_unsafe_code \ No newline at end of file diff --git a/test/acc/test_deepseekr1_mtp.sh b/test/acc/test_deepseekr1_mtp.sh new file mode 100644 index 000000000..046314a72 --- /dev/null +++ b/test/acc/test_deepseekr1_mtp.sh @@ -0,0 +1,3 @@ +LOADWORKER=18 python -m lightllm.server.api_server --model_dir /mtc/models/DeepSeek-R1 --tp 8 --port 8089 --mem_fraction 0.75 --enable_fa3 --batch_max_tokens 6000 --mtp_mode eagle_with_att --mtp_draft_model_dir /mtc/models/DeepSeek-R1-NextN --mtp_step 2 + +HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"deepseek-ai/DeepSeek-R1", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 500 --confirm_run_unsafe_code \ No newline at end of file diff --git a/test/acc/test_deepseekr1_mtp_ep.sh b/test/acc/test_deepseekr1_mtp_ep.sh new file mode 100644 index 000000000..4c33d75c8 --- /dev/null +++ b/test/acc/test_deepseekr1_mtp_ep.sh @@ -0,0 +1,3 @@ +LOADWORKER=18 MOE_MODE=EP NUM_MAX_DISPATCH_TOKENS_PER_RANK=800 python -m lightllm.server.api_server --model_dir /mtc/models/DeepSeek-R1 --tp 8 --dp 8 --port 8089 --mem_fraction 0.75 --graph_max_batch_size 16 --enable_fa3 --batch_max_tokens 6000 --mtp_mode eagle_with_att --mtp_draft_model_dir /mtc/models/DeepSeek-R1-NextN --mtp_step 2 + +HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"deepseek-ai/DeepSeek-R1", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 100 --confirm_run_unsafe_code \ No newline at end of file diff --git a/test/acc/test_qwen3.sh b/test/acc/test_qwen3.sh new file mode 100644 index 000000000..c0da5ec96 --- /dev/null +++ b/test/acc/test_qwen3.sh @@ -0,0 +1,5 @@ +# first +LOADWORKER=18 CUDA_VISIBLE_DEVICES=6,7 python -m lightllm.server.api_server --model_dir /mtc/models/qwen3-8b --tp 2 --port 8089 --enable_fa3 + +# second +HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"qwen/qwen3-8b", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 500 --confirm_run_unsafe_code \ No newline at end of file From 7d310e567e4d3a8ee6abda4bef1a41100018c0b7 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Sun, 4 Jan 2026 04:58:45 +0000 Subject: [PATCH 72/79] fix --- test/acc/test_deepseekr1_mtp_ep.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/acc/test_deepseekr1_mtp_ep.sh b/test/acc/test_deepseekr1_mtp_ep.sh index 4c33d75c8..74c0fde0f 100644 --- a/test/acc/test_deepseekr1_mtp_ep.sh +++ b/test/acc/test_deepseekr1_mtp_ep.sh @@ -1,3 +1,3 @@ -LOADWORKER=18 MOE_MODE=EP NUM_MAX_DISPATCH_TOKENS_PER_RANK=800 python -m lightllm.server.api_server --model_dir /mtc/models/DeepSeek-R1 --tp 8 --dp 8 --port 8089 --mem_fraction 0.75 --graph_max_batch_size 16 --enable_fa3 --batch_max_tokens 6000 --mtp_mode eagle_with_att --mtp_draft_model_dir /mtc/models/DeepSeek-R1-NextN --mtp_step 2 +LOADWORKER=18 MOE_MODE=EP NUM_MAX_DISPATCH_TOKENS_PER_RANK=256 python -m lightllm.server.api_server --model_dir /mtc/models/DeepSeek-R1 --tp 8 --dp 8 --port 8089 --mem_fraction 0.6 --graph_max_batch_size 16 --enable_fa3 --batch_max_tokens 6000 --mtp_mode eagle_with_att --mtp_draft_model_dir /mtc/models/DeepSeek-R1-NextN --mtp_step 2 -HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"deepseek-ai/DeepSeek-R1", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 100 --confirm_run_unsafe_code \ No newline at end of file +HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"deepseek-ai/DeepSeek-R1", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 32 --confirm_run_unsafe_code \ No newline at end of file From e5335f19ea9f9faebef0f9637424066e9554d1ca Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Sun, 4 Jan 2026 05:10:49 +0000 Subject: [PATCH 73/79] fix sh --- test/acc/test_deepseekr1_mtp_ep.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/acc/test_deepseekr1_mtp_ep.sh b/test/acc/test_deepseekr1_mtp_ep.sh index 74c0fde0f..2ea5f7438 100644 --- a/test/acc/test_deepseekr1_mtp_ep.sh +++ b/test/acc/test_deepseekr1_mtp_ep.sh @@ -1,3 +1,3 @@ -LOADWORKER=18 MOE_MODE=EP NUM_MAX_DISPATCH_TOKENS_PER_RANK=256 python -m lightllm.server.api_server --model_dir /mtc/models/DeepSeek-R1 --tp 8 --dp 8 --port 8089 --mem_fraction 0.6 --graph_max_batch_size 16 --enable_fa3 --batch_max_tokens 6000 --mtp_mode eagle_with_att --mtp_draft_model_dir /mtc/models/DeepSeek-R1-NextN --mtp_step 2 +LOADWORKER=18 MOE_MODE=EP NUM_MAX_DISPATCH_TOKENS_PER_RANK=256 python -m lightllm.server.api_server --model_dir /mtc/models/DeepSeek-R1 --tp 8 --dp 8 --port 8089 --max_total_token_num 60000 --graph_max_batch_size 16 --enable_fa3 --batch_max_tokens 6000 --mtp_mode eagle_with_att --mtp_draft_model_dir /mtc/models/DeepSeek-R1-NextN --mtp_step 2 HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"deepseek-ai/DeepSeek-R1", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 32 --confirm_run_unsafe_code \ No newline at end of file From e3a0b4a5fa17e7bfab9950050c0f17fae98e531f Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Sun, 4 Jan 2026 05:25:37 +0000 Subject: [PATCH 74/79] add test_qwen2.sh --- test/acc/test_qwen2.sh | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 test/acc/test_qwen2.sh diff --git a/test/acc/test_qwen2.sh b/test/acc/test_qwen2.sh new file mode 100644 index 000000000..7b89086ad --- /dev/null +++ b/test/acc/test_qwen2.sh @@ -0,0 +1,5 @@ +# first +LOADWORKER=18 CUDA_VISIBLE_DEVICES=6,7 python -m lightllm.server.api_server --model_dir /root/.cache/huggingface/hub/models--Qwen--Qwen2.5-0.5B/snapshots/060db6499f32faf8b98477b0a26969ef7d8b9987 --tp 2 --port 8089 --enable_fa3 + +# second +HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"Qwen/Qwen2.5-0.5B", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 500 --confirm_run_unsafe_code \ No newline at end of file From 8bee7eaa781d6c5f6163453b3da0904cbb8e685a Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Sun, 4 Jan 2026 05:31:00 +0000 Subject: [PATCH 75/79] fix sh --- test/acc/test_qwen2.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/acc/test_qwen2.sh b/test/acc/test_qwen2.sh index 7b89086ad..60f19efe3 100644 --- a/test/acc/test_qwen2.sh +++ b/test/acc/test_qwen2.sh @@ -1,5 +1,5 @@ # first -LOADWORKER=18 CUDA_VISIBLE_DEVICES=6,7 python -m lightllm.server.api_server --model_dir /root/.cache/huggingface/hub/models--Qwen--Qwen2.5-0.5B/snapshots/060db6499f32faf8b98477b0a26969ef7d8b9987 --tp 2 --port 8089 --enable_fa3 +LOADWORKER=18 CUDA_VISIBLE_DEVICES=6,7 python -m lightllm.server.api_server --model_dir /root/.cache/huggingface/hub/models--Qwen--Qwen2.5-7B-Instruct/snapshots/a09a35458c702b33eeacc393d103063234e8bc28 --tp 2 --port 8089 --enable_fa3 # second -HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"Qwen/Qwen2.5-0.5B", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 500 --confirm_run_unsafe_code \ No newline at end of file +HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"Qwen/Qwen2.5-7B-Instruct", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 500 --confirm_run_unsafe_code \ No newline at end of file From 76ff894c961230bc4a246f5090e1e01936af6621 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Sun, 4 Jan 2026 05:46:25 +0000 Subject: [PATCH 76/79] fix sh --- test/acc/test_qwen2.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/acc/test_qwen2.sh b/test/acc/test_qwen2.sh index 60f19efe3..265d679e8 100644 --- a/test/acc/test_qwen2.sh +++ b/test/acc/test_qwen2.sh @@ -1,5 +1,5 @@ # first -LOADWORKER=18 CUDA_VISIBLE_DEVICES=6,7 python -m lightllm.server.api_server --model_dir /root/.cache/huggingface/hub/models--Qwen--Qwen2.5-7B-Instruct/snapshots/a09a35458c702b33eeacc393d103063234e8bc28 --tp 2 --port 8089 --enable_fa3 +LOADWORKER=18 CUDA_VISIBLE_DEVICES=6,7 python -m lightllm.server.api_server --model_dir /root/.cache/huggingface/hub/models--Qwen--Qwen2.5-Math-7B-Instruct/snapshots/ef9926d75ab1d54532f6a30dd5e760355eb9aa4d --tp 2 --port 8089 --enable_fa3 # second -HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"Qwen/Qwen2.5-7B-Instruct", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 500 --confirm_run_unsafe_code \ No newline at end of file +HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"Qwen/Qwen2.5-Math-7B-Instruct", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 500 --confirm_run_unsafe_code \ No newline at end of file From 00ad3b5780e0184dbfa4a49e31b9b3ce0a27695f Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Sun, 4 Jan 2026 06:03:10 +0000 Subject: [PATCH 77/79] fix unittest --- .../triton_kernel/test_gen_decode_params.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/unit_tests/common/basemodel/triton_kernel/test_gen_decode_params.py b/unit_tests/common/basemodel/triton_kernel/test_gen_decode_params.py index c7c6a844d..5c3ca89c6 100644 --- a/unit_tests/common/basemodel/triton_kernel/test_gen_decode_params.py +++ b/unit_tests/common/basemodel/triton_kernel/test_gen_decode_params.py @@ -1,11 +1,21 @@ import torch import pytest -import numpy as np -from lightllm.utils.log_utils import init_logger +import easydict from lightllm.common.basemodel.triton_kernel.gen_decode_params import gen_decode_params +from lightllm.utils.envs_utils import set_env_start_args def test_gen_decode_params_basic(): + set_env_start_args( + easydict.EasyDict( + { + "mtp_step": 0, + "enable_flashinfer_prefill": False, + "enable_flashinfer_decode": False, + } + ) + ) + b_seq_len = torch.ones((9,), dtype=torch.int64, device="cuda") * 8192 ( b_q_seq_len, From 99fdb4cff96f3b8d80be680bdf6ae0ab38f38cf4 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Sun, 4 Jan 2026 06:51:50 +0000 Subject: [PATCH 78/79] fix vitnorm params --- .../layer_weights/meta_weights/__init__.py | 2 +- .../layer_weights/meta_weights/norm_weight.py | 50 ++++++++++++++++--- .../layer_weights/transformer_layer_weight.py | 15 ++---- 3 files changed, 47 insertions(+), 20 deletions(-) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py index 1441547aa..0fa02780c 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py @@ -6,7 +6,7 @@ COLMMWeight, ROWBMMWeight, ) -from .norm_weight import NoTpGEMMANormWeight, TpNormWeight, NoTpNormWeight, TpHeadNormWeight +from .norm_weight import NoTpGEMMANormWeight, TpVitPadNormWeight, NoTpNormWeight, TpHeadNormWeight from .fused_moe_weight_tp import create_tp_moe_wegiht_obj from .fused_moe_weight_ep import FusedMoeWeightEP from .embedding_weight import EmbeddingWeight, LMHeadWeight, NoTpPosEmbeddingWeight diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py index 6458c7c2b..5a595bff6 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py @@ -4,6 +4,9 @@ from lightllm.utils.dist_utils import get_current_device_id from lightllm.common.basemodel.triton_kernel.rmsnorm import rmsnorm_forward from lightllm.common.basemodel.triton_kernel.layernorm import layernorm_forward +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) class _NormWeight(BaseWeightTpl): @@ -72,19 +75,50 @@ def load_hf_weights(self, weights): self.weight = (weights[self.weight_name] + 1).to(self.data_type_).cuda(get_current_device_id()) -class TpNormWeight(_NormWeight): - def __init__(self, weight_name, data_type, split_n_embed, bias_name=None): +class TpVitPadNormWeight(_NormWeight): + def __init__(self, weight_name, data_type, head_num: int, bias_name=None): super().__init__(weight_name, data_type, bias_name) - self.split_n_embed = split_n_embed + self.head_num = head_num - def load_hf_weights(self, weights): - start = self.split_n_embed * self.tp_rank_ - end = self.split_n_embed * (self.tp_rank_ + 1) + def _pad_tensor_param(self, weight: torch.Tensor): + assert weight.ndim == 1 + hidden_size = weight.shape[0] + head_dim = hidden_size // self.head_num + assert hidden_size % self.head_num == 0 + + if self.head_num % self.tp_world_size_ == 0: + return weight + else: + logger.warning(f"padding {self.weight_name} weights in TpVitPadNormWeight") + pad_head_num = self.tp_world_size_ - (self.head_num % self.tp_world_size_) + pad_dims = pad_head_num * head_dim + weight = torch.nn.functional.pad(weight, (0, pad_dims), mode="constant", value=0.0) + return weight + def load_hf_weights(self, weights): if self.weight_name in weights and self.weight is None: - self.weight = weights[self.weight_name][start:end].to(self.data_type_).cuda(get_current_device_id()) + t_weight = weights[self.weight_name] + t_weight = self._pad_tensor_param(t_weight) + new_hidden_size = t_weight.shape[0] + split_n_embed = new_hidden_size // self.tp_world_size_ + assert new_hidden_size % self.tp_world_size_ == 0 + + start = split_n_embed * self.tp_rank_ + end = split_n_embed * (self.tp_rank_ + 1) + + self.weight = t_weight[start:end].to(self.data_type_).cuda(get_current_device_id()) + if self.bias_name in weights and self.bias is None: - self.bias = weights[self.bias_name][start:end].to(self.data_type_).cuda(get_current_device_id()) + t_bias = weights[self.bias_name] + t_bias = self._pad_tensor_param(t_bias) + new_hidden_size = t_bias.shape[0] + split_n_embed = new_hidden_size // self.tp_world_size_ + assert new_hidden_size % self.tp_world_size_ == 0 + + start = split_n_embed * self.tp_rank_ + end = split_n_embed * (self.tp_rank_ + 1) + + self.bias = t_bias[start:end].to(self.data_type_).cuda(get_current_device_id()) class TpHeadNormWeight(_NormWeight): diff --git a/lightllm/models/vit/layer_weights/transformer_layer_weight.py b/lightllm/models/vit/layer_weights/transformer_layer_weight.py index 881a378e2..c6024594e 100644 --- a/lightllm/models/vit/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/vit/layer_weights/transformer_layer_weight.py @@ -8,7 +8,7 @@ ROWMMWeight, COLMMWeight, NoTpNormWeight, - TpNormWeight, + TpVitPadNormWeight, ) from lightllm.utils.dist_utils import get_current_device_id @@ -126,10 +126,9 @@ def _init_norm(self): self._ffn_norm_weight_name, self.data_type_, bias_name=self._ffn_norm_bias_name ) if self.qk_norm: - n_embed = self.network_config_["hidden_size"] - split_n_embed = (n_embed + self.padding_hidden_size) // self.tp_world_size_ - self.q_norm_weight_ = TpNormWeight(self._q_norm_weight_name, self.data_type_, split_n_embed) - self.k_norm_weight_ = TpNormWeight(self._k_norm_weight_name, self.data_type_, split_n_embed) + head_num = self.network_config_["num_attention_heads"] + self.q_norm_weight_ = TpVitPadNormWeight(self._q_norm_weight_name, self.data_type_, head_num=head_num) + self.k_norm_weight_ = TpVitPadNormWeight(self._k_norm_weight_name, self.data_type_, head_num=head_num) def load_hf_weights(self, weights): if f"vision_model.encoder.layers.{self.layer_num_}.attn.qkv.weight" in weights: @@ -159,12 +158,6 @@ def load_hf_weights(self, weights): weights[self._v_bias_name] = v_bias_ del weights[f"vision_model.encoder.layers.{self.layer_num_}.attn.qkv.bias"] - if self.qk_norm and self._q_norm_weight_name in weights: - weights[self._q_norm_weight_name] = F.pad(weights[self._q_norm_weight_name], (0, self.padding_hidden_size)) - - if self.qk_norm and self._k_norm_weight_name in weights: - weights[self._k_norm_weight_name] = F.pad(weights[self._k_norm_weight_name], (0, self.padding_hidden_size)) - if f"vision_model.encoder.layers.{self.layer_num_}.ls1" in weights: ls1 = weights[f"vision_model.encoder.layers.{self.layer_num_}.ls1"] self.ls1 = self._cuda(ls1) From 38e92dd0840d8125440e149e03e65f43cb7c1de4 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Sun, 4 Jan 2026 07:11:06 +0000 Subject: [PATCH 79/79] fix vl weight --- lightllm/models/gemma3/layer_infer/pre_layer_infer.py | 4 ++-- lightllm/models/qwen3_vl/layer_infer/pre_layer_infer.py | 4 ++-- lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/lightllm/models/gemma3/layer_infer/pre_layer_infer.py b/lightllm/models/gemma3/layer_infer/pre_layer_infer.py index 20fc9aa9e..dc8a46ad9 100644 --- a/lightllm/models/gemma3/layer_infer/pre_layer_infer.py +++ b/lightllm/models/gemma3/layer_infer/pre_layer_infer.py @@ -18,7 +18,7 @@ def context_forward(self, input_ids, infer_state, layer_weight): img_start_locs_in_cache = [] device = layer_weight.wte_weight_.weight.device dtype = layer_weight.wte_weight_.weight.dtype - hidden_size = layer_weight.wte_weight_.shape[1] + hidden_size = layer_weight.wte_weight_.weight.shape[1] weight_mask = torch.zeros((len(input_ids)), dtype=torch.float32, device=device) # TODO @@ -65,7 +65,7 @@ def context_forward(self, input_ids, infer_state, layer_weight): multimodal_emb( out=out, prompt_ids=input_ids, - text_weight_embs=layer_weight.wte_weight_, + text_weight_embs=layer_weight.wte_weight_.weight, embed_cache=cpu_embed_cache_tensor, img_token_lens=img_token_lens, img_start_token_ids=img_start_token_ids, diff --git a/lightllm/models/qwen3_vl/layer_infer/pre_layer_infer.py b/lightllm/models/qwen3_vl/layer_infer/pre_layer_infer.py index f2f3f5e3d..96e453ebe 100644 --- a/lightllm/models/qwen3_vl/layer_infer/pre_layer_infer.py +++ b/lightllm/models/qwen3_vl/layer_infer/pre_layer_infer.py @@ -20,7 +20,7 @@ def context_forward( img_start_locs_in_cache = [] device = layer_weight.wte_weight_.weight.device dtype = layer_weight.wte_weight_.weight.dtype - hidden_size = layer_weight.wte_weight_.shape[1] + hidden_size = layer_weight.wte_weight_.weight.shape[1] for batch_id, p in enumerate(infer_state.multimodal_params): for img in p["images"] + p["audios"]: @@ -56,7 +56,7 @@ def context_forward( multimodal_emb( out=out, prompt_ids=input_ids, - text_weight_embs=layer_weight.wte_weight_, + text_weight_embs=layer_weight.wte_weight_.weight, embed_cache=cpu_embed_cache_tensor, img_token_lens=infer_state.img_token_lens, img_start_token_ids=infer_state.img_start_token_ids, diff --git a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py index c311e746e..f43907307 100644 --- a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py +++ b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py @@ -34,7 +34,7 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei img_start_locs_in_cache = [] device = layer_weight.wte_weight_.weight.device dtype = layer_weight.wte_weight_.weight.dtype - hidden_size = layer_weight.wte_weight_.shape[1] + hidden_size = layer_weight.wte_weight_.weight.shape[1] for batch_id, p in enumerate(infer_state.multimodal_params): for img in p["images"] + p["audios"]: @@ -68,7 +68,7 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei multimodal_emb( out=out, prompt_ids=input_ids, - text_weight_embs=layer_weight.wte_weight_, + text_weight_embs=layer_weight.wte_weight_.weight, embed_cache=cpu_embed_cache_tensor, img_token_lens=img_token_lens, img_start_token_ids=img_start_token_ids,