Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
b556aec
seed with plans and research from demo
clsandoval Nov 4, 2025
f173030
Add comprehensive ONNX backend TDD plans and research
clsandoval Nov 4, 2025
321157e
Remove outdated YOLO-specific research files
clsandoval Nov 4, 2025
0e58ed5
Integrate Hypothesis property-based testing into ONNX backend TDD plan
clsandoval Nov 5, 2025
31fb2c5
Add /review-plan command for interactive plan confidence-building
clsandoval Nov 5, 2025
0039a41
Integrate Hypothesis property-based testing into ONNX backend Tier 2-…
clsandoval Nov 5, 2025
5999d62
Add ONNX backend infrastructure and core dispatch system
clsandoval Nov 5, 2025
5044404
Add ONNX support for 20 Tier 1 elementwise operations
clsandoval Nov 5, 2025
ec61d79
Add ONNX support for shape operations (DimShuffle)
clsandoval Nov 5, 2025
2908352
Add high-level ONNX export API
clsandoval Nov 5, 2025
cf2d445
Add comprehensive test suite for ONNX backend
clsandoval Nov 5, 2025
55ac06c
Add uv.lock with ONNX dependencies
clsandoval Nov 5, 2025
9e47c4c
Add post-implementation analysis for ONNX backend Phase 1-3
clsandoval Nov 5, 2025
414b0cd
Update Tier 2-3 ONNX plan for Phase 1-3 infrastructure compatibility
clsandoval Nov 5, 2025
8e827e9
Split ONNX Tier 2-3 plan into Phase 0 prerequisite and main implement…
clsandoval Nov 5, 2025
2cfcaa4
Implement ONNX dispatcher extension for multi-node operations (Phase 0)
clsandoval Nov 5, 2025
8a49018
Mark Phase 0 dispatcher extension as complete
clsandoval Nov 5, 2025
787f0b0
Add post-implementation analysis to Phase 0 dispatcher extension plan
clsandoval Nov 5, 2025
0667634
Implement ONNX dispatchers for Tier 2-3 operations
clsandoval Nov 5, 2025
c6aeb27
Fix ONNX backend type handling and API issues
clsandoval Nov 5, 2025
4d505e8
Add implementation notes and bugfix documentation
clsandoval Nov 5, 2025
1f24bf3
Implement AdvancedSubtensor ONNX dispatcher for integer array indexing
clsandoval Nov 5, 2025
a987659
Implement Join and Split ONNX dispatchers
clsandoval Nov 5, 2025
0b11ba7
Implement IncSubtensor for ONNX backend
clsandoval Nov 7, 2025
bba554f
Expand ONNX dispatch support for Tier 4-5 operations
clsandoval Nov 8, 2025
ac33055
Add comprehensive test suite for ONNX operations
clsandoval Nov 8, 2025
10b546f
Update project configuration and documentation
clsandoval Nov 8, 2025
be45132
phased property based testing plans
clsandoval Nov 9, 2025
490862b
Implement ELEMWISE_OPERATIONS registry and test infrastructure
clsandoval Nov 10, 2025
d0fb0d0
Update Phase 1 TDD plan with completion status and analysis
clsandoval Nov 10, 2025
0392c88
Fix IntDiv, Clip, and Squeeze ONNX operation implementations
clsandoval Nov 11, 2025
8a6912b
Add comprehensive property-based tests for ONNX operations
clsandoval Nov 11, 2025
f6d7cb8
Document TDD completion status and registry design rationale
clsandoval Nov 11, 2025
db6fe34
Add property-based tests for subtensor operations
clsandoval Nov 11, 2025
f2735d6
Add post-implementation analysis to Phase 4 subtensor plan
clsandoval Nov 11, 2025
34b0239
Remove .claude directory, thoughts, and markdown files
clsandoval Dec 7, 2025
c877068
Refactor ONNX backend dispatch and improve test coverage
clsandoval Dec 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,4 @@ pytensor-venv/
testing-report.html
coverage.xml
.coverage.*
.hypothesis/
23 changes: 23 additions & 0 deletions pytensor/link/onnx/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""ONNX backend for PyTensor.

