Tutorial 1: Your First Port
In this tutorial, you'll add tensor shape annotations to a simple multi-layer perceptron (MLP) model.
By the end, you'll understand Dim, Tensor[...], class-level type
parameters, and method-level type parameters.
The model
Here's a simple actor network from a reinforcement learning setup — three Linear layers in sequence:
class BaselineActor(nn.Module):
def __init__(self, state_size: int, action_size: int):
super().__init__()
self.fc1 = nn.Linear(state_size, 400)
self.fc2 = nn.Linear(400, 400)
self.out = nn.Linear(400, action_size)
def forward(self, state):
h1 = F.relu(self.fc1(state))
h2 = F.relu(self.fc2(h1))
return torch.tanh(self.out(h2))
Without shape annotations, every intermediate value is just Tensor. You
can't tell from reading the code what shape h1 has, or whether the layer
dimensions are consistent.
Step 1: Identify the dimensions
The constructor takes two parameters that determine tensor dimensions:
state_size— the input dimension (flows intonn.Linear)action_size— the output dimension (flows intonn.Linear)
Both flow to sub-module constructors, so both must be Dim, not int.
(Dim[X] is a type that bridges a runtime integer value to a type-level
symbol X — see Getting Started for details.)
There are also two fixed constants: 400 (hidden dimension). These are
literal values, not parameters, so they don't need type params.
Step 2: Type the constructor
Make the dimension parameters into Dim[...] and add class-level type
parameters:
class BaselineActor[S, A](nn.Module):
def __init__(self, state_size: Dim[S], action_size: Dim[A]) -> None:
super().__init__()
self.fc1 = nn.Linear(state_size, 400)
self.fc2 = nn.Linear(400, 400)
self.out = nn.Linear(400, action_size)
Now when someone writes BaselineActor(24, 4), the type checker binds
S = 24 and A = 4, inferring the type BaselineActor[24, 4]. The
sub-modules are automatically typed: self.fc1 is Linear[24, 400],
self.out is Linear[400, 4].
Step 3: Type the forward
The forward method has one dynamic dimension — batch size — that varies across calls. Make it a method-level type parameter:
def forward[B](self, state: Tensor[B, S]) -> Tensor[B, A]:
h1 = F.relu(self.fc1(state))
h2 = F.relu(self.fc2(h1))
return torch.tanh(self.out(h2))
S and A are class-level params (fixed at construction). B is a
method-level param (bound per call).
Step 4: Verify inferred shapes
Add assert_type after each intermediate to verify what pyrefly infers:
def forward[B](self, state: Tensor[B, S]) -> Tensor[B, A]:
h1 = F.relu(self.fc1(state))
assert_type(h1, Tensor[B, 400])
h2 = F.relu(self.fc2(h1))
assert_type(h2, Tensor[B, 400])
act = torch.tanh(self.out(h2))
assert_type(act, Tensor[B, A])
return act
Run pyrefly check. If any assert_type fails, the shape you expected
doesn't match what pyrefly inferred — investigate the mismatch.
Once all shapes check out, remove the assert_type calls. Each one
corresponds to an inlay type hint that your IDE shows permanently. Pyrefly
catches shape errors through your function signatures and return types
regardless — you don't need assert_type in the final code.
Step 5: Add a smoke test
Smoke tests exercise the model at concrete dimensions. They verify that the shape annotations are consistent end-to-end:
def test_baseline_actor():
actor = BaselineActor(24, 4)
state = torch.randn(8, 24)
# pyrefly infers: Tensor[8, 24]
act = actor(state)
# pyrefly infers: Tensor[8, 4]
Use concrete dimensions in tests (Tensor[8, 24], not generic Tensor[B, S])
so the type checker verifies the full shape calculation.
The complete port
- With tensor shapes
- Without tensor shapes
from __future__ import annotations
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch_shapes import Dim
class BaselineActor[S, A](nn.Module):
def __init__(self, state_size: Dim[S], action_size: Dim[A]) -> None:
super().__init__()
self.fc1 = nn.Linear(state_size, 400)
self.fc2 = nn.Linear(400, 400)
self.out = nn.Linear(400, action_size)
def forward[B](self, state: Tensor[B, S]) -> Tensor[B, A]:
h1 = F.relu(self.fc1(state))
# pyrefly infers: Tensor[B, 400]
h2 = F.relu(self.fc2(h1))
# pyrefly infers: Tensor[B, 400]
act = torch.tanh(self.out(h2))
# pyrefly infers: Tensor[B, A]
return act
import torch
import torch.nn as nn
import torch.nn.functional as F
class BaselineActor(nn.Module):
def __init__(self, state_size: int, action_size: int):
super().__init__()
self.fc1 = nn.Linear(state_size, 400)
self.fc2 = nn.Linear(400, 400)
self.out = nn.Linear(400, action_size)
def forward(self, state):
h1 = F.relu(self.fc1(state))
# what shape is h1? Tensor — that's all you know
h2 = F.relu(self.fc2(h1))
return torch.tanh(self.out(h2))
This example uses from __future__ import annotations with new-style
generics — the simplest setup. Pyrefly shows inferred shapes as inlay type
hints in your IDE. If you want assert_type for runtime regression guards,
see Getting Started
for the torch_shapes.TypeVar import style, which requires old-style
generics.
Key concepts
Dim[X]bridges runtime integer values to type-level symbols. Constructor parameters that determine tensor dimensions should beDim[X], notint.- Class type parameters (
class Foo[S, A]) represent dimensions fixed at construction time. - Method type parameters (
def forward[B]) represent dimensions that vary per call (batch size, sequence length). - Inlay type hints show inferred shapes in your IDE. Use
reveal_typeduring development to inspect shapes in checker output.
Next steps
This model had a simple linear pipeline — each layer feeds into the next with known shapes. In Tutorial 2, you'll see what happens when layers are stacked in loops, as in Transformer architectures.