Skip to main content

Tutorial 4: Configs and Dynamic Patterns

The previous tutorials covered models where dimensions flow directly through constructor parameters. In practice, many models store hyperparameters in config dataclasses and construct modules dynamically. This tutorial shows how to type those patterns.

Config classes with type parameters

When a @dataclass holds dimension hyperparameters consumed by multiple modules, make it generic so dimensions propagate through constructors:

@dataclass
class GPTConfig[VocabSize, BlockSize, NEmbedding, NHead, NLayer]:
block_size: Dim[BlockSize]
vocab_size: Dim[VocabSize]
n_layer: Dim[NLayer]
n_head: Dim[NHead]
n_embd: Dim[NEmbedding]
dropout: float = 0.0
bias: bool = True

Modules extract only the type parameters they need, using Any for the rest:

class MLP[NEmbedding](nn.Module):
def __init__(self, config: GPTConfig[Any, Any, NEmbedding, Any, Any]):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)

def forward[B, T](
self, x: Tensor[B, T, NEmbedding]
) -> Tensor[B, T, NEmbedding]:
h = F.gelu(self.c_fc(x))
assert_type(h, Tensor[B, T, 4 * NEmbedding])
return self.c_proj(h)

Without config parameterization, each module would need to independently accept and thread every dimension through its constructor — error-prone and verbose.

Final class attributes for constants

When a model has fixed hyperparameters, Final class attributes let Literal arithmetic work at the type level:

class DCGAN:
nc: Final = 3
nz: Final = 100
ngf: Final = 64
ndf: Final = 64

The type checker resolves DCGAN.ngf * 8 to Literal[512], so you can write:

self.project = nn.ConvTranspose2d(DCGAN.nz, DCGAN.ngf * 8, 4, 1, 0)
# Type checker infers: ConvTranspose2d[100, 512, ...]

Dynamic construction patterns

Several common patterns break shape tracking. Here's a quick reference and the fix for each:

nn.Sequential(*list_var)

When modules are constructed in a loop and passed as *list_var, the Sequential loses all type information:

# Bad: Sequential erases module types
layers = [nn.Linear(dim, dim) for _ in range(n)]
self.net = nn.Sequential(*layers) # returns bare Tensor

Fix: extract shape-changing modules as individual attributes and chain them in forward. Shape-preserving modules (activations, norms, dropout) can remain grouped:

# Good: individual attributes preserve types
self.fc1 = nn.Linear(dim, hidden)
self.fc2 = nn.Linear(hidden, dim)

def forward[B](self, x: Tensor[B, D]) -> Tensor[B, D]:
h = F.relu(self.fc1(x))
return self.fc2(h)

Note: nn.Sequential(M1(), M2(), M3()) with direct arguments is tracked — only the *list_var form loses types.

Factory functions returning nn.Sequential

Returning nn.Sequential from a function erases all type parameters at the function boundary:

# Bad: factory function — types erased
def _make_block(in_c, out_c) -> nn.Sequential:
return nn.Sequential(nn.Conv2d(in_c, 128, ...), nn.Conv2d(128, out_c, ...))

Fix: use a class with a typed forward method:

# Good: class preserves shape contract
class Block[InC, OutC](nn.Module):
def __init__(self, in_c: Dim[InC], out_c: Dim[OutC]) -> None:
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(in_c, 128, ...), nn.Conv2d(128, out_c, ...)
)

def forward[B, H, W](
self, x: Tensor[B, InC, H, W]
) -> Tensor[B, OutC, H, W]:
return self.net(x)

Dimensions from list[int]

list[int] element access returns int, losing the concrete value:

hidden_units = [512, 256, 128]
last_dim = hidden_units[-1] # type is int, not Dim[128]

Fix: add an explicit Dim field to the config:

@dataclass
class Config[K, MlpOut]:
num_output_features: Dim[K]
mlp_output_dim: Dim[MlpOut] # explicit — was hidden_units[-1]
mlp_hidden_units: list[int] = field(default_factory=lambda: [512, 256])

Typed interfaces

When none of the above restructurings can recover shape tracking, the last resort is a typed interface: the module's forward signature provides the shape contract, and type: ignore narrows the internal result:

class DynamicMLP[InDim, OutDim](nn.Module):
def __init__(self, in_dim: Dim[InDim], out_dim: Dim[OutDim],
hidden: list[int]) -> None:
super().__init__()
layers = []
prev = in_dim
for h in hidden:
layers.append(nn.Linear(prev, h))
prev = h
layers.append(nn.Linear(prev, out_dim))
self.layers = nn.ModuleList(layers)

def forward[B](self, x: Tensor[B, InDim]) -> Tensor[B, OutDim]:
h = x
for layer in self.layers:
h = layer(h)
result: Tensor[B, OutDim] = h # type: ignore[bad-assignment]
return result

The caller sees a clean Tensor[B, InDim] -> Tensor[B, OutDim] contract. The type: ignore is localized to the module's internals.

Use typed interfaces only after exhausting all restructuring options — they're the fallback, not the first move.

Key concepts

  • Parameterized configs propagate dimensions across modules without threading every Dim through every constructor.
  • Final class attributes let the checker resolve constants at the type level.
  • Dynamic construction (Sequential(*list), factory functions, list[int] access) breaks tracking — restructure or add explicit Dim fields.
  • Typed interfaces provide a clean shape contract when internals are dynamic, but should be the last resort.