3.2 KiB
3.2 KiB
SFT Training Guide
Complete guide to Supervised Fine-Tuning (SFT) with TRL for instruction tuning and task-specific fine-tuning.
Overview
SFT trains models on input-output pairs to minimize cross-entropy loss. Use for:
- Instruction following
- Task-specific fine-tuning
- Chatbot training
- Domain adaptation
Dataset Formats
Format 1: Prompt-Completion
[
{
"prompt": "What is the capital of France?",
"completion": "The capital of France is Paris."
}
]
Format 2: Conversational (ChatML)
[
{
"messages": [
{"role": "user", "content": "What is Python?"},
{"role": "assistant", "content": "Python is a programming language."}
]
}
]
Format 3: Text-only
[
{"text": "User: Hello\nAssistant: Hi! How can I help?"}
]
Basic Training
from trl import SFTTrainer, SFTConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
# Load model
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
# Load dataset
dataset = load_dataset("trl-lib/Capybara", split="train")
# Configure
config = SFTConfig(
output_dir="Qwen2.5-SFT",
per_device_train_batch_size=4,
num_train_epochs=1,
learning_rate=2e-5,
save_strategy="epoch"
)
# Train
trainer = SFTTrainer(
model=model,
args=config,
train_dataset=dataset,
tokenizer=tokenizer
)
trainer.train()
Chat Templates
Apply chat templates automatically:
trainer = SFTTrainer(
model=model,
args=config,
train_dataset=dataset, # Messages format
tokenizer=tokenizer
# Chat template applied automatically
)
Or manually:
def format_chat(example):
messages = example["messages"]
text = tokenizer.apply_chat_template(messages, tokenize=False)
return {"text": text}
dataset = dataset.map(format_chat)
Packing for Efficiency
Pack multiple sequences into one to maximize GPU utilization:
config = SFTConfig(
packing=True, # Enable packing
max_seq_length=2048,
dataset_text_field="text"
)
Benefits: 2-3× faster training Trade-off: Slightly more complex batching
Multi-GPU Training
accelerate launch --num_processes 4 train_sft.py
Or with config:
config = SFTConfig(
output_dir="model-sft",
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
num_train_epochs=1
)
LoRA Fine-Tuning
from peft import LoraConfig
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules="all-linear",
lora_dropout=0.05,
task_type="CAUSAL_LM"
)
trainer = SFTTrainer(
model=model,
args=config,
train_dataset=dataset,
peft_config=lora_config # Add LoRA
)
Hyperparameters
| Model Size | Learning Rate | Batch Size | Epochs |
|---|---|---|---|
| <1B | 5e-5 | 8-16 | 1-3 |
| 1-7B | 2e-5 | 4-8 | 1-2 |
| 7-13B | 1e-5 | 2-4 | 1 |
| 13B+ | 5e-6 | 1-2 | 1 |