Sync all skills and memories 2026-04-14 07:27
This commit is contained in:
3
skills/mlops/models/DESCRIPTION.md
Normal file
3
skills/mlops/models/DESCRIPTION.md
Normal file
@@ -0,0 +1,3 @@
|
||||
---
|
||||
description: Specific model architectures and tools — computer vision (CLIP, SAM, Stable Diffusion), speech (Whisper), audio generation (AudioCraft), and multimodal models (LLaVA).
|
||||
---
|
||||
567
skills/mlops/models/audiocraft/SKILL.md
Normal file
567
skills/mlops/models/audiocraft/SKILL.md
Normal file
@@ -0,0 +1,567 @@
|
||||
---
|
||||
name: audiocraft-audio-generation
|
||||
description: PyTorch library for audio generation including text-to-music (MusicGen) and text-to-sound (AudioGen). Use when you need to generate music from text descriptions, create sound effects, or perform melody-conditioned music generation.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [audiocraft, torch>=2.0.0, transformers>=4.30.0]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Multimodal, Audio Generation, Text-to-Music, Text-to-Audio, MusicGen]
|
||||
|
||||
---
|
||||
|
||||
# AudioCraft: Audio Generation
|
||||
|
||||
Comprehensive guide to using Meta's AudioCraft for text-to-music and text-to-audio generation with MusicGen, AudioGen, and EnCodec.
|
||||
|
||||
## When to use AudioCraft
|
||||
|
||||
**Use AudioCraft when:**
|
||||
- Need to generate music from text descriptions
|
||||
- Creating sound effects and environmental audio
|
||||
- Building music generation applications
|
||||
- Need melody-conditioned music generation
|
||||
- Want stereo audio output
|
||||
- Require controllable music generation with style transfer
|
||||
|
||||
**Key features:**
|
||||
- **MusicGen**: Text-to-music generation with melody conditioning
|
||||
- **AudioGen**: Text-to-sound effects generation
|
||||
- **EnCodec**: High-fidelity neural audio codec
|
||||
- **Multiple model sizes**: Small (300M) to Large (3.3B)
|
||||
- **Stereo support**: Full stereo audio generation
|
||||
- **Style conditioning**: MusicGen-Style for reference-based generation
|
||||
|
||||
**Use alternatives instead:**
|
||||
- **Stable Audio**: For longer commercial music generation
|
||||
- **Bark**: For text-to-speech with music/sound effects
|
||||
- **Riffusion**: For spectogram-based music generation
|
||||
- **OpenAI Jukebox**: For raw audio generation with lyrics
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# From PyPI
|
||||
pip install audiocraft
|
||||
|
||||
# From GitHub (latest)
|
||||
pip install git+https://github.com/facebookresearch/audiocraft.git
|
||||
|
||||
# Or use HuggingFace Transformers
|
||||
pip install transformers torch torchaudio
|
||||
```
|
||||
|
||||
### Basic text-to-music (AudioCraft)
|
||||
|
||||
```python
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
# Load model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Set generation parameters
|
||||
model.set_generation_params(
|
||||
duration=8, # seconds
|
||||
top_k=250,
|
||||
temperature=1.0
|
||||
)
|
||||
|
||||
# Generate from text
|
||||
descriptions = ["happy upbeat electronic dance music with synths"]
|
||||
wav = model.generate(descriptions)
|
||||
|
||||
# Save audio
|
||||
torchaudio.save("output.wav", wav[0].cpu(), sample_rate=32000)
|
||||
```
|
||||
|
||||
### Using HuggingFace Transformers
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
||||
import scipy
|
||||
|
||||
# Load model and processor
|
||||
processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
|
||||
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
|
||||
model.to("cuda")
|
||||
|
||||
# Generate music
|
||||
inputs = processor(
|
||||
text=["80s pop track with bassy drums and synth"],
|
||||
padding=True,
|
||||
return_tensors="pt"
|
||||
).to("cuda")
|
||||
|
||||
audio_values = model.generate(
|
||||
**inputs,
|
||||
do_sample=True,
|
||||
guidance_scale=3,
|
||||
max_new_tokens=256
|
||||
)
|
||||
|
||||
# Save
|
||||
sampling_rate = model.config.audio_encoder.sampling_rate
|
||||
scipy.io.wavfile.write("output.wav", rate=sampling_rate, data=audio_values[0, 0].cpu().numpy())
|
||||
```
|
||||
|
||||
### Text-to-sound with AudioGen
|
||||
|
||||
```python
|
||||
from audiocraft.models import AudioGen
|
||||
|
||||
# Load AudioGen
|
||||
model = AudioGen.get_pretrained('facebook/audiogen-medium')
|
||||
|
||||
model.set_generation_params(duration=5)
|
||||
|
||||
# Generate sound effects
|
||||
descriptions = ["dog barking in a park with birds chirping"]
|
||||
wav = model.generate(descriptions)
|
||||
|
||||
torchaudio.save("sound.wav", wav[0].cpu(), sample_rate=16000)
|
||||
```
|
||||
|
||||
## Core concepts
|
||||
|
||||
### Architecture overview
|
||||
|
||||
```
|
||||
AudioCraft Architecture:
|
||||
┌──────────────────────────────────────────────────────────────┐
|
||||
│ Text Encoder (T5) │
|
||||
│ │ │
|
||||
│ Text Embeddings │
|
||||
└────────────────────────┬─────────────────────────────────────┘
|
||||
│
|
||||
┌────────────────────────▼─────────────────────────────────────┐
|
||||
│ Transformer Decoder (LM) │
|
||||
│ Auto-regressively generates audio tokens │
|
||||
│ Using efficient token interleaving patterns │
|
||||
└────────────────────────┬─────────────────────────────────────┘
|
||||
│
|
||||
┌────────────────────────▼─────────────────────────────────────┐
|
||||
│ EnCodec Audio Decoder │
|
||||
│ Converts tokens back to audio waveform │
|
||||
└──────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Model variants
|
||||
|
||||
| Model | Size | Description | Use Case |
|
||||
|-------|------|-------------|----------|
|
||||
| `musicgen-small` | 300M | Text-to-music | Quick generation |
|
||||
| `musicgen-medium` | 1.5B | Text-to-music | Balanced |
|
||||
| `musicgen-large` | 3.3B | Text-to-music | Best quality |
|
||||
| `musicgen-melody` | 1.5B | Text + melody | Melody conditioning |
|
||||
| `musicgen-melody-large` | 3.3B | Text + melody | Best melody |
|
||||
| `musicgen-stereo-*` | Varies | Stereo output | Stereo generation |
|
||||
| `musicgen-style` | 1.5B | Style transfer | Reference-based |
|
||||
| `audiogen-medium` | 1.5B | Text-to-sound | Sound effects |
|
||||
|
||||
### Generation parameters
|
||||
|
||||
| Parameter | Default | Description |
|
||||
|-----------|---------|-------------|
|
||||
| `duration` | 8.0 | Length in seconds (1-120) |
|
||||
| `top_k` | 250 | Top-k sampling |
|
||||
| `top_p` | 0.0 | Nucleus sampling (0 = disabled) |
|
||||
| `temperature` | 1.0 | Sampling temperature |
|
||||
| `cfg_coef` | 3.0 | Classifier-free guidance |
|
||||
|
||||
## MusicGen usage
|
||||
|
||||
### Text-to-music generation
|
||||
|
||||
```python
|
||||
from audiocraft.models import MusicGen
|
||||
import torchaudio
|
||||
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-medium')
|
||||
|
||||
# Configure generation
|
||||
model.set_generation_params(
|
||||
duration=30, # Up to 30 seconds
|
||||
top_k=250, # Sampling diversity
|
||||
top_p=0.0, # 0 = use top_k only
|
||||
temperature=1.0, # Creativity (higher = more varied)
|
||||
cfg_coef=3.0 # Text adherence (higher = stricter)
|
||||
)
|
||||
|
||||
# Generate multiple samples
|
||||
descriptions = [
|
||||
"epic orchestral soundtrack with strings and brass",
|
||||
"chill lo-fi hip hop beat with jazzy piano",
|
||||
"energetic rock song with electric guitar"
|
||||
]
|
||||
|
||||
# Generate (returns [batch, channels, samples])
|
||||
wav = model.generate(descriptions)
|
||||
|
||||
# Save each
|
||||
for i, audio in enumerate(wav):
|
||||
torchaudio.save(f"music_{i}.wav", audio.cpu(), sample_rate=32000)
|
||||
```
|
||||
|
||||
### Melody-conditioned generation
|
||||
|
||||
```python
|
||||
from audiocraft.models import MusicGen
|
||||
import torchaudio
|
||||
|
||||
# Load melody model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-melody')
|
||||
model.set_generation_params(duration=30)
|
||||
|
||||
# Load melody audio
|
||||
melody, sr = torchaudio.load("melody.wav")
|
||||
|
||||
# Generate with melody conditioning
|
||||
descriptions = ["acoustic guitar folk song"]
|
||||
wav = model.generate_with_chroma(descriptions, melody, sr)
|
||||
|
||||
torchaudio.save("melody_conditioned.wav", wav[0].cpu(), sample_rate=32000)
|
||||
```
|
||||
|
||||
### Stereo generation
|
||||
|
||||
```python
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
# Load stereo model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-stereo-medium')
|
||||
model.set_generation_params(duration=15)
|
||||
|
||||
descriptions = ["ambient electronic music with wide stereo panning"]
|
||||
wav = model.generate(descriptions)
|
||||
|
||||
# wav shape: [batch, 2, samples] for stereo
|
||||
print(f"Stereo shape: {wav.shape}") # [1, 2, 480000]
|
||||
torchaudio.save("stereo.wav", wav[0].cpu(), sample_rate=32000)
|
||||
```
|
||||
|
||||
### Audio continuation
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
||||
|
||||
processor = AutoProcessor.from_pretrained("facebook/musicgen-medium")
|
||||
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-medium")
|
||||
|
||||
# Load audio to continue
|
||||
import torchaudio
|
||||
audio, sr = torchaudio.load("intro.wav")
|
||||
|
||||
# Process with text and audio
|
||||
inputs = processor(
|
||||
audio=audio.squeeze().numpy(),
|
||||
sampling_rate=sr,
|
||||
text=["continue with a epic chorus"],
|
||||
padding=True,
|
||||
return_tensors="pt"
|
||||
)
|
||||
|
||||
# Generate continuation
|
||||
audio_values = model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=512)
|
||||
```
|
||||
|
||||
## MusicGen-Style usage
|
||||
|
||||
### Style-conditioned generation
|
||||
|
||||
```python
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
# Load style model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-style')
|
||||
|
||||
# Configure generation with style
|
||||
model.set_generation_params(
|
||||
duration=30,
|
||||
cfg_coef=3.0,
|
||||
cfg_coef_beta=5.0 # Style influence
|
||||
)
|
||||
|
||||
# Configure style conditioner
|
||||
model.set_style_conditioner_params(
|
||||
eval_q=3, # RVQ quantizers (1-6)
|
||||
excerpt_length=3.0 # Style excerpt length
|
||||
)
|
||||
|
||||
# Load style reference
|
||||
style_audio, sr = torchaudio.load("reference_style.wav")
|
||||
|
||||
# Generate with text + style
|
||||
descriptions = ["upbeat dance track"]
|
||||
wav = model.generate_with_style(descriptions, style_audio, sr)
|
||||
```
|
||||
|
||||
### Style-only generation (no text)
|
||||
|
||||
```python
|
||||
# Generate matching style without text prompt
|
||||
model.set_generation_params(
|
||||
duration=30,
|
||||
cfg_coef=3.0,
|
||||
cfg_coef_beta=None # Disable double CFG for style-only
|
||||
)
|
||||
|
||||
wav = model.generate_with_style([None], style_audio, sr)
|
||||
```
|
||||
|
||||
## AudioGen usage
|
||||
|
||||
### Sound effect generation
|
||||
|
||||
```python
|
||||
from audiocraft.models import AudioGen
|
||||
import torchaudio
|
||||
|
||||
model = AudioGen.get_pretrained('facebook/audiogen-medium')
|
||||
model.set_generation_params(duration=10)
|
||||
|
||||
# Generate various sounds
|
||||
descriptions = [
|
||||
"thunderstorm with heavy rain and lightning",
|
||||
"busy city traffic with car horns",
|
||||
"ocean waves crashing on rocks",
|
||||
"crackling campfire in forest"
|
||||
]
|
||||
|
||||
wav = model.generate(descriptions)
|
||||
|
||||
for i, audio in enumerate(wav):
|
||||
torchaudio.save(f"sound_{i}.wav", audio.cpu(), sample_rate=16000)
|
||||
```
|
||||
|
||||
## EnCodec usage
|
||||
|
||||
### Audio compression
|
||||
|
||||
```python
|
||||
from audiocraft.models import CompressionModel
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
# Load EnCodec
|
||||
model = CompressionModel.get_pretrained('facebook/encodec_32khz')
|
||||
|
||||
# Load audio
|
||||
wav, sr = torchaudio.load("audio.wav")
|
||||
|
||||
# Ensure correct sample rate
|
||||
if sr != 32000:
|
||||
resampler = torchaudio.transforms.Resample(sr, 32000)
|
||||
wav = resampler(wav)
|
||||
|
||||
# Encode to tokens
|
||||
with torch.no_grad():
|
||||
encoded = model.encode(wav.unsqueeze(0))
|
||||
codes = encoded[0] # Audio codes
|
||||
|
||||
# Decode back to audio
|
||||
with torch.no_grad():
|
||||
decoded = model.decode(codes)
|
||||
|
||||
torchaudio.save("reconstructed.wav", decoded[0].cpu(), sample_rate=32000)
|
||||
```
|
||||
|
||||
## Common workflows
|
||||
|
||||
### Workflow 1: Music generation pipeline
|
||||
|
||||
```python
|
||||
import torch
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
class MusicGenerator:
|
||||
def __init__(self, model_name="facebook/musicgen-medium"):
|
||||
self.model = MusicGen.get_pretrained(model_name)
|
||||
self.sample_rate = 32000
|
||||
|
||||
def generate(self, prompt, duration=30, temperature=1.0, cfg=3.0):
|
||||
self.model.set_generation_params(
|
||||
duration=duration,
|
||||
top_k=250,
|
||||
temperature=temperature,
|
||||
cfg_coef=cfg
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
wav = self.model.generate([prompt])
|
||||
|
||||
return wav[0].cpu()
|
||||
|
||||
def generate_batch(self, prompts, duration=30):
|
||||
self.model.set_generation_params(duration=duration)
|
||||
|
||||
with torch.no_grad():
|
||||
wav = self.model.generate(prompts)
|
||||
|
||||
return wav.cpu()
|
||||
|
||||
def save(self, audio, path):
|
||||
torchaudio.save(path, audio, sample_rate=self.sample_rate)
|
||||
|
||||
# Usage
|
||||
generator = MusicGenerator()
|
||||
audio = generator.generate(
|
||||
"epic cinematic orchestral music",
|
||||
duration=30,
|
||||
temperature=1.0
|
||||
)
|
||||
generator.save(audio, "epic_music.wav")
|
||||
```
|
||||
|
||||
### Workflow 2: Sound design batch processing
|
||||
|
||||
```python
|
||||
import json
|
||||
from pathlib import Path
|
||||
from audiocraft.models import AudioGen
|
||||
import torchaudio
|
||||
|
||||
def batch_generate_sounds(sound_specs, output_dir):
|
||||
"""
|
||||
Generate multiple sounds from specifications.
|
||||
|
||||
Args:
|
||||
sound_specs: list of {"name": str, "description": str, "duration": float}
|
||||
output_dir: output directory path
|
||||
"""
|
||||
model = AudioGen.get_pretrained('facebook/audiogen-medium')
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
|
||||
results = []
|
||||
|
||||
for spec in sound_specs:
|
||||
model.set_generation_params(duration=spec.get("duration", 5))
|
||||
|
||||
wav = model.generate([spec["description"]])
|
||||
|
||||
output_path = output_dir / f"{spec['name']}.wav"
|
||||
torchaudio.save(str(output_path), wav[0].cpu(), sample_rate=16000)
|
||||
|
||||
results.append({
|
||||
"name": spec["name"],
|
||||
"path": str(output_path),
|
||||
"description": spec["description"]
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
# Usage
|
||||
sounds = [
|
||||
{"name": "explosion", "description": "massive explosion with debris", "duration": 3},
|
||||
{"name": "footsteps", "description": "footsteps on wooden floor", "duration": 5},
|
||||
{"name": "door", "description": "wooden door creaking and closing", "duration": 2}
|
||||
]
|
||||
|
||||
results = batch_generate_sounds(sounds, "sound_effects/")
|
||||
```
|
||||
|
||||
### Workflow 3: Gradio demo
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
import torch
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
def generate_music(prompt, duration, temperature, cfg_coef):
|
||||
model.set_generation_params(
|
||||
duration=duration,
|
||||
temperature=temperature,
|
||||
cfg_coef=cfg_coef
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
wav = model.generate([prompt])
|
||||
|
||||
# Save to temp file
|
||||
path = "temp_output.wav"
|
||||
torchaudio.save(path, wav[0].cpu(), sample_rate=32000)
|
||||
return path
|
||||
|
||||
demo = gr.Interface(
|
||||
fn=generate_music,
|
||||
inputs=[
|
||||
gr.Textbox(label="Music Description", placeholder="upbeat electronic dance music"),
|
||||
gr.Slider(1, 30, value=8, label="Duration (seconds)"),
|
||||
gr.Slider(0.5, 2.0, value=1.0, label="Temperature"),
|
||||
gr.Slider(1.0, 10.0, value=3.0, label="CFG Coefficient")
|
||||
],
|
||||
outputs=gr.Audio(label="Generated Music"),
|
||||
title="MusicGen Demo"
|
||||
)
|
||||
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
## Performance optimization
|
||||
|
||||
### Memory optimization
|
||||
|
||||
```python
|
||||
# Use smaller model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Clear cache between generations
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Generate shorter durations
|
||||
model.set_generation_params(duration=10) # Instead of 30
|
||||
|
||||
# Use half precision
|
||||
model = model.half()
|
||||
```
|
||||
|
||||
### Batch processing efficiency
|
||||
|
||||
```python
|
||||
# Process multiple prompts at once (more efficient)
|
||||
descriptions = ["prompt1", "prompt2", "prompt3", "prompt4"]
|
||||
wav = model.generate(descriptions) # Single batch
|
||||
|
||||
# Instead of
|
||||
for desc in descriptions:
|
||||
wav = model.generate([desc]) # Multiple batches (slower)
|
||||
```
|
||||
|
||||
### GPU memory requirements
|
||||
|
||||
| Model | FP32 VRAM | FP16 VRAM |
|
||||
|-------|-----------|-----------|
|
||||
| musicgen-small | ~4GB | ~2GB |
|
||||
| musicgen-medium | ~8GB | ~4GB |
|
||||
| musicgen-large | ~16GB | ~8GB |
|
||||
|
||||
## Common issues
|
||||
|
||||
| Issue | Solution |
|
||||
|-------|----------|
|
||||
| CUDA OOM | Use smaller model, reduce duration |
|
||||
| Poor quality | Increase cfg_coef, better prompts |
|
||||
| Generation too short | Check max duration setting |
|
||||
| Audio artifacts | Try different temperature |
|
||||
| Stereo not working | Use stereo model variant |
|
||||
|
||||
## References
|
||||
|
||||
- **[Advanced Usage](references/advanced-usage.md)** - Training, fine-tuning, deployment
|
||||
- **[Troubleshooting](references/troubleshooting.md)** - Common issues and solutions
|
||||
|
||||
## Resources
|
||||
|
||||
- **GitHub**: https://github.com/facebookresearch/audiocraft
|
||||
- **Paper (MusicGen)**: https://arxiv.org/abs/2306.05284
|
||||
- **Paper (AudioGen)**: https://arxiv.org/abs/2209.15352
|
||||
- **HuggingFace**: https://huggingface.co/facebook/musicgen-small
|
||||
- **Demo**: https://huggingface.co/spaces/facebook/MusicGen
|
||||
666
skills/mlops/models/audiocraft/references/advanced-usage.md
Normal file
666
skills/mlops/models/audiocraft/references/advanced-usage.md
Normal file
@@ -0,0 +1,666 @@
|
||||
# AudioCraft Advanced Usage Guide
|
||||
|
||||
## Fine-tuning MusicGen
|
||||
|
||||
### Custom dataset preparation
|
||||
|
||||
```python
|
||||
import os
|
||||
import json
|
||||
from pathlib import Path
|
||||
import torchaudio
|
||||
|
||||
def prepare_dataset(audio_dir, output_dir, metadata_file):
|
||||
"""
|
||||
Prepare dataset for MusicGen fine-tuning.
|
||||
|
||||
Directory structure:
|
||||
output_dir/
|
||||
├── audio/
|
||||
│ ├── 0001.wav
|
||||
│ ├── 0002.wav
|
||||
│ └── ...
|
||||
└── metadata.json
|
||||
"""
|
||||
output_dir = Path(output_dir)
|
||||
audio_output = output_dir / "audio"
|
||||
audio_output.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Load metadata (format: {"path": "...", "description": "..."})
|
||||
with open(metadata_file) as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
processed = []
|
||||
|
||||
for idx, item in enumerate(metadata):
|
||||
audio_path = Path(audio_dir) / item["path"]
|
||||
|
||||
# Load and resample to 32kHz
|
||||
wav, sr = torchaudio.load(str(audio_path))
|
||||
if sr != 32000:
|
||||
resampler = torchaudio.transforms.Resample(sr, 32000)
|
||||
wav = resampler(wav)
|
||||
|
||||
# Convert to mono if stereo
|
||||
if wav.shape[0] > 1:
|
||||
wav = wav.mean(dim=0, keepdim=True)
|
||||
|
||||
# Save processed audio
|
||||
output_path = audio_output / f"{idx:04d}.wav"
|
||||
torchaudio.save(str(output_path), wav, sample_rate=32000)
|
||||
|
||||
processed.append({
|
||||
"path": str(output_path.relative_to(output_dir)),
|
||||
"description": item["description"],
|
||||
"duration": wav.shape[1] / 32000
|
||||
})
|
||||
|
||||
# Save processed metadata
|
||||
with open(output_dir / "metadata.json", "w") as f:
|
||||
json.dump(processed, f, indent=2)
|
||||
|
||||
print(f"Processed {len(processed)} samples")
|
||||
return processed
|
||||
```
|
||||
|
||||
### Fine-tuning with dora
|
||||
|
||||
```bash
|
||||
# AudioCraft uses dora for experiment management
|
||||
# Install dora
|
||||
pip install dora-search
|
||||
|
||||
# Clone AudioCraft
|
||||
git clone https://github.com/facebookresearch/audiocraft.git
|
||||
cd audiocraft
|
||||
|
||||
# Create config for fine-tuning
|
||||
cat > config/solver/musicgen/finetune.yaml << 'EOF'
|
||||
defaults:
|
||||
- musicgen/musicgen_base
|
||||
- /model: lm/musicgen_lm
|
||||
- /conditioner: cond_base
|
||||
|
||||
solver: musicgen
|
||||
autocast: true
|
||||
autocast_dtype: float16
|
||||
|
||||
optim:
|
||||
epochs: 100
|
||||
batch_size: 4
|
||||
lr: 1e-4
|
||||
ema: 0.999
|
||||
optimizer: adamw
|
||||
|
||||
dataset:
|
||||
batch_size: 4
|
||||
num_workers: 4
|
||||
train:
|
||||
- dset: your_dataset
|
||||
root: /path/to/dataset
|
||||
valid:
|
||||
- dset: your_dataset
|
||||
root: /path/to/dataset
|
||||
|
||||
checkpoint:
|
||||
save_every: 10
|
||||
keep_every_states: null
|
||||
EOF
|
||||
|
||||
# Run fine-tuning
|
||||
dora run solver=musicgen/finetune
|
||||
```
|
||||
|
||||
### LoRA fine-tuning
|
||||
|
||||
```python
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from audiocraft.models import MusicGen
|
||||
import torch
|
||||
|
||||
# Load base model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Get the language model component
|
||||
lm = model.lm
|
||||
|
||||
# Configure LoRA
|
||||
lora_config = LoraConfig(
|
||||
r=8,
|
||||
lora_alpha=16,
|
||||
target_modules=["q_proj", "v_proj", "k_proj", "out_proj"],
|
||||
lora_dropout=0.05,
|
||||
bias="none"
|
||||
)
|
||||
|
||||
# Apply LoRA
|
||||
lm = get_peft_model(lm, lora_config)
|
||||
lm.print_trainable_parameters()
|
||||
```
|
||||
|
||||
## Multi-GPU Training
|
||||
|
||||
### DataParallel
|
||||
|
||||
```python
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Wrap LM with DataParallel
|
||||
if torch.cuda.device_count() > 1:
|
||||
model.lm = nn.DataParallel(model.lm)
|
||||
|
||||
model.to("cuda")
|
||||
```
|
||||
|
||||
### DistributedDataParallel
|
||||
|
||||
```python
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
def setup(rank, world_size):
|
||||
dist.init_process_group("nccl", rank=rank, world_size=world_size)
|
||||
torch.cuda.set_device(rank)
|
||||
|
||||
def train(rank, world_size):
|
||||
setup(rank, world_size)
|
||||
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
model.lm = model.lm.to(rank)
|
||||
model.lm = DDP(model.lm, device_ids=[rank])
|
||||
|
||||
# Training loop
|
||||
# ...
|
||||
|
||||
dist.destroy_process_group()
|
||||
```
|
||||
|
||||
## Custom Conditioning
|
||||
|
||||
### Adding new conditioners
|
||||
|
||||
```python
|
||||
from audiocraft.modules.conditioners import BaseConditioner
|
||||
import torch
|
||||
|
||||
class CustomConditioner(BaseConditioner):
|
||||
"""Custom conditioner for additional control signals."""
|
||||
|
||||
def __init__(self, dim, output_dim):
|
||||
super().__init__(dim, output_dim)
|
||||
self.embed = torch.nn.Linear(dim, output_dim)
|
||||
|
||||
def forward(self, x):
|
||||
return self.embed(x)
|
||||
|
||||
def tokenize(self, x):
|
||||
# Tokenize input for conditioning
|
||||
return x
|
||||
|
||||
# Use with MusicGen
|
||||
from audiocraft.models.builders import get_lm_model
|
||||
|
||||
# Modify model config to include custom conditioner
|
||||
# This requires editing the model configuration
|
||||
```
|
||||
|
||||
### Melody conditioning internals
|
||||
|
||||
```python
|
||||
from audiocraft.models import MusicGen
|
||||
from audiocraft.modules.codebooks_patterns import DelayedPatternProvider
|
||||
import torch
|
||||
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-melody')
|
||||
|
||||
# Access chroma extractor
|
||||
chroma_extractor = model.lm.condition_provider.conditioners.get('chroma')
|
||||
|
||||
# Manual chroma extraction
|
||||
def extract_chroma(audio, sr):
|
||||
"""Extract chroma features from audio."""
|
||||
import librosa
|
||||
|
||||
# Compute chroma
|
||||
chroma = librosa.feature.chroma_cqt(y=audio.numpy(), sr=sr)
|
||||
|
||||
return torch.from_numpy(chroma).float()
|
||||
|
||||
# Use extracted chroma for conditioning
|
||||
chroma = extract_chroma(melody_audio, sample_rate)
|
||||
```
|
||||
|
||||
## EnCodec Deep Dive
|
||||
|
||||
### Custom compression settings
|
||||
|
||||
```python
|
||||
from audiocraft.models import CompressionModel
|
||||
import torch
|
||||
|
||||
# Load EnCodec
|
||||
encodec = CompressionModel.get_pretrained('facebook/encodec_32khz')
|
||||
|
||||
# Access codec parameters
|
||||
print(f"Sample rate: {encodec.sample_rate}")
|
||||
print(f"Channels: {encodec.channels}")
|
||||
print(f"Cardinality: {encodec.cardinality}") # Codebook size
|
||||
print(f"Num codebooks: {encodec.num_codebooks}")
|
||||
print(f"Frame rate: {encodec.frame_rate}")
|
||||
|
||||
# Encode with specific bandwidth
|
||||
# Lower bandwidth = more compression, lower quality
|
||||
encodec.set_target_bandwidth(6.0) # 6 kbps
|
||||
|
||||
audio = torch.randn(1, 1, 32000) # 1 second
|
||||
encoded = encodec.encode(audio)
|
||||
decoded = encodec.decode(encoded[0])
|
||||
```
|
||||
|
||||
### Streaming encoding
|
||||
|
||||
```python
|
||||
import torch
|
||||
from audiocraft.models import CompressionModel
|
||||
|
||||
encodec = CompressionModel.get_pretrained('facebook/encodec_32khz')
|
||||
|
||||
def encode_streaming(audio_stream, chunk_size=32000):
|
||||
"""Encode audio in streaming fashion."""
|
||||
all_codes = []
|
||||
|
||||
for chunk in audio_stream:
|
||||
# Ensure chunk is right shape
|
||||
if chunk.dim() == 1:
|
||||
chunk = chunk.unsqueeze(0).unsqueeze(0)
|
||||
|
||||
with torch.no_grad():
|
||||
codes = encodec.encode(chunk)[0]
|
||||
all_codes.append(codes)
|
||||
|
||||
return torch.cat(all_codes, dim=-1)
|
||||
|
||||
def decode_streaming(codes_stream, output_stream):
|
||||
"""Decode codes in streaming fashion."""
|
||||
for codes in codes_stream:
|
||||
with torch.no_grad():
|
||||
audio = encodec.decode(codes)
|
||||
output_stream.write(audio.cpu().numpy())
|
||||
```
|
||||
|
||||
## MultiBand Diffusion
|
||||
|
||||
### Using MBD for enhanced quality
|
||||
|
||||
```python
|
||||
from audiocraft.models import MusicGen, MultiBandDiffusion
|
||||
|
||||
# Load MusicGen
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-medium')
|
||||
|
||||
# Load MultiBand Diffusion
|
||||
mbd = MultiBandDiffusion.get_mbd_musicgen()
|
||||
|
||||
model.set_generation_params(duration=10)
|
||||
|
||||
# Generate with standard decoder
|
||||
descriptions = ["epic orchestral music"]
|
||||
wav_standard = model.generate(descriptions)
|
||||
|
||||
# Generate tokens and use MBD decoder
|
||||
with torch.no_grad():
|
||||
# Get tokens
|
||||
gen_tokens = model.generate_tokens(descriptions)
|
||||
|
||||
# Decode with MBD
|
||||
wav_mbd = mbd.tokens_to_wav(gen_tokens)
|
||||
|
||||
# Compare quality
|
||||
print(f"Standard shape: {wav_standard.shape}")
|
||||
print(f"MBD shape: {wav_mbd.shape}")
|
||||
```
|
||||
|
||||
## API Server Deployment
|
||||
|
||||
### FastAPI server
|
||||
|
||||
```python
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
import torch
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
import io
|
||||
import base64
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Load model at startup
|
||||
model = None
|
||||
|
||||
@app.on_event("startup")
|
||||
async def load_model():
|
||||
global model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
model.set_generation_params(duration=10)
|
||||
|
||||
class GenerateRequest(BaseModel):
|
||||
prompt: str
|
||||
duration: float = 10.0
|
||||
temperature: float = 1.0
|
||||
cfg_coef: float = 3.0
|
||||
|
||||
class GenerateResponse(BaseModel):
|
||||
audio_base64: str
|
||||
sample_rate: int
|
||||
duration: float
|
||||
|
||||
@app.post("/generate", response_model=GenerateResponse)
|
||||
async def generate(request: GenerateRequest):
|
||||
if model is None:
|
||||
raise HTTPException(status_code=500, detail="Model not loaded")
|
||||
|
||||
try:
|
||||
model.set_generation_params(
|
||||
duration=min(request.duration, 30),
|
||||
temperature=request.temperature,
|
||||
cfg_coef=request.cfg_coef
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
wav = model.generate([request.prompt])
|
||||
|
||||
# Convert to bytes
|
||||
buffer = io.BytesIO()
|
||||
torchaudio.save(buffer, wav[0].cpu(), sample_rate=32000, format="wav")
|
||||
buffer.seek(0)
|
||||
|
||||
audio_base64 = base64.b64encode(buffer.read()).decode()
|
||||
|
||||
return GenerateResponse(
|
||||
audio_base64=audio_base64,
|
||||
sample_rate=32000,
|
||||
duration=wav.shape[-1] / 32000
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "ok", "model_loaded": model is not None}
|
||||
|
||||
# Run: uvicorn server:app --host 0.0.0.0 --port 8000
|
||||
```
|
||||
|
||||
### Batch processing service
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import torch
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
class MusicGenService:
|
||||
def __init__(self, model_name='facebook/musicgen-small', max_workers=2):
|
||||
self.model = MusicGen.get_pretrained(model_name)
|
||||
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||
self.lock = asyncio.Lock()
|
||||
|
||||
async def generate_async(self, prompt, duration=10):
|
||||
"""Async generation with thread pool."""
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
def _generate():
|
||||
with torch.no_grad():
|
||||
self.model.set_generation_params(duration=duration)
|
||||
return self.model.generate([prompt])
|
||||
|
||||
# Run in thread pool
|
||||
wav = await loop.run_in_executor(self.executor, _generate)
|
||||
return wav[0].cpu()
|
||||
|
||||
async def generate_batch_async(self, prompts, duration=10):
|
||||
"""Process multiple prompts concurrently."""
|
||||
tasks = [self.generate_async(p, duration) for p in prompts]
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
# Usage
|
||||
service = MusicGenService()
|
||||
|
||||
async def main():
|
||||
prompts = ["jazz piano", "rock guitar", "electronic beats"]
|
||||
results = await service.generate_batch_async(prompts)
|
||||
return results
|
||||
```
|
||||
|
||||
## Integration Patterns
|
||||
|
||||
### LangChain tool
|
||||
|
||||
```python
|
||||
from langchain.tools import BaseTool
|
||||
import torch
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
import tempfile
|
||||
|
||||
class MusicGeneratorTool(BaseTool):
|
||||
name = "music_generator"
|
||||
description = "Generate music from a text description. Input should be a detailed description of the music style, mood, and instruments."
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
self.model.set_generation_params(duration=15)
|
||||
|
||||
def _run(self, description: str) -> str:
|
||||
with torch.no_grad():
|
||||
wav = self.model.generate([description])
|
||||
|
||||
# Save to temp file
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
|
||||
torchaudio.save(f.name, wav[0].cpu(), sample_rate=32000)
|
||||
return f"Generated music saved to: {f.name}"
|
||||
|
||||
async def _arun(self, description: str) -> str:
|
||||
return self._run(description)
|
||||
```
|
||||
|
||||
### Gradio with advanced controls
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
import torch
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
models = {}
|
||||
|
||||
def load_model(model_size):
|
||||
if model_size not in models:
|
||||
model_name = f"facebook/musicgen-{model_size}"
|
||||
models[model_size] = MusicGen.get_pretrained(model_name)
|
||||
return models[model_size]
|
||||
|
||||
def generate(prompt, duration, temperature, cfg_coef, top_k, model_size):
|
||||
model = load_model(model_size)
|
||||
|
||||
model.set_generation_params(
|
||||
duration=duration,
|
||||
temperature=temperature,
|
||||
cfg_coef=cfg_coef,
|
||||
top_k=top_k
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
wav = model.generate([prompt])
|
||||
|
||||
# Save
|
||||
path = "output.wav"
|
||||
torchaudio.save(path, wav[0].cpu(), sample_rate=32000)
|
||||
return path
|
||||
|
||||
demo = gr.Interface(
|
||||
fn=generate,
|
||||
inputs=[
|
||||
gr.Textbox(label="Prompt", lines=3),
|
||||
gr.Slider(1, 30, value=10, label="Duration (s)"),
|
||||
gr.Slider(0.1, 2.0, value=1.0, label="Temperature"),
|
||||
gr.Slider(0.5, 10.0, value=3.0, label="CFG Coefficient"),
|
||||
gr.Slider(50, 500, value=250, step=50, label="Top-K"),
|
||||
gr.Dropdown(["small", "medium", "large"], value="small", label="Model Size")
|
||||
],
|
||||
outputs=gr.Audio(label="Generated Music"),
|
||||
title="MusicGen Advanced",
|
||||
allow_flagging="never"
|
||||
)
|
||||
|
||||
demo.launch(share=True)
|
||||
```
|
||||
|
||||
## Audio Processing Pipeline
|
||||
|
||||
### Post-processing chain
|
||||
|
||||
```python
|
||||
import torch
|
||||
import torchaudio
|
||||
import torchaudio.transforms as T
|
||||
import numpy as np
|
||||
|
||||
class AudioPostProcessor:
|
||||
def __init__(self, sample_rate=32000):
|
||||
self.sample_rate = sample_rate
|
||||
|
||||
def normalize(self, audio, target_db=-14.0):
|
||||
"""Normalize audio to target loudness."""
|
||||
rms = torch.sqrt(torch.mean(audio ** 2))
|
||||
target_rms = 10 ** (target_db / 20)
|
||||
gain = target_rms / (rms + 1e-8)
|
||||
return audio * gain
|
||||
|
||||
def fade_in_out(self, audio, fade_duration=0.1):
|
||||
"""Apply fade in/out."""
|
||||
fade_samples = int(fade_duration * self.sample_rate)
|
||||
|
||||
# Create fade curves
|
||||
fade_in = torch.linspace(0, 1, fade_samples)
|
||||
fade_out = torch.linspace(1, 0, fade_samples)
|
||||
|
||||
# Apply fades
|
||||
audio[..., :fade_samples] *= fade_in
|
||||
audio[..., -fade_samples:] *= fade_out
|
||||
|
||||
return audio
|
||||
|
||||
def apply_reverb(self, audio, decay=0.5):
|
||||
"""Apply simple reverb effect."""
|
||||
impulse = torch.zeros(int(self.sample_rate * 0.5))
|
||||
impulse[0] = 1.0
|
||||
impulse[int(self.sample_rate * 0.1)] = decay * 0.5
|
||||
impulse[int(self.sample_rate * 0.2)] = decay * 0.25
|
||||
|
||||
# Convolve
|
||||
audio = torch.nn.functional.conv1d(
|
||||
audio.unsqueeze(0),
|
||||
impulse.unsqueeze(0).unsqueeze(0),
|
||||
padding=len(impulse) // 2
|
||||
).squeeze(0)
|
||||
|
||||
return audio
|
||||
|
||||
def process(self, audio):
|
||||
"""Full processing pipeline."""
|
||||
audio = self.normalize(audio)
|
||||
audio = self.fade_in_out(audio)
|
||||
return audio
|
||||
|
||||
# Usage with MusicGen
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
model.set_generation_params(duration=10)
|
||||
|
||||
wav = model.generate(["chill ambient music"])
|
||||
processor = AudioPostProcessor()
|
||||
wav_processed = processor.process(wav[0].cpu())
|
||||
|
||||
torchaudio.save("processed.wav", wav_processed, sample_rate=32000)
|
||||
```
|
||||
|
||||
## Evaluation
|
||||
|
||||
### Audio quality metrics
|
||||
|
||||
```python
|
||||
import torch
|
||||
from audiocraft.metrics import CLAPTextConsistencyMetric
|
||||
from audiocraft.data.audio import audio_read
|
||||
|
||||
def evaluate_generation(audio_path, text_prompt):
|
||||
"""Evaluate generated audio quality."""
|
||||
# Load audio
|
||||
wav, sr = audio_read(audio_path)
|
||||
|
||||
# CLAP consistency (text-audio alignment)
|
||||
clap_metric = CLAPTextConsistencyMetric()
|
||||
clap_score = clap_metric.compute(wav, [text_prompt])
|
||||
|
||||
return {
|
||||
"clap_score": clap_score,
|
||||
"duration": wav.shape[-1] / sr
|
||||
}
|
||||
|
||||
# Batch evaluation
|
||||
def evaluate_batch(generations):
|
||||
"""Evaluate multiple generations."""
|
||||
results = []
|
||||
for gen in generations:
|
||||
result = evaluate_generation(gen["path"], gen["prompt"])
|
||||
result["prompt"] = gen["prompt"]
|
||||
results.append(result)
|
||||
|
||||
# Aggregate
|
||||
avg_clap = sum(r["clap_score"] for r in results) / len(results)
|
||||
return {
|
||||
"individual": results,
|
||||
"average_clap": avg_clap
|
||||
}
|
||||
```
|
||||
|
||||
## Model Comparison
|
||||
|
||||
### MusicGen variants benchmark
|
||||
|
||||
| Model | CLAP Score | Generation Time (10s) | VRAM |
|
||||
|-------|------------|----------------------|------|
|
||||
| musicgen-small | 0.35 | ~5s | 2GB |
|
||||
| musicgen-medium | 0.42 | ~15s | 4GB |
|
||||
| musicgen-large | 0.48 | ~30s | 8GB |
|
||||
| musicgen-melody | 0.45 | ~15s | 4GB |
|
||||
| musicgen-stereo-medium | 0.41 | ~18s | 5GB |
|
||||
|
||||
### Prompt engineering tips
|
||||
|
||||
```python
|
||||
# Good prompts - specific and descriptive
|
||||
good_prompts = [
|
||||
"upbeat electronic dance music with synthesizer leads and punchy drums at 128 bpm",
|
||||
"melancholic piano ballad with strings, slow tempo, emotional and cinematic",
|
||||
"funky disco groove with slap bass, brass section, and rhythmic guitar"
|
||||
]
|
||||
|
||||
# Bad prompts - too vague
|
||||
bad_prompts = [
|
||||
"nice music",
|
||||
"song",
|
||||
"good beat"
|
||||
]
|
||||
|
||||
# Structure: [mood] [genre] with [instruments] at [tempo/style]
|
||||
```
|
||||
504
skills/mlops/models/audiocraft/references/troubleshooting.md
Normal file
504
skills/mlops/models/audiocraft/references/troubleshooting.md
Normal file
@@ -0,0 +1,504 @@
|
||||
# AudioCraft Troubleshooting Guide
|
||||
|
||||
## Installation Issues
|
||||
|
||||
### Import errors
|
||||
|
||||
**Error**: `ModuleNotFoundError: No module named 'audiocraft'`
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Install from PyPI
|
||||
pip install audiocraft
|
||||
|
||||
# Or from GitHub
|
||||
pip install git+https://github.com/facebookresearch/audiocraft.git
|
||||
|
||||
# Verify installation
|
||||
python -c "from audiocraft.models import MusicGen; print('OK')"
|
||||
```
|
||||
|
||||
### FFmpeg not found
|
||||
|
||||
**Error**: `RuntimeError: ffmpeg not found`
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Ubuntu/Debian
|
||||
sudo apt-get install ffmpeg
|
||||
|
||||
# macOS
|
||||
brew install ffmpeg
|
||||
|
||||
# Windows (using conda)
|
||||
conda install -c conda-forge ffmpeg
|
||||
|
||||
# Verify
|
||||
ffmpeg -version
|
||||
```
|
||||
|
||||
### PyTorch CUDA mismatch
|
||||
|
||||
**Error**: `RuntimeError: CUDA error: no kernel image is available`
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Check CUDA version
|
||||
nvcc --version
|
||||
python -c "import torch; print(torch.version.cuda)"
|
||||
|
||||
# Install matching PyTorch
|
||||
pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu121
|
||||
|
||||
# For CUDA 11.8
|
||||
pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu118
|
||||
```
|
||||
|
||||
### xformers issues
|
||||
|
||||
**Error**: `ImportError: xformers` related errors
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Install xformers for memory efficiency
|
||||
pip install xformers
|
||||
|
||||
# Or disable xformers
|
||||
export AUDIOCRAFT_USE_XFORMERS=0
|
||||
|
||||
# In Python
|
||||
import os
|
||||
os.environ["AUDIOCRAFT_USE_XFORMERS"] = "0"
|
||||
from audiocraft.models import MusicGen
|
||||
```
|
||||
|
||||
## Model Loading Issues
|
||||
|
||||
### Out of memory during load
|
||||
|
||||
**Error**: `torch.cuda.OutOfMemoryError` during model loading
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Use smaller model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Force CPU loading first
|
||||
import torch
|
||||
device = "cpu"
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small', device=device)
|
||||
model = model.to("cuda")
|
||||
|
||||
# Use HuggingFace with device_map
|
||||
from transformers import MusicgenForConditionalGeneration
|
||||
model = MusicgenForConditionalGeneration.from_pretrained(
|
||||
"facebook/musicgen-small",
|
||||
device_map="auto"
|
||||
)
|
||||
```
|
||||
|
||||
### Download failures
|
||||
|
||||
**Error**: Connection errors or incomplete downloads
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Set cache directory
|
||||
import os
|
||||
os.environ["AUDIOCRAFT_CACHE_DIR"] = "/path/to/cache"
|
||||
|
||||
# Or for HuggingFace
|
||||
os.environ["HF_HOME"] = "/path/to/hf_cache"
|
||||
|
||||
# Resume download
|
||||
from huggingface_hub import snapshot_download
|
||||
snapshot_download("facebook/musicgen-small", resume_download=True)
|
||||
|
||||
# Use local files
|
||||
model = MusicGen.get_pretrained('/local/path/to/model')
|
||||
```
|
||||
|
||||
### Wrong model type
|
||||
|
||||
**Error**: Loading wrong model for task
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# For text-to-music: use MusicGen
|
||||
from audiocraft.models import MusicGen
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-medium')
|
||||
|
||||
# For text-to-sound: use AudioGen
|
||||
from audiocraft.models import AudioGen
|
||||
model = AudioGen.get_pretrained('facebook/audiogen-medium')
|
||||
|
||||
# For melody conditioning: use melody variant
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-melody')
|
||||
|
||||
# For stereo: use stereo variant
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-stereo-medium')
|
||||
```
|
||||
|
||||
## Generation Issues
|
||||
|
||||
### Empty or silent output
|
||||
|
||||
**Problem**: Generated audio is silent or very quiet
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import torch
|
||||
|
||||
# Check output
|
||||
wav = model.generate(["upbeat music"])
|
||||
print(f"Shape: {wav.shape}")
|
||||
print(f"Max amplitude: {wav.abs().max().item()}")
|
||||
print(f"Mean amplitude: {wav.abs().mean().item()}")
|
||||
|
||||
# If too quiet, normalize
|
||||
def normalize_audio(audio, target_db=-14.0):
|
||||
rms = torch.sqrt(torch.mean(audio ** 2))
|
||||
target_rms = 10 ** (target_db / 20)
|
||||
gain = target_rms / (rms + 1e-8)
|
||||
return audio * gain
|
||||
|
||||
wav_normalized = normalize_audio(wav)
|
||||
```
|
||||
|
||||
### Poor quality output
|
||||
|
||||
**Problem**: Generated music sounds bad or noisy
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Use larger model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-large')
|
||||
|
||||
# Adjust generation parameters
|
||||
model.set_generation_params(
|
||||
duration=15,
|
||||
top_k=250, # Increase for more diversity
|
||||
temperature=0.8, # Lower for more focused output
|
||||
cfg_coef=4.0 # Increase for better text adherence
|
||||
)
|
||||
|
||||
# Use better prompts
|
||||
# Bad: "music"
|
||||
# Good: "upbeat electronic dance music with synthesizers and punchy drums"
|
||||
|
||||
# Try MultiBand Diffusion
|
||||
from audiocraft.models import MultiBandDiffusion
|
||||
mbd = MultiBandDiffusion.get_mbd_musicgen()
|
||||
tokens = model.generate_tokens(["prompt"])
|
||||
wav = mbd.tokens_to_wav(tokens)
|
||||
```
|
||||
|
||||
### Generation too short
|
||||
|
||||
**Problem**: Audio shorter than expected
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Check duration setting
|
||||
model.set_generation_params(duration=30) # Set before generate
|
||||
|
||||
# Verify in generation
|
||||
print(f"Duration setting: {model.generation_params}")
|
||||
|
||||
# Check output shape
|
||||
wav = model.generate(["prompt"])
|
||||
actual_duration = wav.shape[-1] / 32000
|
||||
print(f"Actual duration: {actual_duration}s")
|
||||
|
||||
# Note: max duration is typically 30s
|
||||
```
|
||||
|
||||
### Melody conditioning fails
|
||||
|
||||
**Error**: Issues with melody-conditioned generation
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
# Load melody model (not base model)
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-melody')
|
||||
|
||||
# Load and prepare melody
|
||||
melody, sr = torchaudio.load("melody.wav")
|
||||
|
||||
# Resample to model sample rate if needed
|
||||
if sr != 32000:
|
||||
resampler = torchaudio.transforms.Resample(sr, 32000)
|
||||
melody = resampler(melody)
|
||||
|
||||
# Ensure correct shape [batch, channels, samples]
|
||||
if melody.dim() == 1:
|
||||
melody = melody.unsqueeze(0).unsqueeze(0)
|
||||
elif melody.dim() == 2:
|
||||
melody = melody.unsqueeze(0)
|
||||
|
||||
# Convert stereo to mono
|
||||
if melody.shape[1] > 1:
|
||||
melody = melody.mean(dim=1, keepdim=True)
|
||||
|
||||
# Generate with melody
|
||||
model.set_generation_params(duration=min(melody.shape[-1] / 32000, 30))
|
||||
wav = model.generate_with_chroma(["piano cover"], melody, 32000)
|
||||
```
|
||||
|
||||
## Memory Issues
|
||||
|
||||
### CUDA out of memory
|
||||
|
||||
**Error**: `torch.cuda.OutOfMemoryError: CUDA out of memory`
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import torch
|
||||
|
||||
# Clear cache before generation
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Use smaller model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Reduce duration
|
||||
model.set_generation_params(duration=10) # Instead of 30
|
||||
|
||||
# Generate one at a time
|
||||
for prompt in prompts:
|
||||
wav = model.generate([prompt])
|
||||
save_audio(wav)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Use CPU for very large generations
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small', device="cpu")
|
||||
```
|
||||
|
||||
### Memory leak during batch processing
|
||||
|
||||
**Problem**: Memory grows over time
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import gc
|
||||
import torch
|
||||
|
||||
def generate_with_cleanup(model, prompts):
|
||||
results = []
|
||||
|
||||
for prompt in prompts:
|
||||
with torch.no_grad():
|
||||
wav = model.generate([prompt])
|
||||
results.append(wav.cpu())
|
||||
|
||||
# Cleanup
|
||||
del wav
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return results
|
||||
|
||||
# Use context manager
|
||||
with torch.inference_mode():
|
||||
wav = model.generate(["prompt"])
|
||||
```
|
||||
|
||||
## Audio Format Issues
|
||||
|
||||
### Wrong sample rate
|
||||
|
||||
**Problem**: Audio plays at wrong speed
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import torchaudio
|
||||
|
||||
# MusicGen outputs at 32kHz
|
||||
sample_rate = 32000
|
||||
|
||||
# AudioGen outputs at 16kHz
|
||||
sample_rate = 16000
|
||||
|
||||
# Always use correct rate when saving
|
||||
torchaudio.save("output.wav", wav[0].cpu(), sample_rate=sample_rate)
|
||||
|
||||
# Resample if needed
|
||||
resampler = torchaudio.transforms.Resample(32000, 44100)
|
||||
wav_resampled = resampler(wav)
|
||||
```
|
||||
|
||||
### Stereo/mono mismatch
|
||||
|
||||
**Problem**: Wrong number of channels
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Check model type
|
||||
print(f"Audio channels: {wav.shape}")
|
||||
# Mono: [batch, 1, samples]
|
||||
# Stereo: [batch, 2, samples]
|
||||
|
||||
# Convert mono to stereo
|
||||
if wav.shape[1] == 1:
|
||||
wav_stereo = wav.repeat(1, 2, 1)
|
||||
|
||||
# Convert stereo to mono
|
||||
if wav.shape[1] == 2:
|
||||
wav_mono = wav.mean(dim=1, keepdim=True)
|
||||
|
||||
# Use stereo model for stereo output
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-stereo-medium')
|
||||
```
|
||||
|
||||
### Clipping and distortion
|
||||
|
||||
**Problem**: Audio has clipping or distortion
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import torch
|
||||
|
||||
# Check for clipping
|
||||
max_val = wav.abs().max().item()
|
||||
print(f"Max amplitude: {max_val}")
|
||||
|
||||
# Normalize to prevent clipping
|
||||
if max_val > 1.0:
|
||||
wav = wav / max_val
|
||||
|
||||
# Apply soft clipping
|
||||
def soft_clip(x, threshold=0.9):
|
||||
return torch.tanh(x / threshold) * threshold
|
||||
|
||||
wav_clipped = soft_clip(wav)
|
||||
|
||||
# Lower temperature during generation
|
||||
model.set_generation_params(temperature=0.7) # More controlled
|
||||
```
|
||||
|
||||
## HuggingFace Transformers Issues
|
||||
|
||||
### Processor errors
|
||||
|
||||
**Error**: Issues with MusicgenProcessor
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
||||
|
||||
# Load matching processor and model
|
||||
processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
|
||||
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
|
||||
|
||||
# Ensure inputs are on same device
|
||||
inputs = processor(
|
||||
text=["prompt"],
|
||||
padding=True,
|
||||
return_tensors="pt"
|
||||
).to("cuda")
|
||||
|
||||
# Check processor configuration
|
||||
print(processor.tokenizer)
|
||||
print(processor.feature_extractor)
|
||||
```
|
||||
|
||||
### Generation parameter errors
|
||||
|
||||
**Error**: Invalid generation parameters
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# HuggingFace uses different parameter names
|
||||
audio_values = model.generate(
|
||||
**inputs,
|
||||
do_sample=True, # Enable sampling
|
||||
guidance_scale=3.0, # CFG (not cfg_coef)
|
||||
max_new_tokens=256, # Token limit (not duration)
|
||||
temperature=1.0
|
||||
)
|
||||
|
||||
# Calculate tokens from duration
|
||||
# ~50 tokens per second
|
||||
duration_seconds = 10
|
||||
max_tokens = duration_seconds * 50
|
||||
audio_values = model.generate(**inputs, max_new_tokens=max_tokens)
|
||||
```
|
||||
|
||||
## Performance Issues
|
||||
|
||||
### Slow generation
|
||||
|
||||
**Problem**: Generation takes too long
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Use smaller model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Reduce duration
|
||||
model.set_generation_params(duration=10)
|
||||
|
||||
# Use GPU
|
||||
model.to("cuda")
|
||||
|
||||
# Enable flash attention if available
|
||||
# (requires compatible hardware)
|
||||
|
||||
# Batch multiple prompts
|
||||
prompts = ["prompt1", "prompt2", "prompt3"]
|
||||
wav = model.generate(prompts) # Single batch is faster than loop
|
||||
|
||||
# Use compile (PyTorch 2.0+)
|
||||
model.lm = torch.compile(model.lm)
|
||||
```
|
||||
|
||||
### CPU fallback
|
||||
|
||||
**Problem**: Generation running on CPU instead of GPU
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import torch
|
||||
|
||||
# Check CUDA availability
|
||||
print(f"CUDA available: {torch.cuda.is_available()}")
|
||||
print(f"CUDA device: {torch.cuda.get_device_name(0)}")
|
||||
|
||||
# Explicitly move to GPU
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
model.to("cuda")
|
||||
|
||||
# Verify model device
|
||||
print(f"Model device: {next(model.lm.parameters()).device}")
|
||||
```
|
||||
|
||||
## Common Error Messages
|
||||
|
||||
| Error | Cause | Solution |
|
||||
|-------|-------|----------|
|
||||
| `CUDA out of memory` | Model too large | Use smaller model, reduce duration |
|
||||
| `ffmpeg not found` | FFmpeg not installed | Install FFmpeg |
|
||||
| `No module named 'audiocraft'` | Not installed | `pip install audiocraft` |
|
||||
| `RuntimeError: Expected 3D tensor` | Wrong input shape | Check tensor dimensions |
|
||||
| `KeyError: 'melody'` | Wrong model for melody | Use musicgen-melody |
|
||||
| `Sample rate mismatch` | Wrong audio format | Resample to model rate |
|
||||
|
||||
## Getting Help
|
||||
|
||||
1. **GitHub Issues**: https://github.com/facebookresearch/audiocraft/issues
|
||||
2. **HuggingFace Forums**: https://discuss.huggingface.co
|
||||
3. **Paper**: https://arxiv.org/abs/2306.05284
|
||||
|
||||
### Reporting Issues
|
||||
|
||||
Include:
|
||||
- Python version
|
||||
- PyTorch version
|
||||
- CUDA version
|
||||
- AudioCraft version: `pip show audiocraft`
|
||||
- Full error traceback
|
||||
- Minimal reproducible code
|
||||
- Hardware (GPU model, VRAM)
|
||||
256
skills/mlops/models/clip/SKILL.md
Normal file
256
skills/mlops/models/clip/SKILL.md
Normal file
@@ -0,0 +1,256 @@
|
||||
---
|
||||
name: clip
|
||||
description: OpenAI's model connecting vision and language. Enables zero-shot image classification, image-text matching, and cross-modal retrieval. Trained on 400M image-text pairs. Use for image search, content moderation, or vision-language tasks without fine-tuning. Best for general-purpose image understanding.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [transformers, torch, pillow]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Multimodal, CLIP, Vision-Language, Zero-Shot, Image Classification, OpenAI, Image Search, Cross-Modal Retrieval, Content Moderation]
|
||||
|
||||
---
|
||||
|
||||
# CLIP - Contrastive Language-Image Pre-Training
|
||||
|
||||
OpenAI's model that understands images from natural language.
|
||||
|
||||
## When to use CLIP
|
||||
|
||||
**Use when:**
|
||||
- Zero-shot image classification (no training data needed)
|
||||
- Image-text similarity/matching
|
||||
- Semantic image search
|
||||
- Content moderation (detect NSFW, violence)
|
||||
- Visual question answering
|
||||
- Cross-modal retrieval (image→text, text→image)
|
||||
|
||||
**Metrics**:
|
||||
- **25,300+ GitHub stars**
|
||||
- Trained on 400M image-text pairs
|
||||
- Matches ResNet-50 on ImageNet (zero-shot)
|
||||
- MIT License
|
||||
|
||||
**Use alternatives instead**:
|
||||
- **BLIP-2**: Better captioning
|
||||
- **LLaVA**: Vision-language chat
|
||||
- **Segment Anything**: Image segmentation
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
pip install git+https://github.com/openai/CLIP.git
|
||||
pip install torch torchvision ftfy regex tqdm
|
||||
```
|
||||
|
||||
### Zero-shot classification
|
||||
|
||||
```python
|
||||
import torch
|
||||
import clip
|
||||
from PIL import Image
|
||||
|
||||
# Load model
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model, preprocess = clip.load("ViT-B/32", device=device)
|
||||
|
||||
# Load image
|
||||
image = preprocess(Image.open("photo.jpg")).unsqueeze(0).to(device)
|
||||
|
||||
# Define possible labels
|
||||
text = clip.tokenize(["a dog", "a cat", "a bird", "a car"]).to(device)
|
||||
|
||||
# Compute similarity
|
||||
with torch.no_grad():
|
||||
image_features = model.encode_image(image)
|
||||
text_features = model.encode_text(text)
|
||||
|
||||
# Cosine similarity
|
||||
logits_per_image, logits_per_text = model(image, text)
|
||||
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
|
||||
|
||||
# Print results
|
||||
labels = ["a dog", "a cat", "a bird", "a car"]
|
||||
for label, prob in zip(labels, probs[0]):
|
||||
print(f"{label}: {prob:.2%}")
|
||||
```
|
||||
|
||||
## Available models
|
||||
|
||||
```python
|
||||
# Models (sorted by size)
|
||||
models = [
|
||||
"RN50", # ResNet-50
|
||||
"RN101", # ResNet-101
|
||||
"ViT-B/32", # Vision Transformer (recommended)
|
||||
"ViT-B/16", # Better quality, slower
|
||||
"ViT-L/14", # Best quality, slowest
|
||||
]
|
||||
|
||||
model, preprocess = clip.load("ViT-B/32")
|
||||
```
|
||||
|
||||
| Model | Parameters | Speed | Quality |
|
||||
|-------|------------|-------|---------|
|
||||
| RN50 | 102M | Fast | Good |
|
||||
| ViT-B/32 | 151M | Medium | Better |
|
||||
| ViT-L/14 | 428M | Slow | Best |
|
||||
|
||||
## Image-text similarity
|
||||
|
||||
```python
|
||||
# Compute embeddings
|
||||
image_features = model.encode_image(image)
|
||||
text_features = model.encode_text(text)
|
||||
|
||||
# Normalize
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||
|
||||
# Cosine similarity
|
||||
similarity = (image_features @ text_features.T).item()
|
||||
print(f"Similarity: {similarity:.4f}")
|
||||
```
|
||||
|
||||
## Semantic image search
|
||||
|
||||
```python
|
||||
# Index images
|
||||
image_paths = ["img1.jpg", "img2.jpg", "img3.jpg"]
|
||||
image_embeddings = []
|
||||
|
||||
for img_path in image_paths:
|
||||
image = preprocess(Image.open(img_path)).unsqueeze(0).to(device)
|
||||
with torch.no_grad():
|
||||
embedding = model.encode_image(image)
|
||||
embedding /= embedding.norm(dim=-1, keepdim=True)
|
||||
image_embeddings.append(embedding)
|
||||
|
||||
image_embeddings = torch.cat(image_embeddings)
|
||||
|
||||
# Search with text query
|
||||
query = "a sunset over the ocean"
|
||||
text_input = clip.tokenize([query]).to(device)
|
||||
with torch.no_grad():
|
||||
text_embedding = model.encode_text(text_input)
|
||||
text_embedding /= text_embedding.norm(dim=-1, keepdim=True)
|
||||
|
||||
# Find most similar images
|
||||
similarities = (text_embedding @ image_embeddings.T).squeeze(0)
|
||||
top_k = similarities.topk(3)
|
||||
|
||||
for idx, score in zip(top_k.indices, top_k.values):
|
||||
print(f"{image_paths[idx]}: {score:.3f}")
|
||||
```
|
||||
|
||||
## Content moderation
|
||||
|
||||
```python
|
||||
# Define categories
|
||||
categories = [
|
||||
"safe for work",
|
||||
"not safe for work",
|
||||
"violent content",
|
||||
"graphic content"
|
||||
]
|
||||
|
||||
text = clip.tokenize(categories).to(device)
|
||||
|
||||
# Check image
|
||||
with torch.no_grad():
|
||||
logits_per_image, _ = model(image, text)
|
||||
probs = logits_per_image.softmax(dim=-1)
|
||||
|
||||
# Get classification
|
||||
max_idx = probs.argmax().item()
|
||||
max_prob = probs[0, max_idx].item()
|
||||
|
||||
print(f"Category: {categories[max_idx]} ({max_prob:.2%})")
|
||||
```
|
||||
|
||||
## Batch processing
|
||||
|
||||
```python
|
||||
# Process multiple images
|
||||
images = [preprocess(Image.open(f"img{i}.jpg")) for i in range(10)]
|
||||
images = torch.stack(images).to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
image_features = model.encode_image(images)
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||
|
||||
# Batch text
|
||||
texts = ["a dog", "a cat", "a bird"]
|
||||
text_tokens = clip.tokenize(texts).to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
text_features = model.encode_text(text_tokens)
|
||||
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||
|
||||
# Similarity matrix (10 images × 3 texts)
|
||||
similarities = image_features @ text_features.T
|
||||
print(similarities.shape) # (10, 3)
|
||||
```
|
||||
|
||||
## Integration with vector databases
|
||||
|
||||
```python
|
||||
# Store CLIP embeddings in Chroma/FAISS
|
||||
import chromadb
|
||||
|
||||
client = chromadb.Client()
|
||||
collection = client.create_collection("image_embeddings")
|
||||
|
||||
# Add image embeddings
|
||||
for img_path, embedding in zip(image_paths, image_embeddings):
|
||||
collection.add(
|
||||
embeddings=[embedding.cpu().numpy().tolist()],
|
||||
metadatas=[{"path": img_path}],
|
||||
ids=[img_path]
|
||||
)
|
||||
|
||||
# Query with text
|
||||
query = "a sunset"
|
||||
text_embedding = model.encode_text(clip.tokenize([query]))
|
||||
results = collection.query(
|
||||
query_embeddings=[text_embedding.cpu().numpy().tolist()],
|
||||
n_results=5
|
||||
)
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Use ViT-B/32 for most cases** - Good balance
|
||||
2. **Normalize embeddings** - Required for cosine similarity
|
||||
3. **Batch processing** - More efficient
|
||||
4. **Cache embeddings** - Expensive to recompute
|
||||
5. **Use descriptive labels** - Better zero-shot performance
|
||||
6. **GPU recommended** - 10-50× faster
|
||||
7. **Preprocess images** - Use provided preprocess function
|
||||
|
||||
## Performance
|
||||
|
||||
| Operation | CPU | GPU (V100) |
|
||||
|-----------|-----|------------|
|
||||
| Image encoding | ~200ms | ~20ms |
|
||||
| Text encoding | ~50ms | ~5ms |
|
||||
| Similarity compute | <1ms | <1ms |
|
||||
|
||||
## Limitations
|
||||
|
||||
1. **Not for fine-grained tasks** - Best for broad categories
|
||||
2. **Requires descriptive text** - Vague labels perform poorly
|
||||
3. **Biased on web data** - May have dataset biases
|
||||
4. **No bounding boxes** - Whole image only
|
||||
5. **Limited spatial understanding** - Position/counting weak
|
||||
|
||||
## Resources
|
||||
|
||||
- **GitHub**: https://github.com/openai/CLIP ⭐ 25,300+
|
||||
- **Paper**: https://arxiv.org/abs/2103.00020
|
||||
- **Colab**: https://colab.research.google.com/github/openai/clip/
|
||||
- **License**: MIT
|
||||
|
||||
|
||||
207
skills/mlops/models/clip/references/applications.md
Normal file
207
skills/mlops/models/clip/references/applications.md
Normal file
@@ -0,0 +1,207 @@
|
||||
# CLIP Applications Guide
|
||||
|
||||
Practical applications and use cases for CLIP.
|
||||
|
||||
## Zero-shot image classification
|
||||
|
||||
```python
|
||||
import torch
|
||||
import clip
|
||||
from PIL import Image
|
||||
|
||||
model, preprocess = clip.load("ViT-B/32")
|
||||
|
||||
# Define categories
|
||||
categories = [
|
||||
"a photo of a dog",
|
||||
"a photo of a cat",
|
||||
"a photo of a bird",
|
||||
"a photo of a car",
|
||||
"a photo of a person"
|
||||
]
|
||||
|
||||
# Prepare image
|
||||
image = preprocess(Image.open("photo.jpg")).unsqueeze(0)
|
||||
text = clip.tokenize(categories)
|
||||
|
||||
# Classify
|
||||
with torch.no_grad():
|
||||
image_features = model.encode_image(image)
|
||||
text_features = model.encode_text(text)
|
||||
|
||||
logits_per_image, _ = model(image, text)
|
||||
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
|
||||
|
||||
# Print results
|
||||
for category, prob in zip(categories, probs[0]):
|
||||
print(f"{category}: {prob:.2%}")
|
||||
```
|
||||
|
||||
## Semantic image search
|
||||
|
||||
```python
|
||||
# Index images
|
||||
image_database = []
|
||||
image_paths = ["img1.jpg", "img2.jpg", "img3.jpg"]
|
||||
|
||||
for img_path in image_paths:
|
||||
image = preprocess(Image.open(img_path)).unsqueeze(0)
|
||||
with torch.no_grad():
|
||||
features = model.encode_image(image)
|
||||
features /= features.norm(dim=-1, keepdim=True)
|
||||
image_database.append((img_path, features))
|
||||
|
||||
# Search with text
|
||||
query = "a sunset over mountains"
|
||||
text_input = clip.tokenize([query])
|
||||
|
||||
with torch.no_grad():
|
||||
text_features = model.encode_text(text_input)
|
||||
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||
|
||||
# Find matches
|
||||
similarities = []
|
||||
for img_path, img_features in image_database:
|
||||
similarity = (text_features @ img_features.T).item()
|
||||
similarities.append((img_path, similarity))
|
||||
|
||||
# Sort by similarity
|
||||
similarities.sort(key=lambda x: x[1], reverse=True)
|
||||
for img_path, score in similarities[:3]:
|
||||
print(f"{img_path}: {score:.3f}")
|
||||
```
|
||||
|
||||
## Content moderation
|
||||
|
||||
```python
|
||||
# Define safety categories
|
||||
categories = [
|
||||
"safe for work content",
|
||||
"not safe for work content",
|
||||
"violent or graphic content",
|
||||
"hate speech or offensive content",
|
||||
"spam or misleading content"
|
||||
]
|
||||
|
||||
text = clip.tokenize(categories)
|
||||
|
||||
# Check image
|
||||
with torch.no_grad():
|
||||
logits, _ = model(image, text)
|
||||
probs = logits.softmax(dim=-1)
|
||||
|
||||
# Get classification
|
||||
max_idx = probs.argmax().item()
|
||||
confidence = probs[0, max_idx].item()
|
||||
|
||||
if confidence > 0.7:
|
||||
print(f"Classified as: {categories[max_idx]} ({confidence:.2%})")
|
||||
else:
|
||||
print(f"Uncertain classification (confidence: {confidence:.2%})")
|
||||
```
|
||||
|
||||
## Image-to-text retrieval
|
||||
|
||||
```python
|
||||
# Text database
|
||||
captions = [
|
||||
"A beautiful sunset over the ocean",
|
||||
"A cute dog playing in the park",
|
||||
"A modern city skyline at night",
|
||||
"A delicious pizza with toppings"
|
||||
]
|
||||
|
||||
# Encode captions
|
||||
caption_features = []
|
||||
for caption in captions:
|
||||
text = clip.tokenize([caption])
|
||||
with torch.no_grad():
|
||||
features = model.encode_text(text)
|
||||
features /= features.norm(dim=-1, keepdim=True)
|
||||
caption_features.append(features)
|
||||
|
||||
caption_features = torch.cat(caption_features)
|
||||
|
||||
# Find matching captions for image
|
||||
with torch.no_grad():
|
||||
image_features = model.encode_image(image)
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||
|
||||
similarities = (image_features @ caption_features.T).squeeze(0)
|
||||
top_k = similarities.topk(3)
|
||||
|
||||
for idx, score in zip(top_k.indices, top_k.values):
|
||||
print(f"{captions[idx]}: {score:.3f}")
|
||||
```
|
||||
|
||||
## Visual question answering
|
||||
|
||||
```python
|
||||
# Create yes/no questions
|
||||
image = preprocess(Image.open("photo.jpg")).unsqueeze(0)
|
||||
|
||||
questions = [
|
||||
"a photo showing people",
|
||||
"a photo showing animals",
|
||||
"a photo taken indoors",
|
||||
"a photo taken outdoors",
|
||||
"a photo taken during daytime",
|
||||
"a photo taken at night"
|
||||
]
|
||||
|
||||
text = clip.tokenize(questions)
|
||||
|
||||
with torch.no_grad():
|
||||
logits, _ = model(image, text)
|
||||
probs = logits.softmax(dim=-1)
|
||||
|
||||
# Answer questions
|
||||
for question, prob in zip(questions, probs[0]):
|
||||
answer = "Yes" if prob > 0.5 else "No"
|
||||
print(f"{question}: {answer} ({prob:.2%})")
|
||||
```
|
||||
|
||||
## Image deduplication
|
||||
|
||||
```python
|
||||
# Detect duplicate/similar images
|
||||
def compute_similarity(img1_path, img2_path):
|
||||
img1 = preprocess(Image.open(img1_path)).unsqueeze(0)
|
||||
img2 = preprocess(Image.open(img2_path)).unsqueeze(0)
|
||||
|
||||
with torch.no_grad():
|
||||
feat1 = model.encode_image(img1)
|
||||
feat2 = model.encode_image(img2)
|
||||
|
||||
feat1 /= feat1.norm(dim=-1, keepdim=True)
|
||||
feat2 /= feat2.norm(dim=-1, keepdim=True)
|
||||
|
||||
similarity = (feat1 @ feat2.T).item()
|
||||
|
||||
return similarity
|
||||
|
||||
# Check for duplicates
|
||||
threshold = 0.95
|
||||
image_pairs = [("img1.jpg", "img2.jpg"), ("img1.jpg", "img3.jpg")]
|
||||
|
||||
for img1, img2 in image_pairs:
|
||||
sim = compute_similarity(img1, img2)
|
||||
if sim > threshold:
|
||||
print(f"{img1} and {img2} are duplicates (similarity: {sim:.3f})")
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Use descriptive labels** - "a photo of X" works better than just "X"
|
||||
2. **Normalize embeddings** - Always normalize for cosine similarity
|
||||
3. **Batch processing** - Process multiple images/texts together
|
||||
4. **Cache embeddings** - Expensive to recompute
|
||||
5. **Set appropriate thresholds** - Test on validation data
|
||||
6. **Use GPU** - 10-50× faster than CPU
|
||||
7. **Consider model size** - ViT-B/32 good default, ViT-L/14 for best quality
|
||||
|
||||
## Resources
|
||||
|
||||
- **Paper**: https://arxiv.org/abs/2103.00020
|
||||
- **GitHub**: https://github.com/openai/CLIP
|
||||
- **Colab**: https://colab.research.google.com/github/openai/clip/
|
||||
503
skills/mlops/models/segment-anything/SKILL.md
Normal file
503
skills/mlops/models/segment-anything/SKILL.md
Normal file
@@ -0,0 +1,503 @@
|
||||
---
|
||||
name: segment-anything-model
|
||||
description: Foundation model for image segmentation with zero-shot transfer. Use when you need to segment any object in images using points, boxes, or masks as prompts, or automatically generate all object masks in an image.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [segment-anything, transformers>=4.30.0, torch>=1.7.0]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Multimodal, Image Segmentation, Computer Vision, SAM, Zero-Shot]
|
||||
|
||||
---
|
||||
|
||||
# Segment Anything Model (SAM)
|
||||
|
||||
Comprehensive guide to using Meta AI's Segment Anything Model for zero-shot image segmentation.
|
||||
|
||||
## When to use SAM
|
||||
|
||||
**Use SAM when:**
|
||||
- Need to segment any object in images without task-specific training
|
||||
- Building interactive annotation tools with point/box prompts
|
||||
- Generating training data for other vision models
|
||||
- Need zero-shot transfer to new image domains
|
||||
- Building object detection/segmentation pipelines
|
||||
- Processing medical, satellite, or domain-specific images
|
||||
|
||||
**Key features:**
|
||||
- **Zero-shot segmentation**: Works on any image domain without fine-tuning
|
||||
- **Flexible prompts**: Points, bounding boxes, or previous masks
|
||||
- **Automatic segmentation**: Generate all object masks automatically
|
||||
- **High quality**: Trained on 1.1 billion masks from 11 million images
|
||||
- **Multiple model sizes**: ViT-B (fastest), ViT-L, ViT-H (most accurate)
|
||||
- **ONNX export**: Deploy in browsers and edge devices
|
||||
|
||||
**Use alternatives instead:**
|
||||
- **YOLO/Detectron2**: For real-time object detection with classes
|
||||
- **Mask2Former**: For semantic/panoptic segmentation with categories
|
||||
- **GroundingDINO + SAM**: For text-prompted segmentation
|
||||
- **SAM 2**: For video segmentation tasks
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# From GitHub
|
||||
pip install git+https://github.com/facebookresearch/segment-anything.git
|
||||
|
||||
# Optional dependencies
|
||||
pip install opencv-python pycocotools matplotlib
|
||||
|
||||
# Or use HuggingFace transformers
|
||||
pip install transformers
|
||||
```
|
||||
|
||||
### Download checkpoints
|
||||
|
||||
```bash
|
||||
# ViT-H (largest, most accurate) - 2.4GB
|
||||
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
|
||||
|
||||
# ViT-L (medium) - 1.2GB
|
||||
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth
|
||||
|
||||
# ViT-B (smallest, fastest) - 375MB
|
||||
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth
|
||||
```
|
||||
|
||||
### Basic usage with SamPredictor
|
||||
|
||||
```python
|
||||
import numpy as np
|
||||
from segment_anything import sam_model_registry, SamPredictor
|
||||
|
||||
# Load model
|
||||
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
|
||||
sam.to(device="cuda")
|
||||
|
||||
# Create predictor
|
||||
predictor = SamPredictor(sam)
|
||||
|
||||
# Set image (computes embeddings once)
|
||||
image = cv2.imread("image.jpg")
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
predictor.set_image(image)
|
||||
|
||||
# Predict with point prompts
|
||||
input_point = np.array([[500, 375]]) # (x, y) coordinates
|
||||
input_label = np.array([1]) # 1 = foreground, 0 = background
|
||||
|
||||
masks, scores, logits = predictor.predict(
|
||||
point_coords=input_point,
|
||||
point_labels=input_label,
|
||||
multimask_output=True # Returns 3 mask options
|
||||
)
|
||||
|
||||
# Select best mask
|
||||
best_mask = masks[np.argmax(scores)]
|
||||
```
|
||||
|
||||
### HuggingFace Transformers
|
||||
|
||||
```python
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import SamModel, SamProcessor
|
||||
|
||||
# Load model and processor
|
||||
model = SamModel.from_pretrained("facebook/sam-vit-huge")
|
||||
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
|
||||
model.to("cuda")
|
||||
|
||||
# Process image with point prompt
|
||||
image = Image.open("image.jpg")
|
||||
input_points = [[[450, 600]]] # Batch of points
|
||||
|
||||
inputs = processor(image, input_points=input_points, return_tensors="pt")
|
||||
inputs = {k: v.to("cuda") for k, v in inputs.items()}
|
||||
|
||||
# Generate masks
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
# Post-process masks to original size
|
||||
masks = processor.image_processor.post_process_masks(
|
||||
outputs.pred_masks.cpu(),
|
||||
inputs["original_sizes"].cpu(),
|
||||
inputs["reshaped_input_sizes"].cpu()
|
||||
)
|
||||
```
|
||||
|
||||
## Core concepts
|
||||
|
||||
### Model architecture
|
||||
|
||||
```
|
||||
SAM Architecture:
|
||||
┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
|
||||
│ Image Encoder │────▶│ Prompt Encoder │────▶│ Mask Decoder │
|
||||
│ (ViT) │ │ (Points/Boxes) │ │ (Transformer) │
|
||||
└─────────────────┘ └─────────────────┘ └─────────────────┘
|
||||
│ │ │
|
||||
Image Embeddings Prompt Embeddings Masks + IoU
|
||||
(computed once) (per prompt) predictions
|
||||
```
|
||||
|
||||
### Model variants
|
||||
|
||||
| Model | Checkpoint | Size | Speed | Accuracy |
|
||||
|-------|------------|------|-------|----------|
|
||||
| ViT-H | `vit_h` | 2.4 GB | Slowest | Best |
|
||||
| ViT-L | `vit_l` | 1.2 GB | Medium | Good |
|
||||
| ViT-B | `vit_b` | 375 MB | Fastest | Good |
|
||||
|
||||
### Prompt types
|
||||
|
||||
| Prompt | Description | Use Case |
|
||||
|--------|-------------|----------|
|
||||
| Point (foreground) | Click on object | Single object selection |
|
||||
| Point (background) | Click outside object | Exclude regions |
|
||||
| Bounding box | Rectangle around object | Larger objects |
|
||||
| Previous mask | Low-res mask input | Iterative refinement |
|
||||
|
||||
## Interactive segmentation
|
||||
|
||||
### Point prompts
|
||||
|
||||
```python
|
||||
# Single foreground point
|
||||
input_point = np.array([[500, 375]])
|
||||
input_label = np.array([1])
|
||||
|
||||
masks, scores, logits = predictor.predict(
|
||||
point_coords=input_point,
|
||||
point_labels=input_label,
|
||||
multimask_output=True
|
||||
)
|
||||
|
||||
# Multiple points (foreground + background)
|
||||
input_points = np.array([[500, 375], [600, 400], [450, 300]])
|
||||
input_labels = np.array([1, 1, 0]) # 2 foreground, 1 background
|
||||
|
||||
masks, scores, logits = predictor.predict(
|
||||
point_coords=input_points,
|
||||
point_labels=input_labels,
|
||||
multimask_output=False # Single mask when prompts are clear
|
||||
)
|
||||
```
|
||||
|
||||
### Box prompts
|
||||
|
||||
```python
|
||||
# Bounding box [x1, y1, x2, y2]
|
||||
input_box = np.array([425, 600, 700, 875])
|
||||
|
||||
masks, scores, logits = predictor.predict(
|
||||
box=input_box,
|
||||
multimask_output=False
|
||||
)
|
||||
```
|
||||
|
||||
### Combined prompts
|
||||
|
||||
```python
|
||||
# Box + points for precise control
|
||||
masks, scores, logits = predictor.predict(
|
||||
point_coords=np.array([[500, 375]]),
|
||||
point_labels=np.array([1]),
|
||||
box=np.array([400, 300, 700, 600]),
|
||||
multimask_output=False
|
||||
)
|
||||
```
|
||||
|
||||
### Iterative refinement
|
||||
|
||||
```python
|
||||
# Initial prediction
|
||||
masks, scores, logits = predictor.predict(
|
||||
point_coords=np.array([[500, 375]]),
|
||||
point_labels=np.array([1]),
|
||||
multimask_output=True
|
||||
)
|
||||
|
||||
# Refine with additional point using previous mask
|
||||
masks, scores, logits = predictor.predict(
|
||||
point_coords=np.array([[500, 375], [550, 400]]),
|
||||
point_labels=np.array([1, 0]), # Add background point
|
||||
mask_input=logits[np.argmax(scores)][None, :, :], # Use best mask
|
||||
multimask_output=False
|
||||
)
|
||||
```
|
||||
|
||||
## Automatic mask generation
|
||||
|
||||
### Basic automatic segmentation
|
||||
|
||||
```python
|
||||
from segment_anything import SamAutomaticMaskGenerator
|
||||
|
||||
# Create generator
|
||||
mask_generator = SamAutomaticMaskGenerator(sam)
|
||||
|
||||
# Generate all masks
|
||||
masks = mask_generator.generate(image)
|
||||
|
||||
# Each mask contains:
|
||||
# - segmentation: binary mask
|
||||
# - bbox: [x, y, w, h]
|
||||
# - area: pixel count
|
||||
# - predicted_iou: quality score
|
||||
# - stability_score: robustness score
|
||||
# - point_coords: generating point
|
||||
```
|
||||
|
||||
### Customized generation
|
||||
|
||||
```python
|
||||
mask_generator = SamAutomaticMaskGenerator(
|
||||
model=sam,
|
||||
points_per_side=32, # Grid density (more = more masks)
|
||||
pred_iou_thresh=0.88, # Quality threshold
|
||||
stability_score_thresh=0.95, # Stability threshold
|
||||
crop_n_layers=1, # Multi-scale crops
|
||||
crop_n_points_downscale_factor=2,
|
||||
min_mask_region_area=100, # Remove tiny masks
|
||||
)
|
||||
|
||||
masks = mask_generator.generate(image)
|
||||
```
|
||||
|
||||
### Filtering masks
|
||||
|
||||
```python
|
||||
# Sort by area (largest first)
|
||||
masks = sorted(masks, key=lambda x: x['area'], reverse=True)
|
||||
|
||||
# Filter by predicted IoU
|
||||
high_quality = [m for m in masks if m['predicted_iou'] > 0.9]
|
||||
|
||||
# Filter by stability score
|
||||
stable_masks = [m for m in masks if m['stability_score'] > 0.95]
|
||||
```
|
||||
|
||||
## Batched inference
|
||||
|
||||
### Multiple images
|
||||
|
||||
```python
|
||||
# Process multiple images efficiently
|
||||
images = [cv2.imread(f"image_{i}.jpg") for i in range(10)]
|
||||
|
||||
all_masks = []
|
||||
for image in images:
|
||||
predictor.set_image(image)
|
||||
masks, _, _ = predictor.predict(
|
||||
point_coords=np.array([[500, 375]]),
|
||||
point_labels=np.array([1]),
|
||||
multimask_output=True
|
||||
)
|
||||
all_masks.append(masks)
|
||||
```
|
||||
|
||||
### Multiple prompts per image
|
||||
|
||||
```python
|
||||
# Process multiple prompts efficiently (one image encoding)
|
||||
predictor.set_image(image)
|
||||
|
||||
# Batch of point prompts
|
||||
points = [
|
||||
np.array([[100, 100]]),
|
||||
np.array([[200, 200]]),
|
||||
np.array([[300, 300]])
|
||||
]
|
||||
|
||||
all_masks = []
|
||||
for point in points:
|
||||
masks, scores, _ = predictor.predict(
|
||||
point_coords=point,
|
||||
point_labels=np.array([1]),
|
||||
multimask_output=True
|
||||
)
|
||||
all_masks.append(masks[np.argmax(scores)])
|
||||
```
|
||||
|
||||
## ONNX deployment
|
||||
|
||||
### Export model
|
||||
|
||||
```bash
|
||||
python scripts/export_onnx_model.py \
|
||||
--checkpoint sam_vit_h_4b8939.pth \
|
||||
--model-type vit_h \
|
||||
--output sam_onnx.onnx \
|
||||
--return-single-mask
|
||||
```
|
||||
|
||||
### Use ONNX model
|
||||
|
||||
```python
|
||||
import onnxruntime
|
||||
|
||||
# Load ONNX model
|
||||
ort_session = onnxruntime.InferenceSession("sam_onnx.onnx")
|
||||
|
||||
# Run inference (image embeddings computed separately)
|
||||
masks = ort_session.run(
|
||||
None,
|
||||
{
|
||||
"image_embeddings": image_embeddings,
|
||||
"point_coords": point_coords,
|
||||
"point_labels": point_labels,
|
||||
"mask_input": np.zeros((1, 1, 256, 256), dtype=np.float32),
|
||||
"has_mask_input": np.array([0], dtype=np.float32),
|
||||
"orig_im_size": np.array([h, w], dtype=np.float32)
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
## Common workflows
|
||||
|
||||
### Workflow 1: Annotation tool
|
||||
|
||||
```python
|
||||
import cv2
|
||||
|
||||
# Load model
|
||||
predictor = SamPredictor(sam)
|
||||
predictor.set_image(image)
|
||||
|
||||
def on_click(event, x, y, flags, param):
|
||||
if event == cv2.EVENT_LBUTTONDOWN:
|
||||
# Foreground point
|
||||
masks, scores, _ = predictor.predict(
|
||||
point_coords=np.array([[x, y]]),
|
||||
point_labels=np.array([1]),
|
||||
multimask_output=True
|
||||
)
|
||||
# Display best mask
|
||||
display_mask(masks[np.argmax(scores)])
|
||||
```
|
||||
|
||||
### Workflow 2: Object extraction
|
||||
|
||||
```python
|
||||
def extract_object(image, point):
|
||||
"""Extract object at point with transparent background."""
|
||||
predictor.set_image(image)
|
||||
|
||||
masks, scores, _ = predictor.predict(
|
||||
point_coords=np.array([point]),
|
||||
point_labels=np.array([1]),
|
||||
multimask_output=True
|
||||
)
|
||||
|
||||
best_mask = masks[np.argmax(scores)]
|
||||
|
||||
# Create RGBA output
|
||||
rgba = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8)
|
||||
rgba[:, :, :3] = image
|
||||
rgba[:, :, 3] = best_mask * 255
|
||||
|
||||
return rgba
|
||||
```
|
||||
|
||||
### Workflow 3: Medical image segmentation
|
||||
|
||||
```python
|
||||
# Process medical images (grayscale to RGB)
|
||||
medical_image = cv2.imread("scan.png", cv2.IMREAD_GRAYSCALE)
|
||||
rgb_image = cv2.cvtColor(medical_image, cv2.COLOR_GRAY2RGB)
|
||||
|
||||
predictor.set_image(rgb_image)
|
||||
|
||||
# Segment region of interest
|
||||
masks, scores, _ = predictor.predict(
|
||||
box=np.array([x1, y1, x2, y2]), # ROI bounding box
|
||||
multimask_output=True
|
||||
)
|
||||
```
|
||||
|
||||
## Output format
|
||||
|
||||
### Mask data structure
|
||||
|
||||
```python
|
||||
# SamAutomaticMaskGenerator output
|
||||
{
|
||||
"segmentation": np.ndarray, # H×W binary mask
|
||||
"bbox": [x, y, w, h], # Bounding box
|
||||
"area": int, # Pixel count
|
||||
"predicted_iou": float, # 0-1 quality score
|
||||
"stability_score": float, # 0-1 robustness score
|
||||
"crop_box": [x, y, w, h], # Generation crop region
|
||||
"point_coords": [[x, y]], # Input point
|
||||
}
|
||||
```
|
||||
|
||||
### COCO RLE format
|
||||
|
||||
```python
|
||||
from pycocotools import mask as mask_utils
|
||||
|
||||
# Encode mask to RLE
|
||||
rle = mask_utils.encode(np.asfortranarray(mask.astype(np.uint8)))
|
||||
rle["counts"] = rle["counts"].decode("utf-8")
|
||||
|
||||
# Decode RLE to mask
|
||||
decoded_mask = mask_utils.decode(rle)
|
||||
```
|
||||
|
||||
## Performance optimization
|
||||
|
||||
### GPU memory
|
||||
|
||||
```python
|
||||
# Use smaller model for limited VRAM
|
||||
sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth")
|
||||
|
||||
# Process images in batches
|
||||
# Clear CUDA cache between large batches
|
||||
torch.cuda.empty_cache()
|
||||
```
|
||||
|
||||
### Speed optimization
|
||||
|
||||
```python
|
||||
# Use half precision
|
||||
sam = sam.half()
|
||||
|
||||
# Reduce points for automatic generation
|
||||
mask_generator = SamAutomaticMaskGenerator(
|
||||
model=sam,
|
||||
points_per_side=16, # Default is 32
|
||||
)
|
||||
|
||||
# Use ONNX for deployment
|
||||
# Export with --return-single-mask for faster inference
|
||||
```
|
||||
|
||||
## Common issues
|
||||
|
||||
| Issue | Solution |
|
||||
|-------|----------|
|
||||
| Out of memory | Use ViT-B model, reduce image size |
|
||||
| Slow inference | Use ViT-B, reduce points_per_side |
|
||||
| Poor mask quality | Try different prompts, use box + points |
|
||||
| Edge artifacts | Use stability_score filtering |
|
||||
| Small objects missed | Increase points_per_side |
|
||||
|
||||
## References
|
||||
|
||||
- **[Advanced Usage](references/advanced-usage.md)** - Batching, fine-tuning, integration
|
||||
- **[Troubleshooting](references/troubleshooting.md)** - Common issues and solutions
|
||||
|
||||
## Resources
|
||||
|
||||
- **GitHub**: https://github.com/facebookresearch/segment-anything
|
||||
- **Paper**: https://arxiv.org/abs/2304.02643
|
||||
- **Demo**: https://segment-anything.com
|
||||
- **SAM 2 (Video)**: https://github.com/facebookresearch/segment-anything-2
|
||||
- **HuggingFace**: https://huggingface.co/facebook/sam-vit-huge
|
||||
@@ -0,0 +1,589 @@
|
||||
# Segment Anything Advanced Usage Guide
|
||||
|
||||
## SAM 2 (Video Segmentation)
|
||||
|
||||
### Overview
|
||||
|
||||
SAM 2 extends SAM to video segmentation with streaming memory architecture:
|
||||
|
||||
```bash
|
||||
pip install git+https://github.com/facebookresearch/segment-anything-2.git
|
||||
```
|
||||
|
||||
### Video segmentation
|
||||
|
||||
```python
|
||||
from sam2.build_sam import build_sam2_video_predictor
|
||||
|
||||
predictor = build_sam2_video_predictor("sam2_hiera_l.yaml", "sam2_hiera_large.pt")
|
||||
|
||||
# Initialize with video
|
||||
predictor.init_state(video_path="video.mp4")
|
||||
|
||||
# Add prompt on first frame
|
||||
predictor.add_new_points(
|
||||
frame_idx=0,
|
||||
obj_id=1,
|
||||
points=[[100, 200]],
|
||||
labels=[1]
|
||||
)
|
||||
|
||||
# Propagate through video
|
||||
for frame_idx, masks in predictor.propagate_in_video():
|
||||
# masks contains segmentation for all tracked objects
|
||||
process_frame(frame_idx, masks)
|
||||
```
|
||||
|
||||
### SAM 2 vs SAM comparison
|
||||
|
||||
| Feature | SAM | SAM 2 |
|
||||
|---------|-----|-------|
|
||||
| Input | Images only | Images + Videos |
|
||||
| Architecture | ViT + Decoder | Hiera + Memory |
|
||||
| Memory | Per-image | Streaming memory bank |
|
||||
| Tracking | No | Yes, across frames |
|
||||
| Models | ViT-B/L/H | Hiera-T/S/B+/L |
|
||||
|
||||
## Grounded SAM (Text-Prompted Segmentation)
|
||||
|
||||
### Setup
|
||||
|
||||
```bash
|
||||
pip install groundingdino-py
|
||||
pip install git+https://github.com/facebookresearch/segment-anything.git
|
||||
```
|
||||
|
||||
### Text-to-mask pipeline
|
||||
|
||||
```python
|
||||
from groundingdino.util.inference import load_model, predict
|
||||
from segment_anything import sam_model_registry, SamPredictor
|
||||
import cv2
|
||||
|
||||
# Load Grounding DINO
|
||||
grounding_model = load_model("groundingdino_swint_ogc.pth", "GroundingDINO_SwinT_OGC.py")
|
||||
|
||||
# Load SAM
|
||||
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
|
||||
predictor = SamPredictor(sam)
|
||||
|
||||
def text_to_mask(image, text_prompt, box_threshold=0.3, text_threshold=0.25):
|
||||
"""Generate masks from text description."""
|
||||
# Get bounding boxes from text
|
||||
boxes, logits, phrases = predict(
|
||||
model=grounding_model,
|
||||
image=image,
|
||||
caption=text_prompt,
|
||||
box_threshold=box_threshold,
|
||||
text_threshold=text_threshold
|
||||
)
|
||||
|
||||
# Generate masks with SAM
|
||||
predictor.set_image(image)
|
||||
|
||||
masks = []
|
||||
for box in boxes:
|
||||
# Convert normalized box to pixel coordinates
|
||||
h, w = image.shape[:2]
|
||||
box_pixels = box * np.array([w, h, w, h])
|
||||
|
||||
mask, score, _ = predictor.predict(
|
||||
box=box_pixels,
|
||||
multimask_output=False
|
||||
)
|
||||
masks.append(mask[0])
|
||||
|
||||
return masks, boxes, phrases
|
||||
|
||||
# Usage
|
||||
image = cv2.imread("image.jpg")
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
masks, boxes, phrases = text_to_mask(image, "person . dog . car")
|
||||
```
|
||||
|
||||
## Batched Processing
|
||||
|
||||
### Efficient multi-image processing
|
||||
|
||||
```python
|
||||
import torch
|
||||
from segment_anything import SamPredictor, sam_model_registry
|
||||
|
||||
class BatchedSAM:
|
||||
def __init__(self, checkpoint, model_type="vit_h", device="cuda"):
|
||||
self.sam = sam_model_registry[model_type](checkpoint=checkpoint)
|
||||
self.sam.to(device)
|
||||
self.predictor = SamPredictor(self.sam)
|
||||
self.device = device
|
||||
|
||||
def process_batch(self, images, prompts):
|
||||
"""Process multiple images with corresponding prompts."""
|
||||
results = []
|
||||
|
||||
for image, prompt in zip(images, prompts):
|
||||
self.predictor.set_image(image)
|
||||
|
||||
if "point" in prompt:
|
||||
masks, scores, _ = self.predictor.predict(
|
||||
point_coords=prompt["point"],
|
||||
point_labels=prompt["label"],
|
||||
multimask_output=True
|
||||
)
|
||||
elif "box" in prompt:
|
||||
masks, scores, _ = self.predictor.predict(
|
||||
box=prompt["box"],
|
||||
multimask_output=False
|
||||
)
|
||||
|
||||
results.append({
|
||||
"masks": masks,
|
||||
"scores": scores,
|
||||
"best_mask": masks[np.argmax(scores)]
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
# Usage
|
||||
batch_sam = BatchedSAM("sam_vit_h_4b8939.pth")
|
||||
|
||||
images = [cv2.imread(f"image_{i}.jpg") for i in range(10)]
|
||||
prompts = [{"point": np.array([[100, 100]]), "label": np.array([1])} for _ in range(10)]
|
||||
|
||||
results = batch_sam.process_batch(images, prompts)
|
||||
```
|
||||
|
||||
### Parallel automatic mask generation
|
||||
|
||||
```python
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from segment_anything import SamAutomaticMaskGenerator
|
||||
|
||||
def generate_masks_parallel(images, num_workers=4):
|
||||
"""Generate masks for multiple images in parallel."""
|
||||
# Note: Each worker needs its own model instance
|
||||
def worker_init():
|
||||
sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth")
|
||||
return SamAutomaticMaskGenerator(sam)
|
||||
|
||||
generators = [worker_init() for _ in range(num_workers)]
|
||||
|
||||
def process_image(args):
|
||||
idx, image = args
|
||||
generator = generators[idx % num_workers]
|
||||
return generator.generate(image)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
results = list(executor.map(process_image, enumerate(images)))
|
||||
|
||||
return results
|
||||
```
|
||||
|
||||
## Custom Integration
|
||||
|
||||
### FastAPI service
|
||||
|
||||
```python
|
||||
from fastapi import FastAPI, File, UploadFile
|
||||
from pydantic import BaseModel
|
||||
import numpy as np
|
||||
import cv2
|
||||
import io
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Load model once
|
||||
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
|
||||
sam.to("cuda")
|
||||
predictor = SamPredictor(sam)
|
||||
|
||||
class PointPrompt(BaseModel):
|
||||
x: int
|
||||
y: int
|
||||
label: int = 1
|
||||
|
||||
@app.post("/segment/point")
|
||||
async def segment_with_point(
|
||||
file: UploadFile = File(...),
|
||||
points: list[PointPrompt] = []
|
||||
):
|
||||
# Read image
|
||||
contents = await file.read()
|
||||
nparr = np.frombuffer(contents, np.uint8)
|
||||
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# Set image
|
||||
predictor.set_image(image)
|
||||
|
||||
# Prepare prompts
|
||||
point_coords = np.array([[p.x, p.y] for p in points])
|
||||
point_labels = np.array([p.label for p in points])
|
||||
|
||||
# Generate masks
|
||||
masks, scores, _ = predictor.predict(
|
||||
point_coords=point_coords,
|
||||
point_labels=point_labels,
|
||||
multimask_output=True
|
||||
)
|
||||
|
||||
best_idx = np.argmax(scores)
|
||||
|
||||
return {
|
||||
"mask": masks[best_idx].tolist(),
|
||||
"score": float(scores[best_idx]),
|
||||
"all_scores": scores.tolist()
|
||||
}
|
||||
|
||||
@app.post("/segment/auto")
|
||||
async def segment_automatic(file: UploadFile = File(...)):
|
||||
contents = await file.read()
|
||||
nparr = np.frombuffer(contents, np.uint8)
|
||||
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
mask_generator = SamAutomaticMaskGenerator(sam)
|
||||
masks = mask_generator.generate(image)
|
||||
|
||||
return {
|
||||
"num_masks": len(masks),
|
||||
"masks": [
|
||||
{
|
||||
"bbox": m["bbox"],
|
||||
"area": m["area"],
|
||||
"predicted_iou": m["predicted_iou"],
|
||||
"stability_score": m["stability_score"]
|
||||
}
|
||||
for m in masks
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### Gradio interface
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
|
||||
# Load model
|
||||
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
|
||||
predictor = SamPredictor(sam)
|
||||
|
||||
def segment_image(image, evt: gr.SelectData):
|
||||
"""Segment object at clicked point."""
|
||||
predictor.set_image(image)
|
||||
|
||||
point = np.array([[evt.index[0], evt.index[1]]])
|
||||
label = np.array([1])
|
||||
|
||||
masks, scores, _ = predictor.predict(
|
||||
point_coords=point,
|
||||
point_labels=label,
|
||||
multimask_output=True
|
||||
)
|
||||
|
||||
best_mask = masks[np.argmax(scores)]
|
||||
|
||||
# Overlay mask on image
|
||||
overlay = image.copy()
|
||||
overlay[best_mask] = overlay[best_mask] * 0.5 + np.array([255, 0, 0]) * 0.5
|
||||
|
||||
return overlay
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown("# SAM Interactive Segmentation")
|
||||
gr.Markdown("Click on an object to segment it")
|
||||
|
||||
with gr.Row():
|
||||
input_image = gr.Image(label="Input Image", interactive=True)
|
||||
output_image = gr.Image(label="Segmented Image")
|
||||
|
||||
input_image.select(segment_image, inputs=[input_image], outputs=[output_image])
|
||||
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
## Fine-Tuning SAM
|
||||
|
||||
### LoRA fine-tuning (experimental)
|
||||
|
||||
```python
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from transformers import SamModel
|
||||
|
||||
# Load model
|
||||
model = SamModel.from_pretrained("facebook/sam-vit-base")
|
||||
|
||||
# Configure LoRA
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
target_modules=["qkv"], # Attention layers
|
||||
lora_dropout=0.1,
|
||||
bias="none",
|
||||
)
|
||||
|
||||
# Apply LoRA
|
||||
model = get_peft_model(model, lora_config)
|
||||
|
||||
# Training loop (simplified)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
|
||||
|
||||
for batch in dataloader:
|
||||
outputs = model(
|
||||
pixel_values=batch["pixel_values"],
|
||||
input_points=batch["input_points"],
|
||||
input_labels=batch["input_labels"]
|
||||
)
|
||||
|
||||
# Custom loss (e.g., IoU loss with ground truth)
|
||||
loss = compute_loss(outputs.pred_masks, batch["gt_masks"])
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
```
|
||||
|
||||
### MedSAM (Medical imaging)
|
||||
|
||||
```python
|
||||
# MedSAM is a fine-tuned SAM for medical images
|
||||
# https://github.com/bowang-lab/MedSAM
|
||||
|
||||
from segment_anything import sam_model_registry, SamPredictor
|
||||
import torch
|
||||
|
||||
# Load MedSAM checkpoint
|
||||
medsam = sam_model_registry["vit_b"](checkpoint="medsam_vit_b.pth")
|
||||
medsam.to("cuda")
|
||||
|
||||
predictor = SamPredictor(medsam)
|
||||
|
||||
# Process medical image
|
||||
# Convert grayscale to RGB if needed
|
||||
medical_image = cv2.imread("ct_scan.png", cv2.IMREAD_GRAYSCALE)
|
||||
rgb_image = np.stack([medical_image] * 3, axis=-1)
|
||||
|
||||
predictor.set_image(rgb_image)
|
||||
|
||||
# Segment with box prompt (common for medical imaging)
|
||||
masks, scores, _ = predictor.predict(
|
||||
box=np.array([x1, y1, x2, y2]),
|
||||
multimask_output=False
|
||||
)
|
||||
```
|
||||
|
||||
## Advanced Mask Processing
|
||||
|
||||
### Mask refinement
|
||||
|
||||
```python
|
||||
import cv2
|
||||
from scipy import ndimage
|
||||
|
||||
def refine_mask(mask, kernel_size=5, iterations=2):
|
||||
"""Refine mask with morphological operations."""
|
||||
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
|
||||
|
||||
# Close small holes
|
||||
closed = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, kernel, iterations=iterations)
|
||||
|
||||
# Remove small noise
|
||||
opened = cv2.morphologyEx(closed, cv2.MORPH_OPEN, kernel, iterations=iterations)
|
||||
|
||||
return opened.astype(bool)
|
||||
|
||||
def fill_holes(mask):
|
||||
"""Fill holes in mask."""
|
||||
filled = ndimage.binary_fill_holes(mask)
|
||||
return filled
|
||||
|
||||
def remove_small_regions(mask, min_area=100):
|
||||
"""Remove small disconnected regions."""
|
||||
labeled, num_features = ndimage.label(mask)
|
||||
sizes = ndimage.sum(mask, labeled, range(1, num_features + 1))
|
||||
|
||||
# Keep only regions larger than min_area
|
||||
mask_clean = np.zeros_like(mask)
|
||||
for i, size in enumerate(sizes, 1):
|
||||
if size >= min_area:
|
||||
mask_clean[labeled == i] = True
|
||||
|
||||
return mask_clean
|
||||
```
|
||||
|
||||
### Mask to polygon conversion
|
||||
|
||||
```python
|
||||
import cv2
|
||||
|
||||
def mask_to_polygons(mask, epsilon_factor=0.01):
|
||||
"""Convert binary mask to polygon coordinates."""
|
||||
contours, _ = cv2.findContours(
|
||||
mask.astype(np.uint8),
|
||||
cv2.RETR_EXTERNAL,
|
||||
cv2.CHAIN_APPROX_SIMPLE
|
||||
)
|
||||
|
||||
polygons = []
|
||||
for contour in contours:
|
||||
epsilon = epsilon_factor * cv2.arcLength(contour, True)
|
||||
approx = cv2.approxPolyDP(contour, epsilon, True)
|
||||
polygon = approx.squeeze().tolist()
|
||||
if len(polygon) >= 3: # Valid polygon
|
||||
polygons.append(polygon)
|
||||
|
||||
return polygons
|
||||
|
||||
def polygons_to_mask(polygons, height, width):
|
||||
"""Convert polygons back to binary mask."""
|
||||
mask = np.zeros((height, width), dtype=np.uint8)
|
||||
for polygon in polygons:
|
||||
pts = np.array(polygon, dtype=np.int32)
|
||||
cv2.fillPoly(mask, [pts], 1)
|
||||
return mask.astype(bool)
|
||||
```
|
||||
|
||||
### Multi-scale segmentation
|
||||
|
||||
```python
|
||||
def multiscale_segment(image, predictor, point, scales=[0.5, 1.0, 2.0]):
|
||||
"""Generate masks at multiple scales and combine."""
|
||||
h, w = image.shape[:2]
|
||||
masks_all = []
|
||||
|
||||
for scale in scales:
|
||||
# Resize image
|
||||
new_h, new_w = int(h * scale), int(w * scale)
|
||||
scaled_image = cv2.resize(image, (new_w, new_h))
|
||||
scaled_point = (point * scale).astype(int)
|
||||
|
||||
# Segment
|
||||
predictor.set_image(scaled_image)
|
||||
masks, scores, _ = predictor.predict(
|
||||
point_coords=scaled_point.reshape(1, 2),
|
||||
point_labels=np.array([1]),
|
||||
multimask_output=True
|
||||
)
|
||||
|
||||
# Resize mask back
|
||||
best_mask = masks[np.argmax(scores)]
|
||||
original_mask = cv2.resize(best_mask.astype(np.uint8), (w, h)) > 0.5
|
||||
|
||||
masks_all.append(original_mask)
|
||||
|
||||
# Combine masks (majority voting)
|
||||
combined = np.stack(masks_all, axis=0)
|
||||
final_mask = np.sum(combined, axis=0) >= len(scales) // 2 + 1
|
||||
|
||||
return final_mask
|
||||
```
|
||||
|
||||
## Performance Optimization
|
||||
|
||||
### TensorRT acceleration
|
||||
|
||||
```python
|
||||
import tensorrt as trt
|
||||
import pycuda.driver as cuda
|
||||
import pycuda.autoinit
|
||||
|
||||
def export_to_tensorrt(onnx_path, engine_path, fp16=True):
|
||||
"""Convert ONNX model to TensorRT engine."""
|
||||
logger = trt.Logger(trt.Logger.WARNING)
|
||||
builder = trt.Builder(logger)
|
||||
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
|
||||
parser = trt.OnnxParser(network, logger)
|
||||
|
||||
with open(onnx_path, 'rb') as f:
|
||||
if not parser.parse(f.read()):
|
||||
for error in range(parser.num_errors):
|
||||
print(parser.get_error(error))
|
||||
return None
|
||||
|
||||
config = builder.create_builder_config()
|
||||
config.max_workspace_size = 1 << 30 # 1GB
|
||||
|
||||
if fp16:
|
||||
config.set_flag(trt.BuilderFlag.FP16)
|
||||
|
||||
engine = builder.build_engine(network, config)
|
||||
|
||||
with open(engine_path, 'wb') as f:
|
||||
f.write(engine.serialize())
|
||||
|
||||
return engine
|
||||
```
|
||||
|
||||
### Memory-efficient inference
|
||||
|
||||
```python
|
||||
class MemoryEfficientSAM:
|
||||
def __init__(self, checkpoint, model_type="vit_b"):
|
||||
self.sam = sam_model_registry[model_type](checkpoint=checkpoint)
|
||||
self.sam.eval()
|
||||
self.predictor = None
|
||||
|
||||
def __enter__(self):
|
||||
self.sam.to("cuda")
|
||||
self.predictor = SamPredictor(self.sam)
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.sam.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def segment(self, image, points, labels):
|
||||
self.predictor.set_image(image)
|
||||
masks, scores, _ = self.predictor.predict(
|
||||
point_coords=points,
|
||||
point_labels=labels,
|
||||
multimask_output=True
|
||||
)
|
||||
return masks, scores
|
||||
|
||||
# Usage with context manager (auto-cleanup)
|
||||
with MemoryEfficientSAM("sam_vit_b_01ec64.pth") as sam:
|
||||
masks, scores = sam.segment(image, points, labels)
|
||||
# CUDA memory freed automatically
|
||||
```
|
||||
|
||||
## Dataset Generation
|
||||
|
||||
### Create segmentation dataset
|
||||
|
||||
```python
|
||||
import json
|
||||
|
||||
def generate_dataset(images_dir, output_dir, mask_generator):
|
||||
"""Generate segmentation dataset from images."""
|
||||
annotations = []
|
||||
|
||||
for img_path in Path(images_dir).glob("*.jpg"):
|
||||
image = cv2.imread(str(img_path))
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# Generate masks
|
||||
masks = mask_generator.generate(image)
|
||||
|
||||
# Filter high-quality masks
|
||||
good_masks = [m for m in masks if m["predicted_iou"] > 0.9]
|
||||
|
||||
# Save annotations
|
||||
for i, mask_data in enumerate(good_masks):
|
||||
annotation = {
|
||||
"image_id": img_path.stem,
|
||||
"mask_id": i,
|
||||
"bbox": mask_data["bbox"],
|
||||
"area": mask_data["area"],
|
||||
"segmentation": mask_to_rle(mask_data["segmentation"]),
|
||||
"predicted_iou": mask_data["predicted_iou"],
|
||||
"stability_score": mask_data["stability_score"]
|
||||
}
|
||||
annotations.append(annotation)
|
||||
|
||||
# Save dataset
|
||||
with open(output_dir / "annotations.json", "w") as f:
|
||||
json.dump(annotations, f)
|
||||
|
||||
return annotations
|
||||
```
|
||||
@@ -0,0 +1,484 @@
|
||||
# Segment Anything Troubleshooting Guide
|
||||
|
||||
## Installation Issues
|
||||
|
||||
### CUDA not available
|
||||
|
||||
**Error**: `RuntimeError: CUDA not available`
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Check CUDA availability
|
||||
import torch
|
||||
print(torch.cuda.is_available())
|
||||
print(torch.version.cuda)
|
||||
|
||||
# Install PyTorch with CUDA
|
||||
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121
|
||||
|
||||
# If CUDA works but SAM doesn't use it
|
||||
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
|
||||
sam.to("cuda") # Explicitly move to GPU
|
||||
```
|
||||
|
||||
### Import errors
|
||||
|
||||
**Error**: `ModuleNotFoundError: No module named 'segment_anything'`
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Install from GitHub
|
||||
pip install git+https://github.com/facebookresearch/segment-anything.git
|
||||
|
||||
# Or clone and install
|
||||
git clone https://github.com/facebookresearch/segment-anything.git
|
||||
cd segment-anything
|
||||
pip install -e .
|
||||
|
||||
# Verify installation
|
||||
python -c "from segment_anything import sam_model_registry; print('OK')"
|
||||
```
|
||||
|
||||
### Missing dependencies
|
||||
|
||||
**Error**: `ModuleNotFoundError: No module named 'cv2'` or similar
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Install all optional dependencies
|
||||
pip install opencv-python pycocotools matplotlib onnxruntime onnx
|
||||
|
||||
# For pycocotools on Windows
|
||||
pip install pycocotools-windows
|
||||
```
|
||||
|
||||
## Model Loading Issues
|
||||
|
||||
### Checkpoint not found
|
||||
|
||||
**Error**: `FileNotFoundError: checkpoint file not found`
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Download correct checkpoint
|
||||
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
|
||||
|
||||
# Verify file integrity
|
||||
md5sum sam_vit_h_4b8939.pth
|
||||
# Expected: a7bf3b02f3ebf1267aba913ff637d9a2
|
||||
|
||||
# Use absolute path
|
||||
sam = sam_model_registry["vit_h"](checkpoint="/full/path/to/sam_vit_h_4b8939.pth")
|
||||
```
|
||||
|
||||
### Model type mismatch
|
||||
|
||||
**Error**: `KeyError: 'unexpected key in state_dict'`
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Ensure model type matches checkpoint
|
||||
# vit_h checkpoint → vit_h model
|
||||
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
|
||||
|
||||
# vit_l checkpoint → vit_l model
|
||||
sam = sam_model_registry["vit_l"](checkpoint="sam_vit_l_0b3195.pth")
|
||||
|
||||
# vit_b checkpoint → vit_b model
|
||||
sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth")
|
||||
```
|
||||
|
||||
### Out of memory during load
|
||||
|
||||
**Error**: `CUDA out of memory` during model loading
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Use smaller model
|
||||
sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth")
|
||||
|
||||
# Load to CPU first, then move
|
||||
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
|
||||
sam.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
sam.to("cuda")
|
||||
|
||||
# Use half precision
|
||||
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
|
||||
sam = sam.half()
|
||||
sam.to("cuda")
|
||||
```
|
||||
|
||||
## Inference Issues
|
||||
|
||||
### Image format errors
|
||||
|
||||
**Error**: `ValueError: expected input to have 3 channels`
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import cv2
|
||||
|
||||
# Ensure RGB format
|
||||
image = cv2.imread("image.jpg")
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # BGR to RGB
|
||||
|
||||
# Convert grayscale to RGB
|
||||
if len(image.shape) == 2:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
||||
|
||||
# Handle RGBA
|
||||
if image.shape[2] == 4:
|
||||
image = image[:, :, :3] # Drop alpha channel
|
||||
```
|
||||
|
||||
### Coordinate errors
|
||||
|
||||
**Error**: `IndexError: index out of bounds` or incorrect mask location
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Ensure points are (x, y) not (row, col)
|
||||
# x = column index, y = row index
|
||||
point = np.array([[x, y]]) # Correct
|
||||
|
||||
# Verify coordinates are within image bounds
|
||||
h, w = image.shape[:2]
|
||||
assert 0 <= x < w and 0 <= y < h, "Point outside image"
|
||||
|
||||
# For bounding boxes: [x1, y1, x2, y2]
|
||||
box = np.array([x1, y1, x2, y2])
|
||||
assert x1 < x2 and y1 < y2, "Invalid box coordinates"
|
||||
```
|
||||
|
||||
### Empty or incorrect masks
|
||||
|
||||
**Problem**: Masks don't match expected object
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Try multiple prompts
|
||||
input_points = np.array([[x1, y1], [x2, y2]])
|
||||
input_labels = np.array([1, 1]) # Multiple foreground points
|
||||
|
||||
# Add background points
|
||||
input_points = np.array([[obj_x, obj_y], [bg_x, bg_y]])
|
||||
input_labels = np.array([1, 0]) # 1=foreground, 0=background
|
||||
|
||||
# Use box prompt for large objects
|
||||
box = np.array([x1, y1, x2, y2])
|
||||
masks, scores, _ = predictor.predict(box=box, multimask_output=False)
|
||||
|
||||
# Combine box and point
|
||||
masks, scores, _ = predictor.predict(
|
||||
point_coords=np.array([[center_x, center_y]]),
|
||||
point_labels=np.array([1]),
|
||||
box=np.array([x1, y1, x2, y2]),
|
||||
multimask_output=True
|
||||
)
|
||||
|
||||
# Check scores and select best
|
||||
print(f"Scores: {scores}")
|
||||
best_mask = masks[np.argmax(scores)]
|
||||
```
|
||||
|
||||
### Slow inference
|
||||
|
||||
**Problem**: Prediction takes too long
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Use smaller model
|
||||
sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth")
|
||||
|
||||
# Reuse image embeddings
|
||||
predictor.set_image(image) # Compute once
|
||||
for point in points:
|
||||
masks, _, _ = predictor.predict(...) # Fast, reuses embeddings
|
||||
|
||||
# Reduce automatic generation points
|
||||
mask_generator = SamAutomaticMaskGenerator(
|
||||
model=sam,
|
||||
points_per_side=16, # Default is 32
|
||||
)
|
||||
|
||||
# Use ONNX for deployment
|
||||
# Export: python scripts/export_onnx_model.py --return-single-mask
|
||||
```
|
||||
|
||||
## Automatic Mask Generation Issues
|
||||
|
||||
### Too many masks
|
||||
|
||||
**Problem**: Generating thousands of overlapping masks
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
mask_generator = SamAutomaticMaskGenerator(
|
||||
model=sam,
|
||||
points_per_side=16, # Reduce from 32
|
||||
pred_iou_thresh=0.92, # Increase from 0.88
|
||||
stability_score_thresh=0.98, # Increase from 0.95
|
||||
box_nms_thresh=0.5, # More aggressive NMS
|
||||
min_mask_region_area=500, # Remove small masks
|
||||
)
|
||||
```
|
||||
|
||||
### Too few masks
|
||||
|
||||
**Problem**: Missing objects in automatic generation
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
mask_generator = SamAutomaticMaskGenerator(
|
||||
model=sam,
|
||||
points_per_side=64, # Increase density
|
||||
pred_iou_thresh=0.80, # Lower threshold
|
||||
stability_score_thresh=0.85, # Lower threshold
|
||||
crop_n_layers=2, # Add multi-scale
|
||||
min_mask_region_area=0, # Keep all masks
|
||||
)
|
||||
```
|
||||
|
||||
### Small objects missed
|
||||
|
||||
**Problem**: Automatic generation misses small objects
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Use crop layers for multi-scale detection
|
||||
mask_generator = SamAutomaticMaskGenerator(
|
||||
model=sam,
|
||||
crop_n_layers=2,
|
||||
crop_n_points_downscale_factor=1, # Don't reduce points in crops
|
||||
min_mask_region_area=10, # Very small minimum
|
||||
)
|
||||
|
||||
# Or process image patches
|
||||
def segment_with_patches(image, patch_size=512, overlap=64):
|
||||
h, w = image.shape[:2]
|
||||
all_masks = []
|
||||
|
||||
for y in range(0, h, patch_size - overlap):
|
||||
for x in range(0, w, patch_size - overlap):
|
||||
patch = image[y:y+patch_size, x:x+patch_size]
|
||||
masks = mask_generator.generate(patch)
|
||||
|
||||
# Offset masks to original coordinates
|
||||
for m in masks:
|
||||
m['bbox'][0] += x
|
||||
m['bbox'][1] += y
|
||||
# Offset segmentation mask too
|
||||
|
||||
all_masks.extend(masks)
|
||||
|
||||
return all_masks
|
||||
```
|
||||
|
||||
## Memory Issues
|
||||
|
||||
### CUDA out of memory
|
||||
|
||||
**Error**: `torch.cuda.OutOfMemoryError: CUDA out of memory`
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Use smaller model
|
||||
sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth")
|
||||
|
||||
# Clear cache between images
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Process images sequentially, not batched
|
||||
for image in images:
|
||||
predictor.set_image(image)
|
||||
masks, _, _ = predictor.predict(...)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Reduce image size
|
||||
max_size = 1024
|
||||
h, w = image.shape[:2]
|
||||
if max(h, w) > max_size:
|
||||
scale = max_size / max(h, w)
|
||||
image = cv2.resize(image, (int(w*scale), int(h*scale)))
|
||||
|
||||
# Use CPU for large batch processing
|
||||
sam.to("cpu")
|
||||
```
|
||||
|
||||
### RAM out of memory
|
||||
|
||||
**Problem**: System runs out of RAM
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Process images one at a time
|
||||
for img_path in image_paths:
|
||||
image = cv2.imread(img_path)
|
||||
masks = process_image(image)
|
||||
save_results(masks)
|
||||
del image, masks
|
||||
gc.collect()
|
||||
|
||||
# Use generators instead of lists
|
||||
def generate_masks_lazy(image_paths):
|
||||
for path in image_paths:
|
||||
image = cv2.imread(path)
|
||||
masks = mask_generator.generate(image)
|
||||
yield path, masks
|
||||
```
|
||||
|
||||
## ONNX Export Issues
|
||||
|
||||
### Export fails
|
||||
|
||||
**Error**: Various export errors
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Install correct ONNX version
|
||||
pip install onnx==1.14.0 onnxruntime==1.15.0
|
||||
|
||||
# Use correct opset version
|
||||
python scripts/export_onnx_model.py \
|
||||
--checkpoint sam_vit_h_4b8939.pth \
|
||||
--model-type vit_h \
|
||||
--output sam.onnx \
|
||||
--opset 17
|
||||
```
|
||||
|
||||
### ONNX runtime errors
|
||||
|
||||
**Error**: `ONNXRuntimeError` during inference
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import onnxruntime
|
||||
|
||||
# Check available providers
|
||||
print(onnxruntime.get_available_providers())
|
||||
|
||||
# Use CPU provider if GPU fails
|
||||
session = onnxruntime.InferenceSession(
|
||||
"sam.onnx",
|
||||
providers=['CPUExecutionProvider']
|
||||
)
|
||||
|
||||
# Verify input shapes
|
||||
for input in session.get_inputs():
|
||||
print(f"{input.name}: {input.shape}")
|
||||
```
|
||||
|
||||
## HuggingFace Integration Issues
|
||||
|
||||
### Processor errors
|
||||
|
||||
**Error**: Issues with SamProcessor
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
from transformers import SamModel, SamProcessor
|
||||
|
||||
# Use matching processor and model
|
||||
model = SamModel.from_pretrained("facebook/sam-vit-huge")
|
||||
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
|
||||
|
||||
# Ensure input format
|
||||
input_points = [[[x, y]]] # Nested list for batch dimension
|
||||
inputs = processor(image, input_points=input_points, return_tensors="pt")
|
||||
|
||||
# Post-process correctly
|
||||
masks = processor.image_processor.post_process_masks(
|
||||
outputs.pred_masks.cpu(),
|
||||
inputs["original_sizes"].cpu(),
|
||||
inputs["reshaped_input_sizes"].cpu()
|
||||
)
|
||||
```
|
||||
|
||||
## Quality Issues
|
||||
|
||||
### Jagged mask edges
|
||||
|
||||
**Problem**: Masks have rough, pixelated edges
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import cv2
|
||||
from scipy import ndimage
|
||||
|
||||
def smooth_mask(mask, sigma=2):
|
||||
"""Smooth mask edges."""
|
||||
# Gaussian blur
|
||||
smooth = ndimage.gaussian_filter(mask.astype(float), sigma=sigma)
|
||||
return smooth > 0.5
|
||||
|
||||
def refine_edges(mask, kernel_size=5):
|
||||
"""Refine mask edges with morphological operations."""
|
||||
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
|
||||
# Close small gaps
|
||||
closed = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, kernel)
|
||||
# Open to remove noise
|
||||
opened = cv2.morphologyEx(closed, cv2.MORPH_OPEN, kernel)
|
||||
return opened.astype(bool)
|
||||
```
|
||||
|
||||
### Incomplete segmentation
|
||||
|
||||
**Problem**: Mask doesn't cover entire object
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Add multiple points
|
||||
input_points = np.array([
|
||||
[obj_center_x, obj_center_y],
|
||||
[obj_left_x, obj_center_y],
|
||||
[obj_right_x, obj_center_y],
|
||||
[obj_center_x, obj_top_y],
|
||||
[obj_center_x, obj_bottom_y]
|
||||
])
|
||||
input_labels = np.array([1, 1, 1, 1, 1])
|
||||
|
||||
# Use bounding box
|
||||
masks, _, _ = predictor.predict(
|
||||
box=np.array([x1, y1, x2, y2]),
|
||||
multimask_output=False
|
||||
)
|
||||
|
||||
# Iterative refinement
|
||||
mask_input = None
|
||||
for point in points:
|
||||
masks, scores, logits = predictor.predict(
|
||||
point_coords=point.reshape(1, 2),
|
||||
point_labels=np.array([1]),
|
||||
mask_input=mask_input,
|
||||
multimask_output=False
|
||||
)
|
||||
mask_input = logits
|
||||
```
|
||||
|
||||
## Common Error Messages
|
||||
|
||||
| Error | Cause | Solution |
|
||||
|-------|-------|----------|
|
||||
| `CUDA out of memory` | GPU memory full | Use smaller model, clear cache |
|
||||
| `expected 3 channels` | Wrong image format | Convert to RGB |
|
||||
| `index out of bounds` | Invalid coordinates | Check point/box bounds |
|
||||
| `checkpoint not found` | Wrong path | Use absolute path |
|
||||
| `unexpected key` | Model/checkpoint mismatch | Match model type |
|
||||
| `invalid box coordinates` | x1 > x2 or y1 > y2 | Fix box format |
|
||||
|
||||
## Getting Help
|
||||
|
||||
1. **GitHub Issues**: https://github.com/facebookresearch/segment-anything/issues
|
||||
2. **HuggingFace Forums**: https://discuss.huggingface.co
|
||||
3. **Paper**: https://arxiv.org/abs/2304.02643
|
||||
|
||||
### Reporting Issues
|
||||
|
||||
Include:
|
||||
- Python version
|
||||
- PyTorch version: `python -c "import torch; print(torch.__version__)"`
|
||||
- CUDA version: `python -c "import torch; print(torch.version.cuda)"`
|
||||
- SAM model type (vit_b/l/h)
|
||||
- Full error traceback
|
||||
- Minimal reproducible code
|
||||
522
skills/mlops/models/stable-diffusion/SKILL.md
Normal file
522
skills/mlops/models/stable-diffusion/SKILL.md
Normal file
@@ -0,0 +1,522 @@
|
||||
---
|
||||
name: stable-diffusion-image-generation
|
||||
description: State-of-the-art text-to-image generation with Stable Diffusion models via HuggingFace Diffusers. Use when generating images from text prompts, performing image-to-image translation, inpainting, or building custom diffusion pipelines.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [diffusers>=0.30.0, transformers>=4.41.0, accelerate>=0.31.0, torch>=2.0.0]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Image Generation, Stable Diffusion, Diffusers, Text-to-Image, Multimodal, Computer Vision]
|
||||
|
||||
---
|
||||
|
||||
# Stable Diffusion Image Generation
|
||||
|
||||
Comprehensive guide to generating images with Stable Diffusion using the HuggingFace Diffusers library.
|
||||
|
||||
## When to use Stable Diffusion
|
||||
|
||||
**Use Stable Diffusion when:**
|
||||
- Generating images from text descriptions
|
||||
- Performing image-to-image translation (style transfer, enhancement)
|
||||
- Inpainting (filling in masked regions)
|
||||
- Outpainting (extending images beyond boundaries)
|
||||
- Creating variations of existing images
|
||||
- Building custom image generation workflows
|
||||
|
||||
**Key features:**
|
||||
- **Text-to-Image**: Generate images from natural language prompts
|
||||
- **Image-to-Image**: Transform existing images with text guidance
|
||||
- **Inpainting**: Fill masked regions with context-aware content
|
||||
- **ControlNet**: Add spatial conditioning (edges, poses, depth)
|
||||
- **LoRA Support**: Efficient fine-tuning and style adaptation
|
||||
- **Multiple Models**: SD 1.5, SDXL, SD 3.0, Flux support
|
||||
|
||||
**Use alternatives instead:**
|
||||
- **DALL-E 3**: For API-based generation without GPU
|
||||
- **Midjourney**: For artistic, stylized outputs
|
||||
- **Imagen**: For Google Cloud integration
|
||||
- **Leonardo.ai**: For web-based creative workflows
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
pip install diffusers transformers accelerate torch
|
||||
pip install xformers # Optional: memory-efficient attention
|
||||
```
|
||||
|
||||
### Basic text-to-image
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
# Load pipeline (auto-detects model type)
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5",
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
pipe.to("cuda")
|
||||
|
||||
# Generate image
|
||||
image = pipe(
|
||||
"A serene mountain landscape at sunset, highly detailed",
|
||||
num_inference_steps=50,
|
||||
guidance_scale=7.5
|
||||
).images[0]
|
||||
|
||||
image.save("output.png")
|
||||
```
|
||||
|
||||
### Using SDXL (higher quality)
|
||||
|
||||
```python
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
import torch
|
||||
|
||||
pipe = AutoPipelineForText2Image.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16"
|
||||
)
|
||||
pipe.to("cuda")
|
||||
|
||||
# Enable memory optimization
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
image = pipe(
|
||||
prompt="A futuristic city with flying cars, cinematic lighting",
|
||||
height=1024,
|
||||
width=1024,
|
||||
num_inference_steps=30
|
||||
).images[0]
|
||||
```
|
||||
|
||||
## Architecture overview
|
||||
|
||||
### Three-pillar design
|
||||
|
||||
Diffusers is built around three core components:
|
||||
|
||||
```
|
||||
Pipeline (orchestration)
|
||||
├── Model (neural networks)
|
||||
│ ├── UNet / Transformer (noise prediction)
|
||||
│ ├── VAE (latent encoding/decoding)
|
||||
│ └── Text Encoder (CLIP/T5)
|
||||
└── Scheduler (denoising algorithm)
|
||||
```
|
||||
|
||||
### Pipeline inference flow
|
||||
|
||||
```
|
||||
Text Prompt → Text Encoder → Text Embeddings
|
||||
↓
|
||||
Random Noise → [Denoising Loop] ← Scheduler
|
||||
↓
|
||||
Predicted Noise
|
||||
↓
|
||||
VAE Decoder → Final Image
|
||||
```
|
||||
|
||||
## Core concepts
|
||||
|
||||
### Pipelines
|
||||
|
||||
Pipelines orchestrate complete workflows:
|
||||
|
||||
| Pipeline | Purpose |
|
||||
|----------|---------|
|
||||
| `StableDiffusionPipeline` | Text-to-image (SD 1.x/2.x) |
|
||||
| `StableDiffusionXLPipeline` | Text-to-image (SDXL) |
|
||||
| `StableDiffusion3Pipeline` | Text-to-image (SD 3.0) |
|
||||
| `FluxPipeline` | Text-to-image (Flux models) |
|
||||
| `StableDiffusionImg2ImgPipeline` | Image-to-image |
|
||||
| `StableDiffusionInpaintPipeline` | Inpainting |
|
||||
|
||||
### Schedulers
|
||||
|
||||
Schedulers control the denoising process:
|
||||
|
||||
| Scheduler | Steps | Quality | Use Case |
|
||||
|-----------|-------|---------|----------|
|
||||
| `EulerDiscreteScheduler` | 20-50 | Good | Default choice |
|
||||
| `EulerAncestralDiscreteScheduler` | 20-50 | Good | More variation |
|
||||
| `DPMSolverMultistepScheduler` | 15-25 | Excellent | Fast, high quality |
|
||||
| `DDIMScheduler` | 50-100 | Good | Deterministic |
|
||||
| `LCMScheduler` | 4-8 | Good | Very fast |
|
||||
| `UniPCMultistepScheduler` | 15-25 | Excellent | Fast convergence |
|
||||
|
||||
### Swapping schedulers
|
||||
|
||||
```python
|
||||
from diffusers import DPMSolverMultistepScheduler
|
||||
|
||||
# Swap for faster generation
|
||||
pipe.scheduler = DPMSolverMultistepScheduler.from_config(
|
||||
pipe.scheduler.config
|
||||
)
|
||||
|
||||
# Now generate with fewer steps
|
||||
image = pipe(prompt, num_inference_steps=20).images[0]
|
||||
```
|
||||
|
||||
## Generation parameters
|
||||
|
||||
### Key parameters
|
||||
|
||||
| Parameter | Default | Description |
|
||||
|-----------|---------|-------------|
|
||||
| `prompt` | Required | Text description of desired image |
|
||||
| `negative_prompt` | None | What to avoid in the image |
|
||||
| `num_inference_steps` | 50 | Denoising steps (more = better quality) |
|
||||
| `guidance_scale` | 7.5 | Prompt adherence (7-12 typical) |
|
||||
| `height`, `width` | 512/1024 | Output dimensions (multiples of 8) |
|
||||
| `generator` | None | Torch generator for reproducibility |
|
||||
| `num_images_per_prompt` | 1 | Batch size |
|
||||
|
||||
### Reproducible generation
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
generator = torch.Generator(device="cuda").manual_seed(42)
|
||||
|
||||
image = pipe(
|
||||
prompt="A cat wearing a top hat",
|
||||
generator=generator,
|
||||
num_inference_steps=50
|
||||
).images[0]
|
||||
```
|
||||
|
||||
### Negative prompts
|
||||
|
||||
```python
|
||||
image = pipe(
|
||||
prompt="Professional photo of a dog in a garden",
|
||||
negative_prompt="blurry, low quality, distorted, ugly, bad anatomy",
|
||||
guidance_scale=7.5
|
||||
).images[0]
|
||||
```
|
||||
|
||||
## Image-to-image
|
||||
|
||||
Transform existing images with text guidance:
|
||||
|
||||
```python
|
||||
from diffusers import AutoPipelineForImage2Image
|
||||
from PIL import Image
|
||||
|
||||
pipe = AutoPipelineForImage2Image.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5",
|
||||
torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
|
||||
init_image = Image.open("input.jpg").resize((512, 512))
|
||||
|
||||
image = pipe(
|
||||
prompt="A watercolor painting of the scene",
|
||||
image=init_image,
|
||||
strength=0.75, # How much to transform (0-1)
|
||||
num_inference_steps=50
|
||||
).images[0]
|
||||
```
|
||||
|
||||
## Inpainting
|
||||
|
||||
Fill masked regions:
|
||||
|
||||
```python
|
||||
from diffusers import AutoPipelineForInpainting
|
||||
from PIL import Image
|
||||
|
||||
pipe = AutoPipelineForInpainting.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting",
|
||||
torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
|
||||
image = Image.open("photo.jpg")
|
||||
mask = Image.open("mask.png") # White = inpaint region
|
||||
|
||||
result = pipe(
|
||||
prompt="A red car parked on the street",
|
||||
image=image,
|
||||
mask_image=mask,
|
||||
num_inference_steps=50
|
||||
).images[0]
|
||||
```
|
||||
|
||||
## ControlNet
|
||||
|
||||
Add spatial conditioning for precise control:
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
|
||||
import torch
|
||||
|
||||
# Load ControlNet for edge conditioning
|
||||
controlnet = ControlNetModel.from_pretrained(
|
||||
"lllyasviel/control_v11p_sd15_canny",
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5",
|
||||
controlnet=controlnet,
|
||||
torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
|
||||
# Use Canny edge image as control
|
||||
control_image = get_canny_image(input_image)
|
||||
|
||||
image = pipe(
|
||||
prompt="A beautiful house in the style of Van Gogh",
|
||||
image=control_image,
|
||||
num_inference_steps=30
|
||||
).images[0]
|
||||
```
|
||||
|
||||
### Available ControlNets
|
||||
|
||||
| ControlNet | Input Type | Use Case |
|
||||
|------------|------------|----------|
|
||||
| `canny` | Edge maps | Preserve structure |
|
||||
| `openpose` | Pose skeletons | Human poses |
|
||||
| `depth` | Depth maps | 3D-aware generation |
|
||||
| `normal` | Normal maps | Surface details |
|
||||
| `mlsd` | Line segments | Architectural lines |
|
||||
| `scribble` | Rough sketches | Sketch-to-image |
|
||||
|
||||
## LoRA adapters
|
||||
|
||||
Load fine-tuned style adapters:
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5",
|
||||
torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
|
||||
# Load LoRA weights
|
||||
pipe.load_lora_weights("path/to/lora", weight_name="style.safetensors")
|
||||
|
||||
# Generate with LoRA style
|
||||
image = pipe("A portrait in the trained style").images[0]
|
||||
|
||||
# Adjust LoRA strength
|
||||
pipe.fuse_lora(lora_scale=0.8)
|
||||
|
||||
# Unload LoRA
|
||||
pipe.unload_lora_weights()
|
||||
```
|
||||
|
||||
### Multiple LoRAs
|
||||
|
||||
```python
|
||||
# Load multiple LoRAs
|
||||
pipe.load_lora_weights("lora1", adapter_name="style")
|
||||
pipe.load_lora_weights("lora2", adapter_name="character")
|
||||
|
||||
# Set weights for each
|
||||
pipe.set_adapters(["style", "character"], adapter_weights=[0.7, 0.5])
|
||||
|
||||
image = pipe("A portrait").images[0]
|
||||
```
|
||||
|
||||
## Memory optimization
|
||||
|
||||
### Enable CPU offloading
|
||||
|
||||
```python
|
||||
# Model CPU offload - moves models to CPU when not in use
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
# Sequential CPU offload - more aggressive, slower
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
```
|
||||
|
||||
### Attention slicing
|
||||
|
||||
```python
|
||||
# Reduce memory by computing attention in chunks
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
# Or specific chunk size
|
||||
pipe.enable_attention_slicing("max")
|
||||
```
|
||||
|
||||
### xFormers memory-efficient attention
|
||||
|
||||
```python
|
||||
# Requires xformers package
|
||||
pipe.enable_xformers_memory_efficient_attention()
|
||||
```
|
||||
|
||||
### VAE slicing for large images
|
||||
|
||||
```python
|
||||
# Decode latents in tiles for large images
|
||||
pipe.enable_vae_slicing()
|
||||
pipe.enable_vae_tiling()
|
||||
```
|
||||
|
||||
## Model variants
|
||||
|
||||
### Loading different precisions
|
||||
|
||||
```python
|
||||
# FP16 (recommended for GPU)
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"model-id",
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16"
|
||||
)
|
||||
|
||||
# BF16 (better precision, requires Ampere+ GPU)
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"model-id",
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
```
|
||||
|
||||
### Loading specific components
|
||||
|
||||
```python
|
||||
from diffusers import UNet2DConditionModel, AutoencoderKL
|
||||
|
||||
# Load custom VAE
|
||||
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
|
||||
|
||||
# Use with pipeline
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5",
|
||||
vae=vae,
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
```
|
||||
|
||||
## Batch generation
|
||||
|
||||
Generate multiple images efficiently:
|
||||
|
||||
```python
|
||||
# Multiple prompts
|
||||
prompts = [
|
||||
"A cat playing piano",
|
||||
"A dog reading a book",
|
||||
"A bird painting a picture"
|
||||
]
|
||||
|
||||
images = pipe(prompts, num_inference_steps=30).images
|
||||
|
||||
# Multiple images per prompt
|
||||
images = pipe(
|
||||
"A beautiful sunset",
|
||||
num_images_per_prompt=4,
|
||||
num_inference_steps=30
|
||||
).images
|
||||
```
|
||||
|
||||
## Common workflows
|
||||
|
||||
### Workflow 1: High-quality generation
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler
|
||||
import torch
|
||||
|
||||
# 1. Load SDXL with optimizations
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16"
|
||||
)
|
||||
pipe.to("cuda")
|
||||
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
# 2. Generate with quality settings
|
||||
image = pipe(
|
||||
prompt="A majestic lion in the savanna, golden hour lighting, 8k, detailed fur",
|
||||
negative_prompt="blurry, low quality, cartoon, anime, sketch",
|
||||
num_inference_steps=30,
|
||||
guidance_scale=7.5,
|
||||
height=1024,
|
||||
width=1024
|
||||
).images[0]
|
||||
```
|
||||
|
||||
### Workflow 2: Fast prototyping
|
||||
|
||||
```python
|
||||
from diffusers import AutoPipelineForText2Image, LCMScheduler
|
||||
import torch
|
||||
|
||||
# Use LCM for 4-8 step generation
|
||||
pipe = AutoPipelineForText2Image.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
|
||||
# Load LCM LoRA for fast generation
|
||||
pipe.load_lora_weights("latent-consistency/lcm-lora-sdxl")
|
||||
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.fuse_lora()
|
||||
|
||||
# Generate in ~1 second
|
||||
image = pipe(
|
||||
"A beautiful landscape",
|
||||
num_inference_steps=4,
|
||||
guidance_scale=1.0
|
||||
).images[0]
|
||||
```
|
||||
|
||||
## Common issues
|
||||
|
||||
**CUDA out of memory:**
|
||||
```python
|
||||
# Enable memory optimizations
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.enable_attention_slicing()
|
||||
pipe.enable_vae_slicing()
|
||||
|
||||
# Or use lower precision
|
||||
pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
|
||||
```
|
||||
|
||||
**Black/noise images:**
|
||||
```python
|
||||
# Check VAE configuration
|
||||
# Use safety checker bypass if needed
|
||||
pipe.safety_checker = None
|
||||
|
||||
# Ensure proper dtype consistency
|
||||
pipe = pipe.to(dtype=torch.float16)
|
||||
```
|
||||
|
||||
**Slow generation:**
|
||||
```python
|
||||
# Use faster scheduler
|
||||
from diffusers import DPMSolverMultistepScheduler
|
||||
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
|
||||
# Reduce steps
|
||||
image = pipe(prompt, num_inference_steps=20).images[0]
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
- **[Advanced Usage](references/advanced-usage.md)** - Custom pipelines, fine-tuning, deployment
|
||||
- **[Troubleshooting](references/troubleshooting.md)** - Common issues and solutions
|
||||
|
||||
## Resources
|
||||
|
||||
- **Documentation**: https://huggingface.co/docs/diffusers
|
||||
- **Repository**: https://github.com/huggingface/diffusers
|
||||
- **Model Hub**: https://huggingface.co/models?library=diffusers
|
||||
- **Discord**: https://discord.gg/diffusers
|
||||
@@ -0,0 +1,716 @@
|
||||
# Stable Diffusion Advanced Usage Guide
|
||||
|
||||
## Custom Pipelines
|
||||
|
||||
### Building from components
|
||||
|
||||
```python
|
||||
from diffusers import (
|
||||
UNet2DConditionModel,
|
||||
AutoencoderKL,
|
||||
DDPMScheduler,
|
||||
StableDiffusionPipeline
|
||||
)
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
import torch
|
||||
|
||||
# Load components individually
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5",
|
||||
subfolder="unet"
|
||||
)
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5",
|
||||
subfolder="vae"
|
||||
)
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5",
|
||||
subfolder="text_encoder"
|
||||
)
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5",
|
||||
subfolder="tokenizer"
|
||||
)
|
||||
scheduler = DDPMScheduler.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5",
|
||||
subfolder="scheduler"
|
||||
)
|
||||
|
||||
# Assemble pipeline
|
||||
pipe = StableDiffusionPipeline(
|
||||
unet=unet,
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
scheduler=scheduler,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False
|
||||
)
|
||||
```
|
||||
|
||||
### Custom denoising loop
|
||||
|
||||
```python
|
||||
from diffusers import DDIMScheduler, AutoencoderKL, UNet2DConditionModel
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
import torch
|
||||
|
||||
def custom_generate(
|
||||
prompt: str,
|
||||
num_steps: int = 50,
|
||||
guidance_scale: float = 7.5,
|
||||
height: int = 512,
|
||||
width: int = 512
|
||||
):
|
||||
# Load components
|
||||
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
||||
unet = UNet2DConditionModel.from_pretrained("sd-model", subfolder="unet")
|
||||
vae = AutoencoderKL.from_pretrained("sd-model", subfolder="vae")
|
||||
scheduler = DDIMScheduler.from_pretrained("sd-model", subfolder="scheduler")
|
||||
|
||||
device = "cuda"
|
||||
text_encoder.to(device)
|
||||
unet.to(device)
|
||||
vae.to(device)
|
||||
|
||||
# Encode prompt
|
||||
text_input = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=77,
|
||||
truncation=True,
|
||||
return_tensors="pt"
|
||||
)
|
||||
text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
|
||||
|
||||
# Unconditional embeddings for classifier-free guidance
|
||||
uncond_input = tokenizer(
|
||||
"",
|
||||
padding="max_length",
|
||||
max_length=77,
|
||||
return_tensors="pt"
|
||||
)
|
||||
uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
|
||||
|
||||
# Concatenate for batch processing
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
|
||||
# Initialize latents
|
||||
latents = torch.randn(
|
||||
(1, 4, height // 8, width // 8),
|
||||
device=device
|
||||
)
|
||||
latents = latents * scheduler.init_noise_sigma
|
||||
|
||||
# Denoising loop
|
||||
scheduler.set_timesteps(num_steps)
|
||||
for t in scheduler.timesteps:
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# Predict noise
|
||||
with torch.no_grad():
|
||||
noise_pred = unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=text_embeddings
|
||||
).sample
|
||||
|
||||
# Classifier-free guidance
|
||||
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||
noise_pred_cond - noise_pred_uncond
|
||||
)
|
||||
|
||||
# Update latents
|
||||
latents = scheduler.step(noise_pred, t, latents).prev_sample
|
||||
|
||||
# Decode latents
|
||||
latents = latents / vae.config.scaling_factor
|
||||
with torch.no_grad():
|
||||
image = vae.decode(latents).sample
|
||||
|
||||
# Convert to PIL
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
image = (image * 255).round().astype("uint8")[0]
|
||||
|
||||
return Image.fromarray(image)
|
||||
```
|
||||
|
||||
## IP-Adapter
|
||||
|
||||
Use image prompts alongside text:
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from diffusers.utils import load_image
|
||||
import torch
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5",
|
||||
torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
|
||||
# Load IP-Adapter
|
||||
pipe.load_ip_adapter(
|
||||
"h94/IP-Adapter",
|
||||
subfolder="models",
|
||||
weight_name="ip-adapter_sd15.bin"
|
||||
)
|
||||
|
||||
# Set IP-Adapter scale
|
||||
pipe.set_ip_adapter_scale(0.6)
|
||||
|
||||
# Load reference image
|
||||
ip_image = load_image("reference_style.jpg")
|
||||
|
||||
# Generate with image + text prompt
|
||||
image = pipe(
|
||||
prompt="A portrait in a garden",
|
||||
ip_adapter_image=ip_image,
|
||||
num_inference_steps=50
|
||||
).images[0]
|
||||
```
|
||||
|
||||
### Multiple IP-Adapter images
|
||||
|
||||
```python
|
||||
# Use multiple reference images
|
||||
pipe.set_ip_adapter_scale([0.5, 0.7])
|
||||
|
||||
images = [
|
||||
load_image("style_reference.jpg"),
|
||||
load_image("composition_reference.jpg")
|
||||
]
|
||||
|
||||
result = pipe(
|
||||
prompt="A landscape painting",
|
||||
ip_adapter_image=images,
|
||||
num_inference_steps=50
|
||||
).images[0]
|
||||
```
|
||||
|
||||
## SDXL Refiner
|
||||
|
||||
Two-stage generation for higher quality:
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
|
||||
import torch
|
||||
|
||||
# Load base model
|
||||
base = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16"
|
||||
).to("cuda")
|
||||
|
||||
# Load refiner
|
||||
refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-refiner-1.0",
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16"
|
||||
).to("cuda")
|
||||
|
||||
# Generate with base (partial denoising)
|
||||
image = base(
|
||||
prompt="A majestic eagle soaring over mountains",
|
||||
num_inference_steps=40,
|
||||
denoising_end=0.8,
|
||||
output_type="latent"
|
||||
).images
|
||||
|
||||
# Refine with refiner
|
||||
refined = refiner(
|
||||
prompt="A majestic eagle soaring over mountains",
|
||||
image=image,
|
||||
num_inference_steps=40,
|
||||
denoising_start=0.8
|
||||
).images[0]
|
||||
```
|
||||
|
||||
## T2I-Adapter
|
||||
|
||||
Lightweight conditioning without full ControlNet:
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionXLAdapterPipeline, T2IAdapter
|
||||
import torch
|
||||
|
||||
# Load adapter
|
||||
adapter = T2IAdapter.from_pretrained(
|
||||
"TencentARC/t2i-adapter-canny-sdxl-1.0",
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
pipe = StableDiffusionXLAdapterPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
adapter=adapter,
|
||||
torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
|
||||
# Get canny edges
|
||||
canny_image = get_canny_image(input_image)
|
||||
|
||||
image = pipe(
|
||||
prompt="A colorful anime character",
|
||||
image=canny_image,
|
||||
num_inference_steps=30,
|
||||
adapter_conditioning_scale=0.8
|
||||
).images[0]
|
||||
```
|
||||
|
||||
## Fine-tuning with DreamBooth
|
||||
|
||||
Train on custom subjects:
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionPipeline, DDPMScheduler
|
||||
from diffusers.optimization import get_scheduler
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from PIL import Image
|
||||
import os
|
||||
|
||||
class DreamBoothDataset(Dataset):
|
||||
def __init__(self, instance_images_path, instance_prompt, tokenizer, size=512):
|
||||
self.instance_images_path = instance_images_path
|
||||
self.instance_prompt = instance_prompt
|
||||
self.tokenizer = tokenizer
|
||||
self.size = size
|
||||
|
||||
self.instance_images = [
|
||||
os.path.join(instance_images_path, f)
|
||||
for f in os.listdir(instance_images_path)
|
||||
if f.endswith(('.png', '.jpg', '.jpeg'))
|
||||
]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.instance_images)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
image = Image.open(self.instance_images[idx]).convert("RGB")
|
||||
image = image.resize((self.size, self.size))
|
||||
image = torch.tensor(np.array(image)).permute(2, 0, 1) / 127.5 - 1.0
|
||||
|
||||
tokens = self.tokenizer(
|
||||
self.instance_prompt,
|
||||
padding="max_length",
|
||||
max_length=77,
|
||||
truncation=True,
|
||||
return_tensors="pt"
|
||||
)
|
||||
|
||||
return {"image": image, "input_ids": tokens.input_ids.squeeze()}
|
||||
|
||||
def train_dreambooth(
|
||||
pretrained_model: str,
|
||||
instance_data_dir: str,
|
||||
instance_prompt: str,
|
||||
output_dir: str,
|
||||
learning_rate: float = 5e-6,
|
||||
max_train_steps: int = 800,
|
||||
train_batch_size: int = 1
|
||||
):
|
||||
# Load pipeline
|
||||
pipe = StableDiffusionPipeline.from_pretrained(pretrained_model)
|
||||
|
||||
unet = pipe.unet
|
||||
vae = pipe.vae
|
||||
text_encoder = pipe.text_encoder
|
||||
tokenizer = pipe.tokenizer
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model, subfolder="scheduler")
|
||||
|
||||
# Freeze VAE and text encoder
|
||||
vae.requires_grad_(False)
|
||||
text_encoder.requires_grad_(False)
|
||||
|
||||
# Create dataset
|
||||
dataset = DreamBoothDataset(
|
||||
instance_data_dir, instance_prompt, tokenizer
|
||||
)
|
||||
dataloader = DataLoader(dataset, batch_size=train_batch_size, shuffle=True)
|
||||
|
||||
# Setup optimizer
|
||||
optimizer = torch.optim.AdamW(unet.parameters(), lr=learning_rate)
|
||||
lr_scheduler = get_scheduler(
|
||||
"constant",
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=0,
|
||||
num_training_steps=max_train_steps
|
||||
)
|
||||
|
||||
# Training loop
|
||||
unet.train()
|
||||
device = "cuda"
|
||||
unet.to(device)
|
||||
vae.to(device)
|
||||
text_encoder.to(device)
|
||||
|
||||
global_step = 0
|
||||
for epoch in range(max_train_steps // len(dataloader) + 1):
|
||||
for batch in dataloader:
|
||||
if global_step >= max_train_steps:
|
||||
break
|
||||
|
||||
# Encode images to latents
|
||||
latents = vae.encode(batch["image"].to(device)).latent_dist.sample()
|
||||
latents = latents * vae.config.scaling_factor
|
||||
|
||||
# Sample noise
|
||||
noise = torch.randn_like(latents)
|
||||
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (latents.shape[0],))
|
||||
timesteps = timesteps.to(device)
|
||||
|
||||
# Add noise
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# Get text embeddings
|
||||
encoder_hidden_states = text_encoder(batch["input_ids"].to(device))[0]
|
||||
|
||||
# Predict noise
|
||||
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
|
||||
# Compute loss
|
||||
loss = torch.nn.functional.mse_loss(noise_pred, noise)
|
||||
|
||||
# Backprop
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
global_step += 1
|
||||
|
||||
if global_step % 100 == 0:
|
||||
print(f"Step {global_step}, Loss: {loss.item():.4f}")
|
||||
|
||||
# Save model
|
||||
pipe.unet = unet
|
||||
pipe.save_pretrained(output_dir)
|
||||
```
|
||||
|
||||
## LoRA Training
|
||||
|
||||
Efficient fine-tuning with Low-Rank Adaptation:
|
||||
|
||||
```python
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from diffusers import StableDiffusionPipeline
|
||||
import torch
|
||||
|
||||
def train_lora(
|
||||
base_model: str,
|
||||
train_dataset,
|
||||
output_dir: str,
|
||||
lora_rank: int = 4,
|
||||
learning_rate: float = 1e-4,
|
||||
max_train_steps: int = 1000
|
||||
):
|
||||
pipe = StableDiffusionPipeline.from_pretrained(base_model)
|
||||
unet = pipe.unet
|
||||
|
||||
# Configure LoRA
|
||||
lora_config = LoraConfig(
|
||||
r=lora_rank,
|
||||
lora_alpha=lora_rank,
|
||||
target_modules=["to_q", "to_v", "to_k", "to_out.0"],
|
||||
lora_dropout=0.1
|
||||
)
|
||||
|
||||
# Apply LoRA to UNet
|
||||
unet = get_peft_model(unet, lora_config)
|
||||
unet.print_trainable_parameters() # Shows ~0.1% trainable
|
||||
|
||||
# Train (similar to DreamBooth but only LoRA params)
|
||||
optimizer = torch.optim.AdamW(
|
||||
unet.parameters(),
|
||||
lr=learning_rate
|
||||
)
|
||||
|
||||
# ... training loop ...
|
||||
|
||||
# Save LoRA weights only
|
||||
unet.save_pretrained(output_dir)
|
||||
```
|
||||
|
||||
## Textual Inversion
|
||||
|
||||
Learn new concepts through embeddings:
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionPipeline
|
||||
import torch
|
||||
|
||||
# Load with textual inversion
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5",
|
||||
torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
|
||||
# Load learned embedding
|
||||
pipe.load_textual_inversion(
|
||||
"sd-concepts-library/cat-toy",
|
||||
token="<cat-toy>"
|
||||
)
|
||||
|
||||
# Use in prompts
|
||||
image = pipe("A photo of <cat-toy> on a beach").images[0]
|
||||
```
|
||||
|
||||
## Quantization
|
||||
|
||||
Reduce memory with quantization:
|
||||
|
||||
```python
|
||||
from diffusers import BitsAndBytesConfig, StableDiffusionXLPipeline
|
||||
import torch
|
||||
|
||||
# 8-bit quantization
|
||||
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
```
|
||||
|
||||
### NF4 quantization (4-bit)
|
||||
|
||||
```python
|
||||
quantization_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=torch.float16
|
||||
)
|
||||
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
quantization_config=quantization_config
|
||||
)
|
||||
```
|
||||
|
||||
## Production Deployment
|
||||
|
||||
### FastAPI server
|
||||
|
||||
```python
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
import base64
|
||||
from io import BytesIO
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Load model at startup
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5",
|
||||
torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
class GenerationRequest(BaseModel):
|
||||
prompt: str
|
||||
negative_prompt: str = ""
|
||||
num_inference_steps: int = 30
|
||||
guidance_scale: float = 7.5
|
||||
width: int = 512
|
||||
height: int = 512
|
||||
seed: int = None
|
||||
|
||||
class GenerationResponse(BaseModel):
|
||||
image_base64: str
|
||||
seed: int
|
||||
|
||||
@app.post("/generate", response_model=GenerationResponse)
|
||||
async def generate(request: GenerationRequest):
|
||||
try:
|
||||
generator = None
|
||||
seed = request.seed or torch.randint(0, 2**32, (1,)).item()
|
||||
generator = torch.Generator("cuda").manual_seed(seed)
|
||||
|
||||
image = pipe(
|
||||
prompt=request.prompt,
|
||||
negative_prompt=request.negative_prompt,
|
||||
num_inference_steps=request.num_inference_steps,
|
||||
guidance_scale=request.guidance_scale,
|
||||
width=request.width,
|
||||
height=request.height,
|
||||
generator=generator
|
||||
).images[0]
|
||||
|
||||
# Convert to base64
|
||||
buffer = BytesIO()
|
||||
image.save(buffer, format="PNG")
|
||||
image_base64 = base64.b64encode(buffer.getvalue()).decode()
|
||||
|
||||
return GenerationResponse(image_base64=image_base64, seed=seed)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "healthy"}
|
||||
```
|
||||
|
||||
### Docker deployment
|
||||
|
||||
```dockerfile
|
||||
FROM nvidia/cuda:12.1-runtime-ubuntu22.04
|
||||
|
||||
RUN apt-get update && apt-get install -y python3 python3-pip
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY requirements.txt .
|
||||
RUN pip3 install -r requirements.txt
|
||||
|
||||
COPY . .
|
||||
|
||||
# Pre-download model
|
||||
RUN python3 -c "from diffusers import DiffusionPipeline; DiffusionPipeline.from_pretrained('stable-diffusion-v1-5/stable-diffusion-v1-5')"
|
||||
|
||||
EXPOSE 8000
|
||||
CMD ["uvicorn", "server:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
```
|
||||
|
||||
### Kubernetes deployment
|
||||
|
||||
```yaml
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: stable-diffusion
|
||||
spec:
|
||||
replicas: 2
|
||||
selector:
|
||||
matchLabels:
|
||||
app: stable-diffusion
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
app: stable-diffusion
|
||||
spec:
|
||||
containers:
|
||||
- name: sd
|
||||
image: your-registry/stable-diffusion:latest
|
||||
ports:
|
||||
- containerPort: 8000
|
||||
resources:
|
||||
limits:
|
||||
nvidia.com/gpu: 1
|
||||
memory: "16Gi"
|
||||
requests:
|
||||
nvidia.com/gpu: 1
|
||||
memory: "8Gi"
|
||||
env:
|
||||
- name: TRANSFORMERS_CACHE
|
||||
value: "/cache/huggingface"
|
||||
volumeMounts:
|
||||
- name: model-cache
|
||||
mountPath: /cache
|
||||
volumes:
|
||||
- name: model-cache
|
||||
persistentVolumeClaim:
|
||||
claimName: model-cache-pvc
|
||||
---
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: stable-diffusion
|
||||
spec:
|
||||
selector:
|
||||
app: stable-diffusion
|
||||
ports:
|
||||
- port: 80
|
||||
targetPort: 8000
|
||||
type: LoadBalancer
|
||||
```
|
||||
|
||||
## Callback System
|
||||
|
||||
Monitor and modify generation:
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from diffusers.callbacks import PipelineCallback
|
||||
import torch
|
||||
|
||||
class ProgressCallback(PipelineCallback):
|
||||
def __init__(self):
|
||||
self.progress = []
|
||||
|
||||
def callback_fn(self, pipe, step_index, timestep, callback_kwargs):
|
||||
self.progress.append({
|
||||
"step": step_index,
|
||||
"timestep": timestep.item()
|
||||
})
|
||||
|
||||
# Optionally modify latents
|
||||
latents = callback_kwargs["latents"]
|
||||
|
||||
return callback_kwargs
|
||||
|
||||
# Use callback
|
||||
callback = ProgressCallback()
|
||||
|
||||
image = pipe(
|
||||
prompt="A sunset",
|
||||
callback_on_step_end=callback.callback_fn,
|
||||
callback_on_step_end_tensor_inputs=["latents"]
|
||||
).images[0]
|
||||
|
||||
print(f"Generation completed in {len(callback.progress)} steps")
|
||||
```
|
||||
|
||||
### Early stopping
|
||||
|
||||
```python
|
||||
def early_stop_callback(pipe, step_index, timestep, callback_kwargs):
|
||||
# Stop after 20 steps
|
||||
if step_index >= 20:
|
||||
pipe._interrupt = True
|
||||
return callback_kwargs
|
||||
|
||||
image = pipe(
|
||||
prompt="A landscape",
|
||||
num_inference_steps=50,
|
||||
callback_on_step_end=early_stop_callback
|
||||
).images[0]
|
||||
```
|
||||
|
||||
## Multi-GPU Inference
|
||||
|
||||
### Device map auto
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
device_map="auto", # Automatically distribute across GPUs
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
```
|
||||
|
||||
### Manual distribution
|
||||
|
||||
```python
|
||||
from accelerate import infer_auto_device_map, dispatch_model
|
||||
|
||||
# Create device map
|
||||
device_map = infer_auto_device_map(
|
||||
pipe.unet,
|
||||
max_memory={0: "10GiB", 1: "10GiB"}
|
||||
)
|
||||
|
||||
# Dispatch model
|
||||
pipe.unet = dispatch_model(pipe.unet, device_map=device_map)
|
||||
```
|
||||
@@ -0,0 +1,555 @@
|
||||
# Stable Diffusion Troubleshooting Guide
|
||||
|
||||
## Installation Issues
|
||||
|
||||
### Package conflicts
|
||||
|
||||
**Error**: `ImportError: cannot import name 'cached_download' from 'huggingface_hub'`
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Update huggingface_hub
|
||||
pip install --upgrade huggingface_hub
|
||||
|
||||
# Reinstall diffusers
|
||||
pip install --upgrade diffusers
|
||||
```
|
||||
|
||||
### xFormers installation fails
|
||||
|
||||
**Error**: `RuntimeError: CUDA error: no kernel image is available for execution`
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Check CUDA version
|
||||
nvcc --version
|
||||
|
||||
# Install matching xformers
|
||||
pip install xformers --index-url https://download.pytorch.org/whl/cu121 # For CUDA 12.1
|
||||
|
||||
# Or build from source
|
||||
pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers
|
||||
```
|
||||
|
||||
### Torch/CUDA mismatch
|
||||
|
||||
**Error**: `RuntimeError: CUDA error: CUBLAS_STATUS_NOT_INITIALIZED`
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Check versions
|
||||
python -c "import torch; print(torch.__version__, torch.cuda.is_available())"
|
||||
|
||||
# Reinstall PyTorch with correct CUDA
|
||||
pip uninstall torch torchvision
|
||||
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121
|
||||
```
|
||||
|
||||
## Memory Issues
|
||||
|
||||
### CUDA out of memory
|
||||
|
||||
**Error**: `torch.cuda.OutOfMemoryError: CUDA out of memory`
|
||||
|
||||
**Solutions**:
|
||||
|
||||
```python
|
||||
# Solution 1: Enable CPU offloading
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
# Solution 2: Sequential CPU offload (more aggressive)
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
|
||||
# Solution 3: Attention slicing
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
# Solution 4: VAE slicing for large images
|
||||
pipe.enable_vae_slicing()
|
||||
|
||||
# Solution 5: Use lower precision
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"model-id",
|
||||
torch_dtype=torch.float16 # or torch.bfloat16
|
||||
)
|
||||
|
||||
# Solution 6: Reduce batch size
|
||||
image = pipe(prompt, num_images_per_prompt=1).images[0]
|
||||
|
||||
# Solution 7: Generate smaller images
|
||||
image = pipe(prompt, height=512, width=512).images[0]
|
||||
|
||||
# Solution 8: Clear cache between generations
|
||||
import gc
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
```
|
||||
|
||||
### Memory grows over time
|
||||
|
||||
**Problem**: Memory usage increases with each generation
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
import gc
|
||||
import torch
|
||||
|
||||
def generate_with_cleanup(pipe, prompt, **kwargs):
|
||||
try:
|
||||
image = pipe(prompt, **kwargs).images[0]
|
||||
return image
|
||||
finally:
|
||||
# Clear cache after generation
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
```
|
||||
|
||||
### Large model loading fails
|
||||
|
||||
**Error**: `RuntimeError: Unable to load model weights`
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Use low CPU memory mode
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"large-model-id",
|
||||
low_cpu_mem_usage=True,
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
```
|
||||
|
||||
## Generation Issues
|
||||
|
||||
### Black images
|
||||
|
||||
**Problem**: Output images are completely black
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Solution 1: Disable safety checker
|
||||
pipe.safety_checker = None
|
||||
|
||||
# Solution 2: Check VAE scaling
|
||||
# The issue might be with VAE encoding/decoding
|
||||
latents = latents / pipe.vae.config.scaling_factor # Before decode
|
||||
|
||||
# Solution 3: Ensure proper dtype
|
||||
pipe = pipe.to(dtype=torch.float16)
|
||||
pipe.vae = pipe.vae.to(dtype=torch.float32) # VAE often needs fp32
|
||||
|
||||
# Solution 4: Check guidance scale
|
||||
# Too high can cause issues
|
||||
image = pipe(prompt, guidance_scale=7.5).images[0] # Not 20+
|
||||
```
|
||||
|
||||
### Noise/static images
|
||||
|
||||
**Problem**: Output looks like random noise
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Solution 1: Increase inference steps
|
||||
image = pipe(prompt, num_inference_steps=50).images[0]
|
||||
|
||||
# Solution 2: Check scheduler configuration
|
||||
pipe.scheduler = pipe.scheduler.from_config(pipe.scheduler.config)
|
||||
|
||||
# Solution 3: Verify model was loaded correctly
|
||||
print(pipe.unet) # Should show model architecture
|
||||
```
|
||||
|
||||
### Blurry images
|
||||
|
||||
**Problem**: Output images are low quality or blurry
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Solution 1: Use more steps
|
||||
image = pipe(prompt, num_inference_steps=50).images[0]
|
||||
|
||||
# Solution 2: Use better VAE
|
||||
from diffusers import AutoencoderKL
|
||||
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
|
||||
pipe.vae = vae
|
||||
|
||||
# Solution 3: Use SDXL or refiner
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0"
|
||||
)
|
||||
|
||||
# Solution 4: Upscale with img2img
|
||||
upscale_pipe = StableDiffusionImg2ImgPipeline.from_pretrained(...)
|
||||
upscaled = upscale_pipe(
|
||||
prompt=prompt,
|
||||
image=image.resize((1024, 1024)),
|
||||
strength=0.3
|
||||
).images[0]
|
||||
```
|
||||
|
||||
### Prompt not being followed
|
||||
|
||||
**Problem**: Generated image doesn't match the prompt
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Solution 1: Increase guidance scale
|
||||
image = pipe(prompt, guidance_scale=10.0).images[0]
|
||||
|
||||
# Solution 2: Use negative prompts
|
||||
image = pipe(
|
||||
prompt="A red car",
|
||||
negative_prompt="blue, green, yellow, wrong color",
|
||||
guidance_scale=7.5
|
||||
).images[0]
|
||||
|
||||
# Solution 3: Use prompt weighting
|
||||
# Emphasize important words
|
||||
prompt = "A (red:1.5) car on a street"
|
||||
|
||||
# Solution 4: Use longer, more detailed prompts
|
||||
prompt = """
|
||||
A bright red sports car, ferrari style, parked on a city street,
|
||||
photorealistic, high detail, 8k, professional photography
|
||||
"""
|
||||
```
|
||||
|
||||
### Distorted faces/hands
|
||||
|
||||
**Problem**: Faces and hands look deformed
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Solution 1: Use negative prompts
|
||||
negative_prompt = """
|
||||
bad hands, bad anatomy, deformed, ugly, blurry,
|
||||
extra fingers, mutated hands, poorly drawn hands,
|
||||
poorly drawn face, mutation, deformed face
|
||||
"""
|
||||
|
||||
# Solution 2: Use face-specific models
|
||||
# ADetailer or similar post-processing
|
||||
|
||||
# Solution 3: Use ControlNet for poses
|
||||
# Load pose estimation and condition generation
|
||||
|
||||
# Solution 4: Inpaint problematic areas
|
||||
mask = create_face_mask(image)
|
||||
fixed = inpaint_pipe(
|
||||
prompt="beautiful detailed face",
|
||||
image=image,
|
||||
mask_image=mask
|
||||
).images[0]
|
||||
```
|
||||
|
||||
## Scheduler Issues
|
||||
|
||||
### Scheduler not compatible
|
||||
|
||||
**Error**: `ValueError: Scheduler ... is not compatible with pipeline`
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
from diffusers import EulerDiscreteScheduler
|
||||
|
||||
# Create scheduler from config
|
||||
pipe.scheduler = EulerDiscreteScheduler.from_config(
|
||||
pipe.scheduler.config
|
||||
)
|
||||
|
||||
# Check compatible schedulers
|
||||
print(pipe.scheduler.compatibles)
|
||||
```
|
||||
|
||||
### Wrong number of steps
|
||||
|
||||
**Problem**: Model generates different quality with same steps
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Reset timesteps explicitly
|
||||
pipe.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
# Check scheduler's step count
|
||||
print(len(pipe.scheduler.timesteps))
|
||||
```
|
||||
|
||||
## LoRA Issues
|
||||
|
||||
### LoRA weights not loading
|
||||
|
||||
**Error**: `RuntimeError: Error(s) in loading state_dict for UNet2DConditionModel`
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Check weight file format
|
||||
# Should be .safetensors or .bin
|
||||
|
||||
# Load with correct key prefix
|
||||
pipe.load_lora_weights(
|
||||
"path/to/lora",
|
||||
weight_name="lora.safetensors"
|
||||
)
|
||||
|
||||
# Try loading into specific component
|
||||
pipe.unet.load_attn_procs("path/to/lora")
|
||||
```
|
||||
|
||||
### LoRA not affecting output
|
||||
|
||||
**Problem**: Generated images look the same with/without LoRA
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Fuse LoRA weights
|
||||
pipe.fuse_lora(lora_scale=1.0)
|
||||
|
||||
# Or set scale explicitly
|
||||
pipe.set_adapters(["lora_name"], adapter_weights=[1.0])
|
||||
|
||||
# Verify LoRA is loaded
|
||||
print(list(pipe.unet.attn_processors.keys()))
|
||||
```
|
||||
|
||||
### Multiple LoRAs conflict
|
||||
|
||||
**Problem**: Multiple LoRAs produce artifacts
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Load with different adapter names
|
||||
pipe.load_lora_weights("lora1", adapter_name="style")
|
||||
pipe.load_lora_weights("lora2", adapter_name="subject")
|
||||
|
||||
# Balance weights
|
||||
pipe.set_adapters(
|
||||
["style", "subject"],
|
||||
adapter_weights=[0.5, 0.5] # Lower weights
|
||||
)
|
||||
|
||||
# Or use LoRA merge before loading
|
||||
# Merge LoRAs offline with appropriate ratios
|
||||
```
|
||||
|
||||
## ControlNet Issues
|
||||
|
||||
### ControlNet not conditioning
|
||||
|
||||
**Problem**: ControlNet has no effect on output
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Check control image format
|
||||
# Should be RGB, matching generation size
|
||||
control_image = control_image.resize((512, 512))
|
||||
|
||||
# Increase conditioning scale
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
image=control_image,
|
||||
controlnet_conditioning_scale=1.0, # Try 0.5-1.5
|
||||
num_inference_steps=30
|
||||
).images[0]
|
||||
|
||||
# Verify ControlNet is loaded
|
||||
print(pipe.controlnet)
|
||||
```
|
||||
|
||||
### Control image preprocessing
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
from controlnet_aux import CannyDetector
|
||||
|
||||
# Proper preprocessing
|
||||
canny = CannyDetector()
|
||||
control_image = canny(input_image)
|
||||
|
||||
# Ensure correct format
|
||||
control_image = control_image.convert("RGB")
|
||||
control_image = control_image.resize((512, 512))
|
||||
```
|
||||
|
||||
## Hub/Download Issues
|
||||
|
||||
### Model download fails
|
||||
|
||||
**Error**: `requests.exceptions.ConnectionError`
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Set longer timeout
|
||||
export HF_HUB_DOWNLOAD_TIMEOUT=600
|
||||
|
||||
# Use mirror if available
|
||||
export HF_ENDPOINT=https://hf-mirror.com
|
||||
|
||||
# Or download manually
|
||||
huggingface-cli download stable-diffusion-v1-5/stable-diffusion-v1-5
|
||||
```
|
||||
|
||||
### Cache issues
|
||||
|
||||
**Error**: `OSError: Can't load model from cache`
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Clear cache
|
||||
rm -rf ~/.cache/huggingface/hub
|
||||
|
||||
# Or set different cache location
|
||||
export HF_HOME=/path/to/cache
|
||||
|
||||
# Force re-download
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"model-id",
|
||||
force_download=True
|
||||
)
|
||||
```
|
||||
|
||||
### Access denied for gated models
|
||||
|
||||
**Error**: `401 Client Error: Unauthorized`
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Login to Hugging Face
|
||||
huggingface-cli login
|
||||
|
||||
# Or use token
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"model-id",
|
||||
token="hf_xxxxx"
|
||||
)
|
||||
|
||||
# Accept model license on Hub website first
|
||||
```
|
||||
|
||||
## Performance Issues
|
||||
|
||||
### Slow generation
|
||||
|
||||
**Problem**: Generation takes too long
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Solution 1: Use faster scheduler
|
||||
from diffusers import DPMSolverMultistepScheduler
|
||||
pipe.scheduler = DPMSolverMultistepScheduler.from_config(
|
||||
pipe.scheduler.config
|
||||
)
|
||||
|
||||
# Solution 2: Reduce steps
|
||||
image = pipe(prompt, num_inference_steps=20).images[0]
|
||||
|
||||
# Solution 3: Use LCM
|
||||
from diffusers import LCMScheduler
|
||||
pipe.load_lora_weights("latent-consistency/lcm-lora-sdxl")
|
||||
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
|
||||
image = pipe(prompt, num_inference_steps=4, guidance_scale=1.0).images[0]
|
||||
|
||||
# Solution 4: Enable xFormers
|
||||
pipe.enable_xformers_memory_efficient_attention()
|
||||
|
||||
# Solution 5: Compile model
|
||||
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
||||
```
|
||||
|
||||
### First generation is slow
|
||||
|
||||
**Problem**: First image takes much longer
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Warm up the model
|
||||
_ = pipe("warmup", num_inference_steps=1)
|
||||
|
||||
# Then run actual generation
|
||||
image = pipe(prompt, num_inference_steps=50).images[0]
|
||||
|
||||
# Compile for faster subsequent runs
|
||||
pipe.unet = torch.compile(pipe.unet)
|
||||
```
|
||||
|
||||
## Debugging Tips
|
||||
|
||||
### Enable debug logging
|
||||
|
||||
```python
|
||||
import logging
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
# Or for specific modules
|
||||
logging.getLogger("diffusers").setLevel(logging.DEBUG)
|
||||
logging.getLogger("transformers").setLevel(logging.DEBUG)
|
||||
```
|
||||
|
||||
### Check model components
|
||||
|
||||
```python
|
||||
# Print pipeline components
|
||||
print(pipe.components)
|
||||
|
||||
# Check model config
|
||||
print(pipe.unet.config)
|
||||
print(pipe.vae.config)
|
||||
print(pipe.scheduler.config)
|
||||
|
||||
# Verify device placement
|
||||
print(pipe.device)
|
||||
for name, module in pipe.components.items():
|
||||
if hasattr(module, 'device'):
|
||||
print(f"{name}: {module.device}")
|
||||
```
|
||||
|
||||
### Validate inputs
|
||||
|
||||
```python
|
||||
# Check image dimensions
|
||||
print(f"Height: {height}, Width: {width}")
|
||||
assert height % 8 == 0, "Height must be divisible by 8"
|
||||
assert width % 8 == 0, "Width must be divisible by 8"
|
||||
|
||||
# Check prompt tokenization
|
||||
tokens = pipe.tokenizer(prompt, return_tensors="pt")
|
||||
print(f"Token count: {tokens.input_ids.shape[1]}") # Max 77 for SD
|
||||
```
|
||||
|
||||
### Save intermediate results
|
||||
|
||||
```python
|
||||
def save_latents_callback(pipe, step_index, timestep, callback_kwargs):
|
||||
latents = callback_kwargs["latents"]
|
||||
|
||||
# Decode and save intermediate
|
||||
with torch.no_grad():
|
||||
image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor).sample
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
|
||||
Image.fromarray((image * 255).astype("uint8")).save(f"step_{step_index}.png")
|
||||
|
||||
return callback_kwargs
|
||||
|
||||
image = pipe(
|
||||
prompt,
|
||||
callback_on_step_end=save_latents_callback,
|
||||
callback_on_step_end_tensor_inputs=["latents"]
|
||||
).images[0]
|
||||
```
|
||||
|
||||
## Getting Help
|
||||
|
||||
1. **Documentation**: https://huggingface.co/docs/diffusers
|
||||
2. **GitHub Issues**: https://github.com/huggingface/diffusers/issues
|
||||
3. **Discord**: https://discord.gg/diffusers
|
||||
4. **Forum**: https://discuss.huggingface.co
|
||||
|
||||
### Reporting Issues
|
||||
|
||||
Include:
|
||||
- Diffusers version: `pip show diffusers`
|
||||
- PyTorch version: `python -c "import torch; print(torch.__version__)"`
|
||||
- CUDA version: `nvcc --version`
|
||||
- GPU model: `nvidia-smi`
|
||||
- Full error traceback
|
||||
- Minimal reproducible code
|
||||
- Model name/ID used
|
||||
320
skills/mlops/models/whisper/SKILL.md
Normal file
320
skills/mlops/models/whisper/SKILL.md
Normal file
@@ -0,0 +1,320 @@
|
||||
---
|
||||
name: whisper
|
||||
description: OpenAI's general-purpose speech recognition model. Supports 99 languages, transcription, translation to English, and language identification. Six model sizes from tiny (39M params) to large (1550M params). Use for speech-to-text, podcast transcription, or multilingual audio processing. Best for robust, multilingual ASR.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [openai-whisper, transformers, torch]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Whisper, Speech Recognition, ASR, Multimodal, Multilingual, OpenAI, Speech-To-Text, Transcription, Translation, Audio Processing]
|
||||
|
||||
---
|
||||
|
||||
# Whisper - Robust Speech Recognition
|
||||
|
||||
OpenAI's multilingual speech recognition model.
|
||||
|
||||
## When to use Whisper
|
||||
|
||||
**Use when:**
|
||||
- Speech-to-text transcription (99 languages)
|
||||
- Podcast/video transcription
|
||||
- Meeting notes automation
|
||||
- Translation to English
|
||||
- Noisy audio transcription
|
||||
- Multilingual audio processing
|
||||
|
||||
**Metrics**:
|
||||
- **72,900+ GitHub stars**
|
||||
- 99 languages supported
|
||||
- Trained on 680,000 hours of audio
|
||||
- MIT License
|
||||
|
||||
**Use alternatives instead**:
|
||||
- **AssemblyAI**: Managed API, speaker diarization
|
||||
- **Deepgram**: Real-time streaming ASR
|
||||
- **Google Speech-to-Text**: Cloud-based
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# Requires Python 3.8-3.11
|
||||
pip install -U openai-whisper
|
||||
|
||||
# Requires ffmpeg
|
||||
# macOS: brew install ffmpeg
|
||||
# Ubuntu: sudo apt install ffmpeg
|
||||
# Windows: choco install ffmpeg
|
||||
```
|
||||
|
||||
### Basic transcription
|
||||
|
||||
```python
|
||||
import whisper
|
||||
|
||||
# Load model
|
||||
model = whisper.load_model("base")
|
||||
|
||||
# Transcribe
|
||||
result = model.transcribe("audio.mp3")
|
||||
|
||||
# Print text
|
||||
print(result["text"])
|
||||
|
||||
# Access segments
|
||||
for segment in result["segments"]:
|
||||
print(f"[{segment['start']:.2f}s - {segment['end']:.2f}s] {segment['text']}")
|
||||
```
|
||||
|
||||
## Model sizes
|
||||
|
||||
```python
|
||||
# Available models
|
||||
models = ["tiny", "base", "small", "medium", "large", "turbo"]
|
||||
|
||||
# Load specific model
|
||||
model = whisper.load_model("turbo") # Fastest, good quality
|
||||
```
|
||||
|
||||
| Model | Parameters | English-only | Multilingual | Speed | VRAM |
|
||||
|-------|------------|--------------|--------------|-------|------|
|
||||
| tiny | 39M | ✓ | ✓ | ~32x | ~1 GB |
|
||||
| base | 74M | ✓ | ✓ | ~16x | ~1 GB |
|
||||
| small | 244M | ✓ | ✓ | ~6x | ~2 GB |
|
||||
| medium | 769M | ✓ | ✓ | ~2x | ~5 GB |
|
||||
| large | 1550M | ✗ | ✓ | 1x | ~10 GB |
|
||||
| turbo | 809M | ✗ | ✓ | ~8x | ~6 GB |
|
||||
|
||||
**Recommendation**: Use `turbo` for best speed/quality, `base` for prototyping
|
||||
|
||||
## Transcription options
|
||||
|
||||
### Language specification
|
||||
|
||||
```python
|
||||
# Auto-detect language
|
||||
result = model.transcribe("audio.mp3")
|
||||
|
||||
# Specify language (faster)
|
||||
result = model.transcribe("audio.mp3", language="en")
|
||||
|
||||
# Supported: en, es, fr, de, it, pt, ru, ja, ko, zh, and 89 more
|
||||
```
|
||||
|
||||
### Task selection
|
||||
|
||||
```python
|
||||
# Transcription (default)
|
||||
result = model.transcribe("audio.mp3", task="transcribe")
|
||||
|
||||
# Translation to English
|
||||
result = model.transcribe("spanish.mp3", task="translate")
|
||||
# Input: Spanish audio → Output: English text
|
||||
```
|
||||
|
||||
### Initial prompt
|
||||
|
||||
```python
|
||||
# Improve accuracy with context
|
||||
result = model.transcribe(
|
||||
"audio.mp3",
|
||||
initial_prompt="This is a technical podcast about machine learning and AI."
|
||||
)
|
||||
|
||||
# Helps with:
|
||||
# - Technical terms
|
||||
# - Proper nouns
|
||||
# - Domain-specific vocabulary
|
||||
```
|
||||
|
||||
### Timestamps
|
||||
|
||||
```python
|
||||
# Word-level timestamps
|
||||
result = model.transcribe("audio.mp3", word_timestamps=True)
|
||||
|
||||
for segment in result["segments"]:
|
||||
for word in segment["words"]:
|
||||
print(f"{word['word']} ({word['start']:.2f}s - {word['end']:.2f}s)")
|
||||
```
|
||||
|
||||
### Temperature fallback
|
||||
|
||||
```python
|
||||
# Retry with different temperatures if confidence low
|
||||
result = model.transcribe(
|
||||
"audio.mp3",
|
||||
temperature=(0.0, 0.2, 0.4, 0.6, 0.8, 1.0)
|
||||
)
|
||||
```
|
||||
|
||||
## Command line usage
|
||||
|
||||
```bash
|
||||
# Basic transcription
|
||||
whisper audio.mp3
|
||||
|
||||
# Specify model
|
||||
whisper audio.mp3 --model turbo
|
||||
|
||||
# Output formats
|
||||
whisper audio.mp3 --output_format txt # Plain text
|
||||
whisper audio.mp3 --output_format srt # Subtitles
|
||||
whisper audio.mp3 --output_format vtt # WebVTT
|
||||
whisper audio.mp3 --output_format json # JSON with timestamps
|
||||
|
||||
# Language
|
||||
whisper audio.mp3 --language Spanish
|
||||
|
||||
# Translation
|
||||
whisper spanish.mp3 --task translate
|
||||
```
|
||||
|
||||
## Batch processing
|
||||
|
||||
```python
|
||||
import os
|
||||
|
||||
audio_files = ["file1.mp3", "file2.mp3", "file3.mp3"]
|
||||
|
||||
for audio_file in audio_files:
|
||||
print(f"Transcribing {audio_file}...")
|
||||
result = model.transcribe(audio_file)
|
||||
|
||||
# Save to file
|
||||
output_file = audio_file.replace(".mp3", ".txt")
|
||||
with open(output_file, "w") as f:
|
||||
f.write(result["text"])
|
||||
```
|
||||
|
||||
## Real-time transcription
|
||||
|
||||
```python
|
||||
# For streaming audio, use faster-whisper
|
||||
# pip install faster-whisper
|
||||
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
model = WhisperModel("base", device="cuda", compute_type="float16")
|
||||
|
||||
# Transcribe with streaming
|
||||
segments, info = model.transcribe("audio.mp3", beam_size=5)
|
||||
|
||||
for segment in segments:
|
||||
print(f"[{segment.start:.2f}s -> {segment.end:.2f}s] {segment.text}")
|
||||
```
|
||||
|
||||
## GPU acceleration
|
||||
|
||||
```python
|
||||
import whisper
|
||||
|
||||
# Automatically uses GPU if available
|
||||
model = whisper.load_model("turbo")
|
||||
|
||||
# Force CPU
|
||||
model = whisper.load_model("turbo", device="cpu")
|
||||
|
||||
# Force GPU
|
||||
model = whisper.load_model("turbo", device="cuda")
|
||||
|
||||
# 10-20× faster on GPU
|
||||
```
|
||||
|
||||
## Integration with other tools
|
||||
|
||||
### Subtitle generation
|
||||
|
||||
```bash
|
||||
# Generate SRT subtitles
|
||||
whisper video.mp4 --output_format srt --language English
|
||||
|
||||
# Output: video.srt
|
||||
```
|
||||
|
||||
### With LangChain
|
||||
|
||||
```python
|
||||
from langchain.document_loaders import WhisperTranscriptionLoader
|
||||
|
||||
loader = WhisperTranscriptionLoader(file_path="audio.mp3")
|
||||
docs = loader.load()
|
||||
|
||||
# Use transcription in RAG
|
||||
from langchain_chroma import Chroma
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
|
||||
vectorstore = Chroma.from_documents(docs, OpenAIEmbeddings())
|
||||
```
|
||||
|
||||
### Extract audio from video
|
||||
|
||||
```bash
|
||||
# Use ffmpeg to extract audio
|
||||
ffmpeg -i video.mp4 -vn -acodec pcm_s16le audio.wav
|
||||
|
||||
# Then transcribe
|
||||
whisper audio.wav
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Use turbo model** - Best speed/quality for English
|
||||
2. **Specify language** - Faster than auto-detect
|
||||
3. **Add initial prompt** - Improves technical terms
|
||||
4. **Use GPU** - 10-20× faster
|
||||
5. **Batch process** - More efficient
|
||||
6. **Convert to WAV** - Better compatibility
|
||||
7. **Split long audio** - <30 min chunks
|
||||
8. **Check language support** - Quality varies by language
|
||||
9. **Use faster-whisper** - 4× faster than openai-whisper
|
||||
10. **Monitor VRAM** - Scale model size to hardware
|
||||
|
||||
## Performance
|
||||
|
||||
| Model | Real-time factor (CPU) | Real-time factor (GPU) |
|
||||
|-------|------------------------|------------------------|
|
||||
| tiny | ~0.32 | ~0.01 |
|
||||
| base | ~0.16 | ~0.01 |
|
||||
| turbo | ~0.08 | ~0.01 |
|
||||
| large | ~1.0 | ~0.05 |
|
||||
|
||||
*Real-time factor: 0.1 = 10× faster than real-time*
|
||||
|
||||
## Language support
|
||||
|
||||
Top-supported languages:
|
||||
- English (en)
|
||||
- Spanish (es)
|
||||
- French (fr)
|
||||
- German (de)
|
||||
- Italian (it)
|
||||
- Portuguese (pt)
|
||||
- Russian (ru)
|
||||
- Japanese (ja)
|
||||
- Korean (ko)
|
||||
- Chinese (zh)
|
||||
|
||||
Full list: 99 languages total
|
||||
|
||||
## Limitations
|
||||
|
||||
1. **Hallucinations** - May repeat or invent text
|
||||
2. **Long-form accuracy** - Degrades on >30 min audio
|
||||
3. **Speaker identification** - No diarization
|
||||
4. **Accents** - Quality varies
|
||||
5. **Background noise** - Can affect accuracy
|
||||
6. **Real-time latency** - Not suitable for live captioning
|
||||
|
||||
## Resources
|
||||
|
||||
- **GitHub**: https://github.com/openai/whisper ⭐ 72,900+
|
||||
- **Paper**: https://arxiv.org/abs/2212.04356
|
||||
- **Model Card**: https://github.com/openai/whisper/blob/main/model-card.md
|
||||
- **Colab**: Available in repo
|
||||
- **License**: MIT
|
||||
|
||||
|
||||
189
skills/mlops/models/whisper/references/languages.md
Normal file
189
skills/mlops/models/whisper/references/languages.md
Normal file
@@ -0,0 +1,189 @@
|
||||
# Whisper Language Support Guide
|
||||
|
||||
Complete guide to Whisper's multilingual capabilities.
|
||||
|
||||
## Supported languages (99 total)
|
||||
|
||||
### Top-tier support (WER < 10%)
|
||||
|
||||
- English (en)
|
||||
- Spanish (es)
|
||||
- French (fr)
|
||||
- German (de)
|
||||
- Italian (it)
|
||||
- Portuguese (pt)
|
||||
- Dutch (nl)
|
||||
- Polish (pl)
|
||||
- Russian (ru)
|
||||
- Japanese (ja)
|
||||
- Korean (ko)
|
||||
- Chinese (zh)
|
||||
|
||||
### Good support (WER 10-20%)
|
||||
|
||||
- Arabic (ar)
|
||||
- Turkish (tr)
|
||||
- Vietnamese (vi)
|
||||
- Swedish (sv)
|
||||
- Finnish (fi)
|
||||
- Czech (cs)
|
||||
- Romanian (ro)
|
||||
- Hungarian (hu)
|
||||
- Danish (da)
|
||||
- Norwegian (no)
|
||||
- Thai (th)
|
||||
- Hebrew (he)
|
||||
- Greek (el)
|
||||
- Indonesian (id)
|
||||
- Malay (ms)
|
||||
|
||||
### Full list (99 languages)
|
||||
|
||||
Afrikaans, Albanian, Amharic, Arabic, Armenian, Assamese, Azerbaijani, Bashkir, Basque, Belarusian, Bengali, Bosnian, Breton, Bulgarian, Burmese, Cantonese, Catalan, Chinese, Croatian, Czech, Danish, Dutch, English, Estonian, Faroese, Finnish, French, Galician, Georgian, German, Greek, Gujarati, Haitian Creole, Hausa, Hawaiian, Hebrew, Hindi, Hungarian, Icelandic, Indonesian, Italian, Japanese, Javanese, Kannada, Kazakh, Khmer, Korean, Lao, Latin, Latvian, Lingala, Lithuanian, Luxembourgish, Macedonian, Malagasy, Malay, Malayalam, Maltese, Maori, Marathi, Moldavian, Mongolian, Myanmar, Nepali, Norwegian, Nynorsk, Occitan, Pashto, Persian, Polish, Portuguese, Punjabi, Pushto, Romanian, Russian, Sanskrit, Serbian, Shona, Sindhi, Sinhala, Slovak, Slovenian, Somali, Spanish, Sundanese, Swahili, Swedish, Tagalog, Tajik, Tamil, Tatar, Telugu, Thai, Tibetan, Turkish, Turkmen, Ukrainian, Urdu, Uzbek, Vietnamese, Welsh, Yiddish, Yoruba
|
||||
|
||||
## Usage examples
|
||||
|
||||
### Auto-detect language
|
||||
|
||||
```python
|
||||
import whisper
|
||||
|
||||
model = whisper.load_model("turbo")
|
||||
|
||||
# Auto-detect language
|
||||
result = model.transcribe("audio.mp3")
|
||||
|
||||
print(f"Detected language: {result['language']}")
|
||||
print(f"Text: {result['text']}")
|
||||
```
|
||||
|
||||
### Specify language (faster)
|
||||
|
||||
```python
|
||||
# Specify language for faster transcription
|
||||
result = model.transcribe("audio.mp3", language="es") # Spanish
|
||||
result = model.transcribe("audio.mp3", language="fr") # French
|
||||
result = model.transcribe("audio.mp3", language="ja") # Japanese
|
||||
```
|
||||
|
||||
### Translation to English
|
||||
|
||||
```python
|
||||
# Translate any language to English
|
||||
result = model.transcribe(
|
||||
"spanish_audio.mp3",
|
||||
task="translate" # Translates to English
|
||||
)
|
||||
|
||||
print(f"Original language: {result['language']}")
|
||||
print(f"English translation: {result['text']}")
|
||||
```
|
||||
|
||||
## Language-specific tips
|
||||
|
||||
### Chinese
|
||||
|
||||
```python
|
||||
# Chinese works well with larger models
|
||||
model = whisper.load_model("large")
|
||||
|
||||
result = model.transcribe(
|
||||
"chinese_audio.mp3",
|
||||
language="zh",
|
||||
initial_prompt="这是一段关于技术的讨论" # Context helps
|
||||
)
|
||||
```
|
||||
|
||||
### Japanese
|
||||
|
||||
```python
|
||||
# Japanese benefits from initial prompt
|
||||
result = model.transcribe(
|
||||
"japanese_audio.mp3",
|
||||
language="ja",
|
||||
initial_prompt="これは技術的な会議の録音です"
|
||||
)
|
||||
```
|
||||
|
||||
### Arabic
|
||||
|
||||
```python
|
||||
# Arabic: Use large model for best results
|
||||
model = whisper.load_model("large")
|
||||
|
||||
result = model.transcribe(
|
||||
"arabic_audio.mp3",
|
||||
language="ar"
|
||||
)
|
||||
```
|
||||
|
||||
## Model size recommendations
|
||||
|
||||
| Language Tier | Recommended Model | WER |
|
||||
|---------------|-------------------|-----|
|
||||
| Top-tier (en, es, fr, de) | base/turbo | < 10% |
|
||||
| Good (ar, tr, vi) | medium/large | 10-20% |
|
||||
| Lower-resource | large | 20-30% |
|
||||
|
||||
## Performance by language
|
||||
|
||||
### English
|
||||
|
||||
- **tiny**: WER ~15%
|
||||
- **base**: WER ~8%
|
||||
- **small**: WER ~5%
|
||||
- **medium**: WER ~4%
|
||||
- **large**: WER ~3%
|
||||
- **turbo**: WER ~3.5%
|
||||
|
||||
### Spanish
|
||||
|
||||
- **tiny**: WER ~20%
|
||||
- **base**: WER ~12%
|
||||
- **medium**: WER ~6%
|
||||
- **large**: WER ~4%
|
||||
|
||||
### Chinese
|
||||
|
||||
- **small**: WER ~15%
|
||||
- **medium**: WER ~8%
|
||||
- **large**: WER ~5%
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Use English-only models** - Better for small models (tiny/base)
|
||||
2. **Specify language** - Faster than auto-detect
|
||||
3. **Add initial prompt** - Improves accuracy for technical terms
|
||||
4. **Use larger models** - For low-resource languages
|
||||
5. **Test on sample** - Quality varies by accent/dialect
|
||||
6. **Consider audio quality** - Clear audio = better results
|
||||
7. **Check language codes** - Use ISO 639-1 codes (2 letters)
|
||||
|
||||
## Language detection
|
||||
|
||||
```python
|
||||
# Detect language only (no transcription)
|
||||
import whisper
|
||||
|
||||
model = whisper.load_model("base")
|
||||
|
||||
# Load audio
|
||||
audio = whisper.load_audio("audio.mp3")
|
||||
audio = whisper.pad_or_trim(audio)
|
||||
|
||||
# Make log-Mel spectrogram
|
||||
mel = whisper.log_mel_spectrogram(audio).to(model.device)
|
||||
|
||||
# Detect language
|
||||
_, probs = model.detect_language(mel)
|
||||
detected_language = max(probs, key=probs.get)
|
||||
|
||||
print(f"Detected language: {detected_language}")
|
||||
print(f"Confidence: {probs[detected_language]:.2%}")
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- **Paper**: https://arxiv.org/abs/2212.04356
|
||||
- **GitHub**: https://github.com/openai/whisper
|
||||
- **Model Card**: https://github.com/openai/whisper/blob/main/model-card.md
|
||||
Reference in New Issue
Block a user