504 lines
13 KiB
Markdown
504 lines
13 KiB
Markdown
---
|
||
name: segment-anything-model
|
||
description: Foundation model for image segmentation with zero-shot transfer. Use when you need to segment any object in images using points, boxes, or masks as prompts, or automatically generate all object masks in an image.
|
||
version: 1.0.0
|
||
author: Orchestra Research
|
||
license: MIT
|
||
dependencies: [segment-anything, transformers>=4.30.0, torch>=1.7.0]
|
||
metadata:
|
||
hermes:
|
||
tags: [Multimodal, Image Segmentation, Computer Vision, SAM, Zero-Shot]
|
||
|
||
---
|
||
|
||
# Segment Anything Model (SAM)
|
||
|
||
Comprehensive guide to using Meta AI's Segment Anything Model for zero-shot image segmentation.
|
||
|
||
## When to use SAM
|
||
|
||
**Use SAM when:**
|
||
- Need to segment any object in images without task-specific training
|
||
- Building interactive annotation tools with point/box prompts
|
||
- Generating training data for other vision models
|
||
- Need zero-shot transfer to new image domains
|
||
- Building object detection/segmentation pipelines
|
||
- Processing medical, satellite, or domain-specific images
|
||
|
||
**Key features:**
|
||
- **Zero-shot segmentation**: Works on any image domain without fine-tuning
|
||
- **Flexible prompts**: Points, bounding boxes, or previous masks
|
||
- **Automatic segmentation**: Generate all object masks automatically
|
||
- **High quality**: Trained on 1.1 billion masks from 11 million images
|
||
- **Multiple model sizes**: ViT-B (fastest), ViT-L, ViT-H (most accurate)
|
||
- **ONNX export**: Deploy in browsers and edge devices
|
||
|
||
**Use alternatives instead:**
|
||
- **YOLO/Detectron2**: For real-time object detection with classes
|
||
- **Mask2Former**: For semantic/panoptic segmentation with categories
|
||
- **GroundingDINO + SAM**: For text-prompted segmentation
|
||
- **SAM 2**: For video segmentation tasks
|
||
|
||
## Quick start
|
||
|
||
### Installation
|
||
|
||
```bash
|
||
# From GitHub
|
||
pip install git+https://github.com/facebookresearch/segment-anything.git
|
||
|
||
# Optional dependencies
|
||
pip install opencv-python pycocotools matplotlib
|
||
|
||
# Or use HuggingFace transformers
|
||
pip install transformers
|
||
```
|
||
|
||
### Download checkpoints
|
||
|
||
```bash
|
||
# ViT-H (largest, most accurate) - 2.4GB
|
||
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
|
||
|
||
# ViT-L (medium) - 1.2GB
|
||
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth
|
||
|
||
# ViT-B (smallest, fastest) - 375MB
|
||
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth
|
||
```
|
||
|
||
### Basic usage with SamPredictor
|
||
|
||
```python
|
||
import numpy as np
|
||
from segment_anything import sam_model_registry, SamPredictor
|
||
|
||
# Load model
|
||
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
|
||
sam.to(device="cuda")
|
||
|
||
# Create predictor
|
||
predictor = SamPredictor(sam)
|
||
|
||
# Set image (computes embeddings once)
|
||
image = cv2.imread("image.jpg")
|
||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||
predictor.set_image(image)
|
||
|
||
# Predict with point prompts
|
||
input_point = np.array([[500, 375]]) # (x, y) coordinates
|
||
input_label = np.array([1]) # 1 = foreground, 0 = background
|
||
|
||
masks, scores, logits = predictor.predict(
|
||
point_coords=input_point,
|
||
point_labels=input_label,
|
||
multimask_output=True # Returns 3 mask options
|
||
)
|
||
|
||
# Select best mask
|
||
best_mask = masks[np.argmax(scores)]
|
||
```
|
||
|
||
### HuggingFace Transformers
|
||
|
||
```python
|
||
import torch
|
||
from PIL import Image
|
||
from transformers import SamModel, SamProcessor
|
||
|
||
# Load model and processor
|
||
model = SamModel.from_pretrained("facebook/sam-vit-huge")
|
||
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
|
||
model.to("cuda")
|
||
|
||
# Process image with point prompt
|
||
image = Image.open("image.jpg")
|
||
input_points = [[[450, 600]]] # Batch of points
|
||
|
||
inputs = processor(image, input_points=input_points, return_tensors="pt")
|
||
inputs = {k: v.to("cuda") for k, v in inputs.items()}
|
||
|
||
# Generate masks
|
||
with torch.no_grad():
|
||
outputs = model(**inputs)
|
||
|
||
# Post-process masks to original size
|
||
masks = processor.image_processor.post_process_masks(
|
||
outputs.pred_masks.cpu(),
|
||
inputs["original_sizes"].cpu(),
|
||
inputs["reshaped_input_sizes"].cpu()
|
||
)
|
||
```
|
||
|
||
## Core concepts
|
||
|
||
### Model architecture
|
||
|
||
```
|
||
SAM Architecture:
|
||
┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
|
||
│ Image Encoder │────▶│ Prompt Encoder │────▶│ Mask Decoder │
|
||
│ (ViT) │ │ (Points/Boxes) │ │ (Transformer) │
|
||
└─────────────────┘ └─────────────────┘ └─────────────────┘
|
||
│ │ │
|
||
Image Embeddings Prompt Embeddings Masks + IoU
|
||
(computed once) (per prompt) predictions
|
||
```
|
||
|
||
### Model variants
|
||
|
||
| Model | Checkpoint | Size | Speed | Accuracy |
|
||
|-------|------------|------|-------|----------|
|
||
| ViT-H | `vit_h` | 2.4 GB | Slowest | Best |
|
||
| ViT-L | `vit_l` | 1.2 GB | Medium | Good |
|
||
| ViT-B | `vit_b` | 375 MB | Fastest | Good |
|
||
|
||
### Prompt types
|
||
|
||
| Prompt | Description | Use Case |
|
||
|--------|-------------|----------|
|
||
| Point (foreground) | Click on object | Single object selection |
|
||
| Point (background) | Click outside object | Exclude regions |
|
||
| Bounding box | Rectangle around object | Larger objects |
|
||
| Previous mask | Low-res mask input | Iterative refinement |
|
||
|
||
## Interactive segmentation
|
||
|
||
### Point prompts
|
||
|
||
```python
|
||
# Single foreground point
|
||
input_point = np.array([[500, 375]])
|
||
input_label = np.array([1])
|
||
|
||
masks, scores, logits = predictor.predict(
|
||
point_coords=input_point,
|
||
point_labels=input_label,
|
||
multimask_output=True
|
||
)
|
||
|
||
# Multiple points (foreground + background)
|
||
input_points = np.array([[500, 375], [600, 400], [450, 300]])
|
||
input_labels = np.array([1, 1, 0]) # 2 foreground, 1 background
|
||
|
||
masks, scores, logits = predictor.predict(
|
||
point_coords=input_points,
|
||
point_labels=input_labels,
|
||
multimask_output=False # Single mask when prompts are clear
|
||
)
|
||
```
|
||
|
||
### Box prompts
|
||
|
||
```python
|
||
# Bounding box [x1, y1, x2, y2]
|
||
input_box = np.array([425, 600, 700, 875])
|
||
|
||
masks, scores, logits = predictor.predict(
|
||
box=input_box,
|
||
multimask_output=False
|
||
)
|
||
```
|
||
|
||
### Combined prompts
|
||
|
||
```python
|
||
# Box + points for precise control
|
||
masks, scores, logits = predictor.predict(
|
||
point_coords=np.array([[500, 375]]),
|
||
point_labels=np.array([1]),
|
||
box=np.array([400, 300, 700, 600]),
|
||
multimask_output=False
|
||
)
|
||
```
|
||
|
||
### Iterative refinement
|
||
|
||
```python
|
||
# Initial prediction
|
||
masks, scores, logits = predictor.predict(
|
||
point_coords=np.array([[500, 375]]),
|
||
point_labels=np.array([1]),
|
||
multimask_output=True
|
||
)
|
||
|
||
# Refine with additional point using previous mask
|
||
masks, scores, logits = predictor.predict(
|
||
point_coords=np.array([[500, 375], [550, 400]]),
|
||
point_labels=np.array([1, 0]), # Add background point
|
||
mask_input=logits[np.argmax(scores)][None, :, :], # Use best mask
|
||
multimask_output=False
|
||
)
|
||
```
|
||
|
||
## Automatic mask generation
|
||
|
||
### Basic automatic segmentation
|
||
|
||
```python
|
||
from segment_anything import SamAutomaticMaskGenerator
|
||
|
||
# Create generator
|
||
mask_generator = SamAutomaticMaskGenerator(sam)
|
||
|
||
# Generate all masks
|
||
masks = mask_generator.generate(image)
|
||
|
||
# Each mask contains:
|
||
# - segmentation: binary mask
|
||
# - bbox: [x, y, w, h]
|
||
# - area: pixel count
|
||
# - predicted_iou: quality score
|
||
# - stability_score: robustness score
|
||
# - point_coords: generating point
|
||
```
|
||
|
||
### Customized generation
|
||
|
||
```python
|
||
mask_generator = SamAutomaticMaskGenerator(
|
||
model=sam,
|
||
points_per_side=32, # Grid density (more = more masks)
|
||
pred_iou_thresh=0.88, # Quality threshold
|
||
stability_score_thresh=0.95, # Stability threshold
|
||
crop_n_layers=1, # Multi-scale crops
|
||
crop_n_points_downscale_factor=2,
|
||
min_mask_region_area=100, # Remove tiny masks
|
||
)
|
||
|
||
masks = mask_generator.generate(image)
|
||
```
|
||
|
||
### Filtering masks
|
||
|
||
```python
|
||
# Sort by area (largest first)
|
||
masks = sorted(masks, key=lambda x: x['area'], reverse=True)
|
||
|
||
# Filter by predicted IoU
|
||
high_quality = [m for m in masks if m['predicted_iou'] > 0.9]
|
||
|
||
# Filter by stability score
|
||
stable_masks = [m for m in masks if m['stability_score'] > 0.95]
|
||
```
|
||
|
||
## Batched inference
|
||
|
||
### Multiple images
|
||
|
||
```python
|
||
# Process multiple images efficiently
|
||
images = [cv2.imread(f"image_{i}.jpg") for i in range(10)]
|
||
|
||
all_masks = []
|
||
for image in images:
|
||
predictor.set_image(image)
|
||
masks, _, _ = predictor.predict(
|
||
point_coords=np.array([[500, 375]]),
|
||
point_labels=np.array([1]),
|
||
multimask_output=True
|
||
)
|
||
all_masks.append(masks)
|
||
```
|
||
|
||
### Multiple prompts per image
|
||
|
||
```python
|
||
# Process multiple prompts efficiently (one image encoding)
|
||
predictor.set_image(image)
|
||
|
||
# Batch of point prompts
|
||
points = [
|
||
np.array([[100, 100]]),
|
||
np.array([[200, 200]]),
|
||
np.array([[300, 300]])
|
||
]
|
||
|
||
all_masks = []
|
||
for point in points:
|
||
masks, scores, _ = predictor.predict(
|
||
point_coords=point,
|
||
point_labels=np.array([1]),
|
||
multimask_output=True
|
||
)
|
||
all_masks.append(masks[np.argmax(scores)])
|
||
```
|
||
|
||
## ONNX deployment
|
||
|
||
### Export model
|
||
|
||
```bash
|
||
python scripts/export_onnx_model.py \
|
||
--checkpoint sam_vit_h_4b8939.pth \
|
||
--model-type vit_h \
|
||
--output sam_onnx.onnx \
|
||
--return-single-mask
|
||
```
|
||
|
||
### Use ONNX model
|
||
|
||
```python
|
||
import onnxruntime
|
||
|
||
# Load ONNX model
|
||
ort_session = onnxruntime.InferenceSession("sam_onnx.onnx")
|
||
|
||
# Run inference (image embeddings computed separately)
|
||
masks = ort_session.run(
|
||
None,
|
||
{
|
||
"image_embeddings": image_embeddings,
|
||
"point_coords": point_coords,
|
||
"point_labels": point_labels,
|
||
"mask_input": np.zeros((1, 1, 256, 256), dtype=np.float32),
|
||
"has_mask_input": np.array([0], dtype=np.float32),
|
||
"orig_im_size": np.array([h, w], dtype=np.float32)
|
||
}
|
||
)
|
||
```
|
||
|
||
## Common workflows
|
||
|
||
### Workflow 1: Annotation tool
|
||
|
||
```python
|
||
import cv2
|
||
|
||
# Load model
|
||
predictor = SamPredictor(sam)
|
||
predictor.set_image(image)
|
||
|
||
def on_click(event, x, y, flags, param):
|
||
if event == cv2.EVENT_LBUTTONDOWN:
|
||
# Foreground point
|
||
masks, scores, _ = predictor.predict(
|
||
point_coords=np.array([[x, y]]),
|
||
point_labels=np.array([1]),
|
||
multimask_output=True
|
||
)
|
||
# Display best mask
|
||
display_mask(masks[np.argmax(scores)])
|
||
```
|
||
|
||
### Workflow 2: Object extraction
|
||
|
||
```python
|
||
def extract_object(image, point):
|
||
"""Extract object at point with transparent background."""
|
||
predictor.set_image(image)
|
||
|
||
masks, scores, _ = predictor.predict(
|
||
point_coords=np.array([point]),
|
||
point_labels=np.array([1]),
|
||
multimask_output=True
|
||
)
|
||
|
||
best_mask = masks[np.argmax(scores)]
|
||
|
||
# Create RGBA output
|
||
rgba = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8)
|
||
rgba[:, :, :3] = image
|
||
rgba[:, :, 3] = best_mask * 255
|
||
|
||
return rgba
|
||
```
|
||
|
||
### Workflow 3: Medical image segmentation
|
||
|
||
```python
|
||
# Process medical images (grayscale to RGB)
|
||
medical_image = cv2.imread("scan.png", cv2.IMREAD_GRAYSCALE)
|
||
rgb_image = cv2.cvtColor(medical_image, cv2.COLOR_GRAY2RGB)
|
||
|
||
predictor.set_image(rgb_image)
|
||
|
||
# Segment region of interest
|
||
masks, scores, _ = predictor.predict(
|
||
box=np.array([x1, y1, x2, y2]), # ROI bounding box
|
||
multimask_output=True
|
||
)
|
||
```
|
||
|
||
## Output format
|
||
|
||
### Mask data structure
|
||
|
||
```python
|
||
# SamAutomaticMaskGenerator output
|
||
{
|
||
"segmentation": np.ndarray, # H×W binary mask
|
||
"bbox": [x, y, w, h], # Bounding box
|
||
"area": int, # Pixel count
|
||
"predicted_iou": float, # 0-1 quality score
|
||
"stability_score": float, # 0-1 robustness score
|
||
"crop_box": [x, y, w, h], # Generation crop region
|
||
"point_coords": [[x, y]], # Input point
|
||
}
|
||
```
|
||
|
||
### COCO RLE format
|
||
|
||
```python
|
||
from pycocotools import mask as mask_utils
|
||
|
||
# Encode mask to RLE
|
||
rle = mask_utils.encode(np.asfortranarray(mask.astype(np.uint8)))
|
||
rle["counts"] = rle["counts"].decode("utf-8")
|
||
|
||
# Decode RLE to mask
|
||
decoded_mask = mask_utils.decode(rle)
|
||
```
|
||
|
||
## Performance optimization
|
||
|
||
### GPU memory
|
||
|
||
```python
|
||
# Use smaller model for limited VRAM
|
||
sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth")
|
||
|
||
# Process images in batches
|
||
# Clear CUDA cache between large batches
|
||
torch.cuda.empty_cache()
|
||
```
|
||
|
||
### Speed optimization
|
||
|
||
```python
|
||
# Use half precision
|
||
sam = sam.half()
|
||
|
||
# Reduce points for automatic generation
|
||
mask_generator = SamAutomaticMaskGenerator(
|
||
model=sam,
|
||
points_per_side=16, # Default is 32
|
||
)
|
||
|
||
# Use ONNX for deployment
|
||
# Export with --return-single-mask for faster inference
|
||
```
|
||
|
||
## Common issues
|
||
|
||
| Issue | Solution |
|
||
|-------|----------|
|
||
| Out of memory | Use ViT-B model, reduce image size |
|
||
| Slow inference | Use ViT-B, reduce points_per_side |
|
||
| Poor mask quality | Try different prompts, use box + points |
|
||
| Edge artifacts | Use stability_score filtering |
|
||
| Small objects missed | Increase points_per_side |
|
||
|
||
## References
|
||
|
||
- **[Advanced Usage](references/advanced-usage.md)** - Batching, fine-tuning, integration
|
||
- **[Troubleshooting](references/troubleshooting.md)** - Common issues and solutions
|
||
|
||
## Resources
|
||
|
||
- **GitHub**: https://github.com/facebookresearch/segment-anything
|
||
- **Paper**: https://arxiv.org/abs/2304.02643
|
||
- **Demo**: https://segment-anything.com
|
||
- **SAM 2 (Video)**: https://github.com/facebookresearch/segment-anything-2
|
||
- **HuggingFace**: https://huggingface.co/facebook/sam-vit-huge
|