Tutorial 3: Complex Architectures
Tutorial 2 covered shape-preserving loops. This tutorial tackles architectures where shapes change systematically — encoder-decoder networks with skip connections, and recursive chains where dimensions grow or shrink exponentially.
Encoder-decoder with skip connections
Encoder-decoder architectures (UNet, Demucs, Super SloMo) encode the input to a bottleneck and then decode back, with skip connections between corresponding encoder and decoder layers.
The shape pattern
Each encode step transforms (B, C, H, W) to (B, 2C, H', W') — doubling
channels and shrinking spatial dimensions. Decoding reverses this, using the
skip connection to restore the original shape. The key insight is that each
encode-recurse-decode cycle preserves the input shape:
encode: (B, C, H, W) → (B, 2C, H', W')
recurse: preserves (B, 2C, H', W')
decode + skip: (B, 2C, H', W') + (B, C, H, W) → (B, C, H, W)
Typing the recursion
This gives a recursive signature where recurse takes and returns the same
shape:
class UNet[NChannels, NClasses](nn.Module):
def _encode[B, C, H, W](
self, x: Tensor[B, C, H, W], depth: int
) -> Tensor[B, 2 * C, (H - 2) // 2 + 1, (W - 2) // 2 + 1]:
idx = len(self.downs) - depth
down: Down[C, 2 * C] = self.downs[idx]
return down(x)
def _decode[B, C, H, W](
self,
skip: Tensor[B, C, H, W],
deep: Tensor[B, 2 * C, (H - 2) // 2 + 1, (W - 2) // 2 + 1],
depth: int,
) -> Tensor[B, C, H, W]:
idx = len(self.ups) - depth
up: Up[2 * C, C] = self.ups[idx]
return up(deep, skip)
def recurse[I, B, C, H, W](
self, x: Tensor[B, C, H, W], depth: Dim[I]
) -> Tensor[B, C, H, W]:
if depth == 0:
return x
skip = x
encoded = self._encode(x, depth)
middle = self.recurse(encoded, depth - 1)
decoded = self._decode(skip, middle, depth)
return decoded
Narrowing annotations for heterogeneous lists
Python has no way to express "element i of this list has type
Stage[C * 2**i]". The workaround:
- Declare the list with
Any:list[Down[Any, Any]] - Narrow at the access site:
down: Down[C, 2 * C] = self.downs[idx]
The Any erases element-level type info, and the annotation re-introduces
it for each use.
Algebraic gaps
Some algebraic equivalences can't be automatically proven. For example,
((H - 2) // 2 + 1) * 2 does not simplify back to H. When you hit this,
use type: ignore with a comment explaining the gap:
return up(deep, skip) # type: ignore[bad-argument-type] # ((H-2)//2+1)*2 = H
Keep these to an absolute minimum and document each one.
Recursive chains with exponential shapes
When each stage doubles or halves a dimension, the result after I stages
involves 2**I. This appears in DCGAN (generator and discriminator), ResNet,
and DenseNet.
The @overload pattern
Use @overload to separate the base case from the recursive case:
class Generator(nn.Module):
def _apply_stage[B, C, H, W](
self, x: Tensor[B, C, H, W], depth: int
) -> Tensor[B, C // 2, (H - 1) * 2 + 2, (W - 1) * 2 + 2]:
idx = len(self.up_stages) - depth
stage: GenUpStage[C] = self.up_stages[idx]
return stage(x)
@overload
def _chain[B, C, H, W](
self, x: Tensor[B, C, H, W], depth: Dim[1]
) -> Tensor[B, C // 2, H * 2, W * 2]: ...
@overload
def _chain[I, B, C, H, W](
self, x: Tensor[B, C, H, W], depth: Dim[I]
) -> Tensor[B, C // 2 ** I, H * 2 ** I, W * 2 ** I]: ...
def _chain[I, B, C, H, W](
self, x: Tensor[B, C, H, W], depth: Dim[I]
) -> (Tensor[B, C // 2, H * 2, W * 2]
| Tensor[B, C // 2 ** I, H * 2 ** I, W * 2 ** I]):
y = self._apply_stage(x, depth)
if depth == 1:
return y
return self._chain(y, depth - 1)
The base-case overload (depth: Dim[1]) handles the single-stage case
where the formula simplifies concretely. The recursive overload uses 2**I
to express the exponential relationship.
The two-method pattern
This _apply_stage + _chain pattern separates concerns:
_apply_stage: applies a single stage from theModuleList, using a narrowing annotation to type the list element._chain: recursively applies_apply_stagewith overloaded return types.
The caller invokes _chain with a concrete depth:
def forward[B](self, input: Tensor[B, 100, 1, 1]) -> Tensor[B, 3, 64, 64]:
h0 = F.relu(self.project_bn(self.project(input)))
assert_type(h0, Tensor[B, 512, 4, 4])
h1 = self._chain(h0, 3) # 512->64, 4->32
assert_type(h1, Tensor[B, 64, 32, 32])
return torch.tanh(self.output(h1))
Key concepts
- Recursive shape preservation: encode-recurse-decode cycles preserve the input shape, enabling a clean recursive signature.
- Narrowing annotations re-introduce type information lost by
heterogeneous
ModuleLists. @overloadseparates base and recursive cases for exponential shape chains.type: ignoreis a last resort for algebraic gaps the checker can't prove. Always document the specific equivalence.
Next steps
In Tutorial 4, you'll see how to handle config classes with type parameters, dynamic construction patterns, and other advanced techniques.