Tensor Shapes
This feature is experimental. The API and behavior may change in future releases without notice.
Pyrefly can track tensor shapes through your PyTorch models, giving you end-to-end static type checking of shape transformations.
Why tensor shapes?
When you write PyTorch code, the hardest thing to keep track of is tensor shapes. Every operator transforms shapes in non-trivial ways, and shape mistakes don't always crash — they can silently produce wrong results. With Pyrefly's tensor shape support, you get automatic inlay type hints showing tensor shapes as you type, so you can see exactly what shape every intermediate tensor has without running your code or adding print statements.
Here's the same NanoGPT forward method, without and with tensor shape tracking:
- Without tensor shapes
- With tensor shapes


With tensor shapes enabled, Pyrefly infers and displays the shape of every
intermediate tensor — Tensor[B, T, NEmbedding] for embeddings,
Tensor[T] for position indices — without any manual annotations on local
variables.
How it works
Pyrefly's tensor shape support is built on two extensions that work together:
- Symbolic integer arithmetic in the core type system — lets you write
Tensor[B, C, H, W]and have arithmetic likeD // NHeadwork at the type level. - Shape transform specifications for PyTorch operators — a library of shape rules that tells Pyrefly how each operator transforms shapes.
With these two extensions, you can use Pyrefly to typecheck real-world PyTorch models, where the tensor shapes of all local variables are inferred from just a few annotations at class and function boundaries.
Symbolic integer arithmetic
You can write tensor types with integer dimensions — Tensor[3, 4] is a
2D tensor with shape (3, 4). This works for modules too: nn.Linear[3, 4]
takes Tensor[..., 3] as input and returns Tensor[..., 4].
Dim[X] bridges runtime integer values to the type level. When
x: Tensor[3, 4], then x.shape has type tuple[Dim[3], Dim[4]] — you
can extract dimensions from tensors and use them to construct new ones.
Arithmetic works too: if a: Dim[3] and b: Dim[4], then
a * b: Dim[12].
Generic type parameters let you write shape-polymorphic modules:
class Linear[N, M]:
def __init__(self, n: Dim[N], m: Dim[M]):
...
def forward[*Xs](self, inp: Tensor[*Xs, N]) -> Tensor[*Xs, M]:
...
linear: Linear[3, 4] = Linear(3, 4)
inp: Tensor[2, 5, 3] = ...
x: Tensor[2, 5, 4] = linear(inp)
Dim encodes symbolic shapes at the type level throughout PyTorch, covering modules as well as tensors.
Arithmetic on type variables lets you write custom shape transforms:
def custom_rand_tensor[A, B](a: Dim[A], b: Dim[B]) -> Tensor[(A + B) // 2]:
return torch.randn((a + b) // 2)
x: Tensor[3] = custom_rand_tensor(2, 4)
While these examples use type annotations for exposition, the types of local
variables like x and linear are inferred automatically — Pyrefly shows
them as inlay type hints in your editor.
Shape transform specifications
Some PyTorch operators have simple shape signatures that can be expressed as
standard type stubs. For example, torch.mm:
def mm[M, K, N](x: Tensor[M, K], y: Tensor[K, N]) -> Tensor[M, N]:
...
For operators with more complex shape logic (like reshape, cat, or
F.interpolate), Pyrefly uses a small DSL to specify shape transforms:
# Internal library definition — not user-facing code
def repeat_ir(self: Tensor, sizes: list[int | symint]) -> Tensor:
return Tensor(shape=[d * r for d, r in zip(self.shape, sizes)])
This means you can extend Pyrefly's shape coverage for new PyTorch operators without touching Pyrefly's internals — see the contributing guide for details.
Related Work
Tensor shape checking is not a new idea. Several projects have explored this space, each with different trade-offs. Here's how Pyrefly's approach compares.
Pyre and Pyright (static type checking)
An early attempt at tensor shapes in Pyre, the precursor of Pyrefly, is
described in this PyCon 2023 talk.
That system used Literal to wrap concrete sizes and explicit type constructors
like IntDiv for arithmetic — avoiding the need to support integers as type
arguments or arithmetic on type variables, but at the cost of verbose syntax
(e.g., Tensor[float, M, Literal[2], IntDiv[M, 2]]).
More recently, this video details two approaches to extend Pyright with tensor shapes, with the author concluding neither was the right approach. Both suffered from verbose syntax and a large set of type-level operators.
Pyrefly's system is leaner by design:
- No constraint solver — Pyrefly prioritizes "more shape inference with fewer annotations" rather than "more red squiggles with more annotations."
- The complexity of specifying shape transforms for PyTorch operators is separated into a DSL available to library maintainers, keeping the type system features available to users simple.
Jaxtyping (runtime type checking)
With jaxtyping, users can
express tensor shapes in type annotations that other libraries like typeguard
and beartype can check at runtime. The syntax is designed to be universal for
array-like containers (supporting JAX, NumPy, and PyTorch), but is somewhat
verbose — for example, Shaped[Tensor, "M 2 M//2"] instead of
Tensor[M, 2, M // 2].
Pyrefly supports jaxtyping annotations as an alternative front-end to our native syntax; these annotations are translated internally to use generics and display back in jaxtyping syntax.
However, there is a significant limitation of jaxtyping: there is no way to share symbolic dimensions across variables and functions in a class (see this issue). This limits jaxtyping to individual functions that operate on tensors, rather than a hierarchy of modules connecting them end-to-end. In practical terms, our fully typechecked implementations of real-world models (e.g., NanoGPT) cannot be faithfully ported to use jaxtyping syntax alone.
Get involved
Tensor shapes are under active development. We welcome contributions — especially new fixture stubs and DSL specifications for PyTorch operators. See the contributing guide to get started.
Join the conversation: