diff --git a/docs/source/Instruction/GKD.md b/docs/source/Instruction/GKD.md
index 377e610aaa..d051a262f9 100644
--- a/docs/source/Instruction/GKD.md
+++ b/docs/source/Instruction/GKD.md
@@ -210,3 +210,78 @@ swift rlhf \
```
相关脚本可以参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/on_policy_distillation.sh)
+
+## 条件蒸馏(Conditional Distillation)
+
+条件蒸馏允许教师模型和学生模型使用**不同的上下文或提示词**进行训练,从而实现更灵活的知识迁移策略。例如:
+- 教师模型接收包含额外专家指导的提示词
+- 教师模型接收任务重构后的输入(如摘要、翻译等)
+- 教师模型使用更长的上下文信息
+
+### TeacherAdapter 插件系统
+
+通过实现 `TeacherAdapter` 接口,可以自定义教师模型的上下文转换逻辑:
+
+```python
+# swift/plugin/teacher_adapter.py
+from swift.plugin import TeacherAdapter
+
+class MyTeacherAdapter(TeacherAdapter):
+ def shape_context(self, history):
+ """将学生的消息转换为教师的消息
+
+ Args:
+ history: 学生模型的消息列表(OpenAI 格式)
+
+ Returns:
+ 教师模型的消息列表
+ """
+ # 为教师添加额外的系统提示
+ teacher_history = history.copy()
+ if teacher_history and teacher_history[0]['role'] == 'system':
+ teacher_history[0]['content'] += '\n\n你是一位专业领域专家。'
+ else:
+ teacher_history.insert(0, {
+ 'role': 'system',
+ 'content': '你是一位专业领域专家。'
+ })
+ return teacher_history
+
+# 注册到插件系统
+from swift.plugin import teacher_adapters
+teacher_adapters['my_adapter'] = MyTeacherAdapter
+```
+
+### 内置 Adapter
+
+SWIFT 提供两个内置的 teacher adapter:
+
+| Adapter | 说明 |
+|---------|------|
+| `default` | 默认:教师使用与学生相同的上下文 |
+| `example` | 示例:为教师添加额外的系统提示 |
+
+### 使用方法
+
+```bash
+swift rlhf \
+ --rlhf_type gkd \
+ --model Qwen/Qwen2.5-0.5B-Instruct \
+ --teacher_model Qwen/Qwen2.5-7B-Instruct \
+ --teacher_adapter example \
+ --dataset your_dataset.jsonl \
+ ...
+```
+
+### 工作原理
+
+在条件蒸馏中:
+
+1. **学生模型**处理原始输入:`[prompt_student] + [response]`
+2. **教师模型**处理转换后的输入:`[prompt_teacher] + [response]`
+3. 两个模型在**相同的 response tokens** 上计算 logits
+4. 使用这些 logits 计算蒸馏损失
+
+其中 `prompt_teacher` 由 `teacher_adapter.shape_context()` 从 `prompt_student` 转换而来,而 `response` 部分保持不变。
+
+训练脚本参考[这里](../../../examples/train/on_policy_condition_distillation.sh)
diff --git a/docs/source_en/Instruction/GKD.md b/docs/source_en/Instruction/GKD.md
index 98a0b6b2cb..b676d78528 100644
--- a/docs/source_en/Instruction/GKD.md
+++ b/docs/source_en/Instruction/GKD.md
@@ -212,3 +212,78 @@ We can achieve the [On-Policy Distillation](https://thinkingmachines.ai/blog/on-
```
For a complete implementation, refer to the example script [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/on_policy_distillation.sh).
+
+## Conditional Distillation
+
+Conditional distillation enables the teacher and student models to use **different contexts or prompts** during training, allowing for more flexible knowledge transfer strategies. For example:
+- Teacher model receives prompts with additional expert guidance
+- Teacher model receives task-reformulated inputs (e.g., summaries, translations)
+- Teacher model uses longer context information
+
+### TeacherAdapter Plugin System
+
+You can customize the teacher model's context transformation logic by implementing the `TeacherAdapter` interface:
+
+```python
+# swift/plugin/teacher_adapter.py
+from swift.plugin import TeacherAdapter
+
+class MyTeacherAdapter(TeacherAdapter):
+ def shape_context(self, history):
+ """Transform student messages to teacher messages
+
+ Args:
+ history: Student model's message list (OpenAI format)
+
+ Returns:
+ Teacher model's message list
+ """
+ # Add extra system prompt for teacher
+ teacher_history = history.copy()
+ if teacher_history and teacher_history[0]['role'] == 'system':
+ teacher_history[0]['content'] += '\n\nYou are an expert with extensive knowledge.'
+ else:
+ teacher_history.insert(0, {
+ 'role': 'system',
+ 'content': 'You are an expert with extensive knowledge.'
+ })
+ return teacher_history
+
+# Register to plugin system
+from swift.plugin import teacher_adapters
+teacher_adapters['my_adapter'] = MyTeacherAdapter
+```
+
+### Built-in Adapters
+
+SWIFT provides two built-in teacher adapters:
+
+| Adapter | Description |
+|---------|-------------|
+| `default` | Default: teacher uses the same context as student |
+| `example` | Example: adds extra instructions to system prompt for teacher |
+
+### Usage
+
+```bash
+swift rlhf \
+ --rlhf_type gkd \
+ --model Qwen/Qwen2.5-0.5B-Instruct \
+ --teacher_model Qwen/Qwen2.5-7B-Instruct \
+ --teacher_adapter example \
+ --dataset your_dataset.jsonl \
+ ...
+```
+
+### How It Works
+
+In conditional distillation:
+
+1. **Student model** processes original input: `[prompt_student] + [response]`
+2. **Teacher model** processes transformed input: `[prompt_teacher] + [response]`
+3. Both models compute logits on the **same response tokens**
+4. Distillation loss is calculated using these logits
+
+Where `prompt_teacher` is transformed from `prompt_student` by `teacher_adapter.shape_context()`, while the `response` part remains unchanged.
+
+Training script reference [here](../../../examples/train/on_policy_condition_distillation.sh)
diff --git a/examples/train/on_policy_condition_distillation.sh b/examples/train/on_policy_condition_distillation.sh
new file mode 100644
index 0000000000..15503b3ae9
--- /dev/null
+++ b/examples/train/on_policy_condition_distillation.sh
@@ -0,0 +1,45 @@
+# On-Policy Distillation https://thinkingmachines.ai/blog/on-policy-distillation/
+
+# CUDA_VISIBLE_DEVICES=6 \
+# swift rollout \
+# --template qwen3_nothinking\
+# --model Qwen/Qwen3-8B-Base \
+# --vllm_max_model_len 24192
+
+NPROC_PER_NODE=7 \
+PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \
+CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,7 \
+swift rlhf \
+ --rlhf_type gkd \
+ --teacher_adapter math_teacher \
+ --template qwen3_nothinking\
+ --model Qwen/Qwen3-8B-Base \
+ --teacher_model_type qwen3_nothinking\
+ --teacher_model Qwen/Qwen3-8B-Base \
+ --train_type full \
+ --dataset open-thoughts/OpenThoughts3-1.2M#1000 \
+ --seq_kd false \
+ --lmbda 1 \
+ --beta 1 \
+ --torch_dtype bfloat16 \
+ --num_train_epochs 1 \
+ --per_device_train_batch_size 1 \
+ --learning_rate 1e-5 \
+ --gradient_accumulation_steps 1 \
+ --save_steps 1000 \
+ --save_total_limit 2 \
+ --logging_steps 1 \
+ --max_length 16000 \
+ --max_completion_length 100 \
+ --output_dir output \
+ --warmup_ratio 0.05 \
+ --save_only_model true \
+ --dataloader_num_workers 64 \
+ --dataset_num_proc 4 \
+ --deepspeed zero2 \
+ --teacher_deepspeed zero3 \
+ --attn_impl flash_attn \
+ --use_vllm true \
+ --vllm_mode server \
+ --vllm_server_host 127.0.0.1 \
+ --vllm_server_port 8000
diff --git a/examples/train/rft/rft.py b/examples/train/rft/rft.py
index 8c617fbc31..16316d9b85 100644
--- a/examples/train/rft/rft.py
+++ b/examples/train/rft/rft.py
@@ -22,7 +22,7 @@ def do_sample(model: str, model_type: str, dataset: List[str], iter: int):
for device in range(device_count):
sample_cmd = (f'{conda_prefix} USE_OPENCOMPASS_EVALUATOR=True CUDA_VISIBLE_DEVICES={device} swift sample '
f'--model {model} --model_type {model_type} '
- f'--dataset {" ".join(dataset)} '
+ f'--dataset {' '.join(dataset)} '
f'--data_range {device} {device_count} '
f'--max_length 2048 '
f'--system "You are a math model, you should **think step by step** carefully, '
@@ -61,7 +61,7 @@ def do_sample(model: str, model_type: str, dataset: List[str], iter: int):
sample_cmd = (
f'{conda_prefix} USE_OPENCOMPASS_EVALUATOR=True CUDA_VISIBLE_DEVICES={device} swift sample '
f'--model {model} --model_type {model_type} ' # change to --resume_from_checkpoint to use the latest optimizer state # noqa
- f'--dataset {" ".join(dataset)} '
+ f'--dataset {' '.join(dataset)} '
f'--data_range {device} {device_count} '
f'--max_length 2048 '
f'--system "You are a math model, you should **think step by step** carefully, '
@@ -91,7 +91,7 @@ def do_sample(model: str, model_type: str, dataset: List[str], iter: int):
for proc, handler in enumerate(handlers):
handler.wait()
assert os.path.exists(os.path.join('sample_output', f'iter_{iter}_proc_{proc}_sampling.jsonl')), (
- f'{os.path.join("sample_output", f"iter_{iter}_proc_{proc}_sampling.jsonl")} not exists, '
+ f'{os.path.join('sample_output', f"iter_{iter}_proc_{proc}_sampling.jsonl")} not exists, '
'please check the sample logs to get the detail error.')
datasets.append(os.path.join('sample_output', f'iter_{iter}_proc_{proc}_sampling.jsonl'))
print(f'Sampling done, files:{datasets}', flush=True)
@@ -110,7 +110,7 @@ def do_train(model: str, model_type: str, datasets: List[str], iter, cmd='sft'):
ga = 128 // get_device_count() // 2
train_cmd = (f'{conda_prefix} {gpu_prefix} swift {cmd} '
f'--model {model} --model_type {model_type} '
- f'--dataset {" ".join(datasets)} '
+ f'--dataset {' '.join(datasets)} '
f'--max_length 2048 '
f'--num_train_epochs 1 '
f'--load_args false '
diff --git a/scripts/benchmark/exp_utils.py b/scripts/benchmark/exp_utils.py
index b7209691c8..321a89f1a0 100644
--- a/scripts/benchmark/exp_utils.py
+++ b/scripts/benchmark/exp_utils.py
@@ -122,7 +122,7 @@ def run(self, exp: Experiment):
exp.runtime = runtime
envs = deepcopy(runtime.get('env', {}))
envs.update(os.environ)
- logger.info(f'Running cmd: {runtime["running_cmd"]}, env: {runtime.get("env", {})}')
+ logger.info(f'Running cmd: {runtime['running_cmd']}, env: {runtime.get('env', {})}')
os.makedirs('exp', exist_ok=True)
log_file = os.path.join('exp', f'{exp.name}.eval.log')
exp.handler = subprocess.Popen(runtime['running_cmd'] + f' > {log_file} 2>&1', env=envs, shell=True)
@@ -140,7 +140,7 @@ def run(self, exp: Experiment):
exp.runtime = runtime
envs = deepcopy(runtime.get('env', {}))
envs.update(os.environ)
- logger.info(f'Running cmd: {runtime["running_cmd"]}, env: {runtime.get("env", {})}')
+ logger.info(f'Running cmd: {runtime['running_cmd']}, env: {runtime.get('env', {})}')
os.makedirs('exp', exist_ok=True)
log_file = os.path.join('exp', f'{exp.name}.{exp.cmd}.log')
exp.handler = subprocess.Popen(runtime['running_cmd'] + f' > {log_file} 2>&1', env=envs, shell=True)
@@ -162,10 +162,10 @@ def _build_eval_cmd(self, exp: Experiment):
if best_model_checkpoint is not None:
if not os.path.exists(os.path.join(best_model_checkpoint, 'args.json')):
cmd = f'swift eval --ckpt_dir {best_model_checkpoint} ' \
- + f'--infer_backend pt --train_type full --eval_dataset {" ".join(eval_dataset)}'
+ + f'--infer_backend pt --train_type full --eval_dataset {' '.join(eval_dataset)}'
else:
- cmd = f'swift eval --model {exp.args.get("model")} --infer_backend pt ' \
- f'--eval_dataset {" ".join(eval_dataset)}'
+ cmd = f'swift eval --model {exp.args.get('model')} --infer_backend pt ' \
+ f'--eval_dataset {' '.join(eval_dataset)}'
return {
'running_cmd': cmd,
diff --git a/scripts/benchmark/generate_report.py b/scripts/benchmark/generate_report.py
index 6d618151d4..88804c75ea 100644
--- a/scripts/benchmark/generate_report.py
+++ b/scripts/benchmark/generate_report.py
@@ -69,23 +69,23 @@ def tuner_hyper_params(self):
return ''
if args['sft_type'] in ('lora', 'adalora', 'longlora'):
if 'lora_rank' in args:
- hyper_params += f'rank={args["lora_rank"]}/' \
- f'target={args["lora_target_modules"]}/' \
- f'alpha={args["lora_alpha"]}/' \
- f'lr_ratio={args.get("lora_lr_ratio", None)}/' \
- f'use_rslora={args.get("use_rslora", False)}/' \
- f'use_dora={args.get("use_dora", False)}'
+ hyper_params += f'rank={args['lora_rank']}/' \
+ f'target={args['lora_target_modules']}/' \
+ f'alpha={args['lora_alpha']}/' \
+ f'lr_ratio={args.get('lora_lr_ratio', None)}/' \
+ f'use_rslora={args.get('use_rslora', False)}/' \
+ f'use_dora={args.get('use_dora', False)}'
else:
hyper_params = ''
if args['sft_type'] == 'full':
if 'use_galore' in args and args['use_galore'] == 'true':
- hyper_params += f'galore_rank={args["galore_rank"]}/' \
- f'galore_per_parameter={args["galore_optim_per_parameter"]}/' \
- f'galore_with_embedding={args["galore_with_embedding"]}/'
+ hyper_params += f'galore_rank={args['galore_rank']}/' \
+ f'galore_per_parameter={args['galore_optim_per_parameter']}/' \
+ f'galore_with_embedding={args['galore_with_embedding']}/'
if args['sft_type'] == 'llamapro':
- hyper_params += f'num_blocks={args["llamapro_num_new_blocks"]}/'
+ hyper_params += f'num_blocks={args['llamapro_num_new_blocks']}/'
if 'neftune_noise_alpha' in args and args['neftune_noise_alpha']:
- hyper_params += f'neftune_noise_alpha={args["neftune_noise_alpha"]}/'
+ hyper_params += f'neftune_noise_alpha={args['neftune_noise_alpha']}/'
if hyper_params.endswith('/'):
hyper_params = hyper_params[:-1]
@@ -95,8 +95,8 @@ def tuner_hyper_params(self):
def hyper_parameters(self):
if 'learning_rate' not in self.args:
return ''
- return f'lr={self.args["learning_rate"]}/' \
- f'epoch={self.args["num_train_epochs"]}'
+ return f'lr={self.args['learning_rate']}/' \
+ f'epoch={self.args['num_train_epochs']}'
@property
def train_speed(self):
@@ -190,10 +190,10 @@ def generate_sft_report(outputs: List[ModelOutput]):
ceval_acc = '' if not ceval_acc else f'**{ceval_acc:.3f}**'
line = f'|{output.name}|' \
- f'{output.args["model_type"]}|' \
- f'{output.args.get("dataset")}|' \
- f'{output.args.get("train_dataset_mix_ratio", 0.)}|' \
- f'{output.args.get("sft_type")}|' \
+ f'{output.args['model_type']}|' \
+ f'{output.args.get('dataset')}|' \
+ f'{output.args.get('train_dataset_mix_ratio', 0.)}|' \
+ f'{output.args.get('sft_type')}|' \
f'{output.tuner_hyper_params}|' \
f'{output.num_trainable_parameters}({output.trainable_parameters_percentage})|' \
f'{use_flash_attn}|' \
@@ -267,14 +267,14 @@ def generate_export_report(outputs: List[ModelOutput]):
ceval_acc = '' if not ceval_acc else f'**{ceval_acc:.3f}**'
if output.train_dataset_info:
- dataset_info = f'{output.args["dataset"]}/{output.train_dataset_info}'
+ dataset_info = f'{output.args['dataset']}/{output.train_dataset_info}'
else:
- dataset_info = f'{output.args["dataset"]}'
+ dataset_info = f'{output.args['dataset']}'
line = f'|{output.name}|' \
- f'{output.args["model_type"]}|' \
+ f'{output.args['model_type']}|' \
f'{dataset_info}|' \
- f'{output.args["quant_method"]}|' \
- f'{output.args["quant_bits"]}|' \
+ f'{output.args['quant_method']}|' \
+ f'{output.args['quant_bits']}|' \
f'{infer_speed}|' \
f'{gsm8k_acc}|' \
f'{arc_acc}|' \
diff --git a/setup.py b/setup.py
index 387dea74c0..abd0273dec 100644
--- a/setup.py
+++ b/setup.py
@@ -16,8 +16,9 @@ def readme():
def get_version():
with open(version_file, 'r', encoding='utf-8') as f:
- exec(compile(f.read(), version_file, 'exec'))
- return locals()['__version__']
+ version_globals = {}
+ exec(compile(f.read(), version_file, 'exec'), version_globals)
+ return version_globals['__version__']
def parse_requirements(fname='requirements.txt', with_version=True):
diff --git a/swift/llm/argument/base_args/model_args.py b/swift/llm/argument/base_args/model_args.py
index e74a2d912e..5f204a388e 100644
--- a/swift/llm/argument/base_args/model_args.py
+++ b/swift/llm/argument/base_args/model_args.py
@@ -143,7 +143,7 @@ def _init_rope_scaling(self):
if self.max_model_len is None:
self.max_model_len = rope_model_len
elif self.max_model_len > rope_model_len:
- logger.warning(f'rope config ({rope_model_len} = {rope_scaling["factor"]} * '
+ logger.warning(f'rope config ({rope_model_len} = {rope_scaling['factor']} * '
f'{origin_max_model_len}) should be bigger than max_model_len '
f'from command line ({self.max_model_len})')
self.rope_scaling = rope_scaling
diff --git a/swift/llm/argument/rlhf_args.py b/swift/llm/argument/rlhf_args.py
index 054fffa5d3..6a9294625a 100644
--- a/swift/llm/argument/rlhf_args.py
+++ b/swift/llm/argument/rlhf_args.py
@@ -120,6 +120,7 @@ class RLHFArguments(TeacherModelArguments, GRPOArguments, PPOArguments, RewardMo
lmbda: float = 0.5
seq_kd: bool = False
offload_teacher_model: bool = False
+ teacher_adapter: Optional[str] = None # Teacher adapter plugin for conditional distillation
# compat
max_new_tokens: Optional[int] = None # use max_completion_length instead
@@ -433,6 +434,10 @@ def _init_teacher_deepspeed(self):
def _check_gkd(self):
if self.rlhf_type != 'gkd':
return
+ # Preserve extra dataset fields for conditional distillation
+ self.remove_unused_columns = False
+ logger.info(f'Setting args.remove_unused_columns: {self.remove_unused_columns}')
+
if is_mp() and self.use_vllm:
raise ValueError('GKD with vLLM is not compatible with `device_map`. '
'Please set NPROC_PER_NODE equal to num_processes.')
diff --git a/swift/llm/base.py b/swift/llm/base.py
index addd19de26..fd6eefd578 100644
--- a/swift/llm/base.py
+++ b/swift/llm/base.py
@@ -44,10 +44,10 @@ def _compat_dsw_gradio(args) -> None:
os.environ['GRADIO_ROOT_PATH'] = f"/{os.environ['JUPYTER_NAME']}/proxy/{args.server_port}"
def main(self):
- logger.info(f'Start time of running main: {dt.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")}')
+ logger.info(f'Start time of running main: {dt.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')}')
logger.info(f'swift.__version__: {swift.__version__}')
result = self.run()
- logger.info(f'End time of running main: {dt.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")}')
+ logger.info(f'End time of running main: {dt.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')}')
return result
@abstractmethod
diff --git a/swift/llm/dataset/dataset/llm.py b/swift/llm/dataset/dataset/llm.py
index 2cba486e3f..9a7ceb15e3 100644
--- a/swift/llm/dataset/dataset/llm.py
+++ b/swift/llm/dataset/dataset/llm.py
@@ -285,7 +285,7 @@ def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]:
chapter = row[f'chapter{i}']
if chapter is not None:
cur_prompt += f'{chapter}'
- cur_prompt += f'{row["response"]}'
+ cur_prompt += f'{row['response']}'
return super().preprocess({'response': cur_prompt})
diff --git a/swift/llm/dataset/dataset/mllm.py b/swift/llm/dataset/dataset/mllm.py
index aa8e9e6ef7..78a1bdf428 100644
--- a/swift/llm/dataset/dataset/mllm.py
+++ b/swift/llm/dataset/dataset/mllm.py
@@ -579,7 +579,7 @@ def preprocess_row(row: Dict[str, Any]) -> Dict[str, Any]:
what = ''
if ':' in action:
action, what = action[:action.find(':')], action[action.find(':') + 1:]
- row['response'] = f'Action: {action.strip()}\nAction Input: {where.strip()}{"," + what.strip()}'
+ row['response'] = f'Action: {action.strip()}\nAction Input: {where.strip()}{',' + what.strip()}'
return row
conversations = []
diff --git a/swift/llm/export/ollama.py b/swift/llm/export/ollama.py
index c706de25b1..87f63a1584 100644
--- a/swift/llm/export/ollama.py
+++ b/swift/llm/export/ollama.py
@@ -38,15 +38,15 @@ def export_to_ollama(args: ExportArguments):
with open(os.path.join(args.output_dir, 'Modelfile'), 'w', encoding='utf-8') as f:
f.write(f'FROM {pt_engine.model_dir}\n')
f.write(f'TEMPLATE """{{{{ if .System }}}}'
- f'{replace_and_concat(template, template_meta.system_prefix, "{{SYSTEM}}", "{{ .System }}")}'
- f'{{{{ else }}}}{replace_and_concat(template, template_meta.prefix, "", "")}'
+ f'{replace_and_concat(template, template_meta.system_prefix, '{{SYSTEM}}', '{{ .System }}')}'
+ f'{{{{ else }}}}{replace_and_concat(template, template_meta.prefix, '', '')}'
f'{{{{ end }}}}')
f.write(f'{{{{ if .Prompt }}}}'
- f'{replace_and_concat(template, template_meta.prompt, "{{QUERY}}", "{{ .Prompt }}")}'
+ f'{replace_and_concat(template, template_meta.prompt, '{{QUERY}}', '{{ .Prompt }}')}'
f'{{{{ end }}}}')
f.write('{{ .Response }}')
f.write(replace_and_concat(template, template_meta.suffix, '', '') + '"""\n')
- f.write(f'PARAMETER stop "{replace_and_concat(template, template_meta.suffix, "", "")}"\n')
+ f.write(f'PARAMETER stop "{replace_and_concat(template, template_meta.suffix, '', '')}"\n')
request_config = RequestConfig(
temperature=args.temperature,
@@ -65,5 +65,5 @@ def export_to_ollama(args: ExportArguments):
logger.info('Save Modelfile done, you can start ollama by:')
logger.info('> ollama serve')
logger.info('In another terminal:')
- logger.info('> ollama create my-custom-model ' f'-f {os.path.join(args.output_dir, "Modelfile")}')
+ logger.info('> ollama create my-custom-model ' f'-f {os.path.join(args.output_dir, 'Modelfile')}')
logger.info('> ollama run my-custom-model')
diff --git a/swift/llm/infer/protocol.py b/swift/llm/infer/protocol.py
index a3926df386..23c174a040 100644
--- a/swift/llm/infer/protocol.py
+++ b/swift/llm/infer/protocol.py
@@ -323,7 +323,7 @@ class ChatCompletionResponse:
def to_cmpl_response(self) -> 'CompletionResponse':
self = deepcopy(self)
choices = [choice.to_cmpl_choice() for choice in self.choices]
- id_ = f'cmpl{self.id[len("chatcmpl"):]}'
+ id_ = f'cmpl{self.id[len('chatcmpl'):]}'
return CompletionResponse(self.model, choices, self.usage, id_, created=self.created)
@@ -436,7 +436,7 @@ class ChatCompletionStreamResponse:
def to_cmpl_response(self) -> 'CompletionStreamResponse':
self = deepcopy(self)
choices = [choice.to_cmpl_choice() for choice in self.choices]
- id_ = f'cmpl{self.id[len("chatcmpl"):]}'
+ id_ = f'cmpl{self.id[len('chatcmpl'):]}'
return CompletionStreamResponse(self.model, choices, self.usage, id_, created=self.created)
diff --git a/swift/llm/template/base.py b/swift/llm/template/base.py
index fce20eb7d2..8752c0fe91 100644
--- a/swift/llm/template/base.py
+++ b/swift/llm/template/base.py
@@ -380,6 +380,33 @@ def _kto_encode(self, inputs: TemplateInputs) -> Dict[str, Any]:
def _gkd_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
encoded = self._encode_truncated(inputs)
encoded['prompts'] = encoded['input_ids'][:-len(encoded.pop('answer_input_ids'))]
+
+ # If teacher_history exists, encode teacher's prompt
+ if 'teacher_history' in inputs.extra_kwargs:
+ teacher_inputs = deepcopy(inputs)
+ teacher_history = inputs.extra_kwargs['teacher_history']
+
+ # Construct teacher messages: teacher_history + empty assistant response
+ if teacher_history and teacher_history[-1]['role'] != 'assistant':
+ teacher_messages = teacher_history + [{'role': 'assistant', 'content': ''}]
+ else:
+ teacher_messages = teacher_history
+
+ # Handle system message
+ if teacher_messages and teacher_messages[0]['role'] == 'system':
+ teacher_inputs.system = teacher_messages[0]['content']
+ teacher_messages = teacher_messages[1:]
+ else:
+ teacher_inputs.system = None
+
+ teacher_inputs.messages = teacher_messages
+ teacher_encoded = self._encode_truncated(teacher_inputs)
+ teacher_answer_len = len(teacher_encoded.get('answer_input_ids', []))
+ if teacher_answer_len > 0:
+ encoded['teacher_prompts'] = teacher_encoded['input_ids'][:-teacher_answer_len]
+ else:
+ encoded['teacher_prompts'] = teacher_encoded['input_ids']
+
for k in list(encoded.keys()):
if k.startswith('prompt_') or k.endswith('answer_'):
encoded.pop(k, None)
@@ -1504,6 +1531,25 @@ def _gkd_data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optiona
prompts_res = self._data_collator(prompts_batch, padding_to=padding_to)
res['prompts'] = prompts_res.pop('input_ids')
res.update({f'prompt_{k}': v for k, v in prompts_res.items()})
+
+ # Handle teacher_prompts + response for conditional distillation
+ if any(b.get('teacher_prompts') is not None for b in batch):
+ # Extract response tokens from input_ids using labels (includes all tokens like <|im_end|>)
+ teacher_input_ids_batch = []
+ for b in batch:
+ if b.get('teacher_prompts') is not None:
+ # Extract response tokens from input_ids where labels != -100
+ response_mask = [label != -100 for label in b['labels']]
+ response_tokens = [token for token, keep in zip(b['input_ids'], response_mask) if keep]
+ # Concatenate teacher_prompts + response_tokens
+ teacher_input_ids = b['teacher_prompts'] + response_tokens
+ teacher_input_ids_batch.append({'input_ids': teacher_input_ids})
+
+ if teacher_input_ids_batch:
+ teacher_res = self._data_collator(teacher_input_ids_batch, padding_to=padding_to)
+ res['teacher_input_ids'] = teacher_res.pop('input_ids')
+ res.update({f'teacher_{k}': v for k, v in teacher_res.items()})
+
return res
def _embedding_data_collator(self,
@@ -1684,7 +1730,7 @@ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[in
seq_len = max(seq_lens) if padding_to is None else padding_to
res['attention_mask'] = torch.tril(torch.ones(
(len(seq_lens), seq_len, seq_len), dtype=torch.bool)).view(len(seq_lens), 1, seq_len, seq_len)
- assert res['attention_mask'].dtype is torch.bool, f'attention_mask.dtype: {res["attention_mask"].dtype}'
+ assert res['attention_mask'].dtype is torch.bool, f'attention_mask.dtype: {res['attention_mask'].dtype}'
for i, seq_len in enumerate(seq_lens):
res['attention_mask'][i, :, seq_len:] = 0
res['attention_mask'] = ~res['attention_mask']
diff --git a/swift/llm/train/rlhf.py b/swift/llm/train/rlhf.py
index 832b458ee5..9930e51587 100644
--- a/swift/llm/train/rlhf.py
+++ b/swift/llm/train/rlhf.py
@@ -218,8 +218,11 @@ def _get_trainer_kwargs(self):
trainer_kwargs['reward_funcs'] = self.args.reward_funcs
if self.args.chord_sft_dataset:
trainer_kwargs['chord_sft_dataset'], _ = self._prepare_chord_sft_dataset()
- if self.args.rlhf_type == 'gkd' and self.args.teacher_deepspeed:
- trainer_kwargs['teacher_deepspeed_config'] = self.args.teacher_deepspeed
+ if self.args.rlhf_type == 'gkd':
+ if self.args.teacher_deepspeed:
+ trainer_kwargs['teacher_deepspeed_config'] = self.args.teacher_deepspeed
+ # Pass rlhf_args to GKDTrainer for accessing custom parameters
+ trainer_kwargs['rlhf_args'] = self.args
return trainer_kwargs
diff --git a/swift/llm/train/tuner.py b/swift/llm/train/tuner.py
index 447fdaba3c..e18972e16d 100644
--- a/swift/llm/train/tuner.py
+++ b/swift/llm/train/tuner.py
@@ -120,11 +120,11 @@ def get_multimodal_target_regex(
if not target_modules:
continue
target_modules = [tm for tm in target_modules if tm]
- target_pattern = rf'.*\.({"|".join(target_modules)})' if target_modules else ''
- rejected_pattern = rf'(?!({"|".join(rejected_modules)}))' if rejected_modules else ''
+ target_pattern = rf'.*\.({'|'.join(target_modules)})' if target_modules else ''
+ rejected_pattern = rf'(?!({'|'.join(rejected_modules)}))' if rejected_modules else ''
res.append(rf'{rejected_pattern}{module}{target_pattern}')
- return rf'^({"|".join(res)})$'
+ return rf'^({'|'.join(res)})$'
def get_target_modules(args, model) -> Union[str, List[str]]:
diff --git a/swift/megatron/utils/config.py b/swift/megatron/utils/config.py
index b8320e809e..fa28afaea1 100644
--- a/swift/megatron/utils/config.py
+++ b/swift/megatron/utils/config.py
@@ -149,7 +149,7 @@ def convert_hf_config(config) -> Dict[str, Any]:
res['mrope_interleaved'] = mrope_interleaved
if first_k_dense_replace is not None:
- res['moe_layer_freq'] = f'[0]*{first_k_dense_replace}+[1]*{res["num_layers"] - first_k_dense_replace}'
+ res['moe_layer_freq'] = f'[0]*{first_k_dense_replace}+[1]*{res['num_layers'] - first_k_dense_replace}'
if res.get('moe_router_score_function', 'softmax') == 'sigmoid':
res['moe_router_enable_expert_bias'] = True
if n_shared_experts is not None and 'moe_shared_expert_intermediate_size' not in res:
diff --git a/swift/megatron/utils/utils.py b/swift/megatron/utils/utils.py
index 393e5983a9..82eb100c81 100644
--- a/swift/megatron/utils/utils.py
+++ b/swift/megatron/utils/utils.py
@@ -76,11 +76,11 @@ def get_multimodal_target_regex(
if not target_modules:
continue
target_modules = [tm for tm in target_modules if tm]
- target_pattern = rf'.*\.({"|".join(target_modules)})' if target_modules else ''
- rejected_pattern = rf'(?!({"|".join(rejected_modules)}))' if rejected_modules else ''
+ target_pattern = rf'.*\.({'|'.join(target_modules)})' if target_modules else ''
+ rejected_pattern = rf'(?!({'|'.join(rejected_modules)}))' if rejected_modules else ''
res.append(rf'{rejected_pattern}{module}{target_pattern}')
- return rf'^({"|".join(res)})$'
+ return rf'^({'|'.join(res)})$'
def get_target_modules(args, model):
diff --git a/swift/plugin/__init__.py b/swift/plugin/__init__.py
index 3e581ee80e..680bc04769 100644
--- a/swift/plugin/__init__.py
+++ b/swift/plugin/__init__.py
@@ -17,6 +17,7 @@
from .rm_plugin import rm_plugins
from .env import envs, Env
from .context_manager import context_managers, ContextManager
+ from .teacher_adapter import teacher_adapters, TeacherAdapter
else:
_import_structure = {
@@ -33,6 +34,7 @@
'rm_plugin': ['rm_plugins'],
'env': ['envs', 'Env'],
'context_manager': ['context_managers', 'ContextManager'],
+ 'teacher_adapter': ['teacher_adapters', 'TeacherAdapter'],
}
import sys
diff --git a/swift/plugin/agent_template/base.py b/swift/plugin/agent_template/base.py
index cf5f87bc36..7bea1b164d 100644
--- a/swift/plugin/agent_template/base.py
+++ b/swift/plugin/agent_template/base.py
@@ -97,8 +97,8 @@ def _format_tool_calls(self, tool_call_messages) -> str:
tool_calls = []
for message in tool_call_messages:
tool_call = self._parse_tool_call(message['content'])
- tool_calls.append(f'{self.keyword.action} {tool_call["name"]}\n'
- f'{self.keyword.action_input} {tool_call["arguments"]}\n')
+ tool_calls.append(f'{self.keyword.action} {tool_call['name']}\n'
+ f'{self.keyword.action_input} {tool_call['arguments']}\n')
tool_calls.append(self.keyword.observation)
return ''.join(tool_calls)
diff --git a/swift/plugin/agent_template/deepseek_v3_1.py b/swift/plugin/agent_template/deepseek_v3_1.py
index 7d0fb694ae..5c1ffb7b88 100644
--- a/swift/plugin/agent_template/deepseek_v3_1.py
+++ b/swift/plugin/agent_template/deepseek_v3_1.py
@@ -31,11 +31,11 @@ def get_toolcall(self, response: str) -> List['Function']:
return functions
def _get_tool_responses(self, tool_messages):
- return ''.join(f'<|tool▁output▁begin|>{tool_message["content"]}<|tool▁output▁end|>'
+ return ''.join(f'<|tool▁output▁begin|>{tool_message['content']}<|tool▁output▁end|>'
for tool_message in tool_messages)
def _get_tool_calls(self, tool_calls: List[str]):
- return f'<|tool▁calls▁begin|>{"".join(tool_calls)}<|tool▁calls▁end|>'
+ return f'<|tool▁calls▁begin|>{''.join(tool_calls)}<|tool▁calls▁end|>'
def _format_tool_responses(
self,
diff --git a/swift/plugin/agent_template/glm4.py b/swift/plugin/agent_template/glm4.py
index fc3461fcb7..7fb9650761 100644
--- a/swift/plugin/agent_template/glm4.py
+++ b/swift/plugin/agent_template/glm4.py
@@ -70,7 +70,7 @@ def _format_tool_calls(self, tool_call_messages) -> str:
tool_calls = []
for message in tool_call_messages:
tool_call = self._parse_tool_call(message['content'])
- tool_calls.append(f'{tool_call["name"]}\n{tool_call["arguments"]}')
+ tool_calls.append(f'{tool_call['name']}\n{tool_call['arguments']}')
return '<|assistant|>'.join(tool_calls) + '<|observation|>'
diff --git a/swift/plugin/agent_template/mistral.py b/swift/plugin/agent_template/mistral.py
index 120a3e21a9..9c36ca7bd5 100644
--- a/swift/plugin/agent_template/mistral.py
+++ b/swift/plugin/agent_template/mistral.py
@@ -47,7 +47,7 @@ def _format_tool_responses(
for tool_message in tool_messages:
tool_content = tool_message['content']
# append `[TOOL_RESULTS]{"content": {{ .Content }}}[/TOOL_RESULTS]` to res_tool
- res_tool.append(f'[TOOL_RESULTS]{json.dumps({"content": tool_content}, ensure_ascii=False)}[/TOOL_RESULTS]')
+ res_tool.append(f'[TOOL_RESULTS]{json.dumps({'content': tool_content}, ensure_ascii=False)}[/TOOL_RESULTS]')
total_tool = '\n'.join(res_tool)
for context in prompt:
if isinstance(context, str):
diff --git a/swift/plugin/agent_template/seed_oss.py b/swift/plugin/agent_template/seed_oss.py
index 97d6891012..bed4a42043 100644
--- a/swift/plugin/agent_template/seed_oss.py
+++ b/swift/plugin/agent_template/seed_oss.py
@@ -80,7 +80,7 @@ def _build_tool_def_string(self, tool: dict) -> str:
]
param_str = ','.join(params)
- docstring_parts = [' """', f' {func.get("description", "").strip()}']
+ docstring_parts = [' """', f' {func.get('description', '').strip()}']
if properties:
docstring_parts.append('\n Args:')
@@ -135,7 +135,7 @@ def _format_tools(self, tools: List[Union[str, dict]], system: Optional[str] = N
doubao_prompt = ('You are Doubao, a helpful AI assistant. '
'You may call one or more functions to assist with the user query.')
return (f'{doubao_prompt}\n\n{tool_defs_joined}\n{tool_call_format_instruction}\n'
- f'{split_token}\n{system or ""}')
+ f'{split_token}\n{system or ''}')
def _format_tool_calls(self, tool_call_messages: List[dict]) -> str:
formatted_calls = []
diff --git a/swift/plugin/teacher_adapter.py b/swift/plugin/teacher_adapter.py
new file mode 100644
index 0000000000..bf533a9539
--- /dev/null
+++ b/swift/plugin/teacher_adapter.py
@@ -0,0 +1,90 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from abc import ABC, abstractmethod
+from typing import TYPE_CHECKING, Dict, Any, Optional
+
+if TYPE_CHECKING:
+ from swift.llm.utils import Messages
+ import torch
+
+
+class TeacherAdapter(ABC):
+ """Base class for transforming student context to teacher context in GKD training."""
+
+ @abstractmethod
+ def shape_context(self, data_dict: Dict[str, Any]) -> 'Messages':
+ """Transform student messages to teacher messages.
+
+ Args:
+ data_dict: Complete data dictionary containing:
+ - 'messages': Student model's messages (OpenAI format)
+ - Other fields like 'dataset', 'images', etc. for flexible usage
+
+ Returns:
+ Teacher model's messages
+ """
+ pass
+
+ def get_loss_mask(self, student_logits: 'torch.Tensor', teacher_logits: 'torch.Tensor',
+ mask: 'torch.Tensor', **kwargs) -> Optional['torch.Tensor']:
+ """Optionally modify the loss mask to control which tokens participate in distillation.
+
+ Args:
+ student_logits: Student model logits, shape (batch_size, seq_len, vocab_size)
+ teacher_logits: Teacher model logits, shape (batch_size, seq_len, vocab_size)
+ mask: Current mask indicating response tokens, shape (batch_size, seq_len)
+ True means the position is a response token (labels != -100 after shift)
+ **kwargs: Additional information like 'inputs', 'labels', etc.
+
+ Returns:
+ Modified mask with same shape as input mask, or None to use original mask.
+ True means the position participates in loss computation.
+
+ Example:
+ # Only train on first 50 tokens + last 5 tokens (to ensure learning stop token)
+ new_mask = torch.zeros_like(mask)
+ for i in range(mask.shape[0]):
+ response_indices = mask[i].nonzero(as_tuple=True)[0]
+ if len(response_indices) > 0:
+ new_mask[i, response_indices[:50]] = True # First 50 tokens
+ new_mask[i, response_indices[-6:-1]] = True # Last 5 valid predictions
+ return new_mask
+ """
+ return None # Default: use original mask
+
+
+class DefaultTeacherAdapter(TeacherAdapter):
+ """Default: teacher uses the same context as student."""
+
+ def shape_context(self, data_dict: Dict[str, Any]) -> 'Messages':
+ return data_dict['messages']
+
+
+class MathTeacherAdapter(TeacherAdapter):
+ """Example: add extra instructions to system prompt for teacher."""
+
+ def shape_context(self, data_dict: Dict[str, Any]) -> 'Messages':
+ # Create a copy to avoid modifying original
+ history = data_dict['messages']
+ teacher_history = history.copy()
+
+ # Example: enhance system prompt for teacher
+ if teacher_history and teacher_history[0]['role'] == 'system':
+ teacher_history[0] = {
+ 'role': 'system',
+ 'content': teacher_history[0]['content'] + '\n\nYou are a math expert, solve problems step by step.'
+ }
+ else:
+ # Insert system prompt at the beginning
+ teacher_history.insert(0, {
+ 'role': 'system',
+ 'content': 'You are a math expert, solve problems step by step.'
+ })
+
+ return teacher_history
+
+
+# Registry for teacher adapter plugins
+teacher_adapters = {
+ 'default': DefaultTeacherAdapter,
+ 'math_teacher': MathTeacherAdapter,
+}
diff --git a/swift/trainers/rlhf_trainer/gkd_trainer.py b/swift/trainers/rlhf_trainer/gkd_trainer.py
index 2f87c8447a..5d44f8acd9 100644
--- a/swift/trainers/rlhf_trainer/gkd_trainer.py
+++ b/swift/trainers/rlhf_trainer/gkd_trainer.py
@@ -29,6 +29,7 @@ def __init__(self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = Non
teacher_model = kwargs.pop('teacher_model')
teacher_deepspeed_config = kwargs.pop('teacher_deepspeed_config', None)
self.vllm_client = kwargs.pop('vllm_client', None)
+ rlhf_args = kwargs.pop('rlhf_args', None)
kwargs['data_collator'] = identity_data_collator
super().__init__(model, None, *_args, **kwargs)
args = kwargs['args']
@@ -36,6 +37,19 @@ def __init__(self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = Non
self.temperature = args.temperature
self.seq_kd = args.seq_kd
self.generation_config = model.generation_config
+
+ # Initialize teacher adapter for conditional distillation
+ teacher_adapter_name = getattr(rlhf_args, 'teacher_adapter', None) if rlhf_args else getattr(args, 'teacher_adapter', None)
+ self.teacher_adapter = None
+ if teacher_adapter_name is not None:
+ from swift.plugin import teacher_adapters
+ adapter_cls = teacher_adapters.get(teacher_adapter_name)
+ if adapter_cls is None:
+ raise ValueError(f'Unknown teacher_adapter: {teacher_adapter_name}. '
+ f'Available adapters: {list(teacher_adapters.keys())}')
+ self.teacher_adapter = adapter_cls()
+ logger.info(f"Loaded teacher_adapter: {type(self.teacher_adapter).__name__}")
+
self._metrics = {'train': defaultdict(list), 'eval': defaultdict(list)}
self._total_train_tokens = 0
@@ -69,7 +83,10 @@ def __init__(self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = Non
def generate_on_policy_outputs(self, model, inputs, generation_config, pad_token_id=None):
assert not self.template.padding_free, 'generate not support padding_free/packing.'
# Generate output with respect to the prompt only
- model_inputs = {k: v for k, v in inputs.items() if not k.startswith('prompt') and k != 'labels'}
+ model_inputs = {k: v for k, v in inputs.items()
+ if not k.startswith('prompt')
+ and not k.startswith('teacher_prompt')
+ and k not in ('labels', 'teacher_prompts')}
model_inputs['input_ids'] = inputs['prompts']
model_inputs.update({k[len('prompt_'):]: v for k, v in inputs.items() if k.startswith('prompt_')})
model_inputs.pop('position_ids', None)
@@ -103,7 +120,10 @@ def generate_on_policy_outputs(self, model, inputs, generation_config, pad_token
return generated_tokens, new_attention_mask, new_labels
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
- model_inputs = {k: v for k, v in inputs.items() if k not in {'prompt', 'labels'}}
+ model_inputs = {k: v for k, v in inputs.items()
+ if k not in {'prompt', 'labels', 'teacher_prompts'}
+ and not k.startswith('prompt_')
+ and not k.startswith('teacher_prompt_')}
# If generate is used, then use_logits_to_keep must be set to False.
use_logits_to_keep = self.get_use_logits_to_keep(True)
if use_logits_to_keep:
@@ -114,13 +134,36 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
# compute student output
outputs_student = model(**model_inputs)
- model_inputs.pop('labels', None)
+ # Prepare teacher model inputs
+ teacher_model_inputs = model_inputs.copy()
+ if 'teacher_input_ids' in inputs:
+ # Conditional distillation: use pre-concatenated teacher_input_ids
+ teacher_model_inputs['input_ids'] = inputs['teacher_input_ids']
+ teacher_model_inputs['attention_mask'] = inputs['teacher_attention_mask']
+ if 'teacher_position_ids' in inputs:
+ teacher_model_inputs['position_ids'] = inputs['teacher_position_ids']
+
load_context = self.load_teacher_model_context() if self.args.offload_teacher_model else nullcontext()
with torch.no_grad(), load_context:
- outputs_teacher = self.teacher_model(**model_inputs)
+ outputs_teacher = self.teacher_model(**teacher_model_inputs)
+ # Extract logits for distillation
+ # With use_logits_to_keep, both student and teacher logits are aligned with labels
shifted_labels = torch.roll(inputs['labels'], shifts=-1, dims=1)
mask = shifted_labels != -100
+
+ # Allow teacher_adapter to modify the loss mask
+ if self.teacher_adapter is not None:
+ modified_mask = self.teacher_adapter.get_loss_mask(
+ student_logits=outputs_student.logits,
+ teacher_logits=outputs_teacher.logits,
+ mask=mask,
+ inputs=inputs,
+ labels=inputs['labels']
+ )
+ if modified_mask is not None:
+ mask = modified_mask
+
shifted_student_logits = outputs_student.logits[mask][None]
shifted_teacher_logits = outputs_teacher.logits[mask][None]
@@ -151,11 +194,30 @@ def _prepare_batch_inputs(self, inputs: list) -> Dict[str, torch.Tensor]:
batch_encoded_inputs = []
for data in inputs:
+ # Save response_token_ids for later use in conditional distillation
+ response_token_ids = data.get('response_token_ids')
+ print(f"{self.tokenizer.decode(response_token_ids)}")
if 'response_token_ids' in data and data['response_token_ids']:
from .utils import replace_assistant_response_with_ids
data['messages'] = replace_assistant_response_with_ids(data['messages'], data['response_token_ids'])
+ # Use teacher_adapter to generate teacher_history if available
+ if self.teacher_adapter is not None:
+ # Prepare student messages (without final assistant response if present)
+ student_messages = data['messages'][:-1] if data['messages'][-1]['role'] == 'assistant' else data['messages']
+ # Pass complete data dict to adapter (similar to GRPO's reward_model_plugin design)
+ # Temporarily replace messages with student_messages for adapter
+ original_messages = data['messages']
+ data['messages'] = student_messages
+ teacher_history = self.teacher_adapter.shape_context(data)
+ data['messages'] = original_messages # Restore
+ data['teacher_history'] = teacher_history
encoded = template.encode(data, return_length=True)
+
+ # Preserve response_token_ids for conditional distillation
+ if response_token_ids is not None:
+ encoded['response_token_ids'] = response_token_ids
+
batch_encoded_inputs.append(encoded)
from swift.llm import to_device
diff --git a/swift/tuners/peft.py b/swift/tuners/peft.py
index 3ad2ffc726..a689bb1ded 100644
--- a/swift/tuners/peft.py
+++ b/swift/tuners/peft.py
@@ -329,7 +329,7 @@ def __new_init__(self, model: torch.nn.Module, config: Dict[str, LoraConfig], ad
# Compatible with SwiftModel
def dummy_function(*args, **kwargs):
- logger.warn(f'The function {kwargs["func"]} has no effects, consider using other functions.')
+ logger.warn(f'The function {kwargs['func']} has no effects, consider using other functions.')
PeftModel.activate_adapter = PeftModel.set_adapter
PeftModel.deactivate_adapter = partial(dummy_function, func='deactivate_adapter')
diff --git a/swift/ui/llm_infer/runtime.py b/swift/ui/llm_infer/runtime.py
index e6b73391d3..a14badc688 100644
--- a/swift/ui/llm_infer/runtime.py
+++ b/swift/ui/llm_infer/runtime.py
@@ -226,7 +226,7 @@ def construct_running_task(proc):
create_time_formatted = datetime.fromtimestamp(create_time).strftime('%Y-%m-%d, %H:%M')
return f'pid:{pid}/create:{create_time_formatted}' \
- f'/running:{format_time(ts - create_time)}/cmd:{" ".join(proc.cmdline())}'
+ f'/running:{format_time(ts - create_time)}/cmd:{' '.join(proc.cmdline())}'
@classmethod
def parse_info_from_cmdline(cls, task):
diff --git a/swift/ui/llm_train/llm_train.py b/swift/ui/llm_train/llm_train.py
index 3f21cbba04..bbcf549d23 100644
--- a/swift/ui/llm_train/llm_train.py
+++ b/swift/ui/llm_train/llm_train.py
@@ -386,7 +386,7 @@ def train(cls, *args):
cmd = train_stage
if kwargs.get('deepspeed'):
- more_params_cmd += f' --deepspeed {kwargs.pop("deepspeed")} '
+ more_params_cmd += f' --deepspeed {kwargs.pop('deepspeed')} '
use_liger_kernel = kwargs.get('use_liger_kernel', None)
if use_liger_kernel:
kwargs.pop('use_liger_kernel')
@@ -452,7 +452,7 @@ def train(cls, *args):
devices = [d for d in devices if d]
if other_kwargs['use_ddp']:
assert int(other_kwargs['ddp_num']) > 0
- ddp_param = f'NPROC_PER_NODE={int(other_kwargs["ddp_num"])}'
+ ddp_param = f'NPROC_PER_NODE={int(other_kwargs['ddp_num'])}'
all_envs['NPROC_PER_NODE'] = str(other_kwargs['ddp_num'])
assert (len(devices) == 1 or 'cpu' not in devices)
gpus = ','.join(devices)
diff --git a/swift/ui/llm_train/runtime.py b/swift/ui/llm_train/runtime.py
index eaa6166660..4468efab78 100644
--- a/swift/ui/llm_train/runtime.py
+++ b/swift/ui/llm_train/runtime.py
@@ -537,7 +537,7 @@ def construct_running_task(proc):
create_time_formatted = datetime.fromtimestamp(create_time).strftime('%Y-%m-%d, %H:%M')
return f'pid:{pid}/create:{create_time_formatted}' \
- f'/running:{format_time(ts-create_time)}/cmd:{" ".join(proc.cmdline())}'
+ f'/running:{format_time(ts-create_time)}/cmd:{' '.join(proc.cmdline())}'
@staticmethod
def parse_info_from_cmdline(task):
diff --git a/tests/test_align/test_template/test_agent.py b/tests/test_align/test_template/test_agent.py
index 9f3d347606..0f97b43e22 100644
--- a/tests/test_align/test_template/test_agent.py
+++ b/tests/test_align/test_template/test_agent.py
@@ -103,8 +103,8 @@ def test_react_en():
' and the condition is sunny with a humidity of 50%.')
template.set_mode('train')
encoded = template.encode({'messages': messages})
- print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
- print(f'labels: {template.safe_decode(encoded["labels"])}')
+ print(f'input_ids: {template.safe_decode(encoded['input_ids'])}')
+ print(f'labels: {template.safe_decode(encoded['labels'])}')
dataset = load_dataset('AI-ModelScope/function-calling-chatml')[0]
data = dataset[6]
@@ -112,8 +112,8 @@ def test_react_en():
data['messages'].insert(3, data['messages'][3])
template.template_backend = 'swift'
encoded = template.encode(data)
- print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
- print(f'labels: {template.safe_decode(encoded["labels"])}')
+ print(f'input_ids: {template.safe_decode(encoded['input_ids'])}')
+ print(f'labels: {template.safe_decode(encoded['labels'])}')
def test_react_zh():
@@ -139,8 +139,8 @@ def test_qwen_en():
'is at 50%. Enjoy the nice weather!')
template.set_mode('train')
encoded = template.encode({'messages': messages})
- print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
- print(f'labels: {template.safe_decode(encoded["labels"])}')
+ print(f'input_ids: {template.safe_decode(encoded['input_ids'])}')
+ print(f'labels: {template.safe_decode(encoded['labels'])}')
dataset = load_dataset('AI-ModelScope/function-calling-chatml')[0]
data = dataset[6]
@@ -148,8 +148,8 @@ def test_qwen_en():
data['messages'].insert(3, data['messages'][3])
template.template_backend = 'swift'
encoded = template.encode(data)
- print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
- print(f'labels: {template.safe_decode(encoded["labels"])}')
+ print(f'input_ids: {template.safe_decode(encoded['input_ids'])}')
+ print(f'labels: {template.safe_decode(encoded['labels'])}')
def test_qwen_zh():
@@ -175,8 +175,8 @@ def test_qwen_en_parallel():
'and the humidity is at 50%. Enjoy the nice weather!')
template.set_mode('train')
encoded = template.encode({'messages': messages})
- print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
- print(f'labels: {template.safe_decode(encoded["labels"])}')
+ print(f'input_ids: {template.safe_decode(encoded['input_ids'])}')
+ print(f'labels: {template.safe_decode(encoded['labels'])}')
dataset = load_dataset('AI-ModelScope/function-calling-chatml')[0]
data = dataset[6]
@@ -184,8 +184,8 @@ def test_qwen_en_parallel():
data['messages'].insert(3, data['messages'][3])
template.template_backend = 'swift'
encoded = template.encode(data)
- print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
- print(f'labels: {template.safe_decode(encoded["labels"])}')
+ print(f'input_ids: {template.safe_decode(encoded['input_ids'])}')
+ print(f'labels: {template.safe_decode(encoded['labels'])}')
def test_qwen_zh_parallel():
@@ -213,8 +213,8 @@ def test_hermes():
'and the humidity is at 50%. Enjoy the nice weather!')
template.set_mode('train')
encoded = template.encode({'messages': messages})
- print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
- print(f'labels: {template.safe_decode(encoded["labels"])}')
+ print(f'input_ids: {template.safe_decode(encoded['input_ids'])}')
+ print(f'labels: {template.safe_decode(encoded['labels'])}')
dataset = load_dataset('AI-ModelScope/function-calling-chatml')[0]
data = dataset[6]
@@ -222,12 +222,12 @@ def test_hermes():
data['messages'].insert(3, data['messages'][3])
template.template_backend = 'swift'
encoded = template.encode(data)
- print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
- print(f'labels: {template.safe_decode(encoded["labels"])}')
+ print(f'input_ids: {template.safe_decode(encoded['input_ids'])}')
+ print(f'labels: {template.safe_decode(encoded['labels'])}')
template.template_backend = 'jinja'
encoded2 = template.encode(data)
- print(f'input_ids: {template.safe_decode(encoded2["input_ids"])}')
- print(f'labels: {template.safe_decode(encoded2["labels"])}')
+ print(f'input_ids: {template.safe_decode(encoded2['input_ids'])}')
+ print(f'labels: {template.safe_decode(encoded2['labels'])}')
assert encoded['input_ids'] == encoded2['input_ids'][:-1]
@@ -262,8 +262,8 @@ def test_glm4_0414():
assert messages[-1]['content'] == '根据天气预报工具,北京今天的空气质量指数为10,属于良好水平;上海今天的空气质量指数为72,属于轻度污染水平。'
template.set_mode('train')
encoded = template.encode({'messages': messages})
- print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
- print(f'labels: {template.safe_decode(encoded["labels"])}')
+ print(f'input_ids: {template.safe_decode(encoded['input_ids'])}')
+ print(f'labels: {template.safe_decode(encoded['labels'])}')
dataset = load_dataset('AI-ModelScope/function-calling-chatml')[0]
data = dataset[6]
@@ -271,8 +271,8 @@ def test_glm4_0414():
data['messages'].insert(3, data['messages'][3])
template.template_backend = 'swift'
encoded = template.encode(data)
- print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
- print(f'labels: {template.safe_decode(encoded["labels"])}')
+ print(f'input_ids: {template.safe_decode(encoded['input_ids'])}')
+ print(f'labels: {template.safe_decode(encoded['labels'])}')
def test_llama3():
@@ -284,8 +284,8 @@ def test_llama3():
template.set_mode('train')
encoded = template.encode({'messages': messages})
- print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
- print(f'labels: {template.safe_decode(encoded["labels"])}')
+ print(f'input_ids: {template.safe_decode(encoded['input_ids'])}')
+ print(f'labels: {template.safe_decode(encoded['labels'])}')
dataset = load_dataset('AI-ModelScope/function-calling-chatml')[0]
data = dataset[6]
@@ -293,8 +293,8 @@ def test_llama3():
data['messages'].insert(3, data['messages'][3])
template.template_backend = 'swift'
encoded = template.encode(data)
- print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
- print(f'labels: {template.safe_decode(encoded["labels"])}')
+ print(f'input_ids: {template.safe_decode(encoded['input_ids'])}')
+ print(f'labels: {template.safe_decode(encoded['labels'])}')
def test_llama4():
@@ -305,8 +305,8 @@ def test_llama4():
messages = _infer(engine)
template.set_mode('train')
encoded = template.encode({'messages': messages})
- print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
- print(f'labels: {template.safe_decode(encoded["labels"])}')
+ print(f'input_ids: {template.safe_decode(encoded['input_ids'])}')
+ print(f'labels: {template.safe_decode(encoded['labels'])}')
def test_hunyuan():
@@ -322,12 +322,12 @@ def test_hunyuan():
template.template_backend = 'swift'
template.set_mode('train')
encoded = template.encode(data)
- print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
- print(f'labels: {template.safe_decode(encoded["labels"])}')
+ print(f'input_ids: {template.safe_decode(encoded['input_ids'])}')
+ print(f'labels: {template.safe_decode(encoded['labels'])}')
template.template_backend = 'jinja'
encoded2 = template.encode(data)
- print(f'input_ids: {template.safe_decode(encoded2["input_ids"])}')
- print(f'labels: {template.safe_decode(encoded2["labels"])}')
+ print(f'input_ids: {template.safe_decode(encoded2['input_ids'])}')
+ print(f'labels: {template.safe_decode(encoded2['labels'])}')
assert encoded['input_ids'][:-1] == encoded2['input_ids']
@@ -344,12 +344,12 @@ def test_glm4_5():
template.template_backend = 'swift'
template.set_mode('train')
encoded = template.encode(data)
- print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
- print(f'labels: {template.safe_decode(encoded["labels"])}')
+ print(f'input_ids: {template.safe_decode(encoded['input_ids'])}')
+ print(f'labels: {template.safe_decode(encoded['labels'])}')
template.template_backend = 'jinja'
encoded2 = template.encode(data)
- print(f'input_ids: {template.safe_decode(encoded2["input_ids"])}')
- print(f'labels: {template.safe_decode(encoded2["labels"])}')
+ print(f'input_ids: {template.safe_decode(encoded2['input_ids'])}')
+ print(f'labels: {template.safe_decode(encoded2['labels'])}')
assert encoded['input_ids'][:-1] == encoded2['input_ids']
@@ -368,12 +368,12 @@ def test_qwen3_coder():
template.template_backend = 'swift'
template.set_mode('train')
encoded = template.encode(data)
- print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
- print(f'labels: {template.safe_decode(encoded["labels"])}')
+ print(f'input_ids: {template.safe_decode(encoded['input_ids'])}')
+ print(f'labels: {template.safe_decode(encoded['labels'])}')
template.template_backend = 'jinja'
encoded2 = template.encode(data)
- print(f'input_ids: {template.safe_decode(encoded2["input_ids"])}')
- print(f'labels: {template.safe_decode(encoded2["labels"])}')
+ print(f'input_ids: {template.safe_decode(encoded2['input_ids'])}')
+ print(f'labels: {template.safe_decode(encoded2['labels'])}')
assert encoded['input_ids'] == encoded2['input_ids'][:-1]
@@ -392,12 +392,12 @@ def test_deepseek_v3_1():
template.template_backend = 'swift'
template.set_mode('train')
encoded = template.encode(data)
- print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
- print(f'labels: {template.safe_decode(encoded["labels"])}')
+ print(f'input_ids: {template.safe_decode(encoded['input_ids'])}')
+ print(f'labels: {template.safe_decode(encoded['labels'])}')
template.template_backend = 'jinja'
encoded2 = template.encode(data)
- print(f'input_ids: {template.safe_decode(encoded2["input_ids"])}')
- print(f'labels: {template.safe_decode(encoded2["labels"])}')
+ print(f'input_ids: {template.safe_decode(encoded2['input_ids'])}')
+ print(f'labels: {template.safe_decode(encoded2['labels'])}')
expected_input_ids = (
'<|begin▁of▁sentence|>\n\n## Tools\n'
@@ -499,15 +499,15 @@ def test_seed_oss():
template.template_backend = 'swift'
template.set_mode('train')
encoded = template.encode(data)
- print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
- print(f'labels: {template.safe_decode(encoded["labels"])}')
+ print(f'input_ids: {template.safe_decode(encoded['input_ids'])}')
+ print(f'labels: {template.safe_decode(encoded['labels'])}')
import re
expected_input_ids = re.sub(
r'.*?', '', template.safe_decode(encoded['input_ids']), flags=re.DOTALL)
template.template_backend = 'jinja'
encoded2 = template.encode(data)
- print(f'input_ids: {template.safe_decode(encoded2["input_ids"])}')
- print(f'labels: {template.safe_decode(encoded2["labels"])}')
+ print(f'input_ids: {template.safe_decode(encoded2['input_ids'])}')
+ print(f'labels: {template.safe_decode(encoded2['labels'])}')
assert template.safe_decode(encoded2['input_ids']) == expected_input_ids
diff --git a/train_gkd_debug.py b/train_gkd_debug.py
new file mode 100644
index 0000000000..262dc60c77
--- /dev/null
+++ b/train_gkd_debug.py
@@ -0,0 +1,112 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+GKD Training Script - Debuggable Version
+可以直接在IDE中运行和调试,无需使用shell脚本
+
+使用方法:
+1. 单GPU调试 (不使用DeepSpeed):
+ python train_gkd_debug.py
+
+2. 单GPU使用DeepSpeed:
+ torchrun --nproc_per_node=1 train_gkd_debug.py
+
+3. 多GPU:
+ torchrun --nproc_per_node=N train_gkd_debug.py
+"""
+import os
+from swift.llm import rlhf_main, RLHFArguments
+
+
+def main():
+ # 设置环境变量
+ os.environ['WANDB_API_KEY'] = '28e11ef52849c4640b93051377be27eafac62c44'
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
+ os.environ['CUDA_VISIBLE_DEVICES'] = '0'
+
+ # 创建参数对象
+ args = RLHFArguments(
+ # RLHF类型和模型配置
+ rlhf_type='gkd',
+ model='Qwen/Qwen2.5-0.5B-Instruct',
+
+ # Teacher模型配置
+ external_plugins=['rp_teacher_adapter.py'],
+ teacher_adapter='rp_teacher_adapter',
+ teacher_model='Qwen/Qwen2.5-0.5B-Instruct',
+ # 调试模式: 关闭teacher的DeepSpeed
+ teacher_deepspeed=None, # 改为 'zero3' 启用DeepSpeed
+ # teacher_deepspeed='zero3', # 启用DeepSpeed时取消注释
+
+ # 训练类型
+ train_type='full',
+
+ # 数据集配置
+ dataset=[
+ 'processed_training_data_final_fixed.jsonl',
+ 'benchmark_datasets_filtered_14k/alignbench_v1.1.jsonl',
+ 'benchmark_datasets_filtered_14k/arena_hard.jsonl',
+ 'benchmark_datasets_filtered_14k/arena_multi_turn_10-20.jsonl',
+ 'benchmark_datasets_filtered_14k/creative_writing_v3.jsonl',
+ 'benchmark_datasets_filtered_14k/ifeval.jsonl',
+ 'benchmark_datasets_filtered_14k/wildchat_gpt4_10-40.jsonl',
+ 'benchmark_datasets_filtered_14k/writingbench.jsonl',
+ ],
+
+ # GKD特定参数
+ seq_kd=False,
+ lmbda=1,
+ beta=1,
+
+ # 长度和截断配置
+ truncation_strategy='delete',
+ max_length=17000,
+ max_model_len=17000,
+ max_completion_length=200,
+
+ # 训练超参数
+ torch_dtype='bfloat16',
+ num_train_epochs=2,
+ per_device_train_batch_size=1,
+ learning_rate=1e-5,
+ gradient_accumulation_steps=1,
+ warmup_ratio=0.05,
+
+ # 保存和日志配置
+ save_steps=500,
+ save_total_limit=8,
+ logging_steps=1,
+ output_dir='condition_distill',
+ save_only_model=True,
+
+ # 数据加载配置
+ dataloader_num_workers=64,
+ dataset_num_proc=4,
+
+ # DeepSpeed配置
+ # 调试模式: 关闭DeepSpeed以避免分布式训练的复杂性
+ # 如需使用DeepSpeed,请用: torchrun --nproc_per_node=1 train_gkd_debug.py
+ deepspeed=None, # 改为 'zero3' 启用DeepSpeed
+ # deepspeed='zero3', # 启用DeepSpeed时取消注释
+
+ # 注意力实现
+ attn_impl='flash_attn',
+
+ # vLLM配置
+ use_vllm=True,
+ vllm_mode='server',
+ vllm_server_host=['127.0.0.1'], # 必须是列表
+ vllm_server_port=[8001], # 必须是列表
+
+ # 其他配置
+ report_to=['wandb'],
+ use_hf=True,
+ )
+
+ # 运行训练
+ # 在这里设置断点即可调试
+ rlhf_main(args)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/train_gkd_with_deepspeed.sh b/train_gkd_with_deepspeed.sh
new file mode 100755
index 0000000000..159a656f2e
--- /dev/null
+++ b/train_gkd_with_deepspeed.sh
@@ -0,0 +1,10 @@
+#!/bin/bash
+# 使用 DeepSpeed 启动训练脚本
+
+export WANDB_API_KEY=28e11ef52849c4640b93051377be27eafac62c44
+export PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True'
+export CUDA_VISIBLE_DEVICES=0
+
+# 使用 torchrun 启动以支持 DeepSpeed
+# 注意: 在运行前,需要在 train_gkd_debug.py 中启用 DeepSpeed 配置
+torchrun --nproc_per_node=1 train_gkd_debug.py