-
Notifications
You must be signed in to change notification settings - Fork 290
Qwen3 mtp dense #1159
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Qwen3 mtp dense #1159
Conversation
Summary of ChangesHello @hiworldwzj, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly expands the Multi-Token Prediction (MTP) framework by introducing native support for Qwen3 MoE and Mistral models. It refines the core inference mechanisms, particularly FlashAttention, to be more adaptive to MTP operations and updates the model loading infrastructure to handle multiple and diverse MTP draft models. These changes aim to enhance the system's flexibility and performance when utilizing various draft models for accelerated inference. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces Multi-Token Prediction (MTP) support for Qwen3 (dense and MoE) and Mistral models. This is a significant feature addition, involving new model classes, layer implementations, and weight loading utilities for MTP draft models. The changes also update base model components and server logic to accommodate MTP configurations. My review focuses on improving code clarity, maintainability, and addressing a design issue related to model coupling. Overall, the changes are extensive and well-structured to support MTP.
lightllm/models/qwen2/model.py
Outdated
| added_mtp_layer_num = 0 | ||
| if get_env_start_args().mtp_mode == "deepseekv3_eagle": | ||
| added_mtp_layer_num += 1 | ||
| elif get_env_start_args().mtp_mode == "deepseekv3_vanilla": | ||
| added_mtp_layer_num += get_env_start_args().mtp_step |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic for added_mtp_layer_num is hardcoded for deepseekv3 MTP modes. This creates a tight coupling between the qwen2 model and other models' MTP configurations. This should be refactored to be more generic, for example by having the number of extra layers passed in or configured in a more abstract way, to improve modularity and reduce dependencies between model implementations.
| is_deepseekv3_mtp_draft_model = ( | ||
| "Deepseek3MTPModel" in str(self.__class__) | ||
| or "Qwen3MOEMTPModel" in str(self.__class__) | ||
| or "MistralMTPModel" in str(self.__class__) | ||
| ) | ||
| if is_deepseekv3_mtp_draft_model: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable is_deepseekv3_mtp_draft_model is now used to check for Qwen3MOEMTPModel and MistralMTPModel as well. Its name has become misleading. To improve code clarity and maintainability, please rename it to something more generic, like is_mtp_draft_model.
| is_deepseekv3_mtp_draft_model = ( | |
| "Deepseek3MTPModel" in str(self.__class__) | |
| or "Qwen3MOEMTPModel" in str(self.__class__) | |
| or "MistralMTPModel" in str(self.__class__) | |
| ) | |
| if is_deepseekv3_mtp_draft_model: | |
| is_mtp_draft_model = ( | |
| "Deepseek3MTPModel" in str(self.__class__) | |
| or "Qwen3MOEMTPModel" in str(self.__class__) | |
| or "MistralMTPModel" in str(self.__class__) | |
| ) | |
| if is_mtp_draft_model: |
| if args_mtp_step > 0: | ||
| self.b_att_seq_len = self.b_seq_len[args_mtp_step :: (args_mtp_step + 1)].contiguous() | ||
| else: | ||
| self.b_att_seq_len = self.b_seq_len |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The if/else block can be simplified into a single line. When args_mtp_step is 0, the slice [0::1] is equivalent to taking the whole tensor, so the else branch is redundant. Using a single line for both cases improves conciseness and is safer if self.b_seq_len is not always contiguous.
self.b_att_seq_len = self.b_seq_len[args_mtp_step :: (args_mtp_step + 1)].contiguous()| import os | ||
| import torch | ||
| import torch.functional as F | ||
| import torch.distributed as dist | ||
| import numpy as np | ||
| from lightllm.common.basemodel.layer_weights.base_layer_weight import BaseLayerWeight | ||
|
|
||
| from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight | ||
| from einops import rearrange | ||
| from lightllm.models.llama.infer_struct import LlamaInferStateInfo | ||
| from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward | ||
| from lightllm.common.basemodel import PostLayerInferTpl | ||
| from lightllm.utils.infer_utils import mark_cost_time | ||
| from lightllm.distributed.communication_op import all_gather |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| input_ids = infer_state.prefill_dp_balance(input_ids=input_ids) | ||
|
|
||
| input_embdings = self.context_forward(input_ids=input_ids, infer_state=infer_state, layer_weight=layer_weight) | ||
| from lightllm.common.basemodel.triton_kernel.sp_pad_copy import sp_pad_copy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The import from lightllm.common.basemodel.triton_kernel.sp_pad_copy import sp_pad_copy is repeated in multiple functions (tpsp_context_forward, tpsp_token_forward, overlap_tpsp_token_forward, overlap_tpsp_context_forward). To improve code readability and avoid duplication, please move this import to the top of the file.
| errors = "weights load not ok" | ||
| weights = [self.eh_proj_weight_, self.enorm_weight_, self.hnorm_weight_] | ||
| for i in range(len(weights)): | ||
| assert weights[i] is not None, "index:" + str(i) + " " + errors |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lightllm/models/mistral_mtp/model.py
Outdated
| weight_dict=self.weight_dict, | ||
| ) | ||
| self.pre_post_weight.verify_load() | ||
| [weight.verify_load() for weight in self.trans_layers_weight] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using a list comprehension for side effects (like calling verify_load) is not idiomatic Python. A standard for loop is more readable and conventional for this purpose.
| [weight.verify_load() for weight in self.trans_layers_weight] | |
| for weight in self.trans_layers_weight: | |
| weight.verify_load() |
| def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): | ||
| input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) | ||
| ffn_out = self._ffn(input1, infer_state, layer_weight) | ||
| input1 = None | ||
| if self.tp_world_size_ > 1: | ||
| all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) | ||
| input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) | ||
| return input_embdings | ||
|
|
||
| def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): | ||
| input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) | ||
| ffn_out = self._ffn(input1, infer_state, layer_weight) | ||
| input1 = None | ||
| if self.tp_world_size_ > 1: | ||
| all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) | ||
| input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) | ||
| return input_embdings |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The methods context_forward and token_forward have identical implementations. To avoid code duplication, you can extract the common logic into a private method and call it from both.
def _forward(self, input_embdings, infer_state: InferStateInfo, layer_weight):
input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
ffn_out = self._ffn(input1, infer_state, layer_weight)
input1 = None
if self.tp_world_size_ > 1:
all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False)
input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
return input_embdings
def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight):
return self._forward(input_embdings, infer_state, layer_weight)
def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight):
return self._forward(input_embdings, infer_state, layer_weight)| weight_dict=self.weight_dict, | ||
| ) | ||
| self.pre_post_weight.verify_load() | ||
| [weight.verify_load() for weight in self.trans_layers_weight] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using a list comprehension for side effects (like calling verify_load) is not idiomatic Python. A standard for loop is more readable and conventional for this purpose.
| [weight.verify_load() for weight in self.trans_layers_weight] | |
| for weight in self.trans_layers_weight: | |
| weight.verify_load() |
| # self.pre_post_weight.lm_head_weight_ = self.main_model.pre_post_weight.lm_head_weight_ | ||
| # self.pre_post_weight.final_norm_weight_ = self.main_model.pre_post_weight.final_norm_weight_ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
491211c to
5046c53
Compare
No description provided.