Skip to main content

Tutorial 2: Loops and Stacking

In Tutorial 1, every layer was called directly. But Transformer-style architectures stack identical layers in a ModuleList and iterate over them. This tutorial shows how to type those patterns.

Shape-preserving loops

When every layer in a loop has the same type signature — input shape equals output shape — the type checker can verify that the loop preserves the shape invariant.

Here's a Transformer encoder that stacks n_layers identical EncoderLayer modules:

class Encoder[NHead, DK, DInner](nn.Module):
def __init__(
self,
n_head: Dim[NHead],
d_k: Dim[DK],
d_inner: Dim[DInner],
n_layers: int = 6,
) -> None:
super().__init__()
self.layer_stack = nn.ModuleList(
[EncoderLayer(n_head, d_k, d_inner) for _ in range(n_layers)]
)

def forward[B, T](
self, src_seq: Tensor[B, T, NHead * DK]
) -> Tensor[B, T, NHead * DK]:
enc_output = src_seq
for layer in self.layer_stack:
enc_output, _attn = layer(enc_output)
assert_type(enc_output, Tensor[B, T, NHead * DK])
return enc_output

Notice that n_layers is int, not Dim — it's an iteration count that doesn't flow into any tensor dimension. Only values that determine tensor shapes need to be Dim.

Each EncoderLayer takes Tensor[B, T, NHead * DK] and returns the same shape, so the loop preserves the invariant and the type checker is satisfied.

Derived dimensions

The model dimension here is NHead * DK — a derived expression, not an independent type parameter. This is important: only independent degrees of freedom get type params. If you wrote class Encoder[NHead, DK, D] with an independent D, you'd lose the constraint that D == NHead * DK.

Separating the first iteration

Sometimes the first iteration of a loop changes the shape, while subsequent iterations preserve it. The type checker sees the union of both shapes and widens to a less precise type.

The fix is to separate the first iteration:

# Problem: x widens to Tensor[B, F, D] | Tensor[B, K, D]
x = input_embs
for layer in self.layers:
x = layer(x)
out: Tensor[B, K, D] = x # type: ignore[bad-assignment]

# Solution: no union, no type: ignore
x = self.layers[0](input_embs) # [B, F, D] -> [B, K, D]
assert_type(x, Tensor[B, K, D])
for i in range(1, len(self.layers)):
x = self.layers[i](x) # [B, K, D] -> [B, K, D]

The first call changes the shape; subsequent calls preserve it. By separating them, you avoid the union widening entirely.

Shape-preserving activations

Many architectures accept an activation function as a parameter (ReLU, GELU, SiLU, etc.). Since each activation's forward signature is Tensor[*S] -> Tensor[*S], you can express this with a type alias:

ShapePreservingActivation = (
type[nn.ReLU] | type[nn.GELU] | type[nn.SiLU] | type[nn.Tanh]
)

class ResBlock[C](nn.Module):
def __init__(self, c: Dim[C], act_fn: ShapePreservingActivation) -> None:
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(c, c, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(c),
act_fn(),
)

The Sequential chains the modules and the type checker verifies that the overall shape is preserved.

Multi-head attention

Multi-head attention involves reshaping tensors from [B, T, D] to [B, NHead, T, D // NHead] and back. The dimension arithmetic is expressed directly in annotations:

class CausalSelfAttention[NEmbedding, NHead](nn.Module):
def __init__(
self,
n_embd: Dim[NEmbedding],
n_head: Dim[NHead],
) -> None:
super().__init__()
self.c_attn = nn.Linear(n_embd, 3 * n_embd)
self.c_proj = nn.Linear(n_embd, n_embd)
self.n_head = n_head
self.n_embd = n_embd

def forward[B, T](
self, x: Tensor[B, T, NEmbedding]
) -> Tensor[B, T, NEmbedding]:
qkv = self.c_attn(x)
assert_type(qkv, Tensor[B, T, 3 * NEmbedding])
q, k, v = qkv.split(self.n_embd, dim=2)
assert_type(q, Tensor[B, T, NEmbedding])

# Reshape for multi-head: [B, T, D] -> [B, NHead, T, D // NHead]
head_dim = self.n_embd // self.n_head
q = q.view(q.size(0), q.size(1), self.n_head, head_dim)
q = q.transpose(1, 2)
assert_type(q, Tensor[B, NHead, T, NEmbedding // NHead])
...

The type checker tracks the reshape and transpose, verifying that D // NHead is consistent throughout the attention computation.

Key concepts

  • int for iteration counts, Dim for tensor dimensions. n_layers doesn't flow into tensor shapes, so it stays int.
  • Derived dimensions express relationships: NHead * DK, not independent D.
  • Separate the first iteration when it changes shape to avoid union widening.
  • Arithmetic in annotations (3 * NEmbedding, NEmbedding // NHead) is tracked and simplified automatically.

Next steps

In Tutorial 3, you'll see how to handle encoder-decoder architectures with skip connections, where shapes change as you go deeper and must be restored on the way back up.