Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions docs/source/Instruction/GKD.md
Original file line number Diff line number Diff line change
Expand Up @@ -210,3 +210,78 @@ swift rlhf \
```

相关脚本可以参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/on_policy_distillation.sh)

## 条件蒸馏(Conditional Distillation)

条件蒸馏允许教师模型和学生模型使用**不同的上下文或提示词**进行训练,从而实现更灵活的知识迁移策略。例如:
- 教师模型接收包含额外专家指导的提示词
- 教师模型接收任务重构后的输入(如摘要、翻译等)
- 教师模型使用更长的上下文信息

### TeacherAdapter 插件系统

通过实现 `TeacherAdapter` 接口,可以自定义教师模型的上下文转换逻辑:

```python
# swift/plugin/teacher_adapter.py
from swift.plugin import TeacherAdapter

class MyTeacherAdapter(TeacherAdapter):
def shape_context(self, history):
"""将学生的消息转换为教师的消息

Args:
history: 学生模型的消息列表(OpenAI 格式)

Returns:
教师模型的消息列表
"""
# 为教师添加额外的系统提示
teacher_history = history.copy()
if teacher_history and teacher_history[0]['role'] == 'system':
teacher_history[0]['content'] += '\n\n你是一位专业领域专家。'
else:
teacher_history.insert(0, {
'role': 'system',
'content': '你是一位专业领域专家。'
})
Comment on lines +241 to +247
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

文档中 MyTeacherAdapter 的示例代码使用了 teacher_history[0]['content'] += ... 的方式来修改系统提示。由于 history.copy() 是浅拷贝,这样做会意外地修改原始的 history 对象中的内容。虽然在当前 GKD 的流程中可能不会引发问题,但这是一种有风险的实践。建议将文档中的示例更新为更安全的实现方式,即创建一个新的字典来替换 teacher_history[0],就像 swift/plugin/teacher_adapter.pyMathTeacherAdapter 的实现一样,以避免潜在的副作用。

Suggested change
if teacher_history and teacher_history[0]['role'] == 'system':
teacher_history[0]['content'] += '\n\n你是一位专业领域专家。'
else:
teacher_history.insert(0, {
'role': 'system',
'content': '你是一位专业领域专家。'
})
teacher_history = history.copy()
if teacher_history and teacher_history[0]['role'] == 'system':
# 更健壮的方式:创建一个新字典以避免副作用
teacher_history[0] = {
'role': 'system',
'content': teacher_history[0]['content'] + '\n\n你是一位专业领域专家。'
}
else:
teacher_history.insert(0, {
'role': 'system',
'content': '你是一位专业领域专家。'
})

return teacher_history

# 注册到插件系统
from swift.plugin import teacher_adapters
teacher_adapters['my_adapter'] = MyTeacherAdapter
```

### 内置 Adapter

SWIFT 提供两个内置的 teacher adapter:

| Adapter | 说明 |
|---------|------|
| `default` | 默认:教师使用与学生相同的上下文 |
| `example` | 示例:为教师添加额外的系统提示 |

### 使用方法

```bash
swift rlhf \
--rlhf_type gkd \
--model Qwen/Qwen2.5-0.5B-Instruct \
--teacher_model Qwen/Qwen2.5-7B-Instruct \
--teacher_adapter example \
--dataset your_dataset.jsonl \
...
```

### 工作原理

在条件蒸馏中:

1. **学生模型**处理原始输入:`[prompt_student] + [response]`
2. **教师模型**处理转换后的输入:`[prompt_teacher] + [response]`
3. 两个模型在**相同的 response tokens** 上计算 logits
4. 使用这些 logits 计算蒸馏损失

其中 `prompt_teacher` 由 `teacher_adapter.shape_context()` 从 `prompt_student` 转换而来,而 `response` 部分保持不变。

训练脚本参考[这里](../../../examples/train/on_policy_condition_distillation.sh)
75 changes: 75 additions & 0 deletions docs/source_en/Instruction/GKD.md
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,78 @@ We can achieve the [On-Policy Distillation](https://thinkingmachines.ai/blog/on-
```

For a complete implementation, refer to the example script [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/on_policy_distillation.sh).

## Conditional Distillation

Conditional distillation enables the teacher and student models to use **different contexts or prompts** during training, allowing for more flexible knowledge transfer strategies. For example:
- Teacher model receives prompts with additional expert guidance
- Teacher model receives task-reformulated inputs (e.g., summaries, translations)
- Teacher model uses longer context information

### TeacherAdapter Plugin System

You can customize the teacher model's context transformation logic by implementing the `TeacherAdapter` interface:

```python
# swift/plugin/teacher_adapter.py
from swift.plugin import TeacherAdapter

class MyTeacherAdapter(TeacherAdapter):
def shape_context(self, history):
"""Transform student messages to teacher messages

Args:
history: Student model's message list (OpenAI format)

Returns:
Teacher model's message list
"""
# Add extra system prompt for teacher
teacher_history = history.copy()
if teacher_history and teacher_history[0]['role'] == 'system':
teacher_history[0]['content'] += '\n\nYou are an expert with extensive knowledge.'
else:
teacher_history.insert(0, {
'role': 'system',
'content': 'You are an expert with extensive knowledge.'
})
Comment on lines +243 to +249
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The example code for MyTeacherAdapter in the documentation uses teacher_history[0]['content'] += ... to modify the system prompt. Since history.copy() performs a shallow copy, this approach will unintentionally modify the content of the original history object. While this might not cause issues in the current GKD workflow, it is a risky practice. It's recommended to update the example to a safer implementation by creating a new dictionary to replace teacher_history[0], similar to how MathTeacherAdapter is implemented in swift/plugin/teacher_adapter.py, to prevent potential side effects.

Suggested change
if teacher_history and teacher_history[0]['role'] == 'system':
teacher_history[0]['content'] += '\n\nYou are an expert with extensive knowledge.'
else:
teacher_history.insert(0, {
'role': 'system',
'content': 'You are an expert with extensive knowledge.'
})
teacher_history = history.copy()
if teacher_history and teacher_history[0]['role'] == 'system':
# More robust way: create a new dict to avoid side effects
teacher_history[0] = {
'role': 'system',
'content': teacher_history[0]['content'] + '\n\nYou are an expert with extensive knowledge.'
}
else:
teacher_history.insert(0, {
'role': 'system',
'content': 'You are an expert with extensive knowledge.'
})

return teacher_history

# Register to plugin system
from swift.plugin import teacher_adapters
teacher_adapters['my_adapter'] = MyTeacherAdapter
```

### Built-in Adapters

SWIFT provides two built-in teacher adapters:

| Adapter | Description |
|---------|-------------|
| `default` | Default: teacher uses the same context as student |
| `example` | Example: adds extra instructions to system prompt for teacher |

### Usage

```bash
swift rlhf \
--rlhf_type gkd \
--model Qwen/Qwen2.5-0.5B-Instruct \
--teacher_model Qwen/Qwen2.5-7B-Instruct \
--teacher_adapter example \
--dataset your_dataset.jsonl \
...
```

### How It Works

In conditional distillation:

1. **Student model** processes original input: `[prompt_student] + [response]`
2. **Teacher model** processes transformed input: `[prompt_teacher] + [response]`
3. Both models compute logits on the **same response tokens**
4. Distillation loss is calculated using these logits

Where `prompt_teacher` is transformed from `prompt_student` by `teacher_adapter.shape_context()`, while the `response` part remains unchanged.

Training script reference [here](../../../examples/train/on_policy_condition_distillation.sh)
45 changes: 45 additions & 0 deletions examples/train/on_policy_condition_distillation.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# On-Policy Distillation https://thinkingmachines.ai/blog/on-policy-distillation/

# CUDA_VISIBLE_DEVICES=6 \
# swift rollout \
# --template qwen3_nothinking\
# --model Qwen/Qwen3-8B-Base \
# --vllm_max_model_len 24192

