Skip to content

Conversation

@clsandoval
Copy link

Description

This PR adds a complete ONNX backend for PyTensor, enabling export of PyTensor computational graphs to ONNX format and execution via ONNX Runtime.

Key features:

  • ONNX Linker: Core infrastructure for converting PyTensor graphs to ONNX models
  • High-level export API: export_onnx(), compile_onnx(), and export_function_onnx() for easy model export
  • Comprehensive dispatcher system: Type-based dispatch for converting PyTensor Ops to ONNX nodes
  • Operation support across multiple categories:
    • Elementwise operations (Add, Sub, Mul, Div, Exp, Log, trigonometric functions, etc.)
    • Shape operations (Reshape, DimShuffle, Transpose, Join, Split)
    • Subtensor operations (basic slicing, advanced indexing, IncSubtensor)
    • Math operations (Sum, Prod, Max, Min, Mean, Argmax, Argmin)
    • Linear algebra (MatMul, Dot, BatchedDot, matrix operations)
    • Neural network ops (Softmax, LogSoftmax, Sigmoid, Conv2d, MaxPool)
  • Property-based testing: Hypothesis-powered test strategies for robust validation

Files added:

  • pytensor/link/onnx/ - Core ONNX backend implementation (13 files, ~2k lines)
  • tests/link/onnx/ - Comprehensive test suite (18 files, ~4k lines, 186 tests)

Related Issue

  • Closes #
  • Related #

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

clsandoval and others added 30 commits November 4, 2025 05:28
Added TDD implementation plans covering all 5 tiers of ONNX backend:
- Phase 1-3: Infrastructure and Tier 1 (20 elemwise operations)
- Tier 2-3: Shape operations and reductions (31 operations)
- Tier 4-5: Linear algebra and advanced operations (63 operations)

Includes production roadmap, infrastructure analysis, and development environment setup research. Removed outdated JAX-focused plans and YOLO-specific research.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>
Cleaned up remaining YOLO11n and CNN-specific research that's no longer relevant to the ONNX backend production implementation focus.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>
Transform the TDD plan from manual test-by-test approach to Hypothesis-based
property testing:

- Replace 40+ manual tests with ~20-25 focused tests (16 total, validating 200+ scenarios)
- Add operation registry pattern for scalable test coverage
- Implement bulk operations via SCALAR_OP_TO_ONNX mapping
- Update all phases to use `uv run` consistently
- Restructure Phase 3 to emphasize bulk implementation strategy

Key changes:
- Desired End State: Now includes Hypothesis strategies and property tests
- Phase 2: Added Hypothesis verification steps before testing
- Phase 3: Streamlined from 1165 lines to 500 lines focusing on bulk approach
- All commands now use `uv run` prefix

This approach allows adding 20 operations with one mapping dict instead of
writing 20+ individual test functions, while automatically generating edge
cases and validating 200+ scenarios per test run.
New slash command that facilitates section-by-section plan review before
implementation. Enables collaborative validation where Claude:
- Presents understanding of each section in own words
- Waits for user confirmation before proceeding
- Makes updates to the plan as requested
- Builds incremental confidence through structured review

Usage: /review-plan thoughts/shared/plans/your-plan.md

This mirrors the workflow used to integrate Hypothesis into the ONNX TDD plan,
making it reusable for future plan reviews.
…3 plan

Replace manual test approach (45+ individual tests) with scalable
Hypothesis-based architecture using operation registries. New approach
automatically tests 31 operations with ~15-20 property tests instead
of 45+ manual tests, improving maintainability and coverage.

Key improvements:
- Operation registries for shape ops, reductions, allocations, and
  subtensor operations
- Hypothesis strategies for automatic test case generation
- Property tests that validate all operations systematically
- Reduced test count while increasing coverage through edge case
  generation
Implement the foundational ONNX backend infrastructure including:
- Core dispatch system using singledispatch pattern (onnx_funcify, onnx_typify)
- ONNXLinker that converts PyTensor graphs to ONNX ModelProto
- ONNX Runtime integration for graph execution
- Support for FunctionGraph to ONNX model conversion
- Handlers for Constants, DeepCopyOp (Identity)

Uses ONNX opset 18 with IR version 9 for compatibility with ONNX Runtime.
Implement elementwise operation conversion using a mapping-based approach:
- Arithmetic: Add, Sub, Mul, Div, Neg, IntDiv
- Math: Abs, Exp, Log, Sqrt, Pow, Floor, Ceil, Round
- Min/Max: Maximum, Minimum

All operations handled through a single converter function with a
SCALAR_OP_TO_ONNX mapping dictionary for maintainability.
Implement DimShuffle conversion supporting:
- Unsqueeze: Adding dimensions for broadcasting
- Transpose: Permuting dimensions
- Squeeze: Removing dimensions

Handles ONNX opset 13+ requirement for axes as separate input tensors.
Provide user-facing functions for ONNX export:
- export_onnx(): Export PyTensor graphs to .onnx files
- compile_onnx(): Compile graphs for ONNX Runtime execution
- export_function_onnx(): Export compiled PyTensor functions to ONNX
Implement 30 tests covering:
- Module structure and imports (3 tests)
- Core dispatch system (3 tests)
- ONNXLinker functionality (3 tests)
- Elementwise operations (16 tests)
- Export API (3 tests)
- Testing utilities (2 tests)

Includes compare_onnx_and_py utility for validating ONNX Runtime
output against Python reference implementation.

Current status: 27/30 tests passing (90% success rate).
Lock ONNX and ONNX Runtime versions for reproducible builds.
Document what worked as planned and divergences from the original TDD plan,
including:
- Infrastructure successes (dispatch system, singledispatch pattern)
- Test approach divergence (traditional tests vs Hypothesis property tests)
- Implementation gaps (DimShuffle needed earlier than planned, API mismatches)
- Bugs encountered (mixed-type arithmetic, tuple return handling)
- Lessons learned for future TDD planning

This analysis helps improve future planning by documenting actual
implementation experience against initial estimates.
Review and update the Tier 2-3 implementation plan based on the actual
Phase 1-3 infrastructure that was implemented. Key changes:

- Add Phase 0: Dispatcher extension for multi-node operations
  - Many Tier 2-3 ops (Shape_i, DimShuffle, MakeVector) return multiple
    ONNX nodes, requiring list return support
  - Simple 4-line extension to existing dispatcher pattern

- Fix all return patterns to match actual infrastructure
  - Verified all code examples use correct patterns (list/tuple/single/None)
  - Added comprehensive return pattern documentation

- Expand IncSubtensor implementation details
  - Add ScatterND vs ScatterElements decision tree
  - Detail set_subtensor vs inc_subtensor handling
  - Provide phased implementation strategy

- Add Subtensor negative indexing conversion details
  - Show Shape → Gather → Add conversion pattern
  - Handle both simple (non-negative) and complex (negative) cases

- Add Join/Split implementation examples

Plan is now production-ready with no expected implementation snags.
…ation

Refactored the monolithic Tier 2-3 TDD plan to improve clarity and execution:

- Extracted Phase 0 (dispatcher extension) into separate 30-minute plan
- Phase 0 now includes Shape, Shape_i, and SpecifyShape as reference implementations
- Updated main Tier 2-3 plan to require Phase 0 completion as prerequisite
- Clarified that 3 shape operations are already complete from Phase 0
- Updated metadata: status is now "ready-to-implement", timeline reflects actual scope

This separation makes it clear that Phase 0 is a quick foundational step
required before tackling the remaining 28 Tier 2-3 operations.
Extend the ONNX dispatcher to support operations that compile to multiple
ONNX nodes, which is required for Tier 2-3 operations.

Changes:
- Extend dispatcher to handle list returns from operation handlers
- Add None return handling with proper variable aliasing for pass-through ops
- Document all 4 return patterns (single node, multiple nodes, node with
  initializers, None) with examples
- Implement Shape, Shape_i, and SpecifyShape operations
- Shape_i demonstrates multi-node pattern: returns [Constant, Shape, Gather]
- SpecifyShape demonstrates None pattern: pass-through with no ONNX nodes
- Add comprehensive test suite with 5 tests covering all patterns

All tests passing, no regressions in existing functionality.
All success criteria met:
- Dispatcher handles list and None returns correctly
- Shape operations implemented and tested
- Multi-node pattern demonstrated with Shape_i
- No regressions in existing tests
Append comprehensive post-implementation analysis documenting what diverged
between the Phase 0 TDD plan and actual implementation:

Key Findings:
- Plan was created as retrospective documentation (unusual but valuable)
- Implementation completed in ~34 minutes (matched ~30 min estimate)
- Zero bugs encountered - clean first implementation
- Scope expanded to include DimShuffle and type(None) handler (module cohesion)
- Integration tests sufficed without dedicated dispatcher unit tests

Divergences Documented:
- Scope: Added DimShuffle + type(None) handler beyond minimal plan
- Tests: Consolidated into test_shape.py instead of separate dispatcher tests
- Naming: Improved test names (passthrough vs removed)

Lessons Learned:
- Define scope at module level when ops are tightly related
- Integration tests can replace unit tests when they cover all patterns
- Retrospective plans accurately capture timeline and real challenges
- Small focused scope (1-2 hours) works well for infrastructure

Patterns Documented:
- Multi-node return pattern with code examples
- Tuple with initializers pattern
- None return pass-through pattern
- None op handler pattern

Includes: Timeline analysis, success criteria verification, comparison table,
git commit references, file:line locations, and recommendations for Tier 2-3
implementation.
Add support for reduction operations (sum, prod, max, min, argmax),
subtensor operations (basic slicing), and tensor creation operations
(alloc, arange, make_vector).

Key implementations:
- CAReduce dispatcher maps PyTensor reductions to ONNX ReduceSum/Prod/Max/Min
- Argmax/Argmin with single-axis support
- Subtensor for 1D/2D/3D slicing with constant non-negative indices
- Alloc, AllocEmpty, MakeVector, ARange for tensor creation

Tests include property-based testing with Hypothesis for comprehensive
coverage across multiple operation types.
Fix three critical bugs blocking ONNX tests:

1. Argmax axis parameter: PyTensor stores axis as tuple (1,) but ONNX
   expects scalar int. Extract first element from tuple.

2. Scalar constant types: PyTensor defaults to int8 for scalar integers,
   causing type mismatches with float32 tensors in ONNX. Auto-upcast
   scalar integer constants to float32.

3. Export function: construct_nominal_fgraph returns tuple, not
   FunctionGraph directly. Extract first element.

Fixes enable all 62 tests to pass (5 intentionally skipped).
Document the subtensor implementation status, known issues, and next
steps in IMPLEMENTATION_NOTES.md.

Add comprehensive bugfix documentation detailing the three bugs fixed,
their root causes, solutions, and test results.

Update the Tier 2-3 plan to mark completed implementations.
Add support for AdvancedSubtensor operations in the ONNX backend,
enabling integer array indexing like x[indices]. This complements
the existing AdvancedSubtensor1 implementation.

Key changes:
- Add AdvancedSubtensor dispatcher using ONNX Gather operation
- Handle simple integer array indexing on axis 0 for both 1D and 2D arrays
- Unskip and enhance test suite with 2 passing tests
- Update implementation plan to mark Implementation 5 and 6 as complete

The implementation was needed because PyTensor creates AdvancedSubtensor
operations (not AdvancedSubtensor1) when using x[indices] syntax in ONNX
mode, which runs without optimizations.

Tests: 10/10 subtensor tests passing (3 appropriately skipped for future work)
Add ONNX export support for Join (concatenate) and Split operations,
completing Implementation 8 of the Tier 2-3 ONNX backend plan.

- Join: Maps to ONNX Concat node with axis as attribute
- Split: Maps to ONNX Split node with split sizes as input tensor
- Both operations require constant axis/split values
- Handle edge cases: uniform vs non-uniform splits

Tests added:
- test_concatenate_axis0/axis1: Verify concatenation along different axes
- test_stack_axis0: Verify stacking operation
- test_split_equal/unequal: Verify splitting with equal and unequal sizes

All tests passing (69 total, 4 intentionally skipped).
Completes Tier 2-3 implementation phase.
Add support for set_subtensor and inc_subtensor operations using ONNX
ScatterElements. This completes all 31 Tier 2-3 operations.

Implementation details:
- Uses Range node to generate indices for the slice
- set_subtensor: directly scatters new values at indices
- inc_subtensor: gathers current values, adds, then scatters sum
- Supports basic 1D slicing with constant bounds (step=1)

Tests:
- Add test_set_subtensor() verifying ScatterElements generation
- Add test_inc_subtensor() verifying Gather/Add/ScatterElements chain
- 71/74 tests now passing (3 intentionally skipped)

Updates plan document to mark Tier 2-3 as complete with all operations
implemented and tested.
Add comprehensive support for advanced operations in ONNX backend:

Elemwise operations:
- Trigonometric: sin, cos, tan, arcsin, arccos, arctan
- Hyperbolic: sinh, cosh, tanh, arcsinh, arccosh, arctanh
- Comparison: lt, gt, le, ge, eq, neq (composed as Equal+Not)
- Logical: and, or, xor, not
- Special: sigmoid, softplus, erf, clip, switch (where)
- Composed ops: log1p (Log(Add(x,1))), expm1 (Sub(Exp(x),1))

Linear algebra operations (nlinalg):
- MatMul with broadcasting support
- Einsum with equation parsing and transposition
- SVD with full matrices support

Neural network operations (nnet):
- Softmax with axis support and numerical stability
- LogSoftmax for stable log-probability computation
- Softplus activation function

Rewrite system:
- Softmax decomposition (exp normalization pattern)
- LogSoftmax optimization (prevents exp overflow)
Test coverage for:
- Linear algebra: matmul, einsum, svd with various shapes and dtypes
- Neural networks: softmax, log_softmax, softplus with numerical stability checks
- Special functions: erf, clip, switch operations
- Extra operations: trigonometric, hyperbolic, comparison, logical ops
- Integration tests: composite operations and real-world patterns

All tests use compare_onnx_and_py() helper for validation against ONNX Runtime.
- Add .hypothesis/ to .gitignore for test artifact exclusion
- Add CLAUDE.md with project instructions for uv workflow
- Document Tier 4-5 implementation progress and challenges
- Add property-based testing master plan for future scalability
- Create phased TDD plans for systematic test coverage expansion
- Add research notes on Hypothesis integration strategy
Add comprehensive test infrastructure for validating the ELEMWISE_OPERATIONS
registry structure and behavior:

- Create test_strategies.py with 24 tests validating registry structure,
  strategy data generation, and build_graph functions
- Add 4 helper strategies for elemwise operations:
  * binary_float32_arrays_strategy() for binary ops
  * unary_float32_array_strategy() for unary ops
  * positive_float32_array_strategy() for log (x > 0)
  * non_negative_float32_array_strategy() for sqrt (x >= 0)
- Implement ELEMWISE_OPERATIONS registry with 18 Tier 1 operations:
  * Binary arithmetic: add, mul, sub, div, int_div, pow
  * Element-wise min/max: maximum, minimum
  * Unary math: neg, abs, exp, log, sqrt
  * Rounding: floor, ceil, round, round_away
  * Special: clip

All 24 tests pass, no regressions in existing test suite (131/148 passing).
Follows TDD approach with tests written first, verified to fail correctly,
then implementation made tests pass.
Mark all Phase 1-3 success criteria as completed and append comprehensive
post-implementation analysis documenting:

- What worked as planned (24 tests, 18 operations, zero bugs)
- Implementation divergences (one-pass vs incremental, refactoring deferred)
- Lessons learned for future TDD planning
- Patterns worth documenting (registry pattern, constrained strategies)
- Metrics (30min implementation time, 495 LOC, 100% success rate)

Analysis shows thorough planning with concrete code examples enables fast,
correct implementation with TDD approach working exactly as intended.
clsandoval and others added 7 commits November 11, 2025 08:39
- IntDiv: Implement as Div + Floor composition instead of plain Div
- Clip: Add Squeeze nodes to convert PyTensor tensor min/max to ONNX scalar bounds
- Squeeze: Update to ONNX opset 13+ format (axes as input tensor, not attribute)

These fixes ensure correct ONNX export for operations with special requirements.
- Add 180+ property-based test scenarios for elemwise operations
- Add property tests for shape operations (shape, reshape, transpose, etc.)
- Fix shape generation strategy to ensure valid tensor shapes
- Fix Shape_i operation to use correct PyTensor API instead of indexing

Test coverage now includes:
- Elemwise: 18 operations (add, mul, sub, div, int_div, log, sqrt, pow, clip, etc.)
- Shape: 9 operations (shape, shape_i, reshape, transpose, dimshuffle, concat, stack)

Property-based tests provide diverse inputs and edge cases automatically.
- Update Phase 2 (elemwise) and Phase 3 (shape) completion status
- Add analysis of property-based testing results and lessons learned
- Document rationale for registry pattern and constrained strategies
- Explain architectural decisions for maintainable test infrastructure

Documentation includes:
- Success metrics: 180+ elemwise scenarios, 90+ shape scenarios
- Design patterns: centralized registry, domain-constrained strategies
- Benefits: maintainability, discoverability, type safety
- Add 4 property-based test functions covering 50+ test scenarios:
  * test_subtensor_basic_slicing_correctness (60 scenarios)
  * test_advanced_subtensor_indexing_correctness (10 scenarios)
  * test_set_subtensor_operation_correctness (10 scenarios)
  * test_inc_subtensor_operation_correctness (10 scenarios)