This module provides functionality to export PyTensor graphs to ONNX format
and execute them using ONNX Runtime.
"""

from pytensor.link.onnx.dispatch import onnx_funcify, onnx_typify
from pytensor.link.onnx.export import compile_onnx, export_function_onnx, export_onnx
from pytensor.link.onnx.linker import ONNXLinker


# ONNX opset version used by default
ONNX_OPSET_VERSION = 18

__all__ = [
"ONNX_OPSET_VERSION",
"ONNXLinker",
"compile_onnx",
"export_function_onnx",
"export_onnx",
"onnx_funcify",
"onnx_typify",
]
15 changes: 15 additions & 0 deletions pytensor/link/onnx/dispatch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""ONNX dispatch system for converting PyTensor operations to ONNX."""

# isort: off
from pytensor.link.onnx.dispatch.basic import onnx_funcify, onnx_typify

# Load dispatch specializations
import pytensor.link.onnx.dispatch.elemwise
import pytensor.link.onnx.dispatch.shape
import pytensor.link.onnx.dispatch.math
import pytensor.link.onnx.dispatch.tensor_basic
import pytensor.link.onnx.dispatch.subtensor
import pytensor.link.onnx.dispatch.nlinalg
import pytensor.link.onnx.dispatch.nnet

# isort: on
324 changes: 324 additions & 0 deletions pytensor/link/onnx/dispatch/basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,324 @@
"""Core ONNX dispatch functions for converting PyTensor graphs to ONNX."""

from functools import singledispatch

import numpy as np
import onnx
from onnx import helper, numpy_helper

from pytensor.compile.ops import DeepCopyOp
from pytensor.graph import Constant
from pytensor.graph.fg import FunctionGraph


# Mapping from PyTensor dtypes to ONNX TensorProto dtypes
PYTENSOR_DTYPE_TO_ONNX = {
"float32": onnx.TensorProto.FLOAT,
"float64": onnx.TensorProto.DOUBLE,
"int32": onnx.TensorProto.INT32,
"int64": onnx.TensorProto.INT64,
"uint8": onnx.TensorProto.UINT8,
"int8": onnx.TensorProto.INT8,
"uint16": onnx.TensorProto.UINT16,
"int16": onnx.TensorProto.INT16,
"bool": onnx.TensorProto.BOOL,
}


@singledispatch
def onnx_typify(data, dtype=None, name=None, **kwargs):
"""Convert Python/NumPy data to ONNX TensorProto.

Parameters
----------
data : array-like
Data to convert
dtype : str, optional
Data type
name : str, optional
Name for the tensor

Returns
-------
onnx.TensorProto
ONNX tensor representation
"""
# Default: try to convert to numpy array first
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A better default is to raise, we did this before for other backends and have been moving away

if not isinstance(data, np.ndarray):
data = np.array(data, dtype=dtype)
return numpy_helper.from_array(data, name=name)


@onnx_typify.register(np.ndarray)
def onnx_typify_ndarray(data, dtype=None, name=None, **kwargs):
"""Convert NumPy array to ONNX TensorProto."""
if dtype is not None:
data = data.astype(dtype)
return numpy_helper.from_array(data, name=name)


@singledispatch
def onnx_funcify(op, node=None, **kwargs):
"""Convert a PyTensor Op to an ONNX node.

This is the core dispatch function that converts PyTensor operations
to their ONNX equivalents.

Parameters
----------
op : Op or FunctionGraph
The operation or graph to convert
node : Apply, optional
The Apply node containing this operation
**kwargs : dict
Additional arguments passed through the conversion

Returns
-------
onnx.NodeProto or onnx.ModelProto
ONNX representation of the operation

Raises
------
NotImplementedError
If no ONNX conversion is available for this operation
"""
op_type = type(op).__name__
raise NotImplementedError(
f"No ONNX conversion available for: {op_type}. "
f"The operation {op} is not yet supported in the ONNX backend."
)


