Skip to content

Conversation

@hiworldwzj
Copy link
Collaborator

No description provided.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @hiworldwzj, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly expands the Multi-Token Prediction (MTP) framework by introducing native support for Qwen3 MoE and Mistral models. It refines the core inference mechanisms, particularly FlashAttention, to be more adaptive to MTP operations and updates the model loading infrastructure to handle multiple and diverse MTP draft models. These changes aim to enhance the system's flexibility and performance when utilizing various draft models for accelerated inference.

Highlights

  • Expanded MTP Draft Model Support: The system now recognizes and supports 'Qwen3MOEMTPModel' and 'MistralMTPModel' within the Multi-Token Prediction (MTP) draft model detection logic, alongside the existing 'Deepseek3MTPModel'.
  • Dynamic FlashAttention State Management: FlashAttention state initialization has been enhanced to dynamically adjust page table sizing and request index slicing based on the mtp_step environment argument, improving memory management for MTP models. The cache_seqlens and max_seqlen_q parameters in transformer layer inference are now also dynamically set.
  • New Mistral MTP Model Integration: Comprehensive support for Mistral models in MTP mode has been added, including new pre-layer, post-layer, and transformer layer inference classes, along with dedicated weight loading utilities and a new MistralMTPModel definition.
  • New Qwen3 MoE MTP Model Integration: Full integration for Qwen3 MoE models in MTP mode has been implemented, introducing new transformer layer inference, pre/post layer weight handling, and a Qwen3MOEMTPModel definition to leverage MTP capabilities.
  • Flexible MTP Draft Model Loading: The command-line interface for specifying MTP draft models (--mtp_draft_model_dir) now accepts multiple model directories, and the backend model loading logic has been updated to dynamically instantiate the correct MTP model type (Deepseek, Qwen3, or Mistral) based on its configuration.
  • Memory Manager Adaptation for MTP: The memory manager's layer_num calculation has been updated to account for additional layers used by MTP draft models, ensuring proper memory allocation when operating in MTP modes like 'deepseekv3_eagle' or 'deepseekv3_vanilla'.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces Multi-Token Prediction (MTP) support for Qwen3 (dense and MoE) and Mistral models. This is a significant feature addition, involving new model classes, layer implementations, and weight loading utilities for MTP draft models. The changes also update base model components and server logic to accommodate MTP configurations. My review focuses on improving code clarity, maintainability, and addressing a design issue related to model coupling. Overall, the changes are extensive and well-structured to support MTP.

Comment on lines 47 to 51
added_mtp_layer_num = 0
if get_env_start_args().mtp_mode == "deepseekv3_eagle":
added_mtp_layer_num += 1
elif get_env_start_args().mtp_mode == "deepseekv3_vanilla":
added_mtp_layer_num += get_env_start_args().mtp_step
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic for added_mtp_layer_num is hardcoded for deepseekv3 MTP modes. This creates a tight coupling between the qwen2 model and other models' MTP configurations. This should be refactored to be more generic, for example by having the number of extra layers passed in or configured in a more abstract way, to improve modularity and reduce dependencies between model implementations.

Comment on lines 911 to 1001
is_deepseekv3_mtp_draft_model = (
"Deepseek3MTPModel" in str(self.__class__)
or "Qwen3MOEMTPModel" in str(self.__class__)
or "MistralMTPModel" in str(self.__class__)
)
if is_deepseekv3_mtp_draft_model:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable is_deepseekv3_mtp_draft_model is now used to check for Qwen3MOEMTPModel and MistralMTPModel as well. Its name has become misleading. To improve code clarity and maintainability, please rename it to something more generic, like is_mtp_draft_model.

Suggested change
is_deepseekv3_mtp_draft_model = (
"Deepseek3MTPModel" in str(self.__class__)
or "Qwen3MOEMTPModel" in str(self.__class__)
or "MistralMTPModel" in str(self.__class__)
)
if is_deepseekv3_mtp_draft_model:
is_mtp_draft_model = (
"Deepseek3MTPModel" in str(self.__class__)
or "Qwen3MOEMTPModel" in str(self.__class__)
or "MistralMTPModel" in str(self.__class__)
)
if is_mtp_draft_model:

