API Reference
This page documents the types, functions, and type-level constructs that make up pyrefly's tensor shape type system.
Dim[X]
Dim[X] is a type constructor that bridges runtime integer values to
type-level symbols. It is defined in the torch_shapes package.
Basics
Dim[X] denotes the type of an integer value whose type-level identity is
X. For example:
Dim[5]is the type of the literal5Dim[N](whereNis a type variable) is the type of an integer whose value is bound toNat the type level
Dim is a subtype of int, so Dim values can be used anywhere int is
expected. However, the reverse is not true — passing a plain int where
Dim[X] is expected loses tracking.
Arithmetic
Arithmetic on Dim values produces Dim results with the corresponding
type-level expression:
| Expression | Type |
|---|---|
a + b where a: Dim[A], b: Dim[B] | Dim[A + B] |
a - b | Dim[A - B] |
a * b | Dim[A * B] |
a // b | Dim[A // B] |
a ** b | Dim[A ** B] |
Caution: int * Dim produces Unknown because the int side has no
type-level identity. Use Dim * Dim or literal * Dim instead.
Dim[X] | None
For optional dimensions — parameters that may or may not be present — use
Dim[X] | None. In the forward method, narrow with
if param is not None: to recover Dim[X] inside the branch:
class Attention[D, RK](nn.Module):
def __init__(self, dim: Dim[D], rank_k: Dim[RK] | None = None):
...
def forward[B, T](self, x: Tensor[B, T, D]) -> Tensor[B, T, D]:
if self.rank_k is not None:
# rank_k is Dim[RK] here
...
Usage patterns
| Pattern | Purpose |
|---|---|
def __init__(self, dim: Dim[D]) | Accept a dimension as a constructor parameter |
class Model[D](nn.Module) | Make a dimension a class-level type parameter |
def forward[B](self, x: Tensor[B, D]) | Bind a per-call dimension |
self.head_dim = dim // n_head | Compute a derived dimension (Dim[D // NHead]) |
Tensor[D1, D2, ...]
Tensor with type arguments represents a tensor with a known shape. The
type arguments are the dimensions, in order.
Forms
| Form | Meaning |
|---|---|
Tensor[3, 4] | Concrete 2D tensor with shape (3, 4) |
Tensor[B, C, H, W] | Generic 4D tensor with symbolic dimensions |
Tensor[B, 3 * C, H // 2] | Dimensions can contain arithmetic expressions |
Tensor[*Bs, D] | Variadic: any number of leading batch dimensions |
Tensor (bare) | Shape unknown — tracking gap |
Variadic dimensions with *Bs
Use TypeVarTuple (or PEP 646 *Bs syntax) for dimensions that should be
propagated without being enumerated:
def forward[*Bs](self, x: Tensor[*Bs, InDim]) -> Tensor[*Bs, OutDim]:
...
This accepts any number of leading dimensions (batch, sequence, etc.) and preserves them in the output.
Don't hide known class dims inside variadic params. If the module has a
class-level Dim D, use Tensor[*Bs, D] not Tensor[*S].
.shape and .size()
When x: Tensor[B, C, H, W]:
x.shapehas typetuple[Dim[B], Dim[C], Dim[H], Dim[W]]x.size(0)has typeDim[B]x.size()has typetuple[Dim[B], Dim[C], Dim[H], Dim[W]]
This means you can extract dimensions from tensors and use them to construct new tensors with matching shapes.
assert_type
assert_type(expr, Type) is checked by the type checker: it verifies that
expr has exactly the stated type. If the types don't match, the checker
reports an error.
h = self.fc1(x)
assert_type(h, Tensor[B, 512]) # checked by pyrefly
Use assert_type during development to verify inferred shapes as you port
a model. Once the port is complete, remove the assert_type calls — each
one corresponds to an inlay type hint that your IDE shows permanently.
Pyrefly catches shape errors through function signatures and return types
regardless.
assert_type forces evaluation of its type argument at runtime, so a file
with assert_type calls will crash if executed. This is fine during
development (you run pyrefly check, not the file itself) — just remove
them when the port is done.
When to use
- During porting, after key shape-changing operations (reshapes, convolutions, matmuls)
- As regression guards for complex shape computations
In practice, pyrefly shows inferred shapes as inlay type hints in your
editor, so you can verify shapes visually. Use assert_type at key
checkpoints where you want a permanent regression guard.
reveal_type
reveal_type(expr) prints the inferred type of expr during type checking.
Use it to understand what pyrefly infers before writing assert_type:
h = self.fc1(x)
reveal_type(h) # Revealed type: Tensor[B, 512]
Replace reveal_type with assert_type once you know the expected type.
Type-level arithmetic
Annotations can contain arithmetic on type parameters and literals:
| Expression | Example |
|---|---|
| Addition | Tensor[B, C1 + C2, H, W] — concatenation |
| Subtraction | Tensor[B, T, D - 1] |
| Multiplication | Tensor[B, NHead * DK] — multi-head reshape |
| Floor division | Tensor[B, NHead, T, D // NHead] |
| Exponentiation | Tensor[B, C * 2 ** I, H // 2 ** I] |
Simplification rules
The type checker automatically simplifies expressions:
2 * C // 2→C(H - 1) * 2 + 2→H * 2(a * b) // b→a(sound for all positive integers)
Known limitations
N * (X // N) does not simplify to X — floor division loses the
remainder, so the equivalence only holds when X is divisible by N.
The checker can't assume this. Common instances:
- Multi-head reassembly:
NHead * (D // NHead)— usetype: ignore - BiLSTM output:
2 * (D // 2)— usetype: ignore
Annotation hierarchy
When annotating local variables, choose from most to least desirable:
assert_type— verifies the checker's inference. Proves the system works, not just that you annotated correctly.- Annotation fallback —
x: Tensor[B, C, H, W] = untracked_op(...). The checker can't infer the shape, but the annotation is compatible. Document WHY. type: ignore— the checker produces a WRONG type (algebraic gap). Last resort. Always include a comment explaining the specific gap.- Bare
Tensor— shape genuinely unknowable (data-dependent token counts, conditional accumulation). Document the specific reason.
Jaxtyping compatibility
Pyrefly supports jaxtyping annotations as an alternative front-end:
| Pyrefly native | Jaxtyping equivalent |
|---|---|
Tensor[M, 2, M // 2] | Shaped[Tensor, "M 2 M//2"] |
Tensor[B, C, H, W] | Shaped[Tensor, "B C H W"] |
Jaxtyping annotations are translated internally to generics and display back in jaxtyping syntax. Note that jaxtyping cannot share symbolic dimensions across class boundaries — see the overview for details.