Skip to main content

Tutorial 1: Your First Port

In this tutorial, you'll add tensor shape annotations to a simple multi-layer perceptron (MLP) model. By the end, you'll understand Dim, Tensor[...], class-level type parameters, and method-level type parameters.

The model

Here's a simple actor network from a reinforcement learning setup — three Linear layers in sequence:

class BaselineActor(nn.Module):
def __init__(self, state_size: int, action_size: int):
super().__init__()
self.fc1 = nn.Linear(state_size, 400)
self.fc2 = nn.Linear(400, 400)
self.out = nn.Linear(400, action_size)

def forward(self, state):
h1 = F.relu(self.fc1(state))
h2 = F.relu(self.fc2(h1))
return torch.tanh(self.out(h2))

Without shape annotations, every intermediate value is just Tensor. You can't tell from reading the code what shape h1 has, or whether the layer dimensions are consistent.

Step 1: Identify the dimensions

The constructor takes two parameters that determine tensor dimensions:

  • state_size — the input dimension (flows into nn.Linear)
  • action_size — the output dimension (flows into nn.Linear)

Both flow to sub-module constructors, so both must be Dim, not int. (Dim[X] is a type that bridges a runtime integer value to a type-level symbol X — see Getting Started for details.)

There are also two fixed constants: 400 (hidden dimension). These are literal values, not parameters, so they don't need type params.

Step 2: Type the constructor

Make the dimension parameters into Dim[...] and add class-level type parameters:

class BaselineActor[S, A](nn.Module):
def __init__(self, state_size: Dim[S], action_size: Dim[A]) -> None:
super().__init__()
self.fc1 = nn.Linear(state_size, 400)
self.fc2 = nn.Linear(400, 400)
self.out = nn.Linear(400, action_size)

Now when someone writes BaselineActor(24, 4), the type checker binds S = 24 and A = 4, inferring the type BaselineActor[24, 4]. The sub-modules are automatically typed: self.fc1 is Linear[24, 400], self.out is Linear[400, 4].

Step 3: Type the forward

The forward method has one dynamic dimension — batch size — that varies across calls. Make it a method-level type parameter:

def forward[B](self, state: Tensor[B, S]) -> Tensor[B, A]:
h1 = F.relu(self.fc1(state))
h2 = F.relu(self.fc2(h1))
return torch.tanh(self.out(h2))

S and A are class-level params (fixed at construction). B is a method-level param (bound per call).

Step 4: Verify inferred shapes

Add assert_type after each intermediate to verify what pyrefly infers:

def forward[B](self, state: Tensor[B, S]) -> Tensor[B, A]:
h1 = F.relu(self.fc1(state))
assert_type(h1, Tensor[B, 400])
h2 = F.relu(self.fc2(h1))
assert_type(h2, Tensor[B, 400])
act = torch.tanh(self.out(h2))
assert_type(act, Tensor[B, A])
return act

Run pyrefly check. If any assert_type fails, the shape you expected doesn't match what pyrefly inferred — investigate the mismatch.

Once all shapes check out, remove the assert_type calls. Each one corresponds to an inlay type hint that your IDE shows permanently. Pyrefly catches shape errors through your function signatures and return types regardless — you don't need assert_type in the final code.

Step 5: Add a smoke test

Smoke tests exercise the model at concrete dimensions. They verify that the shape annotations are consistent end-to-end:

def test_baseline_actor():
actor = BaselineActor(24, 4)
state = torch.randn(8, 24)
# pyrefly infers: Tensor[8, 24]
act = actor(state)
# pyrefly infers: Tensor[8, 4]

Use concrete dimensions in tests (Tensor[8, 24], not generic Tensor[B, S]) so the type checker verifies the full shape calculation.

The complete port

from __future__ import annotations

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch_shapes import Dim


class BaselineActor[S, A](nn.Module):
def __init__(self, state_size: Dim[S], action_size: Dim[A]) -> None:
super().__init__()
self.fc1 = nn.Linear(state_size, 400)
self.fc2 = nn.Linear(400, 400)
self.out = nn.Linear(400, action_size)

def forward[B](self, state: Tensor[B, S]) -> Tensor[B, A]:
h1 = F.relu(self.fc1(state))
# pyrefly infers: Tensor[B, 400]
h2 = F.relu(self.fc2(h1))
# pyrefly infers: Tensor[B, 400]
act = torch.tanh(self.out(h2))
# pyrefly infers: Tensor[B, A]
return act
note

This example uses from __future__ import annotations with new-style generics — the simplest setup. Pyrefly shows inferred shapes as inlay type hints in your IDE. If you want assert_type for runtime regression guards, see Getting Started for the torch_shapes.TypeVar import style, which requires old-style generics.

Key concepts

  • Dim[X] bridges runtime integer values to type-level symbols. Constructor parameters that determine tensor dimensions should be Dim[X], not int.
  • Class type parameters (class Foo[S, A]) represent dimensions fixed at construction time.
  • Method type parameters (def forward[B]) represent dimensions that vary per call (batch size, sequence length).
  • Inlay type hints show inferred shapes in your IDE. Use reveal_type during development to inspect shapes in checker output.

Next steps

This model had a simple linear pipeline — each layer feeds into the next with known shapes. In Tutorial 2, you'll see what happens when layers are stacked in loops, as in Transformer architectures.