diff --git a/docs/source/Instruction/GKD.md b/docs/source/Instruction/GKD.md index 377e610aaa..d051a262f9 100644 --- a/docs/source/Instruction/GKD.md +++ b/docs/source/Instruction/GKD.md @@ -210,3 +210,78 @@ swift rlhf \ ``` 相关脚本可以参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/on_policy_distillation.sh) + +## 条件蒸馏(Conditional Distillation) + +条件蒸馏允许教师模型和学生模型使用**不同的上下文或提示词**进行训练,从而实现更灵活的知识迁移策略。例如: +- 教师模型接收包含额外专家指导的提示词 +- 教师模型接收任务重构后的输入(如摘要、翻译等) +- 教师模型使用更长的上下文信息 + +### TeacherAdapter 插件系统 + +通过实现 `TeacherAdapter` 接口,可以自定义教师模型的上下文转换逻辑: + +```python +# swift/plugin/teacher_adapter.py +from swift.plugin import TeacherAdapter + +class MyTeacherAdapter(TeacherAdapter): + def shape_context(self, history): + """将学生的消息转换为教师的消息 + + Args: + history: 学生模型的消息列表(OpenAI 格式) + + Returns: + 教师模型的消息列表 + """ + # 为教师添加额外的系统提示 + teacher_history = history.copy() + if teacher_history and teacher_history[0]['role'] == 'system': + teacher_history[0]['content'] += '\n\n你是一位专业领域专家。' + else: + teacher_history.insert(0, { + 'role': 'system', + 'content': '你是一位专业领域专家。' + }) + return teacher_history + +# 注册到插件系统 +from swift.plugin import teacher_adapters +teacher_adapters['my_adapter'] = MyTeacherAdapter +``` + +### 内置 Adapter + +SWIFT 提供两个内置的 teacher adapter: + +| Adapter | 说明 | +|---------|------| +| `default` | 默认:教师使用与学生相同的上下文 | +| `example` | 示例:为教师添加额外的系统提示 | + +### 使用方法 + +```bash +swift rlhf \ + --rlhf_type gkd \ + --model Qwen/Qwen2.5-0.5B-Instruct \ + --teacher_model Qwen/Qwen2.5-7B-Instruct \ + --teacher_adapter example \ + --dataset your_dataset.jsonl \ + ... +``` + +### 工作原理 + +在条件蒸馏中: + +1. **学生模型**处理原始输入:`[prompt_student] + [response]` +2. **教师模型**处理转换后的输入:`[prompt_teacher] + [response]` +3. 两个模型在**相同的 response tokens** 上计算 logits +4. 使用这些 logits 计算蒸馏损失 + +其中 `prompt_teacher` 由 `teacher_adapter.shape_context()` 从 `prompt_student` 转换而来,而 `response` 部分保持不变。 + +训练脚本参考[这里](../../../examples/train/on_policy_condition_distillation.sh) diff --git a/docs/source_en/Instruction/GKD.md b/docs/source_en/Instruction/GKD.md index 98a0b6b2cb..b676d78528 100644 --- a/docs/source_en/Instruction/GKD.md +++ b/docs/source_en/Instruction/GKD.md @@ -212,3 +212,78 @@ We can achieve the [On-Policy Distillation](https://thinkingmachines.ai/blog/on- ``` For a complete implementation, refer to the example script [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/on_policy_distillation.sh). + +## Conditional Distillation + +Conditional distillation enables the teacher and student models to use **different contexts or prompts** during training, allowing for more flexible knowledge transfer strategies. For example: +- Teacher model receives prompts with additional expert guidance +- Teacher model receives task-reformulated inputs (e.g., summaries, translations) +- Teacher model uses longer context information + +### TeacherAdapter Plugin System + +You can customize the teacher model's context transformation logic by implementing the `TeacherAdapter` interface: + +```python +# swift/plugin/teacher_adapter.py +from swift.plugin import TeacherAdapter + +class MyTeacherAdapter(TeacherAdapter): + def shape_context(self, history): + """Transform student messages to teacher messages + + Args: + history: Student model's message list (OpenAI format) + + Returns: + Teacher model's message list + """ + # Add extra system prompt for teacher + teacher_history = history.copy() + if teacher_history and teacher_history[0]['role'] == 'system': + teacher_history[0]['content'] += '\n\nYou are an expert with extensive knowledge.' + else: + teacher_history.insert(0, { + 'role': 'system', + 'content': 'You are an expert with extensive knowledge.' + }) + return teacher_history + +# Register to plugin system +from swift.plugin import teacher_adapters +teacher_adapters['my_adapter'] = MyTeacherAdapter +``` + +### Built-in Adapters + +SWIFT provides two built-in teacher adapters: + +| Adapter | Description | +|---------|-------------| +| `default` | Default: teacher uses the same context as student | +| `example` | Example: adds extra instructions to system prompt for teacher | + +### Usage + +```bash +swift rlhf \ + --rlhf_type gkd \ + --model Qwen/Qwen2.5-0.5B-Instruct \ + --teacher_model Qwen/Qwen2.5-7B-Instruct \ + --teacher_adapter example \ + --dataset your_dataset.jsonl \ + ... +``` + +### How It Works + +In conditional distillation: + +1. **Student model** processes original input: `[prompt_student] + [response]` +2. **Teacher model** processes transformed input: `[prompt_teacher] + [response]` +3. Both models compute logits on the **same response tokens** +4. Distillation loss is calculated using these logits + +Where `prompt_teacher` is transformed from `prompt_student` by `teacher_adapter.shape_context()`, while the `response` part remains unchanged. + +Training script reference [here](../../../examples/train/on_policy_condition_distillation.sh) diff --git a/examples/train/on_policy_condition_distillation.sh b/examples/train/on_policy_condition_distillation.sh new file mode 100644 index 0000000000..15503b3ae9 --- /dev/null +++ b/examples/train/on_policy_condition_distillation.sh @@ -0,0 +1,45 @@ +# On-Policy Distillation https://thinkingmachines.ai/blog/on-policy-distillation/ + +# CUDA_VISIBLE_DEVICES=6 \ +# swift rollout \ +# --template qwen3_nothinking\ +# --model Qwen/Qwen3-8B-Base \ +# --vllm_max_model_len 24192 + +NPROC_PER_NODE=7 \ +PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,7 \ +swift rlhf \ + --rlhf_type gkd \ + --teacher_adapter math_teacher \ + --template qwen3_nothinking\ + --model Qwen/Qwen3-8B-Base \ + --teacher_model_type qwen3_nothinking\ + --teacher_model Qwen/Qwen3-8B-Base \ + --train_type full \ + --dataset open-thoughts/OpenThoughts3-1.2M#1000 \ + --seq_kd false \ + --lmbda 1 \ + --beta 1 \ + --torch_dtype bfloat16 \ + --num_train_epochs 1 \ + --per_device_train_batch_size 1 \ + --learning_rate 1e-5 \ + --gradient_accumulation_steps 1 \ + --save_steps 1000 \ + --save_total_limit 2 \ + --logging_steps 1 \ + --max_length 16000 \ + --max_completion_length 100 \ + --output_dir output \ + --warmup_ratio 0.05 \ + --save_only_model true \ + --dataloader_num_workers 64 \ + --dataset_num_proc 4 \ + --deepspeed zero2 \ + --teacher_deepspeed zero3 \ + --attn_impl flash_attn \ + --use_vllm true \ + --vllm_mode server \ + --vllm_server_host 127.0.0.1 \ + --vllm_server_port 8000 diff --git a/examples/train/rft/rft.py b/examples/train/rft/rft.py index 8c617fbc31..16316d9b85 100644 --- a/examples/train/rft/rft.py +++ b/examples/train/rft/rft.py @@ -22,7 +22,7 @@ def do_sample(model: str, model_type: str, dataset: List[str], iter: int): for device in range(device_count): sample_cmd = (f'{conda_prefix} USE_OPENCOMPASS_EVALUATOR=True CUDA_VISIBLE_DEVICES={device} swift sample ' f'--model {model} --model_type {model_type} ' - f'--dataset {" ".join(dataset)} ' + f'--dataset {' '.join(dataset)} ' f'--data_range {device} {device_count} ' f'--max_length 2048 ' f'--system "You are a math model, you should **think step by step** carefully, ' @@ -61,7 +61,7 @@ def do_sample(model: str, model_type: str, dataset: List[str], iter: int): sample_cmd = ( f'{conda_prefix} USE_OPENCOMPASS_EVALUATOR=True CUDA_VISIBLE_DEVICES={device} swift sample ' f'--model {model} --model_type {model_type} ' # change to --resume_from_checkpoint to use the latest optimizer state # noqa - f'--dataset {" ".join(dataset)} ' + f'--dataset {' '.join(dataset)} ' f'--data_range {device} {device_count} ' f'--max_length 2048 ' f'--system "You are a math model, you should **think step by step** carefully, ' @@ -91,7 +91,7 @@ def do_sample(model: str, model_type: str, dataset: List[str], iter: int): for proc, handler in enumerate(handlers): handler.wait() assert os.path.exists(os.path.join('sample_output', f'iter_{iter}_proc_{proc}_sampling.jsonl')), ( - f'{os.path.join("sample_output", f"iter_{iter}_proc_{proc}_sampling.jsonl")} not exists, ' + f'{os.path.join('sample_output', f"iter_{iter}_proc_{proc}_sampling.jsonl")} not exists, ' 'please check the sample logs to get the detail error.') datasets.append(os.path.join('sample_output', f'iter_{iter}_proc_{proc}_sampling.jsonl')) print(f'Sampling done, files:{datasets}', flush=True) @@ -110,7 +110,7 @@ def do_train(model: str, model_type: str, datasets: List[str], iter, cmd='sft'): ga = 128 // get_device_count() // 2 train_cmd = (f'{conda_prefix} {gpu_prefix} swift {cmd} ' f'--model {model} --model_type {model_type} ' - f'--dataset {" ".join(datasets)} ' + f'--dataset {' '.join(datasets)} ' f'--max_length 2048 ' f'--num_train_epochs 1 ' f'--load_args false ' diff --git a/scripts/benchmark/exp_utils.py b/scripts/benchmark/exp_utils.py index b7209691c8..321a89f1a0 100644 --- a/scripts/benchmark/exp_utils.py +++ b/scripts/benchmark/exp_utils.py @@ -122,7 +122,7 @@ def run(self, exp: Experiment): exp.runtime = runtime envs = deepcopy(runtime.get('env', {})) envs.update(os.environ) - logger.info(f'Running cmd: {runtime["running_cmd"]}, env: {runtime.get("env", {})}') + logger.info(f'Running cmd: {runtime['running_cmd']}, env: {runtime.get('env', {})}') os.makedirs('exp', exist_ok=True) log_file = os.path.join('exp', f'{exp.name}.eval.log') exp.handler = subprocess.Popen(runtime['running_cmd'] + f' > {log_file} 2>&1', env=envs, shell=True) @@ -140,7 +140,7 @@ def run(self, exp: Experiment): exp.runtime = runtime envs = deepcopy(runtime.get('env', {})) envs.update(os.environ) - logger.info(f'Running cmd: {runtime["running_cmd"]}, env: {runtime.get("env", {})}') + logger.info(f'Running cmd: {runtime['running_cmd']}, env: {runtime.get('env', {})}') os.makedirs('exp', exist_ok=True) log_file = os.path.join('exp', f'{exp.name}.{exp.cmd}.log') exp.handler = subprocess.Popen(runtime['running_cmd'] + f' > {log_file} 2>&1', env=envs, shell=True) @@ -162,10 +162,10 @@ def _build_eval_cmd(self, exp: Experiment): if best_model_checkpoint is not None: if not os.path.exists(os.path.join(best_model_checkpoint, 'args.json')): cmd = f'swift eval --ckpt_dir {best_model_checkpoint} ' \ - + f'--infer_backend pt --train_type full --eval_dataset {" ".join(eval_dataset)}' + + f'--infer_backend pt --train_type full --eval_dataset {' '.join(eval_dataset)}' else: - cmd = f'swift eval --model {exp.args.get("model")} --infer_backend pt ' \ - f'--eval_dataset {" ".join(eval_dataset)}' + cmd = f'swift eval --model {exp.args.get('model')} --infer_backend pt ' \ + f'--eval_dataset {' '.join(eval_dataset)}' return { 'running_cmd': cmd, diff --git a/scripts/benchmark/generate_report.py b/scripts/benchmark/generate_report.py index 6d618151d4..88804c75ea 100644 --- a/scripts/benchmark/generate_report.py +++ b/scripts/benchmark/generate_report.py @@ -69,23 +69,23 @@ def tuner_hyper_params(self): return '' if args['sft_type'] in ('lora', 'adalora', 'longlora'): if 'lora_rank' in args: - hyper_params += f'rank={args["lora_rank"]}/' \ - f'target={args["lora_target_modules"]}/' \ - f'alpha={args["lora_alpha"]}/' \ - f'lr_ratio={args.get("lora_lr_ratio", None)}/' \ - f'use_rslora={args.get("use_rslora", False)}/' \ - f'use_dora={args.get("use_dora", False)}' + hyper_params += f'rank={args['lora_rank']}/' \ + f'target={args['lora_target_modules']}/' \ + f'alpha={args['lora_alpha']}/' \ + f'lr_ratio={args.get('lora_lr_ratio', None)}/' \ + f'use_rslora={args.get('use_rslora', False)}/' \ + f'use_dora={args.get('use_dora', False)}' else: hyper_params = '' if args['sft_type'] == 'full': if 'use_galore' in args and args['use_galore'] == 'true': - hyper_params += f'galore_rank={args["galore_rank"]}/' \ - f'galore_per_parameter={args["galore_optim_per_parameter"]}/' \ - f'galore_with_embedding={args["galore_with_embedding"]}/' + hyper_params += f'galore_rank={args['galore_rank']}/' \ + f'galore_per_parameter={args['galore_optim_per_parameter']}/' \ + f'galore_with_embedding={args['galore_with_embedding']}/' if args['sft_type'] == 'llamapro': - hyper_params += f'num_blocks={args["llamapro_num_new_blocks"]}/' + hyper_params += f'num_blocks={args['llamapro_num_new_blocks']}/' if 'neftune_noise_alpha' in args and args['neftune_noise_alpha']: - hyper_params += f'neftune_noise_alpha={args["neftune_noise_alpha"]}/' + hyper_params += f'neftune_noise_alpha={args['neftune_noise_alpha']}/' if hyper_params.endswith('/'): hyper_params = hyper_params[:-1] @@ -95,8 +95,8 @@ def tuner_hyper_params(self): def hyper_parameters(self): if 'learning_rate' not in self.args: return '' - return f'lr={self.args["learning_rate"]}/' \ - f'epoch={self.args["num_train_epochs"]}' + return f'lr={self.args['learning_rate']}/' \ + f'epoch={self.args['num_train_epochs']}' @property def train_speed(self): @@ -190,10 +190,10 @@ def generate_sft_report(outputs: List[ModelOutput]): ceval_acc = '' if not ceval_acc else f'**{ceval_acc:.3f}**' line = f'|{output.name}|' \ - f'{output.args["model_type"]}|' \ - f'{output.args.get("dataset")}|' \ - f'{output.args.get("train_dataset_mix_ratio", 0.)}|' \ - f'{output.args.get("sft_type")}|' \ + f'{output.args['model_type']}|' \ + f'{output.args.get('dataset')}|' \ + f'{output.args.get('train_dataset_mix_ratio', 0.)}|' \ + f'{output.args.get('sft_type')}|' \ f'{output.tuner_hyper_params}|' \ f'{output.num_trainable_parameters}({output.trainable_parameters_percentage})|' \ f'{use_flash_attn}|' \ @@ -267,14 +267,14 @@ def generate_export_report(outputs: List[ModelOutput]): ceval_acc = '' if not ceval_acc else f'**{ceval_acc:.3f}**' if output.train_dataset_info: - dataset_info = f'{output.args["dataset"]}/{output.train_dataset_info}' + dataset_info = f'{output.args['dataset']}/{output.train_dataset_info}' else: - dataset_info = f'{output.args["dataset"]}' + dataset_info = f'{output.args['dataset']}' line = f'|{output.name}|' \ - f'{output.args["model_type"]}|' \ + f'{output.args['model_type']}|' \ f'{dataset_info}|' \ - f'{output.args["quant_method"]}|' \ - f'{output.args["quant_bits"]}|' \ + f'{output.args['quant_method']}|' \ + f'{output.args['quant_bits']}|' \ f'{infer_speed}|' \ f'{gsm8k_acc}|' \ f'{arc_acc}|' \ diff --git a/setup.py b/setup.py index 387dea74c0..abd0273dec 100644 --- a/setup.py +++ b/setup.py @@ -16,8 +16,9 @@ def readme(): def get_version(): with open(version_file, 'r', encoding='utf-8') as f: - exec(compile(f.read(), version_file, 'exec')) - return locals()['__version__'] + version_globals = {} + exec(compile(f.read(), version_file, 'exec'), version_globals) + return version_globals['__version__'] def parse_requirements(fname='requirements.txt', with_version=True): diff --git a/swift/llm/argument/base_args/model_args.py b/swift/llm/argument/base_args/model_args.py index e74a2d912e..5f204a388e 100644 --- a/swift/llm/argument/base_args/model_args.py +++ b/swift/llm/argument/base_args/model_args.py @@ -143,7 +143,7 @@ def _init_rope_scaling(self): if self.max_model_len is None: self.max_model_len = rope_model_len elif self.max_model_len > rope_model_len: - logger.warning(f'rope config ({rope_model_len} = {rope_scaling["factor"]} * ' + logger.warning(f'rope config ({rope_model_len} = {rope_scaling['factor']} * ' f'{origin_max_model_len}) should be bigger than max_model_len ' f'from command line ({self.max_model_len})') self.rope_scaling = rope_scaling diff --git a/swift/llm/argument/rlhf_args.py b/swift/llm/argument/rlhf_args.py index 054fffa5d3..6a9294625a 100644 --- a/swift/llm/argument/rlhf_args.py +++ b/swift/llm/argument/rlhf_args.py @@ -120,6 +120,7 @@ class RLHFArguments(TeacherModelArguments, GRPOArguments, PPOArguments, RewardMo lmbda: float = 0.5 seq_kd: bool = False offload_teacher_model: bool = False + teacher_adapter: Optional[str] = None # Teacher adapter plugin for conditional distillation # compat max_new_tokens: Optional[int] = None # use max_completion_length instead @@ -433,6 +434,10 @@ def _init_teacher_deepspeed(self): def _check_gkd(self): if self.rlhf_type != 'gkd': return + # Preserve extra dataset fields for conditional distillation + self.remove_unused_columns = False + logger.info(f'Setting args.remove_unused_columns: {self.remove_unused_columns}') + if is_mp() and self.use_vllm: raise ValueError('GKD with vLLM is not compatible with `device_map`. ' 'Please set NPROC_PER_NODE equal to num_processes.') diff --git a/swift/llm/base.py b/swift/llm/base.py index addd19de26..fd6eefd578 100644 --- a/swift/llm/base.py +++ b/swift/llm/base.py @@ -44,10 +44,10 @@ def _compat_dsw_gradio(args) -> None: os.environ['GRADIO_ROOT_PATH'] = f"/{os.environ['JUPYTER_NAME']}/proxy/{args.server_port}" def main(self): - logger.info(f'Start time of running main: {dt.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")}') + logger.info(f'Start time of running main: {dt.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')}') logger.info(f'swift.__version__: {swift.__version__}') result = self.run() - logger.info(f'End time of running main: {dt.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")}') + logger.info(f'End time of running main: {dt.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')}') return result @abstractmethod diff --git a/swift/llm/dataset/dataset/llm.py b/swift/llm/dataset/dataset/llm.py index 2cba486e3f..9a7ceb15e3 100644 --- a/swift/llm/dataset/dataset/llm.py +++ b/swift/llm/dataset/dataset/llm.py @@ -285,7 +285,7 @@ def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]: chapter = row[f'chapter{i}'] if chapter is not None: cur_prompt += f'{chapter}' - cur_prompt += f'{row["response"]}' + cur_prompt += f'{row['response']}' return super().preprocess({'response': cur_prompt}) diff --git a/swift/llm/dataset/dataset/mllm.py b/swift/llm/dataset/dataset/mllm.py index aa8e9e6ef7..78a1bdf428 100644 --- a/swift/llm/dataset/dataset/mllm.py +++ b/swift/llm/dataset/dataset/mllm.py @@ -579,7 +579,7 @@ def preprocess_row(row: Dict[str, Any]) -> Dict[str, Any]: what = '' if ':' in action: action, what = action[:action.find(':')], action[action.find(':') + 1:] - row['response'] = f'Action: {action.strip()}\nAction Input: {where.strip()}{"," + what.strip()}' + row['response'] = f'Action: {action.strip()}\nAction Input: {where.strip()}{',' + what.strip()}' return row conversations = [] diff --git a/swift/llm/export/ollama.py b/swift/llm/export/ollama.py index c706de25b1..87f63a1584 100644 --- a/swift/llm/export/ollama.py +++ b/swift/llm/export/ollama.py @@ -38,15 +38,15 @@ def export_to_ollama(args: ExportArguments): with open(os.path.join(args.output_dir, 'Modelfile'), 'w', encoding='utf-8') as f: f.write(f'FROM {pt_engine.model_dir}\n') f.write(f'TEMPLATE """{{{{ if .System }}}}' - f'{replace_and_concat(template, template_meta.system_prefix, "{{SYSTEM}}", "{{ .System }}")}' - f'{{{{ else }}}}{replace_and_concat(template, template_meta.prefix, "", "")}' + f'{replace_and_concat(template, template_meta.system_prefix, '{{SYSTEM}}', '{{ .System }}')}' + f'{{{{ else }}}}{replace_and_concat(template, template_meta.prefix, '', '')}' f'{{{{ end }}}}') f.write(f'{{{{ if .Prompt }}}}' - f'{replace_and_concat(template, template_meta.prompt, "{{QUERY}}", "{{ .Prompt }}")}' + f'{replace_and_concat(template, template_meta.prompt, '{{QUERY}}', '{{ .Prompt }}')}' f'{{{{ end }}}}') f.write('{{ .Response }}') f.write(replace_and_concat(template, template_meta.suffix, '', '') + '"""\n') - f.write(f'PARAMETER stop "{replace_and_concat(template, template_meta.suffix, "", "")}"\n') + f.write(f'PARAMETER stop "{replace_and_concat(template, template_meta.suffix, '', '')}"\n') request_config = RequestConfig( temperature=args.temperature, @@ -65,5 +65,5 @@ def export_to_ollama(args: ExportArguments): logger.info('Save Modelfile done, you can start ollama by:') logger.info('> ollama serve') logger.info('In another terminal:') - logger.info('> ollama create my-custom-model ' f'-f {os.path.join(args.output_dir, "Modelfile")}') + logger.info('> ollama create my-custom-model ' f'-f {os.path.join(args.output_dir, 'Modelfile')}') logger.info('> ollama run my-custom-model') diff --git a/swift/llm/infer/protocol.py b/swift/llm/infer/protocol.py index a3926df386..23c174a040 100644 --- a/swift/llm/infer/protocol.py +++ b/swift/llm/infer/protocol.py @@ -323,7 +323,7 @@ class ChatCompletionResponse: def to_cmpl_response(self) -> 'CompletionResponse': self = deepcopy(self) choices = [choice.to_cmpl_choice() for choice in self.choices] - id_ = f'cmpl{self.id[len("chatcmpl"):]}' + id_ = f'cmpl{self.id[len('chatcmpl'):]}' return CompletionResponse(self.model, choices, self.usage, id_, created=self.created) @@ -436,7 +436,7 @@ class ChatCompletionStreamResponse: def to_cmpl_response(self) -> 'CompletionStreamResponse': self = deepcopy(self) choices = [choice.to_cmpl_choice() for choice in self.choices] - id_ = f'cmpl{self.id[len("chatcmpl"):]}' + id_ = f'cmpl{self.id[len('chatcmpl'):]}' return CompletionStreamResponse(self.model, choices, self.usage, id_, created=self.created) diff --git a/swift/llm/template/base.py b/swift/llm/template/base.py index fce20eb7d2..8752c0fe91 100644 --- a/swift/llm/template/base.py +++ b/swift/llm/template/base.py @@ -380,6 +380,33 @@ def _kto_encode(self, inputs: TemplateInputs) -> Dict[str, Any]: def _gkd_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: encoded = self._encode_truncated(inputs) encoded['prompts'] = encoded['input_ids'][:-len(encoded.pop('answer_input_ids'))] + + # If teacher_history exists, encode teacher's prompt + if 'teacher_history' in inputs.extra_kwargs: + teacher_inputs = deepcopy(inputs) + teacher_history = inputs.extra_kwargs['teacher_history'] + + # Construct teacher messages: teacher_history + empty assistant response + if teacher_history and teacher_history[-1]['role'] != 'assistant': + teacher_messages = teacher_history + [{'role': 'assistant', 'content': ''}] + else: + teacher_messages = teacher_history + + # Handle system message + if teacher_messages and teacher_messages[0]['role'] == 'system': + teacher_inputs.system = teacher_messages[0]['content'] + teacher_messages = teacher_messages[1:] + else: + teacher_inputs.system = None + + teacher_inputs.messages = teacher_messages + teacher_encoded = self._encode_truncated(teacher_inputs) + teacher_answer_len = len(teacher_encoded.get('answer_input_ids', [])) + if teacher_answer_len > 0: + encoded['teacher_prompts'] = teacher_encoded['input_ids'][:-teacher_answer_len] + else: + encoded['teacher_prompts'] = teacher_encoded['input_ids'] + for k in list(encoded.keys()): if k.startswith('prompt_') or k.endswith('answer_'): encoded.pop(k, None) @@ -1504,6 +1531,25 @@ def _gkd_data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optiona prompts_res = self._data_collator(prompts_batch, padding_to=padding_to) res['prompts'] = prompts_res.pop('input_ids') res.update({f'prompt_{k}': v for k, v in prompts_res.items()}) + + # Handle teacher_prompts + response for conditional distillation + if any(b.get('teacher_prompts') is not None for b in batch): + # Extract response tokens from input_ids using labels (includes all tokens like <|im_end|>) + teacher_input_ids_batch = [] + for b in batch: + if b.get('teacher_prompts') is not None: + # Extract response tokens from input_ids where labels != -100 + response_mask = [label != -100 for label in b['labels']] + response_tokens = [token for token, keep in zip(b['input_ids'], response_mask) if keep] + # Concatenate teacher_prompts + response_tokens + teacher_input_ids = b['teacher_prompts'] + response_tokens + teacher_input_ids_batch.append({'input_ids': teacher_input_ids}) + + if teacher_input_ids_batch: + teacher_res = self._data_collator(teacher_input_ids_batch, padding_to=padding_to) + res['teacher_input_ids'] = teacher_res.pop('input_ids') + res.update({f'teacher_{k}': v for k, v in teacher_res.items()}) + return res def _embedding_data_collator(self, @@ -1684,7 +1730,7 @@ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[in seq_len = max(seq_lens) if padding_to is None else padding_to res['attention_mask'] = torch.tril(torch.ones( (len(seq_lens), seq_len, seq_len), dtype=torch.bool)).view(len(seq_lens), 1, seq_len, seq_len) - assert res['attention_mask'].dtype is torch.bool, f'attention_mask.dtype: {res["attention_mask"].dtype}' + assert res['attention_mask'].dtype is torch.bool, f'attention_mask.dtype: {res['attention_mask'].dtype}' for i, seq_len in enumerate(seq_lens): res['attention_mask'][i, :, seq_len:] = 0 res['attention_mask'] = ~res['attention_mask'] diff --git a/swift/llm/train/rlhf.py b/swift/llm/train/rlhf.py index 832b458ee5..9930e51587 100644 --- a/swift/llm/train/rlhf.py +++ b/swift/llm/train/rlhf.py @@ -218,8 +218,11 @@ def _get_trainer_kwargs(self): trainer_kwargs['reward_funcs'] = self.args.reward_funcs if self.args.chord_sft_dataset: trainer_kwargs['chord_sft_dataset'], _ = self._prepare_chord_sft_dataset() - if self.args.rlhf_type == 'gkd' and self.args.teacher_deepspeed: - trainer_kwargs['teacher_deepspeed_config'] = self.args.teacher_deepspeed + if self.args.rlhf_type == 'gkd': + if self.args.teacher_deepspeed: + trainer_kwargs['teacher_deepspeed_config'] = self.args.teacher_deepspeed + # Pass rlhf_args to GKDTrainer for accessing custom parameters + trainer_kwargs['rlhf_args'] = self.args return trainer_kwargs diff --git a/swift/llm/train/tuner.py b/swift/llm/train/tuner.py index 447fdaba3c..e18972e16d 100644 --- a/swift/llm/train/tuner.py +++ b/swift/llm/train/tuner.py @@ -120,11 +120,11 @@ def get_multimodal_target_regex( if not target_modules: continue target_modules = [tm for tm in target_modules if tm] - target_pattern = rf'.*\.({"|".join(target_modules)})' if target_modules else '' - rejected_pattern = rf'(?!({"|".join(rejected_modules)}))' if rejected_modules else '' + target_pattern = rf'.*\.({'|'.join(target_modules)})' if target_modules else '' + rejected_pattern = rf'(?!({'|'.join(rejected_modules)}))' if rejected_modules else '' res.append(rf'{rejected_pattern}{module}{target_pattern}') - return rf'^({"|".join(res)})$' + return rf'^({'|'.join(res)})$' def get_target_modules(args, model) -> Union[str, List[str]]: diff --git a/swift/megatron/utils/config.py b/swift/megatron/utils/config.py index b8320e809e..fa28afaea1 100644 --- a/swift/megatron/utils/config.py +++ b/swift/megatron/utils/config.py @@ -149,7 +149,7 @@ def convert_hf_config(config) -> Dict[str, Any]: res['mrope_interleaved'] = mrope_interleaved if first_k_dense_replace is not None: - res['moe_layer_freq'] = f'[0]*{first_k_dense_replace}+[1]*{res["num_layers"] - first_k_dense_replace}' + res['moe_layer_freq'] = f'[0]*{first_k_dense_replace}+[1]*{res['num_layers'] - first_k_dense_replace}' if res.get('moe_router_score_function', 'softmax') == 'sigmoid': res['moe_router_enable_expert_bias'] = True if n_shared_experts is not None and 'moe_shared_expert_intermediate_size' not in res: diff --git a/swift/megatron/utils/utils.py b/swift/megatron/utils/utils.py index 393e5983a9..82eb100c81 100644 --- a/swift/megatron/utils/utils.py +++ b/swift/megatron/utils/utils.py @@ -76,11 +76,11 @@ def get_multimodal_target_regex( if not target_modules: continue target_modules = [tm for tm in target_modules if tm] - target_pattern = rf'.*\.({"|".join(target_modules)})' if target_modules else '' - rejected_pattern = rf'(?!({"|".join(rejected_modules)}))' if rejected_modules else '' + target_pattern = rf'.*\.({'|'.join(target_modules)})' if target_modules else '' + rejected_pattern = rf'(?!({'|'.join(rejected_modules)}))' if rejected_modules else '' res.append(rf'{rejected_pattern}{module}{target_pattern}') - return rf'^({"|".join(res)})$' + return rf'^({'|'.join(res)})$' def get_target_modules(args, model): diff --git a/swift/plugin/__init__.py b/swift/plugin/__init__.py index 3e581ee80e..680bc04769 100644 --- a/swift/plugin/__init__.py +++ b/swift/plugin/__init__.py @@ -17,6 +17,7 @@ from .rm_plugin import rm_plugins from .env import envs, Env from .context_manager import context_managers, ContextManager + from .teacher_adapter import teacher_adapters, TeacherAdapter else: _import_structure = { @@ -33,6 +34,7 @@ 'rm_plugin': ['rm_plugins'], 'env': ['envs', 'Env'], 'context_manager': ['context_managers', 'ContextManager'], + 'teacher_adapter': ['teacher_adapters', 'TeacherAdapter'], } import sys diff --git a/swift/plugin/agent_template/base.py b/swift/plugin/agent_template/base.py index cf5f87bc36..7bea1b164d 100644 --- a/swift/plugin/agent_template/base.py +++ b/swift/plugin/agent_template/base.py @@ -97,8 +97,8 @@ def _format_tool_calls(self, tool_call_messages) -> str: tool_calls = [] for message in tool_call_messages: tool_call = self._parse_tool_call(message['content']) - tool_calls.append(f'{self.keyword.action} {tool_call["name"]}\n' - f'{self.keyword.action_input} {tool_call["arguments"]}\n') + tool_calls.append(f'{self.keyword.action} {tool_call['name']}\n' + f'{self.keyword.action_input} {tool_call['arguments']}\n') tool_calls.append(self.keyword.observation) return ''.join(tool_calls) diff --git a/swift/plugin/agent_template/deepseek_v3_1.py b/swift/plugin/agent_template/deepseek_v3_1.py index 7d0fb694ae..5c1ffb7b88 100644 --- a/swift/plugin/agent_template/deepseek_v3_1.py +++ b/swift/plugin/agent_template/deepseek_v3_1.py @@ -31,11 +31,11 @@ def get_toolcall(self, response: str) -> List['Function']: return functions def _get_tool_responses(self, tool_messages): - return ''.join(f'<|tool▁output▁begin|>{tool_message["content"]}<|tool▁output▁end|>' + return ''.join(f'<|tool▁output▁begin|>{tool_message['content']}<|tool▁output▁end|>' for tool_message in tool_messages) def _get_tool_calls(self, tool_calls: List[str]): - return f'<|tool▁calls▁begin|>{"".join(tool_calls)}<|tool▁calls▁end|>' + return f'<|tool▁calls▁begin|>{''.join(tool_calls)}<|tool▁calls▁end|>' def _format_tool_responses( self, diff --git a/swift/plugin/agent_template/glm4.py b/swift/plugin/agent_template/glm4.py index fc3461fcb7..7fb9650761 100644 --- a/swift/plugin/agent_template/glm4.py +++ b/swift/plugin/agent_template/glm4.py @@ -70,7 +70,7 @@ def _format_tool_calls(self, tool_call_messages) -> str: tool_calls = [] for message in tool_call_messages: tool_call = self._parse_tool_call(message['content']) - tool_calls.append(f'{tool_call["name"]}\n{tool_call["arguments"]}') + tool_calls.append(f'{tool_call['name']}\n{tool_call['arguments']}') return '<|assistant|>'.join(tool_calls) + '<|observation|>' diff --git a/swift/plugin/agent_template/mistral.py b/swift/plugin/agent_template/mistral.py index 120a3e21a9..9c36ca7bd5 100644 --- a/swift/plugin/agent_template/mistral.py +++ b/swift/plugin/agent_template/mistral.py @@ -47,7 +47,7 @@ def _format_tool_responses( for tool_message in tool_messages: tool_content = tool_message['content'] # append `[TOOL_RESULTS]{"content": {{ .Content }}}[/TOOL_RESULTS]` to res_tool - res_tool.append(f'[TOOL_RESULTS]{json.dumps({"content": tool_content}, ensure_ascii=False)}[/TOOL_RESULTS]') + res_tool.append(f'[TOOL_RESULTS]{json.dumps({'content': tool_content}, ensure_ascii=False)}[/TOOL_RESULTS]') total_tool = '\n'.join(res_tool) for context in prompt: if isinstance(context, str): diff --git a/swift/plugin/agent_template/seed_oss.py b/swift/plugin/agent_template/seed_oss.py index 97d6891012..bed4a42043 100644 --- a/swift/plugin/agent_template/seed_oss.py +++ b/swift/plugin/agent_template/seed_oss.py @@ -80,7 +80,7 @@ def _build_tool_def_string(self, tool: dict) -> str: ] param_str = ','.join(params) - docstring_parts = [' """', f' {func.get("description", "").strip()}'] + docstring_parts = [' """', f' {func.get('description', '').strip()}'] if properties: docstring_parts.append('\n Args:') @@ -135,7 +135,7 @@ def _format_tools(self, tools: List[Union[str, dict]], system: Optional[str] = N doubao_prompt = ('You are Doubao, a helpful AI assistant. ' 'You may call one or more functions to assist with the user query.') return (f'{doubao_prompt}\n\n{tool_defs_joined}\n{tool_call_format_instruction}\n' - f'{split_token}\n{system or ""}') + f'{split_token}\n{system or ''}') def _format_tool_calls(self, tool_call_messages: List[dict]) -> str: formatted_calls = [] diff --git a/swift/plugin/teacher_adapter.py b/swift/plugin/teacher_adapter.py new file mode 100644 index 0000000000..bf533a9539 --- /dev/null +++ b/swift/plugin/teacher_adapter.py @@ -0,0 +1,90 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Dict, Any, Optional + +if TYPE_CHECKING: + from swift.llm.utils import Messages + import torch + + +class TeacherAdapter(ABC): + """Base class for transforming student context to teacher context in GKD training.""" + + @abstractmethod + def shape_context(self, data_dict: Dict[str, Any]) -> 'Messages': + """Transform student messages to teacher messages. + + Args: + data_dict: Complete data dictionary containing: + - 'messages': Student model's messages (OpenAI format) + - Other fields like 'dataset', 'images', etc. for flexible usage + + Returns: + Teacher model's messages + """ + pass + + def get_loss_mask(self, student_logits: 'torch.Tensor', teacher_logits: 'torch.Tensor', + mask: 'torch.Tensor', **kwargs) -> Optional['torch.Tensor']: + """Optionally modify the loss mask to control which tokens participate in distillation. + + Args: + student_logits: Student model logits, shape (batch_size, seq_len, vocab_size) + teacher_logits: Teacher model logits, shape (batch_size, seq_len, vocab_size) + mask: Current mask indicating response tokens, shape (batch_size, seq_len) + True means the position is a response token (labels != -100 after shift) + **kwargs: Additional information like 'inputs', 'labels', etc. + + Returns: + Modified mask with same shape as input mask, or None to use original mask. + True means the position participates in loss computation. + + Example: + # Only train on first 50 tokens + last 5 tokens (to ensure learning stop token) + new_mask = torch.zeros_like(mask) + for i in range(mask.shape[0]): + response_indices = mask[i].nonzero(as_tuple=True)[0] + if len(response_indices) > 0: + new_mask[i, response_indices[:50]] = True # First 50 tokens + new_mask[i, response_indices[-6:-1]] = True # Last 5 valid predictions + return new_mask + """ + return None # Default: use original mask + + +class DefaultTeacherAdapter(TeacherAdapter): + """Default: teacher uses the same context as student.""" + + def shape_context(self, data_dict: Dict[str, Any]) -> 'Messages': + return data_dict['messages'] + + +class MathTeacherAdapter(TeacherAdapter): + """Example: add extra instructions to system prompt for teacher.""" + + def shape_context(self, data_dict: Dict[str, Any]) -> 'Messages': + # Create a copy to avoid modifying original + history = data_dict['messages'] + teacher_history = history.copy() + + # Example: enhance system prompt for teacher + if teacher_history and teacher_history[0]['role'] == 'system': + teacher_history[0] = { + 'role': 'system', + 'content': teacher_history[0]['content'] + '\n\nYou are a math expert, solve problems step by step.' + } + else: + # Insert system prompt at the beginning + teacher_history.insert(0, { + 'role': 'system', + 'content': 'You are a math expert, solve problems step by step.' + }) + + return teacher_history + + +# Registry for teacher adapter plugins +teacher_adapters = { + 'default': DefaultTeacherAdapter, + 'math_teacher': MathTeacherAdapter, +} diff --git a/swift/trainers/rlhf_trainer/gkd_trainer.py b/swift/trainers/rlhf_trainer/gkd_trainer.py index 2f87c8447a..5d44f8acd9 100644 --- a/swift/trainers/rlhf_trainer/gkd_trainer.py +++ b/swift/trainers/rlhf_trainer/gkd_trainer.py @@ -29,6 +29,7 @@ def __init__(self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = Non teacher_model = kwargs.pop('teacher_model') teacher_deepspeed_config = kwargs.pop('teacher_deepspeed_config', None) self.vllm_client = kwargs.pop('vllm_client', None) + rlhf_args = kwargs.pop('rlhf_args', None) kwargs['data_collator'] = identity_data_collator super().__init__(model, None, *_args, **kwargs) args = kwargs['args'] @@ -36,6 +37,19 @@ def __init__(self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = Non self.temperature = args.temperature self.seq_kd = args.seq_kd self.generation_config = model.generation_config + + # Initialize teacher adapter for conditional distillation + teacher_adapter_name = getattr(rlhf_args, 'teacher_adapter', None) if rlhf_args else getattr(args, 'teacher_adapter', None) + self.teacher_adapter = None + if teacher_adapter_name is not None: + from swift.plugin import teacher_adapters + adapter_cls = teacher_adapters.get(teacher_adapter_name) + if adapter_cls is None: + raise ValueError(f'Unknown teacher_adapter: {teacher_adapter_name}. ' + f'Available adapters: {list(teacher_adapters.keys())}') + self.teacher_adapter = adapter_cls() + logger.info(f"Loaded teacher_adapter: {type(self.teacher_adapter).__name__}") + self._metrics = {'train': defaultdict(list), 'eval': defaultdict(list)} self._total_train_tokens = 0 @@ -69,7 +83,10 @@ def __init__(self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = Non def generate_on_policy_outputs(self, model, inputs, generation_config, pad_token_id=None): assert not self.template.padding_free, 'generate not support padding_free/packing.' # Generate output with respect to the prompt only - model_inputs = {k: v for k, v in inputs.items() if not k.startswith('prompt') and k != 'labels'} + model_inputs = {k: v for k, v in inputs.items() + if not k.startswith('prompt') + and not k.startswith('teacher_prompt') + and k not in ('labels', 'teacher_prompts')} model_inputs['input_ids'] = inputs['prompts'] model_inputs.update({k[len('prompt_'):]: v for k, v in inputs.items() if k.startswith('prompt_')}) model_inputs.pop('position_ids', None) @@ -103,7 +120,10 @@ def generate_on_policy_outputs(self, model, inputs, generation_config, pad_token return generated_tokens, new_attention_mask, new_labels def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): - model_inputs = {k: v for k, v in inputs.items() if k not in {'prompt', 'labels'}} + model_inputs = {k: v for k, v in inputs.items() + if k not in {'prompt', 'labels', 'teacher_prompts'} + and not k.startswith('prompt_') + and not k.startswith('teacher_prompt_')} # If generate is used, then use_logits_to_keep must be set to False. use_logits_to_keep = self.get_use_logits_to_keep(True) if use_logits_to_keep: @@ -114,13 +134,36 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N # compute student output outputs_student = model(**model_inputs) - model_inputs.pop('labels', None) + # Prepare teacher model inputs + teacher_model_inputs = model_inputs.copy() + if 'teacher_input_ids' in inputs: + # Conditional distillation: use pre-concatenated teacher_input_ids + teacher_model_inputs['input_ids'] = inputs['teacher_input_ids'] + teacher_model_inputs['attention_mask'] = inputs['teacher_attention_mask'] + if 'teacher_position_ids' in inputs: + teacher_model_inputs['position_ids'] = inputs['teacher_position_ids'] + load_context = self.load_teacher_model_context() if self.args.offload_teacher_model else nullcontext() with torch.no_grad(), load_context: - outputs_teacher = self.teacher_model(**model_inputs) + outputs_teacher = self.teacher_model(**teacher_model_inputs) + # Extract logits for distillation + # With use_logits_to_keep, both student and teacher logits are aligned with labels shifted_labels = torch.roll(inputs['labels'], shifts=-1, dims=1) mask = shifted_labels != -100 + + # Allow teacher_adapter to modify the loss mask + if self.teacher_adapter is not None: + modified_mask = self.teacher_adapter.get_loss_mask( + student_logits=outputs_student.logits, + teacher_logits=outputs_teacher.logits, + mask=mask, + inputs=inputs, + labels=inputs['labels'] + ) + if modified_mask is not None: + mask = modified_mask + shifted_student_logits = outputs_student.logits[mask][None] shifted_teacher_logits = outputs_teacher.logits[mask][None] @@ -151,11 +194,30 @@ def _prepare_batch_inputs(self, inputs: list) -> Dict[str, torch.Tensor]: batch_encoded_inputs = [] for data in inputs: + # Save response_token_ids for later use in conditional distillation + response_token_ids = data.get('response_token_ids') + print(f"{self.tokenizer.decode(response_token_ids)}") if 'response_token_ids' in data and data['response_token_ids']: from .utils import replace_assistant_response_with_ids data['messages'] = replace_assistant_response_with_ids(data['messages'], data['response_token_ids']) + # Use teacher_adapter to generate teacher_history if available + if self.teacher_adapter is not None: + # Prepare student messages (without final assistant response if present) + student_messages = data['messages'][:-1] if data['messages'][-1]['role'] == 'assistant' else data['messages'] + # Pass complete data dict to adapter (similar to GRPO's reward_model_plugin design) + # Temporarily replace messages with student_messages for adapter + original_messages = data['messages'] + data['messages'] = student_messages + teacher_history = self.teacher_adapter.shape_context(data) + data['messages'] = original_messages # Restore + data['teacher_history'] = teacher_history encoded = template.encode(data, return_length=True) + + # Preserve response_token_ids for conditional distillation + if response_token_ids is not None: + encoded['response_token_ids'] = response_token_ids + batch_encoded_inputs.append(encoded) from swift.llm import to_device diff --git a/swift/tuners/peft.py b/swift/tuners/peft.py index 3ad2ffc726..a689bb1ded 100644 --- a/swift/tuners/peft.py +++ b/swift/tuners/peft.py @@ -329,7 +329,7 @@ def __new_init__(self, model: torch.nn.Module, config: Dict[str, LoraConfig], ad # Compatible with SwiftModel def dummy_function(*args, **kwargs): - logger.warn(f'The function {kwargs["func"]} has no effects, consider using other functions.') + logger.warn(f'The function {kwargs['func']} has no effects, consider using other functions.') PeftModel.activate_adapter = PeftModel.set_adapter PeftModel.deactivate_adapter = partial(dummy_function, func='deactivate_adapter') diff --git a/swift/ui/llm_infer/runtime.py b/swift/ui/llm_infer/runtime.py index e6b73391d3..a14badc688 100644 --- a/swift/ui/llm_infer/runtime.py +++ b/swift/ui/llm_infer/runtime.py @@ -226,7 +226,7 @@ def construct_running_task(proc): create_time_formatted = datetime.fromtimestamp(create_time).strftime('%Y-%m-%d, %H:%M') return f'pid:{pid}/create:{create_time_formatted}' \ - f'/running:{format_time(ts - create_time)}/cmd:{" ".join(proc.cmdline())}' + f'/running:{format_time(ts - create_time)}/cmd:{' '.join(proc.cmdline())}' @classmethod def parse_info_from_cmdline(cls, task): diff --git a/swift/ui/llm_train/llm_train.py b/swift/ui/llm_train/llm_train.py index 3f21cbba04..bbcf549d23 100644 --- a/swift/ui/llm_train/llm_train.py +++ b/swift/ui/llm_train/llm_train.py @@ -386,7 +386,7 @@ def train(cls, *args): cmd = train_stage if kwargs.get('deepspeed'): - more_params_cmd += f' --deepspeed {kwargs.pop("deepspeed")} ' + more_params_cmd += f' --deepspeed {kwargs.pop('deepspeed')} ' use_liger_kernel = kwargs.get('use_liger_kernel', None) if use_liger_kernel: kwargs.pop('use_liger_kernel') @@ -452,7 +452,7 @@ def train(cls, *args): devices = [d for d in devices if d] if other_kwargs['use_ddp']: assert int(other_kwargs['ddp_num']) > 0 - ddp_param = f'NPROC_PER_NODE={int(other_kwargs["ddp_num"])}' + ddp_param = f'NPROC_PER_NODE={int(other_kwargs['ddp_num'])}' all_envs['NPROC_PER_NODE'] = str(other_kwargs['ddp_num']) assert (len(devices) == 1 or 'cpu' not in devices) gpus = ','.join(devices) diff --git a/swift/ui/llm_train/runtime.py b/swift/ui/llm_train/runtime.py index eaa6166660..4468efab78 100644 --- a/swift/ui/llm_train/runtime.py +++ b/swift/ui/llm_train/runtime.py @@ -537,7 +537,7 @@ def construct_running_task(proc): create_time_formatted = datetime.fromtimestamp(create_time).strftime('%Y-%m-%d, %H:%M') return f'pid:{pid}/create:{create_time_formatted}' \ - f'/running:{format_time(ts-create_time)}/cmd:{" ".join(proc.cmdline())}' + f'/running:{format_time(ts-create_time)}/cmd:{' '.join(proc.cmdline())}' @staticmethod def parse_info_from_cmdline(task): diff --git a/tests/test_align/test_template/test_agent.py b/tests/test_align/test_template/test_agent.py index 9f3d347606..0f97b43e22 100644 --- a/tests/test_align/test_template/test_agent.py +++ b/tests/test_align/test_template/test_agent.py @@ -103,8 +103,8 @@ def test_react_en(): ' and the condition is sunny with a humidity of 50%.') template.set_mode('train') encoded = template.encode({'messages': messages}) - print(f'input_ids: {template.safe_decode(encoded["input_ids"])}') - print(f'labels: {template.safe_decode(encoded["labels"])}') + print(f'input_ids: {template.safe_decode(encoded['input_ids'])}') + print(f'labels: {template.safe_decode(encoded['labels'])}') dataset = load_dataset('AI-ModelScope/function-calling-chatml')[0] data = dataset[6] @@ -112,8 +112,8 @@ def test_react_en(): data['messages'].insert(3, data['messages'][3]) template.template_backend = 'swift' encoded = template.encode(data) - print(f'input_ids: {template.safe_decode(encoded["input_ids"])}') - print(f'labels: {template.safe_decode(encoded["labels"])}') + print(f'input_ids: {template.safe_decode(encoded['input_ids'])}') + print(f'labels: {template.safe_decode(encoded['labels'])}') def test_react_zh(): @@ -139,8 +139,8 @@ def test_qwen_en(): 'is at 50%. Enjoy the nice weather!') template.set_mode('train') encoded = template.encode({'messages': messages}) - print(f'input_ids: {template.safe_decode(encoded["input_ids"])}') - print(f'labels: {template.safe_decode(encoded["labels"])}') + print(f'input_ids: {template.safe_decode(encoded['input_ids'])}') + print(f'labels: {template.safe_decode(encoded['labels'])}') dataset = load_dataset('AI-ModelScope/function-calling-chatml')[0] data = dataset[6] @@ -148,8 +148,8 @@ def test_qwen_en(): data['messages'].insert(3, data['messages'][3]) template.template_backend = 'swift' encoded = template.encode(data) - print(f'input_ids: {template.safe_decode(encoded["input_ids"])}') - print(f'labels: {template.safe_decode(encoded["labels"])}') + print(f'input_ids: {template.safe_decode(encoded['input_ids'])}') + print(f'labels: {template.safe_decode(encoded['labels'])}') def test_qwen_zh(): @@ -175,8 +175,8 @@ def test_qwen_en_parallel(): 'and the humidity is at 50%. Enjoy the nice weather!') template.set_mode('train') encoded = template.encode({'messages': messages}) - print(f'input_ids: {template.safe_decode(encoded["input_ids"])}') - print(f'labels: {template.safe_decode(encoded["labels"])}') + print(f'input_ids: {template.safe_decode(encoded['input_ids'])}') + print(f'labels: {template.safe_decode(encoded['labels'])}') dataset = load_dataset('AI-ModelScope/function-calling-chatml')[0] data = dataset[6] @@ -184,8 +184,8 @@ def test_qwen_en_parallel(): data['messages'].insert(3, data['messages'][3]) template.template_backend = 'swift' encoded = template.encode(data) - print(f'input_ids: {template.safe_decode(encoded["input_ids"])}') - print(f'labels: {template.safe_decode(encoded["labels"])}') + print(f'input_ids: {template.safe_decode(encoded['input_ids'])}') + print(f'labels: {template.safe_decode(encoded['labels'])}') def test_qwen_zh_parallel(): @@ -213,8 +213,8 @@ def test_hermes(): 'and the humidity is at 50%. Enjoy the nice weather!') template.set_mode('train') encoded = template.encode({'messages': messages}) - print(f'input_ids: {template.safe_decode(encoded["input_ids"])}') - print(f'labels: {template.safe_decode(encoded["labels"])}') + print(f'input_ids: {template.safe_decode(encoded['input_ids'])}') + print(f'labels: {template.safe_decode(encoded['labels'])}') dataset = load_dataset('AI-ModelScope/function-calling-chatml')[0] data = dataset[6] @@ -222,12 +222,12 @@ def test_hermes(): data['messages'].insert(3, data['messages'][3]) template.template_backend = 'swift' encoded = template.encode(data) - print(f'input_ids: {template.safe_decode(encoded["input_ids"])}') - print(f'labels: {template.safe_decode(encoded["labels"])}') + print(f'input_ids: {template.safe_decode(encoded['input_ids'])}') + print(f'labels: {template.safe_decode(encoded['labels'])}') template.template_backend = 'jinja' encoded2 = template.encode(data) - print(f'input_ids: {template.safe_decode(encoded2["input_ids"])}') - print(f'labels: {template.safe_decode(encoded2["labels"])}') + print(f'input_ids: {template.safe_decode(encoded2['input_ids'])}') + print(f'labels: {template.safe_decode(encoded2['labels'])}') assert encoded['input_ids'] == encoded2['input_ids'][:-1] @@ -262,8 +262,8 @@ def test_glm4_0414(): assert messages[-1]['content'] == '根据天气预报工具,北京今天的空气质量指数为10,属于良好水平;上海今天的空气质量指数为72,属于轻度污染水平。' template.set_mode('train') encoded = template.encode({'messages': messages}) - print(f'input_ids: {template.safe_decode(encoded["input_ids"])}') - print(f'labels: {template.safe_decode(encoded["labels"])}') + print(f'input_ids: {template.safe_decode(encoded['input_ids'])}') + print(f'labels: {template.safe_decode(encoded['labels'])}') dataset = load_dataset('AI-ModelScope/function-calling-chatml')[0] data = dataset[6] @@ -271,8 +271,8 @@ def test_glm4_0414(): data['messages'].insert(3, data['messages'][3]) template.template_backend = 'swift' encoded = template.encode(data) - print(f'input_ids: {template.safe_decode(encoded["input_ids"])}') - print(f'labels: {template.safe_decode(encoded["labels"])}') + print(f'input_ids: {template.safe_decode(encoded['input_ids'])}') + print(f'labels: {template.safe_decode(encoded['labels'])}') def test_llama3(): @@ -284,8 +284,8 @@ def test_llama3(): template.set_mode('train') encoded = template.encode({'messages': messages}) - print(f'input_ids: {template.safe_decode(encoded["input_ids"])}') - print(f'labels: {template.safe_decode(encoded["labels"])}') + print(f'input_ids: {template.safe_decode(encoded['input_ids'])}') + print(f'labels: {template.safe_decode(encoded['labels'])}') dataset = load_dataset('AI-ModelScope/function-calling-chatml')[0] data = dataset[6] @@ -293,8 +293,8 @@ def test_llama3(): data['messages'].insert(3, data['messages'][3]) template.template_backend = 'swift' encoded = template.encode(data) - print(f'input_ids: {template.safe_decode(encoded["input_ids"])}') - print(f'labels: {template.safe_decode(encoded["labels"])}') + print(f'input_ids: {template.safe_decode(encoded['input_ids'])}') + print(f'labels: {template.safe_decode(encoded['labels'])}') def test_llama4(): @@ -305,8 +305,8 @@ def test_llama4(): messages = _infer(engine) template.set_mode('train') encoded = template.encode({'messages': messages}) - print(f'input_ids: {template.safe_decode(encoded["input_ids"])}') - print(f'labels: {template.safe_decode(encoded["labels"])}') + print(f'input_ids: {template.safe_decode(encoded['input_ids'])}') + print(f'labels: {template.safe_decode(encoded['labels'])}') def test_hunyuan(): @@ -322,12 +322,12 @@ def test_hunyuan(): template.template_backend = 'swift' template.set_mode('train') encoded = template.encode(data) - print(f'input_ids: {template.safe_decode(encoded["input_ids"])}') - print(f'labels: {template.safe_decode(encoded["labels"])}') + print(f'input_ids: {template.safe_decode(encoded['input_ids'])}') + print(f'labels: {template.safe_decode(encoded['labels'])}') template.template_backend = 'jinja' encoded2 = template.encode(data) - print(f'input_ids: {template.safe_decode(encoded2["input_ids"])}') - print(f'labels: {template.safe_decode(encoded2["labels"])}') + print(f'input_ids: {template.safe_decode(encoded2['input_ids'])}') + print(f'labels: {template.safe_decode(encoded2['labels'])}') assert encoded['input_ids'][:-1] == encoded2['input_ids'] @@ -344,12 +344,12 @@ def test_glm4_5(): template.template_backend = 'swift' template.set_mode('train') encoded = template.encode(data) - print(f'input_ids: {template.safe_decode(encoded["input_ids"])}') - print(f'labels: {template.safe_decode(encoded["labels"])}') + print(f'input_ids: {template.safe_decode(encoded['input_ids'])}') + print(f'labels: {template.safe_decode(encoded['labels'])}') template.template_backend = 'jinja' encoded2 = template.encode(data) - print(f'input_ids: {template.safe_decode(encoded2["input_ids"])}') - print(f'labels: {template.safe_decode(encoded2["labels"])}') + print(f'input_ids: {template.safe_decode(encoded2['input_ids'])}') + print(f'labels: {template.safe_decode(encoded2['labels'])}') assert encoded['input_ids'][:-1] == encoded2['input_ids'] @@ -368,12 +368,12 @@ def test_qwen3_coder(): template.template_backend = 'swift' template.set_mode('train') encoded = template.encode(data) - print(f'input_ids: {template.safe_decode(encoded["input_ids"])}') - print(f'labels: {template.safe_decode(encoded["labels"])}') + print(f'input_ids: {template.safe_decode(encoded['input_ids'])}') + print(f'labels: {template.safe_decode(encoded['labels'])}') template.template_backend = 'jinja' encoded2 = template.encode(data) - print(f'input_ids: {template.safe_decode(encoded2["input_ids"])}') - print(f'labels: {template.safe_decode(encoded2["labels"])}') + print(f'input_ids: {template.safe_decode(encoded2['input_ids'])}') + print(f'labels: {template.safe_decode(encoded2['labels'])}') assert encoded['input_ids'] == encoded2['input_ids'][:-1] @@ -392,12 +392,12 @@ def test_deepseek_v3_1(): template.template_backend = 'swift' template.set_mode('train') encoded = template.encode(data) - print(f'input_ids: {template.safe_decode(encoded["input_ids"])}') - print(f'labels: {template.safe_decode(encoded["labels"])}') + print(f'input_ids: {template.safe_decode(encoded['input_ids'])}') + print(f'labels: {template.safe_decode(encoded['labels'])}') template.template_backend = 'jinja' encoded2 = template.encode(data) - print(f'input_ids: {template.safe_decode(encoded2["input_ids"])}') - print(f'labels: {template.safe_decode(encoded2["labels"])}') + print(f'input_ids: {template.safe_decode(encoded2['input_ids'])}') + print(f'labels: {template.safe_decode(encoded2['labels'])}') expected_input_ids = ( '<|begin▁of▁sentence|>\n\n## Tools\n' @@ -499,15 +499,15 @@ def test_seed_oss(): template.template_backend = 'swift' template.set_mode('train') encoded = template.encode(data) - print(f'input_ids: {template.safe_decode(encoded["input_ids"])}') - print(f'labels: {template.safe_decode(encoded["labels"])}') + print(f'input_ids: {template.safe_decode(encoded['input_ids'])}') + print(f'labels: {template.safe_decode(encoded['labels'])}') import re expected_input_ids = re.sub( r'.*?', '', template.safe_decode(encoded['input_ids']), flags=re.DOTALL) template.template_backend = 'jinja' encoded2 = template.encode(data) - print(f'input_ids: {template.safe_decode(encoded2["input_ids"])}') - print(f'labels: {template.safe_decode(encoded2["labels"])}') + print(f'input_ids: {template.safe_decode(encoded2['input_ids'])}') + print(f'labels: {template.safe_decode(encoded2['labels'])}') assert template.safe_decode(encoded2['input_ids']) == expected_input_ids diff --git a/train_gkd_debug.py b/train_gkd_debug.py new file mode 100644 index 0000000000..262dc60c77 --- /dev/null +++ b/train_gkd_debug.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +GKD Training Script - Debuggable Version +可以直接在IDE中运行和调试,无需使用shell脚本 + +使用方法: +1. 单GPU调试 (不使用DeepSpeed): + python train_gkd_debug.py + +2. 单GPU使用DeepSpeed: + torchrun --nproc_per_node=1 train_gkd_debug.py + +3. 多GPU: + torchrun --nproc_per_node=N train_gkd_debug.py +""" +import os +from swift.llm import rlhf_main, RLHFArguments + + +def main(): + # 设置环境变量 + os.environ['WANDB_API_KEY'] = '28e11ef52849c4640b93051377be27eafac62c44' + os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' + os.environ['CUDA_VISIBLE_DEVICES'] = '0' + + # 创建参数对象 + args = RLHFArguments( + # RLHF类型和模型配置 + rlhf_type='gkd', + model='Qwen/Qwen2.5-0.5B-Instruct', + + # Teacher模型配置 + external_plugins=['rp_teacher_adapter.py'], + teacher_adapter='rp_teacher_adapter', + teacher_model='Qwen/Qwen2.5-0.5B-Instruct', + # 调试模式: 关闭teacher的DeepSpeed + teacher_deepspeed=None, # 改为 'zero3' 启用DeepSpeed + # teacher_deepspeed='zero3', # 启用DeepSpeed时取消注释 + + # 训练类型 + train_type='full', + + # 数据集配置 + dataset=[ + 'processed_training_data_final_fixed.jsonl', + 'benchmark_datasets_filtered_14k/alignbench_v1.1.jsonl', + 'benchmark_datasets_filtered_14k/arena_hard.jsonl', + 'benchmark_datasets_filtered_14k/arena_multi_turn_10-20.jsonl', + 'benchmark_datasets_filtered_14k/creative_writing_v3.jsonl', + 'benchmark_datasets_filtered_14k/ifeval.jsonl', + 'benchmark_datasets_filtered_14k/wildchat_gpt4_10-40.jsonl', + 'benchmark_datasets_filtered_14k/writingbench.jsonl', + ], + + # GKD特定参数 + seq_kd=False, + lmbda=1, + beta=1, + + # 长度和截断配置 + truncation_strategy='delete', + max_length=17000, + max_model_len=17000, + max_completion_length=200, + + # 训练超参数 + torch_dtype='bfloat16', + num_train_epochs=2, + per_device_train_batch_size=1, + learning_rate=1e-5, + gradient_accumulation_steps=1, + warmup_ratio=0.05, + + # 保存和日志配置 + save_steps=500, + save_total_limit=8, + logging_steps=1, + output_dir='condition_distill', + save_only_model=True, + + # 数据加载配置 + dataloader_num_workers=64, + dataset_num_proc=4, + + # DeepSpeed配置 + # 调试模式: 关闭DeepSpeed以避免分布式训练的复杂性 + # 如需使用DeepSpeed,请用: torchrun --nproc_per_node=1 train_gkd_debug.py + deepspeed=None, # 改为 'zero3' 启用DeepSpeed + # deepspeed='zero3', # 启用DeepSpeed时取消注释 + + # 注意力实现 + attn_impl='flash_attn', + + # vLLM配置 + use_vllm=True, + vllm_mode='server', + vllm_server_host=['127.0.0.1'], # 必须是列表 + vllm_server_port=[8001], # 必须是列表 + + # 其他配置 + report_to=['wandb'], + use_hf=True, + ) + + # 运行训练 + # 在这里设置断点即可调试 + rlhf_main(args) + + +if __name__ == '__main__': + main() diff --git a/train_gkd_with_deepspeed.sh b/train_gkd_with_deepspeed.sh new file mode 100755 index 0000000000..159a656f2e --- /dev/null +++ b/train_gkd_with_deepspeed.sh @@ -0,0 +1,10 @@ +#!/bin/bash +# 使用 DeepSpeed 启动训练脚本 + +export WANDB_API_KEY=28e11ef52849c4640b93051377be27eafac62c44 +export PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' +export CUDA_VISIBLE_DEVICES=0 + +# 使用 torchrun 启动以支持 DeepSpeed +# 注意: 在运行前,需要在 train_gkd_debug.py 中启用 DeepSpeed 配置 +torchrun --nproc_per_node=1 train_gkd_debug.py