Files

12 KiB

DSPy Modules

Complete guide to DSPy's built-in modules for language model programming.

Module Basics

DSPy modules are composable building blocks inspired by PyTorch's NN modules:

  • Have learnable parameters (prompts, few-shot examples)
  • Can be composed using Python control flow
  • Generalized to handle any signature
  • Optimizable with DSPy optimizers

Base Module Pattern

import dspy

class CustomModule(dspy.Module):
    def __init__(self):
        super().__init__()
        # Initialize sub-modules
        self.predictor = dspy.Predict("input -> output")

    def forward(self, input):
        # Module logic
        result = self.predictor(input=input)
        return result

Core Modules

dspy.Predict

Basic prediction module - Makes LM calls without reasoning steps.

# Inline signature
qa = dspy.Predict("question -> answer")
result = qa(question="What is 2+2?")

# Class signature
class QA(dspy.Signature):
    """Answer questions concisely."""
    question = dspy.InputField()
    answer = dspy.OutputField(desc="short, factual answer")

qa = dspy.Predict(QA)
result = qa(question="What is the capital of France?")
print(result.answer)  # "Paris"

When to use:

  • Simple, direct predictions
  • No reasoning steps needed
  • Fast responses required

dspy.ChainOfThought

Step-by-step reasoning - Generates rationale before answer.

Parameters:

  • signature: Task signature
  • rationale_field: Custom reasoning field (optional)
  • rationale_field_type: Type for rationale (default: str)
# Basic usage
cot = dspy.ChainOfThought("question -> answer")
result = cot(question="If I have 5 apples and give away 2, how many remain?")
print(result.rationale)  # "Let's think step by step..."
print(result.answer)     # "3"

# Custom rationale field
cot = dspy.ChainOfThought(
    signature="problem -> solution",
    rationale_field=dspy.OutputField(
        prefix="Reasoning: Let's break this down step by step to"
    )
)

When to use:

  • Complex reasoning tasks
  • Math word problems
  • Logical deduction
  • Quality > speed

Performance:

  • ~2x slower than Predict
  • Significantly better accuracy on reasoning tasks

dspy.ProgramOfThought

Code-based reasoning - Generates and executes Python code.

pot = dspy.ProgramOfThought("question -> answer")

result = pot(question="What is 15% of 240?")
# Internally generates: answer = 240 * 0.15
# Executes code and returns result
print(result.answer)  # 36.0

result = pot(question="If a train travels 60 mph for 2.5 hours, how far does it go?")
# Generates: distance = 60 * 2.5
print(result.answer)  # 150.0

When to use:

  • Arithmetic calculations
  • Symbolic math
  • Data transformations
  • Deterministic computations

Benefits:

  • More reliable than text-based math
  • Handles complex calculations
  • Transparent (shows generated code)

dspy.ReAct

Reasoning + Acting - Agent that uses tools iteratively.

from dspy.predict import ReAct

# Define tools
def search_wikipedia(query: str) -> str:
    """Search Wikipedia for information."""
    # Your search implementation
    return search_results

def calculate(expression: str) -> float:
    """Evaluate a mathematical expression."""
    return eval(expression)

# Create ReAct agent
class ResearchQA(dspy.Signature):
    """Answer questions using available tools."""
    question = dspy.InputField()
    answer = dspy.OutputField()

react = ReAct(ResearchQA, tools=[search_wikipedia, calculate])

# Agent decides which tools to use
result = react(question="How old was Einstein when he published special relativity?")
# Internally:
# 1. Thinks: "Need birth year and publication year"
# 2. Acts: search_wikipedia("Albert Einstein")
# 3. Acts: search_wikipedia("Special relativity 1905")
# 4. Acts: calculate("1905 - 1879")
# 5. Returns: "26 years old"

When to use:

  • Multi-step research tasks
  • Tool-using agents
  • Complex information retrieval
  • Tasks requiring multiple API calls

Best practices:

  • Keep tool descriptions clear and specific
  • Limit to 5-7 tools (too many = confusion)
  • Provide tool usage examples in docstrings

dspy.MultiChainComparison

Generate multiple outputs and compare - Self-consistency pattern.

mcc = dspy.MultiChainComparison("question -> answer", M=5)

result = mcc(question="What is the capital of France?")
# Generates 5 candidate answers
# Compares and selects most consistent
print(result.answer)  # "Paris"
print(result.candidates)  # All 5 generated answers

Parameters:

  • M: Number of candidates to generate (default: 5)
  • temperature: Sampling temperature for diversity

When to use:

  • High-stakes decisions
  • Ambiguous questions
  • When single answer may be unreliable

Tradeoff:

  • M times slower (M parallel calls)
  • Higher accuracy on ambiguous tasks

dspy.majority

Majority voting over multiple predictions.

from dspy.primitives import majority

# Generate multiple predictions
predictor = dspy.Predict("question -> answer")
predictions = [predictor(question="What is 2+2?") for _ in range(5)]

# Take majority vote
answer = majority([p.answer for p in predictions])
print(answer)  # "4"

When to use:

  • Combining multiple model outputs
  • Reducing variance in predictions
  • Ensemble approaches

Advanced Modules

dspy.TypedPredictor

Structured output with Pydantic models.

from pydantic import BaseModel, Field

class PersonInfo(BaseModel):
    name: str = Field(description="Full name")
    age: int = Field(description="Age in years")
    occupation: str = Field(description="Current job")

class ExtractPerson(dspy.Signature):
    """Extract person information from text."""
    text = dspy.InputField()
    person: PersonInfo = dspy.OutputField()

extractor = dspy.TypedPredictor(ExtractPerson)
result = extractor(text="John Doe is a 35-year-old software engineer.")