Comment on lines +60 to +63
if args_mtp_step > 0:
self.b_att_seq_len = self.b_seq_len[args_mtp_step :: (args_mtp_step + 1)].contiguous()
else:
self.b_att_seq_len = self.b_seq_len
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The if/else block can be simplified into a single line. When args_mtp_step is 0, the slice [0::1] is equivalent to taking the whole tensor, so the else branch is redundant. Using a single line for both cases improves conciseness and is safer if self.b_seq_len is not always contiguous.

            self.b_att_seq_len = self.b_seq_len[args_mtp_step :: (args_mtp_step + 1)].contiguous()

Comment on lines 1 to 14
import os
import torch
import torch.functional as F
import torch.distributed as dist
import numpy as np
from lightllm.common.basemodel.layer_weights.base_layer_weight import BaseLayerWeight

from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight
from einops import rearrange
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward
from lightllm.common.basemodel import PostLayerInferTpl
from lightllm.utils.infer_utils import mark_cost_time
from lightllm.distributed.communication_op import all_gather
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There are several unused imports in this file. Please remove them to keep the code clean:

  • torch.functional as F (line 3)
  • from einops import rearrange (line 9)
  • from lightllm.utils.infer_utils import mark_cost_time (line 13)

input_ids = infer_state.prefill_dp_balance(input_ids=input_ids)

input_embdings = self.context_forward(input_ids=input_ids, infer_state=infer_state, layer_weight=layer_weight)
from lightllm.common.basemodel.triton_kernel.sp_pad_copy import sp_pad_copy
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The import from lightllm.common.basemodel.triton_kernel.sp_pad_copy import sp_pad_copy is repeated in multiple functions (tpsp_context_forward, tpsp_token_forward, overlap_tpsp_token_forward, overlap_tpsp_context_forward). To improve code readability and avoid duplication, please move this import to the top of the file.

errors = "weights load not ok"
weights = [self.eh_proj_weight_, self.enorm_weight_, self.hnorm_weight_]
for i in range(len(weights)):
assert weights[i] is not None, "index:" + str(i) + " " + errors
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The assertion message can be made more informative and readable by using an f-string.

Suggested change
assert weights[i] is not None, "index:" + str(i) + " " + errors
assert weights[i] is not None, f"index: {i} {errors}"

weight_dict=self.weight_dict,
)
self.pre_post_weight.verify_load()
[weight.verify_load() for weight in self.trans_layers_weight]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using a list comprehension for side effects (like calling verify_load) is not idiomatic Python. A standard for loop is more readable and conventional for this purpose.

Suggested change
[weight.verify_load() for weight in self.trans_layers_weight]
for weight in self.trans_layers_weight:
weight.verify_load()

Comment on lines +21 to +37
def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight):
input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
ffn_out = self._ffn(input1, infer_state, layer_weight)
input1 = None
if self.tp_world_size_ > 1:
all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False)
input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
return input_embdings

def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight):
input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
ffn_out = self._ffn(input1, infer_state, layer_weight)
input1 = None
if self.tp_world_size_ > 1:
all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False)
input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
return input_embdings
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The methods context_forward and token_forward have identical implementations. To avoid code duplication, you can extract the common logic into a private method and call it from both.

    def _forward(self, input_embdings, infer_state: InferStateInfo, layer_weight):
        input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
        ffn_out = self._ffn(input1, infer_state, layer_weight)
        input1 = None
        if self.tp_world_size_ > 1:
            all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False)
        input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
        return input_embdings

    def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight):
        return self._forward(input_embdings, infer_state, layer_weight)

    def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight):
        return self._forward(input_embdings, infer_state, layer_weight)

weight_dict=self.weight_dict,
)
self.pre_post_weight.verify_load()
[weight.verify_load() for weight in self.trans_layers_weight]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using a list comprehension for side effects (like calling verify_load) is not idiomatic Python. A standard for loop is more readable and conventional for this purpose.

Suggested change
[weight.verify_load() for weight in self.trans_layers_weight]
for weight in self.trans_layers_weight:
weight.verify_load()

Comment on lines +67 to +68
# self.pre_post_weight.lm_head_weight_ = self.main_model.pre_post_weight.lm_head_weight_
# self.pre_post_weight.final_norm_weight_ = self.main_model.pre_post_weight.final_norm_weight_
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

These lines are commented out. If they are no longer needed, please remove them to improve code clarity.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants