Skip to main content

Agent Skill for Tensor Shape Porting

Pyrefly includes a skill for Claude Code that can automatically port PyTorch models to use tensor shape annotations. The skill runs a structured, multi-step workflow — auditing ops, inventorying parameters, probing shapes with reveal_type, restructuring for tracking, and verifying the result — that produces a fully annotated port file.

Invoking the skill

In Claude Code, from a project with pyrefly's tensor shape fixtures, ask:

Port this model to use tensor shape types: path/to/model.py

Or use the slash command directly:

/port-model path/to/model.py

The skill will:

  1. Audit ops — check every nn.Module and torch/F. function against fixture stubs and the DSL registry, adding missing stubs before proceeding.
  2. Inventory the original — list every class, method, and function, noting which constructor parameters are Dim vs int.
  3. Port each module in dependency order — for each module, inventory parameters, type the constructor, probe the forward with reveal_type, restructure bare results for tracking, write the forward with assert_type, and fill a post-module checklist.
  4. Verify — run verify_port.sh, audit bare assert_type calls, compare against the style guide, and produce a completion report.

The iterate-and-reflect loop

A single pass usually gets 80-90% of the way there. The remaining gaps — bridge dims, optional dims, parameterized configs — benefit from a second pass where the AI reflects on what it missed.

The workflow:

  1. First pass. Invoke the skill. Review the completion report, paying attention to the bare assert_type fraction and any type: ignore counts.

  2. Reflect. Ask the AI to re-read the skill file and compare what it did against what the skill prescribes. A prompt like:

    Re-read the port-model skill and reflect on the port you just produced.
    What did you miss? What could be improved?

    This typically surfaces missed restructuring opportunities — cases where the AI used a bare assert_type or annotation fallback without attempting all the restructuring steps in Step 4.

  3. Second pass. Ask it to fix the issues it identified. This usually converges: the bare fraction drops and shaped coverage improves.

Most models converge in 2-3 iterations. Complex models with many dynamic patterns (e.g., nn.Sequential(*list), conditional branching, late-initialized buffers) may need one more round.

Reading the output

The skill produces a verify_port.sh report with these key metrics:

MetricWhat it means
ig (ignore)type: ignore comments. Lower is better. Each should have a category (algebraic gap, conditional equality, stub gap).
bs (bare sig)Bare Tensor in function signatures. Should be 0 for well-typed modules.
bv (bare var)Bare Tensor in local variable annotations. Lower is better.
sh (shaped)assert_type calls with full shapes (e.g., Tensor[B, D]). Higher is better.
ba (bare assert)assert_type(x, Tensor) — tracking gaps. Lower is better. Each should have a root-cause comment.
sm (smoke)Smoke test functions. At least 1-2 per model.

A good port has high sh, low ba, and every bare assert documented with a root cause tracing back to a restructuring receipt.

When to intervene

The AI workflow handles most patterns automatically, but some decisions require domain knowledge:

  • Bridge dims. When an untracked section (e.g., a feature extractor built with nn.Sequential(*list)) connects to a tracked downstream module, you may need to decide which dimension to promote to a class type parameter.
  • Missing stubs. If the model uses a library without fixture stubs, you'll need to add minimal stubs for the specific ops used.
  • Architectural choices. When a model has multiple valid ways to restructure for tracking (e.g., separating the first iteration of a loop vs. using a typed interface), domain knowledge about the model's invariants can guide the better choice.
  • Config parameterization. Deciding which config fields should become Dim type parameters (and which modules extract which parameters) is often a design decision about the model's API.