-
Notifications
You must be signed in to change notification settings - Fork 149
Onnx backend #1777
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Onnx backend #1777
Conversation
Added TDD implementation plans covering all 5 tiers of ONNX backend: - Phase 1-3: Infrastructure and Tier 1 (20 elemwise operations) - Tier 2-3: Shape operations and reductions (31 operations) - Tier 4-5: Linear algebra and advanced operations (63 operations) Includes production roadmap, infrastructure analysis, and development environment setup research. Removed outdated JAX-focused plans and YOLO-specific research. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
Cleaned up remaining YOLO11n and CNN-specific research that's no longer relevant to the ONNX backend production implementation focus. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
Transform the TDD plan from manual test-by-test approach to Hypothesis-based property testing: - Replace 40+ manual tests with ~20-25 focused tests (16 total, validating 200+ scenarios) - Add operation registry pattern for scalable test coverage - Implement bulk operations via SCALAR_OP_TO_ONNX mapping - Update all phases to use `uv run` consistently - Restructure Phase 3 to emphasize bulk implementation strategy Key changes: - Desired End State: Now includes Hypothesis strategies and property tests - Phase 2: Added Hypothesis verification steps before testing - Phase 3: Streamlined from 1165 lines to 500 lines focusing on bulk approach - All commands now use `uv run` prefix This approach allows adding 20 operations with one mapping dict instead of writing 20+ individual test functions, while automatically generating edge cases and validating 200+ scenarios per test run.
New slash command that facilitates section-by-section plan review before implementation. Enables collaborative validation where Claude: - Presents understanding of each section in own words - Waits for user confirmation before proceeding - Makes updates to the plan as requested - Builds incremental confidence through structured review Usage: /review-plan thoughts/shared/plans/your-plan.md This mirrors the workflow used to integrate Hypothesis into the ONNX TDD plan, making it reusable for future plan reviews.
…3 plan Replace manual test approach (45+ individual tests) with scalable Hypothesis-based architecture using operation registries. New approach automatically tests 31 operations with ~15-20 property tests instead of 45+ manual tests, improving maintainability and coverage. Key improvements: - Operation registries for shape ops, reductions, allocations, and subtensor operations - Hypothesis strategies for automatic test case generation - Property tests that validate all operations systematically - Reduced test count while increasing coverage through edge case generation
Implement the foundational ONNX backend infrastructure including: - Core dispatch system using singledispatch pattern (onnx_funcify, onnx_typify) - ONNXLinker that converts PyTensor graphs to ONNX ModelProto - ONNX Runtime integration for graph execution - Support for FunctionGraph to ONNX model conversion - Handlers for Constants, DeepCopyOp (Identity) Uses ONNX opset 18 with IR version 9 for compatibility with ONNX Runtime.
Implement elementwise operation conversion using a mapping-based approach: - Arithmetic: Add, Sub, Mul, Div, Neg, IntDiv - Math: Abs, Exp, Log, Sqrt, Pow, Floor, Ceil, Round - Min/Max: Maximum, Minimum All operations handled through a single converter function with a SCALAR_OP_TO_ONNX mapping dictionary for maintainability.
Implement DimShuffle conversion supporting: - Unsqueeze: Adding dimensions for broadcasting - Transpose: Permuting dimensions - Squeeze: Removing dimensions Handles ONNX opset 13+ requirement for axes as separate input tensors.
Provide user-facing functions for ONNX export: - export_onnx(): Export PyTensor graphs to .onnx files - compile_onnx(): Compile graphs for ONNX Runtime execution - export_function_onnx(): Export compiled PyTensor functions to ONNX
Implement 30 tests covering: - Module structure and imports (3 tests) - Core dispatch system (3 tests) - ONNXLinker functionality (3 tests) - Elementwise operations (16 tests) - Export API (3 tests) - Testing utilities (2 tests) Includes compare_onnx_and_py utility for validating ONNX Runtime output against Python reference implementation. Current status: 27/30 tests passing (90% success rate).
Lock ONNX and ONNX Runtime versions for reproducible builds.
Document what worked as planned and divergences from the original TDD plan, including: - Infrastructure successes (dispatch system, singledispatch pattern) - Test approach divergence (traditional tests vs Hypothesis property tests) - Implementation gaps (DimShuffle needed earlier than planned, API mismatches) - Bugs encountered (mixed-type arithmetic, tuple return handling) - Lessons learned for future TDD planning This analysis helps improve future planning by documenting actual implementation experience against initial estimates.
Review and update the Tier 2-3 implementation plan based on the actual
Phase 1-3 infrastructure that was implemented. Key changes:
- Add Phase 0: Dispatcher extension for multi-node operations
- Many Tier 2-3 ops (Shape_i, DimShuffle, MakeVector) return multiple
ONNX nodes, requiring list return support
- Simple 4-line extension to existing dispatcher pattern
- Fix all return patterns to match actual infrastructure
- Verified all code examples use correct patterns (list/tuple/single/None)
- Added comprehensive return pattern documentation
- Expand IncSubtensor implementation details
- Add ScatterND vs ScatterElements decision tree
- Detail set_subtensor vs inc_subtensor handling
- Provide phased implementation strategy
- Add Subtensor negative indexing conversion details
- Show Shape → Gather → Add conversion pattern
- Handle both simple (non-negative) and complex (negative) cases
- Add Join/Split implementation examples
Plan is now production-ready with no expected implementation snags.
…ation Refactored the monolithic Tier 2-3 TDD plan to improve clarity and execution: - Extracted Phase 0 (dispatcher extension) into separate 30-minute plan - Phase 0 now includes Shape, Shape_i, and SpecifyShape as reference implementations - Updated main Tier 2-3 plan to require Phase 0 completion as prerequisite - Clarified that 3 shape operations are already complete from Phase 0 - Updated metadata: status is now "ready-to-implement", timeline reflects actual scope This separation makes it clear that Phase 0 is a quick foundational step required before tackling the remaining 28 Tier 2-3 operations.
Extend the ONNX dispatcher to support operations that compile to multiple ONNX nodes, which is required for Tier 2-3 operations. Changes: - Extend dispatcher to handle list returns from operation handlers - Add None return handling with proper variable aliasing for pass-through ops - Document all 4 return patterns (single node, multiple nodes, node with initializers, None) with examples - Implement Shape, Shape_i, and SpecifyShape operations - Shape_i demonstrates multi-node pattern: returns [Constant, Shape, Gather] - SpecifyShape demonstrates None pattern: pass-through with no ONNX nodes - Add comprehensive test suite with 5 tests covering all patterns All tests passing, no regressions in existing functionality.
All success criteria met: - Dispatcher handles list and None returns correctly - Shape operations implemented and tested - Multi-node pattern demonstrated with Shape_i - No regressions in existing tests
Append comprehensive post-implementation analysis documenting what diverged between the Phase 0 TDD plan and actual implementation: Key Findings: - Plan was created as retrospective documentation (unusual but valuable) - Implementation completed in ~34 minutes (matched ~30 min estimate) - Zero bugs encountered - clean first implementation - Scope expanded to include DimShuffle and type(None) handler (module cohesion) - Integration tests sufficed without dedicated dispatcher unit tests Divergences Documented: - Scope: Added DimShuffle + type(None) handler beyond minimal plan - Tests: Consolidated into test_shape.py instead of separate dispatcher tests - Naming: Improved test names (passthrough vs removed) Lessons Learned: - Define scope at module level when ops are tightly related - Integration tests can replace unit tests when they cover all patterns - Retrospective plans accurately capture timeline and real challenges - Small focused scope (1-2 hours) works well for infrastructure Patterns Documented: - Multi-node return pattern with code examples - Tuple with initializers pattern - None return pass-through pattern - None op handler pattern Includes: Timeline analysis, success criteria verification, comparison table, git commit references, file:line locations, and recommendations for Tier 2-3 implementation.
Add support for reduction operations (sum, prod, max, min, argmax), subtensor operations (basic slicing), and tensor creation operations (alloc, arange, make_vector). Key implementations: - CAReduce dispatcher maps PyTensor reductions to ONNX ReduceSum/Prod/Max/Min - Argmax/Argmin with single-axis support - Subtensor for 1D/2D/3D slicing with constant non-negative indices - Alloc, AllocEmpty, MakeVector, ARange for tensor creation Tests include property-based testing with Hypothesis for comprehensive coverage across multiple operation types.
Fix three critical bugs blocking ONNX tests: 1. Argmax axis parameter: PyTensor stores axis as tuple (1,) but ONNX expects scalar int. Extract first element from tuple. 2. Scalar constant types: PyTensor defaults to int8 for scalar integers, causing type mismatches with float32 tensors in ONNX. Auto-upcast scalar integer constants to float32. 3. Export function: construct_nominal_fgraph returns tuple, not FunctionGraph directly. Extract first element. Fixes enable all 62 tests to pass (5 intentionally skipped).
Document the subtensor implementation status, known issues, and next steps in IMPLEMENTATION_NOTES.md. Add comprehensive bugfix documentation detailing the three bugs fixed, their root causes, solutions, and test results. Update the Tier 2-3 plan to mark completed implementations.
Add support for AdvancedSubtensor operations in the ONNX backend, enabling integer array indexing like x[indices]. This complements the existing AdvancedSubtensor1 implementation. Key changes: - Add AdvancedSubtensor dispatcher using ONNX Gather operation - Handle simple integer array indexing on axis 0 for both 1D and 2D arrays - Unskip and enhance test suite with 2 passing tests - Update implementation plan to mark Implementation 5 and 6 as complete The implementation was needed because PyTensor creates AdvancedSubtensor operations (not AdvancedSubtensor1) when using x[indices] syntax in ONNX mode, which runs without optimizations. Tests: 10/10 subtensor tests passing (3 appropriately skipped for future work)
Add ONNX export support for Join (concatenate) and Split operations, completing Implementation 8 of the Tier 2-3 ONNX backend plan. - Join: Maps to ONNX Concat node with axis as attribute - Split: Maps to ONNX Split node with split sizes as input tensor - Both operations require constant axis/split values - Handle edge cases: uniform vs non-uniform splits Tests added: - test_concatenate_axis0/axis1: Verify concatenation along different axes - test_stack_axis0: Verify stacking operation - test_split_equal/unequal: Verify splitting with equal and unequal sizes All tests passing (69 total, 4 intentionally skipped). Completes Tier 2-3 implementation phase.
Add support for set_subtensor and inc_subtensor operations using ONNX ScatterElements. This completes all 31 Tier 2-3 operations. Implementation details: - Uses Range node to generate indices for the slice - set_subtensor: directly scatters new values at indices - inc_subtensor: gathers current values, adds, then scatters sum - Supports basic 1D slicing with constant bounds (step=1) Tests: - Add test_set_subtensor() verifying ScatterElements generation - Add test_inc_subtensor() verifying Gather/Add/ScatterElements chain - 71/74 tests now passing (3 intentionally skipped) Updates plan document to mark Tier 2-3 as complete with all operations implemented and tested.
Add comprehensive support for advanced operations in ONNX backend: Elemwise operations: - Trigonometric: sin, cos, tan, arcsin, arccos, arctan - Hyperbolic: sinh, cosh, tanh, arcsinh, arccosh, arctanh - Comparison: lt, gt, le, ge, eq, neq (composed as Equal+Not) - Logical: and, or, xor, not - Special: sigmoid, softplus, erf, clip, switch (where) - Composed ops: log1p (Log(Add(x,1))), expm1 (Sub(Exp(x),1)) Linear algebra operations (nlinalg): - MatMul with broadcasting support - Einsum with equation parsing and transposition - SVD with full matrices support Neural network operations (nnet): - Softmax with axis support and numerical stability - LogSoftmax for stable log-probability computation - Softplus activation function Rewrite system: - Softmax decomposition (exp normalization pattern) - LogSoftmax optimization (prevents exp overflow)
Test coverage for: - Linear algebra: matmul, einsum, svd with various shapes and dtypes - Neural networks: softmax, log_softmax, softplus with numerical stability checks - Special functions: erf, clip, switch operations - Extra operations: trigonometric, hyperbolic, comparison, logical ops - Integration tests: composite operations and real-world patterns All tests use compare_onnx_and_py() helper for validation against ONNX Runtime.
- Add .hypothesis/ to .gitignore for test artifact exclusion - Add CLAUDE.md with project instructions for uv workflow - Document Tier 4-5 implementation progress and challenges - Add property-based testing master plan for future scalability - Create phased TDD plans for systematic test coverage expansion - Add research notes on Hypothesis integration strategy
Add comprehensive test infrastructure for validating the ELEMWISE_OPERATIONS registry structure and behavior: - Create test_strategies.py with 24 tests validating registry structure, strategy data generation, and build_graph functions - Add 4 helper strategies for elemwise operations: * binary_float32_arrays_strategy() for binary ops * unary_float32_array_strategy() for unary ops * positive_float32_array_strategy() for log (x > 0) * non_negative_float32_array_strategy() for sqrt (x >= 0) - Implement ELEMWISE_OPERATIONS registry with 18 Tier 1 operations: * Binary arithmetic: add, mul, sub, div, int_div, pow * Element-wise min/max: maximum, minimum * Unary math: neg, abs, exp, log, sqrt * Rounding: floor, ceil, round, round_away * Special: clip All 24 tests pass, no regressions in existing test suite (131/148 passing). Follows TDD approach with tests written first, verified to fail correctly, then implementation made tests pass.
Mark all Phase 1-3 success criteria as completed and append comprehensive post-implementation analysis documenting: - What worked as planned (24 tests, 18 operations, zero bugs) - Implementation divergences (one-pass vs incremental, refactoring deferred) - Lessons learned for future TDD planning - Patterns worth documenting (registry pattern, constrained strategies) - Metrics (30min implementation time, 495 LOC, 100% success rate) Analysis shows thorough planning with concrete code examples enables fast, correct implementation with TDD approach working exactly as intended.
- IntDiv: Implement as Div + Floor composition instead of plain Div - Clip: Add Squeeze nodes to convert PyTensor tensor min/max to ONNX scalar bounds - Squeeze: Update to ONNX opset 13+ format (axes as input tensor, not attribute) These fixes ensure correct ONNX export for operations with special requirements.
- Add 180+ property-based test scenarios for elemwise operations - Add property tests for shape operations (shape, reshape, transpose, etc.) - Fix shape generation strategy to ensure valid tensor shapes - Fix Shape_i operation to use correct PyTensor API instead of indexing Test coverage now includes: - Elemwise: 18 operations (add, mul, sub, div, int_div, log, sqrt, pow, clip, etc.) - Shape: 9 operations (shape, shape_i, reshape, transpose, dimshuffle, concat, stack) Property-based tests provide diverse inputs and edge cases automatically.
- Update Phase 2 (elemwise) and Phase 3 (shape) completion status - Add analysis of property-based testing results and lessons learned - Document rationale for registry pattern and constrained strategies - Explain architectural decisions for maintainable test infrastructure Documentation includes: - Success metrics: 180+ elemwise scenarios, 90+ shape scenarios - Design patterns: centralized registry, domain-constrained strategies - Benefits: maintainability, discoverability, type safety
- Add 4 property-based test functions covering 50+ test scenarios: * test_subtensor_basic_slicing_correctness (60 scenarios) * test_advanced_subtensor_indexing_correctness (10 scenarios) * test_set_subtensor_operation_correctness (10 scenarios) * test_inc_subtensor_operation_correctness (10 scenarios) - Fix registry patterns in strategies.py to follow ELEMWISE pattern: * Update SUBTENSOR_OPERATIONS to wrap numpy→PyTensor conversion * Update INCSUBTENSOR_OPERATIONS to wrap numpy→PyTensor conversion * Ensures build_graph functions properly create symbolic variables - Add comprehensive module and class documentation - Document negative index limitation in test docstrings - Organize tests with clear section markers All property tests pass, validating ONNX subtensor implementation. Manual tests retained for documentation and specific edge cases.
Document what diverged from the TDD plan and extract lessons: - 2 bugs encountered (registry pattern issues) - 3 test divergences identified with root causes - 6 concrete recommendations for future TDD planning Key lessons learned: - Verify infrastructure patterns before writing tests - Test infrastructure incrementally (one test at a time) - Research API constraints before planning Document reusable patterns: - Registry lambda wrapping for numpy→PyTensor conversion - Hypothesis assume() for edge case filtering - Dual test coverage strategy (property + manual tests) Implementation was successful after registry pattern fix.
🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <[email protected]>
- Clean up dispatch implementations for shape, subtensor, and tensor_basic ops - Improve property-based testing strategies - Fix type annotations and code style issues - Update test fixtures and assertions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @clsandoval
Thanks for opening the PR. Looks pretty promising.
I left a bunch of comments. In general this would have been nicer split into a PR that implements the basics, followed by another that adds most functionality.
If you intend to iterate on this PR to address the comments, please make sure to manually validate against LLM amnesia, I don't want to be policing against undoing of previous agreed changes and stuff of that sort in a PR that has 8k LOC changes.
Also please prompt it to avoid comment galore. I can read the function names. Also tell it avoid stuff like # now testing axis=2 because we told it the case was missing in a previous review.
With that grumpy note behind, Great work!
PS: edit the test yaml config, so onnx tests are only run in a single separate job, and not mixed with the previous tests
| onnx.TensorProto | ||
| ONNX tensor representation | ||
| """ | ||
| # Default: try to convert to numpy array first |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A better default is to raise, we did this before for other backends and have been moving away
| # Get shape - handle both static and symbolic shapes | ||
| # For now, we'll use None for unknown dimensions | ||
| ndim = var.type.ndim | ||
| shape = [None] * ndim # Unknown dimensions |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| # Get shape - handle both static and symbolic shapes | |
| # For now, we'll use None for unknown dimensions | |
| ndim = var.type.ndim | |
| shape = [None] * ndim # Unknown dimensions | |
| shape = var.type.shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about non-TensorVariables? Are we raising explicitly if not supported? Examples include Slices, TypedLists, RandomGenerator, SparseTensorVariables
| var_names = {} | ||
| var_counter = 0 | ||
|
|
||
| def get_var_name(var): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there's a unique_name_generator helper already in link.utils that I think you can reuse
| if data.ndim == 0 and np.issubdtype(data.dtype, np.integer): | ||
| # Check if this constant is used with float operations | ||
| # For now, we'll upcast all scalar integer constants to float32 | ||
| # This is a simplification but handles the common case of: x * 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't sound safe. Constants show up a lot in indexing operations for example x[:2], you wouldn't want to make that a float. Any implicit casting should be done by the Op that needs it, or is there a more fundamental onnx limitation here?
| # Example: Add returns single Add node | ||
| nodes.append(result) | ||
| else: | ||
| # Handler returned None - this is a no-op operation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you really need this: Make the handler return a specific sentinel ONNX_NO_OP instead of None to avoid subtle errors where users just forget to return something?
Given you have identity node, below it sounds like you don't need it though
| np.testing.assert_array_equal(result, expected) | ||
|
|
||
|
|
||
| # Unique Tests |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove these redundant comments. They are also esay to go out of place when we refactor tests around
| # ============================================================================ | ||
|
|
||
|
|
||
| def test_elemwise_registry_exists(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why test this? Sounds like an implementation details of little consequence
| # ============================================================================ | ||
| # MANUAL EDGE CASE TESTS | ||
| # ============================================================================ | ||
| # These tests complement the property-based tests above by: | ||
| # - Testing specific edge cases and patterns | ||
| # - Providing readable examples for documentation | ||
| # - Validating 3D operations (more complex than property tests cover) | ||
| # ============================================================================ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Put comments inside test functions / class so they don't get lost in future refactors. Do err on the side of less comments
| @@ -0,0 +1,1083 @@ | |||
| version = 1 | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove uv.lock file. If you want to argue for it, do it in a separate PR
| @@ -0,0 +1,748 @@ | |||
| """Hypothesis strategies and operation registries for ONNX backend testing.""" | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How should I review this 700 line file? Do we need it as a monolith?
Description
This PR adds a complete ONNX backend for PyTensor, enabling export of PyTensor computational graphs to ONNX format and execution via ONNX Runtime.
Key features:
export_onnx(),compile_onnx(), andexport_function_onnx()for easy model exportFiles added:
pytensor/link/onnx/- Core ONNX backend implementation (13 files, ~2k lines)tests/link/onnx/- Comprehensive test suite (18 files, ~4k lines, 186 tests)Related Issue
Checklist
Type of change