Getting Started with Pyrefly Tensor Shapes
This page walks you through configuring Pyrefly for tensor shape checking and getting your first shape-annotated code running.
Configuration
Tensor shape checking requires two settings in your pyrefly.toml (or under
[tool.pyrefly] in pyproject.toml):
tensor-shapes = true
search_path = [
"path/to/fixtures",
]
tensor-shapes = true enables shape inference for Tensor, including
subscript syntax (Tensor[B, C, H, W]), algebraic dimension arithmetic, and
shape-aware dispatch for operations like conv2d, view, and cat.
search_path points to a directory of fixture stubs — .pyi files that
provide shape-generic type signatures for PyTorch modules and functions. The
real torch library's type stubs don't carry shape information, so the fixtures
replace them with shape-aware versions (e.g., nn.Conv2d.__init__ that captures
kernel size, stride, and padding as type-level values, and a forward that
computes the output spatial dimensions).
The fixtures live in Pyrefly's source tree at
test/tensor_shapes/fixtures/.
To use them in your project, copy the fixtures/ directory into your project
and set search_path to point to it. The path is relative to the location of
your pyrefly.toml.
The fixtures also provide the torch_shapes package, which exports Dim — the
bridge between runtime integer values and type-level symbols.
Imports and runtime considerations
Python evaluates type annotations at runtime by default. This is a problem
for tensor shape annotations because Python's built-in typing.TypeVar
doesn't support arithmetic — expressions like D // NHead in an annotation
will raise TypeError when the annotation is evaluated. There are two ways
to avoid this:
Option 1: from __future__ import annotations (recommended)
Adding this import at the top of the file defers evaluation of all annotations, so shape arithmetic never executes at runtime:
from __future__ import annotations
import torch
import torch.nn as nn
from torch import Tensor
from torch_shapes import Dim
This works with both old-style and new-style generics (PEP 695
class Foo[T] syntax).
assert_type during development: You can use assert_type while
porting to verify shapes via pyrefly check. Once you're done, remove the
assert_type calls — each one corresponds to an IDE inlay type hint that
shows the same information permanently. Pyrefly catches shape errors
through your function signatures and return types regardless.
Note that assert_type forces evaluation of its type argument, so the
file will crash if you try to run it with assert_type calls still
present. This is fine — just remove them when the port is complete.
You can also guard Tensor and Dim under TYPE_CHECKING if you prefer
to keep shape imports invisible at runtime:
from __future__ import annotations
from typing import TYPE_CHECKING
import torch
import torch.nn as nn
if TYPE_CHECKING:
from torch import Tensor
from torch_shapes import Dim
Option 2: torch_shapes.TypeVar (runtime-compatible)
If you need annotations to evaluate at runtime (e.g., for runtime shape
validation or keeping assert_type in production code), import
torch_shapes directly. The package patches
torch.Tensor, nn.Conv2d, and other torch classes to accept subscript
syntax at runtime without crashing. It also provides a TypeVar that
supports arithmetic (N + 1 returns self instead of raising TypeError).
Use old-style generics with torch_shapes.TypeVar:
from typing import assert_type
import torch
import torch.nn as nn
from torch import Tensor
from torch_shapes import Dim, TypeVar
N = TypeVar("N")
M = TypeVar("M")
class Linear(nn.Module):
def __init__(self, n: Dim[N], m: Dim[M]):
...
PEP 695 new-style generics (class Foo[T]) automatically use
typing.TypeVar internally, which doesn't support arithmetic — so this
option requires old-style generics.
Hello world
Here's a minimal example to verify everything works. This example uses
Option 1 (from __future__ import annotations) since it's the simplest
setup. We skip assert_type — instead, run pyrefly check and use your
IDE's inlay type hints to verify shapes.
Create a file hello_shapes.py:
from __future__ import annotations
import torch
import torch.nn as nn
from torch import Tensor
from torch_shapes import Dim
class TwoLayerNet[InDim, HidDim, OutDim](nn.Module):
def __init__(
self,
in_dim: Dim[InDim],
hid_dim: Dim[HidDim],
out_dim: Dim[OutDim],
):
super().__init__()
self.fc1 = nn.Linear(in_dim, hid_dim)
self.fc2 = nn.Linear(hid_dim, out_dim)
def forward[B](self, x: Tensor[B, InDim]) -> Tensor[B, OutDim]:
h = self.fc1(x) # pyrefly infers: Tensor[B, HidDim]
return self.fc2(torch.relu(h))
Run pyrefly check hello_shapes.py. You should see no errors — pyrefly
infers the shapes through the nn.Linear calls.
If you're using an IDE with Pyrefly's language server, you'll see inlay
type hints showing the inferred shape of h as Tensor[B, HidDim]
without needing any assert_type calls.
Inlay hints in action
Here's what inlay hints look like on a real model (NanoGPT). The MLP module shows shapes flowing through linear layers and activations:
The forward method signature shows how x.size() unpacks into typed
dimensions:
And the attention module, where view/transpose reshapes for multi-head attention are fully tracked:
The full attention body, including both flash and manual paths:
