-
Notifications
You must be signed in to change notification settings - Fork 291
Add neo chat #1161
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add neo chat #1161
Changes from all commits
8a67a47
fdc1369
e8e7416
ba44983
4d41a33
0e8845c
b48cd49
7a904f3
4b757dd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -110,6 +110,8 @@ def _init_custom(self): | |||||||||||||||||||||||||||||||||||||
| rope_scaling = self.config.get("rope_scaling", None) | ||||||||||||||||||||||||||||||||||||||
| if rope_scaling is None: | ||||||||||||||||||||||||||||||||||||||
| self._init_to_get_rotary() | ||||||||||||||||||||||||||||||||||||||
| if "rope_theta_hw" in self.config: | ||||||||||||||||||||||||||||||||||||||
| self._init_to_get_hw_rotary() | ||||||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| if "rope_type" in rope_scaling: | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -132,6 +134,8 @@ def _init_custom(self): | |||||||||||||||||||||||||||||||||||||
| self._init_to_get_rotary() | ||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||
| raise ValueError(f"Unknown RoPE scaling type {scaling_type}") | ||||||||||||||||||||||||||||||||||||||
| if "rope_theta_hw" in self.config: | ||||||||||||||||||||||||||||||||||||||
| self._init_to_get_hw_rotary() | ||||||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| def _init_weights(self): | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -178,7 +182,7 @@ def _init_to_get_rotary(self, default_base=10000): | |||||||||||||||||||||||||||||||||||||
| rope_scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| base = self.config.get("rope_theta", float(default_base)) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| print(f"base is {base}") | ||||||||||||||||||||||||||||||||||||||
| if "max_sequence_length" in self.config: | ||||||||||||||||||||||||||||||||||||||
| max_seq_len = self.config["max_sequence_length"] | ||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -211,6 +215,47 @@ def _init_to_get_rotary(self, default_base=10000): | |||||||||||||||||||||||||||||||||||||
| self._sin_cached = torch.sin(freqs).to(self.data_type).cuda() | ||||||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| def _init_to_get_hw_rotary(self, default_base=10000): | ||||||||||||||||||||||||||||||||||||||
| partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_ // 2) | ||||||||||||||||||||||||||||||||||||||
| if self.config.get("rope_scaling", {}) is None: | ||||||||||||||||||||||||||||||||||||||
| rope_scaling_factor = 1.0 | ||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||
| rope_scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| base = self.config.get("rope_theta_hw", float(default_base)) | ||||||||||||||||||||||||||||||||||||||
| print(f"hw_base is {base}") | ||||||||||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||||||||||||||||||||||||||||||
| if "max_sequence_length" in self.config: | ||||||||||||||||||||||||||||||||||||||
| max_seq_len = self.config["max_sequence_length"] | ||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||
| max_position_embeddings = self.config.get( | ||||||||||||||||||||||||||||||||||||||
| "max_position_embeddings_hw", 2048 if base <= 10000.0 + 1e-5 else 16384 | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
| max_seq_len = max_position_embeddings * rope_scaling_factor | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| # NTK | ||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||
| ntk_alpha = float(os.environ.get("LIGHTLLM_NTK_ALPHA", 1)) | ||||||||||||||||||||||||||||||||||||||
| assert ntk_alpha >= 1 | ||||||||||||||||||||||||||||||||||||||
| if ntk_alpha > 1: | ||||||||||||||||||||||||||||||||||||||
| logger.info(f"Note: NTK enabled, alpha set to {ntk_alpha}") | ||||||||||||||||||||||||||||||||||||||
| max_seq_len *= ntk_alpha | ||||||||||||||||||||||||||||||||||||||
| base = base * (ntk_alpha ** (partial_head_dim / (partial_head_dim - 2))) # Base change formula | ||||||||||||||||||||||||||||||||||||||
| except: | ||||||||||||||||||||||||||||||||||||||
| pass | ||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+236
to
+244
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using a bare
Suggested change
|
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| inv_freq = 1.0 / ( | ||||||||||||||||||||||||||||||||||||||
| base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim) | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
| t = ( | ||||||||||||||||||||||||||||||||||||||
| torch.arange(max(max_seq_len + 1024 * 128, self.max_seq_length), device="cpu", dtype=torch.float32) | ||||||||||||||||||||||||||||||||||||||
| / rope_scaling_factor | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
| freqs = torch.outer(t, inv_freq) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| self._hw_cos_cached = torch.cos(freqs).to(self.data_type).cuda() | ||||||||||||||||||||||||||||||||||||||
| self._hw_sin_cached = torch.sin(freqs).to(self.data_type).cuda() | ||||||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| def _init_to_get_dynamic_ntk_rotary(self): | ||||||||||||||||||||||||||||||||||||||
| partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_) | ||||||||||||||||||||||||||||||||||||||
| max_position_embeddings = self.config.get("max_position_embeddings", 2048) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,159 @@ | ||
| import torch | ||
| from functools import partial | ||
| from typing import Tuple | ||
| from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward | ||
| from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd | ||
| from lightllm.models.neo_chat_moe.infer_struct import NeoChatInferStateInfo | ||
| from lightllm.models.neo_chat_moe.triton_kernel.context_attention_fwd_neo import context_attention_fwd_neo | ||
| from lightllm.models.llama.triton_kernel.context_flashattention_nopad import context_attention_fwd | ||
| from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd | ||
| from lightllm.models.qwen3.layer_infer.transformer_layer_infer import Qwen3TransformerLayerInfer | ||
| from lightllm.models.neo_chat.layer_weights.transformer_layer_weight import NeoChatTransformerLayerWeight | ||
| from lightllm.distributed import all_reduce | ||
| import torch.distributed as dist | ||
| from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer | ||
| from lightllm.models.qwen3.triton_kernel.qk_norm import qk_rmsnorm_forward | ||
|
|
||
|
|
||
| class NeoChatTransformerLayerInfer(Qwen3TransformerLayerInfer): | ||
| def __init__(self, data_type, network_config, mode): | ||
| super().__init__(data_type, network_config, mode) | ||
| return | ||
|
|
||
| def _bind_attention(self): | ||
| self._context_attention_kernel = self._context_attention_kernel | ||
| self._token_attention_kernel = self._token_decode_attention_normal | ||
| self._copy_kv_to_mem_cache = self._copy_kv_to_mem_cache_normal | ||
| return | ||
|
|
||
| def _get_qkv(self, input, infer_state: NeoChatInferStateInfo, layer_weight: NeoChatTransformerLayerWeight): | ||
| input = input.view(-1, self.embed_dim_) | ||
| q = layer_weight.q_proj.mm(input) # [T, Hq*D] | ||
|
|
||
| q_hw = layer_weight.q_hw_proj.mm(input) | ||
| q_hw = q_hw.view(-1, self.tp_q_head_num_, self.head_dim_) | ||
| q_h, q_w = q_hw.chunk(2, dim=-1) | ||
|
|
||
| k_hw = layer_weight.k_hw_proj.mm(input) | ||
| k_hw = k_hw.view(-1, self.tp_k_head_num_, self.head_dim_) | ||
| k_h, k_w = k_hw.chunk(2, dim=-1) | ||
|
|
||
| cache_kv = layer_weight.kv_proj.mm(input) # [T, (Hk+Hv)*D] | ||
|
|
||
| qk_rmsnorm_forward(q, weight=layer_weight.q_norm_weight_.weight, eps=self.eps_) | ||
|
|
||
| q_h_2d = q_h.reshape(q.shape[0], -1) | ||
| q_w_2d = q_w.reshape(q.shape[0], -1) | ||
| qk_rmsnorm_forward(q_h_2d, weight=layer_weight.q_norm_h_weight_.weight, eps=self.eps_) | ||
| qk_rmsnorm_forward(q_w_2d, weight=layer_weight.q_norm_w_weight_.weight, eps=self.eps_) | ||
| q_h = q_h_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2) | ||
| q_w = q_w_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2) | ||
|
|
||
| qk_rmsnorm_forward( | ||
| cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], | ||
| weight=layer_weight.k_norm_weight_.weight, | ||
| eps=self.eps_, | ||
| ) | ||
|
|
||
| k_h_2d = k_h.reshape(q.shape[0], -1) # [T, Hk*(D/2)] | ||
| k_w_2d = k_w.reshape(q.shape[0], -1) | ||
| qk_rmsnorm_forward(k_h_2d, weight=layer_weight.k_norm_h_weight_.weight, eps=self.eps_) | ||
| qk_rmsnorm_forward(k_w_2d, weight=layer_weight.k_norm_w_weight_.weight, eps=self.eps_) | ||
| k_h = k_h_2d.view(q.shape[0], self.tp_k_head_num_, self.head_dim_ // 2) | ||
| k_w = k_w_2d.view(q.shape[0], self.tp_k_head_num_, self.head_dim_ // 2) | ||
|
|
||
| cache_kv = cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) | ||
|
|
||
| rotary_emb_fwd( | ||
| q.view(-1, self.tp_q_head_num_, self.head_dim_), | ||
| cache_kv[:, : self.tp_k_head_num_, :], | ||
| infer_state.position_cos, | ||
| infer_state.position_sin, | ||
| ) | ||
| rotary_emb_fwd( | ||
| q_h, | ||
| k_h, | ||
| infer_state.position_cos_h, | ||
| infer_state.position_sin_h, | ||
| ) | ||
| rotary_emb_fwd( | ||
| q_w, | ||
| k_w, | ||
| infer_state.position_cos_w, | ||
| infer_state.position_sin_w, | ||
| ) | ||
|
|
||
| q3 = q.view(-1, self.tp_q_head_num_, self.head_dim_) | ||
| q3 = torch.cat([q3, q_h, q_w], dim=-1) | ||
| q = q3.reshape(q3.shape[0], -1) | ||
|
|
||
| k = cache_kv[:, : self.tp_k_head_num_, :] | ||
| k = torch.cat([k, k_h, k_w], dim=-1) | ||
|
|
||
| v = cache_kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :] | ||
| v_pad = torch.zeros((v.shape[0], v.shape[1], self.head_dim_), device=v.device, dtype=v.dtype) | ||
| v = torch.cat([v, v_pad], dim=-1) | ||
|
|
||
| cache_kv = torch.cat([k, v], dim=1) | ||
| return q, cache_kv | ||
|
|
||
| def _context_attention_kernel( | ||
| self, q, kv, infer_state: NeoChatInferStateInfo, layer_weight, out=None | ||
| ) -> torch.Tensor: | ||
| o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out | ||
| kv = infer_state.mem_manager.kv_buffer[self.layer_num_] | ||
| context_attention_fwd_neo( | ||
| q.view(-1, self.tp_q_head_num_, self.head_dim_ * 2), | ||
| kv[:, 0 : self.tp_k_head_num_, :], | ||
| kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :], | ||
| o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_ * 2), | ||
| infer_state.position_ids[0], # [0,0,1,2,3,3,3,4] | ||
| infer_state.b_req_idx, | ||
| infer_state.b_start_loc, | ||
| infer_state.b_seq_len, | ||
| infer_state.b_ready_cache_len, | ||
| infer_state.max_len_in_batch, | ||
| infer_state.req_manager.req_to_token_indexs, | ||
| ) | ||
| o3 = o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_ * 2) | ||
| o3 = o3[:, :, : self.head_dim_].contiguous() | ||
| return o3.view(o3.shape[0], -1) | ||
|
|
||
| def _token_decode_attention_normal(self, q, infer_state: NeoChatInferStateInfo, layer_weight, out=None): | ||
| total_token_num = infer_state.total_token_num | ||
| batch_size = infer_state.batch_size | ||
|
|
||
| q_3d = q.view(batch_size, self.tp_q_head_num_, self.head_dim_ * 2) | ||
|
|
||
| att_m_tensor = self.alloc_tensor((self.tp_q_head_num_, total_token_num), torch.float32) | ||
|
|
||
| k_3d = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] | ||
| token_att_fwd( | ||
| q_3d, | ||
| k_3d, | ||
| att_m_tensor, | ||
| infer_state.req_manager.req_to_token_indexs, | ||
| infer_state.b_req_idx, | ||
| infer_state.b_start_loc, | ||
| infer_state.b_seq_len, | ||
| infer_state.max_len_in_batch, | ||
| ) | ||
|
|
||
| from lightllm.models.llama.triton_kernel.token_attention_softmax_and_reducev import token_softmax_reducev_fwd | ||
|
|
||
| v_3d = infer_state.mem_manager.kv_buffer[self.layer_num_][ | ||
| :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : self.head_dim_ | ||
| ] | ||
|
|
||
| o_3d = self.alloc_tensor((batch_size, self.tp_q_head_num_, self.head_dim_), q.dtype) if out is None else out | ||
|
|
||
| token_softmax_reducev_fwd( | ||
| att_m_tensor, | ||
| v_3d, | ||
| o_3d, | ||
| infer_state.req_manager.req_to_token_indexs, | ||
| infer_state.b_req_idx, | ||
| infer_state.b_start_loc, | ||
| infer_state.b_seq_len, | ||
| ) | ||
| return o_3d.view(batch_size, -1) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,23 @@ | ||
| import torch | ||
| import numpy as np | ||
| from lightllm.models.qwen2.layer_weights.pre_and_post_layer_weight import Qwen2PreAndPostLayerWeight | ||
|
|
||
| # add key: language_model.xxx -> xxx | ||
| # only change keys at PreAndPostLayerWeight load, TransformLayerWeight is correct now | ||
| def rename_weight_keys(weights): | ||
| prefix = "language_model." | ||
| keys = list(weights.keys()) | ||
| for k in keys: | ||
| if prefix in k: | ||
| weights[k.replace(prefix, "")] = weights.pop(k) | ||
|
|
||
|
|
||
| class NeoChatPreAndPostLayerWeight(Qwen2PreAndPostLayerWeight): | ||
| def __init__(self, data_type, network_config, mode): | ||
| super().__init__(data_type, network_config, mode) | ||
| return | ||
|
|
||
| def load_hf_weights(self, weights): | ||
| rename_weight_keys(weights) | ||
| super().load_hf_weights(weights) | ||
| return |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,51 @@ | ||
| from lightllm.models.qwen3.layer_weights.transformer_layer_weight import Qwen3TransformerLayerWeight | ||
| from lightllm.common.basemodel.layer_weights.meta_weights import ( | ||
| NormWeight, | ||
| ROWMMWeight, | ||
| ) | ||
|
|
||
|
|
||
| class NeoChatTransformerLayerWeight(Qwen3TransformerLayerWeight): | ||
| def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): | ||
| super().__init__(layer_num, data_type, network_config, mode, quant_cfg) | ||
| return | ||
|
|
||
| def _init_weight_names(self): | ||
| super()._init_weight_names() | ||
| self._q_weight_hw_name = f"model.layers.{self.layer_num_}.self_attn.q_proj_hw.weight" | ||
| self._q_bias_hw_name = None | ||
| self._k_weight_hw_name = f"model.layers.{self.layer_num_}.self_attn.k_proj_hw.weight" | ||
| self._k_bias_hw_name = None | ||
|
|
||
| self._q_norm_h_name = f"model.layers.{self.layer_num_}.self_attn.q_norm_h.weight" | ||
| self._q_norm_w_name = f"model.layers.{self.layer_num_}.self_attn.q_norm_w.weight" | ||
|
|
||
| self._k_norm_h_name = f"model.layers.{self.layer_num_}.self_attn.k_norm_h.weight" | ||
| self._k_norm_w_name = f"model.layers.{self.layer_num_}.self_attn.k_norm_w.weight" | ||
|
|
||
| def _init_qkv(self): | ||
| super()._init_qkv() | ||
| self.q_hw_proj = ROWMMWeight( | ||
| weight_names=self._q_weight_hw_name, | ||
| data_type=self.data_type_, | ||
| bias_names=self._q_bias_hw_name, | ||
| quant_cfg=self.quant_cfg, | ||
| layer_num=self.layer_num_, | ||
| name="q_hw_proj", | ||
| ) | ||
| self.k_hw_proj = ROWMMWeight( | ||
| weight_names=self._k_weight_hw_name, | ||
| data_type=self.data_type_, | ||
| bias_names=self._k_bias_hw_name, | ||
| quant_cfg=self.quant_cfg, | ||
| layer_num=self.layer_num_, | ||
| name="k_hw_proj", | ||
| ) | ||
|
|
||
| def _init_norm(self): | ||
| super()._init_norm() | ||
|
|
||
| self.q_norm_h_weight_ = NormWeight(weight_name=self._q_norm_h_name, data_type=self.data_type_) | ||
| self.q_norm_w_weight_ = NormWeight(weight_name=self._q_norm_w_name, data_type=self.data_type_) | ||
| self.k_norm_h_weight_ = NormWeight(weight_name=self._k_norm_h_name, data_type=self.data_type_) | ||
| self.k_norm_w_weight_ = NormWeight(weight_name=self._k_norm_w_name, data_type=self.data_type_) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,53 @@ | ||
| import os | ||
| import json | ||
| from lightllm.common.build_utils import repair_config | ||
| from lightllm.models.registry import ModelRegistry, llm_model_type_is | ||
| from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo | ||
| from lightllm.models.qwen3_vl.layer_infer.pre_layer_infer import Qwen3VLMultimodalPreLayerInfer | ||
| from lightllm.models.qwen3_vl.layer_infer.transformer_layer_infer import Qwen3VLTransformerLayerInfer | ||
| from lightllm.models.qwen3_vl.layer_weights.pre_and_post_layer_weight import Qwen3VLPreAndPostLayerWeight | ||
| from lightllm.models.qwen2_vl.model import QWen2VLTokenizer | ||
| from lightllm.models.qwen3.model import Qwen3TpPartModel | ||
| from lightllm.server.core.objs import SamplingParams | ||
| from lightllm.models.qwen3_moe.model import Qwen3MOEModel | ||
| from lightllm.server.multimodal_params import AudioItem, MultimodalParams, ImageItem | ||
| from lightllm.models.neo_chat_moe.vision_process import smart_resize | ||
| from lightllm.models.internvl.model import InternvlTokenizer | ||
| from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer | ||
| from lightllm.models.neo_chat.layer_infer.transformer_layer_infer import NeoChatTransformerLayerInfer | ||
| from lightllm.models.llama.infer_struct import LlamaInferStateInfo | ||
| from lightllm.models.neo_chat.layer_weights.transformer_layer_weight import NeoChatTransformerLayerWeight | ||
| from lightllm.models.neo_chat.layer_weights.pre_and_post_layer_weight import NeoChatPreAndPostLayerWeight | ||
| from lightllm.common.basemodel.multimodal_tokenizer import BaseMultiModalTokenizer | ||
| from lightllm.models.neo_chat_moe.infer_struct import NeoChatInferStateInfo | ||
|
|
||
|
|
||
| @ModelRegistry(["neo_chat"], is_multimodal=True, condition=llm_model_type_is("qwen3")) | ||
| class NeoTpPartModel(Qwen3TpPartModel): | ||
|
|
||
| pre_layer_infer_class = LlamaMultimodalPreLayerInfer | ||
| transformer_layer_infer_class = NeoChatTransformerLayerInfer | ||
|
|
||
| pre_and_post_weight_class = NeoChatPreAndPostLayerWeight | ||
| transformer_weight_class = NeoChatTransformerLayerWeight | ||
|
|
||
| infer_state_class = NeoChatInferStateInfo | ||
|
|
||
| def __init__(self, kvargs): | ||
| super().__init__(kvargs) | ||
| return | ||
|
|
||
| def _init_inferstate_cls(self): | ||
| pass | ||
|
|
||
| def _init_config(self): | ||
| with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file: | ||
| all_config = json.load(json_file) | ||
| self.config = all_config["llm_config"] | ||
| # rename keys | ||
| repair_config(self.config, same_names=["num_attention_heads", "n_head"]) | ||
| repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"]) | ||
| repair_config(self.config, same_names=["num_hidden_layers", "n_layer"]) | ||
| if self.finetune_config: | ||
| self.config["vocab_size"] = self.finetune_config.vocab_size | ||
| return |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This
printstatement appears to be for debugging purposes and should be removed before merging.