Tutorial 4: Configs and Dynamic Patterns
The previous tutorials covered models where dimensions flow directly through constructor parameters. In practice, many models store hyperparameters in config dataclasses and construct modules dynamically. This tutorial shows how to type those patterns.
Config classes with type parameters
When a @dataclass holds dimension hyperparameters consumed by multiple
modules, make it generic so dimensions propagate through constructors:
@dataclass
class GPTConfig[VocabSize, BlockSize, NEmbedding, NHead, NLayer]:
block_size: Dim[BlockSize]
vocab_size: Dim[VocabSize]
n_layer: Dim[NLayer]
n_head: Dim[NHead]
n_embd: Dim[NEmbedding]
dropout: float = 0.0
bias: bool = True
Modules extract only the type parameters they need, using Any for the rest:
class MLP[NEmbedding](nn.Module):
def __init__(self, config: GPTConfig[Any, Any, NEmbedding, Any, Any]):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
def forward[B, T](
self, x: Tensor[B, T, NEmbedding]
) -> Tensor[B, T, NEmbedding]:
h = F.gelu(self.c_fc(x))
assert_type(h, Tensor[B, T, 4 * NEmbedding])
return self.c_proj(h)
Without config parameterization, each module would need to independently accept and thread every dimension through its constructor — error-prone and verbose.
Final class attributes for constants
When a model has fixed hyperparameters, Final class attributes let
Literal arithmetic work at the type level:
class DCGAN:
nc: Final = 3
nz: Final = 100
ngf: Final = 64
ndf: Final = 64
The type checker resolves DCGAN.ngf * 8 to Literal[512], so you can
write:
self.project = nn.ConvTranspose2d(DCGAN.nz, DCGAN.ngf * 8, 4, 1, 0)
# Type checker infers: ConvTranspose2d[100, 512, ...]
Dynamic construction patterns
Several common patterns break shape tracking. Here's a quick reference and the fix for each:
nn.Sequential(*list_var)
When modules are constructed in a loop and passed as *list_var, the
Sequential loses all type information:
# Bad: Sequential erases module types
layers = [nn.Linear(dim, dim) for _ in range(n)]
self.net = nn.Sequential(*layers) # returns bare Tensor
Fix: extract shape-changing modules as individual attributes and chain them
in forward. Shape-preserving modules (activations, norms, dropout) can
remain grouped:
# Good: individual attributes preserve types
self.fc1 = nn.Linear(dim, hidden)
self.fc2 = nn.Linear(hidden, dim)
def forward[B](self, x: Tensor[B, D]) -> Tensor[B, D]:
h = F.relu(self.fc1(x))
return self.fc2(h)
Note: nn.Sequential(M1(), M2(), M3()) with direct arguments is
tracked — only the *list_var form loses types.
Factory functions returning nn.Sequential
Returning nn.Sequential from a function erases all type parameters at the
function boundary:
# Bad: factory function — types erased
def _make_block(in_c, out_c) -> nn.Sequential:
return nn.Sequential(nn.Conv2d(in_c, 128, ...), nn.Conv2d(128, out_c, ...))
Fix: use a class with a typed forward method:
# Good: class preserves shape contract
class Block[InC, OutC](nn.Module):
def __init__(self, in_c: Dim[InC], out_c: Dim[OutC]) -> None:
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(in_c, 128, ...), nn.Conv2d(128, out_c, ...)
)
def forward[B, H, W](
self, x: Tensor[B, InC, H, W]
) -> Tensor[B, OutC, H, W]:
return self.net(x)
Dimensions from list[int]
list[int] element access returns int, losing the concrete value:
hidden_units = [512, 256, 128]
last_dim = hidden_units[-1] # type is int, not Dim[128]
Fix: add an explicit Dim field to the config:
@dataclass
class Config[K, MlpOut]:
num_output_features: Dim[K]
mlp_output_dim: Dim[MlpOut] # explicit — was hidden_units[-1]
mlp_hidden_units: list[int] = field(default_factory=lambda: [512, 256])
Typed interfaces
When none of the above restructurings can recover shape tracking, the last
resort is a typed interface: the module's forward signature provides
the shape contract, and type: ignore narrows the internal result:
class DynamicMLP[InDim, OutDim](nn.Module):
def __init__(self, in_dim: Dim[InDim], out_dim: Dim[OutDim],
hidden: list[int]) -> None:
super().__init__()
layers = []
prev = in_dim
for h in hidden:
layers.append(nn.Linear(prev, h))
prev = h
layers.append(nn.Linear(prev, out_dim))
self.layers = nn.ModuleList(layers)
def forward[B](self, x: Tensor[B, InDim]) -> Tensor[B, OutDim]:
h = x
for layer in self.layers:
h = layer(h)
result: Tensor[B, OutDim] = h # type: ignore[bad-assignment]
return result
The caller sees a clean Tensor[B, InDim] -> Tensor[B, OutDim] contract.
The type: ignore is localized to the module's internals.
Use typed interfaces only after exhausting all restructuring options — they're the fallback, not the first move.
Key concepts
- Parameterized configs propagate dimensions across modules without
threading every
Dimthrough every constructor. Finalclass attributes let the checker resolve constants at the type level.- Dynamic construction (
Sequential(*list), factory functions,list[int]access) breaks tracking — restructure or add explicitDimfields. - Typed interfaces provide a clean shape contract when internals are dynamic, but should be the last resort.