Pydantic 2.x Discriminated Unions Research
Overview
Discriminated unions in Pydantic 2.x allow efficient validation of union types by routing to the correct model based on a discriminator field, rather than trying each variant sequentially.
Key Components
1. Discriminator(callable)
Takes a function that receives input data and returns a tag string:
from typing import Annotated, Union, Any
from pydantic import Discriminator, Tag
def get_type_tag(v: Any) -> str:
if isinstance(v, dict):
return v.get('type', 'unknown')
return getattr(v, 'type', 'unknown')
MyUnion = Annotated[
Union[
Annotated[TypeA, Tag('a')],
Annotated[TypeB, Tag('b')],
],
Discriminator(get_type_tag)
]
2. Tag('name')
Labels each union member. The discriminator function returns this string to route validation:
Annotated[MyModel, Tag('my_tag')]
3. Callable Discriminator Requirements
The callable must handle both:
dictinput (during validation)- Model instance (during serialization)
def discriminator(v: Any) -> str:
if isinstance(v, dict):
return v.get('discriminator_field', '')
return getattr(v, 'discriminator_field', '')
Patterns
Pattern 1: Simple Literal Discriminator
When the discriminator field has a fixed set of values, use Literal types:
from typing import Literal
class Dog(BaseModel):
pet_type: Literal['dog']
bark_volume: int
class Cat(BaseModel):
pet_type: Literal['cat']
meow_pitch: float
Pet = Annotated[
Union[Dog, Cat],
Discriminator('pet_type') # String field name works for Literal
]
Pattern 2: Callable Discriminator for Complex Routing
When routing logic is more complex than exact string matching:
def pet_discriminator(v: Any) -> str:
if isinstance(v, dict):
pet_type = v.get('pet_type', '')
else:
pet_type = getattr(v, 'pet_type', '')
# Complex routing logic
if pet_type.startswith('dog'):
return 'canine'
elif pet_type.startswith('cat'):
return 'feline'
return 'unknown'
Pet = Annotated[
Union[
Annotated[CanineModel, Tag('canine')],
Annotated[FelineModel, Tag('feline')],
Annotated[UnknownPetModel, Tag('unknown')],
],
Discriminator(pet_discriminator)
]
Pattern 3: Literal Types + Validators for Hybrid Matching
Combine Literal for simple cases and validators for complex patterns:
class SimpleList(BaseModel):
collection_type: Literal["list"]
elements: List[Item]
class NestedList(BaseModel):
collection_type: str # Accepts any string
@field_validator('collection_type')
@classmethod
def must_contain_colon(cls, v: str) -> str:
if ':' not in v:
raise ValueError(f'Must contain ":", got "{v}"')
return v
elements: List[NestedItem]
Pattern 4: Fallback with Left-to-Right Union Mode
Handle unknown types with a fallback model:
# Inner discriminated union for known types
KnownTypes = Annotated[
Union[
Annotated[TypeA, Tag('a')],
Annotated[TypeB, Tag('b')],
],
Discriminator(known_type_discriminator)
]
# Outer union with left-to-right fallback
WithFallback = Annotated[
Union[KnownTypes, GenericFallback],
Field(union_mode="left_to_right")
]
Pattern 5: Nested Discriminated Unions
For recursive structures, use forward references and model_rebuild():
class Container(BaseModel):
items: List["ItemUnion"]
class ItemA(BaseModel):
type: Literal["a"]
value: str
class ItemB(BaseModel):
type: Literal["b"]
nested: "Container" # Forward reference
ItemUnion = Annotated[
Union[
Annotated[ItemA, Tag('a')],
Annotated[ItemB, Tag('b')],
],
Discriminator('type')
]
# Rebuild after all definitions
Container.model_rebuild()
ItemB.model_rebuild()
Performance
Discriminated unions are faster than regular unions because:
- Pydantic extracts the discriminator value first
- Routes directly to the matching model
- No need to try each variant sequentially
The discriminator callable is implemented efficiently and runs in Rust when possible.
Common Pitfalls
1. Forgetting to handle both dict and model instance
# Wrong - only handles dict
def bad_discriminator(v: dict) -> str:
return v['type']
# Correct - handles both
def good_discriminator(v: Any) -> str:
if isinstance(v, dict):
return v.get('type', '')
return getattr(v, 'type', '')
2. Missing model_rebuild() for forward references
class A(BaseModel):
items: List["B"]
class B(BaseModel):
parent: Optional["A"]
# Must call after all definitions
A.model_rebuild()
B.model_rebuild()
3. Tag mismatch
The string returned by discriminator must exactly match a Tag():
# Discriminator returns 'type_a'
# But Tag is 'a' - won't match!
Annotated[TypeA, Tag('a')] # Wrong
Annotated[TypeA, Tag('type_a')] # Correct
Use Case: Collection Type Discrimination
For Galaxy collection types where collection_type can be:
- Simple:
"list","paired","record" - Nested:
"list:paired","sample_sheet:record"
def collection_discriminator(v: Any) -> str:
if isinstance(v, dict):
ct = v.get('collection_type', '')
else:
ct = getattr(v, 'collection_type', '')
# Simple types - exact match
if ct in ('list', 'paired', 'record', 'paired_or_unpaired', 'sample_sheet'):
return ct
# Nested types - route by outer structure
if ':' in ct:
first_segment = ct.split(':')[0]
if first_segment in ('list', 'sample_sheet'):
return 'nested_list'
else:
return 'nested_record'
return 'list' # fallback
CollectionUnion = Annotated[
Union[
Annotated[ListRuntime, Tag('list')],
Annotated[PairedRuntime, Tag('paired')],
Annotated[RecordRuntime, Tag('record')],
Annotated[NestedListRuntime, Tag('nested_list')],
Annotated[NestedRecordRuntime, Tag('nested_record')],
],
Discriminator(collection_discriminator)
]