# Segment Anything Advanced Usage Guide ## SAM 2 (Video Segmentation) ### Overview SAM 2 extends SAM to video segmentation with streaming memory architecture: ```bash pip install git+https://github.com/facebookresearch/segment-anything-2.git ``` ### Video segmentation ```python from sam2.build_sam import build_sam2_video_predictor predictor = build_sam2_video_predictor("sam2_hiera_l.yaml", "sam2_hiera_large.pt") # Initialize with video predictor.init_state(video_path="video.mp4") # Add prompt on first frame predictor.add_new_points( frame_idx=0, obj_id=1, points=[[100, 200]], labels=[1] ) # Propagate through video for frame_idx, masks in predictor.propagate_in_video(): # masks contains segmentation for all tracked objects process_frame(frame_idx, masks) ``` ### SAM 2 vs SAM comparison | Feature | SAM | SAM 2 | |---------|-----|-------| | Input | Images only | Images + Videos | | Architecture | ViT + Decoder | Hiera + Memory | | Memory | Per-image | Streaming memory bank | | Tracking | No | Yes, across frames | | Models | ViT-B/L/H | Hiera-T/S/B+/L | ## Grounded SAM (Text-Prompted Segmentation) ### Setup ```bash pip install groundingdino-py pip install git+https://github.com/facebookresearch/segment-anything.git ``` ### Text-to-mask pipeline ```python from groundingdino.util.inference import load_model, predict from segment_anything import sam_model_registry, SamPredictor import cv2 # Load Grounding DINO grounding_model = load_model("groundingdino_swint_ogc.pth", "GroundingDINO_SwinT_OGC.py") # Load SAM sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth") predictor = SamPredictor(sam) def text_to_mask(image, text_prompt, box_threshold=0.3, text_threshold=0.25): """Generate masks from text description.""" # Get bounding boxes from text boxes, logits, phrases = predict( model=grounding_model, image=image, caption=text_prompt, box_threshold=box_threshold, text_threshold=text_threshold ) # Generate masks with SAM predictor.set_image(image) masks = [] for box in boxes: # Convert normalized box to pixel coordinates h, w = image.shape[:2] box_pixels = box * np.array([w, h, w, h]) mask, score, _ = predictor.predict( box=box_pixels, multimask_output=False ) masks.append(mask[0]) return masks, boxes, phrases # Usage image = cv2.imread("image.jpg") image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) masks, boxes, phrases = text_to_mask(image, "person . dog . car") ``` ## Batched Processing ### Efficient multi-image processing ```python import torch from segment_anything import SamPredictor, sam_model_registry class BatchedSAM: def __init__(self, checkpoint, model_type="vit_h", device="cuda"): self.sam = sam_model_registry[model_type](checkpoint=checkpoint) self.sam.to(device) self.predictor = SamPredictor(self.sam) self.device = device def process_batch(self, images, prompts): """Process multiple images with corresponding prompts.""" results = [] for image, prompt in zip(images, prompts): self.predictor.set_image(image) if "point" in prompt: masks, scores, _ = self.predictor.predict( point_coords=prompt["point"], point_labels=prompt["label"], multimask_output=True ) elif "box" in prompt: masks, scores, _ = self.predictor.predict( box=prompt["box"], multimask_output=False ) results.append({ "masks": masks, "scores": scores, "best_mask": masks[np.argmax(scores)] }) return results # Usage batch_sam = BatchedSAM("sam_vit_h_4b8939.pth") images = [cv2.imread(f"image_{i}.jpg") for i in range(10)] prompts = [{"point": np.array([[100, 100]]), "label": np.array([1])} for _ in range(10)] results = batch_sam.process_batch(images, prompts) ``` ### Parallel automatic mask generation ```python from concurrent.futures import ThreadPoolExecutor from segment_anything import SamAutomaticMaskGenerator def generate_masks_parallel(images, num_workers=4): """Generate masks for multiple images in parallel.""" # Note: Each worker needs its own model instance def worker_init(): sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth") return SamAutomaticMaskGenerator(sam) generators = [worker_init() for _ in range(num_workers)] def process_image(args): idx, image = args generator = generators[idx % num_workers] return generator.generate(image) with ThreadPoolExecutor(max_workers=num_workers) as executor: results = list(executor.map(process_image, enumerate(images))) return results ``` ## Custom Integration ### FastAPI service ```python from fastapi import FastAPI, File, UploadFile from pydantic import BaseModel import numpy as np import cv2 import io app = FastAPI() # Load model once sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth") sam.to("cuda") predictor = SamPredictor(sam) class PointPrompt(BaseModel): x: int y: int label: int = 1 @app.post("/segment/point") async def segment_with_point( file: UploadFile = File(...), points: list[PointPrompt] = [] ): # Read image contents = await file.read() nparr = np.frombuffer(contents, np.uint8) image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Set image predictor.set_image(image) # Prepare prompts point_coords = np.array([[p.x, p.y] for p in points]) point_labels = np.array([p.label for p in points]) # Generate masks masks, scores, _ = predictor.predict( point_coords=point_coords, point_labels=point_labels, multimask_output=True ) best_idx = np.argmax(scores) return { "mask": masks[best_idx].tolist(), "score": float(scores[best_idx]), "all_scores": scores.tolist() } @app.post("/segment/auto") async def segment_automatic(file: UploadFile = File(...)): contents = await file.read() nparr = np.frombuffer(contents, np.uint8) image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) mask_generator = SamAutomaticMaskGenerator(sam) masks = mask_generator.generate(image) return { "num_masks": len(masks), "masks": [ { "bbox": m["bbox"], "area": m["area"], "predicted_iou": m["predicted_iou"], "stability_score": m["stability_score"] } for m in masks ] } ``` ### Gradio interface ```python import gradio as gr import numpy as np # Load model sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth") predictor = SamPredictor(sam) def segment_image(image, evt: gr.SelectData): """Segment object at clicked point.""" predictor.set_image(image) point = np.array([[evt.index[0], evt.index[1]]]) label = np.array([1]) masks, scores, _ = predictor.predict( point_coords=point, point_labels=label, multimask_output=True ) best_mask = masks[np.argmax(scores)] # Overlay mask on image overlay = image.copy() overlay[best_mask] = overlay[best_mask] * 0.5 + np.array([255, 0, 0]) * 0.5 return overlay with gr.Blocks() as demo: gr.Markdown("# SAM Interactive Segmentation") gr.Markdown("Click on an object to segment it") with gr.Row(): input_image = gr.Image(label="Input Image", interactive=True) output_image = gr.Image(label="Segmented Image") input_image.select(segment_image, inputs=[input_image], outputs=[output_image]) demo.launch() ``` ## Fine-Tuning SAM ### LoRA fine-tuning (experimental) ```python from peft import LoraConfig, get_peft_model from transformers import SamModel # Load model model = SamModel.from_pretrained("facebook/sam-vit-base") # Configure LoRA lora_config = LoraConfig( r=16, lora_alpha=32, target_modules=["qkv"], # Attention layers lora_dropout=0.1, bias="none", ) # Apply LoRA model = get_peft_model(model, lora_config) # Training loop (simplified) optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) for batch in dataloader: outputs = model( pixel_values=batch["pixel_values"], input_points=batch["input_points"], input_labels=batch["input_labels"] ) # Custom loss (e.g., IoU loss with ground truth) loss = compute_loss(outputs.pred_masks, batch["gt_masks"]) loss.backward() optimizer.step() optimizer.zero_grad() ``` ### MedSAM (Medical imaging) ```python # MedSAM is a fine-tuned SAM for medical images # https://github.com/bowang-lab/MedSAM from segment_anything import sam_model_registry, SamPredictor import torch # Load MedSAM checkpoint medsam = sam_model_registry["vit_b"](checkpoint="medsam_vit_b.pth") medsam.to("cuda") predictor = SamPredictor(medsam) # Process medical image # Convert grayscale to RGB if needed medical_image = cv2.imread("ct_scan.png", cv2.IMREAD_GRAYSCALE) rgb_image = np.stack([medical_image] * 3, axis=-1) predictor.set_image(rgb_image) # Segment with box prompt (common for medical imaging) masks, scores, _ = predictor.predict( box=np.array([x1, y1, x2, y2]), multimask_output=False ) ``` ## Advanced Mask Processing ### Mask refinement ```python import cv2 from scipy import ndimage def refine_mask(mask, kernel_size=5, iterations=2): """Refine mask with morphological operations.""" kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size)) # Close small holes closed = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, kernel, iterations=iterations) # Remove small noise opened = cv2.morphologyEx(closed, cv2.MORPH_OPEN, kernel, iterations=iterations) return opened.astype(bool) def fill_holes(mask): """Fill holes in mask.""" filled = ndimage.binary_fill_holes(mask) return filled def remove_small_regions(mask, min_area=100): """Remove small disconnected regions.""" labeled, num_features = ndimage.label(mask) sizes = ndimage.sum(mask, labeled, range(1, num_features + 1)) # Keep only regions larger than min_area mask_clean = np.zeros_like(mask) for i, size in enumerate(sizes, 1): if size >= min_area: mask_clean[labeled == i] = True return mask_clean ``` ### Mask to polygon conversion ```python import cv2 def mask_to_polygons(mask, epsilon_factor=0.01): """Convert binary mask to polygon coordinates.""" contours, _ = cv2.findContours( mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE ) polygons = [] for contour in contours: epsilon = epsilon_factor * cv2.arcLength(contour, True) approx = cv2.approxPolyDP(contour, epsilon, True) polygon = approx.squeeze().tolist() if len(polygon) >= 3: # Valid polygon polygons.append(polygon) return polygons def polygons_to_mask(polygons, height, width): """Convert polygons back to binary mask.""" mask = np.zeros((height, width), dtype=np.uint8) for polygon in polygons: pts = np.array(polygon, dtype=np.int32) cv2.fillPoly(mask, [pts], 1) return mask.astype(bool) ``` ### Multi-scale segmentation ```python def multiscale_segment(image, predictor, point, scales=[0.5, 1.0, 2.0]): """Generate masks at multiple scales and combine.""" h, w = image.shape[:2] masks_all = [] for scale in scales: # Resize image new_h, new_w = int(h * scale), int(w * scale) scaled_image = cv2.resize(image, (new_w, new_h)) scaled_point = (point * scale).astype(int) # Segment predictor.set_image(scaled_image) masks, scores, _ = predictor.predict( point_coords=scaled_point.reshape(1, 2), point_labels=np.array([1]), multimask_output=True ) # Resize mask back best_mask = masks[np.argmax(scores)] original_mask = cv2.resize(best_mask.astype(np.uint8), (w, h)) > 0.5 masks_all.append(original_mask) # Combine masks (majority voting) combined = np.stack(masks_all, axis=0) final_mask = np.sum(combined, axis=0) >= len(scales) // 2 + 1 return final_mask ``` ## Performance Optimization ### TensorRT acceleration ```python import tensorrt as trt import pycuda.driver as cuda import pycuda.autoinit def export_to_tensorrt(onnx_path, engine_path, fp16=True): """Convert ONNX model to TensorRT engine.""" logger = trt.Logger(trt.Logger.WARNING) builder = trt.Builder(logger) network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) parser = trt.OnnxParser(network, logger) with open(onnx_path, 'rb') as f: if not parser.parse(f.read()): for error in range(parser.num_errors): print(parser.get_error(error)) return None config = builder.create_builder_config() config.max_workspace_size = 1 << 30 # 1GB if fp16: config.set_flag(trt.BuilderFlag.FP16) engine = builder.build_engine(network, config) with open(engine_path, 'wb') as f: f.write(engine.serialize()) return engine ``` ### Memory-efficient inference ```python class MemoryEfficientSAM: def __init__(self, checkpoint, model_type="vit_b"): self.sam = sam_model_registry[model_type](checkpoint=checkpoint) self.sam.eval() self.predictor = None def __enter__(self): self.sam.to("cuda") self.predictor = SamPredictor(self.sam) return self def __exit__(self, *args): self.sam.to("cpu") torch.cuda.empty_cache() def segment(self, image, points, labels): self.predictor.set_image(image) masks, scores, _ = self.predictor.predict( point_coords=points, point_labels=labels, multimask_output=True ) return masks, scores # Usage with context manager (auto-cleanup) with MemoryEfficientSAM("sam_vit_b_01ec64.pth") as sam: masks, scores = sam.segment(image, points, labels) # CUDA memory freed automatically ``` ## Dataset Generation ### Create segmentation dataset ```python import json def generate_dataset(images_dir, output_dir, mask_generator): """Generate segmentation dataset from images.""" annotations = [] for img_path in Path(images_dir).glob("*.jpg"): image = cv2.imread(str(img_path)) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Generate masks masks = mask_generator.generate(image) # Filter high-quality masks good_masks = [m for m in masks if m["predicted_iou"] > 0.9] # Save annotations for i, mask_data in enumerate(good_masks): annotation = { "image_id": img_path.stem, "mask_id": i, "bbox": mask_data["bbox"], "area": mask_data["area"], "segmentation": mask_to_rle(mask_data["segmentation"]), "predicted_iou": mask_data["predicted_iou"], "stability_score": mask_data["stability_score"] } annotations.append(annotation) # Save dataset with open(output_dir / "annotations.json", "w") as f: json.dump(annotations, f) return annotations ```