Skip to main content

Tutorial 3: Complex Architectures

Tutorial 2 covered shape-preserving loops. This tutorial tackles architectures where shapes change systematically — encoder-decoder networks with skip connections, and recursive chains where dimensions grow or shrink exponentially.

Encoder-decoder with skip connections

Encoder-decoder architectures (UNet, Demucs, Super SloMo) encode the input to a bottleneck and then decode back, with skip connections between corresponding encoder and decoder layers.

The shape pattern

Each encode step transforms (B, C, H, W) to (B, 2C, H', W') — doubling channels and shrinking spatial dimensions. Decoding reverses this, using the skip connection to restore the original shape. The key insight is that each encode-recurse-decode cycle preserves the input shape:

encode: (B, C, H, W) → (B, 2C, H', W')
recurse: preserves (B, 2C, H', W')
decode + skip: (B, 2C, H', W') + (B, C, H, W) → (B, C, H, W)

Typing the recursion

This gives a recursive signature where recurse takes and returns the same shape:

class UNet[NChannels, NClasses](nn.Module):
def _encode[B, C, H, W](
self, x: Tensor[B, C, H, W], depth: int
) -> Tensor[B, 2 * C, (H - 2) // 2 + 1, (W - 2) // 2 + 1]:
idx = len(self.downs) - depth
down: Down[C, 2 * C] = self.downs[idx]
return down(x)

def _decode[B, C, H, W](
self,
skip: Tensor[B, C, H, W],
deep: Tensor[B, 2 * C, (H - 2) // 2 + 1, (W - 2) // 2 + 1],
depth: int,
) -> Tensor[B, C, H, W]:
idx = len(self.ups) - depth
up: Up[2 * C, C] = self.ups[idx]
return up(deep, skip)

def recurse[I, B, C, H, W](
self, x: Tensor[B, C, H, W], depth: Dim[I]
) -> Tensor[B, C, H, W]:
if depth == 0:
return x
skip = x
encoded = self._encode(x, depth)
middle = self.recurse(encoded, depth - 1)
decoded = self._decode(skip, middle, depth)
return decoded

Narrowing annotations for heterogeneous lists

Python has no way to express "element i of this list has type Stage[C * 2**i]". The workaround:

  1. Declare the list with Any: list[Down[Any, Any]]
  2. Narrow at the access site:
down: Down[C, 2 * C] = self.downs[idx]

The Any erases element-level type info, and the annotation re-introduces it for each use.

Algebraic gaps

Some algebraic equivalences can't be automatically proven. For example, ((H - 2) // 2 + 1) * 2 does not simplify back to H. When you hit this, use type: ignore with a comment explaining the gap:

return up(deep, skip)  # type: ignore[bad-argument-type]  # ((H-2)//2+1)*2 = H

Keep these to an absolute minimum and document each one.

Recursive chains with exponential shapes

When each stage doubles or halves a dimension, the result after I stages involves 2**I. This appears in DCGAN (generator and discriminator), ResNet, and DenseNet.

The @overload pattern

Use @overload to separate the base case from the recursive case:

class Generator(nn.Module):
def _apply_stage[B, C, H, W](
self, x: Tensor[B, C, H, W], depth: int
) -> Tensor[B, C // 2, (H - 1) * 2 + 2, (W - 1) * 2 + 2]:
idx = len(self.up_stages) - depth
stage: GenUpStage[C] = self.up_stages[idx]
return stage(x)

@overload
def _chain[B, C, H, W](
self, x: Tensor[B, C, H, W], depth: Dim[1]
) -> Tensor[B, C // 2, H * 2, W * 2]: ...

@overload
def _chain[I, B, C, H, W](
self, x: Tensor[B, C, H, W], depth: Dim[I]
) -> Tensor[B, C // 2 ** I, H * 2 ** I, W * 2 ** I]: ...

def _chain[I, B, C, H, W](
self, x: Tensor[B, C, H, W], depth: Dim[I]
) -> (Tensor[B, C // 2, H * 2, W * 2]
| Tensor[B, C // 2 ** I, H * 2 ** I, W * 2 ** I]):
y = self._apply_stage(x, depth)
if depth == 1:
return y
return self._chain(y, depth - 1)

The base-case overload (depth: Dim[1]) handles the single-stage case where the formula simplifies concretely. The recursive overload uses 2**I to express the exponential relationship.

The two-method pattern

This _apply_stage + _chain pattern separates concerns:

  • _apply_stage: applies a single stage from the ModuleList, using a narrowing annotation to type the list element.
  • _chain: recursively applies _apply_stage with overloaded return types.

The caller invokes _chain with a concrete depth:

def forward[B](self, input: Tensor[B, 100, 1, 1]) -> Tensor[B, 3, 64, 64]:
h0 = F.relu(self.project_bn(self.project(input)))
assert_type(h0, Tensor[B, 512, 4, 4])
h1 = self._chain(h0, 3) # 512->64, 4->32
assert_type(h1, Tensor[B, 64, 32, 32])
return torch.tanh(self.output(h1))

Key concepts

  • Recursive shape preservation: encode-recurse-decode cycles preserve the input shape, enabling a clean recursive signature.
  • Narrowing annotations re-introduce type information lost by heterogeneous ModuleLists.
  • @overload separates base and recursive cases for exponential shape chains.
  • type: ignore is a last resort for algebraic gaps the checker can't prove. Always document the specific equivalence.

Next steps

In Tutorial 4, you'll see how to handle config classes with type parameters, dynamic construction patterns, and other advanced techniques.