Skip to main content

Getting Started with Pyrefly Tensor Shapes

This page walks you through configuring Pyrefly for tensor shape checking and getting your first shape-annotated code running.

Configuration

Tensor shape checking requires two settings in your pyrefly.toml (or under [tool.pyrefly] in pyproject.toml):

tensor-shapes = true

search_path = [
"path/to/fixtures",
]

tensor-shapes = true enables shape inference for Tensor, including subscript syntax (Tensor[B, C, H, W]), algebraic dimension arithmetic, and shape-aware dispatch for operations like conv2d, view, and cat.

search_path points to a directory of fixture stubs.pyi files that provide shape-generic type signatures for PyTorch modules and functions. The real torch library's type stubs don't carry shape information, so the fixtures replace them with shape-aware versions (e.g., nn.Conv2d.__init__ that captures kernel size, stride, and padding as type-level values, and a forward that computes the output spatial dimensions).

The fixtures live in Pyrefly's source tree at test/tensor_shapes/fixtures/. To use them in your project, copy the fixtures/ directory into your project and set search_path to point to it. The path is relative to the location of your pyrefly.toml.

The fixtures also provide the torch_shapes package, which exports Dim — the bridge between runtime integer values and type-level symbols.

Imports and runtime considerations

Python evaluates type annotations at runtime by default. This is a problem for tensor shape annotations because Python's built-in typing.TypeVar doesn't support arithmetic — expressions like D // NHead in an annotation will raise TypeError when the annotation is evaluated. There are two ways to avoid this:

Adding this import at the top of the file defers evaluation of all annotations, so shape arithmetic never executes at runtime:

from __future__ import annotations

import torch
import torch.nn as nn
from torch import Tensor
from torch_shapes import Dim

This works with both old-style and new-style generics (PEP 695 class Foo[T] syntax).

assert_type during development: You can use assert_type while porting to verify shapes via pyrefly check. Once you're done, remove the assert_type calls — each one corresponds to an IDE inlay type hint that shows the same information permanently. Pyrefly catches shape errors through your function signatures and return types regardless.

Note that assert_type forces evaluation of its type argument, so the file will crash if you try to run it with assert_type calls still present. This is fine — just remove them when the port is complete.

You can also guard Tensor and Dim under TYPE_CHECKING if you prefer to keep shape imports invisible at runtime:

from __future__ import annotations
from typing import TYPE_CHECKING

import torch
import torch.nn as nn

if TYPE_CHECKING:
from torch import Tensor
from torch_shapes import Dim

Option 2: torch_shapes.TypeVar (runtime-compatible)

If you need annotations to evaluate at runtime (e.g., for runtime shape validation or keeping assert_type in production code), import torch_shapes directly. The package patches torch.Tensor, nn.Conv2d, and other torch classes to accept subscript syntax at runtime without crashing. It also provides a TypeVar that supports arithmetic (N + 1 returns self instead of raising TypeError).

Use old-style generics with torch_shapes.TypeVar:

from typing import assert_type

import torch
import torch.nn as nn
from torch import Tensor
from torch_shapes import Dim, TypeVar

N = TypeVar("N")
M = TypeVar("M")

class Linear(nn.Module):
def __init__(self, n: Dim[N], m: Dim[M]):
...

PEP 695 new-style generics (class Foo[T]) automatically use typing.TypeVar internally, which doesn't support arithmetic — so this option requires old-style generics.

Hello world

Here's a minimal example to verify everything works. This example uses Option 1 (from __future__ import annotations) since it's the simplest setup. We skip assert_type — instead, run pyrefly check and use your IDE's inlay type hints to verify shapes.

Create a file hello_shapes.py:

from __future__ import annotations

import torch
import torch.nn as nn
from torch import Tensor
from torch_shapes import Dim


class TwoLayerNet[InDim, HidDim, OutDim](nn.Module):
def __init__(
self,
in_dim: Dim[InDim],
hid_dim: Dim[HidDim],
out_dim: Dim[OutDim],
):
super().__init__()
self.fc1 = nn.Linear(in_dim, hid_dim)
self.fc2 = nn.Linear(hid_dim, out_dim)

def forward[B](self, x: Tensor[B, InDim]) -> Tensor[B, OutDim]:
h = self.fc1(x) # pyrefly infers: Tensor[B, HidDim]
return self.fc2(torch.relu(h))

Run pyrefly check hello_shapes.py. You should see no errors — pyrefly infers the shapes through the nn.Linear calls.

If you're using an IDE with Pyrefly's language server, you'll see inlay type hints showing the inferred shape of h as Tensor[B, HidDim] without needing any assert_type calls.

Inlay hints in action

Here's what inlay hints look like on a real model (NanoGPT). The MLP module shows shapes flowing through linear layers and activations:

NanoGPT MLP module with inlay type hints showing Tensor[B, T, 4 * NEmbedding] after the expansion layer

The forward method signature shows how x.size() unpacks into typed dimensions:

NanoGPT forward signature with x.size() unpacking to b: Dim[B], t: Dim[T], c: Dim[NEmbedding]

And the attention module, where view/transpose reshapes for multi-head attention are fully tracked:

NanoGPT attention QKV computation with inlay type hints showing Tensor[B, NHead, T, NEmbedding // NHead]

The full attention body, including both flash and manual paths:

NanoGPT attention body with flash and manual attention paths fully shape-tracked