print(result.person.name)       # "John Doe"
print(result.person.age)        # 35
print(result.person.occupation) # "software engineer"

Benefits:

  • Type safety
  • Automatic validation
  • JSON schema generation
  • IDE autocomplete

dspy.Retry

Automatic retry with validation.

from dspy.primitives import Retry

def validate_number(example, pred, trace=None):
    """Validate output is a number."""
    try:
        float(pred.answer)
        return True
    except ValueError:
        return False

# Retry up to 3 times if validation fails
qa = Retry(
    dspy.ChainOfThought("question -> answer"),
    validate=validate_number,
    max_retries=3
)

result = qa(question="What is 15% of 80?")
# If first attempt returns non-numeric, retries automatically

dspy.Assert

Assertion-driven optimization.

import dspy
from dspy.primitives.assertions import assert_transform_module, backtrack_handler

class ValidatedQA(dspy.Module):
    def __init__(self):
        super().__init__()
        self.qa = dspy.ChainOfThought("question -> answer: float")

    def forward(self, question):
        answer = self.qa(question=question).answer

        # Assert answer is numeric
        dspy.Assert(
            isinstance(float(answer), float),
            "Answer must be a number",
            backtrack=backtrack_handler
        )

        return dspy.Prediction(answer=answer)

Benefits:

  • Catches errors during optimization
  • Guides LM toward valid outputs
  • Better than post-hoc filtering

Module Composition

Sequential Pipeline

class Pipeline(dspy.Module):
    def __init__(self):
        super().__init__()
        self.stage1 = dspy.Predict("input -> intermediate")
        self.stage2 = dspy.ChainOfThought("intermediate -> output")

    def forward(self, input):
        intermediate = self.stage1(input=input).intermediate
        output = self.stage2(intermediate=intermediate).output
        return dspy.Prediction(output=output)

Conditional Logic

class ConditionalModule(dspy.Module):
    def __init__(self):
        super().__init__()
        self.router = dspy.Predict("question -> category: str")
        self.simple_qa = dspy.Predict("question -> answer")
        self.complex_qa = dspy.ChainOfThought("question -> answer")

    def forward(self, question):
        category = self.router(question=question).category

        if category == "simple":
            return self.simple_qa(question=question)
        else:
            return self.complex_qa(question=question)

Parallel Execution

class ParallelModule(dspy.Module):
    def __init__(self):
        super().__init__()
        self.approach1 = dspy.ChainOfThought("question -> answer")
        self.approach2 = dspy.ProgramOfThought("question -> answer")

    def forward(self, question):
        # Run both approaches
        answer1 = self.approach1(question=question).answer
        answer2 = self.approach2(question=question).answer

        # Compare or combine results
        if answer1 == answer2:
            return dspy.Prediction(answer=answer1, confidence="high")
        else:
            return dspy.Prediction(answer=answer1, confidence="low")

Batch Processing

All modules support batch processing for efficiency:

cot = dspy.ChainOfThought("question -> answer")

questions = [
    "What is 2+2?",
    "What is 3+3?",
    "What is 4+4?"
]

# Process all at once
results = cot.batch([{"question": q} for q in questions])

for result in results:
    print(result.answer)

Saving and Loading

# Save module
qa = dspy.ChainOfThought("question -> answer")
qa.save("models/qa_v1.json")

# Load module
loaded_qa = dspy.ChainOfThought("question -> answer")
loaded_qa.load("models/qa_v1.json")

What gets saved:

  • Few-shot examples
  • Prompt instructions
  • Module configuration

What doesn't get saved:

  • Model weights (DSPy doesn't fine-tune by default)
  • LM provider configuration

Module Selection Guide

Task Module Reason
Simple classification Predict Fast, direct
Math word problems ProgramOfThought Reliable calculations
Logical reasoning ChainOfThought Better with steps
Multi-step research ReAct Tool usage
High-stakes decisions MultiChainComparison Self-consistency
Structured extraction TypedPredictor Type safety
Ambiguous questions MultiChainComparison Multiple perspectives

Performance Tips

  1. Start with Predict, add reasoning only if needed
  2. Use batch processing for multiple inputs
  3. Cache predictions for repeated queries
  4. Profile token usage with track_usage=True
  5. Optimize after prototyping with teleprompters

Common Patterns

Pattern: Retrieval + Generation

class RAG(dspy.Module):
    def __init__(self, k=3):
        super().__init__()
        self.retrieve = dspy.Retrieve(k=k)
        self.generate = dspy.ChainOfThought("context, question -> answer")

    def forward(self, question):
        context = self.retrieve(question).passages
        return self.generate(context=context, question=question)

Pattern: Verification Loop

class VerifiedQA(dspy.Module):
    def __init__(self):
        super().__init__()
        self.answer = dspy.ChainOfThought("question -> answer")
        self.verify = dspy.Predict("question, answer -> is_correct: bool")

    def forward(self, question, max_attempts=3):
        for _ in range(max_attempts):
            answer = self.answer(question=question).answer
            is_correct = self.verify(question=question, answer=answer).is_correct

            if is_correct:
                return dspy.Prediction(answer=answer)

        return dspy.Prediction(answer="Unable to verify answer")

Pattern: Multi-Turn Dialog

class DialogAgent(dspy.Module):
    def __init__(self):
        super().__init__()
        self.respond = dspy.Predict("history, user_message -> assistant_message")
        self.history = []

    def forward(self, user_message):
        history_str = "\n".join(self.history)
        response = self.respond(history=history_str, user_message=user_message)

        self.history.append(f"User: {user_message}")
        self.history.append(f"Assistant: {response.assistant_message}")

        return response