Skip to main content

API Reference

This page documents the types, functions, and type-level constructs that make up pyrefly's tensor shape type system.

Dim[X]

Dim[X] is a type constructor that bridges runtime integer values to type-level symbols. It is defined in the torch_shapes package.

Basics

Dim[X] denotes the type of an integer value whose type-level identity is X. For example:

  • Dim[5] is the type of the literal 5
  • Dim[N] (where N is a type variable) is the type of an integer whose value is bound to N at the type level

Dim is a subtype of int, so Dim values can be used anywhere int is expected. However, the reverse is not true — passing a plain int where Dim[X] is expected loses tracking.

Arithmetic

Arithmetic on Dim values produces Dim results with the corresponding type-level expression:

ExpressionType
a + b where a: Dim[A], b: Dim[B]Dim[A + B]
a - bDim[A - B]
a * bDim[A * B]
a // bDim[A // B]
a ** bDim[A ** B]

Caution: int * Dim produces Unknown because the int side has no type-level identity. Use Dim * Dim or literal * Dim instead.

Dim[X] | None

For optional dimensions — parameters that may or may not be present — use Dim[X] | None. In the forward method, narrow with if param is not None: to recover Dim[X] inside the branch:

class Attention[D, RK](nn.Module):
def __init__(self, dim: Dim[D], rank_k: Dim[RK] | None = None):
...

def forward[B, T](self, x: Tensor[B, T, D]) -> Tensor[B, T, D]:
if self.rank_k is not None:
# rank_k is Dim[RK] here
...

Usage patterns

PatternPurpose
def __init__(self, dim: Dim[D])Accept a dimension as a constructor parameter
class Model[D](nn.Module)Make a dimension a class-level type parameter
def forward[B](self, x: Tensor[B, D])Bind a per-call dimension
self.head_dim = dim // n_headCompute a derived dimension (Dim[D // NHead])

Tensor[D1, D2, ...]

Tensor with type arguments represents a tensor with a known shape. The type arguments are the dimensions, in order.

Forms

FormMeaning
Tensor[3, 4]Concrete 2D tensor with shape (3, 4)
Tensor[B, C, H, W]Generic 4D tensor with symbolic dimensions
Tensor[B, 3 * C, H // 2]Dimensions can contain arithmetic expressions
Tensor[*Bs, D]Variadic: any number of leading batch dimensions
Tensor (bare)Shape unknown — tracking gap

Variadic dimensions with *Bs

Use TypeVarTuple (or PEP 646 *Bs syntax) for dimensions that should be propagated without being enumerated:

def forward[*Bs](self, x: Tensor[*Bs, InDim]) -> Tensor[*Bs, OutDim]:
...

This accepts any number of leading dimensions (batch, sequence, etc.) and preserves them in the output.

Don't hide known class dims inside variadic params. If the module has a class-level Dim D, use Tensor[*Bs, D] not Tensor[*S].

.shape and .size()

When x: Tensor[B, C, H, W]:

  • x.shape has type tuple[Dim[B], Dim[C], Dim[H], Dim[W]]
  • x.size(0) has type Dim[B]
  • x.size() has type tuple[Dim[B], Dim[C], Dim[H], Dim[W]]

This means you can extract dimensions from tensors and use them to construct new tensors with matching shapes.

assert_type

assert_type(expr, Type) is checked by the type checker: it verifies that expr has exactly the stated type. If the types don't match, the checker reports an error.

h = self.fc1(x)
assert_type(h, Tensor[B, 512]) # checked by pyrefly

Use assert_type during development to verify inferred shapes as you port a model. Once the port is complete, remove the assert_type calls — each one corresponds to an inlay type hint that your IDE shows permanently. Pyrefly catches shape errors through function signatures and return types regardless.

assert_type forces evaluation of its type argument at runtime, so a file with assert_type calls will crash if executed. This is fine during development (you run pyrefly check, not the file itself) — just remove them when the port is done.

When to use

  • During porting, after key shape-changing operations (reshapes, convolutions, matmuls)
  • As regression guards for complex shape computations

In practice, pyrefly shows inferred shapes as inlay type hints in your editor, so you can verify shapes visually. Use assert_type at key checkpoints where you want a permanent regression guard.

reveal_type

reveal_type(expr) prints the inferred type of expr during type checking. Use it to understand what pyrefly infers before writing assert_type:

h = self.fc1(x)
reveal_type(h) # Revealed type: Tensor[B, 512]

Replace reveal_type with assert_type once you know the expected type.

Type-level arithmetic

Annotations can contain arithmetic on type parameters and literals:

ExpressionExample
AdditionTensor[B, C1 + C2, H, W] — concatenation
SubtractionTensor[B, T, D - 1]
MultiplicationTensor[B, NHead * DK] — multi-head reshape
Floor divisionTensor[B, NHead, T, D // NHead]
ExponentiationTensor[B, C * 2 ** I, H // 2 ** I]

Simplification rules

The type checker automatically simplifies expressions:

  • 2 * C // 2C
  • (H - 1) * 2 + 2H * 2
  • (a * b) // ba (sound for all positive integers)

Known limitations

N * (X // N) does not simplify to X — floor division loses the remainder, so the equivalence only holds when X is divisible by N. The checker can't assume this. Common instances:

  • Multi-head reassembly: NHead * (D // NHead) — use type: ignore
  • BiLSTM output: 2 * (D // 2) — use type: ignore

Annotation hierarchy

When annotating local variables, choose from most to least desirable:

  1. assert_type — verifies the checker's inference. Proves the system works, not just that you annotated correctly.
  2. Annotation fallbackx: Tensor[B, C, H, W] = untracked_op(...). The checker can't infer the shape, but the annotation is compatible. Document WHY.
  3. type: ignore — the checker produces a WRONG type (algebraic gap). Last resort. Always include a comment explaining the specific gap.
  4. Bare Tensor — shape genuinely unknowable (data-dependent token counts, conditional accumulation). Document the specific reason.

Jaxtyping compatibility

Pyrefly supports jaxtyping annotations as an alternative front-end:

Pyrefly nativeJaxtyping equivalent
Tensor[M, 2, M // 2]Shaped[Tensor, "M 2 M//2"]
Tensor[B, C, H, W]Shaped[Tensor, "B C H W"]

Jaxtyping annotations are translated internally to generics and display back in jaxtyping syntax. Note that jaxtyping cannot share symbolic dimensions across class boundaries — see the overview for details.