476 lines
12 KiB
Markdown
476 lines
12 KiB
Markdown
|
|
# DSPy Modules
|
||
|
|
|
||
|
|
Complete guide to DSPy's built-in modules for language model programming.
|
||
|
|
|
||
|
|
## Module Basics
|
||
|
|
|
||
|
|
DSPy modules are composable building blocks inspired by PyTorch's NN modules:
|
||
|
|
- Have learnable parameters (prompts, few-shot examples)
|
||
|
|
- Can be composed using Python control flow
|
||
|
|
- Generalized to handle any signature
|
||
|
|
- Optimizable with DSPy optimizers
|
||
|
|
|
||
|
|
### Base Module Pattern
|
||
|
|
|
||
|
|
```python
|
||
|
|
import dspy
|
||
|
|
|
||
|
|
class CustomModule(dspy.Module):
|
||
|
|
def __init__(self):
|
||
|
|
super().__init__()
|
||
|
|
# Initialize sub-modules
|
||
|
|
self.predictor = dspy.Predict("input -> output")
|
||
|
|
|
||
|
|
def forward(self, input):
|
||
|
|
# Module logic
|
||
|
|
result = self.predictor(input=input)
|
||
|
|
return result
|
||
|
|
```
|
||
|
|
|
||
|
|
## Core Modules
|
||
|
|
|
||
|
|
### dspy.Predict
|
||
|
|
|
||
|
|
**Basic prediction module** - Makes LM calls without reasoning steps.
|
||
|
|
|
||
|
|
```python
|
||
|
|
# Inline signature
|
||
|
|
qa = dspy.Predict("question -> answer")
|
||
|
|
result = qa(question="What is 2+2?")
|
||
|
|
|
||
|
|
# Class signature
|
||
|
|
class QA(dspy.Signature):
|
||
|
|
"""Answer questions concisely."""
|
||
|
|
question = dspy.InputField()
|
||
|
|
answer = dspy.OutputField(desc="short, factual answer")
|
||
|
|
|
||
|
|
qa = dspy.Predict(QA)
|
||
|
|
result = qa(question="What is the capital of France?")
|
||
|
|
print(result.answer) # "Paris"
|
||
|
|
```
|
||
|
|
|
||
|
|
**When to use:**
|
||
|
|
- Simple, direct predictions
|
||
|
|
- No reasoning steps needed
|
||
|
|
- Fast responses required
|
||
|
|
|
||
|
|
### dspy.ChainOfThought
|
||
|
|
|
||
|
|
**Step-by-step reasoning** - Generates rationale before answer.
|
||
|
|
|
||
|
|
**Parameters:**
|
||
|
|
- `signature`: Task signature
|
||
|
|
- `rationale_field`: Custom reasoning field (optional)
|
||
|
|
- `rationale_field_type`: Type for rationale (default: `str`)
|
||
|
|
|
||
|
|
```python
|
||
|
|
# Basic usage
|
||
|
|
cot = dspy.ChainOfThought("question -> answer")
|
||
|
|
result = cot(question="If I have 5 apples and give away 2, how many remain?")
|
||
|
|
print(result.rationale) # "Let's think step by step..."
|
||
|
|
print(result.answer) # "3"
|
||
|
|
|
||
|
|
# Custom rationale field
|
||
|
|
cot = dspy.ChainOfThought(
|
||
|
|
signature="problem -> solution",
|
||
|
|
rationale_field=dspy.OutputField(
|
||
|
|
prefix="Reasoning: Let's break this down step by step to"
|
||
|
|
)
|
||
|
|
)
|
||
|
|
```
|
||
|
|
|
||
|
|
**When to use:**
|
||
|
|
- Complex reasoning tasks
|
||
|
|
- Math word problems
|
||
|
|
- Logical deduction
|
||
|
|
- Quality > speed
|
||
|
|
|
||
|
|
**Performance:**
|
||
|
|
- ~2x slower than Predict
|
||
|
|
- Significantly better accuracy on reasoning tasks
|
||
|
|
|
||
|
|
### dspy.ProgramOfThought
|
||
|
|
|
||
|
|
**Code-based reasoning** - Generates and executes Python code.
|
||
|
|
|
||
|
|
```python
|
||
|
|
pot = dspy.ProgramOfThought("question -> answer")
|
||
|
|
|
||
|
|
result = pot(question="What is 15% of 240?")
|
||
|
|
# Internally generates: answer = 240 * 0.15
|
||
|
|
# Executes code and returns result
|
||
|
|
print(result.answer) # 36.0
|
||
|
|
|
||
|
|
result = pot(question="If a train travels 60 mph for 2.5 hours, how far does it go?")
|
||
|
|
# Generates: distance = 60 * 2.5
|
||
|
|
print(result.answer) # 150.0
|
||
|
|
```
|
||
|
|
|
||
|
|
**When to use:**
|
||
|
|
- Arithmetic calculations
|
||
|
|
- Symbolic math
|
||
|
|
- Data transformations
|
||
|
|
- Deterministic computations
|
||
|
|
|
||
|
|
**Benefits:**
|
||
|
|
- More reliable than text-based math
|
||
|
|
- Handles complex calculations
|
||
|
|
- Transparent (shows generated code)
|
||
|
|
|
||
|
|
### dspy.ReAct
|
||
|
|
|
||
|
|
**Reasoning + Acting** - Agent that uses tools iteratively.
|
||
|
|
|
||
|
|
```python
|
||
|
|
from dspy.predict import ReAct
|
||
|
|
|
||
|
|
# Define tools
|
||
|
|
def search_wikipedia(query: str) -> str:
|
||
|
|
"""Search Wikipedia for information."""
|
||
|
|
# Your search implementation
|
||
|
|
return search_results
|
||
|
|
|
||
|
|
def calculate(expression: str) -> float:
|
||
|
|
"""Evaluate a mathematical expression."""
|
||
|
|
return eval(expression)
|
||
|
|
|
||
|
|
# Create ReAct agent
|
||
|
|
class ResearchQA(dspy.Signature):
|
||
|
|
"""Answer questions using available tools."""
|
||
|
|
question = dspy.InputField()
|
||
|
|
answer = dspy.OutputField()
|
||
|
|
|
||
|
|
react = ReAct(ResearchQA, tools=[search_wikipedia, calculate])
|
||
|
|
|
||
|
|
# Agent decides which tools to use
|
||
|
|
result = react(question="How old was Einstein when he published special relativity?")
|
||
|
|
# Internally:
|
||
|
|
# 1. Thinks: "Need birth year and publication year"
|
||
|
|
# 2. Acts: search_wikipedia("Albert Einstein")
|
||
|
|
# 3. Acts: search_wikipedia("Special relativity 1905")
|
||
|
|
# 4. Acts: calculate("1905 - 1879")
|
||
|
|
# 5. Returns: "26 years old"
|
||
|
|
```
|
||
|
|
|
||
|
|
**When to use:**
|
||
|
|
- Multi-step research tasks
|
||
|
|
- Tool-using agents
|
||
|
|
- Complex information retrieval
|
||
|
|
- Tasks requiring multiple API calls
|
||
|
|
|
||
|
|
**Best practices:**
|
||
|
|
- Keep tool descriptions clear and specific
|
||
|
|
- Limit to 5-7 tools (too many = confusion)
|
||
|
|
- Provide tool usage examples in docstrings
|
||
|
|
|
||
|
|
### dspy.MultiChainComparison
|
||
|
|
|
||
|
|
**Generate multiple outputs and compare** - Self-consistency pattern.
|
||
|
|
|
||
|
|
```python
|
||
|
|
mcc = dspy.MultiChainComparison("question -> answer", M=5)
|
||
|
|
|
||
|
|
result = mcc(question="What is the capital of France?")
|
||
|
|
# Generates 5 candidate answers
|
||
|
|
# Compares and selects most consistent
|
||
|
|
print(result.answer) # "Paris"
|
||
|
|
print(result.candidates) # All 5 generated answers
|
||
|
|
```
|
||
|
|
|
||
|
|
**Parameters:**
|
||
|
|
- `M`: Number of candidates to generate (default: 5)
|
||
|
|
- `temperature`: Sampling temperature for diversity
|
||
|
|
|
||
|
|
**When to use:**
|
||
|
|
- High-stakes decisions
|
||
|
|
- Ambiguous questions
|
||
|
|
- When single answer may be unreliable
|
||
|
|
|
||
|
|
**Tradeoff:**
|
||
|
|
- M times slower (M parallel calls)
|
||
|
|
- Higher accuracy on ambiguous tasks
|
||
|
|
|
||
|
|
### dspy.majority
|
||
|
|
|
||
|
|
**Majority voting over multiple predictions.**
|
||
|
|
|
||
|
|
```python
|
||
|
|
from dspy.primitives import majority
|
||
|
|
|
||
|
|
# Generate multiple predictions
|
||
|
|
predictor = dspy.Predict("question -> answer")
|
||
|
|
predictions = [predictor(question="What is 2+2?") for _ in range(5)]
|
||
|
|
|
||
|
|
# Take majority vote
|
||
|
|
answer = majority([p.answer for p in predictions])
|
||
|
|
print(answer) # "4"
|
||
|
|
```
|
||
|
|
|
||
|
|
**When to use:**
|
||
|
|
- Combining multiple model outputs
|
||
|
|
- Reducing variance in predictions
|
||
|
|
- Ensemble approaches
|
||
|
|
|
||
|
|
## Advanced Modules
|
||
|
|
|
||
|
|
### dspy.TypedPredictor
|
||
|
|
|
||
|
|
**Structured output with Pydantic models.**
|
||
|
|
|
||
|
|
```python
|
||
|
|
from pydantic import BaseModel, Field
|
||
|
|
|
||
|
|
class PersonInfo(BaseModel):
|
||
|
|
name: str = Field(description="Full name")
|
||
|
|
age: int = Field(description="Age in years")
|
||
|
|
occupation: str = Field(description="Current job")
|
||
|
|
|
||
|
|
class ExtractPerson(dspy.Signature):
|
||
|
|
"""Extract person information from text."""
|
||
|
|
text = dspy.InputField()
|
||
|
|
person: PersonInfo = dspy.OutputField()
|
||
|
|
|
||
|
|
extractor = dspy.TypedPredictor(ExtractPerson)
|
||
|
|
result = extractor(text="John Doe is a 35-year-old software engineer.")
|
||
|
|
|
||
|
|
print(result.person.name) # "John Doe"
|
||
|
|
print(result.person.age) # 35
|
||
|
|
print(result.person.occupation) # "software engineer"
|
||
|
|
```
|
||
|
|
|
||
|
|
**Benefits:**
|
||
|
|
- Type safety
|
||
|
|
- Automatic validation
|
||
|
|
- JSON schema generation
|
||
|
|
- IDE autocomplete
|
||
|
|
|
||
|
|
### dspy.Retry
|
||
|
|
|
||
|
|
**Automatic retry with validation.**
|
||
|
|
|
||
|
|
```python
|
||
|
|
from dspy.primitives import Retry
|
||
|
|
|
||
|
|
def validate_number(example, pred, trace=None):
|
||
|
|
"""Validate output is a number."""
|
||
|
|
try:
|
||
|
|
float(pred.answer)
|
||
|
|
return True
|
||
|
|
except ValueError:
|
||
|
|
return False
|
||
|
|
|
||
|
|
# Retry up to 3 times if validation fails
|
||
|
|
qa = Retry(
|
||
|
|
dspy.ChainOfThought("question -> answer"),
|
||
|
|
validate=validate_number,
|
||
|
|
max_retries=3
|
||
|
|
)
|
||
|
|
|
||
|
|
result = qa(question="What is 15% of 80?")
|
||
|
|
# If first attempt returns non-numeric, retries automatically
|
||
|
|
```
|
||
|
|
|
||
|
|
### dspy.Assert
|
||
|
|
|
||
|
|
**Assertion-driven optimization.**
|
||
|
|
|
||
|
|
```python
|
||
|
|
import dspy
|
||
|
|
from dspy.primitives.assertions import assert_transform_module, backtrack_handler
|
||
|
|
|
||
|
|
class ValidatedQA(dspy.Module):
|
||
|
|
def __init__(self):
|
||
|
|
super().__init__()
|
||
|
|
self.qa = dspy.ChainOfThought("question -> answer: float")
|
||
|
|
|
||
|
|
def forward(self, question):
|
||
|
|
answer = self.qa(question=question).answer
|
||
|
|
|
||
|
|
# Assert answer is numeric
|
||
|
|
dspy.Assert(
|
||
|
|
isinstance(float(answer), float),
|
||
|
|
"Answer must be a number",
|
||
|
|
backtrack=backtrack_handler
|
||
|
|
)
|
||
|
|
|
||
|
|
return dspy.Prediction(answer=answer)
|
||
|
|
```
|
||
|
|
|
||
|
|
**Benefits:**
|
||
|
|
- Catches errors during optimization
|
||
|
|
- Guides LM toward valid outputs
|
||
|
|
- Better than post-hoc filtering
|
||
|
|
|
||
|
|
## Module Composition
|
||
|
|
|
||
|
|
### Sequential Pipeline
|
||
|
|
|
||
|
|
```python
|
||
|
|
class Pipeline(dspy.Module):
|
||
|
|
def __init__(self):
|
||
|
|
super().__init__()
|
||
|
|
self.stage1 = dspy.Predict("input -> intermediate")
|
||
|
|
self.stage2 = dspy.ChainOfThought("intermediate -> output")
|
||
|
|
|
||
|
|
def forward(self, input):
|
||
|
|
intermediate = self.stage1(input=input).intermediate
|
||
|
|
output = self.stage2(intermediate=intermediate).output
|
||
|
|
return dspy.Prediction(output=output)
|
||
|
|
```
|
||
|
|
|
||
|
|
### Conditional Logic
|
||
|
|
|
||
|
|
```python
|
||
|
|
class ConditionalModule(dspy.Module):
|
||
|
|
def __init__(self):
|
||
|
|
super().__init__()
|
||
|
|
self.router = dspy.Predict("question -> category: str")
|
||
|
|
self.simple_qa = dspy.Predict("question -> answer")
|
||
|
|
self.complex_qa = dspy.ChainOfThought("question -> answer")
|
||
|
|
|
||
|
|
def forward(self, question):
|
||
|
|
category = self.router(question=question).category
|
||
|
|
|
||
|
|
if category == "simple":
|
||
|
|
return self.simple_qa(question=question)
|
||
|
|
else:
|
||
|
|
return self.complex_qa(question=question)
|
||
|
|
```
|
||
|
|
|
||
|
|
### Parallel Execution
|
||
|
|
|
||
|
|
```python
|
||
|
|
class ParallelModule(dspy.Module):
|
||
|
|
def __init__(self):
|
||
|
|
super().__init__()
|
||
|
|
self.approach1 = dspy.ChainOfThought("question -> answer")
|
||
|
|
self.approach2 = dspy.ProgramOfThought("question -> answer")
|
||
|
|
|
||
|
|
def forward(self, question):
|
||
|
|
# Run both approaches
|
||
|
|
answer1 = self.approach1(question=question).answer
|
||
|
|
answer2 = self.approach2(question=question).answer
|
||
|
|
|
||
|
|
# Compare or combine results
|
||
|
|
if answer1 == answer2:
|
||
|
|
return dspy.Prediction(answer=answer1, confidence="high")
|
||
|
|
else:
|
||
|
|
return dspy.Prediction(answer=answer1, confidence="low")
|
||
|
|
```
|
||
|
|
|
||
|
|
## Batch Processing
|
||
|
|
|
||
|
|
All modules support batch processing for efficiency:
|
||
|
|
|
||
|
|
```python
|
||
|
|
cot = dspy.ChainOfThought("question -> answer")
|
||
|
|
|
||
|
|
questions = [
|
||
|
|
"What is 2+2?",
|
||
|
|
"What is 3+3?",
|
||
|
|
"What is 4+4?"
|
||
|
|
]
|
||
|
|
|
||
|
|
# Process all at once
|
||
|
|
results = cot.batch([{"question": q} for q in questions])
|
||
|
|
|
||
|
|
for result in results:
|
||
|
|
print(result.answer)
|
||
|
|
```
|
||
|
|
|
||
|
|
## Saving and Loading
|
||
|
|
|
||
|
|
```python
|
||
|
|
# Save module
|
||
|
|
qa = dspy.ChainOfThought("question -> answer")
|
||
|
|
qa.save("models/qa_v1.json")
|
||
|
|
|
||
|
|
# Load module
|
||
|
|
loaded_qa = dspy.ChainOfThought("question -> answer")
|
||
|
|
loaded_qa.load("models/qa_v1.json")
|
||
|
|
```
|
||
|
|
|
||
|
|
**What gets saved:**
|
||
|
|
- Few-shot examples
|
||
|
|
- Prompt instructions
|
||
|
|
- Module configuration
|
||
|
|
|
||
|
|
**What doesn't get saved:**
|
||
|
|
- Model weights (DSPy doesn't fine-tune by default)
|
||
|
|
- LM provider configuration
|
||
|
|
|
||
|
|
## Module Selection Guide
|
||
|
|
|
||
|
|
| Task | Module | Reason |
|
||
|
|
|------|--------|--------|
|
||
|
|
| Simple classification | Predict | Fast, direct |
|
||
|
|
| Math word problems | ProgramOfThought | Reliable calculations |
|
||
|
|
| Logical reasoning | ChainOfThought | Better with steps |
|
||
|
|
| Multi-step research | ReAct | Tool usage |
|
||
|
|
| High-stakes decisions | MultiChainComparison | Self-consistency |
|
||
|
|
| Structured extraction | TypedPredictor | Type safety |
|
||
|
|
| Ambiguous questions | MultiChainComparison | Multiple perspectives |
|
||
|
|
|
||
|
|
## Performance Tips
|
||
|
|
|
||
|
|
1. **Start with Predict**, add reasoning only if needed
|
||
|
|
2. **Use batch processing** for multiple inputs
|
||
|
|
3. **Cache predictions** for repeated queries
|
||
|
|
4. **Profile token usage** with `track_usage=True`
|
||
|
|
5. **Optimize after prototyping** with teleprompters
|
||
|
|
|
||
|
|
## Common Patterns
|
||
|
|
|
||
|
|
### Pattern: Retrieval + Generation
|
||
|
|
|
||
|
|
```python
|
||
|
|
class RAG(dspy.Module):
|
||
|
|
def __init__(self, k=3):
|
||
|
|
super().__init__()
|
||
|
|
self.retrieve = dspy.Retrieve(k=k)
|
||
|
|
self.generate = dspy.ChainOfThought("context, question -> answer")
|
||
|
|
|
||
|
|
def forward(self, question):
|
||
|
|
context = self.retrieve(question).passages
|
||
|
|
return self.generate(context=context, question=question)
|
||
|
|
```
|
||
|
|
|
||
|
|
### Pattern: Verification Loop
|
||
|
|
|
||
|
|
```python
|
||
|
|
class VerifiedQA(dspy.Module):
|
||
|
|
def __init__(self):
|
||
|
|
super().__init__()
|
||
|
|
self.answer = dspy.ChainOfThought("question -> answer")
|
||
|
|
self.verify = dspy.Predict("question, answer -> is_correct: bool")
|
||
|
|
|
||
|
|
def forward(self, question, max_attempts=3):
|
||
|
|
for _ in range(max_attempts):
|
||
|
|
answer = self.answer(question=question).answer
|
||
|
|
is_correct = self.verify(question=question, answer=answer).is_correct
|
||
|
|
|
||
|
|
if is_correct:
|
||
|
|
return dspy.Prediction(answer=answer)
|
||
|
|
|
||
|
|
return dspy.Prediction(answer="Unable to verify answer")
|
||
|
|
```
|
||
|
|
|
||
|
|
### Pattern: Multi-Turn Dialog
|
||
|
|
|
||
|
|
```python
|
||
|
|
class DialogAgent(dspy.Module):
|
||
|
|
def __init__(self):
|
||
|
|
super().__init__()
|
||
|
|
self.respond = dspy.Predict("history, user_message -> assistant_message")
|
||
|
|
self.history = []
|
||
|
|
|
||
|
|
def forward(self, user_message):
|
||
|
|
history_str = "\n".join(self.history)
|
||
|
|
response = self.respond(history=history_str, user_message=user_message)
|
||
|
|
|
||
|
|
self.history.append(f"User: {user_message}")
|
||
|
|
self.history.append(f"Assistant: {response.assistant_message}")
|
||
|
|
|
||
|
|
return response
|
||
|
|
```
|