- Fix registry patterns in strategies.py to follow ELEMWISE pattern:
  * Update SUBTENSOR_OPERATIONS to wrap numpy→PyTensor conversion
  * Update INCSUBTENSOR_OPERATIONS to wrap numpy→PyTensor conversion
  * Ensures build_graph functions properly create symbolic variables

- Add comprehensive module and class documentation
- Document negative index limitation in test docstrings
- Organize tests with clear section markers

All property tests pass, validating ONNX subtensor implementation.
Manual tests retained for documentation and specific edge cases.
Document what diverged from the TDD plan and extract lessons:
- 2 bugs encountered (registry pattern issues)
- 3 test divergences identified with root causes
- 6 concrete recommendations for future TDD planning

Key lessons learned:
- Verify infrastructure patterns before writing tests
- Test infrastructure incrementally (one test at a time)
- Research API constraints before planning

Document reusable patterns:
- Registry lambda wrapping for numpy→PyTensor conversion
- Hypothesis assume() for edge case filtering
- Dual test coverage strategy (property + manual tests)

Implementation was successful after registry pattern fix.
🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <[email protected]>
- Clean up dispatch implementations for shape, subtensor, and tensor_basic ops
- Improve property-based testing strategies
- Fix type annotations and code style issues
- Update test fixtures and assertions
Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @clsandoval

Thanks for opening the PR. Looks pretty promising.

I left a bunch of comments. In general this would have been nicer split into a PR that implements the basics, followed by another that adds most functionality.

If you intend to iterate on this PR to address the comments, please make sure to manually validate against LLM amnesia, I don't want to be policing against undoing of previous agreed changes and stuff of that sort in a PR that has 8k LOC changes.

Also please prompt it to avoid comment galore. I can read the function names. Also tell it avoid stuff like # now testing axis=2 because we told it the case was missing in a previous review.

With that grumpy note behind, Great work!

PS: edit the test yaml config, so onnx tests are only run in a single separate job, and not mixed with the previous tests

onnx.TensorProto
ONNX tensor representation
"""
# Default: try to convert to numpy array first
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A better default is to raise, we did this before for other backends and have been moving away

Comment on lines +117 to +120
# Get shape - handle both static and symbolic shapes
# For now, we'll use None for unknown dimensions
ndim = var.type.ndim
shape = [None] * ndim # Unknown dimensions
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Get shape - handle both static and symbolic shapes
# For now, we'll use None for unknown dimensions
ndim = var.type.ndim
shape = [None] * ndim # Unknown dimensions
shape = var.type.shape

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about non-TensorVariables? Are we raising explicitly if not supported? Examples include Slices, TypedLists, RandomGenerator, SparseTensorVariables

var_names = {}
var_counter = 0

def get_var_name(var):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's a unique_name_generator helper already in link.utils that I think you can reuse

if data.ndim == 0 and np.issubdtype(data.dtype, np.integer):
# Check if this constant is used with float operations
# For now, we'll upcast all scalar integer constants to float32
# This is a simplification but handles the common case of: x * 2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't sound safe. Constants show up a lot in indexing operations for example x[:2], you wouldn't want to make that a float. Any implicit casting should be done by the Op that needs it, or is there a more fundamental onnx limitation here?

# Example: Add returns single Add node
nodes.append(result)
else:
# Handler returned None - this is a no-op operation
Copy link
Member

@ricardoV94 ricardoV94 Dec 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you really need this: Make the handler return a specific sentinel ONNX_NO_OP instead of None to avoid subtle errors where users just forget to return something?

Given you have identity node, below it sounds like you don't need it though

np.testing.assert_array_equal(result, expected)


# Unique Tests
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove these redundant comments. They are also esay to go out of place when we refactor tests around

# ============================================================================


def test_elemwise_registry_exists():
Copy link
Member

@ricardoV94 ricardoV94 Dec 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why test this? Sounds like an implementation details of little consequence

Comment on lines +236 to +243
# ============================================================================
# MANUAL EDGE CASE TESTS
# ============================================================================
# These tests complement the property-based tests above by:
# - Testing specific edge cases and patterns
# - Providing readable examples for documentation
# - Validating 3D operations (more complex than property tests cover)
# ============================================================================
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put comments inside test functions / class so they don't get lost in future refactors. Do err on the side of less comments

@@ -0,0 +1,1083 @@
version = 1
Copy link
Member

@ricardoV94 ricardoV94 Dec 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove uv.lock file. If you want to argue for it, do it in a separate PR

@@ -0,0 +1,748 @@
"""Hypothesis strategies and operation registries for ONNX backend testing."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How should I review this 700 line file? Do we need it as a monolith?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants