diff --git a/docs/CN/source/tutorial/api_server_args_zh.rst b/docs/CN/source/tutorial/api_server_args_zh.rst index 5c53dca84..ce7a79ab9 100755 --- a/docs/CN/source/tutorial/api_server_args_zh.rst +++ b/docs/CN/source/tutorial/api_server_args_zh.rst @@ -447,10 +447,12 @@ MTP 多预测参数 .. option:: --mtp_mode - 支持的 mtp 模式,建议使用 deepseekv3_eagle获得更好的性能体验,可选值: + 支持的 mtp 模式,建议使用 eagle_with_att获得更好的性能体验,可选值: - * ``deepseekv3_vanilla`` - * ``deepseekv3_eagle`` + * ``vanilla_with_att`` + * ``eagle_with_att`` + * ``vanilla_no_att`` + * ``eagle_no_att`` * ``None``: 不启用 mtp(默认) .. option:: --mtp_draft_model_dir diff --git a/docs/EN/source/tutorial/api_server_args_zh.rst b/docs/EN/source/tutorial/api_server_args_zh.rst index 1b72287b6..1644bbab5 100755 --- a/docs/EN/source/tutorial/api_server_args_zh.rst +++ b/docs/EN/source/tutorial/api_server_args_zh.rst @@ -444,10 +444,12 @@ MTP Multi-Prediction Parameters .. option:: --mtp_mode - Supported mtp modes, it is recommended to use deepseekv3_eagle for better performance, optional values: + Supported mtp modes, it is recommended to use eagle_with_att for better performance, optional values: - * ``deepseekv3_vanilla`` - * ``deepseekv3_eagle`` + * ``vanilla_with_att`` + * ``eagle_with_att`` + * ``vanilla_no_att`` + * ``eagle_no_att`` * ``None``: Do not enable mtp (default) .. option:: --mtp_draft_model_dir diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 2d4209028..011f998fc 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -24,7 +24,7 @@ from lightllm.common.basemodel.triton_kernel.gather_token_id import gather_token from lightllm.utils.log_utils import init_logger from lightllm.utils.dist_utils import get_dp_world_size -from lightllm.utils.envs_utils import get_env_start_args, get_llm_data_type +from lightllm.utils.envs_utils import get_env_start_args, get_llm_data_type, get_added_mtp_kv_layer_num from lightllm.distributed.communication_op import dist_group_manager from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput from lightllm.common.triton_utils.autotuner import AutotuneLevel @@ -89,7 +89,12 @@ def __init__(self, kvargs): self.tp_world_size_ = get_dp_world_size() self.enable_tpsp_mix_mode = get_env_start_args().enable_tpsp_mix_mode - self.is_deepseekv3_mtp_mode = self.args.mtp_mode in ["deepseekv3_vanilla", "deepseekv3_eagle"] + self.is_mtp_mode = self.args.mtp_mode in [ + "vanilla_with_att", + "eagle_with_att", + "vanilla_no_att", + "eagle_no_att", + ] self.prefill_graph: PrefillCudaGraph = None self._init_config() @@ -156,7 +161,7 @@ def _init_quant(self): self.quant_cfg = Quantcfg(self.config, self.quant_type, self.quant_cfg_path) logger.info(f"Initial quantization. " f"The default quantization method is {self.quant_cfg.quant_type}") - def _init_weights(self): + def _init_weights(self, start_layer_index=0): self.pre_post_weight = self.pre_and_post_weight_class( self.data_type, network_config=self.config, mode=self.mode ) @@ -168,7 +173,7 @@ def _init_weights(self): mode=self.mode, quant_cfg=self.quant_cfg, ) - for i in range(self.config["n_layer"]) + for i in range(start_layer_index, start_layer_index + self.config["n_layer"]) ] load_hf_weights( self.data_type, @@ -188,7 +193,7 @@ def _init_mem_manager(self): dtype=self.data_type, head_num=self.config["num_attention_heads"] // self.tp_world_size_, head_dim=self.config["n_embed"] // self.config["num_attention_heads"], - layer_num=self.config["n_layer"], + layer_num=self.config["n_layer"] + get_added_mtp_kv_layer_num(), mem_fraction=self.mem_fraction, ) return @@ -214,12 +219,12 @@ def _init_req_manager(self): self.req_manager = ReqManager(self.max_req_num, create_max_seq_len, self.mem_manager) return - def _init_infer_layer(self): + def _init_infer_layer(self, start_layer_index=0): self.pre_infer = self.pre_layer_infer_class(network_config=self.config, mode=self.mode) self.post_infer = self.post_layer_infer_class(network_config=self.config, mode=self.mode) self.layers_infer = [ self.transformer_layer_infer_class(i, network_config=self.config, mode=self.mode) - for i in range(self.config["n_layer"]) + for i in range(start_layer_index, start_layer_index + self.config["n_layer"]) ] return @@ -270,6 +275,7 @@ def forward(self, model_input: ModelInput): def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0): infer_state = self.infer_state_class() + infer_state.input_ids = model_input.input_ids infer_state.is_prefill = model_input.is_prefill infer_state.is_token_healing = self.is_token_healing infer_state.return_all_prompt_logics = self.return_all_prompt_logics @@ -303,7 +309,7 @@ def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0) infer_state.dist_group = dist_group_manager.get_group(microbatch_index) # 特殊模型,特殊模式的特定变量初始化操作。 - infer_state.deepseekv3_mtp_draft_input_hiddens = model_input.deepseekv3_mtp_draft_input_hiddens + infer_state.mtp_draft_input_hiddens = model_input.mtp_draft_input_hiddens return infer_state @@ -343,9 +349,9 @@ def _create_padded_decode_model_input(self, model_input: ModelInput, new_batch_s ) # 特殊模型,特殊模式的特殊变量的特殊 padding - if new_model_input.deepseekv3_mtp_draft_input_hiddens is not None: - new_model_input.deepseekv3_mtp_draft_input_hiddens = pad2dim_tensor_to_new_batch( - input=new_model_input.deepseekv3_mtp_draft_input_hiddens, + if new_model_input.mtp_draft_input_hiddens is not None: + new_model_input.mtp_draft_input_hiddens = pad2dim_tensor_to_new_batch( + input=new_model_input.mtp_draft_input_hiddens, new_batch_size=new_batch_size, ) @@ -388,9 +394,9 @@ def _create_padded_prefill_model_input(self, model_input: ModelInput, new_handle ] # 特殊模型,特殊模式的特殊变量的特殊 padding - if new_model_input.deepseekv3_mtp_draft_input_hiddens is not None: - new_model_input.deepseekv3_mtp_draft_input_hiddens = pad2dim_tensor_to_new_batch( - input=new_model_input.deepseekv3_mtp_draft_input_hiddens, + if new_model_input.mtp_draft_input_hiddens is not None: + new_model_input.mtp_draft_input_hiddens = pad2dim_tensor_to_new_batch( + input=new_model_input.mtp_draft_input_hiddens, new_batch_size=new_handle_token_num, ) @@ -405,9 +411,9 @@ def _create_unpad_decode_model_output(self, model_output: ModelOutput, origin_ba new_model_output.logits = new_model_output.logits[0:origin_batch_size] # 特殊模型,特殊模式的特殊变量的特殊 unpad - if new_model_output.deepseekv3_mtp_main_output_hiddens is not None: - _hidden_states = new_model_output.deepseekv3_mtp_main_output_hiddens - new_model_output.deepseekv3_mtp_main_output_hiddens = _hidden_states[0:origin_batch_size] + if new_model_output.mtp_main_output_hiddens is not None: + _hidden_states = new_model_output.mtp_main_output_hiddens + new_model_output.mtp_main_output_hiddens = _hidden_states[0:origin_batch_size] return new_model_output @@ -421,9 +427,9 @@ def _create_unpad_prefill_model_output(self, padded_model_output: ModelOutput, o new_model_output.logits = new_model_output.logits[0:-1] # 特殊模型,特殊模式的特殊变量的特殊 unpad - if new_model_output.deepseekv3_mtp_main_output_hiddens is not None: - _hidden_states = new_model_output.deepseekv3_mtp_main_output_hiddens - new_model_output.deepseekv3_mtp_main_output_hiddens = _hidden_states[0:origin_handle_token_num] + if new_model_output.mtp_main_output_hiddens is not None: + _hidden_states = new_model_output.mtp_main_output_hiddens + new_model_output.mtp_main_output_hiddens = _hidden_states[0:origin_handle_token_num] return new_model_output @@ -457,8 +463,8 @@ def _prefill( prefill_mem_indexes_ready_event = torch.cuda.Event() prefill_mem_indexes_ready_event.record() - infer_state.init_some_extra_state(self, model_input.input_ids) - model_output = self._context_forward(model_input.input_ids, infer_state) + infer_state.init_some_extra_state(self) + model_output = self._context_forward(infer_state) if is_padded_model_input: model_output = self._create_unpad_prefill_model_output( model_output, origin_handle_token_num=origin_handle_token_num @@ -488,15 +494,13 @@ def _decode( infer_state.b_seq_len, infer_state.mem_index, ) - infer_state.init_some_extra_state(self, padded_model_input.input_ids) + infer_state.init_some_extra_state(self) if self.graph.need_capture(find_graph_batch_size): infer_state.is_cuda_graph = True - model_output: ModelOutput = self.graph.capture_decode( - self._token_forward, padded_model_input.input_ids, infer_state - ) + model_output: ModelOutput = self.graph.capture_decode(self._token_forward, infer_state) else: - model_output: ModelOutput = self.graph.replay(padded_model_input.input_ids, infer_state) + model_output: ModelOutput = self.graph.replay(infer_state) model_output = self._create_unpad_decode_model_output( model_output, origin_batch_size=model_input.batch_size @@ -509,14 +513,15 @@ def _decode( infer_state.b_seq_len, infer_state.mem_index, ) - infer_state.init_some_extra_state(self, model_input.input_ids) - model_output = self._token_forward(model_input.input_ids, infer_state) + infer_state.init_some_extra_state(self) + model_output = self._token_forward(infer_state) return model_output @final - def _context_forward(self, input_ids, infer_state: InferStateInfo): + def _context_forward(self, infer_state: InferStateInfo): run_mode_index = 1 if self.enable_tpsp_mix_mode else 0 + input_ids = infer_state.input_ids cuda_input_ids = input_ids pre_method = (self.pre_infer.context_forward, self.pre_infer.tpsp_context_forward)[run_mode_index] @@ -559,8 +564,8 @@ def prefill_func(input_tensors, infer_state): model_output = ModelOutput(logits=predict_logits) # 特殊模型特殊模式的额外输出 - if self.is_deepseekv3_mtp_mode: - model_output.deepseekv3_mtp_main_output_hiddens = input_embs + if self.is_mtp_mode: + model_output.mtp_main_output_hiddens = input_embs # 在开启使用deepep的时候,需要调用clear_deepep_buffer做资源清理,没有启用的时候 # 该调用没有实际意义 @@ -568,8 +573,9 @@ def prefill_func(input_tensors, infer_state): return model_output @final - def _token_forward(self, input_ids, infer_state: InferStateInfo): + def _token_forward(self, infer_state: InferStateInfo): run_mode_index = 1 if self.enable_tpsp_mix_mode else 0 + input_ids = infer_state.input_ids cuda_input_ids = input_ids pre_method = (self.pre_infer.token_forward, self.pre_infer.tpsp_token_forward)[run_mode_index] input_embs = pre_method(cuda_input_ids, infer_state, self.pre_post_weight) @@ -581,14 +587,14 @@ def _token_forward(self, input_ids, infer_state: InferStateInfo): post_method = (self.post_infer.token_forward, self.post_infer.tpsp_token_forward)[run_mode_index] predict_logits: torch.Tensor = post_method(input_embs, infer_state, self.pre_post_weight) - if self.is_deepseekv3_mtp_mode: + if self.is_mtp_mode: graph_out_hiddens = input_embs.contiguous() model_output = ModelOutput(logits=predict_logits.contiguous()) # 特殊模型特殊模式的额外输出 - if self.is_deepseekv3_mtp_mode: - model_output.deepseekv3_mtp_main_output_hiddens = graph_out_hiddens + if self.is_mtp_mode: + model_output.mtp_main_output_hiddens = graph_out_hiddens # 在 cuda graph 模式下,输出需要转为 no ref tensor, 加强mem pool 的复用,降低显存的使用。 if infer_state.is_cuda_graph: @@ -615,7 +621,7 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod alloc_mem_index=infer_state0.mem_index, max_q_seq_len=infer_state0.max_q_seq_len, ) - infer_state0.init_some_extra_state(self, input_ids0) + infer_state0.init_some_extra_state(self) infer_state1 = self._create_inferstate(model_input1, 1) init_req_to_token_indexes( @@ -627,7 +633,7 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod alloc_mem_index=infer_state1.mem_index, max_q_seq_len=infer_state1.max_q_seq_len, ) - infer_state1.init_some_extra_state(self, input_ids1) + infer_state1.init_some_extra_state(self) prefill_mem_indexes_ready_event = torch.cuda.Event() prefill_mem_indexes_ready_event.record() @@ -681,7 +687,7 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode infer_state0.b_seq_len, infer_state0.mem_index, ) - infer_state0.init_some_extra_state(self, padded_model_input0.input_ids) + infer_state0.init_some_extra_state(self) infer_state1 = self._create_inferstate(padded_model_input1, 1) copy_kv_index_to_req( self.req_manager.req_to_token_indexs, @@ -689,7 +695,7 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode infer_state1.b_seq_len, infer_state1.mem_index, ) - infer_state1.init_some_extra_state(self, padded_model_input1.input_ids) + infer_state1.init_some_extra_state(self) if self.graph.need_capture(find_graph_batch_size): infer_state0.is_cuda_graph = True @@ -697,16 +703,12 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode model_output0, model_output1 = self.graph.capture_decode( self._overlap_tpsp_token_forward, - padded_model_input0.input_ids, infer_state0, - input_ids1=padded_model_input1.input_ids, infer_state1=infer_state1, ) else: model_output0, model_output1 = self.graph.replay( - padded_model_input0.input_ids, infer_state0, - input_ids1=padded_model_input1.input_ids, infer_state1=infer_state1, ) @@ -721,7 +723,7 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode infer_state0.b_seq_len, infer_state0.mem_index, ) - infer_state0.init_some_extra_state(self, model_input0.input_ids) + infer_state0.init_some_extra_state(self) infer_state1 = self._create_inferstate(model_input1, 1) copy_kv_index_to_req( self.req_manager.req_to_token_indexs, @@ -729,20 +731,16 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode infer_state1.b_seq_len, infer_state1.mem_index, ) - infer_state1.init_some_extra_state(self, model_input1.input_ids) + infer_state1.init_some_extra_state(self) - model_output0, model_output1 = self._overlap_tpsp_token_forward( - model_input0.input_ids, infer_state0, input_ids1=model_input1.input_ids, infer_state1=infer_state1 - ) + model_output0, model_output1 = self._overlap_tpsp_token_forward(infer_state0, infer_state1=infer_state1) return model_output0, model_output1 @final - def _overlap_tpsp_context_forward( - self, input_ids, infer_state: InferStateInfo, input_ids1, infer_state1: InferStateInfo - ): + def _overlap_tpsp_context_forward(self, infer_state: InferStateInfo, infer_state1: InferStateInfo): g_cache_manager.cache_env_in() input_embs, input_embs1 = self.pre_infer.overlap_tpsp_context_forward( - input_ids, input_ids1, infer_state, infer_state1, self.pre_post_weight + infer_state.input_ids, infer_state1.input_ids, infer_state, infer_state1, self.pre_post_weight ) for i in range(self.layers_num): input_embs, input_embs1 = self.layers_infer[i].overlap_tpsp_context_forward( @@ -756,18 +754,16 @@ def _overlap_tpsp_context_forward( model_output = ModelOutput(logits=predict_logits.contiguous()) model_output1 = ModelOutput(logits=predict_logits1.contiguous()) - if self.is_deepseekv3_mtp_mode: - model_output.deepseekv3_mtp_main_output_hiddens = input_embs.contiguous() - model_output1.deepseekv3_mtp_main_output_hiddens = input_embs1.contiguous() + if self.is_mtp_mode: + model_output.mtp_main_output_hiddens = input_embs.contiguous() + model_output1.mtp_main_output_hiddens = input_embs1.contiguous() return model_output, model_output1 @final - def _overlap_tpsp_token_forward( - self, input_ids, infer_state: InferStateInfo, input_ids1, infer_state1: InferStateInfo - ): + def _overlap_tpsp_token_forward(self, infer_state: InferStateInfo, infer_state1: InferStateInfo): input_embs, input_embs1 = self.pre_infer.overlap_tpsp_token_forward( - input_ids, input_ids1, infer_state, infer_state1, self.pre_post_weight + infer_state.input_ids, infer_state1.input_ids, infer_state, infer_state1, self.pre_post_weight ) for i in range(self.layers_num): @@ -779,16 +775,16 @@ def _overlap_tpsp_token_forward( input_embs, input_embs1, infer_state, infer_state1, self.pre_post_weight ) - if self.is_deepseekv3_mtp_mode: + if self.is_mtp_mode: graph_out_hiddens = input_embs.contiguous() graph_out_hiddens1 = input_embs1.contiguous() model_output = ModelOutput(logits=predict_logits.contiguous()) model_output1 = ModelOutput(logits=predict_logits1.contiguous()) - if self.is_deepseekv3_mtp_mode: - model_output.deepseekv3_mtp_main_output_hiddens = graph_out_hiddens - model_output1.deepseekv3_mtp_main_output_hiddens = graph_out_hiddens1 + if self.is_mtp_mode: + model_output.mtp_main_output_hiddens = graph_out_hiddens + model_output1.mtp_main_output_hiddens = graph_out_hiddens1 if infer_state.is_cuda_graph: model_output.to_no_ref_tensor() @@ -993,12 +989,16 @@ def _init_padded_req(self): def _gen_special_model_input(self, token_num: int): special_model_input = {} - is_deepseekv3_mtp_draft_model = "Deepseek3MTPModel" in str(self.__class__) - if is_deepseekv3_mtp_draft_model: - special_model_input["deepseekv3_mtp_draft_input_hiddens"] = torch.randn( + is_mtp_draft_model = ( + "Deepseek3MTPModel" in str(self.__class__) + or "Qwen3MOEMTPModel" in str(self.__class__) + or "MistralMTPModel" in str(self.__class__) + ) + if is_mtp_draft_model: + special_model_input["mtp_draft_input_hiddens"] = torch.randn( token_num, self.config["hidden_size"], dtype=self.data_type, device="cuda" ) else: - special_model_input["deepseekv3_mtp_draft_input_hiddens"] = None + special_model_input["mtp_draft_input_hiddens"] = None return special_model_input diff --git a/lightllm/common/basemodel/batch_objs.py b/lightllm/common/basemodel/batch_objs.py index 5a98a13df..138f08427 100644 --- a/lightllm/common/basemodel/batch_objs.py +++ b/lightllm/common/basemodel/batch_objs.py @@ -46,9 +46,9 @@ class ModelInput: # 专有变量,用于一些特殊的模型,特殊的模式下, 传递一些特殊 # 的输入变量。只在特殊的模型模式下才会具体使用和生效。 - # deepseekv3_mtp_draft_input_hiddens 用于 deepseekv3 模型 mtp 模式下 + # mtp_draft_input_hiddens 用于模型 mtp 模式下 # 的 draft 模型的输入 - deepseekv3_mtp_draft_input_hiddens: Optional[torch.Tensor] = None + mtp_draft_input_hiddens: Optional[torch.Tensor] = None def to_cuda(self): if self.input_ids is not None: @@ -90,12 +90,12 @@ class ModelOutput: # 专有变量,用于一些特殊的模型,特殊的模式下, 传递一些特殊 # 的输出变量。只在特殊的模型模式下才会具体使用和生效。 - # deepseekv3_mtp_main_output_hiddens 用于在mtp模式下,llm main model - # 输出最后一层的hidden state 状态用于 draft 模型的 deepseekv3_mtp_draft_input_hiddens + # mtp_main_output_hiddens 用于在mtp模式下,llm main model + # 输出最后一层的hidden state 状态用于 draft 模型的 mtp_draft_input_hiddens # 输入 - deepseekv3_mtp_main_output_hiddens: Optional[torch.Tensor] = None + mtp_main_output_hiddens: Optional[torch.Tensor] = None def to_no_ref_tensor(self): self.logits = tensor_to_no_ref_tensor(self.logits) - if self.deepseekv3_mtp_main_output_hiddens is not None: - self.deepseekv3_mtp_main_output_hiddens = tensor_to_no_ref_tensor(self.deepseekv3_mtp_main_output_hiddens) + if self.mtp_main_output_hiddens is not None: + self.mtp_main_output_hiddens = tensor_to_no_ref_tensor(self.mtp_main_output_hiddens) diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py index 15c55e91c..9eeab7270 100644 --- a/lightllm/common/basemodel/cuda_graph.py +++ b/lightllm/common/basemodel/cuda_graph.py @@ -62,9 +62,10 @@ def find_closest_graph_batch_size(self, batch_size): else: return None - def _capture_decode(self, decode_func, input_ids: torch.Tensor, infer_state: InferStateInfo): + def _capture_decode(self, decode_func, infer_state: InferStateInfo): dist_group: CustomProcessGroup = infer_state.dist_group graph_obj = torch.cuda.CUDAGraph() + input_ids = infer_state.input_ids batch_size = input_ids.shape[0] infer_state.max_len_in_batch = self.graph_max_len_in_batch infer_state.total_token_num = self.graph_max_len_in_batch * batch_size @@ -78,27 +79,26 @@ def _capture_decode(self, decode_func, input_ids: torch.Tensor, infer_state: Inf # 中的 tensor。 for _ in range(1): torch.cuda.synchronize() - decode_func(input_ids, copy.copy(infer_state)) + decode_func(copy.copy(infer_state)) torch.cuda.synchronize() with lightllm_capture_graph(dist_group): with torch.cuda.graph(graph_obj, pool=self.mempool): - model_output = decode_func(input_ids, infer_state) - self.graph[batch_size] = (graph_obj, input_ids, infer_state, model_output) + model_output = decode_func(infer_state) + self.graph[batch_size] = (graph_obj, infer_state, model_output) graph_obj.replay() return model_output def _capture_decode_overlap( self, decode_func, - input_ids: torch.Tensor, infer_state: InferStateInfo, - input_ids1: torch.Tensor, infer_state1: InferStateInfo, ): dist_group: CustomProcessGroup = infer_state.dist_group dist_group1 = infer_state1.dist_group graph_obj = torch.cuda.CUDAGraph() + input_ids = infer_state.input_ids batch_size = input_ids.shape[0] infer_state.max_len_in_batch = self.graph_max_len_in_batch infer_state.total_token_num = self.graph_max_len_in_batch * batch_size @@ -107,17 +107,15 @@ def _capture_decode_overlap( # warmup for _ in range(1): torch.cuda.synchronize() - decode_func(input_ids, copy.copy(infer_state), input_ids1, copy.copy(infer_state1)) + decode_func(copy.copy(infer_state), copy.copy(infer_state1)) torch.cuda.synchronize() with lightllm_capture_graph(dist_group1): with lightllm_capture_graph(dist_group): with torch.cuda.graph(graph_obj, pool=self.mempool): - model_output, model_output1 = decode_func(input_ids, infer_state, input_ids1, infer_state1) + model_output, model_output1 = decode_func(infer_state, infer_state1) self.graph[batch_size] = ( graph_obj, - input_ids, infer_state, - input_ids1, infer_state1, model_output, model_output1, @@ -128,59 +126,50 @@ def _capture_decode_overlap( def capture_decode( self, decode_func, - input_ids: torch.Tensor, infer_state: InferStateInfo, - input_ids1: Optional[torch.Tensor] = None, - infer_state1: Optional[torch.Tensor] = None, + infer_state1: Optional[InferStateInfo] = None, ): """ Capture the cuda graph for the decoding stage. input_ids1 and infer_state1 is used for the overlap. """ if self.enable_decode_microbatch_overlap: - return self._capture_decode_overlap(decode_func, input_ids, infer_state, input_ids1, infer_state1) + return self._capture_decode_overlap(decode_func, infer_state, infer_state1) else: - assert input_ids1 is None and infer_state1 is None - return self._capture_decode(decode_func, input_ids, infer_state) + assert infer_state1 is None + return self._capture_decode(decode_func, infer_state) - def _replay(self, input_ids: torch.Tensor, infer_state: InferStateInfo): - batch_size = input_ids.shape[0] - graph_obj, graph_input_ids, graph_infer_state, graph_output = self.graph[batch_size] - graph_input_ids.copy_(input_ids) + def _replay(self, infer_state: InferStateInfo): + batch_size = infer_state.input_ids.shape[0] + graph_obj, graph_infer_state, graph_output = self.graph[batch_size] graph_infer_state.copy_for_cuda_graph(infer_state) graph_obj.replay() return graph_output def _replay_overlap( self, - input_ids: torch.Tensor, infer_state: InferStateInfo, - input_ids1: torch.Tensor, infer_state1: InferStateInfo, ): - batch_size = input_ids.shape[0] + batch_size = infer_state.input_ids.shape[0] ( graph_obj, - graph_input_ids, graph_infer_state, - graph_input_ids1, graph_infer_state1, graph_model_output, graph_model_output1, ) = self.graph[batch_size] - graph_input_ids.copy_(input_ids) graph_infer_state.copy_for_cuda_graph(infer_state) - graph_input_ids1.copy_(input_ids1) graph_infer_state1.copy_for_cuda_graph(infer_state1) graph_obj.replay() return graph_model_output, graph_model_output1 - def replay(self, input_ids, infer_state, input_ids1=None, infer_state1=None): + def replay(self, infer_state, infer_state1=None): if self.enable_decode_microbatch_overlap: - return self._replay_overlap(input_ids, infer_state, input_ids1, infer_state1) + return self._replay_overlap(infer_state, infer_state1) else: - assert input_ids1 is None and infer_state1 is None - return self._replay(input_ids, infer_state) + assert infer_state1 is None + return self._replay(infer_state) @torch.no_grad() def warmup(self, model): diff --git a/lightllm/common/basemodel/infer_struct.py b/lightllm/common/basemodel/infer_struct.py index 3c28a47b8..8e7174bb3 100755 --- a/lightllm/common/basemodel/infer_struct.py +++ b/lightllm/common/basemodel/infer_struct.py @@ -19,6 +19,7 @@ class InferStateInfo: """ def __init__(self): + self.input_ids: torch.Tensor = None self.batch_size: int = None self.total_token_num: int = None self.b_req_idx: torch.Tensor = None @@ -71,10 +72,10 @@ def __init__(self): # inferstate的基类中,但是为了代码的简洁和方便,都放在基类中 # 进行管理。注意这些成员变量只会在特定的模型和模式下才会生效。 - # deepseekv3 mtp draft model 使用的额外输入参数, - # 在开启 mtp_mode == deepseekv3 时,mtp draft model + # mtp draft model 使用的额外输入参数, + # 在开启 mtp_mode 时,mtp draft model # 的输入会用到,其他模型和场景都不会用到 - self.deepseekv3_mtp_draft_input_hiddens: Optional[torch.Tensor] = None + self.mtp_draft_input_hiddens: Optional[torch.Tensor] = None # 在单节点多dp的运行模式下,在进行prefill的阶段,如果出现了dp之间数据不平衡的现象, # 可以将推理的数据,进行重新分配到各个dp,在做 att 之前,重新 all to all 到各自的 @@ -88,7 +89,8 @@ def __init__(self): self.dp_output_split_sizes: List[List[int]] = None self.dp_input_split_sizes: List[List[int]] = None - def init_some_extra_state(self, model, input_ids: torch.Tensor): + def init_some_extra_state(self, model): + if self.is_prefill: ( self.b_q_seq_len, @@ -97,7 +99,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): self.b1_cu_kv_seq_len, self.position_ids, ) = gen_prefill_params( - input_token_num=input_ids.shape[0], + input_token_num=self.input_ids.shape[0], b_ready_cache_len=self.b_ready_cache_len, b_seq_len=self.b_seq_len, ) @@ -211,6 +213,9 @@ def prefill_dp_balance(self, input_ids: torch.Tensor): self.position_sin = self._all_to_all_balance_get(self.position_sin) + self._unbalance_input_ids = self.input_ids + self.input_ids = new_input_ids + return new_input_ids def _all_to_all_balance_get(self, data: torch.Tensor): diff --git a/lightllm/common/basemodel/layer_infer/template/pre_layer_infer_template.py b/lightllm/common/basemodel/layer_infer/template/pre_layer_infer_template.py index 6dc1ebb50..e7a084079 100644 --- a/lightllm/common/basemodel/layer_infer/template/pre_layer_infer_template.py +++ b/lightllm/common/basemodel/layer_infer/template/pre_layer_infer_template.py @@ -8,8 +8,6 @@ class PreLayerInferTpl(PreLayerInfer): def __init__(self, network_config, mode): super().__init__(network_config, mode) self.eps_ = 1e-5 - self.vob_start_id_ = -1 - self.vob_end_id_ = -1 return def _norm(self, input, infer_state, layer_weight) -> torch.Tensor: diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py index b3dab0614..0fa02780c 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py @@ -6,6 +6,8 @@ COLMMWeight, ROWBMMWeight, ) -from .norm_weight import NormWeight, GEMMANormWeight, TpNormWeight +from .norm_weight import NoTpGEMMANormWeight, TpVitPadNormWeight, NoTpNormWeight, TpHeadNormWeight from .fused_moe_weight_tp import create_tp_moe_wegiht_obj from .fused_moe_weight_ep import FusedMoeWeightEP +from .embedding_weight import EmbeddingWeight, LMHeadWeight, NoTpPosEmbeddingWeight +from .att_sink_weight import TpAttSinkWeight diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/att_sink_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/att_sink_weight.py new file mode 100644 index 000000000..3f8e1f50a --- /dev/null +++ b/lightllm/common/basemodel/layer_weights/meta_weights/att_sink_weight.py @@ -0,0 +1,23 @@ +import torch +from typing import Dict +from .base_weight import BaseWeightTpl +from lightllm.utils.dist_utils import get_current_device_id + + +class TpAttSinkWeight(BaseWeightTpl): + def __init__(self, weight_name: str, data_type): + super().__init__() + self.weight_name = weight_name + self.data_type_ = data_type + self.weight: torch.Tensor = None + + def load_hf_weights(self, weights: Dict[str, torch.Tensor]): + if self.weight_name not in weights or self.weight is not None: + return + + t_weight = weights[self.weight_name] + start_head_index, end_head_index = self._get_head_tp_split_params(weight=t_weight) + self.weight = t_weight[start_head_index:end_head_index].to(self.data_type_).cuda(get_current_device_id()) + + def verify_load(self): + return self.weight is not None diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py index 0d1a1e475..2cd8ea6ae 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py @@ -1,6 +1,6 @@ import torch from abc import ABC, abstractmethod -from typing import Dict +from typing import Dict, Tuple from lightllm.utils.dist_utils import get_dp_world_size, get_current_rank_in_dp, get_current_device_id @@ -29,3 +29,30 @@ def load_hf_weights(self, weights): def verify_load(self) -> bool: raise NotImplementedError("verify_load must implement this method") + + def _get_head_tp_split_params(self, weight: torch.Tensor) -> Tuple[int, int]: + """ + Docstring for _get_head_tp_split_params, + 一个常用的tp 划分head获取head_index 范围的功能函数, 一些继承类可能会使用。 + :param self: Description + :param weight: Description + :type weight: torch.Tensor + :return: Description + :rtype: Tuple[int, int] + """ + assert weight.ndim == 2 + + all_head_num = weight.shape[0] + tp_head_num = all_head_num // self.tp_world_size_ + + if tp_head_num > 0: + start_head_index = self.tp_rank_ * tp_head_num + end_head_index = (self.tp_rank_ + 1) * tp_head_num + else: + # 当 tp_world_size 大于 all_head_num 时的特殊处理 + scale_size = self.tp_world_size_ // all_head_num + assert self.tp_world_size_ % all_head_num == 0 + start_head_index = self.tp_rank_ // scale_size + end_head_index = start_head_index + 1 + + return start_head_index, end_head_index diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py new file mode 100644 index 000000000..fc018267f --- /dev/null +++ b/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py @@ -0,0 +1,108 @@ +import torch +import numpy as np +from typing import Dict, Optional +from .base_weight import BaseWeightTpl +from lightllm.utils.dist_utils import get_current_device_id +from lightllm.common.basemodel.triton_kernel.embedding import embedding as embedding_kernel +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class EmbeddingWeight(BaseWeightTpl): + def __init__(self, weight_name, data_type): + super().__init__() + self.weight_name: str = weight_name + self.data_type_ = data_type + self.weight: torch.Tensor = None + + def load_hf_weights(self, weights: Dict[str, torch.Tensor]): + if self.weight_name not in weights or self.weight is not None: + return + + t_weight = weights[self.weight_name] + # init some params + self.vocab_size = len(t_weight) + split_indexes = np.linspace(0, self.vocab_size, self.tp_world_size_ + 1, dtype=np.int64) + self.tp_vocab_start_id = int(split_indexes[self.tp_rank_]) + self.tp_vocab_end_id = int(split_indexes[self.tp_rank_ + 1]) + + logger.info(f"loaded weight vocab_size: {self.vocab_size}") + + self.weight = ( + t_weight[self.tp_vocab_start_id : self.tp_vocab_end_id, :].to(self.data_type_).cuda(get_current_device_id()) + ) + + def verify_load(self): + return self.weight is not None + + def embedding(self, input_ids: torch.Tensor, out: Optional[torch.Tensor] = None, alloc_func=torch.empty): + if out is None: + out = alloc_func( + (input_ids.shape[0], self.weight.shape[1]), dtype=self.weight.dtype, device=self.weight.device + ) + + embedding_kernel( + input_ids=input_ids, + weight=self.weight, + vob_start_id=self.tp_vocab_start_id, + vob_end_id=self.tp_vocab_end_id, + out=out, + ) + + return out + + def lm_head(self, input: torch.Tensor, out: Optional[torch.Tensor] = None, alloc_func=torch.empty): + assert input.ndim == 2 + if out is None: + out = alloc_func( + (self.weight.shape[0], input.shape[1]), + dtype=input.dtype, + device=input.device, + ) + + torch.mm(self.weight, input, out=out) + return out + + +class LMHeadWeight(EmbeddingWeight): + def __init__(self, weight_name, data_type): + super().__init__(weight_name, data_type) + + +class NoTpPosEmbeddingWeight(BaseWeightTpl): + def __init__(self, weight_name, data_type): + super().__init__() + self.weight_name: str = weight_name + self.data_type_ = data_type + self.weight: torch.Tensor = None + self.tp_world_size_ = 1 + self.tp_rank_ = 0 + + def load_hf_weights(self, weights: Dict[str, torch.Tensor]): + if self.weight_name not in weights or self.weight is not None: + return + + t_weight = weights[self.weight_name] + self.weight = t_weight.to(self.data_type_).cuda(get_current_device_id()) + self.end_position_id: int = t_weight.shape[0] + logger.info(f"loaded weight end_position_id: {self.end_position_id}") + + def verify_load(self): + return self.weight is not None + + def embedding(self, input_ids: torch.Tensor, out: Optional[torch.Tensor] = None, alloc_func=torch.empty): + if out is None: + out = alloc_func( + (input_ids.shape[0], self.weight.shape[1]), dtype=self.weight.dtype, device=self.weight.device + ) + + embedding_kernel( + input_ids=input_ids, + weight=self.weight, + vob_start_id=0, + vob_end_id=self.end_position_id, + out=out, + ) + + return out diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py index 7ec672ab8..5a595bff6 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py @@ -1,22 +1,22 @@ import torch +from typing import Optional from .base_weight import BaseWeightTpl from lightllm.utils.dist_utils import get_current_device_id +from lightllm.common.basemodel.triton_kernel.rmsnorm import rmsnorm_forward +from lightllm.common.basemodel.triton_kernel.layernorm import layernorm_forward +from lightllm.utils.log_utils import init_logger +logger = init_logger(__name__) -class NormWeight(BaseWeightTpl): + +class _NormWeight(BaseWeightTpl): def __init__(self, weight_name, data_type, bias_name=None): super().__init__() self.weight_name = weight_name self.bias_name = bias_name self.data_type_ = data_type - self.weight = None - self.bias = None - - def load_hf_weights(self, weights): - if self.weight_name in weights: - self.weight = weights[self.weight_name].to(self.data_type_).cuda(get_current_device_id()) - if self.bias_name in weights: - self.bias = weights[self.bias_name].to(self.data_type_).cuda(get_current_device_id()) + self.weight: torch.Tensor = None + self.bias: Optional[torch.Tensor] = None def verify_load(self): load_ok = True @@ -27,26 +27,117 @@ def verify_load(self): load_ok = load_ok and self.bias is not None return load_ok + def rmsnorm_forward( + self, input: torch.Tensor, eps: float, out: Optional[torch.Tensor] = None, alloc_func=torch.empty + ) -> torch.Tensor: + assert input.ndim in [2, 3] and self.weight.ndim == 1 + assert self.bias is None + if out is None: + out = alloc_func(input.shape, dtype=input.dtype, device=input.device) + return rmsnorm_forward(x=input, weight=self.weight, eps=eps, out=out) + + def layernorm_forward( + self, input: torch.Tensor, eps: float, out: Optional[torch.Tensor] = None, alloc_func=torch.empty + ) -> torch.Tensor: + assert input.ndim == 2 and self.weight.ndim == 1 + assert self.bias is not None + + _tout = layernorm_forward(x=input, weight=self.weight, bias=self.bias, eps=eps) + if out is None: + return _tout + else: + out.copy_(_tout) + return out + + +class NoTpNormWeight(_NormWeight): + def __init__(self, weight_name, data_type, bias_name=None): + super().__init__(weight_name=weight_name, data_type=data_type, bias_name=bias_name) + self.tp_world_size_ = 1 + self.tp_rank_ = 0 + + def load_hf_weights(self, weights): + if self.weight_name in weights and self.weight is None: + self.weight = weights[self.weight_name].to(self.data_type_).cuda(get_current_device_id()) + if self.bias_name in weights and self.bias is None: + self.bias = weights[self.bias_name].to(self.data_type_).cuda(get_current_device_id()) + -class GEMMANormWeight(NormWeight): +class NoTpGEMMANormWeight(_NormWeight): def __init__(self, weight_name, data_type, bias_name=None): super().__init__(weight_name, data_type, bias_name) + assert self.bias_name is None + self.tp_world_size_ = 1 + self.tp_rank_ = 0 def load_hf_weights(self, weights): - if self.weight_name in weights: + if self.weight_name in weights and self.weight is None: self.weight = (weights[self.weight_name] + 1).to(self.data_type_).cuda(get_current_device_id()) -class TpNormWeight(NormWeight): - def __init__(self, weight_name, data_type, split_n_embed, bias_name=None): +class TpVitPadNormWeight(_NormWeight): + def __init__(self, weight_name, data_type, head_num: int, bias_name=None): + super().__init__(weight_name, data_type, bias_name) + self.head_num = head_num + + def _pad_tensor_param(self, weight: torch.Tensor): + assert weight.ndim == 1 + hidden_size = weight.shape[0] + head_dim = hidden_size // self.head_num + assert hidden_size % self.head_num == 0 + + if self.head_num % self.tp_world_size_ == 0: + return weight + else: + logger.warning(f"padding {self.weight_name} weights in TpVitPadNormWeight") + pad_head_num = self.tp_world_size_ - (self.head_num % self.tp_world_size_) + pad_dims = pad_head_num * head_dim + weight = torch.nn.functional.pad(weight, (0, pad_dims), mode="constant", value=0.0) + return weight + + def load_hf_weights(self, weights): + if self.weight_name in weights and self.weight is None: + t_weight = weights[self.weight_name] + t_weight = self._pad_tensor_param(t_weight) + new_hidden_size = t_weight.shape[0] + split_n_embed = new_hidden_size // self.tp_world_size_ + assert new_hidden_size % self.tp_world_size_ == 0 + + start = split_n_embed * self.tp_rank_ + end = split_n_embed * (self.tp_rank_ + 1) + + self.weight = t_weight[start:end].to(self.data_type_).cuda(get_current_device_id()) + + if self.bias_name in weights and self.bias is None: + t_bias = weights[self.bias_name] + t_bias = self._pad_tensor_param(t_bias) + new_hidden_size = t_bias.shape[0] + split_n_embed = new_hidden_size // self.tp_world_size_ + assert new_hidden_size % self.tp_world_size_ == 0 + + start = split_n_embed * self.tp_rank_ + end = split_n_embed * (self.tp_rank_ + 1) + + self.bias = t_bias[start:end].to(self.data_type_).cuda(get_current_device_id()) + + +class TpHeadNormWeight(_NormWeight): + def __init__(self, weight_name, data_type, bias_name=None): super().__init__(weight_name, data_type, bias_name) - self.split_n_embed = split_n_embed def load_hf_weights(self, weights): - start = self.split_n_embed * self.tp_rank_ - end = self.split_n_embed * (self.tp_rank_ + 1) + if self.weight_name in weights and self.weight is None: + t_weight = weights[self.weight_name] + start_head_index, end_head_index = self._get_head_tp_split_params(weight=t_weight) + self.weight: torch.Tensor = ( + t_weight[start_head_index:end_head_index].to(self.data_type_).cuda(get_current_device_id()) + ) + assert self.weight.ndim == 2 - if self.weight_name in weights: - self.weight = weights[self.weight_name][start:end].to(self.data_type_).cuda(get_current_device_id()) - if self.bias_name in weights: - self.bias = weights[self.bias_name][start:end].to(self.data_type_).cuda(get_current_device_id()) + if self.bias_name in weights and self.bias is None: + t_bias = weights[self.bias_name] + start_head_index, end_head_index = self._get_head_tp_split_params(weight=t_bias) + self.bias: torch.Tensor = ( + t_bias[start_head_index:end_head_index].to(self.data_type_).cuda(get_current_device_id()) + ) + assert self.bias.ndim == 2 diff --git a/lightllm/common/basemodel/layer_weights/pre_and_post_layer_weight.py b/lightllm/common/basemodel/layer_weights/pre_and_post_layer_weight.py index bb670b289..19eb67017 100644 --- a/lightllm/common/basemodel/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/common/basemodel/layer_weights/pre_and_post_layer_weight.py @@ -1,4 +1,5 @@ from .base_layer_weight import BaseLayerWeight +from .meta_weights import BaseWeight, MMWeightTpl class PreAndPostLayerWeight(BaseLayerWeight): @@ -9,3 +10,15 @@ def __init__(self, data_type, network_config, mode): self.mode = mode self.init_static_params() return + + def load_hf_weights(self, weights): + """ + load weights + """ + for attr_name in dir(self): + attr = getattr(self, attr_name, None) + if isinstance(attr, MMWeightTpl) and len(attr.weight_names) >= 2: + with self.lock: + attr.load_hf_weights(weights) + elif isinstance(attr, BaseWeight): + attr.load_hf_weights(weights) diff --git a/lightllm/models/llama/triton_kernel/embedding.py b/lightllm/common/basemodel/triton_kernel/embedding.py similarity index 100% rename from lightllm/models/llama/triton_kernel/embedding.py rename to lightllm/common/basemodel/triton_kernel/embedding.py diff --git a/lightllm/models/bloom/triton_kernel/layernorm.py b/lightllm/common/basemodel/triton_kernel/layernorm.py similarity index 84% rename from lightllm/models/bloom/triton_kernel/layernorm.py rename to lightllm/common/basemodel/triton_kernel/layernorm.py index 6911d707b..538bb2b13 100644 --- a/lightllm/models/bloom/triton_kernel/layernorm.py +++ b/lightllm/common/basemodel/triton_kernel/layernorm.py @@ -24,15 +24,15 @@ def _layer_norm_fwd_fused( _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32) for off in range(0, N, BLOCK_SIZE): cols = off + tl.arange(0, BLOCK_SIZE) - a = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) + a = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) _mean += a mean = tl.sum(_mean, axis=0) / N # Compute variance _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) for off in range(0, N, BLOCK_SIZE): cols = off + tl.arange(0, BLOCK_SIZE) - x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) - x = tl.where(cols < N, x - mean, 0.) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + x = tl.where(cols < N, x - mean, 0.0) _var += x * x var = tl.sum(_var, axis=0) / N rstd = 1 / tl.sqrt(var + eps) @@ -42,7 +42,7 @@ def _layer_norm_fwd_fused( mask = cols < N w = tl.load(W + cols, mask=mask).to(tl.float32) b = tl.load(B + cols, mask=mask).to(tl.float32) - x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32) + x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) x_hat = (x - mean) * rstd y = x_hat * w + b # Write output @@ -72,17 +72,18 @@ def _layer_norm_fwd_fused( # BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps) # return y + def layernorm_forward(x, weight, bias, eps): return torch.layer_norm(x, (x.shape[-1],), weight, bias, eps) -def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'): +def test_layer_norm(M, N, dtype, eps=1e-5, device="cuda"): # create data x_shape = (M, N) - w_shape = (x_shape[-1], ) - weight = torch.rand(w_shape, dtype=dtype, device='cuda') - bias = torch.rand(w_shape, dtype=dtype, device='cuda') - x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda') + w_shape = (x_shape[-1],) + weight = torch.rand(w_shape, dtype=dtype, device="cuda") + bias = torch.rand(w_shape, dtype=dtype, device="cuda") + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") # forward pass y_tri = layernorm_forward(x, weight, bias, eps) y_ref = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype) diff --git a/lightllm/models/llama/triton_kernel/rmsnorm.py b/lightllm/common/basemodel/triton_kernel/rmsnorm.py similarity index 95% rename from lightllm/models/llama/triton_kernel/rmsnorm.py rename to lightllm/common/basemodel/triton_kernel/rmsnorm.py index 0140847af..ca8f9a1c8 100644 --- a/lightllm/models/llama/triton_kernel/rmsnorm.py +++ b/lightllm/common/basemodel/triton_kernel/rmsnorm.py @@ -44,12 +44,13 @@ def _rms_norm_fwd_fused( tl.store(Y + cols * y_stride1, y.to(Y.dtype.element_ty), mask=mask) -def rmsnorm_forward(x: torch.Tensor, weight, eps, out=None): +def rmsnorm_forward(x: torch.Tensor, weight: torch.Tensor, eps: float, out=None): # allocate output y = torch.empty_like(x) if out is None else out # reshape input data into 2D tensor x_arg = x.view(-1, x.shape[-1]) y_arg = y.view(-1, x.shape[-1]) + assert x_arg.shape[-1] == weight.shape[0] and x_arg.shape == y_arg.shape assert y.data_ptr() == y_arg.data_ptr() M, N = x_arg.shape # Less than 64KB per feature: enqueue fused kernel diff --git a/lightllm/models/bloom/layer_infer/post_layer_infer.py b/lightllm/models/bloom/layer_infer/post_layer_infer.py index 0cf8f8e99..7938869f5 100644 --- a/lightllm/models/bloom/layer_infer/post_layer_infer.py +++ b/lightllm/models/bloom/layer_infer/post_layer_infer.py @@ -1,62 +1,21 @@ import torch import torch.functional as F import numpy as np - from lightllm.models.bloom.layer_weights.pre_and_post_layer_weight import BloomPreAndPostLayerWeight -from einops import rearrange -from lightllm.common.basemodel import InferStateInfo, PostLayerInferTpl -from lightllm.models.bloom.triton_kernel.layernorm import layernorm_forward from lightllm.distributed.communication_op import all_gather +from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer +from lightllm.common.build_utils import repair_config -class BloomPostLayerInfer(PostLayerInferTpl): +class BloomPostLayerInfer(LlamaPostLayerInfer): """ """ def __init__(self, network_config, mode): + repair_config(config=network_config, same_names=["layer_norm_epsilon", "rms_norm_eps"]) super().__init__(network_config, mode) - assert network_config["vocab_size"] % self.tp_world_size_ == 0 - self.eps_ = network_config["layer_norm_epsilon"] - self.vocab_size_ = network_config["vocab_size"] - self.embed_dim_ = network_config["n_embed"] return def _norm(self, input, infer_state, layer_weight: BloomPreAndPostLayerWeight) -> torch.Tensor: - return layernorm_forward(input, layer_weight.final_norm_weight_, layer_weight.final_norm_bias_, eps=self.eps_) - - def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight: BloomPreAndPostLayerWeight): - batch_size = infer_state.batch_size - last_input = self.alloc_tensor( - (batch_size, self.embed_dim_), device=input_embdings.device, dtype=input_embdings.dtype + return layer_weight.final_norm_weight_.layernorm_forward( + input=input, eps=self.eps_, alloc_func=self.alloc_tensor ) - if infer_state.is_prefill: - last_index = ( - torch.cumsum(infer_state.b_seq_len - infer_state.b_ready_cache_len, dim=0, dtype=torch.long) - 1 - ) - last_input[:, :] = input_embdings[last_index, :] - else: - last_input[:, :] = input_embdings[-batch_size:, :] - - input_embdings_dtype = input_embdings.dtype - input_embdings = None - last_input = self._norm(last_input, infer_state, layer_weight) - last_input = rearrange(last_input, "batch embed_dim -> embed_dim batch").contiguous().reshape(-1, batch_size) - logic_batch = torch.mm(layer_weight.lm_head_weight_, last_input) - last_input = None - if self.tp_world_size_ == 1: - gather_data = logic_batch - else: - gather_data = self.alloc_tensor( - (self.vocab_size_, batch_size), device=logic_batch.device, dtype=input_embdings_dtype - ) - split_size = self.vocab_size_ // self.tp_world_size_ - all_gather( - [gather_data[i * split_size : (i + 1) * split_size, :] for i in range(self.tp_world_size_)], - logic_batch, - group=infer_state.dist_group, - async_op=False, - ) - logic_batch = None - - ans_logics = gather_data.permute(1, 0).float() - gather_data = None - return ans_logics diff --git a/lightllm/models/bloom/layer_infer/pre_layer_infer.py b/lightllm/models/bloom/layer_infer/pre_layer_infer.py index 1733c20b0..baf1d3084 100644 --- a/lightllm/models/bloom/layer_infer/pre_layer_infer.py +++ b/lightllm/models/bloom/layer_infer/pre_layer_infer.py @@ -3,9 +3,6 @@ from lightllm.common.basemodel import PreLayerInferTpl from lightllm.common.basemodel import InferStateInfo from lightllm.models.bloom.layer_weights.pre_and_post_layer_weight import BloomPreAndPostLayerWeight -from lightllm.utils.infer_utils import mark_cost_time -from lightllm.models.bloom.triton_kernel.layernorm import layernorm_forward -from lightllm.models.llama.triton_kernel.embedding import embedding from lightllm.distributed.communication_op import all_reduce @@ -15,32 +12,20 @@ class BloomPreLayerInfer(PreLayerInferTpl): def __init__(self, network_config, mode): super().__init__(network_config, mode) self.eps_ = network_config["layer_norm_epsilon"] - tp_vocab_size_ = network_config["vocab_size"] // self.tp_world_size_ - self.vob_start_id_ = tp_vocab_size_ * self.tp_rank_ - self.vob_end_id_ = tp_vocab_size_ * (self.tp_rank_ + 1) return def _norm(self, input, infer_state, layer_weight: BloomPreAndPostLayerWeight) -> torch.Tensor: - return layernorm_forward(input, layer_weight.pre_norm_weight_, layer_weight.pre_norm_bias_, eps=self.eps_) + return layer_weight.pre_norm_weight_.layernorm_forward(input=input, eps=self.eps_, alloc_func=self.alloc_tensor) def context_forward(self, input_ids, infer_state: InferStateInfo, layer_weight: BloomPreAndPostLayerWeight): - total_token_num = infer_state.total_token_num - input_ids = input_ids[0:total_token_num] - - input_embdings = self.alloc_tensor( - (input_ids.shape[0], layer_weight.wte_weight_.shape[1]), dtype=layer_weight.data_type_ - ) - embedding(input_ids, layer_weight.wte_weight_, self.vob_start_id_, self.vob_end_id_, input_embdings) + input_embdings = layer_weight.wte_weight_.embedding(input_ids=input_ids, alloc_func=self.alloc_tensor) if self.tp_world_size_ > 1: all_reduce(input_embdings, group=infer_state.dist_group, op=dist.ReduceOp.SUM, async_op=False) input_embdings = self._norm(input_embdings, infer_state, layer_weight) return input_embdings def token_forward(self, input_ids, infer_state: InferStateInfo, layer_weight: BloomPreAndPostLayerWeight): - input_embdings = self.alloc_tensor( - (input_ids.shape[0], layer_weight.wte_weight_.shape[1]), dtype=layer_weight.data_type_ - ) - embedding(input_ids, layer_weight.wte_weight_, self.vob_start_id_, self.vob_end_id_, input_embdings) + input_embdings = layer_weight.wte_weight_.embedding(input_ids=input_ids, alloc_func=self.alloc_tensor) if self.tp_world_size_ > 1: all_reduce(input_embdings, group=infer_state.dist_group, op=dist.ReduceOp.SUM, async_op=False) input_embdings = self._norm(input_embdings, infer_state, layer_weight) diff --git a/lightllm/models/bloom/layer_infer/transformer_layer_infer.py b/lightllm/models/bloom/layer_infer/transformer_layer_infer.py index 8299697f3..d82a23d03 100755 --- a/lightllm/models/bloom/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/bloom/layer_infer/transformer_layer_infer.py @@ -1,16 +1,10 @@ -import time import torch -import torch.functional as F -import torch.distributed as dist -import numpy as np from typing import Tuple from lightllm.common.basemodel import TransformerLayerInferTpl from lightllm.models.bloom.layer_weights.transformer_layer_weight import BloomTransformerLayerWeight from lightllm.models.bloom.triton_kernel.context_flashattention_nopad import context_attention_fwd from lightllm.models.bloom.triton_kernel.token_flashattention_nopad import token_attention_fwd -from lightllm.models.bloom.triton_kernel.layernorm import layernorm_forward from lightllm.common.basemodel import InferStateInfo -from lightllm.utils.infer_utils import mark_cost_time class BloomTransformerLayerInfer(TransformerLayerInferTpl): @@ -27,20 +21,18 @@ def __init__(self, layer_num, network_config, mode): self.embed_dim_ = network_config["n_embed"] return - def _att_norm(self, input, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight) -> torch.Tensor: - return layernorm_forward( - input.view(-1, self.embed_dim_), - weight=layer_weight.att_norm_weight_.weight, - bias=layer_weight.att_norm_weight_.bias, - eps=self.eps_, + def _att_norm( + self, input: torch.Tensor, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight + ) -> torch.Tensor: + return layer_weight.att_norm_weight_.layernorm_forward( + input=input.view(-1, self.embed_dim_), eps=self.eps_, alloc_func=self.alloc_tensor ) - def _ffn_norm(self, input, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight) -> torch.Tensor: - return layernorm_forward( - input.view(-1, self.embed_dim_), - weight=layer_weight.ffn_norm_weight_.weight, - bias=layer_weight.ffn_norm_weight_.bias, - eps=self.eps_, + def _ffn_norm( + self, input: torch.Tensor, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight + ) -> torch.Tensor: + return layer_weight.ffn_norm_weight_.layernorm_forward( + input=input.view(-1, self.embed_dim_), eps=self.eps_, alloc_func=self.alloc_tensor ) def _get_qkv( diff --git a/lightllm/models/bloom/layer_weights/hf_load_utils.py b/lightllm/models/bloom/layer_weights/hf_load_utils.py deleted file mode 100755 index 01c4c5862..000000000 --- a/lightllm/models/bloom/layer_weights/hf_load_utils.py +++ /dev/null @@ -1,58 +0,0 @@ -import torch -import os -import gc -from safetensors import safe_open - - -def load_hf_weights(data_type, weight_dir, pre_post_layer=None, transformer_layer_list=None, weight_dict=None): - if isinstance(data_type, str): - data_type = torch.float16 if data_type == 'fp16' else torch.float32 - if pre_post_layer is not None: - assert pre_post_layer.data_type_ == data_type, "type is not right" - if transformer_layer_list is not None: - assert transformer_layer_list[0].data_type_ == data_type, "type is not right" - if weight_dict: - new_w = {} - for k,v in weight_dict.items(): - if "transformer." in k: - new_w[k[len("transformer."):]] = v - else: - new_w[k] = v - del weight_dict - weight_dict = new_w - if pre_post_layer is not None: - pre_post_layer.load_hf_weights(weight_dict) - if transformer_layer_list is not None: - for layer in transformer_layer_list: - layer.load_hf_weights(weight_dict) - del weight_dict - return - use_safetensors = True - files = os.listdir(weight_dir) - candidate_files = list(filter(lambda x : x.endswith('.safetensors'), files)) - if len(candidate_files) == 0: - use_safetensors = False - candidate_files = list(filter(lambda x : x.endswith('.bin'), files)) - assert len(candidate_files) != 0, "can only support pytorch tensor and safetensors format for weights." - for file_ in candidate_files: - if use_safetensors: - weights = safe_open(os.path.join(weight_dir, file_), 'pt', 'cpu') - weights = {k: weights.get_tensor(k) for k in weights.keys()} - else: - weights = torch.load(os.path.join(weight_dir, file_), 'cpu') - new_w = {} - for k,v in weights.items(): - if "transformer." in k: - new_w[k[len("transformer."):]] = v - else: - new_w[k] = v - del weights - weights = new_w - if pre_post_layer is not None: - pre_post_layer.load_hf_weights(weights) - if transformer_layer_list is not None: - for layer in transformer_layer_list: - layer.load_hf_weights(weights) - del weights - gc.collect() - return diff --git a/lightllm/models/bloom/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/bloom/layer_weights/pre_and_post_layer_weight.py index b740bb62f..afc8c9308 100644 --- a/lightllm/models/bloom/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/bloom/layer_weights/pre_and_post_layer_weight.py @@ -1,43 +1,25 @@ import torch import numpy as np from lightllm.common.basemodel import PreAndPostLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, NoTpNormWeight class BloomPreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) + self.pre_norm_weight_ = NoTpNormWeight( + weight_name="word_embeddings_layernorm.weight", + data_type=self.data_type_, + bias_name="word_embeddings_layernorm.bias", + ) + self.final_norm_weight_ = NoTpNormWeight( + weight_name="ln_f.weight", + data_type=self.data_type_, + bias_name="ln_f.bias", + ) - def load_hf_weights(self, weights): - - if "word_embeddings_layernorm.weight" in weights: - self.pre_norm_weight_ = self._cuda(weights["word_embeddings_layernorm.weight"]) - if "word_embeddings_layernorm.bias" in weights: - self.pre_norm_bias_ = self._cuda(weights["word_embeddings_layernorm.bias"]) - if "ln_f.weight" in weights: - self.final_norm_weight_ = self._cuda(weights["ln_f.weight"]) - if "ln_f.bias" in weights: - self.final_norm_bias_ = self._cuda(weights["ln_f.bias"]) - if "word_embeddings.weight" in weights: - vob_size = self.network_config_["vocab_size"] - split_vob_size = vob_size // self.tp_world_size_ - self.wte_weight_ = self._cuda( - weights["word_embeddings.weight"][ - split_vob_size * self.tp_rank_ : split_vob_size * (self.tp_rank_ + 1), : - ] - ) - self.lm_head_weight_ = self.wte_weight_ - return - - def verify_load(self): - errors = "weights load not ok" - weights = [ - self.pre_norm_weight_, - self.pre_norm_bias_, - self.final_norm_weight_, - self.final_norm_bias_, - self.wte_weight_, - self.lm_head_weight_, - ] - for i in range(len(weights)): - assert weights[i] is not None, "index:" + str(i) + " " + errors - return + self.wte_weight_ = EmbeddingWeight( + weight_name="word_embeddings.weight", + data_type=self.data_type_, + ) + self.lm_head_weight_ = self.wte_weight_ diff --git a/lightllm/models/bloom/model.py b/lightllm/models/bloom/model.py index 2c341a790..7e44ec2eb 100644 --- a/lightllm/models/bloom/model.py +++ b/lightllm/models/bloom/model.py @@ -1,17 +1,11 @@ -import os -import json -import torch from lightllm.models.registry import ModelRegistry from lightllm.models.bloom.layer_infer.pre_layer_infer import BloomPreLayerInfer from lightllm.models.bloom.layer_infer.post_layer_infer import BloomPostLayerInfer from lightllm.models.bloom.layer_infer.transformer_layer_infer import BloomTransformerLayerInfer from lightllm.models.bloom.layer_weights.pre_and_post_layer_weight import BloomPreAndPostLayerWeight from lightllm.models.bloom.layer_weights.transformer_layer_weight import BloomTransformerLayerWeight -from lightllm.models.bloom.layer_weights.hf_load_utils import load_hf_weights from lightllm.common.basemodel import InferStateInfo, TpPartBaseModel -from lightllm.common.build_utils import repair_config - @ModelRegistry("bloom") class BloomTpPartModel(TpPartBaseModel): @@ -41,28 +35,3 @@ def _init_config(self): def _reset_num_key_value_heads(self): self.config["num_key_value_heads"] = self.config["num_attention_heads"] return - - def _init_weights(self): - self.pre_post_weight = self.pre_and_post_weight_class( - self.data_type, network_config=self.config, mode=self.mode - ) - self.trans_layers_weight = [ - self.transformer_weight_class( - i, - self.data_type, - network_config=self.config, - mode=self.mode, - quant_cfg=self.quant_cfg, - ) - for i in range(self.config["n_layer"]) - ] - load_hf_weights( - self.data_type, - weight_dir=self.weight_dir_, - pre_post_layer=self.pre_post_weight, - transformer_layer_list=self.trans_layers_weight, - weight_dict=self.weight_dict, - ) - self.pre_post_weight.verify_load() - [weight.verify_load() for weight in self.trans_layers_weight] - return diff --git a/lightllm/models/bloom/triton_kernel/token_attention_nopad_att1.py b/lightllm/models/bloom/triton_kernel/token_attention_nopad_att1.py index 923d6d83b..174aea6c1 100644 --- a/lightllm/models/bloom/triton_kernel/token_attention_nopad_att1.py +++ b/lightllm/models/bloom/triton_kernel/token_attention_nopad_att1.py @@ -7,15 +7,27 @@ @triton.jit def _fwd_kernel_token_att1( - Q, K, sm_scale, Alibi, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, # B_Start_Loc 保存的是如果连续存储时候的累加输入和 + Q, + K, + sm_scale, + Alibi, + Req_to_tokens, + B_req_idx, + B_Start_Loc, + B_Seqlen, # B_Start_Loc 保存的是如果连续存储时候的累加输入和 Att_Out, - stride_req_to_tokens_b, stride_req_to_tokens_s, - stride_qbs, stride_qh, stride_qd, - stride_kbs, stride_kh, stride_kd, - att_stride_h, att_stride_bs, - + stride_req_to_tokens_b, + stride_req_to_tokens_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + att_stride_h, + att_stride_bs, BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr + BLOCK_N: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) @@ -36,11 +48,19 @@ def _fwd_kernel_token_att1( block_stard_index = start_n * BLOCK_N block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0) + stride_req_to_tokens_s = tl.cast(stride_req_to_tokens_s, tl.int64) + cur_batch_req_id = tl.cast(cur_batch_req_id, tl.int64) + stride_kbs = tl.cast(stride_kbs, tl.int64) + for start_mark in range(0, block_mask, 1): # 用来判断当前 mask 是否需要计算 alibi_m = tl.load(Alibi + cur_head) q = tl.load(Q + off_q + start_mark) offs_n_new = cur_batch_start_index + offs_n - k_loc = tl.load(Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_id + stride_req_to_tokens_s * offs_n_new, mask=offs_n_new < cur_batch_end_index, other=0) + k_loc = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_id + stride_req_to_tokens_s * offs_n_new, + mask=offs_n_new < cur_batch_end_index, + other=0, + ) off_k = k_loc[:, None] * stride_kbs + cur_head * stride_kh + offs_d[None, :] * stride_kd k = tl.load(K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0) att_value = tl.sum(q[None, :] * k, 1) @@ -68,12 +88,25 @@ def token_att_fwd(q, k, att_out, alibi, Req_to_tokens, B_req_idx, B_Start_Loc, B num_warps = 2 _fwd_kernel_token_att1[grid]( - q, k, sm_scale, alibi, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, + q, + k, + sm_scale, + alibi, + Req_to_tokens, + B_req_idx, + B_Start_Loc, + B_Seqlen, att_out, - Req_to_tokens.stride(0), Req_to_tokens.stride(1), - q.stride(0), q.stride(1), q.stride(2), - k.stride(0), k.stride(1), k.stride(2), - att_out.stride(0), att_out.stride(1), + Req_to_tokens.stride(0), + Req_to_tokens.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + att_out.stride(0), + att_out.stride(1), BLOCK_DMODEL=Lk, BLOCK_N=BLOCK, num_warps=num_warps, @@ -88,7 +121,9 @@ def torch_att(xq, xk, bs, seqlen, num_head, head_dim): keys = xk xq = xq.transpose(1, 2) keys = keys.transpose(1, 2) - scores = (torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(head_dim)).squeeze().transpose(0, 1).reshape(num_head, -1) + scores = ( + (torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(head_dim)).squeeze().transpose(0, 1).reshape(num_head, -1) + ) print("s ", scores.shape) return scores @@ -99,4 +134,4 @@ def torch_att1(xq, xk, seqlen, num_head, head_dim): logics = torch.sum(xq * xk, dim=-1, keepdim=False) logics = logics.transpose(0, 1) / math.sqrt(head_dim) - return logics \ No newline at end of file + return logics diff --git a/lightllm/models/bloom/triton_kernel/token_attention_nopad_reduceV.py b/lightllm/models/bloom/triton_kernel/token_attention_nopad_reduceV.py index 3960083fa..8ee01c340 100644 --- a/lightllm/models/bloom/triton_kernel/token_attention_nopad_reduceV.py +++ b/lightllm/models/bloom/triton_kernel/token_attention_nopad_reduceV.py @@ -29,6 +29,9 @@ def _fwd_kernel_token_att2( cur_batch = tl.program_id(0) cur_head = tl.program_id(1) + stride_vbs = tl.cast(stride_vbs, tl.int64) + stride_pbs = tl.cast(stride_pbs, tl.int64) + offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_DMODEL) cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) diff --git a/lightllm/models/chatglm2/layer_infer/transformer_layer_infer.py b/lightllm/models/chatglm2/layer_infer/transformer_layer_infer.py index 89061dae3..07ffc4bea 100755 --- a/lightllm/models/chatglm2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/chatglm2/layer_infer/transformer_layer_infer.py @@ -1,17 +1,8 @@ import torch -import torch.functional as F -import torch.distributed as dist -import numpy as np - -from lightllm.utils.infer_utils import mark_cost_time from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.models.chatglm2.layer_weights.transformer_layer_weight import ChatGLM2TransformerLayerWeight -from lightllm.models.chatglm2.triton_kernel.rotary_emb import rotary_emb_fwd -from lightllm.common.basemodel.triton_kernel.destindex_copy_kv import destindex_copy_kv, destindex_copy_quantize_kv -from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward - class ChatGLM2TransformerLayerInfer(LlamaTransformerLayerInfer): """ """ diff --git a/lightllm/models/chatglm2/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/chatglm2/layer_weights/pre_and_post_layer_weight.py index 0b125bea3..0139eb883 100644 --- a/lightllm/models/chatglm2/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/chatglm2/layer_weights/pre_and_post_layer_weight.py @@ -1,31 +1,20 @@ -import torch -import numpy as np -from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight +from lightllm.common.basemodel import PreAndPostLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, LMHeadWeight, NoTpNormWeight -class ChatGLM2PreAndPostLayerWeight(LlamaPreAndPostLayerWeight): +class ChatGLM2PreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) - def load_hf_weights(self, weights): - # input layernorm params - - vob_size = self.network_config_["padded_vocab_size"] - split_vob_size = vob_size // self.tp_world_size_ - if "transformer.embedding.word_embeddings.weight" in weights: - self.wte_weight_ = weights["transformer.embedding.word_embeddings.weight"] - self.wte_weight_ = self.wte_weight_[ - split_vob_size * self.tp_rank_ : split_vob_size * (self.tp_rank_ + 1), : - ] - self.wte_weight_ = self._cuda(self.wte_weight_) - if "transformer.output_layer.weight" in weights: - self.lm_head_weight_ = weights["transformer.output_layer.weight"] - self.lm_head_weight_ = self.lm_head_weight_[ - split_vob_size * self.tp_rank_ : split_vob_size * (self.tp_rank_ + 1), : - ] - self.lm_head_weight_ = self._cuda(self.lm_head_weight_) - if "transformer.encoder.final_layernorm.weight" in weights: - self.final_norm_weight_ = weights["transformer.encoder.final_layernorm.weight"] - self.final_norm_weight_ = self._cuda(self.final_norm_weight_) - - return + self.wte_weight_ = EmbeddingWeight( + weight_name="transformer.embedding.word_embeddings.weight", data_type=self.data_type_ + ) + self.lm_head_weight_ = LMHeadWeight( + weight_name="transformer.output_layer.weight", + data_type=self.data_type_, + ) + self.final_norm_weight_ = NoTpNormWeight( + weight_name="transformer.encoder.final_layernorm.weight", + data_type=self.data_type_, + bias_name=None, + ) diff --git a/lightllm/models/cohere/layer_infer/post_layer_infer.py b/lightllm/models/cohere/layer_infer/post_layer_infer.py index 8b9d4b268..67987a8d3 100644 --- a/lightllm/models/cohere/layer_infer/post_layer_infer.py +++ b/lightllm/models/cohere/layer_infer/post_layer_infer.py @@ -1,92 +1,44 @@ import torch -import torch.distributed as dist import numpy as np - from lightllm.models.cohere.infer_struct import CohereInferStateInfo from lightllm.models.cohere.layer_weights.pre_and_post_layer_weight import CoherePreAndPostLayerWeight from lightllm.models.cohere.triton_kernels.layernorm import layernorm_forward -from lightllm.common.basemodel.layer_weights.base_layer_weight import BaseLayerWeight - -from einops import rearrange -from lightllm.common.basemodel import PostLayerInferTpl +from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer +from lightllm.common.build_utils import repair_config from lightllm.distributed.communication_op import all_gather -class CoherePostLayerInfer(PostLayerInferTpl): +class CoherePostLayerInfer(LlamaPostLayerInfer): def __init__(self, network_config, mode): + repair_config(config=network_config, same_names=["layer_norm_eps", "rms_norm_eps"]) super().__init__(network_config, mode) self.eps_ = network_config["layer_norm_eps"] - self.vocab_size_ = network_config["vocab_size"] - self.embed_dim_ = network_config["n_embed"] self.logits_scale = network_config["logit_scale"] return - def _norm(self, input, infer_state, layer_weight: CoherePreAndPostLayerWeight) -> torch.Tensor: + def _norm( + self, input: torch.Tensor, infer_state: CohereInferStateInfo, layer_weight: CoherePreAndPostLayerWeight + ) -> torch.Tensor: return layernorm_forward( - input.unsqueeze(1), layer_weight.final_norm_weight_.unsqueeze(0), eps=self.eps_ + input.unsqueeze(1), layer_weight.final_norm_weight_.weight.unsqueeze(0), eps=self.eps_ ).squeeze(1) - def _slice_get_last_input(self, input_embdings, infer_state: CohereInferStateInfo): - - if infer_state.is_prefill and infer_state.is_token_healing: - batch_size = infer_state.batch_size - b_seq_len_numpy = (infer_state.b_seq_len - infer_state.b_ready_cache_len).detach().cpu().numpy() - select_index = [] - start_index = 0 - select_token_num = 0 - for cur_len in b_seq_len_numpy: - - select_index.append(start_index + cur_len - 1) - start_index += cur_len - select_token_num += 1 - - last_index = torch.tensor(select_index, dtype=torch.long, device=input_embdings.device) - last_input = self.alloc_tensor( - (select_token_num, self.embed_dim_), device=input_embdings.device, dtype=input_embdings.dtype - ) - - last_input[:, :] = input_embdings[last_index, :] - return last_input, select_token_num - - if infer_state.is_prefill and not infer_state.return_all_prompt_logics: - batch_size = infer_state.batch_size - last_input = self.alloc_tensor( - (batch_size, self.embed_dim_), device=input_embdings.device, dtype=input_embdings.dtype - ) - last_index = ( - torch.cumsum(infer_state.b_seq_len - infer_state.b_ready_cache_len, dim=0, dtype=torch.long) - 1 - ) - last_input[:, :] = input_embdings[last_index, :] - return last_input, batch_size - - if infer_state.is_prefill and infer_state.return_all_prompt_logics: - total_tokens = infer_state.total_token_num - return input_embdings, total_tokens - - if not infer_state.is_prefill: - batch_size = infer_state.batch_size - return input_embdings[-batch_size:, :], batch_size - - assert False, "Error State" - def token_forward( - self, input_embdings, infer_state: CohereInferStateInfo, layer_weight: CoherePreAndPostLayerWeight + self, input_embdings: torch.Tensor, infer_state: CohereInferStateInfo, layer_weight: CoherePreAndPostLayerWeight ): last_input, token_num = self._slice_get_last_input(input_embdings, infer_state) input_embdings_dtype = input_embdings.dtype input_embdings = None last_input = self._norm(last_input, infer_state, layer_weight) - last_input = rearrange(last_input, "batch embed_dim -> embed_dim batch").contiguous().reshape(-1, token_num) - logic_batch = torch.mm(layer_weight.lm_head_weight_, last_input) - + last_input = last_input.permute(1, 0).view(-1, token_num) + logic_batch = layer_weight.lm_head_weight_.lm_head(input=last_input, alloc_func=self.alloc_tensor) last_input = None + vocab_size = layer_weight.lm_head_weight_.vocab_size if self.tp_world_size_ == 1: gather_data = logic_batch else: - gather_data = self.alloc_tensor( - (self.vocab_size_, token_num), device=logic_batch.device, dtype=input_embdings_dtype - ) - split_indexes = np.linspace(0, self.vocab_size_, self.tp_world_size_ + 1, dtype=np.int64) + gather_data = self.alloc_tensor((vocab_size, token_num), dtype=input_embdings_dtype) + split_indexes = np.linspace(0, vocab_size, self.tp_world_size_ + 1, dtype=np.int64) all_gather( [gather_data[split_indexes[i] : split_indexes[i + 1], :] for i in range(self.tp_world_size_)], logic_batch, @@ -95,7 +47,25 @@ def token_forward( ) gather_data = gather_data * self.logits_scale logic_batch = None - - ans_logics = gather_data.permute(1, 0).float() + ans_logics = self.alloc_tensor( + (token_num, vocab_size), + dtype=torch.float32, + ) + ans_logics[:, :] = gather_data.permute(1, 0) gather_data = None return ans_logics + + def tpsp_token_forward( + self, input_embdings: torch.Tensor, infer_state: CohereInferStateInfo, layer_weight: CoherePreAndPostLayerWeight + ): + raise NotImplementedError("not impl") + + def overlap_tpsp_token_forward( + self, + input_embdings: torch.Tensor, + input_embdings1: torch.Tensor, + infer_state: CohereInferStateInfo, + infer_state1: CohereInferStateInfo, + layer_weight: CoherePreAndPostLayerWeight, + ): + raise NotImplementedError("not impl") diff --git a/lightllm/models/cohere/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/cohere/layer_weights/pre_and_post_layer_weight.py index 993acd64d..f2e5f8547 100644 --- a/lightllm/models/cohere/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/cohere/layer_weights/pre_and_post_layer_weight.py @@ -1,36 +1,25 @@ -import torch -import numpy as np +from lightllm.common.basemodel import PreAndPostLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, LMHeadWeight, NoTpNormWeight -from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight - -class CoherePreAndPostLayerWeight(LlamaPreAndPostLayerWeight): - def load_hf_weights(self, weights): - vob_size = self.network_config_["vocab_size"] +class CoherePreAndPostLayerWeight(PreAndPostLayerWeight): + def __init__(self, data_type, network_config, mode): + super().__init__(data_type, network_config, mode) tie_weight = self.network_config_.get("tie_word_embeddings", True) - split_indexes = np.linspace(0, vob_size, self.tp_world_size_ + 1, dtype=np.int64) - split_start = split_indexes[self.tp_rank_] - split_end = split_indexes[self.tp_rank_ + 1] - if "model.embed_tokens.weight" in weights: - # print(weights['model.embed_tokens.weight'].shape) - self.wte_weight_ = self._cuda(weights["model.embed_tokens.weight"][split_start:split_end, :]) - if tie_weight: - self.lm_head_weight_ = self.wte_weight_ - if "model.norm.weight" in weights: - self.final_norm_weight_ = self._cuda(weights["model.norm.weight"]) - if "model.lm_head.weight" in weights: - self.lm_head_weight_ = self._cuda(weights["model.lm_head.weight"]) - return - - def verify_load(self): - super().verify_load() - errors = "tie weights load not ok" - tie_weight = self.network_config_.get("tie_word_embeddings", True) + self.wte_weight_ = EmbeddingWeight( + weight_name="model.embed_tokens.weight", + data_type=self.data_type_, + ) if tie_weight: - assert self.lm_head_weight_ is not None, errors - assert self.wte_weight_ is self.lm_head_weight_, errors + self.lm_head_weight_ = self.wte_weight_ else: - assert self.lm_head_weight_ is not None, errors - assert self.wte_weight_ is not None, errors - assert self.wte_weight_ is not self.lm_head_weight_, errors + self.lm_head_weight_ = LMHeadWeight( + weight_name="model.lm_head.weight", + data_type=self.data_type_, + ) + self.final_norm_weight_ = NoTpNormWeight( + weight_name="model.norm.weight", + data_type=self.data_type_, + bias_name=None, + ) diff --git a/lightllm/models/cohere/layer_weights/transformer_layer_weight.py b/lightllm/models/cohere/layer_weights/transformer_layer_weight.py index fff92abf5..9c446b49e 100644 --- a/lightllm/models/cohere/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/cohere/layer_weights/transformer_layer_weight.py @@ -1,8 +1,5 @@ from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight -from lightllm.common.basemodel.layer_weights.meta_weights import ( - NormWeight, - TpNormWeight, -) +from lightllm.common.basemodel.layer_weights.meta_weights import NoTpNormWeight, TpHeadNormWeight class CohereTransformerLayerWeight(LlamaTransformerLayerWeight): @@ -14,18 +11,15 @@ def _parse_config(self): super()._parse_config() self.use_qk_norm = self.network_config_.get("use_qk_norm", False) - def _init_norm(self, weights): - q_split_head = self.network_config_["num_attention_heads"] // self.tp_world_size_ - k_split_head = self.network_config_["num_key_value_heads"] // self.tp_world_size_ - - self.att_norm_weight_ = NormWeight(self._att_norm_weight_name, self.data_type_) + def _init_norm(self): + self.att_norm_weight_ = NoTpNormWeight(self._att_norm_weight_name, self.data_type_) if self.use_qk_norm: - self.q_norm_weight_ = TpNormWeight( - f"model.layers.{self.layer_num_}.self_attn.q_norm.weight", self.data_type_, q_split_head + self.q_norm_weight_ = TpHeadNormWeight( + f"model.layers.{self.layer_num_}.self_attn.q_norm.weight", self.data_type_ ) - self.k_norm_weight_ = TpNormWeight( - f"model.layers.{self.layer_num_}.self_attn.k_norm.weight", self.data_type_, k_split_head + self.k_norm_weight_ = TpHeadNormWeight( + f"model.layers.{self.layer_num_}.self_attn.k_norm.weight", self.data_type_ ) return diff --git a/lightllm/models/deepseek2/flashattention_infer_struct.py b/lightllm/models/deepseek2/flashattention_infer_struct.py index 0725fa337..72ba8a43b 100644 --- a/lightllm/models/deepseek2/flashattention_infer_struct.py +++ b/lightllm/models/deepseek2/flashattention_infer_struct.py @@ -23,8 +23,8 @@ def get_page_table_buffer(cls, graph_max_batch_size: int, max_seq_len: int): ] return cls._shared_page_table_buffer - def init_some_extra_state(self, model, input_ids: torch.Tensor): - super().init_some_extra_state(model, input_ids) + def init_some_extra_state(self, model): + super().init_some_extra_state(model) args_mtp_step = get_env_start_args().mtp_step if self.is_prefill: self.cu_seqlens_q = self.b1_cu_q_seq_len @@ -51,7 +51,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): ].view(att_batch_size, model.graph_max_len_in_batch) else: self.page_table = torch.empty((att_batch_size, self.max_len_in_batch), dtype=torch.int32).to( - input_ids.device + self.input_ids.device ) page_table_copy( page_table=self.page_table[:, :max_seq_len_k], diff --git a/lightllm/models/deepseek2/flashinfer_struct.py b/lightllm/models/deepseek2/flashinfer_struct.py index a00c45601..db6386f79 100644 --- a/lightllm/models/deepseek2/flashinfer_struct.py +++ b/lightllm/models/deepseek2/flashinfer_struct.py @@ -14,15 +14,15 @@ def __init__(self): self.decode_wrapper = None self.flashinfer_extra_state = None - def init_some_extra_state(self, model, input_ids: torch.Tensor): - super().init_some_extra_state(model, input_ids) + def init_some_extra_state(self, model): + super().init_some_extra_state(model) self.flashinfer_extra_state = model.flashinfer_extra_state import flashinfer if not self.is_prefill: if get_env_start_args().enable_flashinfer_decode: - self.q_indptr = torch.arange(self.batch_size + 1, dtype=torch.int32).to(input_ids.device) + self.q_indptr = torch.arange(self.batch_size + 1, dtype=torch.int32).to(self.input_ids.device) if self.batch_size <= model.graph_max_batch_size: self.kv_indices = self.flashinfer_extra_state.kv_indices_buffer[self.microbatch_index][ : self.batch_size * self.flashinfer_extra_state.max_seq_length @@ -31,7 +31,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): self.kv_indices = torch.empty( self.batch_size * self.flashinfer_extra_state.max_seq_length, dtype=torch.int32, - device=input_ids.device, + device=self.input_ids.device, ) repack_kv_index( self.req_manager.req_to_token_indexs, diff --git a/lightllm/models/deepseek2/infer_struct.py b/lightllm/models/deepseek2/infer_struct.py index f05f52f2f..0c2ef3048 100644 --- a/lightllm/models/deepseek2/infer_struct.py +++ b/lightllm/models/deepseek2/infer_struct.py @@ -10,8 +10,8 @@ def __init__(self): super().__init__() self.kv_starts = None - def init_some_extra_state(self, model, input_ids: torch.Tensor): - super().init_some_extra_state(model, input_ids) + def init_some_extra_state(self, model): + super().init_some_extra_state(model) if not self.is_prefill: self.kv_starts = self.b1_cu_kv_seq_len diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index 30d37d1df..ff20bc6ee 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -18,7 +18,6 @@ from lightllm.models.deepseek2.triton_kernel.gqa_flash_decoding import gqa_token_decode_attention_flash_decoding from lightllm.models.deepseek2.triton_kernel.gqa_flash_decoding_fp8 import gqa_token_decode_attention_flash_decoding_fp8 from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer -from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo from lightllm.models.deepseek2.flashinfer_struct import Deepseek2FlashInferStateInfo @@ -158,14 +157,18 @@ def _get_qkv( q, cache_kv = layer_weight.qkv_a_proj_with_mqa_.mm(input).split( [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1 ) - q = rmsnorm_forward(q, weight=layer_weight.q_a_layernorm_.weight, eps=self.eps_) + q = layer_weight.q_a_layernorm_.rmsnorm_forward( + input=q, + eps=self.eps_, + alloc_func=self.alloc_tensor, + ) q = layer_weight.q_b_proj_.mm(q) cache_kv = cache_kv.view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim) q = q.view(-1, self.tp_q_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim) q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - rmsnorm_forward( + + layer_weight.kv_a_layernorm_.rmsnorm_forward( cache_kv[:, :, : self.kv_lora_rank], - weight=layer_weight.kv_a_layernorm_.weight, eps=self.eps_, out=cache_kv[:, :, : self.kv_lora_rank], ) @@ -190,16 +193,15 @@ def _tpsp_get_qkv( (sp_token_num * self.tp_world_size_, hidden_dim), dtype=input.dtype, device=input.device ) all_gather_into_tensor(gather_input, input, group=infer_state.dist_group, async_op=False) - input = gather_input[0 : len(infer_state.position_cos), :] + input = gather_input[0 : len(infer_state.input_ids), :] input = input.view(-1, self.embed_dim_) q = layer_weight.q_weight_.mm(input) cache_kv = layer_weight.kv_a_proj_with_mqa_.mm(input).view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim) q = q.view(-1, self.tp_q_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim) q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - rmsnorm_forward( + layer_weight.kv_a_layernorm_.rmsnorm_forward( cache_kv[:, :, : self.kv_lora_rank], - weight=layer_weight.kv_a_layernorm_.weight, eps=self.eps_, out=cache_kv[:, :, : self.kv_lora_rank], ) @@ -223,7 +225,7 @@ def _tpsp_get_qkv( (sp_token_num * self.tp_world_size_, qkv_dim), dtype=qkv.dtype, device=qkv.device ) all_gather_into_tensor(gather_qkv, qkv, group=infer_state.dist_group, async_op=False) - qkv = gather_qkv[0 : len(infer_state.position_cos), :] + qkv = gather_qkv[0 : len(infer_state.input_ids), :] if infer_state.need_dp_prefill_balance: qkv = infer_state._all_to_all_unbalance_get(data=qkv) @@ -234,14 +236,17 @@ def _tpsp_get_qkv( position_sin = infer_state.position_sin q, cache_kv = qkv.split([self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1) - q = rmsnorm_forward(q, weight=layer_weight.q_a_layernorm_.weight, eps=self.eps_) + q = layer_weight.q_a_layernorm_.rmsnorm_forward( + q, + eps=self.eps_, + alloc_func=self.alloc_tensor, + ) q = layer_weight.q_b_proj_.mm(q) cache_kv = cache_kv.view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim) q = q.view(-1, self.tp_q_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim) q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - rmsnorm_forward( + layer_weight.kv_a_layernorm_.rmsnorm_forward( cache_kv[:, :, : self.kv_lora_rank], - weight=layer_weight.kv_a_layernorm_.weight, eps=self.eps_, out=cache_kv[:, :, : self.kv_lora_rank], ) @@ -273,8 +278,8 @@ def _tpsp_get_o( input = input.reshape(-1, self.tp_q_head_num_ * self.qk_nope_head_dim) dest_size = triton.cdiv(input.shape[0], self.tp_world_size_) * self.tp_world_size_ o_tensor = self.alloc_tensor((dest_size, self.embed_dim_), dtype=input.dtype, device=input.device) - layer_weight.o_weight_.mm(input, out=o_tensor[0 : len(infer_state.position_cos), :]) - e_o_tensor = o_tensor[len(infer_state.position_cos) :, :] + layer_weight.o_weight_.mm(input, out=o_tensor[0 : len(infer_state.input_ids), :]) + e_o_tensor = o_tensor[len(infer_state.input_ids) :, :] if e_o_tensor.shape[0] > 0: e_o_tensor.fill_(0) diff --git a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py index c899751eb..611878f9e 100644 --- a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py @@ -7,7 +7,7 @@ from lightllm.common.basemodel.layer_weights.meta_weights import ( ROWMMWeight, COLMMWeight, - NormWeight, + NoTpNormWeight, FusedMoeWeightEP, ROWBMMWeight, create_tp_moe_wegiht_obj, @@ -299,14 +299,16 @@ def _init_ffn(self): self._load_mlp(f"model.layers.{self.layer_num_}.mlp") def _init_norm(self): - self.att_norm_weight_ = NormWeight(f"model.layers.{self.layer_num_}.input_layernorm.weight", self.data_type_) - self.ffn_norm_weight_ = NormWeight( + self.att_norm_weight_ = NoTpNormWeight( + f"model.layers.{self.layer_num_}.input_layernorm.weight", self.data_type_ + ) + self.ffn_norm_weight_ = NoTpNormWeight( f"model.layers.{self.layer_num_}.post_attention_layernorm.weight", self.data_type_ ) - self.kv_a_layernorm_ = NormWeight( + self.kv_a_layernorm_ = NoTpNormWeight( f"model.layers.{self.layer_num_}.self_attn.kv_a_layernorm.weight", self.data_type_ ) if self.q_lora_rank is not None: - self.q_a_layernorm_ = NormWeight( + self.q_a_layernorm_ = NoTpNormWeight( f"model.layers.{self.layer_num_}.self_attn.q_a_layernorm.weight", self.data_type_ ) diff --git a/lightllm/models/deepseek2/model.py b/lightllm/models/deepseek2/model.py index 6dfd88970..e4ce7c826 100644 --- a/lightllm/models/deepseek2/model.py +++ b/lightllm/models/deepseek2/model.py @@ -12,7 +12,7 @@ from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class from lightllm.utils.log_utils import init_logger from lightllm.models.llama.yarn_rotary_utils import get_deepseek_mscale -from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args +from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args, get_added_mtp_kv_layer_num from lightllm.distributed.communication_op import dist_group_manager from lightllm.utils.dist_utils import get_dp_world_size, get_current_device_id @@ -95,61 +95,16 @@ def _verify_params(self): def _init_mem_manager(self): manager_class = select_mem_manager_class() - # mtp 模式下需要在mem manger上扩展draft model使用的layer - added_mtp_layer_num = 0 - if get_env_start_args().mtp_mode == "deepseekv3_eagle": - added_mtp_layer_num += 1 - elif get_env_start_args().mtp_mode == "deepseekv3_vanilla": - added_mtp_layer_num += get_env_start_args().mtp_step - self.mem_manager = manager_class( self.max_total_token_num, dtype=self.data_type, head_num=1, head_dim=self.config["kv_lora_rank"] + self.config["qk_rope_head_dim"], - layer_num=self.config["num_hidden_layers"] + added_mtp_layer_num, + layer_num=self.config["num_hidden_layers"] + get_added_mtp_kv_layer_num(), mem_fraction=self.mem_fraction, ) return - def _init_weights(self): - self.pre_post_weight = self.pre_and_post_weight_class( - self.data_type, network_config=self.config, mode=self.mode - ) - self.trans_layers_weight = [ - self.transformer_weight_class( - i, - self.data_type, - network_config=self.config, - mode=self.mode, - quant_cfg=self.quant_cfg, - ) - for i in range(self.config["n_layer"]) - ] - load_hf_weights( - self.data_type, - weight_dir=self.weight_dir_, - pre_post_layer=self.pre_post_weight, - transformer_layer_list=self.trans_layers_weight, - weight_dict=self.weight_dict, - ) - self.pre_post_weight.verify_load() - [weight.verify_load() for weight in self.trans_layers_weight] - return - - def _init_infer_layer(self): - self.pre_infer = self.pre_layer_infer_class(network_config=self.config, mode=self.mode) - self.post_infer = self.post_layer_infer_class(network_config=self.config, mode=self.mode) - self.layers_infer = [ - self.transformer_layer_infer_class( - i, - network_config=self.config, - mode=self.mode, - ) - for i in range(self.config["n_layer"]) - ] - return - def _init_to_get_yarn_rotary(self): from lightllm.models.llama.yarn_rotary_utils import find_correction_range, linear_ramp_mask, get_deepseek_mscale @@ -191,8 +146,3 @@ def _init_to_get_yarn_rotary(self): self._sin_cached = (freqs.sin() * _mscale).to(self.data_type).cuda() return - - @final - def _context_forward(self, input_ids, infer_state): - predict_logics = super()._context_forward(input_ids, infer_state) - return predict_logics diff --git a/lightllm/models/deepseek_mtp/layer_infer/pre_layer_infer.py b/lightllm/models/deepseek_mtp/layer_infer/pre_layer_infer.py index 0991103e1..26bfc865e 100644 --- a/lightllm/models/deepseek_mtp/layer_infer/pre_layer_infer.py +++ b/lightllm/models/deepseek_mtp/layer_infer/pre_layer_infer.py @@ -3,7 +3,6 @@ from lightllm.models.deepseek_mtp.layer_weights.pre_and_post_layer_weight import Deepseek3MTPPreAndPostLayerWeight from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer -from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward class Deepseek3MTPPreLayerInfer(LlamaPreLayerInfer): @@ -18,35 +17,45 @@ def __init__(self, network_config, mode): def _mtp_context_forward( self, input_embdings, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek3MTPPreAndPostLayerWeight ): - tgt_embdings = infer_state.deepseekv3_mtp_draft_input_hiddens + tgt_embdings = infer_state.mtp_draft_input_hiddens assert ( input_embdings.shape[0] == tgt_embdings.shape[0] ), f"shape {input_embdings.shape} != shape {tgt_embdings.shape}" - rmsnorm_forward(input_embdings, weight=layer_weight.enorm_weight_, eps=self.eps_, out=input_embdings) - rmsnorm_forward(tgt_embdings, weight=layer_weight.hnorm_weight_, eps=self.eps_, out=tgt_embdings) + layer_weight.enorm_weight_.rmsnorm_forward( + input=input_embdings, + eps=self.eps_, + out=input_embdings, + ) + layer_weight.hnorm_weight_.rmsnorm_forward( + input=tgt_embdings, + eps=self.eps_, + out=tgt_embdings, + ) cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1) - ans_logics = self.alloc_tensor( - (cat_embdings.shape[0], layer_weight.eh_proj_weight_.shape[1]), dtype=input_embdings.dtype - ) - torch.mm(cat_embdings, layer_weight.eh_proj_weight_, out=ans_logics) + ans_logics = layer_weight.eh_proj_weight_.mm(cat_embdings) return ans_logics def _mtp_token_forward( self, input_embdings, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek3MTPPreAndPostLayerWeight ): - tgt_embdings = infer_state.deepseekv3_mtp_draft_input_hiddens + tgt_embdings = infer_state.mtp_draft_input_hiddens assert input_embdings.shape[0] == tgt_embdings.shape[0] - rmsnorm_forward(input_embdings, weight=layer_weight.enorm_weight_, eps=self.eps_, out=input_embdings) - rmsnorm_forward(tgt_embdings, weight=layer_weight.hnorm_weight_, eps=self.eps_, out=tgt_embdings) + layer_weight.enorm_weight_.rmsnorm_forward( + input=input_embdings, + eps=self.eps_, + out=input_embdings, + ) + layer_weight.hnorm_weight_.rmsnorm_forward( + input=tgt_embdings, + eps=self.eps_, + out=tgt_embdings, + ) cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1) - ans_logics = self.alloc_tensor( - (cat_embdings.shape[0], layer_weight.eh_proj_weight_.shape[1]), dtype=input_embdings.dtype - ) - torch.mm(cat_embdings, layer_weight.eh_proj_weight_, out=ans_logics) + ans_logics = layer_weight.eh_proj_weight_.mm(cat_embdings) return ans_logics def context_forward( diff --git a/lightllm/models/deepseek_mtp/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/deepseek_mtp/layer_weights/pre_and_post_layer_weight.py index f5b805647..4a5bf2e96 100644 --- a/lightllm/models/deepseek_mtp/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/deepseek_mtp/layer_weights/pre_and_post_layer_weight.py @@ -1,29 +1,40 @@ -import numpy as np -from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight +from lightllm.common.basemodel import PreAndPostLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ( + EmbeddingWeight, + LMHeadWeight, + NoTpNormWeight, + ROWMMWeight, +) -class Deepseek3MTPPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): +class Deepseek3MTPPreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) - # 与DeepseekV3模型共享 - self.wte_weight_ = None - self.lm_head_weight_ = None - return - def load_hf_weights(self, weights): - if "model.layers.0.eh_proj.weight" in weights: - self.eh_proj_weight_ = self._cuda(weights["model.layers.0.eh_proj.weight"]).t() - if "model.layers.0.enorm.weight" in weights: - self.enorm_weight_ = self._cuda(weights["model.layers.0.enorm.weight"]) - if "model.layers.0.hnorm.weight" in weights: - self.hnorm_weight_ = self._cuda(weights["model.layers.0.hnorm.weight"]) - if "model.layers.0.shared_head.norm.weight" in weights: - self.final_norm_weight_ = self._cuda(weights["model.layers.0.shared_head.norm.weight"]) - return + self.eh_proj_weight_ = ROWMMWeight( + weight_names="model.layers.0.eh_proj.weight", + data_type=self.data_type_, + name="eh_proj", + tp_rank=0, + tp_world_size=1, + ) + self.enorm_weight_ = NoTpNormWeight( + weight_name="model.layers.0.enorm.weight", + data_type=self.data_type_, + bias_name=None, + ) + self.hnorm_weight_ = NoTpNormWeight( + weight_name="model.layers.0.hnorm.weight", + data_type=self.data_type_, + bias_name=None, + ) + self.final_norm_weight_ = NoTpNormWeight( + weight_name="model.layers.0.shared_head.norm.weight", + data_type=self.data_type_, + bias_name=None, + ) - def verify_load(self): - errors = "weights load not ok" - weights = [self.eh_proj_weight_, self.enorm_weight_, self.hnorm_weight_, self.final_norm_weight_] - for i in range(len(weights)): - assert weights[i] is not None, "index:" + str(i) + " " + errors + # 与DeepseekV3模型共享, 不通过 load 加载 + self.wte_weight_: EmbeddingWeight = None + self.lm_head_weight_: LMHeadWeight = None return diff --git a/lightllm/models/deepseek_mtp/model.py b/lightllm/models/deepseek_mtp/model.py index 2e2e95187..0204e292a 100644 --- a/lightllm/models/deepseek_mtp/model.py +++ b/lightllm/models/deepseek_mtp/model.py @@ -1,3 +1,4 @@ +from typing import List from lightllm.models.deepseek2.model import Deepseek2TpPartModel from lightllm.models.deepseek_mtp.layer_infer.pre_layer_infer import Deepseek3MTPPreLayerInfer from lightllm.models.deepseek_mtp.layer_weights.pre_and_post_layer_weight import Deepseek3MTPPreAndPostLayerWeight @@ -16,7 +17,7 @@ def __init__(self, kvargs: dict): def _pre_init(self, kvargs: dict): self.main_model: TpPartBaseModel = kvargs.pop("main_model") - self.mem_layer_start = kvargs.pop("mem_layer_start", 0) + self.mtp_previous_draft_models: List[TpPartBaseModel] = kvargs.pop("mtp_previous_draft_models") return def _init_custom(self): @@ -32,15 +33,18 @@ def _init_mem_manager(self): self.mem_manager = self.main_model.mem_manager return - def _init_weights(self): - super()._init_weights() + def _init_weights(self, start_layer_index=None): + assert start_layer_index is None + super()._init_weights(start_layer_index=0) self.pre_post_weight.wte_weight_ = self.main_model.pre_post_weight.wte_weight_ self.pre_post_weight.lm_head_weight_ = self.main_model.pre_post_weight.lm_head_weight_ return - def _init_infer_layer(self): - super()._init_infer_layer() - # reset the layer_num_ of the self.layers_infer - for layer in self.layers_infer: - layer.layer_num_ = layer.layer_num_ + self.mem_layer_start + def _init_infer_layer(self, start_layer_index=None): + assert start_layer_index is None + total_pre_layers_num = len(self.main_model.layers_infer) + total_pre_layers_num += sum( + [len(previous_model.layers_infer) for previous_model in self.mtp_previous_draft_models] + ) + super()._init_infer_layer(start_layer_index=total_pre_layers_num) return diff --git a/lightllm/models/gemma3/infer_struct.py b/lightllm/models/gemma3/infer_struct.py index 4145124af..33bd44815 100644 --- a/lightllm/models/gemma3/infer_struct.py +++ b/lightllm/models/gemma3/infer_struct.py @@ -12,8 +12,8 @@ def __init__(self): self.position_sin_local = None self.position_cos_local = None - def init_some_extra_state(self, model, input_ids: torch.Tensor): - super().init_some_extra_state(model, input_ids) + def init_some_extra_state(self, model): + super().init_some_extra_state(model) if self.is_prefill: self.max_seq_len = self.max_kv_seq_len position_ids = self.position_ids diff --git a/lightllm/models/gemma3/layer_infer/post_layer_infer.py b/lightllm/models/gemma3/layer_infer/post_layer_infer.py index 8004309ed..22dc59505 100644 --- a/lightllm/models/gemma3/layer_infer/post_layer_infer.py +++ b/lightllm/models/gemma3/layer_infer/post_layer_infer.py @@ -1,9 +1,4 @@ -import numpy as np -import torch - -from lightllm.distributed.communication_op import all_gather from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer -from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight class Gemma3PostLayerInfer(LlamaPostLayerInfer): @@ -13,46 +8,3 @@ def __init__(self, network_config, mode): super().__init__(network_config, mode) self.eps_ = 1e-6 return - - def gemma3_rmsnorm(self, input, weight, eps: float = 1e-6, out=None): - def _inner_norm(x): - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) - - output = _inner_norm(input.float()) - output = output * (1.0 + weight.float()) - if out is not None: - out = output.to(out.dtype) - return output - - def _norm(self, input, infer_state, layer_weight) -> torch.Tensor: - return self.gemma3_rmsnorm(input, layer_weight.final_norm_weight_, eps=self.eps_) - - def token_forward(self, input_embdings, infer_state, layer_weight): - last_input, token_num = self._slice_get_last_input(input_embdings, infer_state) - input_embdings_dtype = input_embdings.dtype - last_input = self._norm(last_input.float(), infer_state, layer_weight).to(torch.bfloat16) - last_input = last_input.permute(1, 0).view(-1, token_num) - logic_batch = self.alloc_tensor( - (layer_weight.lm_head_weight_.shape[0], last_input.shape[1]), dtype=last_input.dtype - ) - torch.mm(layer_weight.lm_head_weight_.to(last_input.dtype), last_input, out=logic_batch) - last_input = None - if self.tp_world_size_ == 1: - gather_data = logic_batch - else: - gather_data = self.alloc_tensor((self.vocab_size_, token_num), dtype=input_embdings_dtype) - split_indexes = np.linspace(0, self.vocab_size_, self.tp_world_size_ + 1, dtype=np.int64) - all_gather( - [gather_data[split_indexes[i] : split_indexes[i + 1], :] for i in range(self.tp_world_size_)], - logic_batch, - group=infer_state.dist_group, - async_op=False, - ) - logic_batch = None - ans_logics = self.alloc_tensor( - (token_num, self.vocab_size_), - dtype=torch.float32, - ) - ans_logics[:, :] = gather_data.permute(1, 0) - gather_data = None - return ans_logics diff --git a/lightllm/models/gemma3/layer_infer/pre_layer_infer.py b/lightllm/models/gemma3/layer_infer/pre_layer_infer.py index 0df5c0f06..dc8a46ad9 100644 --- a/lightllm/models/gemma3/layer_infer/pre_layer_infer.py +++ b/lightllm/models/gemma3/layer_infer/pre_layer_infer.py @@ -16,9 +16,9 @@ def context_forward(self, input_ids, infer_state, layer_weight): img_start_token_ids = [] img_token_lens = [] img_start_locs_in_cache = [] - device = layer_weight.wte_weight_.device - dtype = layer_weight.wte_weight_.dtype - hidden_size = layer_weight.wte_weight_.shape[1] + device = layer_weight.wte_weight_.weight.device + dtype = layer_weight.wte_weight_.weight.dtype + hidden_size = layer_weight.wte_weight_.weight.shape[1] weight_mask = torch.zeros((len(input_ids)), dtype=torch.float32, device=device) # TODO @@ -65,18 +65,18 @@ def context_forward(self, input_ids, infer_state, layer_weight): multimodal_emb( out=out, prompt_ids=input_ids, - text_weight_embs=layer_weight.wte_weight_, + text_weight_embs=layer_weight.wte_weight_.weight, embed_cache=cpu_embed_cache_tensor, img_token_lens=img_token_lens, img_start_token_ids=img_start_token_ids, img_start_locs_in_cache=img_start_locs_in_cache, - tp_text_start_token_id=self.vob_start_id_, - tp_text_end_token_id=self.vob_end_id_, + tp_text_start_token_id=layer_weight.wte_weight_.tp_vocab_start_id, + tp_text_end_token_id=layer_weight.wte_weight_.tp_vocab_end_id, tp_world_size=self.tp_world_size_, ) input_dtype = out.dtype if self.tp_world_size_ > 1: - all_reduce(out, group=infer_state.dist_group, op=torch.dist.ReduceOp.SUM, async_op=False) + all_reduce(out, group=infer_state.dist_group, op=torch.distributed.ReduceOp.SUM, async_op=False) return (out.float() * weight_mask.unsqueeze(1).float()).to(input_dtype) def token_forward(self, input_ids, infer_state, layer_weight): diff --git a/lightllm/models/gemma3/layer_infer/transformer_layer_infer.py b/lightllm/models/gemma3/layer_infer/transformer_layer_infer.py index 09efe9a36..d4bd8c3fa 100644 --- a/lightllm/models/gemma3/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/gemma3/layer_infer/transformer_layer_infer.py @@ -9,11 +9,7 @@ from lightllm.common.basemodel.infer_struct import InferStateInfo from lightllm.distributed import all_reduce -from lightllm.models.cohere.layer_infer.transformer_layer_infer import CohereTransformerLayerInfer -from lightllm.models.cohere.triton_kernels.layernorm import layernorm_forward from lightllm.models.gemma3.layer_weights.transformer_layer_weight import Gemma3TransformerLayerWeight -from lightllm.models.gemma_2b.triton_kernel.gelu_and_mul import gelu_and_mul_fwd - from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd @@ -31,54 +27,6 @@ def __init__(self, layer_num, network_config, mode=[]): self.sliding_window_pattern = 6 return - def gemma3_rmsnorm(self, input, weight, eps: float = 1e-6, out=None): - def _norm(x): - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) - - output = _norm(input.float()) - output = output * (1.0 + weight.float()) - if out is not None: - out = output.to(out.dtype) - return output - - def _pre_feedforward_layernorm(self, input, infer_state, layer_weight: Gemma3TransformerLayerWeight): - out = self.alloc_tensor(input.shape, input.dtype) - out = self.gemma3_rmsnorm(input, layer_weight.pre_feedforward_layernorm_weight_.weight, self.eps_, out=out) - return out - - def _post_feedforward_layernorm(self, input, infer_state, layer_weight: Gemma3TransformerLayerWeight): - out = self.alloc_tensor(input.shape, input.dtype) - out = self.gemma3_rmsnorm(input, layer_weight.post_feedforward_layernorm_weight_.weight, self.eps_, out=out) - return out - - def _k_norm(self, input, infer_state, layer_weight: Gemma3TransformerLayerWeight): - out = self.alloc_tensor(input.shape, input.dtype) - out = self.gemma3_rmsnorm(input, layer_weight.k_norm_weight_.weight, self.eps_, out=out) - return out - - def _q_norm(self, input, infer_state, layer_weight: Gemma3TransformerLayerWeight): - out = self.alloc_tensor(input.shape, input.dtype) - out = self.gemma3_rmsnorm(input, layer_weight.q_norm_weight_.weight, self.eps_, out=out) - return out - - def _att_norm(self, input, infer_state, layer_weight): - out = self.alloc_tensor(input.shape, input.dtype) - out = self.gemma3_rmsnorm(input, layer_weight.att_norm_weight_.weight, self.eps_, out=out) - return out - - def _ffn_norm(self, input, infer_state, layer_weight): - out = self.alloc_tensor(input.shape, input.dtype) - out = self.gemma3_rmsnorm(input, layer_weight.ffn_norm_weight_.weight, self.eps_, out=out) - return out - - def _bind_norm(self): - self._att_norm = partial(Gemma3TransformerLayerInfer._att_norm, self) - self._ffn_norm = partial(Gemma3TransformerLayerInfer._ffn_norm, self) - self._q_norm = partial(Gemma3TransformerLayerInfer._q_norm, self) - self._k_norm = partial(Gemma3TransformerLayerInfer._k_norm, self) - self._pre_feedforward_layernorm = partial(Gemma3TransformerLayerInfer._pre_feedforward_layernorm, self) - self._post_feedforward_layernorm = partial(Gemma3TransformerLayerInfer._post_feedforward_layernorm, self) - def _get_qkv( self, input, infer_state: LlamaInferStateInfo, layer_weight: Gemma3TransformerLayerWeight ) -> torch.Tensor: @@ -94,8 +42,16 @@ def _get_qkv( # gemma3 use qk norm q = q.view(-1, self.tp_q_head_num_, self.head_dim_) k = cache_kv[:, 0 : self.tp_k_head_num_, :] - q = self._q_norm(q.float(), infer_state, layer_weight).to(cache_kv.dtype) - cache_kv[:, 0 : self.tp_k_head_num_, :] = self._k_norm(k.float(), infer_state, layer_weight).to(cache_kv.dtype) + + q = layer_weight.q_norm_weight_.rmsnorm_forward( + input=q.float(), eps=self.eps_, alloc_func=self.alloc_tensor + ).to(cache_kv.dtype) + + cache_kv[:, 0 : self.tp_k_head_num_, :] = layer_weight.k_norm_weight_.rmsnorm_forward( + input=k.float(), + eps=self.eps_, + alloc_func=self.alloc_tensor, + ).to(cache_kv.dtype) is_sliding = bool((self.layer_num_ + 1) % self.sliding_window_pattern) if is_sliding: @@ -125,7 +81,7 @@ def _ffn(self, input, infer_state: LlamaInferStateInfo, layer_weight: Gemma3Tran ffn1_out = None return ffn2_out - def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): + def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight: Gemma3TransformerLayerWeight): input_embdings = input_embdings.to(torch.bfloat16) input1 = self._att_norm(input_embdings.view(-1, self.embed_dim_).float(), infer_state, layer_weight).to( torch.bfloat16 @@ -142,16 +98,25 @@ def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_wei input_embdings.add_(o.view(-1, self.embed_dim_)) o = None - input1 = self._pre_feedforward_layernorm(input_embdings.float(), infer_state, layer_weight).to(torch.bfloat16) + input1 = layer_weight.pre_feedforward_layernorm_weight_.rmsnorm_forward( + input=input_embdings.float(), eps=self.eps_, alloc_func=self.alloc_tensor + ).to(torch.bfloat16) + ffn_out = self._ffn(input1, infer_state, layer_weight) input1 = None if self.tp_world_size_ > 1: all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) - ffn_out = self._post_feedforward_layernorm(ffn_out.float(), infer_state, layer_weight).to(torch.bfloat16) + + ffn_out = layer_weight.post_feedforward_layernorm_weight_.rmsnorm_forward( + input=ffn_out.float(), + eps=self.eps_, + alloc_func=self.alloc_tensor, + ).to(torch.bfloat16) + input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) return input_embdings - def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): + def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight: Gemma3TransformerLayerWeight): input_embdings = input_embdings.to(torch.bfloat16) input1 = self._att_norm(input_embdings.view(-1, self.embed_dim_).float(), infer_state, layer_weight).to( torch.bfloat16 @@ -168,11 +133,20 @@ def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weigh input_embdings.add_(o.view(-1, self.embed_dim_)) o = None - input1 = self._pre_feedforward_layernorm(input_embdings.float(), infer_state, layer_weight).to(torch.bfloat16) + input1 = layer_weight.pre_feedforward_layernorm_weight_.rmsnorm_forward( + input=input_embdings.float(), eps=self.eps_, alloc_func=self.alloc_tensor + ).to(torch.bfloat16) + ffn_out = self._ffn(input1, infer_state, layer_weight) input1 = None if self.tp_world_size_ > 1: all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) - ffn_out = self._post_feedforward_layernorm(ffn_out.float(), infer_state, layer_weight).to(torch.bfloat16) + + ffn_out = layer_weight.post_feedforward_layernorm_weight_.rmsnorm_forward( + input=ffn_out.float(), + eps=self.eps_, + alloc_func=self.alloc_tensor, + ).to(torch.bfloat16) + input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) return input_embdings diff --git a/lightllm/models/gemma3/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/gemma3/layer_weights/pre_and_post_layer_weight.py index 24ea91d89..17e65268c 100644 --- a/lightllm/models/gemma3/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/gemma3/layer_weights/pre_and_post_layer_weight.py @@ -1,25 +1,20 @@ -import torch -import numpy as np -from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight +from lightllm.common.basemodel import PreAndPostLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, NoTpGEMMANormWeight -# add key: language_model.xxx -> xxx -# only change keys at PreAndPostLayerWeight load, TransformLayerWeight is correct now -def rename_weight_keys(weights): - prefix = "language_model." - keys = list(weights.keys()) - for k in keys: - if prefix in k: - weights[k[len(prefix) :]] = weights[k] - - -class Gemma3PreAndPostLayerWeight(LlamaPreAndPostLayerWeight): +class Gemma3PreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): - network_config["tie_word_embeddingse"] = True super().__init__(data_type, network_config, mode) - return - def load_hf_weights(self, weights): - rename_weight_keys(weights) - super().load_hf_weights(weights) + self.wte_weight_ = EmbeddingWeight( + weight_name="language_model.model.embed_tokens.weight", + data_type=self.data_type_, + ) + self.lm_head_weight_ = self.wte_weight_ + + self.final_norm_weight_ = NoTpGEMMANormWeight( + weight_name="language_model.model.norm.weight", + data_type=self.data_type_, + bias_name=None, + ) return diff --git a/lightllm/models/gemma3/layer_weights/transformer_layer_weight.py b/lightllm/models/gemma3/layer_weights/transformer_layer_weight.py index 6f5530461..1e7ceeb42 100644 --- a/lightllm/models/gemma3/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/gemma3/layer_weights/transformer_layer_weight.py @@ -1,8 +1,6 @@ -import torch -import numpy as np from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight import ROWMMWeight -from lightllm.common.basemodel.layer_weights.meta_weights.norm_weight import NormWeight from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import NoTpGEMMANormWeight class Gemma3TransformerLayerWeight(LlamaTransformerLayerWeight): @@ -66,12 +64,12 @@ def _init_qkv(self): def _init_norm(self): super()._init_norm() - self.k_norm_weight_ = NormWeight(self._k_norm_weight_name, self.data_type_, bias_name=None) - self.q_norm_weight_ = NormWeight(self._q_norm_weight_name, self.data_type_, bias_name=None) - self.pre_feedforward_layernorm_weight_ = NormWeight( + self.k_norm_weight_ = NoTpGEMMANormWeight(self._k_norm_weight_name, self.data_type_, bias_name=None) + self.q_norm_weight_ = NoTpGEMMANormWeight(self._q_norm_weight_name, self.data_type_, bias_name=None) + self.pre_feedforward_layernorm_weight_ = NoTpGEMMANormWeight( self._pre_feedforward_layernorm_name, self.data_type_, bias_name=None ) - self.post_feedforward_layernorm_weight_ = NormWeight( + self.post_feedforward_layernorm_weight_ = NoTpGEMMANormWeight( self._post_feedforward_layernorm_name, self.data_type_, bias_name=None ) diff --git a/lightllm/models/gemma3/model.py b/lightllm/models/gemma3/model.py index 42326169a..dc4f03b7e 100644 --- a/lightllm/models/gemma3/model.py +++ b/lightllm/models/gemma3/model.py @@ -6,6 +6,7 @@ from lightllm.models.registry import ModelRegistry from lightllm.common.basemodel.multimodal_tokenizer import BaseMultiModalTokenizer from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class +from lightllm.utils.envs_utils import get_added_mtp_kv_layer_num from lightllm.models.gemma3.infer_struct import Gemma3InferStateInfo from lightllm.models.gemma3.layer_infer.post_layer_infer import Gemma3PostLayerInfer from lightllm.models.gemma3.layer_infer.pre_layer_infer import Gemma3PreLayerInfer @@ -148,7 +149,7 @@ def _init_mem_manager(self): dtype=torch.bfloat16, head_num=self.config["num_key_value_heads"] // self.tp_world_size_, head_dim=256, - layer_num=self.config["num_hidden_layers"], + layer_num=self.config["num_hidden_layers"] + get_added_mtp_kv_layer_num(), mem_fraction=self.mem_fraction, ) return diff --git a/lightllm/models/gemma_2b/layer_infer/pre_layer_infer.py b/lightllm/models/gemma_2b/layer_infer/pre_layer_infer.py index c63432a92..ce9737820 100644 --- a/lightllm/models/gemma_2b/layer_infer/pre_layer_infer.py +++ b/lightllm/models/gemma_2b/layer_infer/pre_layer_infer.py @@ -5,8 +5,6 @@ from lightllm.models.gemma_2b.layer_weights.pre_and_post_layer_weight import Gemma_2bPreAndPostLayerWeight from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.common.basemodel import PreLayerInferTpl -from lightllm.utils.infer_utils import mark_cost_time -from lightllm.models.llama.triton_kernel.embedding import embedding from lightllm.distributed.communication_op import all_reduce @@ -24,20 +22,20 @@ def _norm(self, input, infer_state, layer_weight: Gemma_2bPreAndPostLayerWeight) return input * self.normfactor def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: Gemma_2bPreAndPostLayerWeight): - input_embdings = self.alloc_tensor( - (input_ids.shape[0], layer_weight.wte_weight_.shape[1]), dtype=layer_weight.data_type_ + input_embdings = layer_weight.wte_weight_.embedding( + input_ids=input_ids, + alloc_func=self.alloc_tensor, ) - embedding(input_ids, layer_weight.wte_weight_, self.vob_start_id_, self.vob_end_id_, input_embdings) if self.tp_world_size_ > 1: all_reduce(input_embdings, group=infer_state.dist_group, op=dist.ReduceOp.SUM, async_op=False) input_embdings = self._norm(input_embdings, infer_state, layer_weight) return input_embdings def token_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: Gemma_2bPreAndPostLayerWeight): - input_embdings = self.alloc_tensor( - (input_ids.shape[0], layer_weight.wte_weight_.shape[1]), dtype=layer_weight.data_type_ + input_embdings = layer_weight.wte_weight_.embedding( + input_ids=input_ids, + alloc_func=self.alloc_tensor, ) - embedding(input_ids, layer_weight.wte_weight_, self.vob_start_id_, self.vob_end_id_, input_embdings) if self.tp_world_size_ > 1: all_reduce(input_embdings, group=infer_state.dist_group, op=dist.ReduceOp.SUM, async_op=False) input_embdings = self._norm(input_embdings, infer_state, layer_weight) diff --git a/lightllm/models/gemma_2b/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/gemma_2b/layer_weights/pre_and_post_layer_weight.py index c119960c5..d5d0438fa 100644 --- a/lightllm/models/gemma_2b/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/gemma_2b/layer_weights/pre_and_post_layer_weight.py @@ -1,25 +1,21 @@ -import torch -import numpy as np -from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight +from lightllm.common.basemodel import PreAndPostLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, NoTpGEMMANormWeight -class Gemma_2bPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): +class Gemma_2bPreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) - return - def load_hf_weights(self, weights): - vob_size = self.network_config_["vocab_size"] - split_indexes = np.linspace(0, vob_size, self.tp_world_size_ + 1, dtype=np.int64) - split_start = split_indexes[self.tp_rank_] - split_end = split_indexes[self.tp_rank_ + 1] - if "model.embed_tokens.weight" in weights: - # print(weights['model.embed_tokens.weight'].shape) - self.wte_weight_ = self._cuda(weights["model.embed_tokens.weight"][split_start:split_end, :]) - self.lm_head_weight_ = self.wte_weight_ + self.wte_weight_ = EmbeddingWeight( + weight_name="model.embed_tokens.weight", + data_type=self.data_type_, + ) + self.lm_head_weight_ = self.wte_weight_ - if "model.norm.weight" in weights: - self.final_norm_weight_ = self._cuda(weights["model.norm.weight"]) - self.final_norm_weight_ = self.final_norm_weight_ + 1 + self.final_norm_weight_ = NoTpGEMMANormWeight( + weight_name="model.norm.weight", + data_type=self.data_type_, + bias_name=None, + ) return diff --git a/lightllm/models/gemma_2b/layer_weights/transformer_layer_weight.py b/lightllm/models/gemma_2b/layer_weights/transformer_layer_weight.py index 32248e6dd..1916bd095 100644 --- a/lightllm/models/gemma_2b/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/gemma_2b/layer_weights/transformer_layer_weight.py @@ -2,7 +2,7 @@ import math import numpy as np from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight -from lightllm.common.basemodel.layer_weights.meta_weights import GEMMANormWeight, ROWMMWeight +from lightllm.common.basemodel.layer_weights.meta_weights import NoTpGEMMANormWeight, ROWMMWeight class Gemma_2bTransformerLayerWeight(LlamaTransformerLayerWeight): @@ -29,5 +29,5 @@ def _init_qkv(self): ) def _init_norm(self): - self.att_norm_weight_ = GEMMANormWeight(self._att_norm_weight_name, self.data_type_) - self.ffn_norm_weight_ = GEMMANormWeight(self._ffn_norm_weight_name, self.data_type_) + self.att_norm_weight_ = NoTpGEMMANormWeight(self._att_norm_weight_name, self.data_type_) + self.ffn_norm_weight_ = NoTpGEMMANormWeight(self._ffn_norm_weight_name, self.data_type_) diff --git a/lightllm/models/gemma_2b/model.py b/lightllm/models/gemma_2b/model.py index 4b425c9ce..2563c7e79 100644 --- a/lightllm/models/gemma_2b/model.py +++ b/lightllm/models/gemma_2b/model.py @@ -7,6 +7,7 @@ from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.models.llama.model import LlamaTpPartModel from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class +from lightllm.utils.envs_utils import get_added_mtp_kv_layer_num @ModelRegistry("gemma") @@ -43,7 +44,7 @@ def _init_mem_manager(self): dtype=self.data_type, head_num=self.config["num_key_value_heads"], head_dim=self.config["hidden_size"] // self.config["num_attention_heads"], - layer_num=self.config["num_hidden_layers"], + layer_num=self.config["num_hidden_layers"] + get_added_mtp_kv_layer_num(), mem_fraction=self.mem_fraction, ) return diff --git a/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py b/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py index 1246af090..93cd7413b 100644 --- a/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py @@ -69,7 +69,7 @@ def _ffn( return hidden_states.view(num_tokens, hidden_dim) def _context_sliding_attention_flashattention( - self, q, kv, infer_state: FlashAttentionStateInfo, layer_weight, out=None + self, q, kv, infer_state: FlashAttentionStateInfo, layer_weight: GptOssTransformerLayerWeight, out=None ): if self.network_config_["layer_types"][self.layer_num_] == "sliding_attention": window_size = (self.sliding_window - 1, self.sliding_window - 1) @@ -106,7 +106,9 @@ def _context_sliding_attention_flashattention( ) return o - def _token_sliding_attention_flashattention(self, q, infer_state: FlashAttentionStateInfo, layer_weight, out=None): + def _token_sliding_attention_flashattention( + self, q, infer_state: FlashAttentionStateInfo, layer_weight: GptOssTransformerLayerWeight, out=None + ): if self.network_config_["layer_types"][self.layer_num_] == "sliding_attention": window_size = (self.sliding_window - 1, self.sliding_window - 1) else: diff --git a/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py b/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py index 7e6035dc5..f6a841b1a 100644 --- a/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py @@ -4,7 +4,7 @@ from lightllm.common.basemodel.layer_weights.meta_weights.gpt_oss_fused_moe_weight_tp import GPTOSSFusedMoeWeightTP from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight import ROWMMWeight -from lightllm.common.basemodel.layer_weights.meta_weights.norm_weight import NormWeight, TpNormWeight +from lightllm.common.basemodel.layer_weights.meta_weights import TpAttSinkWeight from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight from lightllm.utils.log_utils import init_logger @@ -57,8 +57,6 @@ def _init_moe(self): def _init_weight_names(self): super()._init_weight_names() - self._attn_sink_name = f"model.layers.{self.layer_num_}.self_attn.sinks" - self._q_bias_name = f"model.layers.{self.layer_num_}.self_attn.q_proj.bias" self._k_bias_name = f"model.layers.{self.layer_num_}.self_attn.k_proj.bias" self._v_bias_name = f"model.layers.{self.layer_num_}.self_attn.v_proj.bias" @@ -70,8 +68,10 @@ def _init_weight_names(self): def _init_weight(self): super()._init_weight() - n_split_head = self.network_config_["num_attention_heads"] // self.tp_world_size_ - self.attn_sinks = TpNormWeight(self._attn_sink_name, torch.bfloat16, n_split_head) + self.attn_sinks = TpAttSinkWeight( + weight_name=f"model.layers.{self.layer_num_}.self_attn.sinks", + data_type=torch.bfloat16, + ) def _init_ffn(self): self._init_moe() diff --git a/lightllm/models/internlm2/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/internlm2/layer_weights/pre_and_post_layer_weight.py index dd8c64915..b40330aa3 100644 --- a/lightllm/models/internlm2/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/internlm2/layer_weights/pre_and_post_layer_weight.py @@ -1,23 +1,15 @@ -import torch -import numpy as np -from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight +from lightllm.common.basemodel import PreAndPostLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, LMHeadWeight, NoTpNormWeight -class Internlm2PreAndPostLayerWeight(LlamaPreAndPostLayerWeight): +class Internlm2PreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) - return - - def load_hf_weights(self, weights): - vob_size = self.network_config_["vocab_size"] - split_indexes = np.linspace(0, vob_size, self.tp_world_size_ + 1, dtype=np.int64) - split_start = split_indexes[self.tp_rank_] - split_end = split_indexes[self.tp_rank_ + 1] - if "model.tok_embeddings.weight" in weights: - self.wte_weight_ = self._cuda(weights["model.tok_embeddings.weight"][split_start:split_end, :]) - if "output.weight" in weights: - self.lm_head_weight_ = self._cuda(weights["output.weight"][split_start:split_end, :]) - if "model.norm.weight" in weights: - self.final_norm_weight_ = self._cuda(weights["model.norm.weight"]) + self.wte_weight_ = EmbeddingWeight(weight_name="model.tok_embeddings.weight", data_type=self.data_type_) + self.lm_head_weight_ = LMHeadWeight(weight_name="output.weight", data_type=self.data_type_) + self.final_norm_weight_ = NoTpNormWeight( + weight_name="model.norm.weight", + data_type=self.data_type_, + ) return diff --git a/lightllm/models/internlm2_reward/layer_infer/post_layer_infer.py b/lightllm/models/internlm2_reward/layer_infer/post_layer_infer.py index 42061c784..d5eaa1654 100644 --- a/lightllm/models/internlm2_reward/layer_infer/post_layer_infer.py +++ b/lightllm/models/internlm2_reward/layer_infer/post_layer_infer.py @@ -1,19 +1,17 @@ import torch -import torch.distributed as dist -import numpy as np - from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer -from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight -from einops import rearrange +from ..layer_weights.pre_and_post_layer_weight import Internlm2RewardPreAndPostLayerWeight class Internlm2RewardPostLayerInfer(LlamaPostLayerInfer): - def token_forward(self, input_embdings, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): + def token_forward( + self, input_embdings, infer_state: LlamaInferStateInfo, layer_weight: Internlm2RewardPreAndPostLayerWeight + ): last_input, token_num = self._slice_get_last_input(input_embdings, infer_state) input_embdings = None last_input = self._norm(last_input, infer_state, layer_weight) - score = torch.mm(last_input, layer_weight.lm_head_weight_) + score = layer_weight.score_head_.mm(last_input) return score diff --git a/lightllm/models/internlm2_reward/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/internlm2_reward/layer_weights/pre_and_post_layer_weight.py index 78fb0c5d7..b20b9c495 100644 --- a/lightllm/models/internlm2_reward/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/internlm2_reward/layer_weights/pre_and_post_layer_weight.py @@ -1,23 +1,24 @@ -import torch import numpy as np -from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight +from lightllm.common.basemodel import PreAndPostLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, NoTpNormWeight, ROWMMWeight -class Internlm2RewardPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): +class Internlm2RewardPreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) - return - - def load_hf_weights(self, weights): - vob_size = self.network_config_["vocab_size"] - split_indexes = np.linspace(0, vob_size, self.tp_world_size_ + 1, dtype=np.int64) - split_start = split_indexes[self.tp_rank_] - split_end = split_indexes[self.tp_rank_ + 1] - if "model.tok_embeddings.weight" in weights: - self.wte_weight_ = self._cuda(weights["model.tok_embeddings.weight"][split_start:split_end, :]) - if "v_head.weight" in weights: - self.lm_head_weight_ = self._cuda(weights["v_head.weight"]).transpose(0, 1) - if "model.norm.weight" in weights: - self.final_norm_weight_ = self._cuda(weights["model.norm.weight"]) - + self.wte_weight_ = EmbeddingWeight( + weight_name="model.tok_embeddings.weight", + data_type=self.data_type_, + ) + self.score_head_ = ROWMMWeight( + weight_names="v_head.weight", + data_type=self.data_type_, + name="score_head", + tp_rank=0, + tp_world_size=1, + ) + self.final_norm_weight_ = NoTpNormWeight( + weight_name="model.norm.weight", + data_type=self.data_type_, + ) return diff --git a/lightllm/models/internlm2_reward/model.py b/lightllm/models/internlm2_reward/model.py index 881a607a5..b9ea002a6 100644 --- a/lightllm/models/internlm2_reward/model.py +++ b/lightllm/models/internlm2_reward/model.py @@ -1,6 +1,3 @@ -import os -import json -import torch from lightllm.models.registry import ModelRegistry, is_reward_model from lightllm.models.internlm2_reward.layer_infer.post_layer_infer import Internlm2RewardPostLayerInfer from lightllm.models.internlm2_reward.layer_weights.pre_and_post_layer_weight import ( diff --git a/lightllm/models/internvl/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/internvl/layer_weights/pre_and_post_layer_weight.py index f19563932..7d76d202a 100644 --- a/lightllm/models/internvl/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/internvl/layer_weights/pre_and_post_layer_weight.py @@ -1,7 +1,4 @@ -import torch -import numpy as np from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight - from lightllm.models.internlm2.layer_weights.pre_and_post_layer_weight import Internlm2PreAndPostLayerWeight diff --git a/lightllm/models/llama/flashattention_infer_struct.py b/lightllm/models/llama/flashattention_infer_struct.py index c6e7aa560..9f71cbbc5 100644 --- a/lightllm/models/llama/flashattention_infer_struct.py +++ b/lightllm/models/llama/flashattention_infer_struct.py @@ -25,12 +25,12 @@ def get_page_table_buffer(cls, graph_max_batch_size: int, max_seq_len: int): ] return cls._shared_page_table_buffer - def _init_flash_attention_state(self, model, input_ids: torch.Tensor): + def _init_flash_attention_state(self, model): if self.is_prefill: self.cu_seqlens_q = self.b1_cu_q_seq_len.int() self.cu_seqlens_k = self.b1_cu_kv_seq_len.int() self.page_table = torch.empty( - (self.batch_size, self.max_seq_len), dtype=torch.int32, device=input_ids.device + (self.batch_size, self.max_seq_len), dtype=torch.int32, device=self.input_ids.device ) self.page_table.copy_(model.req_manager.req_to_token_indexs[self.b_req_idx, : self.max_seq_len]) else: @@ -38,26 +38,33 @@ def _init_flash_attention_state(self, model, input_ids: torch.Tensor): self.cu_seqlens_q = self.b1_cu_q_seq_len.int() self.cu_seqlens_k = self.b1_cu_kv_seq_len.int() max_seq_len_k = self.max_kv_seq_len + args_mtp_step = get_env_start_args().mtp_step + att_batch_size = self.batch_size // (args_mtp_step + 1) if self.batch_size <= model.graph_max_batch_size and self.max_len_in_batch <= model.graph_max_len_in_batch: page_buffer = FlashAttentionStateInfo.get_page_table_buffer( model.graph_max_batch_size, model.graph_max_len_in_batch ) self.page_table = page_buffer[self.microbatch_index][ - : self.batch_size * model.graph_max_len_in_batch - ].reshape(self.batch_size, model.graph_max_len_in_batch) + : att_batch_size * model.graph_max_len_in_batch + ].reshape(att_batch_size, model.graph_max_len_in_batch) else: self.page_table = torch.empty( - (self.batch_size, self.max_len_in_batch), dtype=torch.int32, device=input_ids.device + (att_batch_size, self.max_len_in_batch), dtype=torch.int32, device=self.input_ids.device ) + page_table_copy( page_table=self.page_table[:, :max_seq_len_k], req_to_token_indexs=model.req_manager.req_to_token_indexs, - b_req_idx=self.b_req_idx, + b_req_idx=self.b_req_idx[args_mtp_step :: (args_mtp_step + 1)], ) + if args_mtp_step > 0: + self.b_att_seq_len = self.b_seq_len[args_mtp_step :: (args_mtp_step + 1)].contiguous() + else: + self.b_att_seq_len = self.b_seq_len if "offline_calibration_fp8kv" in model.mode: if self.is_prefill: - device = input_ids.device + device = self.input_ids.device # q_scale和token_batch_ids在对q做per head量化使用,为了节省资源在推理外部初始化 self.q_scale = torch.empty( (self.batch_size, self.mem_manager.head_num), dtype=torch.float32, device=device @@ -77,7 +84,7 @@ def _init_flash_attention_state(self, model, input_ids: torch.Tensor): else torch.ones( (self.mem_manager.layer_num, self.batch_size, head_num), dtype=torch.float32, - device=input_ids.device, + device=self.input_ids.device, ) ) self.v_descale = ( @@ -88,12 +95,12 @@ def _init_flash_attention_state(self, model, input_ids: torch.Tensor): else torch.ones( (self.mem_manager.layer_num, self.batch_size, head_num), dtype=torch.float32, - device=input_ids.device, + device=self.input_ids.device, ) ) return - def init_some_extra_state(self, model, input_ids: torch.Tensor): - super().init_some_extra_state(model, input_ids) - self._init_flash_attention_state(model, input_ids) + def init_some_extra_state(self, model): + super().init_some_extra_state(model) + self._init_flash_attention_state(model) return diff --git a/lightllm/models/llama/flashinfer_struct.py b/lightllm/models/llama/flashinfer_struct.py index a0c40b57a..7f9beac1d 100644 --- a/lightllm/models/llama/flashinfer_struct.py +++ b/lightllm/models/llama/flashinfer_struct.py @@ -14,8 +14,8 @@ def __init__(self): self.decode_wrapper = None self.flashinfer_extra_state = None - def init_some_extra_state(self, model, input_ids: torch.Tensor): - super().init_some_extra_state(model, input_ids) + def init_some_extra_state(self, model): + super().init_some_extra_state(model) self.flashinfer_extra_state = model.flashinfer_extra_state import flashinfer @@ -23,7 +23,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): if not self.is_prefill: if get_env_start_args().enable_flashinfer_decode: self.kv_last_page_len_buffer = torch.full( - (self.batch_size,), 1, dtype=torch.int32, device=input_ids.device + (self.batch_size,), 1, dtype=torch.int32, device=self.input_ids.device ) if self.batch_size <= model.graph_max_batch_size: self.kv_indices = self.flashinfer_extra_state.kv_indices_buffer[self.microbatch_index][ @@ -33,7 +33,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): self.kv_indices = torch.empty( self.batch_size * self.flashinfer_extra_state.max_seq_length, dtype=torch.int32, - device=input_ids.device, + device=self.input_ids.device, ) repack_kv_index( @@ -71,11 +71,11 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): if get_env_start_args().enable_flashinfer_prefill: q_starts = self.b1_cu_q_seq_len.int() kv_starts = self.b1_cu_kv_seq_len.int() - kv_last_page_len = torch.full((self.batch_size,), 1, dtype=torch.int32, device=input_ids.device) + kv_last_page_len = torch.full((self.batch_size,), 1, dtype=torch.int32, device=self.input_ids.device) kv_indices = torch.empty( self.batch_size * self.flashinfer_extra_state.max_seq_length, dtype=torch.int32, - device=input_ids.device, + device=self.input_ids.device, ) repack_kv_index( self.req_manager.req_to_token_indexs, diff --git a/lightllm/models/llama/infer_struct.py b/lightllm/models/llama/infer_struct.py index 6373b3782..3bba43976 100644 --- a/lightllm/models/llama/infer_struct.py +++ b/lightllm/models/llama/infer_struct.py @@ -10,8 +10,8 @@ def __init__(self): self.position_cos = None self.position_sin = None - def init_some_extra_state(self, model, input_ids: torch.Tensor): - super().init_some_extra_state(model, input_ids) + def init_some_extra_state(self, model): + super().init_some_extra_state(model) if self.is_prefill: self.max_seq_len = self.max_kv_seq_len self.q_max_seq_len = self.max_q_seq_len diff --git a/lightllm/models/llama/layer_infer/post_layer_infer.py b/lightllm/models/llama/layer_infer/post_layer_infer.py index 28e60a952..7c7b0ea39 100644 --- a/lightllm/models/llama/layer_infer/post_layer_infer.py +++ b/lightllm/models/llama/layer_infer/post_layer_infer.py @@ -4,13 +4,9 @@ import torch.distributed as dist import numpy as np from lightllm.common.basemodel.layer_weights.base_layer_weight import BaseLayerWeight - from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight -from einops import rearrange from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward from lightllm.common.basemodel import PostLayerInferTpl -from lightllm.utils.infer_utils import mark_cost_time from lightllm.distributed.communication_op import all_gather @@ -20,15 +16,13 @@ class LlamaPostLayerInfer(PostLayerInferTpl): def __init__(self, network_config, mode): super().__init__(network_config, mode) self.eps_ = network_config["rms_norm_eps"] - self.vocab_size_ = network_config["vocab_size"] - self.embed_dim_ = network_config["n_embed"] return def _norm(self, input, infer_state, layer_weight: LlamaPreAndPostLayerWeight) -> torch.Tensor: - return rmsnorm_forward(input, layer_weight.final_norm_weight_, eps=self.eps_) - - def _slice_get_last_input(self, input_embdings, infer_state: LlamaInferStateInfo): + return layer_weight.final_norm_weight_.rmsnorm_forward(input=input, eps=self.eps_, alloc_func=self.alloc_tensor) + def _slice_get_last_input(self, input_embdings: torch.Tensor, infer_state: LlamaInferStateInfo): + embed_dim_ = input_embdings.shape[1] if infer_state.is_prefill and infer_state.is_token_healing: batch_size = infer_state.batch_size b_seq_len_numpy = (infer_state.b_seq_len - infer_state.b_ready_cache_len).detach().cpu().numpy() @@ -41,13 +35,13 @@ def _slice_get_last_input(self, input_embdings, infer_state: LlamaInferStateInfo select_token_num += 1 last_index = torch.tensor(select_index, dtype=torch.long, device=input_embdings.device) - last_input = self.alloc_tensor((select_token_num, self.embed_dim_), dtype=input_embdings.dtype) + last_input = self.alloc_tensor((select_token_num, embed_dim_), dtype=input_embdings.dtype) last_input[:, :] = input_embdings[last_index, :] return last_input, select_token_num if infer_state.is_prefill and not infer_state.return_all_prompt_logics: batch_size = infer_state.batch_size - last_input = self.alloc_tensor((batch_size, self.embed_dim_), dtype=input_embdings.dtype) + last_input = self.alloc_tensor((batch_size, embed_dim_), dtype=input_embdings.dtype) last_index = ( torch.cumsum(infer_state.b_seq_len - infer_state.b_ready_cache_len, dim=0, dtype=torch.long) - 1 ) @@ -64,23 +58,22 @@ def _slice_get_last_input(self, input_embdings, infer_state: LlamaInferStateInfo assert False, "Error State" - def token_forward(self, input_embdings, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): + def token_forward( + self, input_embdings: torch.Tensor, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight + ): last_input, token_num = self._slice_get_last_input(input_embdings, infer_state) input_embdings_dtype = input_embdings.dtype input_embdings = None last_input = self._norm(last_input, infer_state, layer_weight) last_input = last_input.permute(1, 0).view(-1, token_num) - logic_batch = self.alloc_tensor( - (layer_weight.lm_head_weight_.shape[0], last_input.shape[1]), dtype=last_input.dtype - ) - torch.mm(layer_weight.lm_head_weight_, last_input, out=logic_batch) - + logic_batch = layer_weight.lm_head_weight_.lm_head(input=last_input, alloc_func=self.alloc_tensor) last_input = None + vocab_size = layer_weight.lm_head_weight_.vocab_size if self.tp_world_size_ == 1: gather_data = logic_batch else: - gather_data = self.alloc_tensor((self.vocab_size_, token_num), dtype=input_embdings_dtype) - split_indexes = np.linspace(0, self.vocab_size_, self.tp_world_size_ + 1, dtype=np.int64) + gather_data = self.alloc_tensor((vocab_size, token_num), dtype=input_embdings_dtype) + split_indexes = np.linspace(0, vocab_size, self.tp_world_size_ + 1, dtype=np.int64) all_gather( [gather_data[split_indexes[i] : split_indexes[i + 1], :] for i in range(self.tp_world_size_)], logic_batch, @@ -89,7 +82,7 @@ def token_forward(self, input_embdings, infer_state: LlamaInferStateInfo, layer_ ) logic_batch = None ans_logics = self.alloc_tensor( - (token_num, self.vocab_size_), + (token_num, vocab_size), dtype=torch.float32, ) ans_logics[:, :] = gather_data.permute(1, 0) @@ -111,8 +104,8 @@ def tpsp_token_forward( group=infer_state.dist_group, async_op=False, ) - # len(infer_state.position_sin) 获取真实输入长度 - input_embdings = gather_data[0 : len(infer_state.position_sin)] + # len(infer_state.input_ids) 获取真实输入长度 + input_embdings = gather_data[0 : len(infer_state.input_ids)] if infer_state.need_dp_prefill_balance: input_embdings = infer_state._all_to_all_unbalance_get(data=input_embdings) @@ -131,18 +124,12 @@ def overlap_tpsp_token_forward( infer_state.hook() infer_state.hook = None - if infer_state.need_dp_prefill_balance: - input_embdings = infer_state._all_to_all_unbalance_get(data=input_embdings) - logics = self.tpsp_token_forward(input_embdings, infer_state, layer_weight=layer_weight) if getattr(infer_state1, "hook", None) is not None: infer_state1.hook() infer_state1.hook = None - if infer_state1.need_dp_prefill_balance: - input_embdings1 = infer_state1._all_to_all_unbalance_get(data=input_embdings1) - logics1 = self.tpsp_token_forward(input_embdings1, infer_state1, layer_weight=layer_weight) return logics, logics1 diff --git a/lightllm/models/llama/layer_infer/pre_layer_infer.py b/lightllm/models/llama/layer_infer/pre_layer_infer.py index 99b7db5bf..ddb99e262 100644 --- a/lightllm/models/llama/layer_infer/pre_layer_infer.py +++ b/lightllm/models/llama/layer_infer/pre_layer_infer.py @@ -1,13 +1,8 @@ -import os import torch import torch.distributed as dist -import numpy as np - from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.common.basemodel import PreLayerInferTpl -from lightllm.utils.infer_utils import mark_cost_time -from lightllm.models.llama.triton_kernel.embedding import embedding from lightllm.distributed.communication_op import all_reduce from lightllm.utils.envs_utils import get_env_start_args @@ -17,25 +12,16 @@ class LlamaPreLayerInfer(PreLayerInferTpl): def __init__(self, network_config, mode): super().__init__(network_config, mode) - tp_vob_ids = np.linspace(0, network_config["vocab_size"], self.tp_world_size_ + 1, dtype=np.int64) - self.vob_start_id_, self.vob_end_id_ = int(tp_vob_ids[self.tp_rank_]), int(tp_vob_ids[self.tp_rank_ + 1]) - return def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): - input_embdings = self.alloc_tensor( - (input_ids.shape[0], layer_weight.wte_weight_.shape[1]), dtype=layer_weight.data_type_ - ) - embedding(input_ids, layer_weight.wte_weight_, self.vob_start_id_, self.vob_end_id_, input_embdings) + input_embdings = layer_weight.wte_weight_.embedding(input_ids=input_ids, alloc_func=self.alloc_tensor) if self.tp_world_size_ > 1: all_reduce(input_embdings, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) return input_embdings def token_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): - input_embdings = self.alloc_tensor( - (input_ids.shape[0], layer_weight.wte_weight_.shape[1]), dtype=layer_weight.data_type_ - ) - embedding(input_ids, layer_weight.wte_weight_, self.vob_start_id_, self.vob_end_id_, input_embdings) + input_embdings = layer_weight.wte_weight_.embedding(input_ids=input_ids, alloc_func=self.alloc_tensor) if self.tp_world_size_ > 1: all_reduce(input_embdings, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) return input_embdings diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py index 8c6015677..b08b2aa1f 100755 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -14,7 +14,6 @@ from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd, token_att_fwd_int8k from lightllm.models.llama.triton_kernel.token_attention_nopad_softmax import token_softmax_fwd from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import token_att_fwd2, token_att_fwd2_int8v -from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.common.fused_moe.moe_silu_and_mul import silu_and_mul_fwd @@ -190,16 +189,16 @@ def _bind_attention(self): def _att_norm( self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight ) -> torch.Tensor: - out = self.alloc_tensor(input.shape, input.dtype) - rmsnorm_forward(input, weight=layer_weight.att_norm_weight_.weight, eps=self.eps_, out=out) - return out + return layer_weight.att_norm_weight_.rmsnorm_forward(input=input, eps=self.eps_, alloc_func=self.alloc_tensor) def _ffn_norm( self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight ) -> torch.Tensor: - out = self.alloc_tensor(input.shape, input.dtype) - rmsnorm_forward(input, weight=layer_weight.ffn_norm_weight_.weight, eps=self.eps_, out=out) - return out + return layer_weight.ffn_norm_weight_.rmsnorm_forward( + input=input, + eps=self.eps_, + alloc_func=self.alloc_tensor, + ) def _get_qkv( self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight @@ -224,7 +223,7 @@ def _tpsp_get_qkv( (sp_token_num * self.tp_world_size_, hidden_dim), dtype=input.dtype, device=input.device ) all_gather_into_tensor(gather_input, input, group=infer_state.dist_group, async_op=False) - input = gather_input[0 : len(infer_state.position_cos), :] + input = gather_input[0 : len(infer_state.input_ids), :] q = layer_weight.q_proj.mm(input) cache_kv = layer_weight.kv_proj.mm(input).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) @@ -415,8 +414,8 @@ def _tpsp_get_o( input = input.view(-1, self.tp_o_head_num_ * self.head_dim_) dest_size = triton.cdiv(input.shape[0], self.tp_world_size_) * self.tp_world_size_ o_tensor = self.alloc_tensor((dest_size, self.embed_dim_), dtype=input.dtype, device=input.device) - layer_weight.o_proj.mm(input, out=o_tensor[0 : len(infer_state.position_cos), :]) - e_o_tensor = o_tensor[len(infer_state.position_cos) :, :] + layer_weight.o_proj.mm(input, out=o_tensor[0 : len(infer_state.input_ids), :]) + e_o_tensor = o_tensor[len(infer_state.input_ids) :, :] if e_o_tensor.shape[0] > 0: e_o_tensor.fill_(0) @@ -883,10 +882,10 @@ def _token_decode_attention_flashattention(self, q, infer_state: FlashAttentionS k_cache=cache_k, v_cache=cache_v, page_table=infer_state.page_table, - cache_seqlens=infer_state.b_seq_len, + cache_seqlens=infer_state.b_att_seq_len, cu_seqlens_q=infer_state.cu_seqlens_q, cu_seqlens_k_new=infer_state.cu_seqlens_k, - max_seqlen_q=1, + max_seqlen_q=infer_state.max_q_seq_len, softmax_scale=sm_scale, causal=True, window_size=(-1, -1), diff --git a/lightllm/models/llama/layer_weights/ds_load_utils.py b/lightllm/models/llama/layer_weights/ds_load_utils.py deleted file mode 100644 index 091c056ca..000000000 --- a/lightllm/models/llama/layer_weights/ds_load_utils.py +++ /dev/null @@ -1,49 +0,0 @@ -import collections -import torch -import os -import gc - -def load_ds_weights(data_type, weight_dir, pre_post_layer=None, transformer_layer_list=None, weight_dict=None, prefix="", num_layer=0): - if weight_dict: - return weight_dict - files = os.listdir(weight_dir) - candidate_files = sorted(list(filter(lambda x : x.endswith('.pt') and x.startswith('layer'), files))) - assert len(candidate_files) != 0, "can only support pytorch tensor format for weights." - if weight_dict: - weights_all = weight_dict - else: - weights_all = {} - for file_ in candidate_files: - file_split = file_.split('-') - layer_num = int(file_split[0].split('_')[-1]) - rank_num = int(file_split[0].split('_')[-1]) - weights = torch.load(os.path.join(weight_dir, file_), 'cpu') - for k,v in weights.items(): - if layer_num >=3 and layer_num < 3 + num_layer: - k = prefix + str(layer_num - 3) + '.' + k - if layer_num == num_layer + 5: - k = 'lm_head.weight' - if layer_num == num_layer + 4: - k = 'model.norm.weight' - if layer_num == 1: - k = 'model.embed_tokens.weight' - if k not in weights_all: - weights_all[k] = v - else: - if 'q_proj' in k or 'k_proj' in k or 'v_proj' in k or 'gate_proj' in k or 'up_proj' in k: - weights_all[k] = torch.cat([weights_all[k], v], dim=0) - elif 'o_proj' in k or 'down_proj' in k: - weights_all[k] = torch.cat([weights_all[k], v], dim=1) - else: - weights_all[k] = v - if pre_post_layer is not None: - pre_post_layer.load_hf_weights(weights_all) - if transformer_layer_list is not None: - for layer in transformer_layer_list: - layer.load_hf_weights(weights_all) - del weights_all - gc.collect() - return - -if __name__ == '__main__': - load_ds_weight('fp16', '/nvme/baishihao/llama7b', prefix='model.layers.', num_layer=32) \ No newline at end of file diff --git a/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py index 711406e3f..ea59d24df 100644 --- a/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py @@ -1,34 +1,27 @@ -import os -import torch -import numpy as np from lightllm.common.basemodel import PreAndPostLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, LMHeadWeight, NoTpNormWeight class LlamaPreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) - return - - def load_hf_weights(self, weights): - vob_size = self.network_config_["vocab_size"] - split_indexes = np.linspace(0, vob_size, self.tp_world_size_ + 1, dtype=np.int64) - split_start = split_indexes[self.tp_rank_] - split_end = split_indexes[self.tp_rank_ + 1] - if "model.embed_tokens.weight" in weights: - self.wte_weight_ = self._cuda(weights["model.embed_tokens.weight"][split_start:split_end, :]) - tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False) - if tie_word_embeddings: - self.lm_head_weight_ = self.wte_weight_ - if "lm_head.weight" in weights: - self.lm_head_weight_ = self._cuda(weights["lm_head.weight"][split_start:split_end, :]) - if "model.norm.weight" in weights: - self.final_norm_weight_ = self._cuda(weights["model.norm.weight"]) - return + self.wte_weight_ = EmbeddingWeight( + weight_name="model.embed_tokens.weight", + data_type=self.data_type_, + ) + tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False) + if tie_word_embeddings: + self.lm_head_weight_: LMHeadWeight = self.wte_weight_ + else: + self.lm_head_weight_ = LMHeadWeight( + weight_name="lm_head.weight", + data_type=self.data_type_, + ) - def verify_load(self): - errors = "weights load not ok" - weights = [self.wte_weight_, self.lm_head_weight_, self.final_norm_weight_] - for i in range(len(weights)): - assert weights[i] is not None, "index:" + str(i) + " " + errors + self.final_norm_weight_ = NoTpNormWeight( + weight_name="model.norm.weight", + data_type=self.data_type_, + bias_name=None, + ) return diff --git a/lightllm/models/llama/layer_weights/transformer_layer_weight.py b/lightllm/models/llama/layer_weights/transformer_layer_weight.py index 624717007..6b92272ee 100644 --- a/lightllm/models/llama/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/llama/layer_weights/transformer_layer_weight.py @@ -2,7 +2,7 @@ import math import numpy as np from lightllm.common.basemodel import TransformerLayerWeight -from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, COLMMWeight, NormWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, COLMMWeight, NoTpNormWeight class LlamaTransformerLayerWeight(TransformerLayerWeight): @@ -103,9 +103,9 @@ def _init_ffn(self): ) def _init_norm(self): - self.att_norm_weight_ = NormWeight( + self.att_norm_weight_ = NoTpNormWeight( self._att_norm_weight_name, self.data_type_, bias_name=self._att_norm_bias_name ) - self.ffn_norm_weight_ = NormWeight( + self.ffn_norm_weight_ = NoTpNormWeight( self._ffn_norm_weight_name, self.data_type_, bias_name=self._ffn_norm_bias_name ) diff --git a/lightllm/models/llama/model.py b/lightllm/models/llama/model.py index a228e0025..95465a9e6 100644 --- a/lightllm/models/llama/model.py +++ b/lightllm/models/llama/model.py @@ -8,14 +8,12 @@ from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight -from lightllm.models.llama.layer_weights.ds_load_utils import load_ds_weights -from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights - from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.models.llama.flashattention_infer_struct import FlashAttentionStateInfo from lightllm.models.llama.flashinfer_struct import LlamaFlashInferStateInfo from lightllm.common.basemodel import TpPartBaseModel from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class +from lightllm.utils.envs_utils import get_added_mtp_kv_layer_num from lightllm.utils.log_utils import init_logger from lightllm.utils.envs_utils import get_env_start_args from lightllm.utils.dist_utils import get_dp_world_size, get_current_device_id @@ -91,7 +89,7 @@ def _init_mem_manager(self): dtype=self.data_type, head_num=self.config["num_key_value_heads"] // self.tp_world_size_, head_dim=head_dim_, - layer_num=self.config["num_hidden_layers"], + layer_num=self.config["num_hidden_layers"] + get_added_mtp_kv_layer_num(), mem_fraction=self.mem_fraction, ) return @@ -134,42 +132,6 @@ def _init_custom(self): raise ValueError(f"Unknown RoPE scaling type {scaling_type}") return - def _init_weights(self): - self.pre_post_weight = self.pre_and_post_weight_class( - self.data_type, network_config=self.config, mode=self.mode - ) - self.trans_layers_weight = [ - self.transformer_weight_class( - i, - self.data_type, - network_config=self.config, - mode=self.mode, - quant_cfg=self.quant_cfg, - ) - for i in range(self.config["n_layer"]) - ] - if self.load_way == "HF": - load_hf_weights( - self.data_type, - weight_dir=self.weight_dir_, - pre_post_layer=self.pre_post_weight, - transformer_layer_list=self.trans_layers_weight, - weight_dict=self.weight_dict, - ) - else: - load_ds_weights( - self.data_type, - weight_dir=self.weight_dir_, - pre_post_layer=self.pre_post_weight, - transformer_layer_list=self.trans_layers_weight, - weight_dict=self.weight_dict, - prefix="model.layers.", - num_layer=self.config["n_layer"], - ) - self.pre_post_weight.verify_load() - [weight.verify_load() for weight in self.trans_layers_weight] - return - def _init_to_get_rotary(self, default_base=10000): partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_) if self.config.get("rope_scaling", {}) is None: diff --git a/lightllm/models/minicpm/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/minicpm/layer_weights/pre_and_post_layer_weight.py index 94e1a27e0..0952468d0 100644 --- a/lightllm/models/minicpm/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/minicpm/layer_weights/pre_and_post_layer_weight.py @@ -1,5 +1,4 @@ -import torch -import numpy as np +import copy from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight @@ -12,24 +11,12 @@ def __init__(self, data_type, network_config, mode): self.scale_emb = self.network_config_.get("scale_emb", 1) return - def load_hf_weights(self, weights): - vob_size = self.network_config_["vocab_size"] - split_indexes = np.linspace(0, vob_size, self.tp_world_size_ + 1, dtype=np.int64) - split_start = split_indexes[self.tp_rank_] - split_end = split_indexes[self.tp_rank_ + 1] - if "model.embed_tokens.weight" in weights: - self.wte_weight_ = self._cuda(weights["model.embed_tokens.weight"][split_start:split_end, :]) - if "lm_head.weight" in weights: - self.lm_head_weight_ = self._cuda(weights["lm_head.weight"][split_start:split_end, :]) / self.lm_head_scale - if "model.norm.weight" in weights: - self.final_norm_weight_ = self._cuda(weights["model.norm.weight"]) - - return - def verify_load(self): - if not hasattr(self, "lm_head_weight_"): - self.lm_head_weight_ = self.wte_weight_ / self.lm_head_scale - self.wte_weight_ = self.wte_weight_ * self.scale_emb + if self.lm_head_weight_ == self.wte_weight_: + self.lm_head_weight_ = copy.copy(self.lm_head_weight_) + + self.lm_head_weight_.weight = self.lm_head_weight_.weight / self.lm_head_scale + self.wte_weight_.weight = self.wte_weight_.weight * self.scale_emb errors = "weights load not ok" weights = [self.wte_weight_, self.lm_head_weight_, self.final_norm_weight_] for i in range(len(weights)): diff --git a/lightllm/models/mistral/model.py b/lightllm/models/mistral/model.py index ef7e5d695..d32f51ae7 100644 --- a/lightllm/models/mistral/model.py +++ b/lightllm/models/mistral/model.py @@ -8,8 +8,11 @@ from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.models.llama.flashattention_infer_struct import FlashAttentionStateInfo from lightllm.models.mistral.layer_infer.transformer_layer_infer import MistralTransformerLayerInfer from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class +from lightllm.utils.envs_utils import get_added_mtp_kv_layer_num +from lightllm.utils.envs_utils import get_env_start_args @ModelRegistry("mistral") @@ -40,6 +43,10 @@ def _init_custom(self): self._init_to_get_rotary() return + def _init_inferstate_cls(self): + if get_env_start_args().enable_fa3: + self.infer_state_class = FlashAttentionStateInfo + def _init_mem_manager(self): # Dealing with head_dim_!=n_embed // num_attention_heads scenarios, such as mistral 13B head_dim = self.config["hidden_size"] // self.config["num_attention_heads"] @@ -49,7 +56,7 @@ def _init_mem_manager(self): dtype=self.data_type, head_num=self.config["num_key_value_heads"] // self.tp_world_size_, head_dim=head_dim, - layer_num=self.config["num_hidden_layers"], + layer_num=self.config["num_hidden_layers"] + get_added_mtp_kv_layer_num(), mem_fraction=self.mem_fraction, ) return diff --git a/lightllm/models/mistral_mtp/__init__.py b/lightllm/models/mistral_mtp/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/mistral_mtp/layer_infer/__init__.py b/lightllm/models/mistral_mtp/layer_infer/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/mistral_mtp/layer_infer/post_layer_infer.py b/lightllm/models/mistral_mtp/layer_infer/post_layer_infer.py new file mode 100644 index 000000000..5eac249ba --- /dev/null +++ b/lightllm/models/mistral_mtp/layer_infer/post_layer_infer.py @@ -0,0 +1,9 @@ +from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer + + +class MistralMTPPostLayerInfer(LlamaPostLayerInfer): + """ """ + + def __init__(self, network_config, mode): + super().__init__(network_config, mode) + return diff --git a/lightllm/models/mistral_mtp/layer_infer/pre_layer_infer.py b/lightllm/models/mistral_mtp/layer_infer/pre_layer_infer.py new file mode 100644 index 000000000..25bea1aa6 --- /dev/null +++ b/lightllm/models/mistral_mtp/layer_infer/pre_layer_infer.py @@ -0,0 +1,80 @@ +import torch +from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.models.mistral_mtp.layer_weights.pre_and_post_layer_weight import MistralMTPPreAndPostLayerWeight + + +class MistralMTPPreLayerInfer(LlamaPreLayerInfer): + """ """ + + def __init__(self, network_config, mode): + super().__init__(network_config, mode) + return + + def _mtp_context_forward( + self, input_embdings, infer_state: LlamaInferStateInfo, layer_weight: MistralMTPPreAndPostLayerWeight + ): + tgt_embdings = infer_state.mtp_draft_input_hiddens + assert ( + input_embdings.shape[0] == tgt_embdings.shape[0] + ), f"shape {input_embdings.shape} != shape {tgt_embdings.shape}" + + layer_weight.enorm_weight_.rmsnorm_forward( + input=input_embdings, + eps=self.eps_, + out=input_embdings, + ) + + tgt_embdings = layer_weight.final_norm_weight_.rmsnorm_forward( + input=tgt_embdings, + eps=self.eps_, + alloc_func=self.alloc_tensor, + ) + layer_weight.hnorm_weight_.rmsnorm_forward( + input=tgt_embdings, + eps=self.eps_, + out=tgt_embdings, + ) + + cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1) + + ans_logics = layer_weight.eh_proj_weight_.mm(cat_embdings) + return ans_logics + + def _mtp_token_forward( + self, input_embdings, infer_state: LlamaInferStateInfo, layer_weight: MistralMTPPreAndPostLayerWeight + ): + tgt_embdings = infer_state.mtp_draft_input_hiddens + assert input_embdings.shape[0] == tgt_embdings.shape[0] + + layer_weight.enorm_weight_.rmsnorm_forward( + input=input_embdings, + eps=self.eps_, + out=input_embdings, + ) + + tgt_embdings = layer_weight.final_norm_weight_.rmsnorm_forward( + input=tgt_embdings, + eps=self.eps_, + alloc_func=self.alloc_tensor, + ) + layer_weight.hnorm_weight_.rmsnorm_forward( + input=tgt_embdings, + eps=self.eps_, + out=tgt_embdings, + ) + + cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1) + + ans_logics = layer_weight.eh_proj_weight_.mm(cat_embdings) + return ans_logics + + def context_forward( + self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: MistralMTPPreAndPostLayerWeight + ): + input_embdings = super().context_forward(input_ids, infer_state, layer_weight) + return self._mtp_context_forward(input_embdings, infer_state, layer_weight) + + def token_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: MistralMTPPreAndPostLayerWeight): + input_embdings = super().token_forward(input_ids, infer_state, layer_weight) + return self._mtp_token_forward(input_embdings, infer_state, layer_weight) diff --git a/lightllm/models/mistral_mtp/layer_infer/transformer_layer_infer.py b/lightllm/models/mistral_mtp/layer_infer/transformer_layer_infer.py new file mode 100644 index 000000000..5724f32af --- /dev/null +++ b/lightllm/models/mistral_mtp/layer_infer/transformer_layer_infer.py @@ -0,0 +1,33 @@ +import torch.functional as F +import torch.distributed as dist +import numpy as np +from lightllm.common.basemodel.infer_struct import InferStateInfo +from lightllm.models.mistral.layer_infer.transformer_layer_infer import MistralTransformerLayerInfer +from lightllm.distributed.communication_op import all_reduce +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class MistralMTPTransformerLayerInfer(MistralTransformerLayerInfer): + def __init__(self, layer_num, network_config, mode=[]): + super().__init__(layer_num, network_config, mode) + return + + def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): + input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) + ffn_out = self._ffn(input1, infer_state, layer_weight) + input1 = None + if self.tp_world_size_ > 1: + all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) + return input_embdings + + def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): + input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) + ffn_out = self._ffn(input1, infer_state, layer_weight) + input1 = None + if self.tp_world_size_ > 1: + all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) + return input_embdings diff --git a/lightllm/models/mistral_mtp/layer_weights/__init__.py b/lightllm/models/mistral_mtp/layer_weights/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/mistral_mtp/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/mistral_mtp/layer_weights/pre_and_post_layer_weight.py new file mode 100644 index 000000000..2fbc89cfd --- /dev/null +++ b/lightllm/models/mistral_mtp/layer_weights/pre_and_post_layer_weight.py @@ -0,0 +1,36 @@ +from lightllm.common.basemodel import PreAndPostLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ( + EmbeddingWeight, + LMHeadWeight, + NoTpNormWeight, + ROWMMWeight, +) + + +class MistralMTPPreAndPostLayerWeight(PreAndPostLayerWeight): + def __init__(self, data_type, network_config, mode): + super().__init__(data_type, network_config, mode) + + self.eh_proj_weight_ = ROWMMWeight( + weight_names="mtp.eh_proj.weight", + data_type=self.data_type_, + layer_num=0, + name="eh_proj", + tp_rank=0, + tp_world_size=1, + ) + self.enorm_weight_ = NoTpNormWeight( + weight_name="mtp.enorm.weight", + data_type=self.data_type_, + bias_name=None, + ) + self.hnorm_weight_ = NoTpNormWeight( + weight_name="mtp.hnorm.weight", + data_type=self.data_type_, + bias_name=None, + ) + + self.wte_weight_: EmbeddingWeight = None + self.lm_head_weight_: LMHeadWeight = None + self.final_norm_weight_: NoTpNormWeight = None + return diff --git a/lightllm/models/mistral_mtp/layer_weights/transformer_layer_weight.py b/lightllm/models/mistral_mtp/layer_weights/transformer_layer_weight.py new file mode 100644 index 000000000..6607dbb70 --- /dev/null +++ b/lightllm/models/mistral_mtp/layer_weights/transformer_layer_weight.py @@ -0,0 +1,46 @@ +from lightllm.common.basemodel import TransformerLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, COLMMWeight, NoTpNormWeight + + +class MistralMTPTransformerLayerWeight(TransformerLayerWeight): + def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): + super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + return + + def _init_weight_names(self): + self._gate_weight_name = f"mtp.layers.{self.layer_num_}.mlp.gate_proj.weight" + self._gate_bias_name = None + self._up_weight_name = f"mtp.layers.{self.layer_num_}.mlp.up_proj.weight" + self._up_bias_name = None + self._down_weight_name = f"mtp.layers.{self.layer_num_}.mlp.down_proj.weight" + self._down_bias_name = None + + self._ffn_norm_weight_name = f"mtp.layers.{self.layer_num_}.post_attention_layernorm.weight" + self._ffn_norm_bias_name = None + + def _init_weight(self): + self._init_norm() + self._init_ffn() + + def _init_ffn(self): + self.gate_up_proj = ROWMMWeight( + weight_names=[self._gate_weight_name, self._up_weight_name], + data_type=self.data_type_, + bias_names=[self._gate_bias_name, self._up_bias_name], + quant_cfg=self.quant_cfg, + layer_num=self.layer_num_, + name="gate_up_proj", + ) + self.down_proj = COLMMWeight( + weight_names=self._down_weight_name, + data_type=self.data_type_, + bias_names=self._down_bias_name, + quant_cfg=self.quant_cfg, + layer_num=self.layer_num_, + name="down_proj", + ) + + def _init_norm(self): + self.ffn_norm_weight_ = NoTpNormWeight( + self._ffn_norm_weight_name, self.data_type_, bias_name=self._ffn_norm_bias_name + ) diff --git a/lightllm/models/mistral_mtp/model.py b/lightllm/models/mistral_mtp/model.py new file mode 100644 index 000000000..0132db80f --- /dev/null +++ b/lightllm/models/mistral_mtp/model.py @@ -0,0 +1,63 @@ +from typing import List +from lightllm.models.mistral.model import MistralTpPartModel +from lightllm.models.mistral_mtp.layer_weights.pre_and_post_layer_weight import MistralMTPPreAndPostLayerWeight +from lightllm.models.mistral_mtp.layer_infer.pre_layer_infer import MistralMTPPreLayerInfer +from lightllm.models.mistral_mtp.layer_infer.post_layer_infer import MistralMTPPostLayerInfer +from lightllm.models.mistral_mtp.layer_infer.transformer_layer_infer import MistralMTPTransformerLayerInfer +from lightllm.models.mistral_mtp.layer_weights.transformer_layer_weight import MistralMTPTransformerLayerWeight +from lightllm.common.basemodel import TpPartBaseModel + + +class MistralMTPModel(MistralTpPartModel): + + pre_and_post_weight_class = MistralMTPPreAndPostLayerWeight + pre_layer_infer_class = MistralMTPPreLayerInfer + + transformer_weight_class = MistralMTPTransformerLayerWeight + transformer_layer_infer_class = MistralMTPTransformerLayerInfer + + post_layer_infer_class = MistralMTPPostLayerInfer + + def __init__(self, kvargs: dict): + self._pre_init(kvargs) + super().__init__(kvargs) + return + + def _pre_init(self, kvargs: dict): + self.main_model: TpPartBaseModel = kvargs.pop("main_model") + self.mtp_previous_draft_models: List[TpPartBaseModel] = kvargs.pop("mtp_previous_draft_models") + return + + def _init_some_value(self): + super()._init_some_value() + self.layers_num = 1 + return + + def _init_custom(self): + self._cos_cached = self.main_model._cos_cached + self._sin_cached = self.main_model._sin_cached + return + + def _init_req_manager(self): + self.req_manager = self.main_model.req_manager + return + + def _init_mem_manager(self): + self.mem_manager = self.main_model.mem_manager + return + + def _init_weights(self, start_layer_index=None): + assert start_layer_index is None + + self.config["n_layer"] = 1 + super()._init_weights(start_layer_index=0) + self.pre_post_weight.wte_weight_ = self.main_model.pre_post_weight.wte_weight_ + self.pre_post_weight.lm_head_weight_ = self.main_model.pre_post_weight.lm_head_weight_ + self.pre_post_weight.final_norm_weight_ = self.main_model.pre_post_weight.final_norm_weight_ + return + + def _init_infer_layer(self, start_layer_index=None): + assert start_layer_index is None + self.config["n_layer"] = 1 + super()._init_infer_layer(start_layer_index=0) + return diff --git a/lightllm/models/phi3/layer_infer/transformer_layer_infer.py b/lightllm/models/phi3/layer_infer/transformer_layer_infer.py index ce27e3ee5..806c59365 100755 --- a/lightllm/models/phi3/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/phi3/layer_infer/transformer_layer_infer.py @@ -1,9 +1,5 @@ import torch -import torch.functional as F -import torch.distributed as dist -import numpy as np from functools import partial - from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.models.phi3.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.models.phi3.triton_kernel.context_flashattention_nopad import ( diff --git a/lightllm/models/qwen/infer_struct.py b/lightllm/models/qwen/infer_struct.py index deff17ce2..d575f6d95 100644 --- a/lightllm/models/qwen/infer_struct.py +++ b/lightllm/models/qwen/infer_struct.py @@ -11,13 +11,13 @@ def __init__(self): self.position_sin = None self.logn_values = None - def init_some_extra_state(self, model, input_ids: torch.Tensor): + def init_some_extra_state(self, model): use_dynamic_ntk = model.config.get("use_dynamic_ntk", False) if not use_dynamic_ntk: - super().init_some_extra_state(model, input_ids) + super().init_some_extra_state(model) return - InferStateInfo.init_some_extra_state(self, model, input_ids) + InferStateInfo.init_some_extra_state(self, model) if self.is_prefill: position_ids = self.position_ids self.position_sin = [] diff --git a/lightllm/models/qwen/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen/layer_weights/pre_and_post_layer_weight.py index 95af6ecd3..00f68eee6 100644 --- a/lightllm/models/qwen/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/qwen/layer_weights/pre_and_post_layer_weight.py @@ -1,40 +1,22 @@ import torch import numpy as np from lightllm.common.basemodel import PreAndPostLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, LMHeadWeight, NoTpNormWeight class QwenPreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) - return - - def load_hf_weights(self, weights): - - vob_size = self.network_config_["vocab_size"] - split_vob_size = vob_size // self.tp_world_size_ - - if "transformer.wte.weight" in weights: - self.wte_weight_ = self._cuda( - weights["transformer.wte.weight"][ - split_vob_size * self.tp_rank_ : split_vob_size * (self.tp_rank_ + 1), : - ] - ) - if "lm_head.weight" in weights: - self.lm_head_weight_ = self._cuda( - weights["lm_head.weight"][split_vob_size * self.tp_rank_ : split_vob_size * (self.tp_rank_ + 1), :] - ) - if "transformer.ln_f.weight" in weights: - self.final_norm_weight_ = self._cuda(weights["transformer.ln_f.weight"]) - - return - - def verify_load(self): - errors = "weights load not ok" - weights = [ - self.wte_weight_, - self.lm_head_weight_, - self.final_norm_weight_, - ] - for i in range(len(weights)): - assert weights[i] is not None, "index:" + str(i) + " " + errors + self.wte_weight_ = EmbeddingWeight( + weight_name="transformer.wte.weight", + data_type=self.data_type_, + ) + self.lm_head_weight_ = LMHeadWeight( + weight_name="lm_head.weight", + data_type=self.data_type_, + ) + self.final_norm_weight_ = NoTpNormWeight( + weight_name="transformer.ln_f.weight", + data_type=self.data_type_, + ) return diff --git a/lightllm/models/qwen/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen/layer_weights/transformer_layer_weight.py index 7c710af2d..9afb964ad 100755 --- a/lightllm/models/qwen/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen/layer_weights/transformer_layer_weight.py @@ -1,6 +1,4 @@ import torch -import math -import numpy as np from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight diff --git a/lightllm/models/qwen2/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen2/layer_weights/pre_and_post_layer_weight.py index 5735b0339..a8a57c02e 100644 --- a/lightllm/models/qwen2/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/qwen2/layer_weights/pre_and_post_layer_weight.py @@ -1,37 +1,7 @@ -import torch -import numpy as np -from lightllm.common.basemodel import PreAndPostLayerWeight +from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight -class Qwen2PreAndPostLayerWeight(PreAndPostLayerWeight): +class Qwen2PreAndPostLayerWeight(LlamaPreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) return - - def load_hf_weights(self, weights): - vob_size = self.network_config_["vocab_size"] - split_indexes = np.linspace(0, vob_size, self.tp_world_size_ + 1, dtype=np.int64) - split_start = split_indexes[self.tp_rank_] - split_end = split_indexes[self.tp_rank_ + 1] - if "model.embed_tokens.weight" in weights: - self.wte_weight_ = self._cuda(weights["model.embed_tokens.weight"][split_start:split_end, :]) - tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False) - if tie_word_embeddings: - self.lm_head_weight_ = self.wte_weight_ - if "lm_head.weight" in weights: - self.lm_head_weight_ = self._cuda(weights["lm_head.weight"][split_start:split_end, :]) - if "model.norm.weight" in weights: - self.final_norm_weight_ = self._cuda(weights["model.norm.weight"]) - - return - - def verify_load(self): - errors = "weights load not ok" - weights = [ - self.wte_weight_, - self.lm_head_weight_, - self.final_norm_weight_, - ] - for i in range(len(weights)): - assert weights[i] is not None, "index:" + str(i) + " " + errors - return diff --git a/lightllm/models/qwen2/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen2/layer_weights/transformer_layer_weight.py index 2e2c0d3bb..6962818c4 100644 --- a/lightllm/models/qwen2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen2/layer_weights/transformer_layer_weight.py @@ -1,6 +1,3 @@ -import torch -import math -import numpy as np from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight diff --git a/lightllm/models/qwen2/model.py b/lightllm/models/qwen2/model.py index 5b756aadf..d2f067c42 100644 --- a/lightllm/models/qwen2/model.py +++ b/lightllm/models/qwen2/model.py @@ -3,6 +3,7 @@ from lightllm.models.qwen2.layer_weights.transformer_layer_weight import Qwen2TransformerLayerWeight from lightllm.models.llama.model import LlamaTpPartModel from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class +from lightllm.utils.envs_utils import get_added_mtp_kv_layer_num @ModelRegistry("qwen2") @@ -41,12 +42,13 @@ def _init_mem_manager(self): head_dim_ = self.config["hidden_size"] // self.config["num_attention_heads"] head_dim_ = self.config.get("head_dim", head_dim_) tp_k_head_num_ = max(self.config["num_key_value_heads"] // self.tp_world_size_, 1) + self.mem_manager = select_mem_manager_class()( self.max_total_token_num, dtype=self.data_type, head_num=tp_k_head_num_, head_dim=head_dim_, - layer_num=self.config["num_hidden_layers"], + layer_num=self.config["num_hidden_layers"] + get_added_mtp_kv_layer_num(), mem_fraction=self.mem_fraction, ) return diff --git a/lightllm/models/qwen2_reward/layer_infer/post_layer_infer.py b/lightllm/models/qwen2_reward/layer_infer/post_layer_infer.py index 22ec8fd43..e9b41d7ab 100644 --- a/lightllm/models/qwen2_reward/layer_infer/post_layer_infer.py +++ b/lightllm/models/qwen2_reward/layer_infer/post_layer_infer.py @@ -3,7 +3,6 @@ from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer from lightllm.models.qwen2_reward.layer_weights.pre_and_post_layer_weight import Qwen2RewardPreAndPostLayerWeight -from einops import rearrange class Qwen2RewardPostLayerInfer(LlamaPostLayerInfer): @@ -15,8 +14,8 @@ def token_forward( input_embdings = None last_input = self._norm(last_input, infer_state, layer_weight) - last_input = torch.addmm(layer_weight.score_up_bias, last_input, layer_weight.score_up_weight) + last_input = layer_weight.score_up_weight_.mm(last_input) last_input = torch.nn.functional.relu(last_input) - score = torch.addmm(layer_weight.score_down_bias, last_input, layer_weight.score_down_weight) + score = layer_weight.score_down_weight_.mm(last_input) return score diff --git a/lightllm/models/qwen2_reward/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen2_reward/layer_weights/pre_and_post_layer_weight.py index a56c5d6cb..7cf636622 100644 --- a/lightllm/models/qwen2_reward/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/qwen2_reward/layer_weights/pre_and_post_layer_weight.py @@ -1,50 +1,27 @@ import torch import numpy as np from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight -from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, COLMMWeight, NormWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, COLMMWeight class Qwen2RewardPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) - return - - def load_hf_weights(self, weights): - vob_size = self.network_config_["vocab_size"] - split_indexes = np.linspace(0, vob_size, self.tp_world_size_ + 1, dtype=np.int64) - split_start = split_indexes[self.tp_rank_] - split_end = split_indexes[self.tp_rank_ + 1] - if "model.embed_tokens.weight" in weights: - self.wte_weight_ = self._cuda(weights["model.embed_tokens.weight"][split_start:split_end, :]) - tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False) - if tie_word_embeddings: - self.lm_head_weight_ = self.wte_weight_ - - if "model.norm.weight" in weights: - self.final_norm_weight_ = self._cuda(weights["model.norm.weight"]) - - if "score.0.weight" in weights: - self.score_up_weight = self._cuda(weights["score.0.weight"]).transpose(0, 1) - if "score.0.bias" in weights: - self.score_up_bias = self._cuda(weights["score.0.bias"]) - - if "score.2.weight" in weights: - self.score_down_weight = self._cuda(weights["score.2.weight"]).transpose(0, 1) - if "score.2.bias" in weights: - self.score_down_bias = self._cuda(weights["score.2.bias"]) - - return - - def verify_load(self): - errors = "weights load not ok" - weights = [ - self.wte_weight_, - self.final_norm_weight_, - self.score_up_weight, - self.score_up_bias, - self.score_down_weight, - self.score_down_bias, - ] - for i in range(len(weights)): - assert weights[i] is not None, "index:" + str(i) + " " + errors + del self.lm_head_weight_ + self.score_up_weight_ = ROWMMWeight( + weight_names="score.0.weight", + bias_names="score.0.bias", + data_type=self.data_type_, + name="score_up_weight", + tp_rank=0, + tp_world_size=1, + ) + self.score_down_weight_ = ROWMMWeight( + weight_names="score.2.weight", + bias_names="score.2.bias", + data_type=self.data_type_, + name="score_down_weight", + tp_rank=0, + tp_world_size=1, + ) return diff --git a/lightllm/models/qwen2_vl/infer_struct.py b/lightllm/models/qwen2_vl/infer_struct.py index ce7938b6a..838590325 100644 --- a/lightllm/models/qwen2_vl/infer_struct.py +++ b/lightllm/models/qwen2_vl/infer_struct.py @@ -16,10 +16,10 @@ def __init__(self): self.position_cos = None self.position_sin = None - def init_some_extra_state(self, model, input_ids: torch.Tensor): + def init_some_extra_state(self, model): rope_scaling = model.config.get("rope_scaling", {}) self.rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) - InferStateInfo.init_some_extra_state(self, model, input_ids) + InferStateInfo.init_some_extra_state(self, model) if self.is_prefill: self.position_ids = self.get_mrope_position(self.multimodal_params) else: @@ -38,7 +38,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): if get_env_start_args().enable_fa3: self.max_seq_len = self.max_kv_seq_len self.q_max_seq_len = self.max_q_seq_len - self.init_flash_attention_state_func(model, input_ids) + self.init_flash_attention_state_func(model) return def get_mrope_position(self, multimodal_params: List[dict]) -> torch.Tensor: diff --git a/lightllm/models/qwen2_vl/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen2_vl/layer_infer/transformer_layer_infer.py index d05f8d3b5..19e17c36e 100755 --- a/lightllm/models/qwen2_vl/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen2_vl/layer_infer/transformer_layer_infer.py @@ -1,10 +1,5 @@ import torch -import torch.functional as F -import torch.distributed as dist -import numpy as np from typing import Tuple -from functools import partial - from lightllm.models.qwen2_vl.triton_kernel.mrope import mrope_triton_fused from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer diff --git a/lightllm/models/qwen3/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3/layer_infer/transformer_layer_infer.py index c69c7f4fb..20f135e76 100644 --- a/lightllm/models/qwen3/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3/layer_infer/transformer_layer_infer.py @@ -1,14 +1,8 @@ -import os import torch -import torch.functional as F -import torch.distributed as dist -import numpy as np -import triton from typing import Tuple from lightllm.models.qwen3.layer_weights.transformer_layer_weight import Qwen3TransformerLayerWeight from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd from lightllm.models.qwen3.triton_kernel.qk_norm import qk_rmsnorm_forward diff --git a/lightllm/models/qwen3/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3/layer_weights/transformer_layer_weight.py index 4c0ef586f..86b9e172a 100644 --- a/lightllm/models/qwen3/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3/layer_weights/transformer_layer_weight.py @@ -1,6 +1,6 @@ from lightllm.models.qwen2.layer_weights.transformer_layer_weight import Qwen2TransformerLayerWeight from lightllm.common.basemodel.layer_weights.meta_weights import ( - NormWeight, + NoTpNormWeight, ) @@ -20,5 +20,5 @@ def _init_weight_names(self): def _init_norm(self): super()._init_norm() - self.q_norm_weight_ = NormWeight(weight_name=self._q_norm_name, data_type=self.data_type_) - self.k_norm_weight_ = NormWeight(weight_name=self._k_norm_name, data_type=self.data_type_) + self.q_norm_weight_ = NoTpNormWeight(weight_name=self._q_norm_name, data_type=self.data_type_) + self.k_norm_weight_ = NoTpNormWeight(weight_name=self._k_norm_name, data_type=self.data_type_) diff --git a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py index 45f1f59d7..10a734e5c 100644 --- a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py @@ -8,7 +8,6 @@ from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd from functools import partial @@ -62,17 +61,17 @@ def _get_qkv( input = input.view(-1, self.embed_dim_) q = layer_weight.q_proj.mm(input) cache_kv = layer_weight.kv_proj.mm(input).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) - rmsnorm_forward( + + layer_weight.q_norm_weight_.rmsnorm_forward( q.view(-1, self.head_dim_), - weight=layer_weight.q_norm_weight_.weight, eps=self.eps_, out=q.view(-1, self.head_dim_), ) - cache_kv[:, : self.tp_k_head_num_, :] = rmsnorm_forward( - cache_kv[:, : self.tp_k_head_num_, :].reshape(-1, cache_kv.shape[-1]), - weight=layer_weight.k_norm_weight_.weight, + cache_kv[:, : self.tp_k_head_num_, :] = layer_weight.k_norm_weight_.rmsnorm_forward( + input=cache_kv[:, : self.tp_k_head_num_, :].reshape(-1, cache_kv.shape[-1]), eps=self.eps_, + alloc_func=self.alloc_tensor, ).view(-1, self.tp_k_head_num_, cache_kv.shape[-1]) rotary_emb_fwd( @@ -95,23 +94,22 @@ def _tpsp_get_qkv( (sp_token_num * self.tp_world_size_, hidden_dim), dtype=input.dtype, device=input.device ) all_gather_into_tensor(gather_input, input, group=infer_state.dist_group, async_op=False) - input = gather_input[0 : len(infer_state.position_cos), :] + input = gather_input[0 : len(infer_state.input_ids), :] input = input.view(-1, self.embed_dim_) q = layer_weight.q_proj.mm(input) cache_kv = layer_weight.kv_proj.mm(input).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) - rmsnorm_forward( + layer_weight.q_norm_weight_.rmsnorm_forward( q.view(-1, self.head_dim_), - weight=layer_weight.q_norm_weight_.weight, eps=self.eps_, out=q.view(-1, self.head_dim_), ) - cache_kv[:, : self.tp_k_head_num_, :] = rmsnorm_forward( + cache_kv[:, : self.tp_k_head_num_, :] = layer_weight.k_norm_weight_.rmsnorm_forward( cache_kv[:, : self.tp_k_head_num_, :].reshape(-1, cache_kv.shape[-1]), - weight=layer_weight.k_norm_weight_.weight, eps=self.eps_, + alloc_func=self.alloc_tensor, ).view(-1, self.tp_k_head_num_, cache_kv.shape[-1]) rotary_emb_fwd( diff --git a/lightllm/models/qwen3_moe_mtp/__init__.py b/lightllm/models/qwen3_moe_mtp/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/qwen3_moe_mtp/layer_infer/__init__.py b/lightllm/models/qwen3_moe_mtp/layer_infer/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/qwen3_moe_mtp/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_moe_mtp/layer_infer/transformer_layer_infer.py new file mode 100644 index 000000000..d21917340 --- /dev/null +++ b/lightllm/models/qwen3_moe_mtp/layer_infer/transformer_layer_infer.py @@ -0,0 +1,37 @@ +import os +import torch +import torch.functional as F +import torch.distributed as dist +import numpy as np +import triton +from typing import Tuple +from lightllm.common.basemodel.infer_struct import InferStateInfo +from lightllm.models.qwen3_moe.layer_infer.transformer_layer_infer import Qwen3MOETransformerLayerInfer +from lightllm.distributed.communication_op import all_reduce +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class Qwen3MOEMTPTransformerLayerInfer(Qwen3MOETransformerLayerInfer): + def __init__(self, layer_num, network_config, mode=[]): + super().__init__(layer_num, network_config, mode) + return + + def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): + input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) + ffn_out = self._ffn(input1, infer_state, layer_weight) + input1 = None + if self.tp_world_size_ > 1: + all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) + return input_embdings + + def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): + input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) + ffn_out = self._ffn(input1, infer_state, layer_weight) + input1 = None + if self.tp_world_size_ > 1: + all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) + return input_embdings diff --git a/lightllm/models/qwen3_moe_mtp/layer_weights/__init__.py b/lightllm/models/qwen3_moe_mtp/layer_weights/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py new file mode 100644 index 000000000..6cc447a59 --- /dev/null +++ b/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py @@ -0,0 +1,36 @@ +import numpy as np +from lightllm.common.basemodel import PreAndPostLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ( + EmbeddingWeight, + ROWMMWeight, + LMHeadWeight, + NoTpNormWeight, +) + + +class Qwen3MOEMTPPreAndPostLayerWeight(PreAndPostLayerWeight): + def __init__(self, data_type, network_config, mode): + super().__init__(data_type, network_config, mode) + + self.eh_proj_weight_ = ROWMMWeight( + weight_names="model.layers.0.proj.weight", + data_type=self.data_type_, + name="eh_proj", + tp_rank=0, + tp_world_size=1, + ) + self.enorm_weight_ = NoTpNormWeight( + weight_name="model.layers.0.norm_after_embedding.weight", + data_type=self.data_type_, + bias_name=None, + ) + self.hnorm_weight_ = NoTpNormWeight( + weight_name="model.layers.0.norm_before_output.weight", + data_type=self.data_type_, + bias_name=None, + ) + # 与Qwen3MOE模型共享 + self.wte_weight_: EmbeddingWeight = None + self.lm_head_weight_: LMHeadWeight = None + self.final_norm_weight_: NoTpNormWeight = None + return diff --git a/lightllm/models/qwen3_moe_mtp/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_moe_mtp/layer_weights/transformer_layer_weight.py new file mode 100644 index 000000000..22d4d1950 --- /dev/null +++ b/lightllm/models/qwen3_moe_mtp/layer_weights/transformer_layer_weight.py @@ -0,0 +1,21 @@ +import os +from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import NoTpNormWeight + + +class Qwen3MOEMTPTransformerLayerWeight(Qwen3MOETransformerLayerWeight): + def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): + super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + return + + def _init_weight(self): + self._init_norm() + if self.is_moe: + self._init_moe() + else: + self._init_ffn() + + def _init_norm(self): + self.ffn_norm_weight_ = NoTpNormWeight( + self._ffn_norm_weight_name, self.data_type_, bias_name=self._ffn_norm_bias_name + ) diff --git a/lightllm/models/qwen3_moe_mtp/model.py b/lightllm/models/qwen3_moe_mtp/model.py new file mode 100644 index 000000000..72aadbda8 --- /dev/null +++ b/lightllm/models/qwen3_moe_mtp/model.py @@ -0,0 +1,57 @@ +from typing import List +from lightllm.models.qwen3_moe.model import Qwen3MOEModel +from lightllm.models.qwen3_moe_mtp.layer_weights.pre_and_post_layer_weight import Qwen3MOEMTPPreAndPostLayerWeight +from lightllm.models.deepseek_mtp.layer_infer.pre_layer_infer import Deepseek3MTPPreLayerInfer +from lightllm.models.qwen3_moe_mtp.layer_infer.transformer_layer_infer import Qwen3MOEMTPTransformerLayerInfer +from lightllm.models.qwen3_moe_mtp.layer_weights.transformer_layer_weight import Qwen3MOEMTPTransformerLayerWeight +from lightllm.common.basemodel import TpPartBaseModel + + +class Qwen3MOEMTPModel(Qwen3MOEModel): + + pre_and_post_weight_class = Qwen3MOEMTPPreAndPostLayerWeight + pre_layer_infer_class = Deepseek3MTPPreLayerInfer + + transformer_weight_class = Qwen3MOEMTPTransformerLayerWeight + transformer_layer_infer_class = Qwen3MOEMTPTransformerLayerInfer + + def __init__(self, kvargs: dict): + self._pre_init(kvargs) + super().__init__(kvargs) + return + + def _pre_init(self, kvargs: dict): + self.main_model: TpPartBaseModel = kvargs.pop("main_model") + self.mtp_previous_draft_models: List[TpPartBaseModel] = kvargs.pop("mtp_previous_draft_models") + return + + def _init_custom(self): + self._cos_cached = self.main_model._cos_cached + self._sin_cached = self.main_model._sin_cached + return + + def _init_req_manager(self): + self.req_manager = self.main_model.req_manager + return + + def _init_mem_manager(self): + self.mem_manager = self.main_model.mem_manager + return + + def _init_weights(self, start_layer_index=None): + assert start_layer_index is None + mtp_index = len(self.mtp_previous_draft_models) + super()._init_weights(start_layer_index=mtp_index) + self.pre_post_weight.wte_weight_ = self.main_model.pre_post_weight.wte_weight_ + self.pre_post_weight.lm_head_weight_ = self.main_model.pre_post_weight.lm_head_weight_ + self.pre_post_weight.final_norm_weight_ = self.main_model.pre_post_weight.final_norm_weight_ + return + + def _init_infer_layer(self, start_layer_index=None): + assert start_layer_index is None + total_pre_layers_num = len(self.main_model.layers_infer) + total_pre_layers_num += sum( + [len(previous_model.layers_infer) for previous_model in self.mtp_previous_draft_models] + ) + super()._init_infer_layer(start_layer_index=total_pre_layers_num) + return diff --git a/lightllm/models/qwen3_vl/layer_infer/pre_layer_infer.py b/lightllm/models/qwen3_vl/layer_infer/pre_layer_infer.py index c79bb7665..96e453ebe 100644 --- a/lightllm/models/qwen3_vl/layer_infer/pre_layer_infer.py +++ b/lightllm/models/qwen3_vl/layer_infer/pre_layer_infer.py @@ -1,11 +1,10 @@ import torch import torch.distributed as dist - -from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo from lightllm.common.basemodel.triton_kernel.multimodal_emb import multimodal_emb from lightllm.distributed.communication_op import all_reduce from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer +from ..layer_weights.pre_and_post_layer_weight import Qwen3VLPreAndPostLayerWeight class Qwen3VLMultimodalPreLayerInfer(LlamaMultimodalPreLayerInfer): @@ -13,13 +12,15 @@ def __init__(self, network_config, mode): super().__init__(network_config, mode) return - def context_forward(self, input_ids, infer_state: Qwen3VLInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): + def context_forward( + self, input_ids, infer_state: Qwen3VLInferStateInfo, layer_weight: Qwen3VLPreAndPostLayerWeight + ): img_start_token_ids = [] img_token_lens = [] img_start_locs_in_cache = [] - device = layer_weight.wte_weight_.device - dtype = layer_weight.wte_weight_.dtype - hidden_size = layer_weight.wte_weight_.shape[1] + device = layer_weight.wte_weight_.weight.device + dtype = layer_weight.wte_weight_.weight.dtype + hidden_size = layer_weight.wte_weight_.weight.shape[1] for batch_id, p in enumerate(infer_state.multimodal_params): for img in p["images"] + p["audios"]: @@ -55,13 +56,13 @@ def context_forward(self, input_ids, infer_state: Qwen3VLInferStateInfo, layer_w multimodal_emb( out=out, prompt_ids=input_ids, - text_weight_embs=layer_weight.wte_weight_, + text_weight_embs=layer_weight.wte_weight_.weight, embed_cache=cpu_embed_cache_tensor, img_token_lens=infer_state.img_token_lens, img_start_token_ids=infer_state.img_start_token_ids, img_start_locs_in_cache=infer_state.img_start_locs_in_cache, - tp_text_start_token_id=self.vob_start_id_, - tp_text_end_token_id=self.vob_end_id_, + tp_text_start_token_id=layer_weight.wte_weight_.tp_vocab_start_id, + tp_text_end_token_id=layer_weight.wte_weight_.tp_vocab_end_id, tp_world_size=self.tp_world_size_, ) if self.tp_world_size_ > 1: diff --git a/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py index a9a6954d6..175340a77 100644 --- a/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py @@ -11,7 +11,6 @@ from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo -from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd from lightllm.distributed import all_reduce diff --git a/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py index 0bc878d96..b155f8b90 100644 --- a/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py @@ -1,15 +1,9 @@ import torch -import torch.functional as F import torch.distributed as dist -import numpy as np -from functools import partial from typing import Tuple -from lightllm.common.basemodel.infer_struct import InferStateInfo from lightllm.models.qwen2_vl.triton_kernel.mrope import mrope_triton_fused from lightllm.models.qwen3_moe.layer_infer.transformer_layer_infer import Qwen3MOETransformerLayerInfer from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight -from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer -from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo from lightllm.models.qwen3.triton_kernel.qk_norm import qk_rmsnorm_forward from lightllm.distributed import all_reduce diff --git a/lightllm/models/qwen3_vl_moe/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen3_vl_moe/layer_weights/pre_and_post_layer_weight.py index 0a7f82a93..b1f5ee660 100644 --- a/lightllm/models/qwen3_vl_moe/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/qwen3_vl_moe/layer_weights/pre_and_post_layer_weight.py @@ -1,37 +1,24 @@ -import torch import numpy as np from lightllm.common.basemodel import PreAndPostLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, LMHeadWeight, NoTpNormWeight class Qwen3VLMOEPreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) - return - - def load_hf_weights(self, weights): - vob_size = self.network_config_["vocab_size"] - split_indexes = np.linspace(0, vob_size, self.tp_world_size_ + 1, dtype=np.int64) - split_start = split_indexes[self.tp_rank_] - split_end = split_indexes[self.tp_rank_ + 1] - if "model.language_model.embed_tokens.weight" in weights: - self.wte_weight_ = self._cuda(weights["model.language_model.embed_tokens.weight"][split_start:split_end, :]) - tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False) - if tie_word_embeddings: - self.lm_head_weight_ = self.wte_weight_ - if "lm_head.weight" in weights: - self.lm_head_weight_ = self._cuda(weights["lm_head.weight"][split_start:split_end, :]) - if "model.language_model.norm.weight" in weights: - self.final_norm_weight_ = self._cuda(weights["model.language_model.norm.weight"]) - - return - - def verify_load(self): - errors = "weights load not ok" - weights = [ - self.wte_weight_, - self.lm_head_weight_, - self.final_norm_weight_, - ] - for i in range(len(weights)): - assert weights[i] is not None, "index:" + str(i) + " " + errors - return + self.wte_weight_ = EmbeddingWeight( + weight_name="model.language_model.embed_tokens.weight", + data_type=self.data_type_, + ) + tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False) + if tie_word_embeddings: + self.lm_head_weight_: LMHeadWeight = self.wte_weight_ + else: + self.lm_head_weight_ = LMHeadWeight( + weight_name="lm_head.weight", + data_type=self.data_type_, + ) + self.final_norm_weight_ = NoTpNormWeight( + weight_name="model.language_model.norm.weight", + data_type=self.data_type_, + ) diff --git a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py index b8500980a..f43907307 100644 --- a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py +++ b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py @@ -32,9 +32,9 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei img_start_token_ids = [] img_token_lens = [] img_start_locs_in_cache = [] - device = layer_weight.wte_weight_.device - dtype = layer_weight.wte_weight_.dtype - hidden_size = layer_weight.wte_weight_.shape[1] + device = layer_weight.wte_weight_.weight.device + dtype = layer_weight.wte_weight_.weight.dtype + hidden_size = layer_weight.wte_weight_.weight.shape[1] for batch_id, p in enumerate(infer_state.multimodal_params): for img in p["images"] + p["audios"]: @@ -68,13 +68,13 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei multimodal_emb( out=out, prompt_ids=input_ids, - text_weight_embs=layer_weight.wte_weight_, + text_weight_embs=layer_weight.wte_weight_.weight, embed_cache=cpu_embed_cache_tensor, img_token_lens=img_token_lens, img_start_token_ids=img_start_token_ids, img_start_locs_in_cache=img_start_locs_in_cache, - tp_text_start_token_id=self.vob_start_id_, - tp_text_end_token_id=self.vob_end_id_, + tp_text_start_token_id=layer_weight.wte_weight_.tp_vocab_start_id, + tp_text_end_token_id=layer_weight.wte_weight_.tp_vocab_end_id, tp_world_size=self.tp_world_size_, ) if self.tp_world_size_ > 1: diff --git a/lightllm/models/stablelm/layer_infer/transformer_layer_infer.py b/lightllm/models/stablelm/layer_infer/transformer_layer_infer.py index 53171ce53..395ed4ba1 100755 --- a/lightllm/models/stablelm/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/stablelm/layer_infer/transformer_layer_infer.py @@ -1,11 +1,7 @@ import torch -import torch.functional as F -import torch.distributed as dist -import numpy as np from functools import partial from typing import Tuple from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd -from lightllm.models.bloom.triton_kernel.layernorm import layernorm_forward from lightllm.models.stablelm.layer_weights.transformer_layer_weight import StablelmTransformerLayerWeight from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.models.llama.infer_struct import LlamaInferStateInfo @@ -57,19 +53,17 @@ def _tpsp_get_o(self, input, infer_state, layer_weight) -> Tuple[torch.Tensor, t def _att_norm( self, input, infer_state: LlamaInferStateInfo, layer_weight: StablelmTransformerLayerWeight ) -> torch.Tensor: - return layernorm_forward( - input.view(-1, self.embed_dim_), - weight=layer_weight.att_norm_weight_.weight, - bias=layer_weight.att_norm_weight_.bias, + return layer_weight.att_norm_weight_.layernorm_forward( + input=input.view(-1, self.embed_dim_), eps=self.eps_, + alloc_func=self.alloc_tensor, ) def _ffn_norm( self, input, infer_state: LlamaInferStateInfo, layer_weight: StablelmTransformerLayerWeight ) -> torch.Tensor: - return layernorm_forward( - input.view(-1, self.embed_dim_), - weight=layer_weight.ffn_norm_weight_.weight, - bias=layer_weight.ffn_norm_weight_.bias, + return layer_weight.ffn_norm_weight_.layernorm_forward( + input=input.view(-1, self.embed_dim_), eps=self.eps_, + alloc_func=self.alloc_tensor, ) diff --git a/lightllm/models/stablelm/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/stablelm/layer_weights/pre_and_post_layer_weight.py index 80966c7b4..0ad3e07df 100755 --- a/lightllm/models/stablelm/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/stablelm/layer_weights/pre_and_post_layer_weight.py @@ -1,34 +1,12 @@ -import torch -import numpy as np -from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight +from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight, NoTpNormWeight class StableLMPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) - return - - def load_hf_weights(self, weights): - vob_size = self.network_config_["vocab_size"] - split_indexes = np.linspace(0, vob_size, self.tp_world_size_ + 1, dtype=np.int64) - split_start = split_indexes[self.tp_rank_] - split_end = split_indexes[self.tp_rank_ + 1] - if "model.embed_tokens.weight" in weights: - # print(weights['model.embed_tokens.weight'].shape) - self.wte_weight_ = self._cuda(weights["model.embed_tokens.weight"][split_start:split_end, :]) - if "lm_head.weight" in weights: - # print(weights['lm_head.weight'].shape) - self.lm_head_weight_ = self._cuda(weights["lm_head.weight"][split_start:split_end, :]) - if "model.norm.weight" in weights: - self.final_norm_weight_ = self._cuda(weights["model.norm.weight"]) - if "model.norm.bias" in weights: - self.final_norm_bias_ = self._cuda(weights["model.norm.bias"]) - - return - - def verify_load(self): - errors = "weights load not ok" - weights = [self.wte_weight_, self.lm_head_weight_, self.final_norm_weight_, self.final_norm_bias_] - for i in range(len(weights)): - assert weights[i] is not None, "index:" + str(i) + " " + errors + self.final_norm_weight_ = NoTpNormWeight( + weight_name="model.norm.weight", + data_type=self.data_type_, + bias_name="model.norm.bias", + ) return diff --git a/lightllm/models/stablelm/layer_weights/transformer_layer_weight.py b/lightllm/models/stablelm/layer_weights/transformer_layer_weight.py index dc7b00862..a1a73f674 100755 --- a/lightllm/models/stablelm/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/stablelm/layer_weights/transformer_layer_weight.py @@ -1,5 +1,4 @@ from lightllm.models.qwen2.layer_weights.transformer_layer_weight import Qwen2TransformerLayerWeight -from lightllm.common.basemodel.layer_weights.meta_weights import NormWeight class StablelmTransformerLayerWeight(Qwen2TransformerLayerWeight): diff --git a/lightllm/models/starcoder/layer_infer/pre_layer_infer.py b/lightllm/models/starcoder/layer_infer/pre_layer_infer.py index a98323ab5..52072a348 100644 --- a/lightllm/models/starcoder/layer_infer/pre_layer_infer.py +++ b/lightllm/models/starcoder/layer_infer/pre_layer_infer.py @@ -1,13 +1,8 @@ -import torch -import torch.functional as F import torch.distributed as dist -import numpy as np -from lightllm.models.starcoder.layer_weights.pre_and_post_layer_weight import PreAndPostLayerWeight +from lightllm.models.starcoder.layer_weights.pre_and_post_layer_weight import StarcoderPreAndPostLayerWeight from lightllm.common.basemodel.infer_struct import InferStateInfo -from lightllm.utils.infer_utils import mark_cost_time from lightllm.common.basemodel import PreLayerInfer -from lightllm.models.llama.triton_kernel.embedding import embedding from lightllm.distributed.communication_op import all_reduce @@ -16,47 +11,27 @@ class StarcoderPreLayerInfer(PreLayerInfer): def __init__(self, network_config, mode): super().__init__(network_config, mode) - assert network_config["vocab_size"] % self.tp_world_size_ == 0 - self.tp_vocab_size_ = network_config["vocab_size"] // self.tp_world_size_ - self.embed_dim_ = network_config["hidden_size"] self.layer_norm_eps_ = network_config["layer_norm_epsilon"] - self.vob_start_id_ = self.tp_vocab_size_ * self.tp_rank_ - self.vob_end_id_ = self.tp_vocab_size_ * (self.tp_rank_ + 1) - def context_forward(self, input_ids, infer_state: InferStateInfo, layer_weight: PreAndPostLayerWeight): - total_token_num = infer_state.total_token_num - input_ids = input_ids[0:total_token_num] - - input_embdings = self.alloc_tensor( - (input_ids.shape[0], layer_weight.wte_weight_.shape[1]), dtype=layer_weight.data_type_ - ) - embedding(input_ids, layer_weight.wte_weight_, self.vob_start_id_, self.vob_end_id_, input_embdings) + def context_forward(self, input_ids, infer_state: InferStateInfo, layer_weight: StarcoderPreAndPostLayerWeight): + input_embdings = layer_weight.wte_weight_.embedding(input_ids=input_ids, alloc_func=self.alloc_tensor) if self.tp_world_size_ > 1: all_reduce(input_embdings, group=infer_state.dist_group, op=dist.ReduceOp.SUM, async_op=False) - position_embeds = self.alloc_tensor( - (infer_state.position_ids.shape[0], layer_weight.wpe_weight_.shape[1]), dtype=layer_weight.data_type_ - ) - embedding( - infer_state.position_ids, layer_weight.wpe_weight_, 0, layer_weight.wpe_weight_.shape[0], position_embeds + position_embeds = layer_weight.wpe_weight_.embedding( + input_ids=infer_state.position_ids, + alloc_func=self.alloc_tensor, ) return input_embdings.add_(position_embeds) - def token_forward(self, input_ids, infer_state: InferStateInfo, layer_weight: PreAndPostLayerWeight): - # import ipdb;ipdb.set_trace() - input_embdings = self.alloc_tensor( - (input_ids.shape[0], layer_weight.wte_weight_.shape[1]), dtype=layer_weight.data_type_ - ) - embedding(input_ids, layer_weight.wte_weight_, self.vob_start_id_, self.vob_end_id_, input_embdings) + def token_forward(self, input_ids, infer_state: InferStateInfo, layer_weight: StarcoderPreAndPostLayerWeight): + input_embdings = layer_weight.wte_weight_.embedding(input_ids=input_ids, alloc_func=self.alloc_tensor) if self.tp_world_size_ > 1: all_reduce(input_embdings, group=infer_state.dist_group, op=dist.ReduceOp.SUM, async_op=False) - position_embeds = self.alloc_tensor( - (infer_state.position_ids.shape[0], layer_weight.wpe_weight_.shape[1]), dtype=layer_weight.data_type_ - ) - embedding( - infer_state.position_ids, layer_weight.wpe_weight_, 0, layer_weight.wpe_weight_.shape[0], position_embeds + position_embeds = layer_weight.wpe_weight_.embedding( + input_ids=infer_state.position_ids, + alloc_func=self.alloc_tensor, ) - return input_embdings.add_(position_embeds) diff --git a/lightllm/models/starcoder/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/starcoder/layer_weights/pre_and_post_layer_weight.py index 8d87c1163..d5bdd79a7 100644 --- a/lightllm/models/starcoder/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/starcoder/layer_weights/pre_and_post_layer_weight.py @@ -1,51 +1,32 @@ -import torch -import numpy as np from lightllm.common.basemodel import PreAndPostLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ( + EmbeddingWeight, + NoTpNormWeight, + NoTpPosEmbeddingWeight, + LMHeadWeight, +) class StarcoderPreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) - def load_hf_weights(self, weights): + self.wte_weight_ = EmbeddingWeight( + weight_name="transformer.wte.weight", + data_type=self.data_type_, + ) + self.wpe_weight_ = NoTpPosEmbeddingWeight( + weight_name="transformer.wpe.weight", + data_type=self.data_type_, + ) - vob_size = self.network_config_["vocab_size"] - split_vob_size = vob_size // self.tp_world_size_ - if "transformer.wte.weight" in weights: - # print(weights['transformer.wte.weight'].shape) - self.wte_weight_ = ( - weights["transformer.wte.weight"][ - split_vob_size * self.tp_rank_ : split_vob_size * (self.tp_rank_ + 1), : - ] - .contiguous() - .to(self.data_type_) - .cuda() - ) - if "transformer.wpe.weight" in weights: - # print(weights['transformer.wpe.weight'].shape) - self.wpe_weight_ = weights["transformer.wpe.weight"].to(self.data_type_).cuda() - if "lm_head.weight" in weights: - self.lm_head_weight_ = ( - weights["lm_head.weight"][split_vob_size * self.tp_rank_ : split_vob_size * (self.tp_rank_ + 1), :] - .contiguous() - .to(self.data_type_) - .cuda() - ) - if "transformer.ln_f.weight" in weights: - self.final_norm_weight_ = weights["transformer.ln_f.weight"].contiguous().to(self.data_type_).cuda() - if "transformer.ln_f.bias" in weights: - self.final_norm_bias_ = weights["transformer.ln_f.bias"].contiguous().to(self.data_type_).cuda() - return - - def verify_load(self): - errors = "weights load not ok" - weights = [ - self.final_norm_weight_, - self.final_norm_bias_, - self.wte_weight_, - self.wpe_weight_, - self.lm_head_weight_, - ] - for i in range(len(weights)): - assert weights[i] is not None, "index:" + str(i) + " " + errors + self.final_norm_weight_ = NoTpNormWeight( + weight_name="transformer.ln_f.weight", + bias_name="transformer.ln_f.bias", + data_type=self.data_type_, + ) + self.lm_head_weight_ = LMHeadWeight( + weight_name="lm_head.weight", + data_type=self.data_type_, + ) return diff --git a/lightllm/models/starcoder/model.py b/lightllm/models/starcoder/model.py index ea2aeabbc..3fcd14208 100644 --- a/lightllm/models/starcoder/model.py +++ b/lightllm/models/starcoder/model.py @@ -6,6 +6,7 @@ from lightllm.models.bloom.layer_infer.post_layer_infer import BloomPostLayerInfer from lightllm.common.build_utils import repair_config from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class +from lightllm.utils.envs_utils import get_added_mtp_kv_layer_num from lightllm.common.basemodel import TpPartBaseModel from lightllm.common.basemodel import InferStateInfo @@ -46,7 +47,7 @@ def _init_mem_manager(self): dtype=self.data_type, head_num=self.config["num_key_value_heads"], head_dim=self.config["hidden_size"] // self.config["num_attention_heads"], - layer_num=self.config["num_hidden_layers"], + layer_num=self.config["num_hidden_layers"] + get_added_mtp_kv_layer_num(), mem_fraction=self.mem_fraction, ) return diff --git a/lightllm/models/starcoder2/layer_infer/transformer_layer_infer.py b/lightllm/models/starcoder2/layer_infer/transformer_layer_infer.py index ca04bee0c..796a96bc4 100644 --- a/lightllm/models/starcoder2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/starcoder2/layer_infer/transformer_layer_infer.py @@ -1,5 +1,4 @@ import torch -from lightllm.models.bloom.triton_kernel.layernorm import layernorm_forward from lightllm.models.starcoder2.layer_weights.transformer_layer_weight import Starcoder2TransformerLayerWeight from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.models.llama.infer_struct import LlamaInferStateInfo @@ -12,21 +11,19 @@ def __init__(self, layer_num, network_config, mode=[]): def _att_norm( self, input, infer_state: LlamaInferStateInfo, layer_weight: Starcoder2TransformerLayerWeight ) -> torch.Tensor: - return layernorm_forward( - input.view(-1, self.embed_dim_), - weight=layer_weight.att_norm_weight_.weight, - bias=layer_weight.att_norm_weight_.bias, + return layer_weight.att_norm_weight_.layernorm_forward( + input=input.view(-1, self.embed_dim_), eps=self.eps_, + alloc_func=self.alloc_tensor, ) def _ffn_norm( self, input, infer_state: LlamaInferStateInfo, layer_weight: Starcoder2TransformerLayerWeight ) -> torch.Tensor: - return layernorm_forward( - input.view(-1, self.embed_dim_), - weight=layer_weight.ffn_norm_weight_.weight, - bias=layer_weight.ffn_norm_weight_.bias, + return layer_weight.ffn_norm_weight_.layernorm_forward( + input=input.view(-1, self.embed_dim_), eps=self.eps_, + alloc_func=self.alloc_tensor, ) def _ffn( diff --git a/lightllm/models/starcoder2/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/starcoder2/layer_weights/pre_and_post_layer_weight.py index fd2d47575..28a26cb4b 100644 --- a/lightllm/models/starcoder2/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/starcoder2/layer_weights/pre_and_post_layer_weight.py @@ -1,37 +1,28 @@ import numpy as np from lightllm.common.basemodel import PreAndPostLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, LMHeadWeight, NoTpNormWeight class Starcoder2PreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) - return - - def load_hf_weights(self, weights): - vob_size = self.network_config_["vocab_size"] - split_indexes = np.linspace(0, vob_size, self.tp_world_size_ + 1, dtype=np.int64) - split_start = split_indexes[self.tp_rank_] - split_end = split_indexes[self.tp_rank_ + 1] - if "model.embed_tokens.weight" in weights: - self.wte_weight_ = self._cuda(weights["model.embed_tokens.weight"][split_start:split_end, :]) - - # for starcoder2-3b and 7b which didn't use lm_head.weight (tie_word_embeddings) - self.lm_head_weight_ = self.wte_weight_ - - if "lm_head.weight" in weights: - self.lm_head_weight_ = self._cuda(weights["lm_head.weight"][split_start:split_end, :]) - - if "model.norm.weight" in weights: - self.final_norm_weight_ = self._cuda(weights["model.norm.weight"]) - - if "model.norm.bias" in weights: - self.final_norm_bias_ = self._cuda(weights["model.norm.bias"]) - - return - def verify_load(self): - errors = "weights load not ok" - weights = [self.wte_weight_, self.lm_head_weight_, self.final_norm_weight_, self.final_norm_bias_] - for i in range(len(weights)): - assert weights[i] is not None, "index:" + str(i) + " " + errors + self.wte_weight_ = EmbeddingWeight( + weight_name="model.embed_tokens.weight", + data_type=self.data_type_, + ) + tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False) + if tie_word_embeddings: + self.lm_head_weight_: LMHeadWeight = self.wte_weight_ + else: + self.lm_head_weight_ = LMHeadWeight( + weight_name="lm_head.weight", + data_type=self.data_type_, + ) + + self.final_norm_weight_ = NoTpNormWeight( + weight_name="model.norm.weight", + data_type=self.data_type_, + bias_name="model.norm.bias", + ) return diff --git a/lightllm/models/starcoder2/model.py b/lightllm/models/starcoder2/model.py index a299c08be..2b7545914 100644 --- a/lightllm/models/starcoder2/model.py +++ b/lightllm/models/starcoder2/model.py @@ -9,6 +9,7 @@ from lightllm.common.build_utils import repair_config from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class +from lightllm.utils.envs_utils import get_added_mtp_kv_layer_num from lightllm.common.basemodel import TpPartBaseModel @@ -52,7 +53,7 @@ def _init_mem_manager(self): dtype=self.data_type, head_num=self.config["num_key_value_heads"] // self.tp_world_size_, head_dim=self.config["hidden_size"] // self.config["num_attention_heads"], - layer_num=self.config["num_hidden_layers"], + layer_num=self.config["num_hidden_layers"] + get_added_mtp_kv_layer_num(), mem_fraction=self.mem_fraction, ) return diff --git a/lightllm/models/vit/layer_weights/transformer_layer_weight.py b/lightllm/models/vit/layer_weights/transformer_layer_weight.py index f1de0bdc1..c6024594e 100644 --- a/lightllm/models/vit/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/vit/layer_weights/transformer_layer_weight.py @@ -7,8 +7,8 @@ from lightllm.common.basemodel.layer_weights.meta_weights import ( ROWMMWeight, COLMMWeight, - NormWeight, - TpNormWeight, + NoTpNormWeight, + TpVitPadNormWeight, ) from lightllm.utils.dist_utils import get_current_device_id @@ -119,17 +119,16 @@ def _init_ffn(self): ) def _init_norm(self): - self.att_norm_weight_ = NormWeight( + self.att_norm_weight_ = NoTpNormWeight( self._att_norm_weight_name, self.data_type_, bias_name=self._att_norm_bias_name ) - self.ffn_norm_weight_ = NormWeight( + self.ffn_norm_weight_ = NoTpNormWeight( self._ffn_norm_weight_name, self.data_type_, bias_name=self._ffn_norm_bias_name ) if self.qk_norm: - n_embed = self.network_config_["hidden_size"] - split_n_embed = (n_embed + self.padding_hidden_size) // self.tp_world_size_ - self.q_norm_weight_ = TpNormWeight(self._q_norm_weight_name, self.data_type_, split_n_embed) - self.k_norm_weight_ = TpNormWeight(self._k_norm_weight_name, self.data_type_, split_n_embed) + head_num = self.network_config_["num_attention_heads"] + self.q_norm_weight_ = TpVitPadNormWeight(self._q_norm_weight_name, self.data_type_, head_num=head_num) + self.k_norm_weight_ = TpVitPadNormWeight(self._k_norm_weight_name, self.data_type_, head_num=head_num) def load_hf_weights(self, weights): if f"vision_model.encoder.layers.{self.layer_num_}.attn.qkv.weight" in weights: @@ -159,12 +158,6 @@ def load_hf_weights(self, weights): weights[self._v_bias_name] = v_bias_ del weights[f"vision_model.encoder.layers.{self.layer_num_}.attn.qkv.bias"] - if self.qk_norm and self._q_norm_weight_name in weights: - weights[self._q_norm_weight_name] = F.pad(weights[self._q_norm_weight_name], (0, self.padding_hidden_size)) - - if self.qk_norm and self._k_norm_weight_name in weights: - weights[self._k_norm_weight_name] = F.pad(weights[self._k_norm_weight_name], (0, self.padding_hidden_size)) - if f"vision_model.encoder.layers.{self.layer_num_}.ls1" in weights: ls1 = weights[f"vision_model.encoder.layers.{self.layer_num_}.ls1"] self.ls1 = self._cuda(ls1) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 68293fc92..d193bab41 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -542,13 +542,17 @@ def make_argument_parser() -> argparse.ArgumentParser: ) parser.add_argument( "--mtp_mode", - choices=["deepseekv3_vanilla", "deepseekv3_eagle", None], + choices=["vanilla_with_att", "eagle_with_att", "vanilla_no_att", "eagle_no_att", None], default=None, - help="""supported mtp mode, None is not enable mtp, """, + help="""Supported MTP modes. + None: Disables MTP. + *_with_att: Uses the MTP model with an attention mechanism to predict the next draft token. + *_no_att: Uses the MTP model without an attention module to predict the next draft token.""", ) parser.add_argument( "--mtp_draft_model_dir", type=str, + nargs="+", default=None, help="""Path to the draft model for the MTP multi-prediction feature, used for loading the MTP multi-output token model.""", diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 014d98766..4ead3cbbf 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -197,7 +197,8 @@ def normal_or_p_d_start(args): assert ( args.batch_max_tokens >= args.chunked_prefill_size - ), "chunked prefill mode, batch_max_tokens must >= chunked_prefill_size" + ), "chunked prefill mode, batch_max_tokens must >= chunked_prefill_size, " + f"but got {args.batch_max_tokens}, {args.chunked_prefill_size}" # help to manage data stored on Ceph if "s3://" in args.model_dir: diff --git a/lightllm/server/core/objs/req.py b/lightllm/server/core/objs/req.py index 1e1796335..f489aac9c 100644 --- a/lightllm/server/core/objs/req.py +++ b/lightllm/server/core/objs/req.py @@ -352,8 +352,8 @@ def get_decode_need_tokens(self): need_tokens = min(self.input_len + self.shm_cur_output_len - self.shm_cur_kv_len, self.chunked_prefill_size) if need_tokens == 1 and self._mtp_step > 0: # self._mtp_step > 0 时,说明开启了mtp 模式,每次decode需要额外的mem token 资源 - # "deepseekv3_vanilla" 模式需要的 mem 用量为 self._mtp_step + 1 - # "deepseekv3_eagle" 模式需要的 mem 用量为 (self._mtp_step + 1)* 2 + # "vanilla_with_att" 模式需要的 mem 用量为 self._mtp_step + 1 + # "eagle_with_att" 模式需要的 mem 用量为 (self._mtp_step + 1)* 2 # 为了简化统一 返回 (self._mtp_step + 1)* 2 need_tokens = (self._mtp_step + 1) * 2 diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 611ff772b..5ebadaf16 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -124,7 +124,9 @@ class StartArgs: ) ep_redundancy_expert_config_path: Optional[str] = field(default=None) auto_update_redundancy_expert: bool = field(default=False) - mtp_mode: Optional[str] = field(default=None) + mtp_mode: Optional[str] = field( + default=None, metadata={"choices": ["vanilla_with_att", "eagle_with_att", "vanilla_no_att", "eagle_no_att"]} + ) mtp_draft_model_dir: Optional[str] = field(default=None) mtp_step: int = field(default=0) kv_quant_calibration_config_path: Optional[str] = field(default=None) diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 4c40814b8..613722787 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -550,6 +550,7 @@ async def _wait_to_token_package( event = req_status.event unfinished_count = sampling_params.best_of out_token_counter = 0 + sub_req_id_to_mtp_accepted_token_num: Dict[int, int] = {} first_token_cost_ms = sys.float_info.max prompt_tokens = len(prompt_ids) is_first_token = True @@ -579,6 +580,8 @@ async def _wait_to_token_package( prompt_cache_len = metadata.pop("prompt_cache_len", 0) cpu_prompt_cache_len = metadata.pop("cpu_prompt_cache_len", 0) disk_prompt_cache_len = metadata.pop("disk_prompt_cache_len", 0) + sub_req_id_to_mtp_accepted_token_num[sub_req_id] = metadata.get("mtp_accepted_token_num", 0) + if is_first_token: first_token_cost_ms = (time.time() - start_time) * 1000 is_first_token = False @@ -605,7 +608,7 @@ async def _wait_to_token_package( disk_prompt_cache_ratio = disk_prompt_cache_len / prompt_tokens mtp_avg_token_per_step = out_token_counter / max( - (out_token_counter - metadata["mtp_accepted_token_num"]), 1 + (out_token_counter - sum(sub_req_id_to_mtp_accepted_token_num.values())), 1 ) format_start_time = datetime.datetime.fromtimestamp(start_time).strftime("%Y-%m-%d %H:%M:%S") logger.info( diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index c37b8d74e..92653bc0c 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -35,11 +35,12 @@ from lightllm.server.core.objs.shm_objs_io_buffer import ShmObjsIOBuffer from lightllm.server.router.model_infer.mode_backend.overlap_events import OverlapEventManager, OverlapEventPack from lightllm.models.deepseek_mtp.model import Deepseek3MTPModel +from lightllm.models.qwen3_moe_mtp.model import Qwen3MOEMTPModel +from lightllm.models.mistral_mtp.model import MistralMTPModel from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token from lightllm.server.pd_io_struct import NIXLChunckedTransTaskRet from .multi_level_kv_cache import MultiLevelKvCacheModule -from lightllm.server.embed_cache.embed_cache_client import CpuEmbedCacheClient class ModeBackend: @@ -288,17 +289,17 @@ def init_mtp_draft_model(self, main_kvargs: dict): os.environ["DISABLE_CHECK_MAX_LEN_INFER"] = "1" - if self.args.mtp_mode == "deepseekv3_vanilla": + if self.args.mtp_mode in ["vanilla_with_att", "vanilla_no_att"]: num_mtp_modules = self.args.mtp_step - elif self.args.mtp_mode == "deepseekv3_eagle": + elif self.args.mtp_mode in ["eagle_with_att", "eagle_no_att"]: num_mtp_modules = 1 else: assert False, f"error mtp mode {self.args.mtp_mode}" for i in range(num_mtp_modules): - mtp_model_cfg, _ = PretrainedConfig.get_config_dict(self.args.mtp_draft_model_dir) + mtp_model_cfg, _ = PretrainedConfig.get_config_dict(self.args.mtp_draft_model_dir[i]) mtp_model_kvargs = { - "weight_dir": self.args.mtp_draft_model_dir, + "weight_dir": self.args.mtp_draft_model_dir[i], "max_total_token_num": self.model.mem_manager.size, "load_way": main_kvargs["load_way"], "mode": main_kvargs["mode"], @@ -317,13 +318,21 @@ def init_mtp_draft_model(self, main_kvargs: dict): "quant_cfg": main_kvargs.get("quant_cfg", None), "run_mode": "normal", "main_model": self.model, - "mem_layer_start": self.model.config["num_hidden_layers"] + i * mtp_model_cfg["num_hidden_layers"], + "mtp_previous_draft_models": self.draft_models.copy(), } - mtp_model_cfg, _ = PretrainedConfig.get_config_dict(self.args.mtp_draft_model_dir) - assert mtp_model_cfg["model_type"] == "deepseek_v3" - assert mtp_model_cfg["architectures"][0] == "DeepseekV3ForCausalLMNextN" - self.draft_models.append(Deepseek3MTPModel(mtp_model_kvargs)) + mtp_model_cfg, _ = PretrainedConfig.get_config_dict(self.args.mtp_draft_model_dir[i]) + if mtp_model_cfg["model_type"] == "deepseek_v3": + assert self.args.mtp_mode in ["vanilla_with_att", "eagle_with_att"] + self.draft_models.append(Deepseek3MTPModel(mtp_model_kvargs)) + elif mtp_model_cfg["model_type"] == "qwen3_moe": + assert self.args.mtp_mode in ["vanilla_no_att", "eagle_no_att"] + self.draft_models.append(Qwen3MOEMTPModel(mtp_model_kvargs)) + elif mtp_model_cfg["model_type"] == "mistral": + assert self.args.mtp_mode in ["vanilla_no_att", "eagle_no_att"] + self.draft_models.append(MistralMTPModel(mtp_model_kvargs)) + else: + assert False, f"error mtp mode {mtp_model_cfg['model_type']}" self.logger.info(f"loaded mtp model class {self.draft_models[i].__class__}") return diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index 858f4713a..f3450261b 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -40,7 +40,7 @@ def __init__(self) -> None: if get_env_start_args().mtp_mode: self.prefill = self.prefill_mtp self.decode = self.decode_mtp - self.is_mtp_eagle = get_env_start_args().mtp_mode == "deepseekv3_eagle" + self.is_mtp_eagle = get_env_start_args().mtp_mode in ["eagle_with_att", "eagle_no_att"] self.num_mtp_models = 1 if self.is_mtp_eagle else get_env_start_args().mtp_step self._draft_decode_func = self._draft_decode_eagle if self.is_mtp_eagle else self._draft_decode_vanilla else: @@ -325,7 +325,7 @@ def _draft_prefill_forward(self, model_input: ModelInput, model_output: ModelOut draft_model_input = prepare_mtp_prefill_inputs( model_input=draft_model_input, b_next_token_ids=draft_next_token_ids_gpu, - deepseekv3_mtp_draft_input_hiddens=draft_model_output.deepseekv3_mtp_main_output_hiddens, + mtp_draft_input_hiddens=draft_model_output.mtp_main_output_hiddens, ) draft_model_output = self.draft_models[draft_model_idx].forward(draft_model_input) draft_next_token_ids_gpu = self._gen_argmax_token_ids(draft_model_output) @@ -349,7 +349,7 @@ def _draft_decode_vanilla( for draft_model_idx in range(self.mtp_step): draft_model_input.input_ids = draft_next_token_ids - draft_model_input.deepseekv3_mtp_draft_input_hiddens = draft_model_output.deepseekv3_mtp_main_output_hiddens + draft_model_input.mtp_draft_input_hiddens = draft_model_output.mtp_main_output_hiddens # spec decode: MTP draft_model_output: ModelOutput = self.draft_models[draft_model_idx].forward(draft_model_input) draft_next_token_ids = self._gen_argmax_token_ids(draft_model_output) @@ -393,7 +393,7 @@ def _draft_decode_eagle( for _step in range(self.mtp_step): draft_model_input.input_ids = draft_next_token_ids - draft_model_input.deepseekv3_mtp_draft_input_hiddens = draft_model_output.deepseekv3_mtp_main_output_hiddens + draft_model_input.mtp_draft_input_hiddens = draft_model_output.mtp_main_output_hiddens # spec decode: MTP draft_model_idx = _step % self.num_mtp_models draft_model_output: ModelOutput = self.draft_models[draft_model_idx].forward(draft_model_input) diff --git a/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py index 5b0bf3335..5a179cb62 100644 --- a/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py @@ -15,12 +15,20 @@ from lightllm.server.router.model_infer.pin_mem_manager import g_pin_mem_manager from ..chunked_prefill.impl import ChunkedPrefillBackend from lightllm.common.basemodel.infer_lock import g_infer_state_lock +from lightllm.utils.envs_utils import get_env_start_args class DiversehBackend(ChunkedPrefillBackend): def __init__(self) -> None: super().__init__() - self.prefill = self.beam_prefill + + if get_env_start_args().mtp_mode: + # 当前只有 mistral mtp 可以使用 diverse mode 的 mtp 功能。 + self.prefill = self.beam_prefill + assert get_env_start_args().mtp_mode in ["vanilla_no_att", "eagle_no_att"] + else: + self.prefill = self.beam_prefill + self.classed_req_strict_prefill = True def beam_prefill(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq]): diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index a1414b8b2..df10a6d4e 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -34,7 +34,7 @@ def __init__(self) -> None: # 在 mtp 模式下切换绑定的prefill 和 decode 函数 if get_env_start_args().mtp_mode: - self.is_mtp_eagle = get_env_start_args().mtp_mode == "deepseekv3_eagle" + self.is_mtp_eagle = get_env_start_args().mtp_mode in ["eagle_with_att", "eagle_no_att"] self.num_mtp_models = 1 if self.is_mtp_eagle else get_env_start_args().mtp_step if self.enable_prefill_microbatch_overlap: self.prefill = self.prefill_overlap_mtp @@ -534,7 +534,7 @@ def _draft_decode_vanilla( for draft_model_idx in range(self.mtp_step): draft_model_input.input_ids = draft_next_token_ids_gpu - draft_model_input.deepseekv3_mtp_draft_input_hiddens = draft_model_output.deepseekv3_mtp_main_output_hiddens + draft_model_input.mtp_draft_input_hiddens = draft_model_output.mtp_main_output_hiddens # spec decode: MTP draft_model_output: ModelOutput = self.draft_models[draft_model_idx].forward(draft_model_input) draft_next_token_ids_gpu = self._gen_argmax_token_ids(draft_model_output) @@ -585,7 +585,7 @@ def _draft_decode_eagle( for _step in range(self.mtp_step): draft_model_input.input_ids = draft_next_token_ids_gpu - draft_model_input.deepseekv3_mtp_draft_input_hiddens = draft_model_output.deepseekv3_mtp_main_output_hiddens + draft_model_input.mtp_draft_input_hiddens = draft_model_output.mtp_main_output_hiddens # spec decode: MTP draft_model_idx = _step % self.num_mtp_models draft_model_output: ModelOutput = self.draft_models[draft_model_idx].forward(draft_model_input) @@ -672,13 +672,13 @@ def prefill_overlap_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[I draft_model_input0 = prepare_mtp_prefill_inputs( model_input=draft_model_input0, b_next_token_ids=draft_next_token_ids_gpu0, - deepseekv3_mtp_draft_input_hiddens=draft_model_output0.deepseekv3_mtp_main_output_hiddens, + mtp_draft_input_hiddens=draft_model_output0.mtp_main_output_hiddens, ) draft_model_input1 = prepare_mtp_prefill_inputs( model_input=draft_model_input1, b_next_token_ids=draft_next_token_ids_gpu1, - deepseekv3_mtp_draft_input_hiddens=draft_model_output1.deepseekv3_mtp_main_output_hiddens, + mtp_draft_input_hiddens=draft_model_output1.mtp_main_output_hiddens, ) draft_model_output0, draft_model_output1 = self.draft_models[ @@ -836,7 +836,7 @@ def _draft_prefill_forward(self, model_input: ModelInput, model_output: ModelOut draft_model_input = prepare_mtp_prefill_inputs( model_input=draft_model_input, b_next_token_ids=draft_next_token_ids_gpu, - deepseekv3_mtp_draft_input_hiddens=draft_model_output.deepseekv3_mtp_main_output_hiddens, + mtp_draft_input_hiddens=draft_model_output.mtp_main_output_hiddens, ) draft_model_output = self.draft_models[draft_model_idx].forward(draft_model_input) draft_next_token_ids_gpu = self._gen_argmax_token_ids(draft_model_output) @@ -874,13 +874,9 @@ def _draft_decode_vanilla_overlap( for draft_model_idx in range(self.mtp_step): draft_model_input0.input_ids = draft_next_token_ids_gpu0 - draft_model_input0.deepseekv3_mtp_draft_input_hiddens = ( - draft_model_output0.deepseekv3_mtp_main_output_hiddens - ) + draft_model_input0.mtp_draft_input_hiddens = draft_model_output0.mtp_main_output_hiddens draft_model_input1.input_ids = draft_next_token_ids_gpu1 - draft_model_input1.deepseekv3_mtp_draft_input_hiddens = ( - draft_model_output1.deepseekv3_mtp_main_output_hiddens - ) + draft_model_input1.mtp_draft_input_hiddens = draft_model_output1.mtp_main_output_hiddens draft_model_output0, draft_model_output1 = self.draft_models[draft_model_idx].microbatch_overlap_decode( draft_model_input0, draft_model_input1 @@ -949,13 +945,9 @@ def _draft_decode_eagle_overlap( for _step in range(self.mtp_step): draft_model_input0.input_ids = draft_next_token_ids_gpu0 - draft_model_input0.deepseekv3_mtp_draft_input_hiddens = ( - draft_model_output0.deepseekv3_mtp_main_output_hiddens - ) + draft_model_input0.mtp_draft_input_hiddens = draft_model_output0.mtp_main_output_hiddens draft_model_input1.input_ids = draft_next_token_ids_gpu1 - draft_model_input1.deepseekv3_mtp_draft_input_hiddens = ( - draft_model_output1.deepseekv3_mtp_main_output_hiddens - ) + draft_model_input1.mtp_draft_input_hiddens = draft_model_output1.mtp_main_output_hiddens draft_model_idx = _step % self.num_mtp_models draft_model_output0, draft_model_output1 = self.draft_models[draft_model_idx].microbatch_overlap_decode( diff --git a/lightllm/server/router/model_infer/mode_backend/mtp_pre_process.py b/lightllm/server/router/model_infer/mode_backend/mtp_pre_process.py index 3991e42de..dbce73a94 100644 --- a/lightllm/server/router/model_infer/mode_backend/mtp_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/mtp_pre_process.py @@ -5,7 +5,7 @@ def prepare_mtp_prefill_inputs( - model_input: ModelInput, b_next_token_ids: torch.Tensor, deepseekv3_mtp_draft_input_hiddens: torch.Tensor + model_input: ModelInput, b_next_token_ids: torch.Tensor, mtp_draft_input_hiddens: torch.Tensor ): new_model_input = copy.copy(model_input) new_input_ids = gen_mtp_new_input_ids( @@ -15,5 +15,5 @@ def prepare_mtp_prefill_inputs( b_ready_cache_len=model_input.b_ready_cache_len, ) new_model_input.input_ids = new_input_ids - new_model_input.deepseekv3_mtp_draft_input_hiddens = deepseekv3_mtp_draft_input_hiddens + new_model_input.mtp_draft_input_hiddens = mtp_draft_input_hiddens return new_model_input diff --git a/lightllm/utils/envs_utils.py b/lightllm/utils/envs_utils.py index 8995afbc5..06f53b307 100644 --- a/lightllm/utils/envs_utils.py +++ b/lightllm/utils/envs_utils.py @@ -235,3 +235,15 @@ def enable_huge_page(): "sudo reboot" """ return enable_env_vars("LIGHTLLM_HUGE_PAGE_ENABLE") + + +@lru_cache(maxsize=None) +def get_added_mtp_kv_layer_num() -> int: + # mtp 模式下需要在mem manger上扩展draft model使用的layer + added_mtp_layer_num = 0 + if get_env_start_args().mtp_mode == "eagle_with_att": + added_mtp_layer_num += 1 + elif get_env_start_args().mtp_mode == "vanilla_with_att": + added_mtp_layer_num += get_env_start_args().mtp_step + + return added_mtp_layer_num diff --git a/lightllm/utils/kv_cache_utils.py b/lightllm/utils/kv_cache_utils.py index ed183e393..4875b4eee 100644 --- a/lightllm/utils/kv_cache_utils.py +++ b/lightllm/utils/kv_cache_utils.py @@ -8,7 +8,12 @@ import numpy as np import triton from functools import lru_cache -from lightllm.utils.envs_utils import get_env_start_args, enable_huge_page, get_llm_data_type +from lightllm.utils.envs_utils import ( + get_env_start_args, + enable_huge_page, + get_llm_data_type, + get_added_mtp_kv_layer_num, +) from lightllm.utils.log_utils import init_logger from lightllm.utils.config_utils import get_num_key_value_heads, get_head_dim, get_layer_num from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class @@ -111,7 +116,7 @@ def calcu_cpu_cache_meta() -> "CpuKVCacheMeta": if args.mtp_mode is not None: # TODO 可能会存在不同mtp模式的精度问题 - cpu_cache_meta.layer_num += 1 + cpu_cache_meta.layer_num += get_added_mtp_kv_layer_num() cpu_cache_page_num = int( (args.cpu_cache_storage_size * 1024 * 1024 * 1024) / (cpu_cache_meta.calcu_one_page_size()) diff --git a/test/acc/test_deepseekr1.sh b/test/acc/test_deepseekr1.sh new file mode 100644 index 000000000..e167303a3 --- /dev/null +++ b/test/acc/test_deepseekr1.sh @@ -0,0 +1,5 @@ +LOADWORKER=18 python -m lightllm.server.api_server --model_dir /mtc/models/DeepSeek-R1 --tp 8 --port 8089 --enable_fa3 + + + +HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"deepseek-ai/DeepSeek-R1", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 500 --confirm_run_unsafe_code \ No newline at end of file diff --git a/test/acc/test_deepseekr1_mtp.sh b/test/acc/test_deepseekr1_mtp.sh new file mode 100644 index 000000000..046314a72 --- /dev/null +++ b/test/acc/test_deepseekr1_mtp.sh @@ -0,0 +1,3 @@ +LOADWORKER=18 python -m lightllm.server.api_server --model_dir /mtc/models/DeepSeek-R1 --tp 8 --port 8089 --mem_fraction 0.75 --enable_fa3 --batch_max_tokens 6000 --mtp_mode eagle_with_att --mtp_draft_model_dir /mtc/models/DeepSeek-R1-NextN --mtp_step 2 + +HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"deepseek-ai/DeepSeek-R1", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 500 --confirm_run_unsafe_code \ No newline at end of file diff --git a/test/acc/test_deepseekr1_mtp_ep.sh b/test/acc/test_deepseekr1_mtp_ep.sh new file mode 100644 index 000000000..2ea5f7438 --- /dev/null +++ b/test/acc/test_deepseekr1_mtp_ep.sh @@ -0,0 +1,3 @@ +LOADWORKER=18 MOE_MODE=EP NUM_MAX_DISPATCH_TOKENS_PER_RANK=256 python -m lightllm.server.api_server --model_dir /mtc/models/DeepSeek-R1 --tp 8 --dp 8 --port 8089 --max_total_token_num 60000 --graph_max_batch_size 16 --enable_fa3 --batch_max_tokens 6000 --mtp_mode eagle_with_att --mtp_draft_model_dir /mtc/models/DeepSeek-R1-NextN --mtp_step 2 + +HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"deepseek-ai/DeepSeek-R1", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 32 --confirm_run_unsafe_code \ No newline at end of file diff --git a/test/acc/test_qwen2.sh b/test/acc/test_qwen2.sh new file mode 100644 index 000000000..265d679e8 --- /dev/null +++ b/test/acc/test_qwen2.sh @@ -0,0 +1,5 @@ +# first +LOADWORKER=18 CUDA_VISIBLE_DEVICES=6,7 python -m lightllm.server.api_server --model_dir /root/.cache/huggingface/hub/models--Qwen--Qwen2.5-Math-7B-Instruct/snapshots/ef9926d75ab1d54532f6a30dd5e760355eb9aa4d --tp 2 --port 8089 --enable_fa3 + +# second +HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"Qwen/Qwen2.5-Math-7B-Instruct", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 500 --confirm_run_unsafe_code \ No newline at end of file diff --git a/test/acc/test_qwen3.sh b/test/acc/test_qwen3.sh new file mode 100644 index 000000000..c0da5ec96 --- /dev/null +++ b/test/acc/test_qwen3.sh @@ -0,0 +1,5 @@ +# first +LOADWORKER=18 CUDA_VISIBLE_DEVICES=6,7 python -m lightllm.server.api_server --model_dir /mtc/models/qwen3-8b --tp 2 --port 8089 --enable_fa3 + +# second +HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"qwen/qwen3-8b", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 500 --confirm_run_unsafe_code \ No newline at end of file diff --git a/test/benchmark/static_inference/model_infer_mtp.py b/test/benchmark/static_inference/model_infer_mtp.py index ef2ada64c..942af0f88 100644 --- a/test/benchmark/static_inference/model_infer_mtp.py +++ b/test/benchmark/static_inference/model_infer_mtp.py @@ -148,7 +148,7 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_ # Draft model Prefill # For simplicity, we'll just take the input of main_model to draft model. - model_input.deepseekv3_mtp_draft_input_hiddens = model_output.deepseekv3_mtp_main_output_hiddens + model_input.mtp_draft_input_hiddens = model_output.mtp_main_output_hiddens for draft_model_id in range(len(draft_models)): draft_model = draft_models[draft_model_id] model_output = draft_model.forward(model_input) @@ -156,7 +156,7 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_ predict_ids = torch.argmax(prob_out, dim=1, keepdim=True) predict_ids = predict_ids.detach().cpu().numpy() draft_ids.append(predict_ids) - model_input.deepseekv3_mtp_draft_input_hiddens = model_output.deepseekv3_mtp_main_output_hiddens + model_input.mtp_draft_input_hiddens = model_output.mtp_main_output_hiddens torch.cuda.synchronize() prefill_end_time = time.time() @@ -218,7 +218,7 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_ # draft decode model_input.input_ids = predict_ids.reshape(-1) - model_input.deepseekv3_mtp_draft_input_hiddens = model_output.deepseekv3_mtp_main_output_hiddens + model_input.mtp_draft_input_hiddens = model_output.mtp_main_output_hiddens for draft_model_id in range(len(draft_models)): draft_model = draft_models[draft_model_id] @@ -228,11 +228,11 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_ prob_out = torch.softmax(model_output.logits, dim=-1) predict_ids = torch.argmax(prob_out, dim=1, keepdim=True) model_input.input_ids = predict_ids.reshape(-1) - model_input.deepseekv3_mtp_draft_input_hiddens = model_output.deepseekv3_mtp_main_output_hiddens + model_input.mtp_draft_input_hiddens = model_output.mtp_main_output_hiddens # accept all draft ids by default. model_input.input_ids = predict_ids.reshape(-1) - model_input.deepseekv3_mtp_draft_input_hiddens = model_output.deepseekv3_mtp_main_output_hiddens + model_input.mtp_draft_input_hiddens = model_output.mtp_main_output_hiddens torch.cuda.synchronize() if i % 100 == 0 or i == output_len - 1: step_end_time = time.time() diff --git a/unit_tests/common/basemodel/triton_kernel/test_gen_decode_params.py b/unit_tests/common/basemodel/triton_kernel/test_gen_decode_params.py index c7c6a844d..5c3ca89c6 100644 --- a/unit_tests/common/basemodel/triton_kernel/test_gen_decode_params.py +++ b/unit_tests/common/basemodel/triton_kernel/test_gen_decode_params.py @@ -1,11 +1,21 @@ import torch import pytest -import numpy as np -from lightllm.utils.log_utils import init_logger +import easydict from lightllm.common.basemodel.triton_kernel.gen_decode_params import gen_decode_params +from lightllm.utils.envs_utils import set_env_start_args def test_gen_decode_params_basic(): + set_env_start_args( + easydict.EasyDict( + { + "mtp_step": 0, + "enable_flashinfer_prefill": False, + "enable_flashinfer_decode": False, + } + ) + ) + b_seq_len = torch.ones((9,), dtype=torch.int64, device="cuda") * 8192 ( b_q_seq_len,