NPROC_PER_NODE=7 \
PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,7 \
swift rlhf \
--rlhf_type gkd \
--teacher_adapter math_teacher \
--template qwen3_nothinking\
--model Qwen/Qwen3-8B-Base \
--teacher_model_type qwen3_nothinking\
--teacher_model Qwen/Qwen3-8B-Base \
--train_type full \
--dataset open-thoughts/OpenThoughts3-1.2M#1000 \
--seq_kd false \
--lmbda 1 \
--beta 1 \
--torch_dtype bfloat16 \
--num_train_epochs 1 \
--per_device_train_batch_size 1 \
--learning_rate 1e-5 \
--gradient_accumulation_steps 1 \
--save_steps 1000 \
--save_total_limit 2 \
--logging_steps 1 \
--max_length 16000 \
--max_completion_length 100 \
--output_dir output \
--warmup_ratio 0.05 \
--save_only_model true \
--dataloader_num_workers 64 \
--dataset_num_proc 4 \
--deepspeed zero2 \
--teacher_deepspeed zero3 \
--attn_impl flash_attn \
--use_vllm true \
--vllm_mode server \
--vllm_server_host 127.0.0.1 \
--vllm_server_port 8000
8 changes: 4 additions & 4 deletions examples/train/rft/rft.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def do_sample(model: str, model_type: str, dataset: List[str], iter: int):
for device in range(device_count):
sample_cmd = (f'{conda_prefix} USE_OPENCOMPASS_EVALUATOR=True CUDA_VISIBLE_DEVICES={device} swift sample '
f'--model {model} --model_type {model_type} '
f'--dataset {" ".join(dataset)} '
f'--dataset {' '.join(dataset)} '
f'--data_range {device} {device_count} '
f'--max_length 2048 '
f'--system "You are a math model, you should **think step by step** carefully, '
Expand Down Expand Up @@ -61,7 +61,7 @@ def do_sample(model: str, model_type: str, dataset: List[str], iter: int):
sample_cmd = (
f'{conda_prefix} USE_OPENCOMPASS_EVALUATOR=True CUDA_VISIBLE_DEVICES={device} swift sample '
f'--model {model} --model_type {model_type} ' # change to --resume_from_checkpoint to use the latest optimizer state # noqa
f'--dataset {" ".join(dataset)} '
f'--dataset {' '.join(dataset)} '
f'--data_range {device} {device_count} '
f'--max_length 2048 '
f'--system "You are a math model, you should **think step by step** carefully, '
Expand Down Expand Up @@ -91,7 +91,7 @@ def do_sample(model: str, model_type: str, dataset: List[str], iter: int):
for proc, handler in enumerate(handlers):
handler.wait()
assert os.path.exists(os.path.join('sample_output', f'iter_{iter}_proc_{proc}_sampling.jsonl')), (
f'{os.path.join("sample_output", f"iter_{iter}_proc_{proc}_sampling.jsonl")} not exists, '
f'{os.path.join('sample_output', f"iter_{iter}_proc_{proc}_sampling.jsonl")} not exists, '
'please check the sample logs to get the detail error.')
datasets.append(os.path.join('sample_output', f'iter_{iter}_proc_{proc}_sampling.jsonl'))
print(f'Sampling done, files:{datasets}', flush=True)
Expand All @@ -110,7 +110,7 @@ def do_train(model: str, model_type: str, datasets: List[str], iter, cmd='sft'):
ga = 128 // get_device_count() // 2
train_cmd = (f'{conda_prefix} {gpu_prefix} swift {cmd} '
f'--model {model} --model_type {model_type} '
f'--dataset {" ".join(datasets)} '
f'--dataset {' '.join(datasets)} '
f'--max_length 2048 '
f'--num_train_epochs 1 '
f'--load_args false '
Expand Down
10 changes: 5 additions & 5 deletions scripts/benchmark/exp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def run(self, exp: Experiment):
exp.runtime = runtime
envs = deepcopy(runtime.get('env', {}))
envs.update(os.environ)
logger.info(f'Running cmd: {runtime["running_cmd"]}, env: {runtime.get("env", {})}')
logger.info(f'Running cmd: {runtime['running_cmd']}, env: {runtime.get('env', {})}')
os.makedirs('exp', exist_ok=True)
log_file = os.path.join('exp', f'{exp.name}.eval.log')
exp.handler = subprocess.Popen(runtime['running_cmd'] + f' > {log_file} 2>&1', env=envs, shell=True)
Expand All @@ -140,7 +140,7 @@ def run(self, exp: Experiment):
exp.runtime = runtime
envs = deepcopy(runtime.get('env', {}))
envs.update(os.environ)
logger.info(f'Running cmd: {runtime["running_cmd"]}, env: {runtime.get("env", {})}')
logger.info(f'Running cmd: {runtime['running_cmd']}, env: {runtime.get('env', {})}')
os.makedirs('exp', exist_ok=True)
log_file = os.path.join('exp', f'{exp.name}.{exp.cmd}.log')
exp.handler = subprocess.Popen(runtime['running_cmd'] + f' > {log_file} 2>&1', env=envs, shell=True)
Expand All @@ -162,10 +162,10 @@ def _build_eval_cmd(self, exp: Experiment):
if best_model_checkpoint is not None:
if not os.path.exists(os.path.join(best_model_checkpoint, 'args.json')):
cmd = f'swift eval --ckpt_dir {best_model_checkpoint} ' \
+ f'--infer_backend pt --train_type full --eval_dataset {" ".join(eval_dataset)}'
+ f'--infer_backend pt --train_type full --eval_dataset {' '.join(eval_dataset)}'
else:
cmd = f'swift eval --model {exp.args.get("model")} --infer_backend pt ' \
f'--eval_dataset {" ".join(eval_dataset)}'
cmd = f'swift eval --model {exp.args.get('model')} --infer_backend pt ' \
f'--eval_dataset {' '.join(eval_dataset)}'

return {
'running_cmd': cmd,
Expand Down
44 changes: 22 additions & 22 deletions scripts/benchmark/generate_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,23 +69,23 @@ def tuner_hyper_params(self):
return ''
if args['sft_type'] in ('lora', 'adalora', 'longlora'):
if 'lora_rank' in args:
hyper_params += f'rank={args["lora_rank"]}/' \
f'target={args["lora_target_modules"]}/' \
f'alpha={args["lora_alpha"]}/' \
f'lr_ratio={args.get("lora_lr_ratio", None)}/' \
f'use_rslora={args.get("use_rslora", False)}/' \
f'use_dora={args.get("use_dora", False)}'
hyper_params += f'rank={args['lora_rank']}/' \
f'target={args['lora_target_modules']}/' \
f'alpha={args['lora_alpha']}/' \
f'lr_ratio={args.get('lora_lr_ratio', None)}/' \
f'use_rslora={args.get('use_rslora', False)}/' \
f'use_dora={args.get('use_dora', False)}'
else:
hyper_params = ''
if args['sft_type'] == 'full':
if 'use_galore' in args and args['use_galore'] == 'true':
hyper_params += f'galore_rank={args["galore_rank"]}/' \
f'galore_per_parameter={args["galore_optim_per_parameter"]}/' \
f'galore_with_embedding={args["galore_with_embedding"]}/'
hyper_params += f'galore_rank={args['galore_rank']}/' \
f'galore_per_parameter={args['galore_optim_per_parameter']}/' \
f'galore_with_embedding={args['galore_with_embedding']}/'
if args['sft_type'] == 'llamapro':
hyper_params += f'num_blocks={args["llamapro_num_new_blocks"]}/'
hyper_params += f'num_blocks={args['llamapro_num_new_blocks']}/'
if 'neftune_noise_alpha' in args and args['neftune_noise_alpha']:
hyper_params += f'neftune_noise_alpha={args["neftune_noise_alpha"]}/'
hyper_params += f'neftune_noise_alpha={args['neftune_noise_alpha']}/'

if hyper_params.endswith('/'):
hyper_params = hyper_params[:-1]
Expand All @@ -95,8 +95,8 @@ def tuner_hyper_params(self):
def hyper_parameters(self):
if 'learning_rate' not in self.args:
return ''
return f'lr={self.args["learning_rate"]}/' \
f'epoch={self.args["num_train_epochs"]}'
return f'lr={self.args['learning_rate']}/' \
f'epoch={self.args['num_train_epochs']}'

@property
def train_speed(self):
Expand Down Expand Up @@ -190,10 +190,10 @@ def generate_sft_report(outputs: List[ModelOutput]):
ceval_acc = '' if not ceval_acc else f'**{ceval_acc:.3f}**'

line = f'|{output.name}|' \
f'{output.args["model_type"]}|' \
f'{output.args.get("dataset")}|' \
f'{output.args.get("train_dataset_mix_ratio", 0.)}|' \
f'{output.args.get("sft_type")}|' \
f'{output.args['model_type']}|' \
f'{output.args.get('dataset')}|' \
f'{output.args.get('train_dataset_mix_ratio', 0.)}|' \
f'{output.args.get('sft_type')}|' \
f'{output.tuner_hyper_params}|' \
f'{output.num_trainable_parameters}({output.trainable_parameters_percentage})|' \
f'{use_flash_attn}|' \
Expand Down Expand Up @@ -267,14 +267,14 @@ def generate_export_report(outputs: List[ModelOutput]):
ceval_acc = '' if not ceval_acc else f'**{ceval_acc:.3f}**'

if output.train_dataset_info:
dataset_info = f'{output.args["dataset"]}/{output.train_dataset_info}'
dataset_info = f'{output.args['dataset']}/{output.train_dataset_info}'
else:
dataset_info = f'{output.args["dataset"]}'
dataset_info = f'{output.args['dataset']}'
line = f'|{output.name}|' \
f'{output.args["model_type"]}|' \
f'{output.args['model_type']}|' \
f'{dataset_info}|' \
f'{output.args["quant_method"]}|' \
f'{output.args["quant_bits"]}|' \
f'{output.args['quant_method']}|' \
f'{output.args['quant_bits']}|' \
f'{infer_speed}|' \
f'{gsm8k_acc}|' \
f'{arc_acc}|' \
Expand Down
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ def readme():

def get_version():
with open(version_file, 'r', encoding='utf-8') as f:
exec(compile(f.read(), version_file, 'exec'))
return locals()['__version__']
version_globals = {}
exec(compile(f.read(), version_file, 'exec'), version_globals)
return version_globals['__version__']


def parse_requirements(fname='requirements.txt', with_version=True):
Expand Down
2 changes: 1 addition & 1 deletion swift/llm/argument/base_args/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def _init_rope_scaling(self):
if self.max_model_len is None:
self.max_model_len = rope_model_len
elif self.max_model_len > rope_model_len:
logger.warning(f'rope config ({rope_model_len} = {rope_scaling["factor"]} * '
logger.warning(f'rope config ({rope_model_len} = {rope_scaling['factor']} * '
f'{origin_max_model_len}) should be bigger than max_model_len '
f'from command line ({self.max_model_len})')
self.rope_scaling = rope_scaling
Expand Down
5 changes: 5 additions & 0 deletions swift/llm/argument/rlhf_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ class RLHFArguments(TeacherModelArguments, GRPOArguments, PPOArguments, RewardMo
lmbda: float = 0.5
seq_kd: bool = False
offload_teacher_model: bool = False
teacher_adapter: Optional[str] = None # Teacher adapter plugin for conditional distillation
# compat
max_new_tokens: Optional[int] = None # use max_completion_length instead

Expand Down Expand Up @@ -433,6 +434,10 @@ def _init_teacher_deepspeed(self):
def _check_gkd(self):
if self.rlhf_type != 'gkd':
return
# Preserve extra dataset fields for conditional distillation
self.remove_unused_columns = False
logger.info(f'Setting args.remove_unused_columns: {self.remove_unused_columns}')

if is_mp() and self.use_vllm:
raise ValueError('GKD with vLLM is not compatible with `device_map`. '
'Please set NPROC_PER_NODE equal to num_processes.')
Expand Down
Loading