def make_value_info(var, name):
"""Create ONNX ValueInfoProto from PyTensor Variable.

Parameters
----------
var : Variable
PyTensor variable
name : str
Name for the ONNX value

Returns
-------
onnx.ValueInfoProto
ONNX value info with shape and dtype
"""
# Get dtype
dtype_str = var.type.dtype
if dtype_str not in PYTENSOR_DTYPE_TO_ONNX:
raise ValueError(
f"Unsupported dtype: {dtype_str}. "
f"Supported dtypes: {list(PYTENSOR_DTYPE_TO_ONNX.keys())}"
)
onnx_dtype = PYTENSOR_DTYPE_TO_ONNX[dtype_str]

# Get shape - handle both static and symbolic shapes
# For now, we'll use None for unknown dimensions
ndim = var.type.ndim
shape = [None] * ndim # Unknown dimensions
Comment on lines +117 to +120
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Get shape - handle both static and symbolic shapes
# For now, we'll use None for unknown dimensions
ndim = var.type.ndim
shape = [None] * ndim # Unknown dimensions
shape = var.type.shape

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about non-TensorVariables? Are we raising explicitly if not supported? Examples include Slices, TypedLists, RandomGenerator, SparseTensorVariables


# Create tensor type
return helper.make_tensor_value_info(name, onnx_dtype, shape)


@onnx_funcify.register(FunctionGraph)
def onnx_funcify_FunctionGraph(
fgraph,
opset_version=18,
**kwargs,
):
"""Convert a PyTensor FunctionGraph to an ONNX ModelProto.

This function:
1. Does topological sort of nodes
2. Converts each node to ONNX via onnx_funcify
3. Collects constants as initializers
4. Creates ONNX ModelProto with inputs, outputs, and nodes

Operation Handler Return Patterns
----------------------------------
Handlers registered via @onnx_funcify.register can return:

1. **Single node** (most common):
return helper.make_node('Add', inputs=[...], outputs=[...])

2. **Multiple nodes** (operations requiring intermediate steps):
return [
helper.make_node('Shape', ...),
helper.make_node('Gather', ...),
helper.make_node('Slice', ...),
]

3. **Node with initializers** (operations with constant data):
return (
helper.make_node('Transpose', ...),
[axes_initializer], # List of TensorProto initializers
)

4. **None** (no-op, pass-through):
return None

Notes:
- List items can be None (will be filtered out)
- Tuple pattern is (node, [initializers]), not (node, initializer)
- Cannot mix patterns: either list OR tuple, not both

Parameters
----------
fgraph : FunctionGraph
The function graph to convert
opset_version : int
ONNX opset version to use

Returns
-------
onnx.ModelProto
Complete ONNX model
"""
# Track variable names to ensure uniqueness
var_names = {}
var_counter = 0

def get_var_name(var):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's a unique_name_generator helper already in link.utils that I think you can reuse

"""Get or create unique name for a variable."""
nonlocal var_counter
if var not in var_names:
if hasattr(var, "name") and var.name:
base_name = var.name
else:
base_name = "var"
# Ensure uniqueness
name = f"{base_name}_{var_counter}"
var_counter += 1
var_names[var] = name
return var_names[var]

# Collect all nodes in topological order
nodes = []
initializers = []

# Process constants first
for var in fgraph.variables:
if isinstance(var, Constant):
name = get_var_name(var)
# Convert constant to ONNX initializer
# Special handling: if constant is a scalar int type and is used in operations
# with float tensors, upcast to float32 to avoid type mismatches
data = var.data
if data.ndim == 0 and np.issubdtype(data.dtype, np.integer):
# Check if this constant is used with float operations
# For now, we'll upcast all scalar integer constants to float32
# This is a simplification but handles the common case of: x * 2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't sound safe. Constants show up a lot in indexing operations for example x[:2], you wouldn't want to make that a float. Any implicit casting should be done by the Op that needs it, or is there a more fundamental onnx limitation here?

