-
Notifications
You must be signed in to change notification settings - Fork 149
Onnx backend #1777
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Onnx backend #1777
Changes from all commits
b556aec
f173030
321157e
0e58ed5
31fb2c5
0039a41
5999d62
5044404
ec61d79
2908352
cf2d445
55ac06c
9e47c4c
414b0cd
8e827e9
2cfcaa4
8a49018
787f0b0
0667634
c6aeb27
4d505e8
1f24bf3
a987659
0b11ba7
bba554f
ac33055
10b546f
be45132
490862b
d0fb0d0
0392c88
8a6912b
f6d7cb8
db6fe34
f2735d6
34b0239
c877068
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -55,3 +55,4 @@ pytensor-venv/ | |
| testing-report.html | ||
| coverage.xml | ||
| .coverage.* | ||
| .hypothesis/ | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,23 @@ | ||
| """ONNX backend for PyTensor. | ||
|
|
||
| This module provides functionality to export PyTensor graphs to ONNX format | ||
| and execute them using ONNX Runtime. | ||
| """ | ||
|
|
||
| from pytensor.link.onnx.dispatch import onnx_funcify, onnx_typify | ||
| from pytensor.link.onnx.export import compile_onnx, export_function_onnx, export_onnx | ||
| from pytensor.link.onnx.linker import ONNXLinker | ||
|
|
||
|
|
||
| # ONNX opset version used by default | ||
| ONNX_OPSET_VERSION = 18 | ||
|
|
||
| __all__ = [ | ||
| "ONNX_OPSET_VERSION", | ||
| "ONNXLinker", | ||
| "compile_onnx", | ||
| "export_function_onnx", | ||
| "export_onnx", | ||
| "onnx_funcify", | ||
| "onnx_typify", | ||
| ] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,15 @@ | ||
| """ONNX dispatch system for converting PyTensor operations to ONNX.""" | ||
|
|
||
| # isort: off | ||
| from pytensor.link.onnx.dispatch.basic import onnx_funcify, onnx_typify | ||
|
|
||
| # Load dispatch specializations | ||
| import pytensor.link.onnx.dispatch.elemwise | ||
| import pytensor.link.onnx.dispatch.shape | ||
| import pytensor.link.onnx.dispatch.math | ||
| import pytensor.link.onnx.dispatch.tensor_basic | ||
| import pytensor.link.onnx.dispatch.subtensor | ||
| import pytensor.link.onnx.dispatch.nlinalg | ||
| import pytensor.link.onnx.dispatch.nnet | ||
|
|
||
| # isort: on |
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,324 @@ | ||||||||||||
| """Core ONNX dispatch functions for converting PyTensor graphs to ONNX.""" | ||||||||||||
|
|
||||||||||||
| from functools import singledispatch | ||||||||||||
|
|
||||||||||||
| import numpy as np | ||||||||||||
| import onnx | ||||||||||||
| from onnx import helper, numpy_helper | ||||||||||||
|
|
||||||||||||
| from pytensor.compile.ops import DeepCopyOp | ||||||||||||
| from pytensor.graph import Constant | ||||||||||||
| from pytensor.graph.fg import FunctionGraph | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| # Mapping from PyTensor dtypes to ONNX TensorProto dtypes | ||||||||||||
| PYTENSOR_DTYPE_TO_ONNX = { | ||||||||||||
| "float32": onnx.TensorProto.FLOAT, | ||||||||||||
| "float64": onnx.TensorProto.DOUBLE, | ||||||||||||
| "int32": onnx.TensorProto.INT32, | ||||||||||||
| "int64": onnx.TensorProto.INT64, | ||||||||||||
| "uint8": onnx.TensorProto.UINT8, | ||||||||||||
| "int8": onnx.TensorProto.INT8, | ||||||||||||
| "uint16": onnx.TensorProto.UINT16, | ||||||||||||
| "int16": onnx.TensorProto.INT16, | ||||||||||||
| "bool": onnx.TensorProto.BOOL, | ||||||||||||
| } | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| @singledispatch | ||||||||||||
| def onnx_typify(data, dtype=None, name=None, **kwargs): | ||||||||||||
| """Convert Python/NumPy data to ONNX TensorProto. | ||||||||||||
|
|
||||||||||||
| Parameters | ||||||||||||
| ---------- | ||||||||||||
| data : array-like | ||||||||||||
| Data to convert | ||||||||||||
| dtype : str, optional | ||||||||||||
| Data type | ||||||||||||
| name : str, optional | ||||||||||||
| Name for the tensor | ||||||||||||
|
|
||||||||||||
| Returns | ||||||||||||
| ------- | ||||||||||||
| onnx.TensorProto | ||||||||||||
| ONNX tensor representation | ||||||||||||
| """ | ||||||||||||
| # Default: try to convert to numpy array first | ||||||||||||
| if not isinstance(data, np.ndarray): | ||||||||||||
| data = np.array(data, dtype=dtype) | ||||||||||||
| return numpy_helper.from_array(data, name=name) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| @onnx_typify.register(np.ndarray) | ||||||||||||
| def onnx_typify_ndarray(data, dtype=None, name=None, **kwargs): | ||||||||||||
| """Convert NumPy array to ONNX TensorProto.""" | ||||||||||||
| if dtype is not None: | ||||||||||||
| data = data.astype(dtype) | ||||||||||||
| return numpy_helper.from_array(data, name=name) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| @singledispatch | ||||||||||||
| def onnx_funcify(op, node=None, **kwargs): | ||||||||||||
| """Convert a PyTensor Op to an ONNX node. | ||||||||||||
|
|
||||||||||||
| This is the core dispatch function that converts PyTensor operations | ||||||||||||
| to their ONNX equivalents. | ||||||||||||
|
|
||||||||||||
| Parameters | ||||||||||||
| ---------- | ||||||||||||
| op : Op or FunctionGraph | ||||||||||||
| The operation or graph to convert | ||||||||||||
| node : Apply, optional | ||||||||||||
| The Apply node containing this operation | ||||||||||||
| **kwargs : dict | ||||||||||||
| Additional arguments passed through the conversion | ||||||||||||
|
|
||||||||||||
| Returns | ||||||||||||
| ------- | ||||||||||||
| onnx.NodeProto or onnx.ModelProto | ||||||||||||
| ONNX representation of the operation | ||||||||||||
|
|
||||||||||||
| Raises | ||||||||||||
| ------ | ||||||||||||
| NotImplementedError | ||||||||||||
| If no ONNX conversion is available for this operation | ||||||||||||
| """ | ||||||||||||
| op_type = type(op).__name__ | ||||||||||||
| raise NotImplementedError( | ||||||||||||
| f"No ONNX conversion available for: {op_type}. " | ||||||||||||
| f"The operation {op} is not yet supported in the ONNX backend." | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def make_value_info(var, name): | ||||||||||||
| """Create ONNX ValueInfoProto from PyTensor Variable. | ||||||||||||
|
|
||||||||||||
| Parameters | ||||||||||||
| ---------- | ||||||||||||
| var : Variable | ||||||||||||
| PyTensor variable | ||||||||||||
| name : str | ||||||||||||
| Name for the ONNX value | ||||||||||||
|
|
||||||||||||
| Returns | ||||||||||||
| ------- | ||||||||||||
| onnx.ValueInfoProto | ||||||||||||
| ONNX value info with shape and dtype | ||||||||||||
| """ | ||||||||||||
| # Get dtype | ||||||||||||
| dtype_str = var.type.dtype | ||||||||||||
| if dtype_str not in PYTENSOR_DTYPE_TO_ONNX: | ||||||||||||
| raise ValueError( | ||||||||||||
| f"Unsupported dtype: {dtype_str}. " | ||||||||||||
| f"Supported dtypes: {list(PYTENSOR_DTYPE_TO_ONNX.keys())}" | ||||||||||||
| ) | ||||||||||||
| onnx_dtype = PYTENSOR_DTYPE_TO_ONNX[dtype_str] | ||||||||||||
|
|
||||||||||||
| # Get shape - handle both static and symbolic shapes | ||||||||||||
| # For now, we'll use None for unknown dimensions | ||||||||||||
| ndim = var.type.ndim | ||||||||||||
| shape = [None] * ndim # Unknown dimensions | ||||||||||||
|
Comment on lines
+117
to
+120
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about non-TensorVariables? Are we raising explicitly if not supported? Examples include Slices, TypedLists, RandomGenerator, SparseTensorVariables |
||||||||||||
|
|
||||||||||||
| # Create tensor type | ||||||||||||
| return helper.make_tensor_value_info(name, onnx_dtype, shape) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| @onnx_funcify.register(FunctionGraph) | ||||||||||||
| def onnx_funcify_FunctionGraph( | ||||||||||||
| fgraph, | ||||||||||||
| opset_version=18, | ||||||||||||
| **kwargs, | ||||||||||||
| ): | ||||||||||||
| """Convert a PyTensor FunctionGraph to an ONNX ModelProto. | ||||||||||||
|
|
||||||||||||
| This function: | ||||||||||||
| 1. Does topological sort of nodes | ||||||||||||
| 2. Converts each node to ONNX via onnx_funcify | ||||||||||||
| 3. Collects constants as initializers | ||||||||||||
| 4. Creates ONNX ModelProto with inputs, outputs, and nodes | ||||||||||||
|
|
||||||||||||
| Operation Handler Return Patterns | ||||||||||||
| ---------------------------------- | ||||||||||||
| Handlers registered via @onnx_funcify.register can return: | ||||||||||||
|
|
||||||||||||
| 1. **Single node** (most common): | ||||||||||||
| return helper.make_node('Add', inputs=[...], outputs=[...]) | ||||||||||||
|
|
||||||||||||
| 2. **Multiple nodes** (operations requiring intermediate steps): | ||||||||||||
| return [ | ||||||||||||
| helper.make_node('Shape', ...), | ||||||||||||
| helper.make_node('Gather', ...), | ||||||||||||
| helper.make_node('Slice', ...), | ||||||||||||
| ] | ||||||||||||
|
|
||||||||||||
| 3. **Node with initializers** (operations with constant data): | ||||||||||||
| return ( | ||||||||||||
| helper.make_node('Transpose', ...), | ||||||||||||
| [axes_initializer], # List of TensorProto initializers | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
| 4. **None** (no-op, pass-through): | ||||||||||||
| return None | ||||||||||||
|
|
||||||||||||
| Notes: | ||||||||||||
| - List items can be None (will be filtered out) | ||||||||||||
| - Tuple pattern is (node, [initializers]), not (node, initializer) | ||||||||||||
| - Cannot mix patterns: either list OR tuple, not both | ||||||||||||
|
|
||||||||||||
| Parameters | ||||||||||||
| ---------- | ||||||||||||
| fgraph : FunctionGraph | ||||||||||||
| The function graph to convert | ||||||||||||
| opset_version : int | ||||||||||||
| ONNX opset version to use | ||||||||||||
|
|
||||||||||||
| Returns | ||||||||||||
| ------- | ||||||||||||
| onnx.ModelProto | ||||||||||||
| Complete ONNX model | ||||||||||||
| """ | ||||||||||||
| # Track variable names to ensure uniqueness | ||||||||||||
| var_names = {} | ||||||||||||
| var_counter = 0 | ||||||||||||
|
|
||||||||||||
| def get_var_name(var): | ||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there's a unique_name_generator helper already in link.utils that I think you can reuse |
||||||||||||
| """Get or create unique name for a variable.""" | ||||||||||||
| nonlocal var_counter | ||||||||||||
| if var not in var_names: | ||||||||||||
| if hasattr(var, "name") and var.name: | ||||||||||||
| base_name = var.name | ||||||||||||
| else: | ||||||||||||
| base_name = "var" | ||||||||||||
| # Ensure uniqueness | ||||||||||||
| name = f"{base_name}_{var_counter}" | ||||||||||||
| var_counter += 1 | ||||||||||||
| var_names[var] = name | ||||||||||||
| return var_names[var] | ||||||||||||
|
|
||||||||||||
| # Collect all nodes in topological order | ||||||||||||
| nodes = [] | ||||||||||||
| initializers = [] | ||||||||||||
|
|
||||||||||||
| # Process constants first | ||||||||||||
| for var in fgraph.variables: | ||||||||||||
| if isinstance(var, Constant): | ||||||||||||
| name = get_var_name(var) | ||||||||||||
| # Convert constant to ONNX initializer | ||||||||||||
| # Special handling: if constant is a scalar int type and is used in operations | ||||||||||||
| # with float tensors, upcast to float32 to avoid type mismatches | ||||||||||||
| data = var.data | ||||||||||||
| if data.ndim == 0 and np.issubdtype(data.dtype, np.integer): | ||||||||||||
| # Check if this constant is used with float operations | ||||||||||||
| # For now, we'll upcast all scalar integer constants to float32 | ||||||||||||
| # This is a simplification but handles the common case of: x * 2 | ||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This doesn't sound safe. Constants show up a lot in indexing operations for example x[:2], you wouldn't want to make that a float. Any implicit casting should be done by the Op that needs it, or is there a more fundamental onnx limitation here? |
||||||||||||
| # where x is float and 2 is an int scalar | ||||||||||||
| data = data.astype("float32") | ||||||||||||
|
|
||||||||||||
| tensor_proto = onnx_typify(data, name=name) | ||||||||||||
| initializers.append(tensor_proto) | ||||||||||||
|
|
||||||||||||
| # Process each node in topological order | ||||||||||||
| for node in fgraph.toposort(): | ||||||||||||
| # Convert node via dispatch | ||||||||||||
| result = onnx_funcify( | ||||||||||||
| node.op, | ||||||||||||
| node=node, | ||||||||||||
| var_names=var_names, | ||||||||||||
| get_var_name=get_var_name, | ||||||||||||
| **kwargs, | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
| # Handle multiple return patterns from operation handlers | ||||||||||||
| if result is not None: | ||||||||||||
| if isinstance(result, list): | ||||||||||||
| # Multiple nodes - add all to graph | ||||||||||||
| # Used for operations that compile to multiple ONNX ops | ||||||||||||
| # Example: Shape_i returns [Constant, Shape, Gather] | ||||||||||||
| nodes.extend(item for item in result if item is not None) | ||||||||||||
| elif isinstance(result, tuple): | ||||||||||||
| # Returned (node, additional_initializers) | ||||||||||||
| # Used for operations with constant initializers | ||||||||||||
| # Example: DimShuffle returns (Transpose, [axes_tensor]) | ||||||||||||
| onnx_node, node_initializers = result | ||||||||||||
| if onnx_node is not None: | ||||||||||||
| nodes.append(onnx_node) | ||||||||||||
| if node_initializers: | ||||||||||||
| initializers.extend(node_initializers) | ||||||||||||
| else: | ||||||||||||
| # Returned single node (most common case) | ||||||||||||
| # Example: Add returns single Add node | ||||||||||||
| nodes.append(result) | ||||||||||||
| else: | ||||||||||||
| # Handler returned None - this is a no-op operation | ||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you really need this: Make the handler return a specific sentinel Given you have identity node, below it sounds like you don't need it though |
||||||||||||
| # Map output variables to input variables (pass-through) | ||||||||||||
| # This is used for operations like SpecifyShape that don't | ||||||||||||
| # change the data, only provide shape hints for optimization | ||||||||||||
| if len(node.outputs) == 1 and len(node.inputs) > 0: | ||||||||||||
| # For single-output ops, alias output to first input | ||||||||||||
| output_var = node.outputs[0] | ||||||||||||
| input_var = node.inputs[0] | ||||||||||||
| # Map the output to use the same name as the input | ||||||||||||
| if output_var not in var_names: | ||||||||||||
| var_names[output_var] = get_var_name(input_var) | ||||||||||||
|
|
||||||||||||
| # Create input ValueInfos | ||||||||||||
| inputs = [] | ||||||||||||
| for inp in fgraph.inputs: | ||||||||||||
| if not isinstance(inp, Constant): | ||||||||||||
| name = get_var_name(inp) | ||||||||||||
| value_info = make_value_info(inp, name) | ||||||||||||
| inputs.append(value_info) | ||||||||||||
|
|
||||||||||||
| # Create output ValueInfos | ||||||||||||
| outputs = [] | ||||||||||||
| for out in fgraph.outputs: | ||||||||||||
| name = get_var_name(out) | ||||||||||||
| value_info = make_value_info(out, name) | ||||||||||||
| outputs.append(value_info) | ||||||||||||
|
|
||||||||||||
| # Create the graph | ||||||||||||
| graph_def = helper.make_graph( | ||||||||||||
| nodes=nodes, | ||||||||||||
| name="pytensor_graph", | ||||||||||||
| inputs=inputs, | ||||||||||||
| outputs=outputs, | ||||||||||||
| initializer=initializers, | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
| # Create the model with IR version 9 for compatibility with ONNX Runtime | ||||||||||||
| model_def = helper.make_model( | ||||||||||||
| graph_def, | ||||||||||||
| opset_imports=[helper.make_opsetid("", opset_version)], | ||||||||||||
| producer_name="PyTensor", | ||||||||||||
| ir_version=9, # Use IR version 9 for ONNX Runtime compatibility | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
| # Check the model | ||||||||||||
| onnx.checker.check_model(model_def) | ||||||||||||
|
|
||||||||||||
| return model_def | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| @onnx_funcify.register(Constant) | ||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Constants aren't nodes so you shouldn't need to register them for funcify |
||||||||||||
| def onnx_funcify_Constant(op, **kwargs): | ||||||||||||
| """Constants are handled as initializers, not nodes.""" | ||||||||||||
| # Constants don't produce nodes - they're added as initializers | ||||||||||||
| # in the FunctionGraph converter | ||||||||||||
| return None | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| @onnx_funcify.register(DeepCopyOp) | ||||||||||||
| def onnx_funcify_DeepCopyOp(op, node, get_var_name, **kwargs): | ||||||||||||
| """Convert DeepCopyOp to ONNX Identity node. | ||||||||||||
|
|
||||||||||||
| DeepCopyOp is equivalent to Identity in ONNX. | ||||||||||||
| """ | ||||||||||||
| input_names = [get_var_name(inp) for inp in node.inputs] | ||||||||||||
| output_names = [get_var_name(out) for out in node.outputs] | ||||||||||||
|
|
||||||||||||
| return helper.make_node( | ||||||||||||
| "Identity", | ||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you have Identity in ONNX why can't you use those for the Ops like |
||||||||||||
| inputs=input_names, | ||||||||||||
| outputs=output_names, | ||||||||||||
| name=f"Identity_{output_names[0]}", | ||||||||||||
| ) | ||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A better default is to raise, we did this before for other backends and have been moving away