169 lines
3.2 KiB
Markdown
169 lines
3.2 KiB
Markdown
|
|
# SFT Training Guide
|
|||
|
|
|
|||
|
|
Complete guide to Supervised Fine-Tuning (SFT) with TRL for instruction tuning and task-specific fine-tuning.
|
|||
|
|
|
|||
|
|
## Overview
|
|||
|
|
|
|||
|
|
SFT trains models on input-output pairs to minimize cross-entropy loss. Use for:
|
|||
|
|
- Instruction following
|
|||
|
|
- Task-specific fine-tuning
|
|||
|
|
- Chatbot training
|
|||
|
|
- Domain adaptation
|
|||
|
|
|
|||
|
|
## Dataset Formats
|
|||
|
|
|
|||
|
|
### Format 1: Prompt-Completion
|
|||
|
|
|
|||
|
|
```json
|
|||
|
|
[
|
|||
|
|
{
|
|||
|
|
"prompt": "What is the capital of France?",
|
|||
|
|
"completion": "The capital of France is Paris."
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
### Format 2: Conversational (ChatML)
|
|||
|
|
|
|||
|
|
```json
|
|||
|
|
[
|
|||
|
|
{
|
|||
|
|
"messages": [
|
|||
|
|
{"role": "user", "content": "What is Python?"},
|
|||
|
|
{"role": "assistant", "content": "Python is a programming language."}
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
### Format 3: Text-only
|
|||
|
|
|
|||
|
|
```json
|
|||
|
|
[
|
|||
|
|
{"text": "User: Hello\nAssistant: Hi! How can I help?"}
|
|||
|
|
]
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
## Basic Training
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
from trl import SFTTrainer, SFTConfig
|
|||
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|||
|
|
from datasets import load_dataset
|
|||
|
|
|
|||
|
|
# Load model
|
|||
|
|
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B")
|
|||
|
|
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
|
|||
|
|
|
|||
|
|
# Load dataset
|
|||
|
|
dataset = load_dataset("trl-lib/Capybara", split="train")
|
|||
|
|
|
|||
|
|
# Configure
|
|||
|
|
config = SFTConfig(
|
|||
|
|
output_dir="Qwen2.5-SFT",
|
|||
|
|
per_device_train_batch_size=4,
|
|||
|
|
num_train_epochs=1,
|
|||
|
|
learning_rate=2e-5,
|
|||
|
|
save_strategy="epoch"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Train
|
|||
|
|
trainer = SFTTrainer(
|
|||
|
|
model=model,
|
|||
|
|
args=config,
|
|||
|
|
train_dataset=dataset,
|
|||
|
|
tokenizer=tokenizer
|
|||
|
|
)
|
|||
|
|
trainer.train()
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
## Chat Templates
|
|||
|
|
|
|||
|
|
Apply chat templates automatically:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
trainer = SFTTrainer(
|
|||
|
|
model=model,
|
|||
|
|
args=config,
|
|||
|
|
train_dataset=dataset, # Messages format
|
|||
|
|
tokenizer=tokenizer
|
|||
|
|
# Chat template applied automatically
|
|||
|
|
)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Or manually:
|
|||
|
|
```python
|
|||
|
|
def format_chat(example):
|
|||
|
|
messages = example["messages"]
|
|||
|
|
text = tokenizer.apply_chat_template(messages, tokenize=False)
|
|||
|
|
return {"text": text}
|
|||
|
|
|
|||
|
|
dataset = dataset.map(format_chat)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
## Packing for Efficiency
|
|||
|
|
|
|||
|
|
Pack multiple sequences into one to maximize GPU utilization:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
config = SFTConfig(
|
|||
|
|
packing=True, # Enable packing
|
|||
|
|
max_seq_length=2048,
|
|||
|
|
dataset_text_field="text"
|
|||
|
|
)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
**Benefits**: 2-3× faster training
|
|||
|
|
**Trade-off**: Slightly more complex batching
|
|||
|
|
|
|||
|
|
## Multi-GPU Training
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
accelerate launch --num_processes 4 train_sft.py
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Or with config:
|
|||
|
|
```python
|
|||
|
|
config = SFTConfig(
|
|||
|
|
output_dir="model-sft",
|
|||
|
|
per_device_train_batch_size=4,
|
|||
|
|
gradient_accumulation_steps=4,
|
|||
|
|
num_train_epochs=1
|
|||
|
|
)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
## LoRA Fine-Tuning
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
from peft import LoraConfig
|
|||
|
|
|
|||
|
|
lora_config = LoraConfig(
|
|||
|
|
r=16,
|
|||
|
|
lora_alpha=32,
|
|||
|
|
target_modules="all-linear",
|
|||
|
|
lora_dropout=0.05,
|
|||
|
|
task_type="CAUSAL_LM"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
trainer = SFTTrainer(
|
|||
|
|
model=model,
|
|||
|
|
args=config,
|
|||
|
|
train_dataset=dataset,
|
|||
|
|
peft_config=lora_config # Add LoRA
|
|||
|
|
)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
## Hyperparameters
|
|||
|
|
|
|||
|
|
| Model Size | Learning Rate | Batch Size | Epochs |
|
|||
|
|
|------------|---------------|------------|--------|
|
|||
|
|
| <1B | 5e-5 | 8-16 | 1-3 |
|
|||
|
|
| 1-7B | 2e-5 | 4-8 | 1-2 |
|
|||
|
|
| 7-13B | 1e-5 | 2-4 | 1 |
|
|||
|
|
| 13B+ | 5e-6 | 1-2 | 1 |
|
|||
|
|
|
|||
|
|
## References
|
|||
|
|
|
|||
|
|
- TRL docs: https://huggingface.co/docs/trl/sft_trainer
|
|||
|
|
- Examples: https://github.com/huggingface/trl/tree/main/examples/scripts
|