From ca74f1e70716233c79fe9a5315923a9dae883b16 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Thu, 25 Dec 2025 09:25:17 +0000 Subject: [PATCH] add-glm4v --- lightllm/models/__init__.py | 1 + lightllm/models/glm4v/__init__.py | 0 lightllm/models/glm4v/glm4v_visual.py | 437 ++++++++++++++++++ lightllm/models/glm4v/layer_infer/__init__.py | 0 .../layer_infer/transformer_layer_infer.py | 104 +++++ .../models/glm4v/layer_weight/__init__.py | 0 .../layer_weight/pre_and_post_layer_weight.py | 14 + .../layer_weight/transformer_layer_weight.py | 35 ++ lightllm/models/glm4v/model.py | 87 ++++ lightllm/models/llama/model.py | 15 +- .../models/qwen2_vl/triton_kernel/mrope.py | 12 +- lightllm/server/tokenizer.py | 6 + .../visualserver/model_infer/model_rpc.py | 5 + 13 files changed, 711 insertions(+), 5 deletions(-) create mode 100644 lightllm/models/glm4v/__init__.py create mode 100644 lightllm/models/glm4v/glm4v_visual.py create mode 100644 lightllm/models/glm4v/layer_infer/__init__.py create mode 100644 lightllm/models/glm4v/layer_infer/transformer_layer_infer.py create mode 100644 lightllm/models/glm4v/layer_weight/__init__.py create mode 100644 lightllm/models/glm4v/layer_weight/pre_and_post_layer_weight.py create mode 100644 lightllm/models/glm4v/layer_weight/transformer_layer_weight.py create mode 100644 lightllm/models/glm4v/model.py diff --git a/lightllm/models/__init__.py b/lightllm/models/__init__.py index 4ee02f003..825bd535e 100644 --- a/lightllm/models/__init__.py +++ b/lightllm/models/__init__.py @@ -32,6 +32,7 @@ from lightllm.models.qwen3_vl.model import Qwen3VLTpPartModel from lightllm.models.qwen3_vl_moe.model import Qwen3VLMOETpPartModel from lightllm.models.gemma3.model import Gemma3TpPartModel +from lightllm.models.glm4v.model import GLM4VTpPartModel from lightllm.models.tarsier2.model import ( Tarsier2Qwen2TpPartModel, Tarsier2Qwen2VLTpPartModel, diff --git a/lightllm/models/glm4v/__init__.py b/lightllm/models/glm4v/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/glm4v/glm4v_visual.py b/lightllm/models/glm4v/glm4v_visual.py new file mode 100644 index 000000000..ea5e592de --- /dev/null +++ b/lightllm/models/glm4v/glm4v_visual.py @@ -0,0 +1,437 @@ +import os +import json +import torch +import torch.nn as nn +from PIL import Image +from io import BytesIO +from typing import List, Optional +from torch.nn import LayerNorm +import torch.nn.functional as F +from safetensors import safe_open +from transformers.activations import ACT2FN +from lightllm.server.multimodal_params import MultimodalParams, ImageItem +from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data +from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager +from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd +from lightllm.models.qwen2_vl.triton_kernel.rotary_pos_emb import apply_rotary_pos_emb_triton +from lightllm.models.qwen2_vl.vision_process import Qwen2VLImageProcessor + + +class Glm4vRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Glm4vRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Glm4VisionMlp(nn.Module): + def __init__(self, hidden_size: int, intermediate_size: int, hidden_act: str, bias: bool = False): + super().__init__() + self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=bias) + self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=bias) + self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=bias) + self.act_fn = ACT2FN[hidden_act] + + def forward(self, hidden_state): + return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + + +class Glm4vVisionPatchEmbed(nn.Module): + def __init__(self, patch_size: int, temporal_patch_size: int, in_channels: int, embed_dim: int) -> None: + super().__init__() + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.in_channels = in_channels + self.embed_dim = embed_dim + + kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size] + self.proj = nn.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = hidden_states.view( + -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size + ) + hidden_states = self.proj(hidden_states).view(-1, self.embed_dim) + return hidden_states + + +class Glm4vVisionRotaryEmbedding(nn.Module): + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + self.dim = dim + self.theta = theta + self.inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self._seq_len_cached = 0 + self._freqs_cos_cached = None + self._freqs_sin_cached = None + + def update_freqs_cache(self, seqlen: int) -> None: + if seqlen > self._seq_len_cached: + seqlen *= 2 + self._seq_len_cached = seqlen + self.inv_freq = 1.0 / ( + self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device) / self.dim) + ) + seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.outer(seq, self.inv_freq) + self._freqs_cos_cached = freqs.cos() + self._freqs_sin_cached = freqs.sin() + + def forward(self, seqlen: int) -> torch.Tensor: + self.update_freqs_cache(seqlen) + return self._freqs_cos_cached[:seqlen], self._freqs_sin_cached[:seqlen] + + +class Glm4vVisionPatchMerger(nn.Module): + def __init__(self, dim: int, context_dim: int, hidden_act: str, bias: bool = False) -> None: + super().__init__() + self.proj = nn.Linear(dim, dim, bias=bias) + self.post_projection_norm = LayerNorm(dim) + self.gate_proj = nn.Linear(dim, context_dim, bias=bias) + self.up_proj = nn.Linear(dim, context_dim, bias=bias) + self.down_proj = nn.Linear(context_dim, dim, bias=bias) + self.act1 = nn.GELU() + self.act_fn = ACT2FN[hidden_act] + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + hidden_state = self.proj(hidden_state) + hidden_state = self.act1(self.post_projection_norm(hidden_state)) + return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + + +class Glm4vVisionEmbeddings(nn.Module): + def __init__(self, hidden_size: int, image_size: int, patch_size: int): + super().__init__() + self.embed_dim = hidden_size + self.image_size = image_size + self.patch_size = patch_size + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + self.position_ids = torch.arange(self.num_positions).expand((1, -1)) + + def forward(self, embeddings, lengths, image_shapes, h_coords, w_coords) -> torch.Tensor: + """ + Forward pass with integrated position encoding adaptation using 2D interpolation. + + Args: + embeddings: Input embeddings tensor + lengths (torch.Tensor): Sequence lengths for each image in the batch. + image_shapes (torch.Tensor): Tensor of shape [batch_size, 3] representing the image shapes (t, h, w). + h_coords (torch.Tensor): Tensor of shape [total_seq] representing the h coordinate for each patch. + w_coords (torch.Tensor): Tensor of shape [total_seq] representing the w coordinate for each patch. + + Returns: + torch.Tensor: Embeddings with adapted position encoding added. + """ + # Get position embedding parameters + pos_embed_weight = self.position_embedding.weight + hidden_size = pos_embed_weight.shape[1] + total_seq = h_coords.shape[0] + device = pos_embed_weight.device + + # Move coordinates to correct device + h_coords, w_coords = h_coords.to(device), w_coords.to(device) + + # Handle empty sequence case + if total_seq == 0: + adapted_pos_embed = torch.empty(0, hidden_size, device=device, dtype=pos_embed_weight.dtype) + else: + # Convert inputs to tensors if needed + if isinstance(lengths, list): + lengths = torch.tensor(lengths, device=device, dtype=torch.long) + if not isinstance(image_shapes, torch.Tensor): + image_shapes = torch.tensor(image_shapes, device=device, dtype=torch.long) + + # Prepare 2D position embedding + orig_size_sq = pos_embed_weight.shape[0] + orig_size = int(orig_size_sq ** 0.5) + pos_embed_2d = ( + pos_embed_weight.view(orig_size, orig_size, hidden_size) + .permute(2, 0, 1) + .unsqueeze(0) + .to(device=device, dtype=torch.float32) + ) + + # Calculate target dimensions for each patch + target_h = torch.cat([image_shapes[i, 1].repeat(lengths[i]) for i in range(len(lengths))]).to( + device=device, dtype=torch.float32 + ) + target_w = torch.cat([image_shapes[i, 2].repeat(lengths[i]) for i in range(len(lengths))]).to( + device=device, dtype=torch.float32 + ) + + # Normalize coordinates to [-1, 1] range for grid_sample + h_coords = h_coords.to(device=device, dtype=torch.float32) + w_coords = w_coords.to(device=device, dtype=torch.float32) + norm_w = ((w_coords + 0.5) / target_w) * 2 - 1 + norm_h = ((h_coords + 0.5) / target_h) * 2 - 1 + + # Create sampling grid + grid = torch.stack((norm_w, norm_h), dim=-1).unsqueeze(0).unsqueeze(2) + + # Perform bicubic interpolation + interpolated_embed_fp32 = F.grid_sample( + pos_embed_2d, grid, mode="bicubic", align_corners=False, padding_mode="border" + ) + + # Reshape and convert back to original dtype + adapted_pos_embed_fp32 = interpolated_embed_fp32.squeeze(0).squeeze(-1).permute(1, 0) + adapted_pos_embed = adapted_pos_embed_fp32.to(pos_embed_weight.dtype).to(embeddings.device) + + # Add adapted position encoding to embeddings + embeddings = embeddings + adapted_pos_embed + return embeddings + + +class Glm4vVisionAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, attention_bias: bool = False, attention_dropout: float = 0.0) -> None: + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = self.dim // self.num_heads + self.num_key_value_groups = 1 # needed for eager attention + self.qkv = nn.Linear(dim, dim * 3, bias=attention_bias) + self.proj = nn.Linear(dim, dim, bias=False) + self.scaling = self.head_dim ** -0.5 + self.attention_dropout = attention_dropout + self.is_causal = False + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int = 0, + rotary_cos: torch.Tensor = None, + rotary_sin: torch.Tensor = None, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + q = apply_rotary_pos_emb_triton(q, rotary_cos, rotary_sin) + k = apply_rotary_pos_emb_triton(k, rotary_cos, rotary_sin) + + attn_output = g_cache_manager.alloc_tensor(q.shape, q.dtype, device=q.device) + + flash_attention_fwd(q, k, v, attn_output, cu_seqlens, max_seqlen) + attn_output = attn_output.reshape(seq_length, -1) + attn_output = self.proj(attn_output) + return attn_output + + +class Glm4vVisionBlock(nn.Module): + def __init__(self, embed_dim, intermediate_size, num_heads, hidden_act, rms_norm_eps) -> None: + super().__init__() + self.norm1 = Glm4vRMSNorm(embed_dim, eps=rms_norm_eps) + self.norm2 = Glm4vRMSNorm(embed_dim, eps=rms_norm_eps) + self.attn = Glm4vVisionAttention(embed_dim, num_heads=num_heads) + self.mlp = Glm4VisionMlp( + hidden_size=embed_dim, intermediate_size=intermediate_size, hidden_act=hidden_act, bias=False + ) + + def forward(self, hidden_states, cu_seqlens, max_seqlen, rotary_cos, rotary_sin) -> torch.Tensor: + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + rotary_cos=rotary_cos, + rotary_sin=rotary_sin, + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + +class Glm4vVisionTransformerPretrainedModel(nn.Module): + def __init__( + self, + kvargs, + depth=24, + image_size=336, + hidden_size=1536, + intermediate_size=13696, + out_hidden_size=4096, + hidden_act="silu", + num_heads=12, + in_channels=3, + patch_size=14, + spatial_merge_size=2, + temporal_patch_size=2, + rms_norm_eps=1e-5, + **kwargs, + ): + super().__init__() + self.data_type = kvargs.get("data_type", "bfloat16") + self.depth = depth + self.intermediate_size = intermediate_size + self.out_hidden_size = out_hidden_size + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.num_heads = num_heads + self.in_channels = in_channels + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + self.temporal_patch_size = temporal_patch_size + + self.embeddings = Glm4vVisionEmbeddings(hidden_size, image_size, patch_size) + self.patch_embed = Glm4vVisionPatchEmbed(patch_size, temporal_patch_size, in_channels, self.hidden_size) + + head_dim = self.hidden_size // self.num_heads + self.rotary_pos_emb = Glm4vVisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.ModuleList( + [ + Glm4vVisionBlock(self.hidden_size, self.out_hidden_size, num_heads, hidden_act, rms_norm_eps) + for _ in range(self.depth) + ] + ) + self.merger = Glm4vVisionPatchMerger( + dim=self.out_hidden_size, context_dim=self.intermediate_size, hidden_act=hidden_act + ) + + self.post_conv_layernorm = Glm4vRMSNorm(hidden_size, eps=rms_norm_eps) + self.downsample = nn.Conv2d( + in_channels=hidden_size, + out_channels=out_hidden_size, + kernel_size=spatial_merge_size, + stride=spatial_merge_size, + ) + self.post_layernorm = Glm4vRMSNorm(hidden_size, eps=rms_norm_eps) + + self._init_datatype() + + def _init_datatype(self): + if isinstance(self.data_type, torch.dtype): + return + if self.data_type in ["fp16", "float16"]: + self.data_type = torch.float16 + elif self.data_type in ["bf16", "bfloat16"]: + self.data_type = torch.bfloat16 + elif self.data_type in ["fp32", "float32"]: + self.data_type = torch.float32 + else: + raise ValueError(f"Unsupport datatype {self.data_type}!") + return + + def load_model(self, weight_dir): + + processor_config_path = os.path.join(weight_dir, "preprocessor_config.json") + with open(processor_config_path, "r") as f: + processor_config_dict = json.load(f) + self.processor = Qwen2VLImageProcessor(**processor_config_dict) + + bin_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".bin")] + if bin_weight_files: + weight_dict = {} + for file_ in bin_weight_files: + f = torch.load(os.path.join(weight_dir, file_), "cpu") + for k, v in f.items(): + if "model.visual" in k: + weight_dict[k[len("model.visual.") :]] = v + else: + hf_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".safetensors")] + weight_dict = {} + for file_ in hf_weight_files: + f = safe_open(os.path.join(weight_dir, file_), "pt", "cpu") + for k in f.keys(): + if "model.visual" in k: + weight_dict[k[len("model.visual.") :]] = f.get_tensor(k) + + self.load_state_dict(weight_dict) + + def rot_pos_emb(self, grid_thw): + pos_ids = [] + s = self.spatial_merge_size + for _, h, w in grid_thw: + pos_shape = (h // s, s, w // s, s) + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + hpos_ids = hpos_ids.reshape(pos_shape).permute(0, 2, 1, 3).flatten() + wpos_ids = wpos_ids.reshape(pos_shape).permute(0, 2, 1, 3).flatten() + + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1)) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + cos_full, sin_full = self.rotary_pos_emb(max_grid_size) + cos = cos_full[pos_ids].flatten(1) + sin = sin_full[pos_ids].flatten(1) + return cos, sin, pos_ids + + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: + hidden_states = self.patch_embed(hidden_states) + hidden_states = self.post_conv_layernorm(hidden_states) + rotary_cos, rotary_sin, pos_ids = self.rot_pos_emb(grid_thw) + rotary_cos = rotary_cos.to("cuda", non_blocking=True) + rotary_sin = rotary_sin.to("cuda", non_blocking=True) + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, dtype=torch.int32 + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + cu_seqlens = cu_seqlens.to("cuda", non_blocking=True) + hidden_states = self.embeddings(hidden_states, seqlens, grid_thw, pos_ids[:, 0], pos_ids[:, 1]) + + for blk in self.blocks: + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + rotary_cos=rotary_cos, + rotary_sin=rotary_sin, + ) + hidden_states = self.post_layernorm(hidden_states) + hidden_states = hidden_states.view( + -1, self.spatial_merge_size, self.spatial_merge_size, hidden_states.shape[-1] + ) + hidden_states = hidden_states.permute(0, 3, 1, 2) + hidden_states = self.downsample(hidden_states).view(-1, self.out_hidden_size) + return self.merger(hidden_states) + + def encode(self, images: List[ImageItem]): + img_tensors = [] + valid_ids = [] + valid_id = 0 + img_grids = [] + uuids = [] + for i, img in enumerate(images): + if isinstance(img, ImageItem): + uuids.append(img.uuid) + image_data = read_shm(get_shm_name_data(img.uuid)) + image_data = Image.open(BytesIO(image_data)) + pixel_values, image_grid_thw = self.processor.preprocess(image_data) + img_tensors.append(pixel_values) + img_grids.append(image_grid_thw) + else: + raise Exception("Unsupport input types: {} for {}".format(type(img), img)) + + # must devide merge_length + cur_num = img_tensors[-1].shape[0] // (self.spatial_merge_size ** 2) + + valid_ids.append([valid_id, valid_id + cur_num]) + valid_id += cur_num + + if len(img_tensors) <= 0: + return None + + imgs = torch.cat(img_tensors, dim=0) + grid_thw = torch.cat(img_grids, dim=0) + + pixel_values = imgs.to("cuda", dtype=self.data_type, non_blocking=True) + image_grid_thw = grid_thw.to("cuda", non_blocking=True) + + all_img_embeds = self.forward(pixel_values, grid_thw=image_grid_thw) + + return all_img_embeds, uuids, valid_ids diff --git a/lightllm/models/glm4v/layer_infer/__init__.py b/lightllm/models/glm4v/layer_infer/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/glm4v/layer_infer/transformer_layer_infer.py b/lightllm/models/glm4v/layer_infer/transformer_layer_infer.py new file mode 100644 index 000000000..70c884381 --- /dev/null +++ b/lightllm/models/glm4v/layer_infer/transformer_layer_infer.py @@ -0,0 +1,104 @@ +import torch +import torch.functional as F +import torch.distributed as dist +import numpy as np +from typing import Tuple +from functools import partial + +from lightllm.distributed import all_reduce +from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward +from lightllm.models.qwen2_vl.triton_kernel.mrope import mrope_triton_fused +from lightllm.models.qwen2_vl.infer_struct import Qwen2VLInferStateInfo +from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer +from lightllm.models.glm4v.layer_weight.transformer_layer_weight import Glm4VTransformerLayerWeight + + +class Glm4VTransformerLayerInfer(LlamaTransformerLayerInfer): + def __init__(self, layer_num, network_config, mode=[]): + super().__init__(layer_num, network_config, mode) + mrope_section = network_config["rope_parameters"]["mrope_section"] + self.mrope_section = torch.tensor(mrope_section, dtype=torch.int32, device="cuda") + self.partial_rotary_factor = network_config["rope_parameters"]["partial_rotary_factor"] + + def _post_self_att_norm( + self, input, infer_state: Qwen2VLInferStateInfo, layer_weight: Glm4VTransformerLayerWeight + ) -> torch.Tensor: + out = self.alloc_tensor(input.shape, input.dtype) + rmsnorm_forward(input, weight=layer_weight._post_self_att_norm_weight_.weight, eps=self.eps_, out=out) + return out + + def _post_mlp_norm( + self, input, infer_state: Qwen2VLInferStateInfo, layer_weight: Glm4VTransformerLayerWeight + ) -> torch.Tensor: + out = self.alloc_tensor(input.shape, input.dtype) + rmsnorm_forward(input, weight=layer_weight._post_mlp_norm_weight_.weight, eps=self.eps_, out=out) + return out + + def _get_qkv(self, input, infer_state, layer_weight): + q = layer_weight.q_proj.mm(input) + cache_kv = layer_weight.kv_proj.mm(input).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) + mrope_triton_fused( + q.view(-1, self.tp_q_head_num_, self.head_dim_), + cache_kv[:, : self.tp_k_head_num_, :], + infer_state.position_cos, + infer_state.position_sin, + self.mrope_section, + partial_rotary_factor=self.partial_rotary_factor, + is_interleaved=False, + is_glm4v=True, + ) + return q, cache_kv + + def context_forward(self, input_embdings, infer_state: Qwen2VLInferStateInfo, layer_weight): + input1 = self._att_norm(input_embdings, infer_state, layer_weight) + q, cache_kv = self._get_qkv(input1, infer_state, layer_weight) + input1 = None + self._post_cache_kv(cache_kv, infer_state, layer_weight) + + o = self._TransformerLayerInferTpl__context_attention_wrapper_run( + q=q, cache_kv=cache_kv, infer_state=infer_state, layer_weight=layer_weight + ) + + q = None + o = self._get_o(o, infer_state, layer_weight) + if self.tp_world_size_ > 1: + all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + o = self._post_self_att_norm(o, infer_state, layer_weight) # add前多一次norm + input_embdings.add_(o.view(-1, self.embed_dim_)) + o = None + + input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) + ffn_out = self._ffn(input1, infer_state, layer_weight) + ffn_out = self._post_mlp_norm(ffn_out, infer_state, layer_weight) # mlp之后多一次norm + input1 = None + if self.tp_world_size_ > 1: + all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) + return input_embdings + + def token_forward(self, input_embdings, infer_state: Qwen2VLInferStateInfo, layer_weight): + input1 = self._att_norm(input_embdings, infer_state, layer_weight) + q, cache_kv = self._get_qkv(input1, infer_state, layer_weight) + input1 = None + self._post_cache_kv(cache_kv, infer_state, layer_weight) + o = self._token_attention_kernel(q, infer_state, layer_weight) + q = None + o = self._get_o(o, infer_state, layer_weight) + if self.tp_world_size_ > 1: + all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + o = self._post_self_att_norm(o, infer_state, layer_weight) # add前多一次norm + input_embdings.add_(o.view(-1, self.embed_dim_)) + o = None + + input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) + ffn_out = self._ffn(input1, infer_state, layer_weight) + ffn_out = self._post_mlp_norm(ffn_out, infer_state, layer_weight) # mlp之后多一次norm + input1 = None + if self.tp_world_size_ > 1: + all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) + return input_embdings + + def _tpsp_get_qkv(self, input, infer_state, layer_weight) -> Tuple[torch.Tensor, torch.Tensor]: + # TODO + raise Exception("not impl") diff --git a/lightllm/models/glm4v/layer_weight/__init__.py b/lightllm/models/glm4v/layer_weight/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/glm4v/layer_weight/pre_and_post_layer_weight.py b/lightllm/models/glm4v/layer_weight/pre_and_post_layer_weight.py new file mode 100644 index 000000000..52bfd76f5 --- /dev/null +++ b/lightllm/models/glm4v/layer_weight/pre_and_post_layer_weight.py @@ -0,0 +1,14 @@ +import numpy as np +from lightllm.models.qwen2.layer_weights.pre_and_post_layer_weight import Qwen2PreAndPostLayerWeight +from lightllm.models.qwen3_vl.layer_weights.pre_and_post_layer_weight import rename_weight_keys + + +class Glm4VPreAndPostLayerWeight(Qwen2PreAndPostLayerWeight): + def __init__(self, data_type, network_config, mode): + super().__init__(data_type, network_config, mode) + return + + def load_hf_weights(self, weights): + rename_weight_keys(weights) + super().load_hf_weights(weights) + return diff --git a/lightllm/models/glm4v/layer_weight/transformer_layer_weight.py b/lightllm/models/glm4v/layer_weight/transformer_layer_weight.py new file mode 100644 index 000000000..8302f3eea --- /dev/null +++ b/lightllm/models/glm4v/layer_weight/transformer_layer_weight.py @@ -0,0 +1,35 @@ +from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, COLMMWeight, NormWeight +from lightllm.models.qwen2.layer_weights.transformer_layer_weight import Qwen2TransformerLayerWeight + + +class Glm4VTransformerLayerWeight(Qwen2TransformerLayerWeight): + def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): + super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + + def _init_weight_names(self): + self._post_self_att_norm_weight_name = f"model.layers.{self.layer_num_}.post_self_attn_layernorm.weight" + self._post_self_att_norm_bias_name = None + self._post_mlp_norm_weight_name = f"model.layers.{self.layer_num_}.post_mlp_layernorm.weight" + self._post_mlp_norm_bias_name = None + super()._init_weight_names() + + def load_hf_weights(self, weights): + gate_up_weight_name = f"model.layers.{self.layer_num_}.mlp.gate_up_proj.weight" + if gate_up_weight_name in weights: + intermediate_size = self.network_config_["intermediate_size"] + gate_up_proj = weights[gate_up_weight_name] + gate_weight_ = gate_up_proj[0:intermediate_size, :] + up_weight_ = gate_up_proj[intermediate_size:, :] + weights[self._gate_weight_name] = gate_weight_ + weights[self._up_weight_name] = up_weight_ + del weights[gate_up_weight_name] + super().load_hf_weights(weights) + + def _init_norm(self): + self._post_self_att_norm_weight_ = NormWeight( + self._post_self_att_norm_weight_name, self.data_type_, bias_name=self._post_self_att_norm_bias_name + ) + self._post_mlp_norm_weight_ = NormWeight( + self._post_mlp_norm_weight_name, self.data_type_, bias_name=self._post_mlp_norm_bias_name + ) + super()._init_norm() diff --git a/lightllm/models/glm4v/model.py b/lightllm/models/glm4v/model.py new file mode 100644 index 000000000..78157fdf7 --- /dev/null +++ b/lightllm/models/glm4v/model.py @@ -0,0 +1,87 @@ +import os +import json +import numpy as np +from lightllm.common.build_utils import repair_config +from lightllm.models.registry import ModelRegistry +from lightllm.models.qwen2_vl.infer_struct import Qwen2VLInferStateInfo +from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer +from lightllm.models.qwen2_vl.layer_infer.transformer_layer_infer import Qwen2VLTransformerLayerInfer +from lightllm.models.glm4v.layer_infer.transformer_layer_infer import Glm4VTransformerLayerInfer +from lightllm.models.glm4v.layer_weight.pre_and_post_layer_weight import Glm4VPreAndPostLayerWeight +from lightllm.models.glm4v.layer_weight.transformer_layer_weight import Glm4VTransformerLayerWeight +from lightllm.server.multimodal_params import MultimodalParams +from lightllm.models.qwen2_vl.model import QWen2VLTokenizer +from lightllm.models.qwen2.model import Qwen2TpPartModel + + +class GLM4VTokenizer(QWen2VLTokenizer): + def __init__(self, tokenizer=None, image_processor=None, **kwargs): + self.tokenizer = tokenizer + self.image_processor = image_processor + self.min_pixel = self.image_processor.size["shortest_edge"] + self.max_pixel = self.image_processor.size["longest_edge"] + self.patch_size = self.image_processor.patch_size + self.merge_size = self.image_processor.merge_size + self.image_start_id = kwargs["model_cfg"]["image_start_token_id"] + self.image_end_id = kwargs["model_cfg"]["image_end_token_id"] + self.image_token_id = kwargs["model_cfg"]["image_token_id"] + + def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs): + origin_ids = self.tokenizer.encode(prompt) + + # -> + origin_ids = [token for token in origin_ids if token != self.image_token_id] + # --> id,id+1...id+num + input_ids = [] + image_id = 0 + while True: + try: + start_idx = origin_ids.index(self.image_start_id) + if start_idx + 1 >= len(origin_ids): + break + if origin_ids[start_idx + 1] == self.image_end_id: + input_ids.extend(origin_ids[: start_idx + 1]) + token_id = multimodal_params.images[image_id].token_id + token_num = multimodal_params.images[image_id].token_num + multimodal_params.images[image_id].start_idx = len(input_ids) + input_ids.extend(range(token_id, token_id + token_num)) + input_ids.append(self.image_end_id) + origin_ids = origin_ids[start_idx + 2 :] + image_id += 1 + else: + raise ValueError("image token error") + except ValueError: + break + input_ids.extend(origin_ids) + return input_ids + + +@ModelRegistry(["glm4v"], is_multimodal=True) +class GLM4VTpPartModel(Qwen2TpPartModel): + + pre_layer_infer_class = LlamaMultimodalPreLayerInfer + transformer_layer_infer_class = Glm4VTransformerLayerInfer + + pre_and_post_weight_class = Glm4VPreAndPostLayerWeight + transformer_weight_class = Glm4VTransformerLayerWeight + + infer_state_class = Qwen2VLInferStateInfo + + def __init__(self, kvargs): + super().__init__(kvargs) + return + + def _init_inferstate_cls(self): + pass + + def _init_config(self): + with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file: + all_config = json.load(json_file) + self.config = all_config["text_config"] + # rename keys + repair_config(self.config, same_names=["num_attention_heads", "n_head"]) + repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"]) + repair_config(self.config, same_names=["num_hidden_layers", "n_layer"]) + if self.finetune_config: + self.config["vocab_size"] = self.finetune_config.vocab_size + return diff --git a/lightllm/models/llama/model.py b/lightllm/models/llama/model.py index a228e0025..420e12e51 100644 --- a/lightllm/models/llama/model.py +++ b/lightllm/models/llama/model.py @@ -108,6 +108,8 @@ def _init_custom(self): 模型特殊的一些初始化 """ rope_scaling = self.config.get("rope_scaling", None) + if rope_scaling is None: + rope_scaling = self.config.get("rope_parameters", None) if rope_scaling is None: self._init_to_get_rotary() return @@ -171,14 +173,21 @@ def _init_weights(self): return def _init_to_get_rotary(self, default_base=10000): - partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_) + rope_params = self.config.get("rope_parameters") + if rope_params is not None: + partial_rotary_factor = rope_params.get("partial_rotary_factor", 1) + base = rope_params.get("rope_theta", float(default_base)) + else: + partial_rotary_factor = self.config.get("partial_rotary_factor", 1) + base = self.config.get("rope_theta", float(default_base)) + + partial_head_dim = int(partial_rotary_factor * self.head_dim_) + if self.config.get("rope_scaling", {}) is None: rope_scaling_factor = 1.0 else: rope_scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0) - base = self.config.get("rope_theta", float(default_base)) - if "max_sequence_length" in self.config: max_seq_len = self.config["max_sequence_length"] else: diff --git a/lightllm/models/qwen2_vl/triton_kernel/mrope.py b/lightllm/models/qwen2_vl/triton_kernel/mrope.py index 5aed65862..1d85b84c3 100644 --- a/lightllm/models/qwen2_vl/triton_kernel/mrope.py +++ b/lightllm/models/qwen2_vl/triton_kernel/mrope.py @@ -85,6 +85,7 @@ def _mrope_triton_fused_kernel( stride_kh, stride_kd, is_interleaved: tl.constexpr, + is_glm4v: tl.constexpr, HEAD_Q: tl.constexpr, HEAD_K: tl.constexpr, BLOCK_DMODEL: tl.constexpr, @@ -95,6 +96,10 @@ def _mrope_triton_fused_kernel( dim_range0 = tl.arange(0, BLOCK_DMODEL // 2) dim_range1 = dim_range0 + BLOCK_DMODEL // 2 + if is_glm4v: + dim_range0 = dim_range0 * 2 + dim_range1 = dim_range0 + 1 + t_cos = Cos + seq_index * stride_cosd h_cos = Cos + stride_cosld + seq_index * stride_cosd w_cos = Cos + 2 * stride_cosld + seq_index * stride_cosd @@ -192,11 +197,13 @@ def mrope_triton_fused( cos: torch.Tensor, sin: torch.Tensor, mrope_section: torch.Tensor, - is_interleaved: bool, + partial_rotary_factor: float = 1.0, + is_interleaved: bool = False, + is_glm4v: bool = False, run_config: Optional[dict] = None, ): head_num_q, head_num_k = q.shape[1], k.shape[1] - head_dim = int(q.shape[2]) + head_dim = int(q.shape[2] * partial_rotary_factor) num_tokens = q.shape[0] if not run_config: @@ -228,6 +235,7 @@ def mrope_triton_fused( stride_kh=k.stride(1), stride_kd=k.stride(2), is_interleaved=is_interleaved, + is_glm4v=is_glm4v, HEAD_Q=head_num_q, HEAD_K=head_num_k, BLOCK_DMODEL=head_dim, diff --git a/lightllm/server/tokenizer.py b/lightllm/server/tokenizer.py index e0b2bd425..e668156b9 100644 --- a/lightllm/server/tokenizer.py +++ b/lightllm/server/tokenizer.py @@ -29,6 +29,7 @@ from ..models.qwen_vl.model import QWenVLTokenizer from ..models.qwen2_vl.model import QWen2VLTokenizer from ..models.qwen3_vl.model import QWen3VLTokenizer +from ..models.glm4v.model import GLM4VTokenizer from ..models.internvl.model import InternvlTokenizer from ..models.gemma3.model import Gemma3Tokenizer @@ -104,5 +105,10 @@ def get_tokenizer( tokenizer = InternvlTokenizer(tokenizer, model_cfg, weight_dir=tokenizer_name) elif model_type == "gemma3": tokenizer = Gemma3Tokenizer(tokenizer, model_cfg) + elif model_type == "glm4v": + from transformers import AutoProcessor + + processor = AutoProcessor.from_pretrained(tokenizer_name) + tokenizer = GLM4VTokenizer(tokenizer=tokenizer, image_processor=processor.image_processor, model_cfg=model_cfg) return tokenizer diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index d3d1610f3..11c9b15c4 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -19,6 +19,7 @@ from lightllm.models.qwen2_5_vl.qwen2_5_visual import Qwen2_5_VisionTransformerPretrainedModel from lightllm.models.qwen3_vl.qwen3_visual import Qwen3VisionTransformerPretrainedModel from lightllm.models.tarsier2.tarsier2_visual import TarsierVisionTransformerPretrainedModel +from lightllm.models.glm4v.glm4v_visual import Glm4vVisionTransformerPretrainedModel from lightllm.utils.infer_utils import set_random_seed from lightllm.utils.dist_utils import init_vision_distributed_env from lightllm.utils.graceful_utils import graceful_registry @@ -78,6 +79,10 @@ def exposed_init_model(self, kvargs): # self.model = InternVLVisionModel() elif self.model_type == "gemma3": self.model = Gemma3VisionModel() + elif self.model_type == "glm4v": + self.model = ( + Glm4vVisionTransformerPretrainedModel(kvargs, **model_cfg["vision_config"]).eval().bfloat16() + ) else: raise Exception(f"can not support {self.model_type} now")