-
Notifications
You must be signed in to change notification settings - Fork 1k
Add conditional distillation support for GKD trainer #6542
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
2cb20f5
9a118b4
5409e7d
507530b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -212,3 +212,78 @@ We can achieve the [On-Policy Distillation](https://thinkingmachines.ai/blog/on- | |||||||||||||||||||||||||||||||||||||||
| ``` | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| For a complete implementation, refer to the example script [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/on_policy_distillation.sh). | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| ## Conditional Distillation | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| Conditional distillation enables the teacher and student models to use **different contexts or prompts** during training, allowing for more flexible knowledge transfer strategies. For example: | ||||||||||||||||||||||||||||||||||||||||
| - Teacher model receives prompts with additional expert guidance | ||||||||||||||||||||||||||||||||||||||||
| - Teacher model receives task-reformulated inputs (e.g., summaries, translations) | ||||||||||||||||||||||||||||||||||||||||
| - Teacher model uses longer context information | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| ### TeacherAdapter Plugin System | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| You can customize the teacher model's context transformation logic by implementing the `TeacherAdapter` interface: | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| ```python | ||||||||||||||||||||||||||||||||||||||||
| # swift/plugin/teacher_adapter.py | ||||||||||||||||||||||||||||||||||||||||
| from swift.plugin import TeacherAdapter | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| class MyTeacherAdapter(TeacherAdapter): | ||||||||||||||||||||||||||||||||||||||||
| def shape_context(self, history): | ||||||||||||||||||||||||||||||||||||||||
| """Transform student messages to teacher messages | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||||||||
| history: Student model's message list (OpenAI format) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||||||||||||||||
| Teacher model's message list | ||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||
| # Add extra system prompt for teacher | ||||||||||||||||||||||||||||||||||||||||
| teacher_history = history.copy() | ||||||||||||||||||||||||||||||||||||||||
| if teacher_history and teacher_history[0]['role'] == 'system': | ||||||||||||||||||||||||||||||||||||||||
| teacher_history[0]['content'] += '\n\nYou are an expert with extensive knowledge.' | ||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||
| teacher_history.insert(0, { | ||||||||||||||||||||||||||||||||||||||||
| 'role': 'system', | ||||||||||||||||||||||||||||||||||||||||
| 'content': 'You are an expert with extensive knowledge.' | ||||||||||||||||||||||||||||||||||||||||
| }) | ||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+243
to
+249
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The example code for
Suggested change
|
||||||||||||||||||||||||||||||||||||||||
| return teacher_history | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| # Register to plugin system | ||||||||||||||||||||||||||||||||||||||||
| from swift.plugin import teacher_adapters | ||||||||||||||||||||||||||||||||||||||||
| teacher_adapters['my_adapter'] = MyTeacherAdapter | ||||||||||||||||||||||||||||||||||||||||
| ``` | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| ### Built-in Adapters | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| SWIFT provides two built-in teacher adapters: | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| | Adapter | Description | | ||||||||||||||||||||||||||||||||||||||||
| |---------|-------------| | ||||||||||||||||||||||||||||||||||||||||
| | `default` | Default: teacher uses the same context as student | | ||||||||||||||||||||||||||||||||||||||||
| | `example` | Example: adds extra instructions to system prompt for teacher | | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| ### Usage | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| ```bash | ||||||||||||||||||||||||||||||||||||||||
| swift rlhf \ | ||||||||||||||||||||||||||||||||||||||||
| --rlhf_type gkd \ | ||||||||||||||||||||||||||||||||||||||||
| --model Qwen/Qwen2.5-0.5B-Instruct \ | ||||||||||||||||||||||||||||||||||||||||
| --teacher_model Qwen/Qwen2.5-7B-Instruct \ | ||||||||||||||||||||||||||||||||||||||||
| --teacher_adapter example \ | ||||||||||||||||||||||||||||||||||||||||
| --dataset your_dataset.jsonl \ | ||||||||||||||||||||||||||||||||||||||||
| ... | ||||||||||||||||||||||||||||||||||||||||
| ``` | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| ### How It Works | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| In conditional distillation: | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| 1. **Student model** processes original input: `[prompt_student] + [response]` | ||||||||||||||||||||||||||||||||||||||||
| 2. **Teacher model** processes transformed input: `[prompt_teacher] + [response]` | ||||||||||||||||||||||||||||||||||||||||
| 3. Both models compute logits on the **same response tokens** | ||||||||||||||||||||||||||||||||||||||||
| 4. Distillation loss is calculated using these logits | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| Where `prompt_teacher` is transformed from `prompt_student` by `teacher_adapter.shape_context()`, while the `response` part remains unchanged. | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| Training script reference [here](../../../examples/train/on_policy_condition_distillation.sh) | ||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,45 @@ | ||
| # On-Policy Distillation https://thinkingmachines.ai/blog/on-policy-distillation/ | ||
|
|
||
| # CUDA_VISIBLE_DEVICES=6 \ | ||
| # swift rollout \ | ||
| # --template qwen3_nothinking\ | ||
| # --model Qwen/Qwen3-8B-Base \ | ||
| # --vllm_max_model_len 24192 | ||
|
|
||
| NPROC_PER_NODE=7 \ | ||
| PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ | ||
| CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,7 \ | ||
| swift rlhf \ | ||
| --rlhf_type gkd \ | ||
| --teacher_adapter math_teacher \ | ||
| --template qwen3_nothinking\ | ||
| --model Qwen/Qwen3-8B-Base \ | ||
| --teacher_model_type qwen3_nothinking\ | ||
| --teacher_model Qwen/Qwen3-8B-Base \ | ||
| --train_type full \ | ||
| --dataset open-thoughts/OpenThoughts3-1.2M#1000 \ | ||
| --seq_kd false \ | ||
| --lmbda 1 \ | ||
| --beta 1 \ | ||
| --torch_dtype bfloat16 \ | ||
| --num_train_epochs 1 \ | ||
| --per_device_train_batch_size 1 \ | ||
| --learning_rate 1e-5 \ | ||
| --gradient_accumulation_steps 1 \ | ||
| --save_steps 1000 \ | ||
| --save_total_limit 2 \ | ||
| --logging_steps 1 \ | ||
| --max_length 16000 \ | ||
| --max_completion_length 100 \ | ||
| --output_dir output \ | ||
| --warmup_ratio 0.05 \ | ||
| --save_only_model true \ | ||
| --dataloader_num_workers 64 \ | ||
| --dataset_num_proc 4 \ | ||
| --deepspeed zero2 \ | ||
| --teacher_deepspeed zero3 \ | ||
| --attn_impl flash_attn \ | ||
| --use_vllm true \ | ||
| --vllm_mode server \ | ||
| --vllm_server_host 127.0.0.1 \ | ||
| --vllm_server_port 8000 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
文档中
MyTeacherAdapter的示例代码使用了teacher_history[0]['content'] += ...的方式来修改系统提示。由于history.copy()是浅拷贝,这样做会意外地修改原始的history对象中的内容。虽然在当前 GKD 的流程中可能不会引发问题,但这是一种有风险的实践。建议将文档中的示例更新为更安全的实现方式,即创建一个新的字典来替换teacher_history[0],就像swift/plugin/teacher_adapter.py中MathTeacherAdapter的实现一样,以避免潜在的副作用。