# where x is float and 2 is an int scalar
data = data.astype("float32")

tensor_proto = onnx_typify(data, name=name)
initializers.append(tensor_proto)

# Process each node in topological order
for node in fgraph.toposort():
# Convert node via dispatch
result = onnx_funcify(
node.op,
node=node,
var_names=var_names,
get_var_name=get_var_name,
**kwargs,
)

# Handle multiple return patterns from operation handlers
if result is not None:
if isinstance(result, list):
# Multiple nodes - add all to graph
# Used for operations that compile to multiple ONNX ops
# Example: Shape_i returns [Constant, Shape, Gather]
nodes.extend(item for item in result if item is not None)
elif isinstance(result, tuple):
# Returned (node, additional_initializers)
# Used for operations with constant initializers
# Example: DimShuffle returns (Transpose, [axes_tensor])
onnx_node, node_initializers = result
if onnx_node is not None:
nodes.append(onnx_node)
if node_initializers:
initializers.extend(node_initializers)
else:
# Returned single node (most common case)
# Example: Add returns single Add node
nodes.append(result)
else:
# Handler returned None - this is a no-op operation
Copy link
Member

@ricardoV94 ricardoV94 Dec 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you really need this: Make the handler return a specific sentinel ONNX_NO_OP instead of None to avoid subtle errors where users just forget to return something?

Given you have identity node, below it sounds like you don't need it though

# Map output variables to input variables (pass-through)
# This is used for operations like SpecifyShape that don't
# change the data, only provide shape hints for optimization
if len(node.outputs) == 1 and len(node.inputs) > 0:
# For single-output ops, alias output to first input
output_var = node.outputs[0]
input_var = node.inputs[0]
# Map the output to use the same name as the input
if output_var not in var_names:
var_names[output_var] = get_var_name(input_var)

# Create input ValueInfos
inputs = []
for inp in fgraph.inputs:
if not isinstance(inp, Constant):
name = get_var_name(inp)
value_info = make_value_info(inp, name)
inputs.append(value_info)

# Create output ValueInfos
outputs = []
for out in fgraph.outputs:
name = get_var_name(out)
value_info = make_value_info(out, name)
outputs.append(value_info)

# Create the graph
graph_def = helper.make_graph(
nodes=nodes,
name="pytensor_graph",
inputs=inputs,
outputs=outputs,
initializer=initializers,
)

# Create the model with IR version 9 for compatibility with ONNX Runtime
model_def = helper.make_model(
graph_def,
opset_imports=[helper.make_opsetid("", opset_version)],
producer_name="PyTensor",
ir_version=9, # Use IR version 9 for ONNX Runtime compatibility
)

# Check the model
onnx.checker.check_model(model_def)

return model_def


@onnx_funcify.register(Constant)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Constants aren't nodes so you shouldn't need to register them for funcify

def onnx_funcify_Constant(op, **kwargs):
"""Constants are handled as initializers, not nodes."""
# Constants don't produce nodes - they're added as initializers
# in the FunctionGraph converter
return None


@onnx_funcify.register(DeepCopyOp)
def onnx_funcify_DeepCopyOp(op, node, get_var_name, **kwargs):
"""Convert DeepCopyOp to ONNX Identity node.

DeepCopyOp is equivalent to Identity in ONNX.
"""
input_names = [get_var_name(inp) for inp in node.inputs]
output_names = [get_var_name(out) for out in node.outputs]

return helper.make_node(
"Identity",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you have Identity in ONNX why can't you use those for the Ops like SpecifyShape that you are handling with special casing in the FunctionGraph loop?

inputs=input_names,
outputs=output_names,
name=f"Identity_{output_names[0]}",
)
Loading
Loading