diff --git a/.gitignore b/.gitignore index ebe8e61bd0..58d2cb6cbc 100644 --- a/.gitignore +++ b/.gitignore @@ -55,3 +55,4 @@ pytensor-venv/ testing-report.html coverage.xml .coverage.* +.hypothesis/ \ No newline at end of file diff --git a/pytensor/link/onnx/__init__.py b/pytensor/link/onnx/__init__.py new file mode 100644 index 0000000000..fdb33a4ab4 --- /dev/null +++ b/pytensor/link/onnx/__init__.py @@ -0,0 +1,23 @@ +"""ONNX backend for PyTensor. + +This module provides functionality to export PyTensor graphs to ONNX format +and execute them using ONNX Runtime. +""" + +from pytensor.link.onnx.dispatch import onnx_funcify, onnx_typify +from pytensor.link.onnx.export import compile_onnx, export_function_onnx, export_onnx +from pytensor.link.onnx.linker import ONNXLinker + + +# ONNX opset version used by default +ONNX_OPSET_VERSION = 18 + +__all__ = [ + "ONNX_OPSET_VERSION", + "ONNXLinker", + "compile_onnx", + "export_function_onnx", + "export_onnx", + "onnx_funcify", + "onnx_typify", +] diff --git a/pytensor/link/onnx/dispatch/__init__.py b/pytensor/link/onnx/dispatch/__init__.py new file mode 100644 index 0000000000..e422c67a00 --- /dev/null +++ b/pytensor/link/onnx/dispatch/__init__.py @@ -0,0 +1,15 @@ +"""ONNX dispatch system for converting PyTensor operations to ONNX.""" + +# isort: off +from pytensor.link.onnx.dispatch.basic import onnx_funcify, onnx_typify + +# Load dispatch specializations +import pytensor.link.onnx.dispatch.elemwise +import pytensor.link.onnx.dispatch.shape +import pytensor.link.onnx.dispatch.math +import pytensor.link.onnx.dispatch.tensor_basic +import pytensor.link.onnx.dispatch.subtensor +import pytensor.link.onnx.dispatch.nlinalg +import pytensor.link.onnx.dispatch.nnet + +# isort: on diff --git a/pytensor/link/onnx/dispatch/basic.py b/pytensor/link/onnx/dispatch/basic.py new file mode 100644 index 0000000000..ac92756710 --- /dev/null +++ b/pytensor/link/onnx/dispatch/basic.py @@ -0,0 +1,324 @@ +"""Core ONNX dispatch functions for converting PyTensor graphs to ONNX.""" + +from functools import singledispatch + +import numpy as np +import onnx +from onnx import helper, numpy_helper + +from pytensor.compile.ops import DeepCopyOp +from pytensor.graph import Constant +from pytensor.graph.fg import FunctionGraph + + +# Mapping from PyTensor dtypes to ONNX TensorProto dtypes +PYTENSOR_DTYPE_TO_ONNX = { + "float32": onnx.TensorProto.FLOAT, + "float64": onnx.TensorProto.DOUBLE, + "int32": onnx.TensorProto.INT32, + "int64": onnx.TensorProto.INT64, + "uint8": onnx.TensorProto.UINT8, + "int8": onnx.TensorProto.INT8, + "uint16": onnx.TensorProto.UINT16, + "int16": onnx.TensorProto.INT16, + "bool": onnx.TensorProto.BOOL, +} + + +@singledispatch +def onnx_typify(data, dtype=None, name=None, **kwargs): + """Convert Python/NumPy data to ONNX TensorProto. + + Parameters + ---------- + data : array-like + Data to convert + dtype : str, optional + Data type + name : str, optional + Name for the tensor + + Returns + ------- + onnx.TensorProto + ONNX tensor representation + """ + # Default: try to convert to numpy array first + if not isinstance(data, np.ndarray): + data = np.array(data, dtype=dtype) + return numpy_helper.from_array(data, name=name) + + +@onnx_typify.register(np.ndarray) +def onnx_typify_ndarray(data, dtype=None, name=None, **kwargs): + """Convert NumPy array to ONNX TensorProto.""" + if dtype is not None: + data = data.astype(dtype) + return numpy_helper.from_array(data, name=name) + + +@singledispatch +def onnx_funcify(op, node=None, **kwargs): + """Convert a PyTensor Op to an ONNX node. + + This is the core dispatch function that converts PyTensor operations + to their ONNX equivalents. + + Parameters + ---------- + op : Op or FunctionGraph + The operation or graph to convert + node : Apply, optional + The Apply node containing this operation + **kwargs : dict + Additional arguments passed through the conversion + + Returns + ------- + onnx.NodeProto or onnx.ModelProto + ONNX representation of the operation + + Raises + ------ + NotImplementedError + If no ONNX conversion is available for this operation + """ + op_type = type(op).__name__ + raise NotImplementedError( + f"No ONNX conversion available for: {op_type}. " + f"The operation {op} is not yet supported in the ONNX backend." + ) + + +def make_value_info(var, name): + """Create ONNX ValueInfoProto from PyTensor Variable. + + Parameters + ---------- + var : Variable + PyTensor variable + name : str + Name for the ONNX value + + Returns + ------- + onnx.ValueInfoProto + ONNX value info with shape and dtype + """ + # Get dtype + dtype_str = var.type.dtype + if dtype_str not in PYTENSOR_DTYPE_TO_ONNX: + raise ValueError( + f"Unsupported dtype: {dtype_str}. " + f"Supported dtypes: {list(PYTENSOR_DTYPE_TO_ONNX.keys())}" + ) + onnx_dtype = PYTENSOR_DTYPE_TO_ONNX[dtype_str] + + # Get shape - handle both static and symbolic shapes + # For now, we'll use None for unknown dimensions + ndim = var.type.ndim + shape = [None] * ndim # Unknown dimensions + + # Create tensor type + return helper.make_tensor_value_info(name, onnx_dtype, shape) + + +@onnx_funcify.register(FunctionGraph) +def onnx_funcify_FunctionGraph( + fgraph, + opset_version=18, + **kwargs, +): + """Convert a PyTensor FunctionGraph to an ONNX ModelProto. + + This function: + 1. Does topological sort of nodes + 2. Converts each node to ONNX via onnx_funcify + 3. Collects constants as initializers + 4. Creates ONNX ModelProto with inputs, outputs, and nodes + + Operation Handler Return Patterns + ---------------------------------- + Handlers registered via @onnx_funcify.register can return: + + 1. **Single node** (most common): + return helper.make_node('Add', inputs=[...], outputs=[...]) + + 2. **Multiple nodes** (operations requiring intermediate steps): + return [ + helper.make_node('Shape', ...), + helper.make_node('Gather', ...), + helper.make_node('Slice', ...), + ] + + 3. **Node with initializers** (operations with constant data): + return ( + helper.make_node('Transpose', ...), + [axes_initializer], # List of TensorProto initializers + ) + + 4. **None** (no-op, pass-through): + return None + + Notes: + - List items can be None (will be filtered out) + - Tuple pattern is (node, [initializers]), not (node, initializer) + - Cannot mix patterns: either list OR tuple, not both + + Parameters + ---------- + fgraph : FunctionGraph + The function graph to convert + opset_version : int + ONNX opset version to use + + Returns + ------- + onnx.ModelProto + Complete ONNX model + """ + # Track variable names to ensure uniqueness + var_names = {} + var_counter = 0 + + def get_var_name(var): + """Get or create unique name for a variable.""" + nonlocal var_counter + if var not in var_names: + if hasattr(var, "name") and var.name: + base_name = var.name + else: + base_name = "var" + # Ensure uniqueness + name = f"{base_name}_{var_counter}" + var_counter += 1 + var_names[var] = name + return var_names[var] + + # Collect all nodes in topological order + nodes = [] + initializers = [] + + # Process constants first + for var in fgraph.variables: + if isinstance(var, Constant): + name = get_var_name(var) + # Convert constant to ONNX initializer + # Special handling: if constant is a scalar int type and is used in operations + # with float tensors, upcast to float32 to avoid type mismatches + data = var.data + if data.ndim == 0 and np.issubdtype(data.dtype, np.integer): + # Check if this constant is used with float operations + # For now, we'll upcast all scalar integer constants to float32 + # This is a simplification but handles the common case of: x * 2 + # where x is float and 2 is an int scalar + data = data.astype("float32") + + tensor_proto = onnx_typify(data, name=name) + initializers.append(tensor_proto) + + # Process each node in topological order + for node in fgraph.toposort(): + # Convert node via dispatch + result = onnx_funcify( + node.op, + node=node, + var_names=var_names, + get_var_name=get_var_name, + **kwargs, + ) + + # Handle multiple return patterns from operation handlers + if result is not None: + if isinstance(result, list): + # Multiple nodes - add all to graph + # Used for operations that compile to multiple ONNX ops + # Example: Shape_i returns [Constant, Shape, Gather] + nodes.extend(item for item in result if item is not None) + elif isinstance(result, tuple): + # Returned (node, additional_initializers) + # Used for operations with constant initializers + # Example: DimShuffle returns (Transpose, [axes_tensor]) + onnx_node, node_initializers = result + if onnx_node is not None: + nodes.append(onnx_node) + if node_initializers: + initializers.extend(node_initializers) + else: + # Returned single node (most common case) + # Example: Add returns single Add node + nodes.append(result) + else: + # Handler returned None - this is a no-op operation + # Map output variables to input variables (pass-through) + # This is used for operations like SpecifyShape that don't + # change the data, only provide shape hints for optimization + if len(node.outputs) == 1 and len(node.inputs) > 0: + # For single-output ops, alias output to first input + output_var = node.outputs[0] + input_var = node.inputs[0] + # Map the output to use the same name as the input + if output_var not in var_names: + var_names[output_var] = get_var_name(input_var) + + # Create input ValueInfos + inputs = [] + for inp in fgraph.inputs: + if not isinstance(inp, Constant): + name = get_var_name(inp) + value_info = make_value_info(inp, name) + inputs.append(value_info) + + # Create output ValueInfos + outputs = [] + for out in fgraph.outputs: + name = get_var_name(out) + value_info = make_value_info(out, name) + outputs.append(value_info) + + # Create the graph + graph_def = helper.make_graph( + nodes=nodes, + name="pytensor_graph", + inputs=inputs, + outputs=outputs, + initializer=initializers, + ) + + # Create the model with IR version 9 for compatibility with ONNX Runtime + model_def = helper.make_model( + graph_def, + opset_imports=[helper.make_opsetid("", opset_version)], + producer_name="PyTensor", + ir_version=9, # Use IR version 9 for ONNX Runtime compatibility + ) + + # Check the model + onnx.checker.check_model(model_def) + + return model_def + + +@onnx_funcify.register(Constant) +def onnx_funcify_Constant(op, **kwargs): + """Constants are handled as initializers, not nodes.""" + # Constants don't produce nodes - they're added as initializers + # in the FunctionGraph converter + return None + + +@onnx_funcify.register(DeepCopyOp) +def onnx_funcify_DeepCopyOp(op, node, get_var_name, **kwargs): + """Convert DeepCopyOp to ONNX Identity node. + + DeepCopyOp is equivalent to Identity in ONNX. + """ + input_names = [get_var_name(inp) for inp in node.inputs] + output_names = [get_var_name(out) for out in node.outputs] + + return helper.make_node( + "Identity", + inputs=input_names, + outputs=output_names, + name=f"Identity_{output_names[0]}", + ) diff --git a/pytensor/link/onnx/dispatch/elemwise.py b/pytensor/link/onnx/dispatch/elemwise.py new file mode 100644 index 0000000000..bf4a280d5b --- /dev/null +++ b/pytensor/link/onnx/dispatch/elemwise.py @@ -0,0 +1,267 @@ +"""ONNX conversion for elementwise operations.""" + +from onnx import helper + +from pytensor.link.onnx.dispatch.basic import onnx_funcify +from pytensor.scalar import basic as scalar +from pytensor.scalar import math as scalar_math +from pytensor.tensor.elemwise import Elemwise + + +# ⭐ THE MAGIC MAPPING - Tier 1 + Tier 4-5 operations +SCALAR_OP_TO_ONNX = { + # Arithmetic (Tier 1) + scalar.Add: "Add", + scalar.Mul: "Mul", + scalar.Sub: "Sub", + scalar.TrueDiv: "Div", + scalar.Neg: "Neg", + # Note: IntDiv handled specially in onnx_funcify_Elemwise as Div + Floor + # Math (Tier 1) + scalar.Abs: "Abs", + scalar.Exp: "Exp", + scalar.Log: "Log", + scalar.Sqrt: "Sqrt", + scalar.Pow: "Pow", + scalar.Floor: "Floor", + scalar.Ceil: "Ceil", + scalar.RoundHalfToEven: "Round", + scalar.RoundHalfAwayFromZero: "Round", + # Min/Max (Tier 1) + scalar.Maximum: "Max", + scalar.Minimum: "Min", + # Trigonometric (Tier 5) + scalar.Sin: "Sin", + scalar.Cos: "Cos", + scalar.Tan: "Tan", + scalar.ArcSin: "Asin", + scalar.ArcCos: "Acos", + scalar.ArcTan: "Atan", + # Hyperbolic (Tier 5) + scalar.Sinh: "Sinh", + scalar.Cosh: "Cosh", + scalar.Tanh: "Tanh", + scalar.ArcSinh: "Asinh", + scalar.ArcCosh: "Acosh", + scalar.ArcTanh: "Atanh", + # Comparison (Tier 5) + scalar.LT: "Less", + scalar.GT: "Greater", + scalar.LE: "LessOrEqual", + scalar.GE: "GreaterOrEqual", + scalar.EQ: "Equal", + # Note: NEQ is handled specially in onnx_funcify_Elemwise as Equal + Not + # Logical (Tier 5) + scalar.AND: "And", + scalar.OR: "Or", + scalar.XOR: "Xor", + scalar.Invert: "Not", + # Special (Tier 5) + scalar_math.Sigmoid: "Sigmoid", + scalar_math.Softplus: "Softplus", + scalar_math.Erf: "Erf", + # Note: Clip handled specially in onnx_funcify_Elemwise (requires scalar min/max) + # Conditional + scalar.Switch: "Where", +} + + +@onnx_funcify.register(Elemwise) +def onnx_funcify_Elemwise(op, node, get_var_name, **kwargs): + """Convert Elemwise op to ONNX node. + + This ONE function handles ALL operations, including composed ones! + + Parameters + ---------- + op : Elemwise + The elementwise operation + node : Apply + The Apply node + get_var_name : callable + Function to get variable names + **kwargs : dict + Additional keyword arguments + + Returns + ------- + onnx.NodeProto or list[onnx.NodeProto] + ONNX node(s) for the operation + """ + scalar_op_type = type(op.scalar_op) + + # Special handling for operations that need to be composed + # Clip(x, min, max) - ONNX requires scalar min/max, but PyTensor may provide tensors + if scalar_op_type == scalar.Clip: + input_names = [get_var_name(inp) for inp in node.inputs] + output_name = get_var_name(node.outputs[0]) + + # Input 0 is the array to clip, inputs 1 and 2 are min/max + # ONNX Clip expects scalars for min/max, but PyTensor may have added dimensions + # We need to squeeze them if they're not scalars + x_name = input_names[0] + min_name = input_names[1] + max_name = input_names[2] + + # Create Squeeze nodes for min and max to ensure they're scalars + # ONNX Squeeze with empty axes removes all dimensions of size 1 + min_scalar_name = f"{output_name}_min_scalar" + min_squeeze = helper.make_node( + "Squeeze", + inputs=[min_name], + outputs=[min_scalar_name], + name=f"Squeeze_{min_scalar_name}", + ) + + max_scalar_name = f"{output_name}_max_scalar" + max_squeeze = helper.make_node( + "Squeeze", + inputs=[max_name], + outputs=[max_scalar_name], + name=f"Squeeze_{max_scalar_name}", + ) + + # Clip with scalar min/max + clip_node = helper.make_node( + "Clip", + inputs=[x_name, min_scalar_name, max_scalar_name], + outputs=[output_name], + name=f"Clip_{output_name}", + ) + + return [min_squeeze, max_squeeze, clip_node] + + # IntDiv(x, y) = Floor(Div(x, y)) + if scalar_op_type == scalar.IntDiv: + input_names = [get_var_name(inp) for inp in node.inputs] + output_name = get_var_name(node.outputs[0]) + + # Div(x, y) + div_name = f"{output_name}_div" + div_node = helper.make_node( + "Div", + inputs=input_names, + outputs=[div_name], + name=f"Div_{div_name}", + ) + + # Floor(Div(x, y)) + floor_node = helper.make_node( + "Floor", + inputs=[div_name], + outputs=[output_name], + name=f"Floor_{output_name}", + ) + + return [div_node, floor_node] + + # NEQ(x, y) = Not(Equal(x, y)) + if scalar_op_type == scalar.NEQ: + input_names = [get_var_name(inp) for inp in node.inputs] + output_name = get_var_name(node.outputs[0]) + + # Equal(x, y) + equal_name = f"{output_name}_equal" + equal_node = helper.make_node( + "Equal", + inputs=input_names, + outputs=[equal_name], + name=f"Equal_{equal_name}", + ) + + # Not(Equal(x, y)) + not_node = helper.make_node( + "Not", + inputs=[equal_name], + outputs=[output_name], + name=f"Not_{output_name}", + ) + + return [equal_node, not_node] + + # Log1p(x) = Log(Add(x, 1)) + if scalar_op_type == scalar.Log1p: + input_name = get_var_name(node.inputs[0]) + output_name = get_var_name(node.outputs[0]) + + # Create constant 1 + one_name = f"{output_name}_one" + one_node = helper.make_node( + "Constant", + inputs=[], + outputs=[one_name], + value=helper.make_tensor("value", helper.TensorProto.FLOAT, [], [1.0]), + ) + + # Add(x, 1) + add_name = f"{output_name}_add" + add_node = helper.make_node( + "Add", + inputs=[input_name, one_name], + outputs=[add_name], + name=f"Add_{add_name}", + ) + + # Log(Add(x, 1)) + log_node = helper.make_node( + "Log", + inputs=[add_name], + outputs=[output_name], + name=f"Log_{output_name}", + ) + + return [one_node, add_node, log_node] + + # Expm1(x) = Sub(Exp(x), 1) + if scalar_op_type == scalar.Expm1: + input_name = get_var_name(node.inputs[0]) + output_name = get_var_name(node.outputs[0]) + + # Exp(x) + exp_name = f"{output_name}_exp" + exp_node = helper.make_node( + "Exp", + inputs=[input_name], + outputs=[exp_name], + name=f"Exp_{exp_name}", + ) + + # Create constant 1 + one_name = f"{output_name}_one" + one_node = helper.make_node( + "Constant", + inputs=[], + outputs=[one_name], + value=helper.make_tensor("value", helper.TensorProto.FLOAT, [], [1.0]), + ) + + # Sub(Exp(x), 1) + sub_node = helper.make_node( + "Sub", + inputs=[exp_name, one_name], + outputs=[output_name], + name=f"Sub_{output_name}", + ) + + return [exp_node, one_node, sub_node] + + # Standard operations + if scalar_op_type not in SCALAR_OP_TO_ONNX: + raise NotImplementedError( + f"Elemwise scalar op not supported for ONNX export: {scalar_op_type.__name__}. " + f"Supported operations: {list(SCALAR_OP_TO_ONNX.keys())}" + ) + + onnx_op_type = SCALAR_OP_TO_ONNX[scalar_op_type] + + # Get input and output names + input_names = [get_var_name(inp) for inp in node.inputs] + output_names = [get_var_name(out) for out in node.outputs] + + # Create ONNX node + return helper.make_node( + onnx_op_type, + inputs=input_names, + outputs=output_names, + name=f"{onnx_op_type}_{output_names[0]}", + ) diff --git a/pytensor/link/onnx/dispatch/math.py b/pytensor/link/onnx/dispatch/math.py new file mode 100644 index 0000000000..6e1e431342 --- /dev/null +++ b/pytensor/link/onnx/dispatch/math.py @@ -0,0 +1,141 @@ +"""ONNX conversion for math operations (reductions).""" + +from pytensor.link.onnx.dispatch.basic import onnx_funcify +from pytensor.scalar.basic import AND, OR, Add, Maximum, Minimum, Mul +from pytensor.tensor.math import Argmax, CAReduce + + +try: + from onnx import helper +except ImportError as e: + raise ImportError("ONNX package required for export") from e + + +# Mapping from PyTensor scalar ops to ONNX reduction ops +REDUCE_OP_MAP = { + Add: "ReduceSum", + Mul: "ReduceProd", + Maximum: "ReduceMax", + Minimum: "ReduceMin", + AND: "ReduceMin", # For boolean AND + OR: "ReduceMax", # For boolean OR +} + + +@onnx_funcify.register(CAReduce) +def onnx_funcify_CAReduce(op, node, get_var_name, **kwargs): + """Convert CAReduce op to ONNX reduction node. + + CAReduce performs reductions (sum, prod, max, min) along specified axes. + + For ONNX opset 18+, axes must be provided as an input tensor, + not as an attribute. + """ + scalar_op_type = type(op.scalar_op) + + if scalar_op_type not in REDUCE_OP_MAP: + raise NotImplementedError( + f"CAReduce with scalar op {scalar_op_type.__name__} not supported for ONNX export" + ) + + onnx_op_type = REDUCE_OP_MAP[scalar_op_type] + + input_name = get_var_name(node.inputs[0]) + output_name = get_var_name(node.outputs[0]) + + # Get axis parameter + axes = op.axis + nodes = [] + + if axes is not None: + # Convert to list if needed + if isinstance(axes, (tuple, list)): + axes_list = list(axes) + else: + axes_list = [axes] + + # For opset 18+, axes must be an input tensor + axes_name = f"{output_name}_axes" + axes_constant = helper.make_node( + "Constant", + inputs=[], + outputs=[axes_name], + name=f"Constant_{axes_name}", + value=helper.make_tensor( + name=f"{axes_name}_value", + data_type=helper.TensorProto.INT64, + dims=[len(axes_list)], + vals=axes_list, + ), + ) + nodes.append(axes_constant) + + onnx_node = helper.make_node( + onnx_op_type, + inputs=[input_name, axes_name], + outputs=[output_name], + name=f"{onnx_op_type}_{output_name}", + keepdims=0, # PyTensor default is to not keep dims + ) + else: + # Reduce over all axes - don't provide axes input + onnx_node = helper.make_node( + onnx_op_type, + inputs=[input_name], + outputs=[output_name], + name=f"{onnx_op_type}_{output_name}", + keepdims=0, + ) + + nodes.append(onnx_node) + return nodes if len(nodes) > 1 else onnx_node + + +@onnx_funcify.register(Argmax) +def onnx_funcify_Argmax(op, node, get_var_name, **kwargs): + """Convert Argmax op to ONNX ArgMax node.""" + input_name = get_var_name(node.inputs[0]) + output_name = get_var_name(node.outputs[0]) + + axis = op.axis + if axis is None: + # Argmax over all axes - need to flatten first + flatten_name = f"{output_name}_flat" + flatten_node = helper.make_node( + "Flatten", + inputs=[input_name], + outputs=[flatten_name], + name=f"Flatten_{flatten_name}", + axis=0, + ) + + argmax_node = helper.make_node( + "ArgMax", + inputs=[flatten_name], + outputs=[output_name], + name=f"ArgMax_{output_name}", + axis=0, + keepdims=0, + ) + + return [flatten_node, argmax_node] + else: + # Argmax over specific axis + # PyTensor stores axis as a tuple, ONNX ArgMax expects a single int + if isinstance(axis, (tuple, list)): + if len(axis) != 1: + raise NotImplementedError( + f"ONNX ArgMax only supports single axis, got {axis}" + ) + axis = axis[0] + + onnx_node = helper.make_node( + "ArgMax", + inputs=[input_name], + outputs=[output_name], + name=f"ArgMax_{output_name}", + axis=int(axis), + keepdims=0, + ) + + return onnx_node diff --git a/pytensor/link/onnx/dispatch/nlinalg.py b/pytensor/link/onnx/dispatch/nlinalg.py new file mode 100644 index 0000000000..01ca6a977d --- /dev/null +++ b/pytensor/link/onnx/dispatch/nlinalg.py @@ -0,0 +1,98 @@ +"""ONNX conversion for linear algebra operations.""" + +from onnx import helper + +from pytensor.link.onnx.dispatch.basic import onnx_funcify +from pytensor.tensor.blas import BatchedDot, Gemm +from pytensor.tensor.math import Dot + + +@onnx_funcify.register(Dot) +def onnx_funcify_Dot(op, node, get_var_name, **kwargs): + """Convert Dot op to ONNX MatMul node. + + Dot performs matrix multiplication. ONNX MatMul handles: + - Matrix @ Matrix + - Vector @ Matrix (with implicit unsqueeze) + - Batched operations + """ + input_a = get_var_name(node.inputs[0]) + input_b = get_var_name(node.inputs[1]) + output_name = get_var_name(node.outputs[0]) + + # ONNX MatMul handles most cases directly + matmul_node = helper.make_node( + "MatMul", + inputs=[input_a, input_b], + outputs=[output_name], + name=f"MatMul_{output_name}", + ) + + return matmul_node + + +@onnx_funcify.register(Gemm) +def onnx_funcify_Gemm(op, node, get_var_name, **kwargs): + """Convert Gemm op to ONNX Gemm node. + + PyTensor Gemm: gemm(C, alpha, A, B, beta) = beta*C + alpha*dot(A, B) + ONNX Gemm: Y = alpha * A' * B' + beta * C + + Where inputs are: [C, alpha, A, B, beta] + Remap to ONNX: [A, B, C] with alpha and beta as attributes + """ + from pytensor.graph import Constant + + # PyTensor inputs: [C, alpha, A, B, beta] + input_c = get_var_name(node.inputs[0]) + alpha_var = node.inputs[1] + input_a = get_var_name(node.inputs[2]) + input_b = get_var_name(node.inputs[3]) + beta_var = node.inputs[4] + output_name = get_var_name(node.outputs[0]) + + # Extract alpha and beta values (should be constants) + if isinstance(alpha_var, Constant): + alpha = float(alpha_var.data) + else: + alpha = 1.0 + + if isinstance(beta_var, Constant): + beta = float(beta_var.data) + else: + beta = 1.0 + + # ONNX Gemm: Y = alpha * A @ B + beta * C + gemm_node = helper.make_node( + "Gemm", + inputs=[input_a, input_b, input_c], + outputs=[output_name], + name=f"Gemm_{output_name}", + alpha=alpha, + beta=beta, + transA=0, + transB=0, + ) + + return gemm_node + + +@onnx_funcify.register(BatchedDot) +def onnx_funcify_BatchedDot(op, node, get_var_name, **kwargs): + """Convert BatchedDot to ONNX MatMul. + + BatchedDot performs batched matrix multiplication. + ONNX MatMul handles batching natively. + """ + input_a = get_var_name(node.inputs[0]) + input_b = get_var_name(node.inputs[1]) + output_name = get_var_name(node.outputs[0]) + + matmul_node = helper.make_node( + "MatMul", + inputs=[input_a, input_b], + outputs=[output_name], + name=f"MatMul_{output_name}", + ) + + return matmul_node diff --git a/pytensor/link/onnx/dispatch/nnet.py b/pytensor/link/onnx/dispatch/nnet.py new file mode 100644 index 0000000000..53c18389b3 --- /dev/null +++ b/pytensor/link/onnx/dispatch/nnet.py @@ -0,0 +1,184 @@ +"""ONNX conversion for neural network operations.""" + +from onnx import helper + +from pytensor.link.onnx.dispatch.basic import onnx_funcify +from pytensor.tensor.special import LogSoftmax, Softmax + + +@onnx_funcify.register(Softmax) +def onnx_funcify_Softmax(op, node, get_var_name, **kwargs): + """Convert Softmax op to ONNX Softmax node. + + PyTensor Softmax: Softmax(x, axis=axis) + ONNX Softmax: Softmax operator with axis attribute + + Special case: When axis=None, PyTensor applies softmax to the entire + flattened array. ONNX doesn't support this directly, so we need to: + 1. Flatten the input + 2. Apply softmax with axis=-1 + 3. Reshape back to original shape + + Parameters + ---------- + op : Softmax + The Softmax operation + node : Apply + The Apply node + get_var_name : callable + Function to get variable names + **kwargs : dict + Additional keyword arguments + + Returns + ------- + onnx.NodeProto or list[onnx.NodeProto] + ONNX node(s) for the operation + """ + input_x = get_var_name(node.inputs[0]) + output_name = get_var_name(node.outputs[0]) + + if op.axis is None: + # axis=None means apply to flattened array + # Need to: Flatten -> Softmax(axis=-1) -> Reshape + + # Get input shape for reshaping back + shape_name = f"{output_name}_orig_shape" + flatten_name = f"{output_name}_flat" + softmax_name = f"{output_name}_softmax" + + # Get original shape + shape_node = helper.make_node( + "Shape", + inputs=[input_x], + outputs=[shape_name], + name=f"Shape_{output_name}", + ) + + # Flatten to 1D + flatten_node = helper.make_node( + "Flatten", + inputs=[input_x], + outputs=[flatten_name], + name=f"Flatten_{output_name}", + axis=0, # Flatten to 1D + ) + + # Apply softmax to flattened array (axis=-1) + softmax_node = helper.make_node( + "Softmax", + inputs=[flatten_name], + outputs=[softmax_name], + name=f"Softmax_{output_name}", + axis=-1, + ) + + # Reshape back to original shape + reshape_node = helper.make_node( + "Reshape", + inputs=[softmax_name, shape_name], + outputs=[output_name], + name=f"Reshape_{output_name}", + ) + + return [shape_node, flatten_node, softmax_node, reshape_node] + else: + # Normal case: axis is specified + softmax_node = helper.make_node( + "Softmax", + inputs=[input_x], + outputs=[output_name], + name=f"Softmax_{output_name}", + axis=op.axis, + ) + + return softmax_node + + +@onnx_funcify.register(LogSoftmax) +def onnx_funcify_LogSoftmax(op, node, get_var_name, **kwargs): + """Convert LogSoftmax op to ONNX LogSoftmax node. + + PyTensor LogSoftmax: LogSoftmax(x, axis=axis) + ONNX LogSoftmax: LogSoftmax operator with axis attribute + + Special case: When axis=None, PyTensor applies logsoftmax to the entire + flattened array. ONNX doesn't support this directly, so we need to: + 1. Flatten the input + 2. Apply logsoftmax with axis=-1 + 3. Reshape back to original shape + + Parameters + ---------- + op : LogSoftmax + The LogSoftmax operation + node : Apply + The Apply node + get_var_name : callable + Function to get variable names + **kwargs : dict + Additional keyword arguments + + Returns + ------- + onnx.NodeProto or list[onnx.NodeProto] + ONNX node(s) for the operation + """ + input_x = get_var_name(node.inputs[0]) + output_name = get_var_name(node.outputs[0]) + + if op.axis is None: + # axis=None means apply to flattened array + # Need to: Flatten -> LogSoftmax(axis=-1) -> Reshape + + # Get input shape for reshaping back + shape_name = f"{output_name}_orig_shape" + flatten_name = f"{output_name}_flat" + logsoftmax_name = f"{output_name}_logsoftmax" + + # Get original shape + shape_node = helper.make_node( + "Shape", + inputs=[input_x], + outputs=[shape_name], + name=f"Shape_{output_name}", + ) + + # Flatten to 1D + flatten_node = helper.make_node( + "Flatten", + inputs=[input_x], + outputs=[flatten_name], + name=f"Flatten_{output_name}", + axis=0, # Flatten to 1D + ) + + # Apply logsoftmax to flattened array (axis=-1) + logsoftmax_node = helper.make_node( + "LogSoftmax", + inputs=[flatten_name], + outputs=[logsoftmax_name], + name=f"LogSoftmax_{output_name}", + axis=-1, + ) + + # Reshape back to original shape + reshape_node = helper.make_node( + "Reshape", + inputs=[logsoftmax_name, shape_name], + outputs=[output_name], + name=f"Reshape_{output_name}", + ) + + return [shape_node, flatten_node, logsoftmax_node, reshape_node] + else: + # Normal case: axis is specified + logsoftmax_node = helper.make_node( + "LogSoftmax", + inputs=[input_x], + outputs=[output_name], + name=f"LogSoftmax_{output_name}", + axis=op.axis, + ) + + return logsoftmax_node diff --git a/pytensor/link/onnx/dispatch/shape.py b/pytensor/link/onnx/dispatch/shape.py new file mode 100644 index 0000000000..6efb2e532a --- /dev/null +++ b/pytensor/link/onnx/dispatch/shape.py @@ -0,0 +1,382 @@ +"""ONNX conversion for shape operations.""" + +import numpy as np +from onnx import helper, numpy_helper + +from pytensor.graph.basic import Constant +from pytensor.link.onnx.dispatch.basic import onnx_funcify +from pytensor.tensor.basic import Join, Split, get_scalar_constant_value +from pytensor.tensor.exceptions import NotScalarConstantError +from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape + + +@onnx_funcify.register(type(None)) +def onnx_funcify_None(op, **kwargs): + """Handle None ops (used in some graph optimizations).""" + return None + + +@onnx_funcify.register(Shape) +def onnx_funcify_Shape(op, node, get_var_name, **kwargs): + """Convert Shape op to ONNX Shape node. + + Returns tensor containing shape of input. + """ + input_name = get_var_name(node.inputs[0]) + output_name = get_var_name(node.outputs[0]) + + onnx_node = helper.make_node( + "Shape", + inputs=[input_name], + outputs=[output_name], + name=f"Shape_{output_name}", + ) + + return onnx_node + + +@onnx_funcify.register(Shape_i) +def onnx_funcify_Shape_i(op, node, get_var_name, **kwargs): + """Convert Shape_i op to ONNX Shape + Gather nodes. + + Shape_i extracts a specific dimension from a tensor's shape. + This requires multiple ONNX nodes: + 1. Constant - create index constant + 2. Shape - get full shape tensor + 3. Gather - extract the specific dimension + + This operation demonstrates the multi-node return pattern. + + Example: + x = pt.matrix('x') + dim0 = x.shape[0] # Shape_i with i=0 + + ONNX graph: + Constant(value=0) → idx + Shape(x) → shape_tensor + Gather(shape_tensor, idx, axis=0) → dim0 + """ + input_name = get_var_name(node.inputs[0]) + output_name = get_var_name(node.outputs[0]) + + # Get dimension index from op + axis_idx = op.i + + # Create intermediate names + shape_name = f"{output_name}_shape" + idx_name = f"{output_name}_idx" + + # Node 1: Create constant for index + idx_constant = helper.make_node( + "Constant", + inputs=[], + outputs=[idx_name], + name=f"Constant_{idx_name}", + value=helper.make_tensor( + name=f"{idx_name}_value", + data_type=helper.TensorProto.INT64, + dims=[], + vals=[axis_idx], + ), + ) + + # Node 2: Get full shape + shape_node = helper.make_node( + "Shape", + inputs=[input_name], + outputs=[shape_name], + name=f"Shape_{shape_name}", + ) + + # Node 3: Gather specific dimension + gather_node = helper.make_node( + "Gather", + inputs=[shape_name, idx_name], + outputs=[output_name], + name=f"Gather_{output_name}", + axis=0, # Gather from dimension 0 of shape tensor + ) + + # Return list of nodes - this is the key pattern! + return [idx_constant, shape_node, gather_node] + + +@onnx_funcify.register(SpecifyShape) +def onnx_funcify_SpecifyShape(op, node, get_var_name, **kwargs): + """SpecifyShape is just a hint - pass through input. + + SpecifyShape doesn't change the tensor data, it just provides + shape information for optimization. In ONNX export, we can + safely ignore it and just pass the input through. + """ + # Return None - no ONNX node needed + # The input will be directly connected to uses of the output + return None + + +# Import DimShuffle after TensorVariable to avoid circular imports +try: + from pytensor.tensor.elemwise import DimShuffle + + @onnx_funcify.register(DimShuffle) + def onnx_funcify_DimShuffle(op, node, get_var_name, **kwargs): + """Convert DimShuffle to ONNX operations. + + DimShuffle handles: + - Adding dimensions (broadcasting): ('x',) -> Unsqueeze + - Removing dimensions: drop -> Squeeze + - Permuting dimensions: (1, 0) -> Transpose + + For now, we focus on the most common case: adding dimensions for broadcasting. + """ + input_names = [get_var_name(inp) for inp in node.inputs] + output_names = [get_var_name(out) for out in node.outputs] + + new_order = op.new_order + input_ndim = op.input_ndim + + # Case 1: Adding dimensions (broadcasting a scalar or expanding dims) + # Example: new_order = ('x',) means add a dimension at the start + # Example: new_order = ('x', 0) means add dimension at start, keep original dim + if "x" in new_order: + # Find positions where 'x' appears - these are the axes to unsqueeze + axes = [i for i, dim in enumerate(new_order) if dim == "x"] + + # In ONNX opset 13+, Unsqueeze requires axes as a separate input (not attribute) + # Create a constant tensor for axes + axes_tensor_name = f"{output_names[0]}_axes" + axes_tensor = numpy_helper.from_array( + np.array(axes, dtype=np.int64), name=axes_tensor_name + ) + + # Create the Unsqueeze node + node = helper.make_node( + "Unsqueeze", + inputs=[input_names[0], axes_tensor_name], + outputs=output_names, + name=f"Unsqueeze_{output_names[0]}", + ) + + # Return (node, [initializers]) + return (node, [axes_tensor]) + + # Case 2: Transpose (permuting dimensions) + # new_order is a permutation of input dimensions + elif len(new_order) == input_ndim and all( + isinstance(d, int) for d in new_order + ): + return helper.make_node( + "Transpose", + inputs=input_names, + outputs=output_names, + name=f"Transpose_{output_names[0]}", + perm=list(new_order), + ) + + # Case 3: Squeeze (removing dimensions) + # This happens when new_order has fewer elements than input_ndim + # and doesn't contain 'x' + elif len(new_order) < input_ndim: + # Find which dimensions to remove + # The dimensions to squeeze are those not in new_order + axes_to_keep = set(new_order) + axes_to_squeeze = [i for i in range(input_ndim) if i not in axes_to_keep] + + # In ONNX opset 13+, Squeeze requires axes as a separate input (not attribute) + # Create a constant tensor for axes + axes_tensor_name = f"{output_names[0]}_axes" + axes_tensor = numpy_helper.from_array( + np.array(axes_to_squeeze, dtype=np.int64), name=axes_tensor_name + ) + + # Create the Squeeze node + node = helper.make_node( + "Squeeze", + inputs=[input_names[0], axes_tensor_name], + outputs=output_names, + name=f"Squeeze_{output_names[0]}", + ) + + # Return (node, [initializers]) + return (node, [axes_tensor]) + + else: + raise NotImplementedError( + f"DimShuffle with new_order={new_order} and input_ndim={input_ndim} " + f"is not yet supported in ONNX backend." + ) + + +except ImportError: + # DimShuffle not available + pass + + +@onnx_funcify.register(Reshape) +def onnx_funcify_Reshape(op, node, get_var_name, **kwargs): + """Convert Reshape op to ONNX Reshape node. + + Reshape changes tensor dimensions without changing data. + ONNX Reshape takes two inputs: + 1. data - the tensor to reshape + 2. shape - target shape (as 1D int64 tensor) + + The shape can be constant or computed dynamically. + """ + data_name = get_var_name(node.inputs[0]) + output_name = get_var_name(node.outputs[0]) + + # The second input is the target shape + # It may be a constant or computed from other tensors + shape_input = node.inputs[1] + + if isinstance(shape_input, Constant): + # Shape is constant - create ONNX Constant node + shape_data = np.array(shape_input.data, dtype=np.int64) + shape_name = f"{output_name}_shape" + + shape_constant = helper.make_node( + "Constant", + inputs=[], + outputs=[shape_name], + name=f"Constant_{shape_name}", + value=helper.make_tensor( + name=f"{shape_name}_value", + data_type=helper.TensorProto.INT64, + dims=[len(shape_data)], + vals=shape_data.tolist(), + ), + ) + + reshape_node = helper.make_node( + "Reshape", + inputs=[data_name, shape_name], + outputs=[output_name], + name=f"Reshape_{output_name}", + ) + + return [shape_constant, reshape_node] + else: + # Shape is computed - use its name directly + shape_name = get_var_name(shape_input) + + reshape_node = helper.make_node( + "Reshape", + inputs=[data_name, shape_name], + outputs=[output_name], + name=f"Reshape_{output_name}", + ) + + return reshape_node + + +@onnx_funcify.register(Join) +def onnx_funcify_Join(op, node, get_var_name, **kwargs): + """Convert Join op to ONNX Concat node. + + Join concatenates tensors along a specified axis. + The first input (node.inputs[0]) is the axis (as a scalar tensor). + The remaining inputs (node.inputs[1:]) are the tensors to concatenate. + + ONNX Concat requires the axis as an attribute (not input), so we need + to extract the constant axis value. + """ + axis_input = node.inputs[0] + tensor_inputs = node.inputs[1:] + + # Extract axis value - it must be constant + try: + axis = get_scalar_constant_value(axis_input) + axis = int(axis) + except NotScalarConstantError: + raise NotImplementedError( + "Join with non-constant axis is not supported for ONNX export. " + "The axis must be a constant integer value." + ) + + # Get tensor input names + input_names = [get_var_name(inp) for inp in tensor_inputs] + output_name = get_var_name(node.outputs[0]) + + # Create ONNX Concat node + concat_node = helper.make_node( + "Concat", + inputs=input_names, + outputs=[output_name], + name=f"Concat_{output_name}", + axis=axis, + ) + + return concat_node + + +@onnx_funcify.register(Split) +def onnx_funcify_Split(op, node, get_var_name, **kwargs): + """Convert Split op to ONNX Split node. + + Split partitions a tensor along a specified axis. + PyTensor Split takes: (tensor, axis, splits_size) as inputs + where splits_size defines the size of each output chunk. + + ONNX Split takes the tensor as input and axis/split as attributes. + """ + # Get input tensor + input_tensor = node.inputs[0] + axis_input = node.inputs[1] + splits_input = node.inputs[2] + + input_name = get_var_name(input_tensor) + output_names = [get_var_name(out) for out in node.outputs] + + # Extract axis - must be constant + try: + axis = get_scalar_constant_value(axis_input) + axis = int(axis) + except NotScalarConstantError: + raise NotImplementedError( + "Split with non-constant axis is not supported for ONNX export." + ) + + # Extract splits - must be constant + # splits_input is typically a 1D array of split sizes + # In ONNX opset 13+, split is provided as a second input tensor (not attribute) + if isinstance(splits_input, Constant): + splits_data = splits_input.data + if np.isscalar(splits_data): + # If it's a scalar, it means uniform split + # Number of splits = number of outputs + splits = np.array([int(splits_data)] * len(node.outputs), dtype=np.int64) + else: + # It's an array of split sizes + splits = np.array([int(s) for s in splits_data], dtype=np.int64) + else: + raise NotImplementedError( + "Split with non-constant split sizes is not supported for ONNX export. " + "The split sizes must be constant values." + ) + + # Create constant node for split sizes (required in opset 13+) + split_name = f"{output_names[0]}_split" + split_constant = helper.make_node( + "Constant", + inputs=[], + outputs=[split_name], + name=f"Constant_{split_name}", + value=helper.make_tensor( + name=f"{split_name}_value", + data_type=helper.TensorProto.INT64, + dims=[len(splits)], + vals=splits.tolist(), + ), + ) + + # Create ONNX Split node with split as an input + split_node = helper.make_node( + "Split", + inputs=[input_name, split_name], + outputs=output_names, + name=f"Split_{output_names[0]}", + axis=axis, + ) + + return [split_constant, split_node] diff --git a/pytensor/link/onnx/dispatch/subtensor.py b/pytensor/link/onnx/dispatch/subtensor.py new file mode 100644 index 0000000000..e0c5da4d76 --- /dev/null +++ b/pytensor/link/onnx/dispatch/subtensor.py @@ -0,0 +1,436 @@ +"""ONNX conversion for subtensor (slicing) operations.""" + +import sys + +import numpy as np +from onnx import helper, numpy_helper + +from pytensor.graph.basic import Constant +from pytensor.link.onnx.dispatch.basic import onnx_funcify +from pytensor.tensor.subtensor import ( + AdvancedSubtensor, + AdvancedSubtensor1, + IncSubtensor, + Subtensor, +) + + +@onnx_funcify.register(Subtensor) +def onnx_funcify_Subtensor(op, node, get_var_name, **kwargs): + """Convert Subtensor (slicing) to ONNX Slice node. + + Subtensor performs array slicing like x[start:stop:step]. + + ONNX Slice (opset 11+) takes inputs: + - data: the tensor to slice + - starts: starting indices for each axis (1D tensor) + - ends: ending indices for each axis (1D tensor) + - axes: which axes to slice (optional, 1D tensor) + - steps: step size for each axis (optional, 1D tensor) + + Key challenges: + 1. PyTensor idx_list contains Type objects (placeholders) and slice objects + 2. Actual slice bounds are in node.inputs[1:] as Constants or Variables + 3. Scalar indices reduce dimensionality (not supported by Slice alone) + 4. Negative indices must be converted using Shape operations + + For now, we focus on basic slicing with constant bounds. + """ + from pytensor.tensor.subtensor import indices_from_subtensor + + input_name = get_var_name(node.inputs[0]) + output_name = get_var_name(node.outputs[0]) + + # Reconstruct the actual slice objects from op.idx_list and node.inputs + # This gives us slice objects with actual Constant values + actual_indices = indices_from_subtensor(node.inputs[1:], op.idx_list) + + # For now, we only handle pure slice objects (not scalar indices) + # Scalar indices would reduce dimensionality and require Gather + Squeeze + if not all(isinstance(idx, slice) for idx in actual_indices): + raise NotImplementedError( + f"Subtensor with scalar indices not yet supported. " + f"Got indices: {actual_indices}. " + f"Only slice objects (e.g., x[1:3]) are supported." + ) + + # Extract starts, ends, steps, axes from slice objects + starts = [] + ends = [] + steps = [] + axes = [] + + has_negative_indices = False + has_non_constant_bounds = False + + for axis, idx in enumerate(actual_indices): + if isinstance(idx, slice): + # Get start, stop, step from the slice + # These might be None, int, or Constant Variables + start = idx.start + stop = idx.stop + step = idx.step + + # Convert None to appropriate defaults + if start is None: + start_val = 0 + elif isinstance(start, Constant): + start_val = int(start.data) + elif isinstance(start, int): + start_val = start + else: + # Dynamic/non-constant start - not yet supported + has_non_constant_bounds = True + start_val = 0 # placeholder + + if stop is None: + stop_val = sys.maxsize + elif isinstance(stop, Constant): + stop_val = int(stop.data) + elif isinstance(stop, int): + stop_val = stop + else: + # Dynamic/non-constant stop + has_non_constant_bounds = True + stop_val = sys.maxsize # placeholder + + if step is None: + step_val = 1 + elif isinstance(step, Constant): + step_val = int(step.data) + elif isinstance(step, int): + step_val = step + else: + # Dynamic/non-constant step + has_non_constant_bounds = True + step_val = 1 # placeholder + + # Check for negative indices + if start_val < 0 or stop_val < 0: + has_negative_indices = True + + starts.append(start_val) + ends.append(stop_val) + steps.append(step_val) + axes.append(axis) + + # Check for unsupported cases + if has_non_constant_bounds: + raise NotImplementedError( + "Subtensor with dynamic (non-constant) slice bounds not yet supported. " + "All start, stop, step values must be constants at export time." + ) + + # If no slicing needed (all slices are [:]), pass through + if not starts: + return None + + if has_negative_indices: + raise NotImplementedError( + f"Subtensor with negative indices not yet implemented. " + f"Please use non-negative indices for now. " + f"Got starts={starts}, ends={ends}" + ) + + # Simple case: all indices are non-negative constants + # Create constant tensors for starts, ends, axes, steps + starts_name = f"{output_name}_starts" + ends_name = f"{output_name}_ends" + axes_name = f"{output_name}_axes" + steps_name = f"{output_name}_steps" + + # Create constants as initializers + starts_tensor = numpy_helper.from_array( + np.array(starts, dtype=np.int64), name=starts_name + ) + ends_tensor = numpy_helper.from_array( + np.array(ends, dtype=np.int64), name=ends_name + ) + axes_tensor = numpy_helper.from_array( + np.array(axes, dtype=np.int64), name=axes_name + ) + steps_tensor = numpy_helper.from_array( + np.array(steps, dtype=np.int64), name=steps_name + ) + + # Create Slice node with input tensors + slice_node = helper.make_node( + "Slice", + inputs=[input_name, starts_name, ends_name, axes_name, steps_name], + outputs=[output_name], + name=f"Slice_{output_name}", + ) + + # Return (node, initializers) + return (slice_node, [starts_tensor, ends_tensor, axes_tensor, steps_tensor]) + + +@onnx_funcify.register(AdvancedSubtensor1) +def onnx_funcify_AdvancedSubtensor1(op, node, get_var_name, **kwargs): + """Convert AdvancedSubtensor1 to ONNX Gather node. + + AdvancedSubtensor1 performs integer array indexing like x[[0, 2, 5]]. + This maps directly to ONNX Gather operation. + + Example: + x = pt.vector('x') + indices = pt.vector('indices', dtype='int64') + y = x[indices] # AdvancedSubtensor1 + + ONNX: Gather(x, indices, axis=0) + """ + data_name = get_var_name(node.inputs[0]) + indices_name = get_var_name(node.inputs[1]) + output_name = get_var_name(node.outputs[0]) + + gather_node = helper.make_node( + "Gather", + inputs=[data_name, indices_name], + outputs=[output_name], + name=f"Gather_{output_name}", + axis=0, # AdvancedSubtensor1 operates on axis 0 + ) + + return gather_node + + +@onnx_funcify.register(AdvancedSubtensor) +def onnx_funcify_AdvancedSubtensor(op, node, get_var_name, **kwargs): + """Convert AdvancedSubtensor to ONNX Gather or GatherND node. + + AdvancedSubtensor implements NumPy's advanced indexing. + + For simple cases (single integer array on axis 0), this maps to Gather. + For complex multi-dimensional indexing, this would require GatherND. + + For now, we handle the simple case: x[indices] where indices is a vector. + This is the most common case and matches AdvancedSubtensor1 behavior. + + Example: + x = pt.vector('x') + indices = pt.vector('indices', dtype='int64') + y = x[indices] # AdvancedSubtensor (gets optimized to AdvancedSubtensor1 in normal mode) + + ONNX: Gather(x, indices, axis=0) + """ + # For now, we only handle the simple case that matches AdvancedSubtensor1 + # More complex cases would need GatherND or multiple operations + + if len(node.inputs) != 2: + raise NotImplementedError( + f"AdvancedSubtensor with {len(node.inputs)} inputs not supported. " + f"Only simple integer array indexing (2 inputs) is currently supported." + ) + + data_name = get_var_name(node.inputs[0]) + indices_name = get_var_name(node.inputs[1]) + output_name = get_var_name(node.outputs[0]) + + # Use Gather for simple indexing on axis 0 + gather_node = helper.make_node( + "Gather", + inputs=[data_name, indices_name], + outputs=[output_name], + name=f"Gather_{output_name}", + axis=0, # Simple indexing operates on axis 0 + ) + + return gather_node + + +@onnx_funcify.register(IncSubtensor) +def onnx_funcify_IncSubtensor(op, node, get_var_name, **kwargs): + """Convert IncSubtensor to ONNX Scatter operations. + + IncSubtensor has two modes: + 1. set_subtensor: x[indices] = values (op.set_instead_of_inc=True) + 2. inc_subtensor: x[indices] += values (op.set_instead_of_inc=False) + + ONNX doesn't have in-place ops, so we use ScatterElements or ScatterND. + + For basic slicing (e.g., x[2:5] = values), we implement this as: + 1. Extract the slice range as indices using ONNX Range + 2. Use ScatterElements to scatter the values at those indices + 3. For inc_subtensor, first extract current values, add, then scatter + + This implementation handles the basic slicing case with constant bounds. + Advanced cases (negative indices, dynamic bounds, multi-dim) are not yet supported. + """ + from pytensor.tensor.subtensor import indices_from_subtensor + + # Inputs: [data, values, ...slice_bounds...] + # Output: modified data + data_name = get_var_name(node.inputs[0]) + values_name = get_var_name(node.inputs[1]) + output_name = get_var_name(node.outputs[0]) + + # Reconstruct the actual slice objects from op.idx_list and node.inputs[2:] + actual_indices = indices_from_subtensor(node.inputs[2:], op.idx_list) + + # For now, only handle simple 1D slicing on the first axis + # x[start:stop] = values + if len(actual_indices) != 1 or not isinstance(actual_indices[0], slice): + raise NotImplementedError( + f"IncSubtensor only supports basic 1D slicing for ONNX export. " + f"Got indices: {actual_indices}. " + f"Only single-axis slice objects (e.g., x[2:5]) are supported." + ) + + slice_obj = actual_indices[0] + start = slice_obj.start + stop = slice_obj.stop + step = slice_obj.step + + # Extract constant values + if start is None: + start_val = 0 + elif isinstance(start, Constant): + start_val = int(start.data) + elif isinstance(start, int): + start_val = start + else: + raise NotImplementedError( + "IncSubtensor with dynamic start index not yet supported" + ) + + if stop is None: + raise NotImplementedError("IncSubtensor with unbounded stop not yet supported") + elif isinstance(stop, Constant): + stop_val = int(stop.data) + elif isinstance(stop, int): + stop_val = stop + else: + raise NotImplementedError( + "IncSubtensor with dynamic stop index not yet supported" + ) + + if step is None: + step_val = 1 + elif isinstance(step, Constant): + step_val = int(step.data) + elif isinstance(step, int): + step_val = step + else: + raise NotImplementedError("IncSubtensor with dynamic step not yet supported") + + if step_val != 1: + raise NotImplementedError("IncSubtensor with step != 1 not yet supported") + + if start_val < 0 or stop_val < 0: + raise NotImplementedError( + "IncSubtensor with negative indices not yet supported" + ) + + # Build ONNX graph: + # 1. Create indices tensor: [start, start+1, ..., stop-1] + # 2. For set_subtensor: ScatterElements(data, indices, values, axis=0) + # 3. For inc_subtensor: current = Gather(data, indices), + # new_values = Add(current, values), + # ScatterElements(data, indices, new_values, axis=0) + + nodes = [] + + # Create Range node to generate indices [start, start+1, ..., stop-1] + indices_name = f"{output_name}_indices" + start_name = f"{output_name}_start" + stop_name = f"{output_name}_stop" + step_name = f"{output_name}_step" + + # Create Constant nodes for start, stop, step + start_const = helper.make_node( + "Constant", + inputs=[], + outputs=[start_name], + name=f"Constant_{start_name}", + value=helper.make_tensor( + name=f"{start_name}_value", + data_type=helper.TensorProto.INT64, + dims=[], + vals=[start_val], + ), + ) + nodes.append(start_const) + + stop_const = helper.make_node( + "Constant", + inputs=[], + outputs=[stop_name], + name=f"Constant_{stop_name}", + value=helper.make_tensor( + name=f"{stop_name}_value", + data_type=helper.TensorProto.INT64, + dims=[], + vals=[stop_val], + ), + ) + nodes.append(stop_const) + + step_const = helper.make_node( + "Constant", + inputs=[], + outputs=[step_name], + name=f"Constant_{step_name}", + value=helper.make_tensor( + name=f"{step_name}_value", + data_type=helper.TensorProto.INT64, + dims=[], + vals=[step_val], + ), + ) + nodes.append(step_const) + + # Range node: creates [start, start+1, ..., stop-1] + range_node = helper.make_node( + "Range", + inputs=[start_name, stop_name, step_name], + outputs=[indices_name], + name=f"Range_{indices_name}", + ) + nodes.append(range_node) + + # Handle set_subtensor vs inc_subtensor + if op.set_instead_of_inc: + # set_subtensor: directly scatter the new values + scatter_node = helper.make_node( + "ScatterElements", + inputs=[data_name, indices_name, values_name], + outputs=[output_name], + name=f"ScatterElements_{output_name}", + axis=0, + ) + nodes.append(scatter_node) + else: + # inc_subtensor: gather current, add, then scatter + # 1. Gather current values + current_values_name = f"{output_name}_current" + gather_node = helper.make_node( + "Gather", + inputs=[data_name, indices_name], + outputs=[current_values_name], + name=f"Gather_{current_values_name}", + axis=0, + ) + nodes.append(gather_node) + + # 2. Add current + new values + sum_values_name = f"{output_name}_sum" + add_node = helper.make_node( + "Add", + inputs=[current_values_name, values_name], + outputs=[sum_values_name], + name=f"Add_{sum_values_name}", + ) + nodes.append(add_node) + + # 3. Scatter the summed values + scatter_node = helper.make_node( + "ScatterElements", + inputs=[data_name, indices_name, sum_values_name], + outputs=[output_name], + name=f"ScatterElements_{output_name}", + axis=0, + ) + nodes.append(scatter_node) + + # Return list of nodes + return nodes diff --git a/pytensor/link/onnx/dispatch/tensor_basic.py b/pytensor/link/onnx/dispatch/tensor_basic.py new file mode 100644 index 0000000000..d7b1d2a067 --- /dev/null +++ b/pytensor/link/onnx/dispatch/tensor_basic.py @@ -0,0 +1,435 @@ +"""ONNX conversion for tensor basic operations (allocation, etc.).""" + +import numpy as np +from onnx import helper + +from pytensor.graph.basic import Constant +from pytensor.link.onnx.dispatch.basic import onnx_funcify +from pytensor.tensor.basic import Alloc, AllocEmpty, ARange, MakeVector + + +@onnx_funcify.register(Alloc) +def onnx_funcify_Alloc(op, node, get_var_name, **kwargs): + """Convert Alloc op to ONNX Expand node. + + Alloc broadcasts a value to a specified shape. + ONNX Expand does the same thing. + + Example: + x = pt.alloc(5.0, 3, 4) # Create 3x4 array filled with 5.0 + + ONNX: Expand(value=5.0, shape=[3, 4]) -> result + """ + value_input = node.inputs[0] + shape_inputs = node.inputs[1:] + + value_name = get_var_name(value_input) + output_name = get_var_name(node.outputs[0]) + + # Create shape tensor from shape inputs + # Shape inputs are scalars that specify each dimension + shape_name = f"{output_name}_shape" + nodes = [] + + if all(isinstance(inp, Constant) for inp in shape_inputs): + # All shape dimensions are constants + shape_data = np.array([inp.data for inp in shape_inputs], dtype=np.int64) + + shape_constant = helper.make_node( + "Constant", + inputs=[], + outputs=[shape_name], + name=f"Constant_{shape_name}", + value=helper.make_tensor( + name=f"{shape_name}_value", + data_type=helper.TensorProto.INT64, + dims=[len(shape_data)], + vals=shape_data.tolist(), + ), + ) + nodes.append(shape_constant) + + expand_node = helper.make_node( + "Expand", + inputs=[value_name, shape_name], + outputs=[output_name], + name=f"Expand_{output_name}", + ) + nodes.append(expand_node) + + return nodes + else: + # Some shape dimensions are dynamic - need to use Concat + # First, unsqueeze each scalar shape dimension to make it 1D + unsqueezed_names = [] + for i, inp in enumerate(shape_inputs): + if isinstance(inp, Constant): + # Create constant for this dimension + dim_name = f"{shape_name}_dim{i}" + dim_constant = helper.make_node( + "Constant", + inputs=[], + outputs=[dim_name], + name=f"Constant_{dim_name}", + value=helper.make_tensor( + name=f"{dim_name}_value", + data_type=helper.TensorProto.INT64, + dims=[1], + vals=[inp.data], + ), + ) + nodes.append(dim_constant) + unsqueezed_names.append(dim_name) + else: + # Dynamic dimension - need to unsqueeze it + inp_name = get_var_name(inp) + unsqueezed_name = f"{shape_name}_unsqueezed{i}" + + # Create axes constant for Unsqueeze + axes_name = f"{unsqueezed_name}_axes" + axes_constant = helper.make_node( + "Constant", + inputs=[], + outputs=[axes_name], + name=f"Constant_{axes_name}", + value=helper.make_tensor( + name=f"{axes_name}_value", + data_type=helper.TensorProto.INT64, + dims=[1], + vals=[0], + ), + ) + nodes.append(axes_constant) + + unsqueeze_node = helper.make_node( + "Unsqueeze", + inputs=[inp_name, axes_name], + outputs=[unsqueezed_name], + name=f"Unsqueeze_{unsqueezed_name}", + ) + nodes.append(unsqueeze_node) + unsqueezed_names.append(unsqueezed_name) + + # Concatenate shape elements into shape vector + concat_node = helper.make_node( + "Concat", + inputs=unsqueezed_names, + outputs=[shape_name], + name=f"Concat_{shape_name}", + axis=0, + ) + nodes.append(concat_node) + + expand_node = helper.make_node( + "Expand", + inputs=[value_name, shape_name], + outputs=[output_name], + name=f"Expand_{output_name}", + ) + nodes.append(expand_node) + + return nodes + + +@onnx_funcify.register(AllocEmpty) +def onnx_funcify_AllocEmpty(op, node, get_var_name, **kwargs): + """Convert AllocEmpty to ONNX ConstantOfShape. + + AllocEmpty creates uninitialized array. In ONNX, we use + ConstantOfShape with value 0 (values don't matter, just shape/dtype). + + Example: + x = pt.AllocEmpty('float32')(3, 4) # Create uninitialized 3x4 array + + ONNX: ConstantOfShape(shape=[3, 4], value=0.0) -> result + """ + shape_inputs = node.inputs + output_name = get_var_name(node.outputs[0]) + + # Create shape tensor + shape_name = f"{output_name}_shape" + nodes = [] + + if all(isinstance(inp, Constant) for inp in shape_inputs): + # Constant shape + shape_data = np.array([inp.data for inp in shape_inputs], dtype=np.int64) + + shape_constant = helper.make_node( + "Constant", + inputs=[], + outputs=[shape_name], + name=f"Constant_{shape_name}", + value=helper.make_tensor( + name=f"{shape_name}_value", + data_type=helper.TensorProto.INT64, + dims=[len(shape_data)], + vals=shape_data.tolist(), + ), + ) + nodes.append(shape_constant) + else: + # Dynamic shape - similar to Alloc + unsqueezed_names = [] + for i, inp in enumerate(shape_inputs): + if isinstance(inp, Constant): + dim_name = f"{shape_name}_dim{i}" + dim_constant = helper.make_node( + "Constant", + inputs=[], + outputs=[dim_name], + name=f"Constant_{dim_name}", + value=helper.make_tensor( + name=f"{dim_name}_value", + data_type=helper.TensorProto.INT64, + dims=[1], + vals=[inp.data], + ), + ) + nodes.append(dim_constant) + unsqueezed_names.append(dim_name) + else: + inp_name = get_var_name(inp) + unsqueezed_name = f"{shape_name}_unsqueezed{i}" + + axes_name = f"{unsqueezed_name}_axes" + axes_constant = helper.make_node( + "Constant", + inputs=[], + outputs=[axes_name], + name=f"Constant_{axes_name}", + value=helper.make_tensor( + name=f"{axes_name}_value", + data_type=helper.TensorProto.INT64, + dims=[1], + vals=[0], + ), + ) + nodes.append(axes_constant) + + unsqueeze_node = helper.make_node( + "Unsqueeze", + inputs=[inp_name, axes_name], + outputs=[unsqueezed_name], + name=f"Unsqueeze_{unsqueezed_name}", + ) + nodes.append(unsqueeze_node) + unsqueezed_names.append(unsqueezed_name) + + concat_node = helper.make_node( + "Concat", + inputs=unsqueezed_names, + outputs=[shape_name], + name=f"Concat_{shape_name}", + axis=0, + ) + nodes.append(concat_node) + + # ConstantOfShape with value 0 + dtype = op.dtype + dtype_map = { + "float32": helper.TensorProto.FLOAT, + "float64": helper.TensorProto.DOUBLE, + "int32": helper.TensorProto.INT32, + "int64": helper.TensorProto.INT64, + } + onnx_dtype = dtype_map.get(dtype, helper.TensorProto.FLOAT) + + constant_of_shape_node = helper.make_node( + "ConstantOfShape", + inputs=[shape_name], + outputs=[output_name], + name=f"ConstantOfShape_{output_name}", + value=helper.make_tensor( + name=f"{output_name}_value", + data_type=onnx_dtype, + dims=[1], + vals=[0], + ), + ) + nodes.append(constant_of_shape_node) + + return nodes + + +@onnx_funcify.register(MakeVector) +def onnx_funcify_MakeVector(op, node, get_var_name, **kwargs): + """Convert MakeVector to ONNX Concat of Unsqueezed scalars. + + MakeVector creates a 1D vector from scalars. + + Example: + x = pt.make_vector(1.0, 2.0, 3.0) # Create [1.0, 2.0, 3.0] + + ONNX: + Unsqueeze(1.0, axes=[0]) -> [1.0] + Unsqueeze(2.0, axes=[0]) -> [2.0] + Unsqueeze(3.0, axes=[0]) -> [3.0] + Concat([1.0], [2.0], [3.0], axis=0) -> [1.0, 2.0, 3.0] + """ + output_name = get_var_name(node.outputs[0]) + + if len(node.inputs) == 0: + # Empty vector + dtype = op.dtype + dtype_map = { + "float32": helper.TensorProto.FLOAT, + "float64": helper.TensorProto.DOUBLE, + "int32": helper.TensorProto.INT32, + "int64": helper.TensorProto.INT64, + } + onnx_dtype = dtype_map.get(dtype, helper.TensorProto.FLOAT) + + empty_constant = helper.make_node( + "Constant", + inputs=[], + outputs=[output_name], + name=f"Constant_{output_name}", + value=helper.make_tensor( + name=f"{output_name}_value", + data_type=onnx_dtype, + dims=[0], + vals=[], + ), + ) + + return empty_constant + + # Unsqueeze each scalar to shape (1,), then concatenate + nodes = [] + unsqueezed_names = [] + + for i, inp in enumerate(node.inputs): + input_name = get_var_name(inp) + unsqueezed_name = f"{output_name}_elem_{i}" + + # Create axes constant for Unsqueeze + axes_name = f"{unsqueezed_name}_axes" + axes_constant = helper.make_node( + "Constant", + inputs=[], + outputs=[axes_name], + name=f"Constant_{axes_name}", + value=helper.make_tensor( + name=f"{axes_name}_value", + data_type=helper.TensorProto.INT64, + dims=[1], + vals=[0], + ), + ) + nodes.append(axes_constant) + + unsqueeze_node = helper.make_node( + "Unsqueeze", + inputs=[input_name, axes_name], + outputs=[unsqueezed_name], + name=f"Unsqueeze_{unsqueezed_name}", + ) + nodes.append(unsqueeze_node) + unsqueezed_names.append(unsqueezed_name) + + # Concatenate all elements + concat_node = helper.make_node( + "Concat", + inputs=unsqueezed_names, + outputs=[output_name], + name=f"Concat_{output_name}", + axis=0, + ) + nodes.append(concat_node) + + return nodes + + +@onnx_funcify.register(ARange) +def onnx_funcify_ARange(op, node, get_var_name, **kwargs): + """Convert ARange to ONNX Range node. + + IMPORTANT: ONNX Range requires constant inputs (start, limit, delta). + Dynamic ranges are not supported in ONNX standard. + + Example: + x = pt.arange(0, 10, 2, dtype='int64') # Create [0, 2, 4, 6, 8] + + ONNX: + Constant(0) -> start + Constant(10) -> stop + Constant(2) -> step + Range(start, stop, step) -> [0, 2, 4, 6, 8] + """ + start_input = node.inputs[0] + stop_input = node.inputs[1] + step_input = node.inputs[2] + + # Verify all inputs are constants + if not all( + isinstance(inp, Constant) for inp in [start_input, stop_input, step_input] + ): + raise NotImplementedError( + "ARange with dynamic (non-constant) inputs is not supported in ONNX. " + "All start, stop, step values must be constants." + ) + + output_name = get_var_name(node.outputs[0]) + + # Create constant nodes for start, limit, delta + start_name = f"{output_name}_start" + stop_name = f"{output_name}_stop" + step_name = f"{output_name}_step" + + dtype = op.dtype + dtype_map = { + "int32": helper.TensorProto.INT32, + "int64": helper.TensorProto.INT64, + "float32": helper.TensorProto.FLOAT, + "float64": helper.TensorProto.DOUBLE, + } + onnx_dtype = dtype_map.get(dtype, helper.TensorProto.INT64) + + start_constant = helper.make_node( + "Constant", + inputs=[], + outputs=[start_name], + name=f"Constant_{start_name}", + value=helper.make_tensor( + name=f"{start_name}_value", + data_type=onnx_dtype, + dims=[], + vals=[int(start_input.data) if "int" in dtype else float(start_input.data)], + ), + ) + + stop_constant = helper.make_node( + "Constant", + inputs=[], + outputs=[stop_name], + name=f"Constant_{stop_name}", + value=helper.make_tensor( + name=f"{stop_name}_value", + data_type=onnx_dtype, + dims=[], + vals=[int(stop_input.data) if "int" in dtype else float(stop_input.data)], + ), + ) + + step_constant = helper.make_node( + "Constant", + inputs=[], + outputs=[step_name], + name=f"Constant_{step_name}", + value=helper.make_tensor( + name=f"{step_name}_value", + data_type=onnx_dtype, + dims=[], + vals=[int(step_input.data) if "int" in dtype else float(step_input.data)], + ), + ) + + # Range node + range_node = helper.make_node( + "Range", + inputs=[start_name, stop_name, step_name], + outputs=[output_name], + name=f"Range_{output_name}", + ) + + return [start_constant, stop_constant, step_constant, range_node] diff --git a/pytensor/link/onnx/export.py b/pytensor/link/onnx/export.py new file mode 100644 index 0000000000..58c167c141 --- /dev/null +++ b/pytensor/link/onnx/export.py @@ -0,0 +1,135 @@ +"""High-level ONNX export API for PyTensor.""" + +import onnx + +from pytensor.compile.function import function +from pytensor.compile.mode import Mode +from pytensor.link.onnx.dispatch import onnx_funcify +from pytensor.link.onnx.linker import ONNXLinker + + +def export_onnx(inputs, outputs, filename, *, opset_version=18, **kwargs): + """Export a PyTensor graph to an ONNX file. + + Parameters + ---------- + inputs : list of Variable + Input variables for the graph + outputs : Variable or list of Variable + Output variable(s) for the graph + filename : str or Path + Path where the ONNX model will be saved + opset_version : int, default=18 + ONNX opset version to use + **kwargs : dict + Additional keyword arguments + + Returns + ------- + onnx.ModelProto + The created ONNX model + + Examples + -------- + >>> import pytensor.tensor as pt + >>> x = pt.vector("x", dtype="float32") + >>> y = x * 2 + 1 + >>> model = export_onnx([x], y, "model.onnx") + """ + # Ensure outputs is a list + if not isinstance(outputs, (list, tuple)): + outputs = [outputs] + + # Create a FunctionGraph (without cloning to preserve structure) + from pytensor.compile.builders import construct_nominal_fgraph + + # construct_nominal_fgraph returns a tuple: (fgraph, updates, unused_inputs, unused_outputs) + result = construct_nominal_fgraph(inputs, outputs) + fgraph = result[0] if isinstance(result, tuple) else result + + # Convert to ONNX ModelProto + onnx_model = onnx_funcify(fgraph, opset_version=opset_version, **kwargs) + + # Save to file + onnx.save(onnx_model, filename) + + return onnx_model + + +def compile_onnx(inputs, outputs, *, opset_version=18, **kwargs): + """Compile a PyTensor graph using the ONNX backend. + + This creates a function that executes the graph via ONNX Runtime. + + Parameters + ---------- + inputs : list of Variable + Input variables for the graph + outputs : Variable or list of Variable + Output variable(s) for the graph + opset_version : int, default=18 + ONNX opset version to use + **kwargs : dict + Additional keyword arguments passed to pytensor.function + + Returns + ------- + Function + Compiled function that executes via ONNX Runtime + + Examples + -------- + >>> import pytensor.tensor as pt + >>> import numpy as np + >>> x = pt.vector("x", dtype="float32") + >>> y = x * 2 + 1 + >>> fn = compile_onnx([x], y) + >>> result = fn(np.array([1, 2, 3], dtype="float32")) + """ + # Create ONNX mode + onnx_linker = ONNXLinker(opset_version=opset_version) + onnx_mode = Mode(linker=onnx_linker, optimizer=None) + + # Compile the function + return function(inputs, outputs, mode=onnx_mode, **kwargs) + + +def export_function_onnx(fn, filename, *, opset_version=18): + """Export an already-compiled PyTensor function to ONNX. + + Parameters + ---------- + fn : Function + A compiled PyTensor function + filename : str or Path + Path where the ONNX model will be saved + opset_version : int, default=18 + ONNX opset version to use (if the function wasn't compiled with ONNX) + + Returns + ------- + onnx.ModelProto + The created ONNX model + + Examples + -------- + >>> import pytensor + >>> import pytensor.tensor as pt + >>> x = pt.vector("x", dtype="float32") + >>> y = pt.sqrt(x) + >>> fn = pytensor.function([x], y) + >>> model = export_function_onnx(fn, "sqrt_model.onnx") + """ + # Check if the function was already compiled with ONNX linker + if isinstance(fn.maker.linker, ONNXLinker): + # Already have ONNX model + onnx_model = fn.maker.linker.onnx_model + else: + # Need to convert the FunctionGraph to ONNX + fgraph = fn.maker.fgraph + onnx_model = onnx_funcify(fgraph, opset_version=opset_version) + + # Save to file + onnx.save(onnx_model, filename) + + return onnx_model diff --git a/pytensor/link/onnx/linker.py b/pytensor/link/onnx/linker.py new file mode 100644 index 0000000000..7a4b10e202 --- /dev/null +++ b/pytensor/link/onnx/linker.py @@ -0,0 +1,169 @@ +"""ONNX linker for PyTensor.""" + +import numpy as np +import onnx +import onnxruntime as ort + +from pytensor.link.basic import JITLinker + + +class ONNXLinker(JITLinker): + """A `Linker` that converts PyTensor graphs to ONNX models and executes them with ONNX Runtime. + + This linker: + 1. Converts the PyTensor FunctionGraph to an ONNX ModelProto + 2. Creates an ONNX Runtime InferenceSession + 3. Returns a function that executes the model via ONNX Runtime + """ + + def __init__(self, opset_version=18, *args, **kwargs): + """Initialize the ONNX linker. + + Parameters + ---------- + opset_version : int, default=18 + ONNX opset version to use for the model + """ + super().__init__(*args, **kwargs) + self.opset_version = opset_version + self.onnx_model = None + + def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs): + """Convert FunctionGraph to ONNX and create executable function. + + Parameters + ---------- + fgraph : FunctionGraph + The function graph to convert + input_storage : list + Storage for inputs + storage_map : dict + Mapping from variables to storage + + Returns + ------- + callable + Function that executes the ONNX model + """ + from pytensor.link.onnx.dispatch import onnx_funcify + + # Convert the FunctionGraph to ONNX ModelProto + self.onnx_model = onnx_funcify( + fgraph, + opset_version=self.opset_version, + input_storage=input_storage, + storage_map=storage_map, + **kwargs, + ) + + # Create ONNX Runtime function + return self._create_onnx_runtime_function(fgraph) + + def _create_onnx_runtime_function(self, fgraph): + """Create a function that executes the ONNX model via ONNX Runtime. + + Parameters + ---------- + fgraph : FunctionGraph + The function graph (for input/output info) + + Returns + ------- + callable + Function that takes inputs and returns outputs + """ + # Serialize the model to bytes + model_bytes = self.onnx_model.SerializeToString() + + # Create ONNX Runtime session + sess_options = ort.SessionOptions() + sess_options.log_severity_level = 3 # Error level only + session = ort.InferenceSession(model_bytes, sess_options) + + # Get input and output names from the ONNX model + input_names = [inp.name for inp in self.onnx_model.graph.input] + output_names = [out.name for out in self.onnx_model.graph.output] + + def onnx_runtime_function(*args): + """Execute the ONNX model with ONNX Runtime. + + Parameters + ---------- + *args : array-like + Input values matching the graph inputs + + Returns + ------- + array or tuple of arrays + Output values from the ONNX model + """ + # Prepare inputs as numpy arrays + input_dict = {} + for name, arg in zip(input_names, args): + # Ensure inputs are numpy arrays with correct dtype + if not isinstance(arg, np.ndarray): + arg = np.array(arg) + input_dict[name] = arg + + # Run the model + outputs = session.run(output_names, input_dict) + + # Return outputs as tuple to match expected format + # (even for single outputs, as the thunk expects to iterate) + return tuple(outputs) + + return onnx_runtime_function + + def create_thunk_inputs(self, storage_map): + """Create thunk inputs from storage map. + + For ONNX, we simply return the storage list for each input variable. + + Parameters + ---------- + storage_map : dict + Mapping from variables to storage + + Returns + ------- + list + List of storage lists for inputs + """ + return [storage_map[n] for n in self.fgraph.inputs] + + def jit_compile(self, fn): + """JIT compile a converted FunctionGraph. + + For ONNX, there is no additional JIT compilation needed - + the function returned by fgraph_convert already executes via ONNX Runtime. + + Parameters + ---------- + fn : callable + The function to compile + + Returns + ------- + callable + The same function (no additional compilation needed) + """ + # No JIT compilation needed for ONNX - already compiled to ONNX Runtime + return fn + + def export_to_file(self, filename): + """Export the ONNX model to a file. + + Parameters + ---------- + filename : str or Path + Path to save the ONNX model + + Raises + ------ + ValueError + If no model has been created yet + """ + if self.onnx_model is None: + raise ValueError("No ONNX model available. Compile a function first.") + + onnx.save(self.onnx_model, filename) diff --git a/pytensor/link/onnx/rewrite.py b/pytensor/link/onnx/rewrite.py new file mode 100644 index 0000000000..604e986a4f --- /dev/null +++ b/pytensor/link/onnx/rewrite.py @@ -0,0 +1,48 @@ +"""Graph rewrites for ONNX backend compatibility. + +These rewrites expand operations that don't have direct ONNX equivalents +into compositions of basic operations that do have ONNX support. +""" + +import numpy as np + +from pytensor import scalar as ps +from pytensor.graph.rewriting.basic import node_rewriter +from pytensor.tensor.elemwise import Elemwise +from pytensor.tensor.math import add, exp, log, sub + + +@node_rewriter([Elemwise]) +def expand_log1p_expm1_for_onnx(fgraph, node): + """Expand log1p and expm1 into basic operations for ONNX export. + + ONNX doesn't have Log1p or Expm1 operators in the standard opset. + We expand them as: + - log1p(x) -> log(1 + x) + - expm1(x) -> exp(x) - 1 + + This rewrite is specific to the ONNX backend and should be applied + before ONNX graph compilation. + """ + if not isinstance(node.op, Elemwise): + return None + + scalar_op = node.op.scalar_op + + # Expand log1p(x) -> log(1 + x) + if isinstance(scalar_op, ps.Log1p): + x = node.inputs[0] + # Create log(1 + x) + one = np.array(1, dtype=x.dtype) + result = log(add(x, one)) + return [result] + + # Expand expm1(x) -> exp(x) - 1 + if isinstance(scalar_op, ps.Expm1): + x = node.inputs[0] + # Create exp(x) - 1 + one = np.array(1, dtype=x.dtype) + result = sub(exp(x), one) + return [result] + + return None diff --git a/tests/link/onnx/__init__.py b/tests/link/onnx/__init__.py new file mode 100644 index 0000000000..96c9c0bdd7 --- /dev/null +++ b/tests/link/onnx/__init__.py @@ -0,0 +1 @@ +"""Tests for ONNX backend.""" diff --git a/tests/link/onnx/conftest.py b/tests/link/onnx/conftest.py new file mode 100644 index 0000000000..c09649e425 --- /dev/null +++ b/tests/link/onnx/conftest.py @@ -0,0 +1,57 @@ +"""Pytest configuration and fixtures for ONNX backend tests.""" + +import numpy as np +import pytest + +from pytensor.configdefaults import config + + +# Import hypothesis if available +try: + from hypothesis import HealthCheck, Phase, Verbosity, settings + + # Hypothesis profiles for different testing scenarios + settings.register_profile("dev", max_examples=10, deadline=None) + settings.register_profile( + "ci", + max_examples=100, + deadline=None, + suppress_health_check=[HealthCheck.too_slow], + ) + settings.register_profile( + "debug", + max_examples=10, + verbosity=Verbosity.verbose, + phases=[Phase.explicit, Phase.reuse, Phase.generate, Phase.target], + ) + + # Load dev profile by default + settings.load_profile("dev") +except ImportError: + # Hypothesis not available, tests will skip + pass + + +@pytest.fixture(scope="module", autouse=True) +def set_pytensor_flags(): + """Module-level PyTensor configuration.""" + with config.change_flags(cxx="", compute_test_value="ignore", floatX="float32"): + yield + + +@pytest.fixture +def rng(): + """Seeded random number generator for reproducible tests.""" + return np.random.default_rng(42) + + +@pytest.fixture +def float32_vector(rng): + """Sample float32 vector for testing.""" + return rng.normal(size=10).astype("float32") + + +@pytest.fixture +def float32_matrix(rng): + """Sample float32 matrix for testing.""" + return rng.normal(size=(5, 5)).astype("float32") diff --git a/tests/link/onnx/strategies.py b/tests/link/onnx/strategies.py new file mode 100644 index 0000000000..32e3649519 --- /dev/null +++ b/tests/link/onnx/strategies.py @@ -0,0 +1,748 @@ +"""Hypothesis strategies and operation registries for ONNX backend testing.""" + +from typing import Any + +import numpy as np +from hypothesis import strategies as st +from hypothesis.extra.numpy import array_shapes, arrays + +import pytensor.tensor as pt + + +# ============================================================================ +# HYPOTHESIS STRATEGIES (Custom Helpers) - Define first! +# ============================================================================ + + +def factorize(n): + """Simple factorization for shape generation.""" + factors = [] + d = 2 + while d * d <= n: + while n % d == 0: + factors.append(d) + n //= d + d += 1 + if n > 1: + factors.append(n) + return factors if factors else [n] + + +def compatible_shape_for_size(total_size): + """Generate shapes compatible with given total size.""" + # Simple factorizations + factors = factorize(total_size) + shapes = [ + (total_size,), + (1, total_size), + (total_size, 1), + ] + # Generate valid shapes from factors + # For 2-factor shapes, use pairs that multiply to total_size + if len(factors) >= 2: + # Use first factor and product of remaining factors + factor1 = factors[0] + remaining_product = total_size // factor1 + shapes.append((factor1, remaining_product)) + + # Also try middle split if we have at least 2 factors + if len(factors) >= 2: + mid = len(factors) // 2 + left_product = int(np.prod(factors[:mid])) + right_product = int(np.prod(factors[mid:])) + shapes.append((left_product, right_product)) + + return st.sampled_from(shapes) + + +def reshape_strategy(): + """Generate tensor and compatible reshape target.""" + + @st.composite + def strategy(draw): + # Original shape + shape = draw(array_shapes(min_dims=2, max_dims=3, min_side=2, max_side=6)) + total_size = int(np.prod(shape)) + + # Generate tensor + x = np.random.randn(*shape).astype("float32") + + # Generate compatible new shape (same total size) + new_shape = draw(compatible_shape_for_size(total_size)) + + return x, new_shape + + return strategy() + + +def concatenate_strategy(): + """Generate tensors and axis for concatenation.""" + + @st.composite + def strategy(draw): + # Generate base shape + shape = draw(array_shapes(min_dims=2, max_dims=3, min_side=2, max_side=8)) + axis = draw(st.integers(0, len(shape) - 1)) + + # Generate two tensors with same shape except along axis + a = np.random.randn(*shape).astype("float32") + + b_shape = list(shape) + b_shape[axis] = draw(st.integers(2, 8)) # Different size along axis + b = np.random.randn(*b_shape).astype("float32") + + return a, b, axis + + return strategy() + + +def tensor_with_axis_strategy(dtype="float32", allow_none=True): + """Generate tensor and valid axis for reduction operations.""" + + @st.composite + def strategy(draw): + # Generate shape + shape = draw(array_shapes(min_dims=2, max_dims=4, min_side=2, max_side=10)) + + # Generate tensor + if dtype == "bool": + x = draw(arrays(dtype=np.bool_, shape=shape)) + else: + x = draw( + arrays( + dtype=getattr(np, dtype), + shape=shape, + elements=st.floats(-10, 10, allow_nan=False, allow_infinity=False), + ) + ) + + # Generate axis + if allow_none: + axis = draw( + st.one_of( + st.none(), + st.integers(0, len(shape) - 1), + st.lists( + st.integers(0, len(shape) - 1), + min_size=1, + max_size=len(shape), + unique=True, + ), + ) + ) + else: + axis = draw(st.integers(0, len(shape) - 1)) + + return x, axis + + return strategy() + + +def alloc_strategy(): + """Generate scalar value and shape for Alloc.""" + return st.builds( + lambda val, s1, s2: (val, s1, s2), + val=st.floats(-10, 10, allow_nan=False, allow_infinity=False), + s1=st.integers(2, 10), + s2=st.integers(2, 10), + ) + + +def arange_strategy(): + """Generate valid start, stop, step for arange (constant only).""" + + @st.composite + def strategy(draw): + start = draw(st.integers(0, 5)) + stop = draw(st.integers(start + 2, start + 20)) + step = draw(st.integers(1, 3)) + return start, stop, step + + return strategy() + + +def set_subtensor_strategy(): + """Generate tensor and values for set_subtensor.""" + + @st.composite + def strategy(draw): + size = draw(st.integers(10, 20)) + x = np.arange(size, dtype="float32") + values = draw( + arrays( + dtype=np.float32, + shape=(3,), + elements=st.floats(-10, 10, allow_nan=False, allow_infinity=False), + ) + ) + return x, values + + return strategy() + + +def advanced_index_strategy(): + """Generate tensor and integer indices for advanced indexing.""" + + @st.composite + def strategy(draw): + size = draw(st.integers(10, 20)) + x = np.arange(size, dtype="float32") + indices = draw(st.lists(st.integers(0, size - 1), min_size=1, max_size=5)) + return x, np.array(indices, dtype="int64") + + return strategy() + + +def binary_float32_arrays_strategy(): + """ + Generate two float32 arrays for binary operations. + + Returns a Hypothesis strategy (lazy evaluation) that generates pairs of + arrays with identical shapes. Arrays are compatible for element-wise + operations but not tested for broadcasting in this phase. + + Shape range: 1-3 dimensions, 2-10 elements per dimension + Value range: [-10, 10] (finite values only) + + Note: Broadcasting validation is deferred to Phase 2. + """ + + @st.composite + def strategy(draw): + # Generate compatible shapes for broadcasting + shape = draw(array_shapes(min_dims=1, max_dims=3, min_side=2, max_side=10)) + + # Generate two arrays with same shape + x = draw( + arrays( + dtype=np.float32, + shape=shape, + elements=st.floats(-10, 10, allow_nan=False, allow_infinity=False), + ) + ) + y = draw( + arrays( + dtype=np.float32, + shape=shape, + elements=st.floats(-10, 10, allow_nan=False, allow_infinity=False), + ) + ) + + return x, y + + return strategy() + + +def unary_float32_array_strategy(): + """ + Generate one float32 array for unary operations. + + Returns a Hypothesis strategy for single array generation. + + Shape range: 1-3 dimensions, 2-10 elements per dimension + Value range: [-10, 10] (finite values only) + """ + return arrays( + dtype=np.float32, + shape=array_shapes(min_dims=1, max_dims=3, min_side=2, max_side=10), + elements=st.floats(-10, 10, allow_nan=False, allow_infinity=False), + ) + + +def positive_float32_array_strategy(): + """ + Generate positive float32 arrays for operations requiring x > 0. + + Used for: log (requires positive inputs) + + Constraint rationale: + - Lower bound 1e-3 (not 0) for numerical stability + - Avoids values too close to zero where log becomes unstable + - Upper bound 10 keeps values in reasonable range + + Shape range: 1-3 dimensions, 2-10 elements per dimension + Value range: [1e-3, 10] (strictly positive, finite values only) + """ + return arrays( + dtype=np.float32, + shape=array_shapes(min_dims=1, max_dims=3, min_side=2, max_side=10), + elements=st.floats(1e-3, 10, allow_nan=False, allow_infinity=False), + ) + + +def non_negative_float32_array_strategy(): + """ + Generate non-negative float32 arrays for operations requiring x >= 0. + + Used for: sqrt (requires non-negative inputs) + + Constraint rationale: + - Lower bound 0 (inclusive) is mathematically valid for sqrt + - No numerical stability issues at zero for sqrt + - Upper bound 10 keeps values in reasonable range + + Shape range: 1-3 dimensions, 2-10 elements per dimension + Value range: [0, 10] (non-negative, finite values only) + """ + return arrays( + dtype=np.float32, + shape=array_shapes(min_dims=1, max_dims=3, min_side=2, max_side=10), + elements=st.floats(0, 10, allow_nan=False, allow_infinity=False), + ) + + +# ============================================================================ +# SHAPE OPERATIONS REGISTRY (Tier 2) +# ============================================================================ + +SHAPE_OPERATIONS: dict[str, dict[str, Any]] = { + # Shape inspection (already implemented in Phase 0) + "shape": { + "build_graph": lambda x: ([x], x.shape), + "strategy": st.builds( + lambda shape: np.random.randn(*shape).astype("float32"), + shape=array_shapes(min_dims=1, max_dims=4, min_side=1, max_side=10), + ), + "expected_onnx_ops": ["Shape"], + "description": "Get tensor shape", + }, + "shape_i": { + "build_graph": lambda x, i: ( + [x], + # Use Shape_i directly instead of x.shape[i] to avoid Subtensor + # Shape_i is imported from pytensor.tensor.shape + __import__("pytensor.tensor.shape", fromlist=["Shape_i"]).Shape_i(i)(x), + ), + "strategy": st.builds( + lambda shape, i: ( + np.random.randn(*shape).astype("float32"), + min(i, len(shape) - 1), + ), + shape=array_shapes(min_dims=2, max_dims=4, min_side=1, max_side=10), + i=st.integers(0, 3), + ), + "expected_onnx_ops": ["Shape", "Gather"], + "description": "Get specific dimension", + }, + # Reshape operations + "reshape": { + "build_graph": lambda x, new_shape: ([x], x.reshape(new_shape)), + "strategy": reshape_strategy(), + "expected_onnx_ops": ["Reshape"], + "description": "Reshape tensor", + }, + "transpose": { + "build_graph": lambda x: ([x], x.T), + "strategy": st.builds( + lambda shape: np.random.randn(*shape).astype("float32"), + shape=st.tuples(st.integers(2, 10), st.integers(2, 10)), + ), + "expected_onnx_ops": ["Transpose"], + "description": "Transpose matrix", + }, + "dimshuffle_add_dim": { + "build_graph": lambda x: ([x], x.dimshuffle("x", 0)), + "strategy": st.builds( + lambda size: np.random.randn(size).astype("float32"), + size=st.integers(2, 20), + ), + "expected_onnx_ops": ["Unsqueeze"], + "description": "Add dimension via dimshuffle", + }, + "dimshuffle_squeeze": { + "build_graph": lambda x: ([x], x.dimshuffle(0, 2)), + "strategy": st.builds( + lambda s1, s2: np.random.randn(s1, 1, s2).astype("float32"), + s1=st.integers(2, 10), + s2=st.integers(2, 10), + ), + "expected_onnx_ops": ["Squeeze"], + "description": "Remove dimension via dimshuffle", + }, + # Join/Split operations + "concatenate": { + "build_graph": lambda a, b, axis: ([a, b], pt.concatenate([a, b], axis=axis)), + "strategy": concatenate_strategy(), + "expected_onnx_ops": ["Concat"], + "description": "Concatenate tensors", + }, + "stack": { + "build_graph": lambda a, b: ([a, b], pt.stack([a, b], axis=0)), + "strategy": st.builds( + lambda shape: ( + np.random.randn(*shape).astype("float32"), + np.random.randn(*shape).astype("float32"), + ), + shape=array_shapes(min_dims=1, max_dims=3, min_side=2, max_side=10), + ), + "expected_onnx_ops": ["Concat", "Unsqueeze"], + "description": "Stack tensors", + }, +} + + +# ============================================================================ +# REDUCTION OPERATIONS REGISTRY (Tier 3) +# ============================================================================ + +REDUCTION_OPERATIONS: dict[str, dict[str, Any]] = { + "sum": { + "build_graph": lambda x_data, axis: ( + lambda x_var: ([x_var], pt.sum(x_var, axis=axis)) + )(pt.tensor("x", dtype=x_data.dtype, shape=(None,) * x_data.ndim)), + "strategy": tensor_with_axis_strategy(), + "expected_onnx_ops": ["ReduceSum"], + "description": "Sum reduction", + }, + "prod": { + "build_graph": lambda x_data, axis: ( + lambda x_var: ([x_var], pt.prod(x_var, axis=axis)) + )(pt.tensor("x", dtype=x_data.dtype, shape=(None,) * x_data.ndim)), + "strategy": tensor_with_axis_strategy(), + "expected_onnx_ops": ["ReduceProd"], + "description": "Product reduction", + }, + "max": { + "build_graph": lambda x_data, axis: ( + lambda x_var: ([x_var], pt.max(x_var, axis=axis)) + )(pt.tensor("x", dtype=x_data.dtype, shape=(None,) * x_data.ndim)), + "strategy": tensor_with_axis_strategy(), + "expected_onnx_ops": ["ReduceMax"], + "description": "Max reduction", + }, + "min": { + "build_graph": lambda x_data, axis: ( + lambda x_var: ([x_var], pt.min(x_var, axis=axis)) + )(pt.tensor("x", dtype=x_data.dtype, shape=(None,) * x_data.ndim)), + "strategy": tensor_with_axis_strategy(), + "expected_onnx_ops": ["Neg", "ReduceMax"], # Min is implemented as -max(-x) + "description": "Min reduction", + }, + "argmax": { + "build_graph": lambda x_data, axis: ( + lambda x_var: ([x_var], pt.argmax(x_var, axis=axis)) + )(pt.tensor("x", dtype=x_data.dtype, shape=(None,) * x_data.ndim)), + "strategy": tensor_with_axis_strategy(allow_none=False), + "expected_onnx_ops": ["ArgMax"], + "description": "Argmax reduction", + }, + "argmin": { + "build_graph": lambda x_data, axis: ( + lambda x_var: ([x_var], pt.argmin(x_var, axis=axis)) + )(pt.tensor("x", dtype=x_data.dtype, shape=(None,) * x_data.ndim)), + "strategy": tensor_with_axis_strategy(allow_none=False), + "expected_onnx_ops": ["Neg", "ArgMax"], # Argmin is implemented as argmax(-x) + "description": "Argmin reduction", + }, + # Skip all/any for now - they have issues with boolean types in ONNX +} + + +# ============================================================================ +# ALLOCATION OPERATIONS REGISTRY (Tier 3) +# ============================================================================ + +ALLOCATION_OPERATIONS: dict[str, dict[str, Any]] = { + "alloc_scalar": { + "build_graph": lambda val, s1, s2: ([], pt.alloc(val, s1, s2)), + "strategy": alloc_strategy(), + "expected_onnx_ops": ["Expand"], + "description": "Allocate tensor from scalar", + }, + "alloc_empty": { + "build_graph": lambda s1, s2: ([], pt.empty((s1, s2), dtype="float32")), + "strategy": st.tuples(st.integers(2, 10), st.integers(2, 10)), + "expected_onnx_ops": ["ConstantOfShape"], + "description": "Allocate uninitialized tensor", + }, + "make_vector": { + "build_graph": lambda v1, v2, v3: ([], pt.stack([v1, v2, v3])), + "strategy": st.builds( + lambda: tuple(float(x) for x in np.random.randn(3)), + ), + "expected_onnx_ops": ["Concat", "Unsqueeze"], + "description": "Create vector from scalars", + }, + "arange": { + "build_graph": lambda start, stop, step: ( + [], + pt.arange(start, stop, step, dtype="int64"), + ), + "strategy": arange_strategy(), + "expected_onnx_ops": ["Range"], + "description": "Create range tensor", + }, +} + + +# ============================================================================ +# SUBTENSOR OPERATIONS REGISTRY +# ============================================================================ + +SUBTENSOR_OPERATIONS: dict[str, dict[str, Any]] = { + "slice_basic": { + "build_graph": lambda x_val: (lambda x: ([x], x[2:5]))( + pt.tensor("x", dtype="float32", shape=(None,) * x_val.ndim) + ), + "strategy": st.builds( + lambda size: np.arange(size, dtype="float32"), size=st.integers(10, 20) + ), + "expected_onnx_ops": ["Slice"], + "description": "Basic slicing", + }, + "slice_multidim": { + "build_graph": lambda x_val: (lambda x: ([x], x[1:3, 2:4]))( + pt.tensor("x", dtype="float32", shape=(None,) * x_val.ndim) + ), + "strategy": st.builds( + lambda s1, s2: np.arange(s1 * s2).reshape(s1, s2).astype("float32"), + s1=st.integers(5, 10), + s2=st.integers(5, 10), + ), + "expected_onnx_ops": ["Slice"], + "description": "Multi-dimensional slicing", + }, + "slice_with_step": { + "build_graph": lambda x_val: (lambda x: ([x], x[::2]))( + pt.tensor("x", dtype="float32", shape=(None,) * x_val.ndim) + ), + "strategy": st.builds( + lambda size: np.arange(size, dtype="float32"), size=st.integers(10, 20) + ), + "expected_onnx_ops": ["Slice"], + "description": "Slicing with step", + }, + "advanced_index": { + "build_graph": lambda x_val, indices_val: ( + lambda x, indices: ([x, indices], x[indices]) + )( + pt.tensor("x", dtype="float32", shape=(None,) * x_val.ndim), + pt.tensor("indices", dtype="int64", shape=(None,)), + ), + "strategy": advanced_index_strategy(), + "expected_onnx_ops": ["Gather"], + "description": "Advanced indexing with integer array", + }, +} + + +# ============================================================================ +# INCSUBTENSOR OPERATIONS REGISTRY +# ============================================================================ + +INCSUBTENSOR_OPERATIONS: dict[str, dict[str, Any]] = { + "set_subtensor": { + "build_graph": lambda x_val, values_val: ( + lambda x, values: ([x, values], pt.set_subtensor(x[2:5], values)) + )( + pt.tensor("x", dtype="float32", shape=(None,)), + pt.tensor("values", dtype="float32", shape=(None,)), + ), + "strategy": set_subtensor_strategy(), + "expected_onnx_ops": ["ScatterND", "ScatterElements"], + "description": "Set subtensor values", + }, + "inc_subtensor": { + "build_graph": lambda x_val, values_val: ( + lambda x, values: ([x, values], pt.inc_subtensor(x[2:5], values)) + )( + pt.tensor("x", dtype="float32", shape=(None,)), + pt.tensor("values", dtype="float32", shape=(None,)), + ), + "strategy": set_subtensor_strategy(), + "expected_onnx_ops": ["ScatterND", "ScatterElements", "Add"], + "description": "Increment subtensor values", + }, +} + + +# ============================================================================ +# ELEMWISE OPERATIONS REGISTRY (Tier 1) +# ============================================================================ + +ELEMWISE_OPERATIONS: dict[str, dict[str, Any]] = { + # ================================================================= + # BINARY ARITHMETIC OPERATIONS + # ================================================================= + "add": { + "build_graph": lambda x_val, y_val: (lambda x, y: ([x, y], x + y))( + pt.tensor("x", dtype="float32", shape=(None,) * x_val.ndim), + pt.tensor("y", dtype="float32", shape=(None,) * y_val.ndim), + ), + "strategy": binary_float32_arrays_strategy(), + "expected_onnx_ops": ["Add"], + "description": "Element-wise addition", + }, + "mul": { + "build_graph": lambda x_val, y_val: (lambda x, y: ([x, y], x * y))( + pt.tensor("x", dtype="float32", shape=(None,) * x_val.ndim), + pt.tensor("y", dtype="float32", shape=(None,) * y_val.ndim), + ), + "strategy": binary_float32_arrays_strategy(), + "expected_onnx_ops": ["Mul"], + "description": "Element-wise multiplication", + }, + "sub": { + "build_graph": lambda x_val, y_val: (lambda x, y: ([x, y], x - y))( + pt.tensor("x", dtype="float32", shape=(None,) * x_val.ndim), + pt.tensor("y", dtype="float32", shape=(None,) * y_val.ndim), + ), + "strategy": binary_float32_arrays_strategy(), + "expected_onnx_ops": ["Sub"], + "description": "Element-wise subtraction", + }, + "div": { + "build_graph": lambda x_val, y_val: (lambda x, y: ([x, y], x / y))( + pt.tensor("x", dtype="float32", shape=(None,) * x_val.ndim), + pt.tensor("y", dtype="float32", shape=(None,) * y_val.ndim), + ), + "strategy": binary_float32_arrays_strategy(), + "expected_onnx_ops": ["Div"], + "description": "Element-wise division", + }, + "int_div": { + "build_graph": lambda x_val, y_val: (lambda x, y: ([x, y], x // y))( + pt.tensor("x", dtype="float32", shape=(None,) * x_val.ndim), + pt.tensor("y", dtype="float32", shape=(None,) * y_val.ndim), + ), + "strategy": binary_float32_arrays_strategy(), + # NOTE: expected_onnx_ops couples test to implementation details + # This specifies HOW int_div is implemented (div + floor) rather than + # just testing correctness. This is intentional for ONNX backend validation + # but makes tests brittle if implementation changes. + "expected_onnx_ops": ["Div", "Floor"], # Integer division is div + floor + "description": "Element-wise integer division", + }, + "pow": { + "build_graph": lambda x_val, y_val: (lambda x, y: ([x, y], x**y))( + pt.tensor("x", dtype="float32", shape=(None,) * x_val.ndim), + pt.tensor("y", dtype="float32", shape=(None,) * y_val.ndim), + ), + "strategy": binary_float32_arrays_strategy(), + "expected_onnx_ops": ["Pow"], + "description": "Element-wise power", + }, + # ================================================================= + # ELEMENT-WISE MIN/MAX OPERATIONS + # ================================================================= + "maximum": { + "build_graph": lambda x_val, y_val: (lambda x, y: ([x, y], pt.maximum(x, y)))( + pt.tensor("x", dtype="float32", shape=(None,) * x_val.ndim), + pt.tensor("y", dtype="float32", shape=(None,) * y_val.ndim), + ), + "strategy": binary_float32_arrays_strategy(), + "expected_onnx_ops": ["Max"], + "description": "Element-wise maximum", + }, + "minimum": { + "build_graph": lambda x_val, y_val: (lambda x, y: ([x, y], pt.minimum(x, y)))( + pt.tensor("x", dtype="float32", shape=(None,) * x_val.ndim), + pt.tensor("y", dtype="float32", shape=(None,) * y_val.ndim), + ), + "strategy": binary_float32_arrays_strategy(), + "expected_onnx_ops": ["Min"], + "description": "Element-wise minimum", + }, + # ================================================================= + # UNARY OPERATIONS + # ================================================================= + "neg": { + "build_graph": lambda x_val: (lambda x: ([x], -x))( + pt.tensor("x", dtype="float32", shape=(None,) * x_val.ndim) + ), + "strategy": unary_float32_array_strategy(), + "expected_onnx_ops": ["Neg"], + "description": "Element-wise negation", + }, + "abs": { + "build_graph": lambda x_val: (lambda x: ([x], pt.abs(x)))( + pt.tensor("x", dtype="float32", shape=(None,) * x_val.ndim) + ), + "strategy": unary_float32_array_strategy(), + "expected_onnx_ops": ["Abs"], + "description": "Element-wise absolute value", + }, + "exp": { + "build_graph": lambda x_val: (lambda x: ([x], pt.exp(x)))( + pt.tensor("x", dtype="float32", shape=(None,) * x_val.ndim) + ), + "strategy": unary_float32_array_strategy(), + "expected_onnx_ops": ["Exp"], + "description": "Element-wise exponential", + }, + # ================================================================= + # CONSTRAINED UNARY OPERATIONS + # ================================================================= + "log": { + "build_graph": lambda x_val: (lambda x: ([x], pt.log(x)))( + pt.tensor("x", dtype="float32", shape=(None,) * x_val.ndim) + ), + "strategy": positive_float32_array_strategy(), + "expected_onnx_ops": ["Log"], + "description": "Element-wise natural logarithm", + }, + "sqrt": { + "build_graph": lambda x_val: (lambda x: ([x], pt.sqrt(x)))( + pt.tensor("x", dtype="float32", shape=(None,) * x_val.ndim) + ), + "strategy": non_negative_float32_array_strategy(), + "expected_onnx_ops": ["Sqrt"], + "description": "Element-wise square root", + }, + # ================================================================= + # ROUNDING OPERATIONS + # ================================================================= + "floor": { + "build_graph": lambda x_val: (lambda x: ([x], pt.floor(x)))( + pt.tensor("x", dtype="float32", shape=(None,) * x_val.ndim) + ), + "strategy": unary_float32_array_strategy(), + "expected_onnx_ops": ["Floor"], + "description": "Element-wise floor", + }, + "ceil": { + "build_graph": lambda x_val: (lambda x: ([x], pt.ceil(x)))( + pt.tensor("x", dtype="float32", shape=(None,) * x_val.ndim) + ), + "strategy": unary_float32_array_strategy(), + "expected_onnx_ops": ["Ceil"], + "description": "Element-wise ceiling", + }, + "round": { + "build_graph": lambda x_val: (lambda x: ([x], pt.round(x)))( + pt.tensor("x", dtype="float32", shape=(None,) * x_val.ndim) + ), + "strategy": unary_float32_array_strategy(), + "expected_onnx_ops": ["Round"], + "description": "Element-wise rounding (half to even)", + }, + "round_away": { + "build_graph": lambda x_val: ( + lambda x: ([x], pt.round(x, mode="half_away_from_zero")) + )(pt.tensor("x", dtype="float32", shape=(None,) * x_val.ndim)), + "strategy": unary_float32_array_strategy(), + "expected_onnx_ops": ["Round"], + "description": "Element-wise rounding (half away from zero)", + }, + # ================================================================= + # SPECIAL OPERATIONS + # ================================================================= + "clip": { + "build_graph": lambda x_val, min_val, max_val: ( + lambda x: ([x], pt.clip(x, min_val, max_val)) + )(pt.tensor("x", dtype="float32", shape=(None,) * x_val.ndim)), + # Strategy ensures min_v < max_v by construction: + # min_v from [-5, 0] and max_v from [0, 5] guarantees min_v <= 0 <= max_v + # Edge case: min_v == max_v == 0 is possible but rare + # This edge case (all values clipped to same value) is worth testing + # separately in Phase 2 manual tests if needed + "strategy": st.builds( + lambda x, min_v, max_v: (x, float(min_v), float(max_v)), + x=unary_float32_array_strategy(), + min_v=st.floats(-5, 0), + max_v=st.floats(0, 5), + ), + "expected_onnx_ops": ["Clip"], + "description": "Element-wise clipping", + }, +} diff --git a/tests/link/onnx/test_basic.py b/tests/link/onnx/test_basic.py new file mode 100644 index 0000000000..163ee240c0 --- /dev/null +++ b/tests/link/onnx/test_basic.py @@ -0,0 +1,169 @@ +"""Core testing utilities for ONNX backend.""" + +from collections.abc import Callable, Iterable +from functools import partial + +import numpy as np +import pytest + +from pytensor.compile.function import function +from pytensor.compile.mode import Mode +from pytensor.graph.basic import Variable + + +# These will be imported once the ONNX backend is implemented +# For now, we'll set up the structure so tests can use them +try: + from pytensor.link.onnx import ONNXLinker + + onnx = pytest.importorskip("onnx") + onnxruntime = pytest.importorskip("onnxruntime") + + onnx_mode = Mode(linker=ONNXLinker(), optimizer=None) + py_mode = Mode(linker="py", optimizer=None) +except ImportError: + # ONNX backend not yet implemented + onnx_mode = None + py_mode = Mode(linker="py", optimizer=None) + + +def compare_onnx_and_py( + graph_inputs: Iterable[Variable], + graph_outputs: Variable | Iterable[Variable], + test_inputs: Iterable, + *, + assert_fn: Callable | None = None, + must_validate: bool = True, + onnx_mode=onnx_mode, + py_mode=py_mode, +): + """Compare ONNX Runtime output to Python reference. + + This is the core testing utility that: + 1. Compiles graph with ONNX backend + 2. Compiles graph with Python backend + 3. Executes both with test_inputs + 4. Asserts outputs match + 5. Validates ONNX model + + Parameters + ---------- + graph_inputs : Iterable[Variable] + Symbolic inputs to the graph + graph_outputs : Variable | Iterable[Variable] + Symbolic outputs of the graph + test_inputs : Iterable + Numerical inputs for testing the function + assert_fn : Callable, optional + Assert function used to check for equality between ONNX and Python. + If not provided, uses np.testing.assert_allclose with rtol=1e-4 + must_validate : bool, optional + If True, validates the ONNX model with onnx.checker.check_model + onnx_mode : Mode, optional + Mode to use for ONNX compilation + py_mode : Mode, optional + Mode to use for Python reference compilation + + Returns + ------- + tuple + (onnx_function, onnx_result) + + Raises + ------ + AssertionError + If ONNX output doesn't match Python output + """ + if assert_fn is None: + assert_fn = partial(np.testing.assert_allclose, rtol=1e-4) + + if any(inp.owner is not None for inp in graph_inputs): + raise ValueError("Inputs must be root variables") + + # Compile with ONNX backend + onnx_fn = function(graph_inputs, graph_outputs, mode=onnx_mode) + onnx_res = onnx_fn(*test_inputs) + + # Compile with Python reference + py_fn = function(graph_inputs, graph_outputs, mode=py_mode) + py_res = py_fn(*test_inputs) + + # Compare outputs + if isinstance(graph_outputs, list | tuple): + for o, p in zip(onnx_res, py_res, strict=True): + assert_fn(o, p) + else: + assert_fn(onnx_res, py_res) + + # Validate ONNX model + if must_validate and hasattr(onnx_fn.maker.linker, "onnx_model"): + import onnx + + onnx.checker.check_model(onnx_fn.maker.linker.onnx_model) + + return onnx_fn, onnx_res + + +def get_onnx_node_types(fn): + """Get list of ONNX node types in compiled function. + + Parameters + ---------- + fn : Function + Compiled PyTensor function with ONNX linker + + Returns + ------- + list of str + List of ONNX operation types (e.g., ['Add', 'Mul', 'Sub']) + """ + if not hasattr(fn.maker.linker, "onnx_model"): + raise ValueError("Function was not compiled with ONNX linker") + + return [node.op_type for node in fn.maker.linker.onnx_model.graph.node] + + +# Meta-test: test the test utilities themselves +def test_compare_onnx_and_py_simple(): + """Test that compare_onnx_and_py works for a simple identity operation.""" + pytest.importorskip("onnx") + pytest.importorskip("onnxruntime") + + import pytensor.tensor as pt + + # Simple identity + x = pt.vector("x", dtype="float32") + y = x + + # Test data + x_val = np.array([1, 2, 3], dtype="float32") + + # Should not raise + try: + _fn, result = compare_onnx_and_py([x], y, [x_val]) + np.testing.assert_array_equal(result, x_val) + except Exception as e: + pytest.fail(f"compare_onnx_and_py raised unexpectedly: {e}") + + +def test_get_onnx_node_types(): + """Test that get_onnx_node_types utility works.""" + pytest.importorskip("onnx") + pytest.importorskip("onnxruntime") + + import pytensor + import pytensor.tensor as pt + from pytensor.link.onnx.linker import ONNXLinker + + # Create a graph with Add operation + x = pt.vector("x", dtype="float32") + y = pt.vector("y", dtype="float32") + z = x + y + + # Compile + fn = pytensor.function([x, y], z, mode=Mode(linker=ONNXLinker())) + + # Get node types + node_types = get_onnx_node_types(fn) + + assert "Add" in node_types, f"Expected 'Add' in node types, got {node_types}" diff --git a/tests/link/onnx/test_dispatch_basic.py b/tests/link/onnx/test_dispatch_basic.py new file mode 100644 index 0000000000..f97421f65c --- /dev/null +++ b/tests/link/onnx/test_dispatch_basic.py @@ -0,0 +1,78 @@ +"""Tests for ONNX dispatch system.""" + +import numpy as np +import pytest + + +def test_onnx_funcify_unregistered_op(): + """Test that onnx_funcify raises informative error for unregistered ops.""" + from pytensor.link.onnx.dispatch import onnx_funcify + + # Create a fake op + class FakeOp: + pass + + fake_op = FakeOp() + + with pytest.raises(NotImplementedError) as exc_info: + onnx_funcify(fake_op) + + error_msg = str(exc_info.value) + assert "No ONNX conversion available" in error_msg, ( + f"Error should mention no conversion available, got: {error_msg}" + ) + assert "FakeOp" in error_msg, f"Error should mention the op type, got: {error_msg}" + + +def test_onnx_typify_ndarray(): + """Test that onnx_typify converts numpy arrays to ONNX tensors.""" + pytest.importorskip("onnx") + + import onnx + from onnx import numpy_helper + + from pytensor.link.onnx.dispatch import onnx_typify + + # Test data + arr = np.array([1, 2, 3], dtype="float32") + + # Convert + result = onnx_typify(arr, name="test_tensor") + + # Verify it's a TensorProto + assert isinstance(result, onnx.TensorProto), ( + f"Expected TensorProto, got {type(result)}" + ) + + # Verify data is correct + result_arr = numpy_helper.to_array(result) + np.testing.assert_array_equal(result_arr, arr) + + +def test_make_value_info_basic(): + """Test that make_value_info creates correct ONNX ValueInfo.""" + pytest.importorskip("onnx") + + import onnx + + import pytensor.tensor as pt + from pytensor.link.onnx.dispatch.basic import make_value_info + + # Create a PyTensor variable + x = pt.vector("x", dtype="float32") + + # Create ValueInfo + value_info = make_value_info(x, "x") + + # Verify type + assert isinstance(value_info, onnx.ValueInfoProto), ( + f"Expected ValueInfoProto, got {type(value_info)}" + ) + + # Verify name + assert value_info.name == "x", f"Expected name 'x', got {value_info.name}" + + # Verify dtype + assert value_info.type.tensor_type.elem_type == onnx.TensorProto.FLOAT, ( + f"Expected FLOAT dtype, got {value_info.type.tensor_type.elem_type}" + ) diff --git a/tests/link/onnx/test_elemwise.py b/tests/link/onnx/test_elemwise.py new file mode 100644 index 0000000000..7db433c0e3 --- /dev/null +++ b/tests/link/onnx/test_elemwise.py @@ -0,0 +1,526 @@ +"""Tests for ONNX elemwise operations. + +Test Strategy: +- Property-based tests provide primary coverage (180+ scenarios) +- Main property test covers 13 unconstrained operations +- Separate property tests for constrained operations (log, sqrt, pow, clip) +- Manual tests retained for edge cases and compositions + +Coverage: 18 elemwise operations total +""" + +from functools import partial + +import numpy as np +import pytest +from hypothesis import given, settings +from hypothesis import strategies as st + +import pytensor.tensor as pt +from tests.link.onnx.strategies import ELEMWISE_OPERATIONS +from tests.link.onnx.test_basic import compare_onnx_and_py, get_onnx_node_types + + +# ============================================================================ +# NUMERICAL TOLERANCE CONSTANTS +# ============================================================================ +# These tolerances account for numerical precision differences between +# PyTensor and ONNX implementations. Documented rationale for each: + +# Standard tolerance for stable operations (add, mul, sub, etc.) +STANDARD_TOLERANCE = {"rtol": 1e-5, "atol": 1e-8} + +# Relaxed tolerance for numerically unstable operations +# Used for: pow (negative base + fractional exponent), exp (large values) +# Rationale: These operations amplify floating-point errors +RELAXED_TOLERANCE = {"rtol": 1e-3, "atol": 1e-5} + +# Log-specific tolerance (between standard and relaxed) +# Used for: log (values near zero are numerically sensitive) +# Rationale: log(x) for small x has larger relative error +LOG_TOLERANCE = {"rtol": 1e-4, "atol": 1e-6} + + +# ============================================================================ +# PROPERTY-BASED TESTS (Primary Coverage) +# ============================================================================ + + +@given( + op_name=st.sampled_from( + [ + # Binary arithmetic (5) + "add", + "mul", + "sub", + "div", + "int_div", + # Binary min/max (2) + "maximum", + "minimum", + # Unary (3) + "neg", + "abs", + "exp", + # Rounding (3) + "floor", + "ceil", + "round", + # Total: 13 unconstrained operations + ] + ), + data=st.data(), +) +@settings(max_examples=10, deadline=None) +def test_elemwise_operations_correctness(op_name, data): + """ + Property test: All unconstrained elemwise operations produce correct ONNX results. + + This test verifies: + - ONNX output matches Python reference implementation + - Correct ONNX node types are generated + - Operations handle diverse inputs correctly + + Operations tested (13 unconstrained Tier 1 operations): + - Binary arithmetic: add, mul, sub, div, int_div (5) + - Binary min/max: maximum, minimum (2) + - Unary: neg, abs, exp (3) + - Rounding: floor, ceil, round (3) + + Total: 13 operations x 10 examples = 130 test scenarios + + Constrained operations tested separately: + - pow, log, sqrt, clip (separate tests with constrained strategies) + """ + pytest.importorskip("onnx") + pytest.importorskip("onnxruntime") + + # Get operation configuration from registry + op_config = ELEMWISE_OPERATIONS[op_name] + + # Generate test data using operation's strategy + test_data = data.draw(op_config["strategy"]) + + # Handle both tuple and single value returns + if isinstance(test_data, tuple): + inputs_tuple = test_data + else: + inputs_tuple = (test_data,) + + # Build PyTensor graph + graph_inputs, graph_output = op_config["build_graph"](*inputs_tuple) + + # Prepare test inputs for execution + if isinstance(test_data, tuple): + test_inputs = list(test_data) + else: + test_inputs = [test_data] + + # Compare ONNX vs PyTensor + fn, _result = compare_onnx_and_py(graph_inputs, graph_output, test_inputs) + + # Verify ONNX node types + node_types = get_onnx_node_types(fn) + expected_ops = op_config["expected_onnx_ops"] + + # Check that at least one expected operation is present + assert any(op in node_types for op in expected_ops), ( + f"{op_name}: Expected one of {expected_ops}, got {node_types}" + ) + + +@given(data=st.data()) +@settings(max_examples=50, deadline=None) # Higher count for critical operation +def test_log_operation_correctness(data): + """ + Property test: Log operation produces correct ONNX results. + + This test verifies: + - Log operation works with positive inputs + - ONNX output matches Python reference + - Correct ONNX node type (Log) is generated + + Note: Uses positive_float32_array_strategy to ensure valid inputs + (log requires x > 0). Uses 50 examples (vs standard 10) due to + numerical sensitivity. + """ + pytest.importorskip("onnx") + pytest.importorskip("onnxruntime") + + op_config = ELEMWISE_OPERATIONS["log"] + + # Generate positive test data + test_data = data.draw(op_config["strategy"]) + + # Verify inputs are positive (strategy constraint) + assert np.all(test_data > 0), "Log operation requires positive inputs" + + # Build graph + graph_inputs, graph_output = op_config["build_graph"](test_data) + + # Compare ONNX vs PyTensor with log-specific tolerance + # Uses LOG_TOLERANCE (rtol=1e-4, atol=1e-6) - see tolerance constants + fn, _result = compare_onnx_and_py( + graph_inputs, + graph_output, + [test_data], + assert_fn=partial(np.testing.assert_allclose, **LOG_TOLERANCE), + ) + + # Verify ONNX node type + node_types = get_onnx_node_types(fn) + assert "Log" in node_types, f"Expected 'Log' node, got {node_types}" + + +@given(data=st.data()) +@settings(max_examples=10, deadline=None) +def test_sqrt_operation_correctness(data): + """ + Property test: Sqrt operation produces correct ONNX results. + + This test verifies: + - Sqrt operation works with non-negative inputs + - ONNX output matches Python reference + - Correct ONNX node type (Sqrt) is generated + + Note: Uses non_negative_float32_array_strategy to ensure valid inputs + (sqrt requires x >= 0) + """ + pytest.importorskip("onnx") + pytest.importorskip("onnxruntime") + + op_config = ELEMWISE_OPERATIONS["sqrt"] + + # Generate non-negative test data + test_data = data.draw(op_config["strategy"]) + + # Verify inputs are non-negative (strategy constraint) + assert np.all(test_data >= 0), "Sqrt operation requires non-negative inputs" + + # Build graph + graph_inputs, graph_output = op_config["build_graph"](test_data) + + # Compare ONNX vs PyTensor + fn, _result = compare_onnx_and_py(graph_inputs, graph_output, [test_data]) + + # Verify ONNX node type + node_types = get_onnx_node_types(fn) + assert "Sqrt" in node_types, f"Expected 'Sqrt' node, got {node_types}" + + +@given(data=st.data()) +@settings(max_examples=50, deadline=None) # Higher count for critical operation +def test_pow_operation_correctness(data): + """ + Property test: Pow operation produces correct ONNX results. + + This test verifies: + - Pow operation works with float inputs + - ONNX output matches Python reference + - Correct ONNX node type (Pow) is generated + + Note: May have numerical precision issues with negative bases + and fractional exponents. Using relaxed tolerance. Uses + 50 examples (vs standard 10) due to numerical complexity. + """ + pytest.importorskip("onnx") + pytest.importorskip("onnxruntime") + + op_config = ELEMWISE_OPERATIONS["pow"] + + # Generate test data (two arrays) + test_data = data.draw(op_config["strategy"]) + x_val, y_val = test_data + + # Build graph + graph_inputs, graph_output = op_config["build_graph"](x_val, y_val) + + # Compare ONNX vs PyTensor with relaxed tolerance + # Uses RELAXED_TOLERANCE (rtol=1e-3, atol=1e-5) - see tolerance constants + # Rationale: Pow with negative base + fractional exponent amplifies errors + fn, _result = compare_onnx_and_py( + graph_inputs, + graph_output, + [x_val, y_val], + assert_fn=partial(np.testing.assert_allclose, **RELAXED_TOLERANCE), + ) + + # Verify ONNX node type + node_types = get_onnx_node_types(fn) + assert "Pow" in node_types, f"Expected 'Pow' node, got {node_types}" + + +@given(data=st.data()) +@settings(max_examples=10, deadline=None) +def test_clip_operation_correctness(data): + """ + Property test: Clip operation produces correct ONNX results. + + This test verifies: + - Clip operation correctly bounds values + - ONNX output matches Python reference + - Correct ONNX node type (Clip) is generated + - Min/max bounds are respected + """ + pytest.importorskip("onnx") + pytest.importorskip("onnxruntime") + + op_config = ELEMWISE_OPERATIONS["clip"] + + # Generate test data (array, min, max) + test_data = data.draw(op_config["strategy"]) + x_val, min_val, max_val = test_data + + # Build graph + graph_inputs, graph_output = op_config["build_graph"](x_val, min_val, max_val) + + # Compare ONNX vs PyTensor + fn, result = compare_onnx_and_py(graph_inputs, graph_output, [x_val]) + + # Verify ONNX node type + node_types = get_onnx_node_types(fn) + assert "Clip" in node_types, f"Expected 'Clip' node, got {node_types}" + + # Additional validation: verify bounds are respected + assert np.all(result >= min_val), f"Result contains values below min_val={min_val}" + assert np.all(result <= max_val), f"Result contains values above max_val={max_val}" + + +# ============================================================================ +# MANUAL EDGE CASE TESTS +# ============================================================================ + + +# Test binary arithmetic operations +def test_add_vectors(): + """Test that vector addition exports correctly to ONNX.""" + pytest.importorskip("onnx") + pytest.importorskip("onnxruntime") + + # Define graph + x = pt.vector("x", dtype="float32") + y = pt.vector("y", dtype="float32") + z = x + y + + # Test data + x_val = np.array([1, 2, 3], dtype="float32") + y_val = np.array([4, 5, 6], dtype="float32") + + # Compare outputs + fn, _result = compare_onnx_and_py([x, y], z, [x_val, y_val]) + + # Verify ONNX node type + node_types = get_onnx_node_types(fn) + assert "Add" in node_types, f"Expected 'Add' node in ONNX graph, got {node_types}" + + +def test_mul_vectors(): + """Test that vector multiplication exports correctly to ONNX.""" + pytest.importorskip("onnx") + pytest.importorskip("onnxruntime") + + x = pt.vector("x", dtype="float32") + y = pt.vector("y", dtype="float32") + z = x * y + + x_val = np.array([1, 2, 3], dtype="float32") + y_val = np.array([2, 3, 4], dtype="float32") + + fn, _result = compare_onnx_and_py([x, y], z, [x_val, y_val]) + + assert "Mul" in get_onnx_node_types(fn) + + +def test_sub_vectors(): + """Test vector subtraction.""" + pytest.importorskip("onnx") + pytest.importorskip("onnxruntime") + + x = pt.vector("x", dtype="float32") + y = pt.vector("y", dtype="float32") + z = x - y + + x_val = np.array([5, 6, 7], dtype="float32") + y_val = np.array([1, 2, 3], dtype="float32") + + fn, _result = compare_onnx_and_py([x, y], z, [x_val, y_val]) + assert "Sub" in get_onnx_node_types(fn) + + +def test_div_vectors(): + """Test vector division.""" + pytest.importorskip("onnx") + pytest.importorskip("onnxruntime") + + x = pt.vector("x", dtype="float32") + y = pt.vector("y", dtype="float32") + z = x / y + + x_val = np.array([6, 8, 10], dtype="float32") + y_val = np.array([2, 4, 5], dtype="float32") + + fn, _result = compare_onnx_and_py([x, y], z, [x_val, y_val]) + assert "Div" in get_onnx_node_types(fn) + + +def test_chained_arithmetic(): + """Test that chained arithmetic operations work correctly.""" + pytest.importorskip("onnx") + pytest.importorskip("onnxruntime") + + x = pt.vector("x", dtype="float32") + # (x * 2 + 3) / 4 + z = ((x * 2) + 3) / 4 + + x_val = np.array([1, 2, 3], dtype="float32") + + fn, _result = compare_onnx_and_py([x], z, [x_val]) + + # Should have multiple operation nodes + node_types = get_onnx_node_types(fn) + assert "Mul" in node_types + assert "Add" in node_types + assert "Div" in node_types + + +# Test unary operations +def test_neg(): + """Test negation operation.""" + pytest.importorskip("onnx") + pytest.importorskip("onnxruntime") + + x = pt.vector("x", dtype="float32") + y = -x + + x_val = np.array([1, -2, 3], dtype="float32") + + fn, _result = compare_onnx_and_py([x], y, [x_val]) + assert "Neg" in get_onnx_node_types(fn) + + +def test_abs(): + """Test absolute value operation.""" + pytest.importorskip("onnx") + pytest.importorskip("onnxruntime") + + x = pt.vector("x", dtype="float32") + y = pt.abs(x) + + x_val = np.array([1, -2, 3, -4], dtype="float32") + + fn, _result = compare_onnx_and_py([x], y, [x_val]) + assert "Abs" in get_onnx_node_types(fn) + + +def test_exp(): + """Test exponential operation.""" + pytest.importorskip("onnx") + pytest.importorskip("onnxruntime") + + x = pt.vector("x", dtype="float32") + y = pt.exp(x) + + x_val = np.array([0, 1, 2], dtype="float32") + + fn, _result = compare_onnx_and_py([x], y, [x_val]) + assert "Exp" in get_onnx_node_types(fn) + + +def test_log(): + """Test natural logarithm operation.""" + pytest.importorskip("onnx") + pytest.importorskip("onnxruntime") + + x = pt.vector("x", dtype="float32") + y = pt.log(x) + + x_val = np.array([1, 2, np.e], dtype="float32") + + fn, _result = compare_onnx_and_py([x], y, [x_val]) + assert "Log" in get_onnx_node_types(fn) + + +def test_sqrt(): + """Test square root operation.""" + pytest.importorskip("onnx") + pytest.importorskip("onnxruntime") + + x = pt.vector("x", dtype="float32") + y = pt.sqrt(x) + + x_val = np.array([1, 4, 9, 16], dtype="float32") + + fn, _result = compare_onnx_and_py([x], y, [x_val]) + assert "Sqrt" in get_onnx_node_types(fn) + + +def test_pow(): + """Test power operation.""" + pytest.importorskip("onnx") + pytest.importorskip("onnxruntime") + + x = pt.vector("x", dtype="float32") + y = pt.vector("y", dtype="float32") + z = x**y + + x_val = np.array([2, 3, 4], dtype="float32") + y_val = np.array([2, 2, 3], dtype="float32") + + fn, _result = compare_onnx_and_py([x, y], z, [x_val, y_val]) + assert "Pow" in get_onnx_node_types(fn) + + +@pytest.mark.parametrize( + "op_name,op_func,expected_node", + [ + ("floor", pt.floor, "Floor"), + ("ceil", pt.ceil, "Ceil"), + ("round", pt.round, "Round"), + ], +) +def test_rounding_operations(op_name, op_func, expected_node): + """Test floor, ceil, and round operations.""" + pytest.importorskip("onnx") + pytest.importorskip("onnxruntime") + + x = pt.vector("x", dtype="float32") + y = op_func(x) + + x_val = np.array([1.2, 2.5, 3.7, -1.5], dtype="float32") + + fn, _result = compare_onnx_and_py([x], y, [x_val]) + assert expected_node in get_onnx_node_types(fn), ( + f"Expected {expected_node} node for {op_name}" + ) + + +def test_maximum(): + """Test element-wise maximum operation.""" + pytest.importorskip("onnx") + pytest.importorskip("onnxruntime") + + x = pt.vector("x", dtype="float32") + y = pt.vector("y", dtype="float32") + z = pt.maximum(x, y) + + x_val = np.array([1, 5, 3], dtype="float32") + y_val = np.array([4, 2, 6], dtype="float32") + + fn, _result = compare_onnx_and_py([x, y], z, [x_val, y_val]) + assert "Max" in get_onnx_node_types(fn) + + +def test_minimum(): + """Test element-wise minimum operation.""" + pytest.importorskip("onnx") + pytest.importorskip("onnxruntime") + + x = pt.vector("x", dtype="float32") + y = pt.vector("y", dtype="float32") + z = pt.minimum(x, y) + + x_val = np.array([1, 5, 3], dtype="float32") + y_val = np.array([4, 2, 6], dtype="float32") + + fn, _result = compare_onnx_and_py([x, y], z, [x_val, y_val]) + assert "Min" in get_onnx_node_types(fn) diff --git a/tests/link/onnx/test_export.py b/tests/link/onnx/test_export.py new file mode 100644 index 0000000000..872ad32821 --- /dev/null +++ b/tests/link/onnx/test_export.py @@ -0,0 +1,79 @@ +"""Tests for ONNX export API.""" + +import numpy as np +import pytest + +import pytensor +import pytensor.tensor as pt + + +def test_export_onnx_basic(tmp_path): + """Test that export_onnx creates a valid ONNX file.""" + pytest.importorskip("onnx") + pytest.importorskip("onnxruntime") + + import onnx + + from pytensor.link.onnx import export_onnx + + # Define graph + x = pt.vector("x", dtype="float32") + y = x * 2 + + # Export + output_path = tmp_path / "test_model.onnx" + model = export_onnx([x], y, str(output_path)) + + # Verify file exists + assert output_path.exists(), f"ONNX file not created at {output_path}" + + # Verify model is valid + onnx.checker.check_model(model) + + # Verify model can be loaded + loaded_model = onnx.load(str(output_path)) + assert loaded_model is not None + + +def test_compile_onnx_basic(): + """Test that compile_onnx returns an executable function.""" + pytest.importorskip("onnx") + pytest.importorskip("onnxruntime") + + from pytensor.link.onnx import compile_onnx + + x = pt.vector("x", dtype="float32") + y = x + 1 + + # Compile + fn = compile_onnx([x], y) + + # Test execution + x_val = np.array([1, 2, 3], dtype="float32") + result = fn(x_val) + + expected = np.array([2, 3, 4], dtype="float32") + np.testing.assert_array_equal(result, expected) + + +def test_export_function_onnx(tmp_path): + """Test exporting a compiled PyTensor function to ONNX.""" + pytest.importorskip("onnx") + pytest.importorskip("onnxruntime") + + import onnx + + from pytensor.link.onnx import export_function_onnx + + # Create and compile function + x = pt.vector("x", dtype="float32") + y = pt.sqrt(x) + fn = pytensor.function([x], y) + + # Export + output_path = tmp_path / "function.onnx" + model = export_function_onnx(fn, str(output_path)) + + # Verify + assert output_path.exists() + onnx.checker.check_model(model) diff --git a/tests/link/onnx/test_extra_ops.py b/tests/link/onnx/test_extra_ops.py new file mode 100644 index 0000000000..46118ac3b2 --- /dev/null +++ b/tests/link/onnx/test_extra_ops.py @@ -0,0 +1,102 @@ +"""Tests for ONNX backend extra operations (Tier 5).""" + +import numpy as np +import pytest + +import pytensor.tensor as pt +from tests.link.onnx.test_basic import compare_onnx_and_py, get_onnx_node_types + + +# CumSum Tests + + +@pytest.mark.parametrize("axis", [0, 1]) +def test_cumsum(axis): + """Test cumulative sum operation.""" + pytest.importorskip("onnx") + pytest.importorskip("onnxruntime") + + x = pt.matrix("x", dtype="float32") + y = pt.cumsum(x, axis=axis) + + x_val = np.random.randn(3, 4).astype("float32") + + fn, result = compare_onnx_and_py([x], y, [x_val]) + + expected = np.cumsum(x_val, axis=axis) + np.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-6) + + node_types = get_onnx_node_types(fn) + assert "CumSum" in node_types + + +# Repeat Tests + + +def test_repeat(): + """Test repeat operation (repeat elements).""" + pytest.importorskip("onnx") + pytest.importorskip("onnxruntime") + + x = pt.vector("x", dtype="float32") + y = pt.repeat(x, repeats=3, axis=0) + + x_val = np.array([1, 2, 3], dtype="float32") + + _fn, result = compare_onnx_and_py([x], y, [x_val]) + + expected = np.repeat(x_val, repeats=3, axis=0) + np.testing.assert_array_equal(result, expected) + + +# Unique Tests + + +def test_unique(): + """Test unique operation (find unique elements). + + Note: ONNX Unique has different semantics than NumPy. + May need special handling. + """ + pytest.importorskip("onnx") + pytest.importorskip("onnxruntime") + + x = pt.vector("x", dtype="int64") + y = pt.unique(x) + + x_val = np.array([1, 2, 3, 2, 1, 4, 3], dtype="int64") + + fn, result = compare_onnx_and_py([x], y, [x_val]) + + expected = np.unique(x_val) + + # Result may be sorted differently + np.testing.assert_array_equal(sorted(result), sorted(expected)) + + node_types = get_onnx_node_types(fn) + assert "Unique" in node_types + + +# Pad Tests + + +def test_pad(): + """Test pad operation.""" + pytest.importorskip("onnx") + pytest.importorskip("onnxruntime") + + x = pt.matrix("x", dtype="float32") + # Pad with 1 zero on each side + y = pt.pad(x, pad_width=((1, 1), (1, 1)), mode="constant", constant_values=0) + + x_val = np.array([[1, 2], [3, 4]], dtype="float32") + + fn, result = compare_onnx_and_py([x], y, [x_val]) + + expected = np.pad( + x_val, pad_width=((1, 1), (1, 1)), mode="constant", constant_values=0 + ) + np.testing.assert_array_equal(result, expected) + + node_types = get_onnx_node_types(fn) + assert "Pad" in node_types diff --git a/tests/link/onnx/test_imports.py b/tests/link/onnx/test_imports.py new file mode 100644 index 0000000000..4da6e718fc --- /dev/null +++ b/tests/link/onnx/test_imports.py @@ -0,0 +1,41 @@ +"""Tests for ONNX backend module structure and imports.""" + +import pytest + + +def test_onnx_module_exists(): + """Test that pytensor.link.onnx module exists and is importable.""" + try: + import pytensor.link.onnx # noqa: F401 + except ImportError as e: + pytest.fail(f"Failed to import pytensor.link.onnx: {e}") + + +def test_onnx_public_api(): + """Test that ONNX backend exports expected public API.""" + from pytensor.link.onnx import ( + ONNX_OPSET_VERSION, + ONNXLinker, + compile_onnx, + export_onnx, + onnx_funcify, + ) + + assert ONNXLinker is not None, "ONNXLinker not exported" + assert export_onnx is not None, "export_onnx not exported" + assert compile_onnx is not None, "compile_onnx not exported" + assert onnx_funcify is not None, "onnx_funcify not exported" + assert ONNX_OPSET_VERSION == 18, f"Expected opset 18, got {ONNX_OPSET_VERSION}" + + +def test_dispatch_module_structure(): + """Test that dispatch module has expected structure.""" + from pytensor.link.onnx.dispatch import onnx_funcify, onnx_typify + + # Check they're singledispatch functions + assert hasattr(onnx_funcify, "register"), ( + "onnx_funcify should be a singledispatch function" + ) + assert hasattr(onnx_typify, "register"), ( + "onnx_typify should be a singledispatch function" + ) diff --git a/tests/link/onnx/test_integration.py b/tests/link/onnx/test_integration.py new file mode 100644 index 0000000000..34ce8d1737 --- /dev/null +++ b/tests/link/onnx/test_integration.py @@ -0,0 +1,53 @@ +"""Integration tests for ONNX backend - complete models and workflows.""" + +import numpy as np +import pytest + +import pytensor.tensor as pt +from pytensor.tensor.special import softmax +from tests.link.onnx.test_basic import compare_onnx_and_py + + +def test_simple_mlp(): + """Test simple MLP using matmul, add, and activation. + + This integration test verifies that a complete neural network + layer can be exported to ONNX. + """ + pytest.importorskip("onnx") + pytest.importorskip("onnxruntime") + + # Input + x = pt.matrix("x", dtype="float32") + + # Weights and biases + W1 = pt.matrix("W1", dtype="float32") + b1 = pt.vector("b1", dtype="float32") + W2 = pt.matrix("W2", dtype="float32") + b2 = pt.vector("b2", dtype="float32") + + # Layer 1: x @ W1 + b1, then ReLU + h = pt.maximum(pt.dot(x, W1) + b1, 0) + + # Layer 2: h @ W2 + b2, then softmax (axis=-1 for row-wise probabilities) + logits = pt.dot(h, W2) + b2 + output = softmax(logits, axis=-1) + + # Test data + rng = np.random.default_rng(42) + x_val = rng.normal(size=(5, 10)).astype("float32") + W1_val = rng.normal(size=(10, 20)).astype("float32") + b1_val = rng.normal(size=(20,)).astype("float32") + W2_val = rng.normal(size=(20, 3)).astype("float32") + b2_val = rng.normal(size=(3,)).astype("float32") + + _fn, result = compare_onnx_and_py( + [x, W1, b1, W2, b2], output, [x_val, W1_val, b1_val, W2_val, b2_val] + ) + + # Verify output is valid probabilities + assert result.shape == (5, 3), f"Expected shape (5, 3), got {result.shape}" + assert np.allclose(result.sum(axis=1), 1.0), "Softmax should sum to 1" + assert np.all(result >= 0) and np.all(result <= 1), ( + "Probabilities should be in [0, 1]" + ) diff --git a/tests/link/onnx/test_linker.py b/tests/link/onnx/test_linker.py new file mode 100644 index 0000000000..4a458ecb59 --- /dev/null +++ b/tests/link/onnx/test_linker.py @@ -0,0 +1,65 @@ +"""Tests for ONNXLinker.""" + +import numpy as np + +from pytensor.compile.mode import Mode + + +def test_linker_instantiation(): + """Test that ONNXLinker can be instantiated.""" + from pytensor.link.onnx.linker import ONNXLinker + + linker = ONNXLinker(opset_version=18) + + assert linker is not None, "Linker instantiation returned None" + assert linker.opset_version == 18, f"Expected opset 18, got {linker.opset_version}" + + +def test_linker_empty_graph(): + """Test that linker can convert a trivial passthrough graph.""" + import pytensor + import pytensor.tensor as pt + from pytensor.link.onnx.linker import ONNXLinker + + # Create identity graph + x = pt.scalar("x", dtype="float32") + y = x # Passthrough + + # Compile with ONNX linker + fn = pytensor.function([x], y, mode=Mode(linker=ONNXLinker())) + + # Test execution + result = fn(5.0) + assert result == 5.0, f"Expected 5.0, got {result}" + + # Verify ONNX model exists + assert hasattr(fn.maker.linker, "onnx_model"), ( + "Linker should have onnx_model attribute" + ) + assert fn.maker.linker.onnx_model is not None, "onnx_model should not be None" + + +def test_linker_constant_graph(): + """Test that linker correctly handles constants as initializers.""" + import pytensor + import pytensor.tensor as pt + from pytensor.link.onnx.linker import ONNXLinker + + # Create graph with constant + x = pt.scalar("x", dtype="float32") + c = pt.constant(2.0, dtype="float32") + y = x * c + + # Compile + fn = pytensor.function([x], y, mode=Mode(linker=ONNXLinker())) + + # Test + result = fn(3.0) + expected = 6.0 + np.testing.assert_allclose(result, expected, rtol=1e-5) + + # Verify ONNX model has initializer for constant + model = fn.maker.linker.onnx_model + assert len(model.graph.initializer) > 0, ( + "Model should have at least one initializer for the constant" + ) diff --git a/tests/link/onnx/test_math.py b/tests/link/onnx/test_math.py new file mode 100644 index 0000000000..da738ca7a4 --- /dev/null +++ b/tests/link/onnx/test_math.py @@ -0,0 +1,208 @@ +"""Tests for ONNX math operations (reductions).""" + +import numpy as np +import pytest +from hypothesis import given, settings +from hypothesis import strategies as st + +import pytensor.tensor as pt +from tests.link.onnx.strategies import REDUCTION_OPERATIONS +from tests.link.onnx.test_basic import compare_onnx_and_py, get_onnx_node_types + + +# Import ONNX and skip if not available +onnx = pytest.importorskip("onnx") +ort = pytest.importorskip("onnxruntime") + + +# ============================================================================ +# Property-Based Tests for Reduction Operations +# ============================================================================ + + +@given( + op_name=st.sampled_from(list(REDUCTION_OPERATIONS.keys())), + data=st.data(), +) +@settings(max_examples=10, deadline=None) +def test_reduction_operations_correctness(op_name, data): + """Property test: All reduction operations produce correct ONNX results. + + Tests: sum, prod, max, min, argmax, argmin, all, any + Total: 8 operations x 10 examples = 80 test scenarios + """ + op_config = REDUCTION_OPERATIONS[op_name] + + # Generate tensor and axis + test_data = data.draw(op_config["strategy"]) + + # Build graph + graph_inputs, graph_output = op_config["build_graph"](*test_data) + + # Compare ONNX vs PyTensor + fn, _result = compare_onnx_and_py(graph_inputs, graph_output, [test_data[0]]) + + # Verify ONNX nodes + node_types = get_onnx_node_types(fn) + expected_ops = op_config["expected_onnx_ops"] + assert any(op in node_types for op in expected_ops), ( + f"{op_name}: Expected {expected_ops}, got {node_types}" + ) + + +# ============================================================================ +# Specific Tests for Edge Cases +# ============================================================================ + + +def test_reduction_keepdims(): + """Reduction with keepdims parameter.""" + x = pt.matrix("x", dtype="float32") + y = pt.sum(x, axis=1, keepdims=True) + + x_val = np.random.randn(3, 4).astype("float32") + + fn, result = compare_onnx_and_py([x], y, [x_val]) + + assert result.shape == (3, 1) + assert "ReduceSum" in get_onnx_node_types(fn) + + +@pytest.mark.parametrize("axis", [None, 0, 1, [0, 1]]) +def test_reduction_axis_variations(axis): + """Test reductions with different axis specifications.""" + x = pt.matrix("x", dtype="float32") + y = pt.sum(x, axis=axis) + + x_val = np.random.randn(3, 4).astype("float32") + + fn, _result = compare_onnx_and_py([x], y, [x_val]) + + assert "ReduceSum" in get_onnx_node_types(fn) + + +def test_sum_reduction(): + """Basic sum reduction.""" + x = pt.matrix("x", dtype="float32") + y = pt.sum(x, axis=1) + + x_val = np.random.randn(4, 5).astype("float32") + + fn, result = compare_onnx_and_py([x], y, [x_val]) + + expected = np.sum(x_val, axis=1) + np.testing.assert_allclose(result, expected, rtol=1e-4) + assert "ReduceSum" in get_onnx_node_types(fn) + + +def test_prod_reduction(): + """Product reduction.""" + x = pt.matrix("x", dtype="float32") + y = pt.prod(x, axis=0) + + x_val = np.random.randn(3, 4).astype("float32") + + fn, result = compare_onnx_and_py([x], y, [x_val]) + + expected = np.prod(x_val, axis=0) + np.testing.assert_allclose(result, expected, rtol=1e-4) + assert "ReduceProd" in get_onnx_node_types(fn) + + +def test_max_min_reduction(): + """Max and min reductions.""" + x = pt.matrix("x", dtype="float32") + y_max = pt.max(x, axis=1) + y_min = pt.min(x, axis=1) + + x_val = np.random.randn(4, 5).astype("float32") + + fn_max, result_max = compare_onnx_and_py([x], y_max, [x_val]) + fn_min, result_min = compare_onnx_and_py([x], y_min, [x_val]) + + expected_max = np.max(x_val, axis=1) + expected_min = np.min(x_val, axis=1) + + np.testing.assert_allclose(result_max, expected_max, rtol=1e-4) + np.testing.assert_allclose(result_min, expected_min, rtol=1e-4) + + assert "ReduceMax" in get_onnx_node_types(fn_max) + # Min is implemented as -max(-x), so we expect Neg and ReduceMax + node_types_min = get_onnx_node_types(fn_min) + assert "ReduceMax" in node_types_min and "Neg" in node_types_min + + +def test_argmax_argmin(): + """Argmax and argmin reductions.""" + x = pt.matrix("x", dtype="float32") + y_argmax = pt.argmax(x, axis=1) + y_argmin = pt.argmin(x, axis=1) + + x_val = np.random.randn(4, 5).astype("float32") + + fn_argmax, result_argmax = compare_onnx_and_py([x], y_argmax, [x_val]) + fn_argmin, result_argmin = compare_onnx_and_py([x], y_argmin, [x_val]) + + expected_argmax = np.argmax(x_val, axis=1) + expected_argmin = np.argmin(x_val, axis=1) + + np.testing.assert_array_equal(result_argmax, expected_argmax) + np.testing.assert_array_equal(result_argmin, expected_argmin) + + assert "ArgMax" in get_onnx_node_types(fn_argmax) + # ArgMin is implemented as ArgMax of negated input + node_types_argmin = get_onnx_node_types(fn_argmin) + assert "ArgMax" in node_types_argmin or "ArgMin" in node_types_argmin + + +@pytest.mark.skip( + reason="Boolean reduction operations (all/any) not yet fully supported in ONNX backend" +) +def test_logical_reductions(): + """Test logical all and any reductions.""" + x = pt.matrix("x", dtype="bool") + y_all = pt.all(x, axis=1) + y_any = pt.any(x, axis=1) + + x_val = np.random.rand(4, 5) > 0.5 + + fn_all, result_all = compare_onnx_and_py([x], y_all, [x_val]) + fn_any, result_any = compare_onnx_and_py([x], y_any, [x_val]) + + expected_all = np.all(x_val, axis=1) + expected_any = np.any(x_val, axis=1) + + np.testing.assert_array_equal(result_all, expected_all) + np.testing.assert_array_equal(result_any, expected_any) + + # All/Any map to ReduceMin/ReduceMax for boolean tensors + node_types_all = get_onnx_node_types(fn_all) + node_types_any = get_onnx_node_types(fn_any) + assert "ReduceMin" in node_types_all or "ReduceMax" in node_types_all + assert "ReduceMin" in node_types_any or "ReduceMax" in node_types_any + + +def test_reduction_no_axis(): + """Reduction over all axes (axis=None).""" + x = pt.matrix("x", dtype="float32") + y = pt.sum(x) # Sum over all axes + + x_val = np.random.randn(3, 4).astype("float32") + + _fn, result = compare_onnx_and_py([x], y, [x_val]) + + expected = np.sum(x_val) + np.testing.assert_allclose(result, expected, rtol=1e-4) + + +def test_reduction_multiple_axes(): + """Reduction over multiple axes.""" + x = pt.tensor3("x", dtype="float32") + y = pt.sum(x, axis=[0, 2]) + + x_val = np.random.randn(2, 3, 4).astype("float32") + + _fn, result = compare_onnx_and_py([x], y, [x_val]) + + expected = np.sum(x_val, axis=(0, 2)) + np.testing.assert_allclose(result, expected, rtol=1e-4) diff --git a/tests/link/onnx/test_nlinalg.py b/tests/link/onnx/test_nlinalg.py new file mode 100644 index 0000000000..b450230854 --- /dev/null +++ b/tests/link/onnx/test_nlinalg.py @@ -0,0 +1,261 @@ +"""Tests for ONNX backend linear algebra operations (Tier 4).""" + +import numpy as np +import pytest + +import pytensor +import pytensor.tensor as pt +from pytensor.compile.mode import Mode +from tests.link.onnx.test_basic import compare_onnx_and_py, get_onnx_node_types + + +# Matrix Multiplication Tests + + +def test_dot_2d(): + """Test 2D matrix multiplication (Dot op).""" + pytest.importorskip("onnx") + pytest.importorskip("onnxruntime") + + A = pt.matrix("A", dtype="float32") + B = pt.matrix("B", dtype="float32") + C = pt.dot(A, B) + + A_val = np.random.randn(3, 4).astype("float32") + B_val = np.random.randn(4, 5).astype("float32") + + fn, result = compare_onnx_and_py([A, B], C, [A_val, B_val]) + + expected = np.dot(A_val, B_val) + np.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-6) + + # Verify ONNX uses MatMul + node_types = get_onnx_node_types(fn) + assert "MatMul" in node_types, f"Expected 'MatMul' node, got {node_types}" + + +def test_dot_1d_2d(): + """Test vector-matrix multiplication.""" + pytest.importorskip("onnx") + pytest.importorskip("onnxruntime") + + v = pt.vector("v", dtype="float32") + M = pt.matrix("M", dtype="float32") + result = pt.dot(v, M) + + v_val = np.random.randn(4).astype("float32") + M_val = np.random.randn(4, 5).astype("float32") + + _fn, output = compare_onnx_and_py([v, M], result, [v_val, M_val]) + + expected = np.dot(v_val, M_val) + np.testing.assert_allclose(output, expected, rtol=1e-5, atol=1e-6) + + # Should be 1D output + assert output.ndim == 1, f"Expected 1D output, got shape {output.shape}" + + +def test_batched_dot(): + """Test batched matrix multiplication.""" + pytest.importorskip("onnx") + pytest.importorskip("onnxruntime") + + A = pt.tensor3("A", dtype="float32") + B = pt.tensor3("B", dtype="float32") + C = pt.batched_dot(A, B) + + A_val = np.random.randn(2, 3, 4).astype("float32") + B_val = np.random.randn(2, 4, 5).astype("float32") + + fn, result = compare_onnx_and_py([A, B], C, [A_val, B_val]) + + expected = np.einsum("bij,bjk->bik", A_val, B_val) + np.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-6) + + # ONNX MatMul handles batched operations natively + node_types = get_onnx_node_types(fn) + assert "MatMul" in node_types + + +def test_gemm(): + """Test GEMM: beta*C + alpha*A@B.""" + pytest.importorskip("onnx") + pytest.importorskip("onnxruntime") + + from pytensor.tensor.blas import gemm + + A = pt.matrix("A", dtype="float32") + B = pt.matrix("B", dtype="float32") + C = pt.matrix("C", dtype="float32") + + # GEMM: gemm(C, alpha, A, B, beta) = beta*C + alpha*dot(A, B) + # GEMM: 0.5 * C + 2.0 * A @ B + alpha = np.float32(2.0) + beta = np.float32(0.5) + result = gemm(C, alpha, A, B, beta) + + A_val = np.random.randn(3, 4).astype("float32") + B_val = np.random.randn(4, 5).astype("float32") + C_val = np.random.randn(3, 5).astype("float32") + + fn, output = compare_onnx_and_py([A, B, C], result, [A_val, B_val, C_val]) + + expected = beta * C_val + alpha * np.dot(A_val, B_val) + np.testing.assert_allclose(output, expected, rtol=1e-5, atol=1e-6) + + # ONNX has Gemm operator + node_types = get_onnx_node_types(fn) + assert "Gemm" in node_types, f"Expected 'Gemm' node, got {node_types}" + + +# Matrix Decomposition Tests (Unsupported) + + +@pytest.mark.skip( + reason="SVD not in standard ONNX opset - requires contrib ops or custom implementation" +) +def test_svd_not_supported(): + """Test SVD - expected to be unsupported in standard ONNX. + + SVD decomposes A into U, S, V.T where A = U @ diag(S) @ V.T + This is NOT available in standard ONNX opset. + + Options: + 1. Use ONNX Runtime contrib op (platform-specific) + 2. Implement as sequence of operations (very complex) + 3. Skip and document as unsupported + + This test documents the expected behavior if we choose to implement. + """ + from pytensor.link.onnx.linker import ONNXLinker + from pytensor.tensor.nlinalg import svd + + A = pt.matrix("A", dtype="float32") + U, s, Vt = svd(A, full_matrices=False) + + # This will raise NotImplementedError + onnx_mode = Mode(linker=ONNXLinker(), optimizer=None) + with pytest.raises(NotImplementedError, match="SVD not supported"): + pytensor.function([A], [U, s, Vt], mode=onnx_mode) + + +@pytest.mark.skip(reason="Cholesky not in standard ONNX opset") +def test_cholesky_not_supported(): + """Test Cholesky decomposition - not in standard ONNX. + + Cholesky decomposes positive definite A into L @ L.T + where L is lower triangular. + + Not available in standard ONNX opset. ONNX Runtime may have + contrib op: com.microsoft.Cholesky + """ + from pytensor.link.onnx.linker import ONNXLinker + from pytensor.tensor.slinalg import cholesky + + A = pt.matrix("A", dtype="float32") + L = cholesky(A) + + onnx_mode = Mode(linker=ONNXLinker(), optimizer=None) + with pytest.raises(NotImplementedError, match="Cholesky not supported"): + pytensor.function([A], L, mode=onnx_mode) + + +# Linear System Solving Tests (Unsupported) + + +@pytest.mark.skip(reason="Solve not in standard ONNX opset") +def test_solve_not_supported(): + """Test Solve operation - not in standard ONNX. + + Solve finds X such that A @ X = B. + Not available in standard ONNX. Would require: + - LU decomposition (not in ONNX) + - Forward/backward substitution + - Or matrix inverse + matmul + """ + from pytensor.link.onnx.linker import ONNXLinker + from pytensor.tensor.slinalg import solve + + A = pt.matrix("A", dtype="float32") + B = pt.matrix("B", dtype="float32") + X = solve(A, B) + + onnx_mode = Mode(linker=ONNXLinker(), optimizer=None) + with pytest.raises(NotImplementedError, match="Solve not supported"): + pytensor.function([A, B], X, mode=onnx_mode) + + +# Matrix Properties Tests (Unsupported) + + +@pytest.mark.skip( + reason="Det requires LU decomposition - complex custom implementation needed" +) +def test_det_custom_implementation(): + """Test matrix determinant - requires custom implementation. + + Determinant can be computed via: + 1. LU decomposition + product of diagonal (preferred) + 2. QR decomposition + product of R diagonal + 3. Direct computation for small matrices + + All approaches require operations not in standard ONNX. + """ + from pytensor.link.onnx.linker import ONNXLinker + from pytensor.tensor.nlinalg import det + + A = pt.matrix("A", dtype="float32") + d = det(A) + + onnx_mode = Mode(linker=ONNXLinker(), optimizer=None) + with pytest.raises(NotImplementedError, match="Det not supported"): + pytensor.function([A], d, mode=onnx_mode) + + +@pytest.mark.skip(reason="Matrix inverse not in standard ONNX opset") +def test_matrix_inverse_not_supported(): + """Test matrix inverse - not in standard ONNX. + + Matrix inverse could be implemented via: + 1. LU decomposition + solving (not available) + 2. Adjugate method (very complex) + 3. Gradient descent (iterative, expensive) + + Not practical for standard ONNX export. + """ + from pytensor.link.onnx.linker import ONNXLinker + from pytensor.tensor.nlinalg import matrix_inverse + + A = pt.matrix("A", dtype="float32") + A_inv = matrix_inverse(A) + + onnx_mode = Mode(linker=ONNXLinker(), optimizer=None) + with pytest.raises(NotImplementedError, match="Matrix inverse not supported"): + pytensor.function([A], A_inv, mode=onnx_mode) + + +# Extract Diagonal Tests + + +def test_extract_diag(): + """Test extracting diagonal from matrix. + + This CAN be implemented in ONNX using: + - Identity matrix of appropriate size + - Element-wise multiply with input + - ReduceSum along one axis + + Or using Gather operations. + """ + pytest.importorskip("onnx") + pytest.importorskip("onnxruntime") + + A = pt.matrix("A", dtype="float32") + d = pt.diag(A) # Extract diagonal + + A_val = np.random.randn(4, 4).astype("float32") + + _fn, result = compare_onnx_and_py([A], d, [A_val]) + + expected = np.diag(A_val) + np.testing.assert_array_equal(result, expected) diff --git a/tests/link/onnx/test_nnet.py b/tests/link/onnx/test_nnet.py new file mode 100644 index 0000000000..56ffc1d368 --- /dev/null +++ b/tests/link/onnx/test_nnet.py @@ -0,0 +1,93 @@ +"""Tests for ONNX backend neural network operations (Tier 5).""" + +import numpy as np +import pytest + +import pytensor.tensor as pt +from pytensor.tensor.special import log_softmax, softmax +from tests.link.onnx.test_basic import compare_onnx_and_py, get_onnx_node_types + + +# Softmax Tests + + +@pytest.mark.parametrize("axis", [None, -1, 0, 1]) +def test_softmax(axis): + """Test softmax activation.""" + pytest.importorskip("onnx") + pytest.importorskip("onnxruntime") + + from scipy.special import softmax as scipy_softmax + + x = pt.matrix("x", dtype="float32") + y = softmax(x, axis=axis) + + x_val = np.random.randn(3, 4).astype("float32") + + fn, result = compare_onnx_and_py([x], y, [x_val]) + + # Compute expected with scipy + # Note: axis=None applies to the entire flattened array + expected = scipy_softmax(x_val, axis=axis) + np.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-6) + + node_types = get_onnx_node_types(fn) + assert "Softmax" in node_types + + +def test_logsoftmax(): + """Test log-softmax activation.""" + pytest.importorskip("onnx") + pytest.importorskip("onnxruntime") + + from scipy.special import log_softmax as scipy_log_softmax + + x = pt.matrix("x", dtype="float32") + # Explicitly specify axis=1 to match typical neural network usage + y = log_softmax(x, axis=1) + + x_val = np.random.randn(3, 4).astype("float32") + + fn, result = compare_onnx_and_py([x], y, [x_val]) + + expected = scipy_log_softmax(x_val, axis=1) + np.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-6) + + node_types = get_onnx_node_types(fn) + assert "LogSoftmax" in node_types + + +# Switch Test + + +def test_switch(): + """Test Switch operation (element-wise conditional). + + Switch(condition, then_value, else_value) returns: + - then_value where condition is True + - else_value where condition is False + + In ONNX this maps to Where operator. + """ + pytest.importorskip("onnx") + pytest.importorskip("onnxruntime") + + condition = pt.vector("condition", dtype="bool") + x = pt.vector("x", dtype="float32") + y = pt.vector("y", dtype="float32") + + result = pt.switch(condition, x, y) + + cond_val = np.array([True, False, True, False, True], dtype=bool) + x_val = np.array([1, 2, 3, 4, 5], dtype="float32") + y_val = np.array([10, 20, 30, 40, 50], dtype="float32") + + fn, output = compare_onnx_and_py( + [condition, x, y], result, [cond_val, x_val, y_val] + ) + + expected = np.where(cond_val, x_val, y_val) + np.testing.assert_array_equal(output, expected) + + node_types = get_onnx_node_types(fn) + assert "Where" in node_types, f"Expected 'Where' node, got {node_types}" diff --git a/tests/link/onnx/test_shape.py b/tests/link/onnx/test_shape.py new file mode 100644 index 0000000000..2ab6815494 --- /dev/null +++ b/tests/link/onnx/test_shape.py @@ -0,0 +1,577 @@ +"""Tests for ONNX shape operations.""" + +import numpy as np +import pytest +from hypothesis import given, settings +from hypothesis import strategies as st +from hypothesis.extra.numpy import array_shapes + +import pytensor.tensor as pt +from pytensor.tensor.shape import Shape_i +from tests.link.onnx.strategies import SHAPE_OPERATIONS +from tests.link.onnx.test_basic import compare_onnx_and_py, get_onnx_node_types + + +# Import ONNX and skip if not available +onnx = pytest.importorskip("onnx") +ort = pytest.importorskip("onnxruntime") + + +# ============================================================================ +# PROPERTY-BASED TESTS - Shape Inspection +# ============================================================================ + + +@given(data=st.data()) +@settings(max_examples=10, deadline=None) +def test_shape_operation_correctness(data): + """ + Property test: Shape operation returns correct tensor shape. + + This test verifies: + - Shape operation returns correct dimensions + - Output is int64 array + - Correct ONNX node type (Shape) is generated + - Works with tensors of various dimensionalities (1D-4D) + """ + op_config = SHAPE_OPERATIONS["shape"] + + # Generate test tensor + test_data = data.draw(op_config["strategy"]) + + # Build graph + x = pt.tensor("x", dtype="float32", shape=(None,) * test_data.ndim) + graph_inputs, graph_output = op_config["build_graph"](x) + + # Compare ONNX vs PyTensor + fn, result = compare_onnx_and_py(graph_inputs, graph_output, [test_data]) + + # Validate result + expected_shape = np.array(test_data.shape, dtype="int64") + np.testing.assert_array_equal(result, expected_shape) + + # Verify ONNX node type + node_types = get_onnx_node_types(fn) + assert "Shape" in node_types, f"Expected 'Shape' node, got {node_types}" + + +@given(data=st.data()) +@settings(max_examples=10, deadline=None) +def test_shape_i_operation_correctness(data): + """ + Property test: Shape_i operation returns correct dimension. + + This test verifies: + - Shape_i returns correct dimension value + - Output is scalar integer + - Correct ONNX node pattern (Constant + Shape + Gather) + - Works with valid dimension indices + """ + op_config = SHAPE_OPERATIONS["shape_i"] + + # Generate test data (tensor and valid dimension index) + test_data = data.draw(op_config["strategy"]) + x_val, dim_index = test_data + + # Build graph + x = pt.tensor("x", dtype="float32", shape=(None,) * x_val.ndim) + graph_inputs, graph_output = op_config["build_graph"](x, dim_index) + + # Compare ONNX vs PyTensor + fn, result = compare_onnx_and_py(graph_inputs, graph_output, [x_val]) + + # Validate result + expected_dim = x_val.shape[dim_index] + assert result == expected_dim, ( + f"Expected dimension {dim_index} to be {expected_dim}, got {result}" + ) + + # Verify ONNX node pattern (multi-node return) + node_types = get_onnx_node_types(fn) + assert "Shape" in node_types, "Expected 'Shape' node" + assert "Gather" in node_types, "Expected 'Gather' node" + + +@given(data=st.data()) +@settings(max_examples=10, deadline=None) +def test_specify_shape_passthrough_correctness(data): + """ + Property test: SpecifyShape passes through without creating ONNX nodes. + + This test verifies: + - SpecifyShape doesn't appear in ONNX graph + - Computation continues correctly after SpecifyShape + - Numerical correctness maintained + - Return pattern: None (pass-through) + """ + from pytensor.tensor.shape import specify_shape + + # Generate random tensor + shape = data.draw(array_shapes(min_dims=1, max_dims=3, min_side=2, max_side=10)) + x_val = np.random.randn(*shape).astype("float32") + + # Build graph with SpecifyShape in the middle + x = pt.tensor("x", dtype="float32", shape=(None,) * x_val.ndim) + x_specified = specify_shape(x, x_val.shape) + y = x_specified * 2.0 # Some computation after SpecifyShape + + # Compare ONNX vs PyTensor + fn, result = compare_onnx_and_py([x], y, [x_val]) + + # Validate numerical correctness + expected = x_val * 2.0 + np.testing.assert_allclose(result, expected, rtol=1e-5) + + # Verify SpecifyShape doesn't appear in ONNX + node_types = get_onnx_node_types(fn) + assert "SpecifyShape" not in node_types, ( + "SpecifyShape should not appear in ONNX graph (it's a pass-through)" + ) + + +# ============================================================================ +# PROPERTY-BASED TESTS - Reshape Operations +# ============================================================================ + + +@given(data=st.data()) +@settings(max_examples=10, deadline=None) +def test_reshape_operation_correctness(data): + """ + Property test: Reshape operation correctly transforms tensor shape. + + This test verifies: + - Reshape produces correct output shape + - Element values preserved (same data, different shape) + - Total element count preserved + - Correct ONNX node type (Reshape) + """ + op_config = SHAPE_OPERATIONS["reshape"] + + # Generate tensor and compatible reshape target + test_data = data.draw(op_config["strategy"]) + x_val, new_shape = test_data + + # Build graph + x = pt.tensor("x", dtype="float32", shape=(None,) * x_val.ndim) + graph_inputs, graph_output = op_config["build_graph"](x, new_shape) + + # Compare ONNX vs PyTensor + fn, result = compare_onnx_and_py(graph_inputs, graph_output, [x_val]) + + # Validate shape transformation + expected = x_val.reshape(new_shape) + np.testing.assert_array_equal(result, expected) + assert result.shape == new_shape, f"Expected shape {new_shape}, got {result.shape}" + + # Verify total elements preserved + assert result.size == x_val.size, ( + f"Element count changed: {x_val.size} -> {result.size}" + ) + + # Verify ONNX node type + node_types = get_onnx_node_types(fn) + assert "Reshape" in node_types, f"Expected 'Reshape' node, got {node_types}" + + +@given(data=st.data()) +@settings(max_examples=10, deadline=None) +def test_transpose_operation_correctness(data): + """ + Property test: Transpose operation correctly transposes matrices. + + This test verifies: + - Transpose swaps axes (shape becomes (cols, rows)) + - Element values correctly repositioned + - Correct ONNX node type (Transpose) + - Works with various matrix sizes + """ + op_config = SHAPE_OPERATIONS["transpose"] + + # Generate 2D matrix + test_data = data.draw(op_config["strategy"]) + + # Build graph + x = pt.tensor("x", dtype="float32", shape=(None,) * test_data.ndim) + graph_inputs, graph_output = op_config["build_graph"](x) + + # Compare ONNX vs PyTensor + fn, result = compare_onnx_and_py(graph_inputs, graph_output, [test_data]) + + # Validate transposition + expected = test_data.T + np.testing.assert_allclose(result, expected, rtol=1e-5) + assert result.shape == (test_data.shape[1], test_data.shape[0]), ( + f"Expected shape {test_data.T.shape}, got {result.shape}" + ) + + # Verify ONNX node type + node_types = get_onnx_node_types(fn) + assert "Transpose" in node_types, f"Expected 'Transpose' node, got {node_types}" + + +@given(data=st.data()) +@settings(max_examples=10, deadline=None) +def test_dimshuffle_add_dim_correctness(data): + """ + Property test: DimShuffle correctly adds dimensions. + + This test verifies: + - DimShuffle adds dimension at correct position + - Shape changes correctly (e.g., (5,) -> (1, 5)) + - Element values unchanged + - Correct ONNX node type (Unsqueeze) + """ + op_config = SHAPE_OPERATIONS["dimshuffle_add_dim"] + + # Generate vector + test_data = data.draw(op_config["strategy"]) + + # Build graph (adds dimension at position 0) + x = pt.tensor("x", dtype="float32", shape=(None,) * test_data.ndim) + graph_inputs, graph_output = op_config["build_graph"](x) + + # Compare ONNX vs PyTensor + fn, result = compare_onnx_and_py(graph_inputs, graph_output, [test_data]) + + # Validate dimension addition + expected = test_data[np.newaxis, :] # Add dimension at position 0 + np.testing.assert_allclose(result, expected, rtol=1e-5) + assert result.shape == (1, test_data.shape[0]), ( + f"Expected shape (1, {test_data.shape[0]}), got {result.shape}" + ) + + # Verify ONNX node type + node_types = get_onnx_node_types(fn) + assert "Unsqueeze" in node_types, f"Expected 'Unsqueeze' node, got {node_types}" + + +@given(data=st.data()) +@settings(max_examples=10, deadline=None) +def test_dimshuffle_squeeze_correctness(data): + """ + Property test: DimShuffle correctly removes singleton dimensions. + + This test verifies: + - DimShuffle removes dimension of size 1 + - Shape changes correctly (e.g., (3, 1, 4) -> (3, 4)) + - Element values unchanged + - Correct ONNX node type (Squeeze) + """ + op_config = SHAPE_OPERATIONS["dimshuffle_squeeze"] + + # Generate tensor with singleton dimension + test_data = data.draw(op_config["strategy"]) + + # Build graph (removes dimension at position 1) + x = pt.tensor("x", dtype="float32", shape=(None,) * test_data.ndim) + graph_inputs, graph_output = op_config["build_graph"](x) + + # Compare ONNX vs PyTensor + fn, result = compare_onnx_and_py(graph_inputs, graph_output, [test_data]) + + # Validate dimension removal + expected = test_data.squeeze(axis=1) + np.testing.assert_allclose(result, expected, rtol=1e-5) + assert result.ndim == test_data.ndim - 1, ( + f"Expected {test_data.ndim - 1} dimensions, got {result.ndim}" + ) + + # Verify ONNX node type + node_types = get_onnx_node_types(fn) + assert "Squeeze" in node_types, f"Expected 'Squeeze' node, got {node_types}" + + +# ============================================================================ +# PROPERTY-BASED TESTS - Join/Split Operations +# ============================================================================ + + +@given(data=st.data()) +@settings(max_examples=10, deadline=None) +def test_concatenate_operation_correctness(data): + """ + Property test: Concatenate correctly joins tensors. + + This test verifies: + - Concatenate joins tensors along specified axis + - Output shape is correct (sum of input dimensions) + - Element values correctly positioned + - Correct ONNX node type (Concat) + """ + op_config = SHAPE_OPERATIONS["concatenate"] + + # Generate two compatible tensors and axis + test_data = data.draw(op_config["strategy"]) + a_val, b_val, axis = test_data + + # Build graph + a = pt.tensor("a", dtype="float32", shape=(None,) * a_val.ndim) + b = pt.tensor("b", dtype="float32", shape=(None,) * b_val.ndim) + graph_inputs, graph_output = op_config["build_graph"](a, b, axis) + + # Compare ONNX vs PyTensor + fn, result = compare_onnx_and_py(graph_inputs, graph_output, [a_val, b_val]) + + # Validate concatenation + expected = np.concatenate([a_val, b_val], axis=axis) + np.testing.assert_allclose(result, expected, rtol=1e-5) + + # Verify shape along concatenation axis + expected_shape = list(a_val.shape) + expected_shape[axis] = a_val.shape[axis] + b_val.shape[axis] + assert result.shape == tuple(expected_shape), ( + f"Expected shape {tuple(expected_shape)}, got {result.shape}" + ) + + # Verify ONNX node type + node_types = get_onnx_node_types(fn) + assert "Concat" in node_types, f"Expected 'Concat' node, got {node_types}" + + +@given(data=st.data()) +@settings(max_examples=10, deadline=None) +def test_stack_operation_correctness(data): + """ + Property test: Stack correctly stacks tensors with new dimension. + + This test verifies: + - Stack adds new dimension for stacking + - Output shape is correct (adds 1 to ndim) + - Element values correctly positioned + - Correct ONNX node types (Unsqueeze + Concat) + """ + op_config = SHAPE_OPERATIONS["stack"] + + # Generate two tensors with same shape + test_data = data.draw(op_config["strategy"]) + a_val, b_val = test_data + + # Build graph (stack along axis 0) + a = pt.tensor("a", dtype="float32", shape=(None,) * a_val.ndim) + b = pt.tensor("b", dtype="float32", shape=(None,) * b_val.ndim) + graph_inputs, graph_output = op_config["build_graph"](a, b) + + # Compare ONNX vs PyTensor + fn, result = compare_onnx_and_py(graph_inputs, graph_output, [a_val, b_val]) + + # Validate stacking + expected = np.stack([a_val, b_val], axis=0) + np.testing.assert_allclose(result, expected, rtol=1e-5) + + # Verify shape (added dimension) + assert result.ndim == a_val.ndim + 1, ( + f"Expected {a_val.ndim + 1} dimensions, got {result.ndim}" + ) + assert result.shape[0] == 2, f"Expected size 2 along axis 0, got {result.shape[0]}" + + # Verify ONNX node types + node_types = get_onnx_node_types(fn) + assert "Concat" in node_types or "Unsqueeze" in node_types, ( + f"Expected 'Concat' or 'Unsqueeze' nodes, got {node_types}" + ) + + +# ============================================================================ +# MANUAL EDGE CASE TESTS +# ============================================================================ + + +def test_shape_basic(): + """Test Shape operation (single node return).""" + x = pt.matrix("x", dtype="float32") + y = x.shape + + x_val = np.random.randn(3, 4).astype("float32") + + fn, result = compare_onnx_and_py([x], y, [x_val]) + + expected = np.array([3, 4], dtype="int64") + np.testing.assert_array_equal(result, expected) + + node_types = get_onnx_node_types(fn) + assert "Shape" in node_types + + +def test_shape_i_dim0(): + """Test Shape_i getting dimension 0 (multi-node return).""" + x = pt.matrix("x", dtype="float32") + # Use Shape_i directly to test the multi-node return pattern + shape_i_op = Shape_i(0) + y = shape_i_op(x) + + x_val = np.random.randn(3, 4).astype("float32") + + fn, result = compare_onnx_and_py([x], y, [x_val]) + + assert result == 3 + + # Verify multi-node pattern: Constant + Shape + Gather + node_types = get_onnx_node_types(fn) + assert "Constant" in node_types + assert "Shape" in node_types + assert "Gather" in node_types + + +def test_shape_i_dim1(): + """Test Shape_i getting dimension 1 (multi-node return).""" + x = pt.matrix("x", dtype="float32") + # Use Shape_i directly + shape_i_op = Shape_i(1) + y = shape_i_op(x) + + x_val = np.random.randn(3, 4).astype("float32") + + fn, result = compare_onnx_and_py([x], y, [x_val]) + + assert result == 4 + + node_types = get_onnx_node_types(fn) + assert "Shape" in node_types + assert "Gather" in node_types + + +def test_shape_i_3d_tensor(): + """Test Shape_i with 3D tensor.""" + x = pt.tensor3("x", dtype="float32") + # Use Shape_i directly for each dimension + dim0 = Shape_i(0)(x) + dim1 = Shape_i(1)(x) + dim2 = Shape_i(2)(x) + + x_val = np.random.randn(2, 3, 4).astype("float32") + + # Test each dimension separately + _fn0, result0 = compare_onnx_and_py([x], dim0, [x_val]) + assert result0 == 2 + + _fn1, result1 = compare_onnx_and_py([x], dim1, [x_val]) + assert result1 == 3 + + _fn2, result2 = compare_onnx_and_py([x], dim2, [x_val]) + assert result2 == 4 + + +def test_specify_shape_passthrough(): + """Test that SpecifyShape creates no ONNX nodes (None return).""" + from pytensor.tensor.shape import specify_shape + + x = pt.vector("x", dtype="float32") + # SpecifyShape should pass through without creating ONNX nodes + x_specified = specify_shape(x, (4,)) + y = x_specified * 2.0 + + x_val = np.array([1.0, 2.0, 3.0, 4.0], dtype="float32") + + fn, result = compare_onnx_and_py([x], y, [x_val]) + + # SpecifyShape should not appear in ONNX graph + node_types = get_onnx_node_types(fn) + assert "SpecifyShape" not in node_types + assert "Mul" in node_types + + expected = x_val * 2.0 + np.testing.assert_allclose(result, expected, rtol=1e-5) + + +def test_concatenate_axis0(): + """Test concatenate operation along axis 0.""" + x = pt.matrix("x", dtype="float32") + y = pt.matrix("y", dtype="float32") + z = pt.concatenate([x, y], axis=0) + + x_val = np.random.randn(2, 3).astype("float32") + y_val = np.random.randn(4, 3).astype("float32") + + fn, result = compare_onnx_and_py([x, y], z, [x_val, y_val]) + + expected = np.concatenate([x_val, y_val], axis=0) + np.testing.assert_allclose(result, expected, rtol=1e-5) + + node_types = get_onnx_node_types(fn) + assert "Concat" in node_types + + +def test_concatenate_axis1(): + """Test concatenate operation along axis 1.""" + x = pt.matrix("x", dtype="float32") + y = pt.matrix("y", dtype="float32") + z = pt.concatenate([x, y], axis=1) + + x_val = np.random.randn(3, 2).astype("float32") + y_val = np.random.randn(3, 4).astype("float32") + + fn, result = compare_onnx_and_py([x, y], z, [x_val, y_val]) + + expected = np.concatenate([x_val, y_val], axis=1) + np.testing.assert_allclose(result, expected, rtol=1e-5) + + node_types = get_onnx_node_types(fn) + assert "Concat" in node_types + + +def test_stack_axis0(): + """Test stack operation along axis 0.""" + x = pt.vector("x", dtype="float32") + y = pt.vector("y", dtype="float32") + z = pt.stack([x, y], axis=0) + + x_val = np.array([1.0, 2.0, 3.0], dtype="float32") + y_val = np.array([4.0, 5.0, 6.0], dtype="float32") + + fn, result = compare_onnx_and_py([x, y], z, [x_val, y_val]) + + expected = np.stack([x_val, y_val], axis=0) + np.testing.assert_allclose(result, expected, rtol=1e-5) + + node_types = get_onnx_node_types(fn) + # Stack uses Join which maps to Concat, along with Unsqueeze + assert "Concat" in node_types or "Unsqueeze" in node_types + + +def test_split_equal(): + """Test split operation with equal sizes.""" + from pytensor.tensor.basic import split + + x = pt.vector("x", dtype="float32") + splits_var = pt.constant([2, 2, 2], dtype="int64") + a, b, c = split(x, splits_var, n_splits=3, axis=0) + + x_val = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype="float32") + + fn, results = compare_onnx_and_py([x], [a, b, c], [x_val]) + + expected_a = x_val[:2] + expected_b = x_val[2:4] + expected_c = x_val[4:] + + np.testing.assert_allclose(results[0], expected_a, rtol=1e-5) + np.testing.assert_allclose(results[1], expected_b, rtol=1e-5) + np.testing.assert_allclose(results[2], expected_c, rtol=1e-5) + + node_types = get_onnx_node_types(fn) + assert "Split" in node_types + + +def test_split_unequal(): + """Test split operation with unequal sizes.""" + from pytensor.tensor.basic import split + + x = pt.vector("x", dtype="float32") + splits_var = pt.constant([3, 2, 1], dtype="int64") + a, b, c = split(x, splits_var, n_splits=3, axis=0) + + x_val = np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0], dtype="float32") + + fn, results = compare_onnx_and_py([x], [a, b, c], [x_val]) + + expected_a = x_val[:3] + expected_b = x_val[3:5] + expected_c = x_val[5:] + + np.testing.assert_allclose(results[0], expected_a, rtol=1e-5) + np.testing.assert_allclose(results[1], expected_b, rtol=1e-5) + np.testing.assert_allclose(results[2], expected_c, rtol=1e-5) + + node_types = get_onnx_node_types(fn) + assert "Split" in node_types diff --git a/tests/link/onnx/test_special.py b/tests/link/onnx/test_special.py new file mode 100644 index 0000000000..7b7ad1cd7e --- /dev/null +++ b/tests/link/onnx/test_special.py @@ -0,0 +1,269 @@ +"""Tests for ONNX backend special operations (Tier 5).""" + +import numpy as np +import pytest + +import pytensor.tensor as pt +from tests.link.onnx.test_basic import compare_onnx_and_py, get_onnx_node_types + + +# Trigonometric Functions + + +@pytest.mark.parametrize( + "pt_op,np_op,onnx_op", + [ + (pt.sin, np.sin, "Sin"), + (pt.cos, np.cos, "Cos"), + (pt.tan, np.tan, "Tan"), + (pt.arcsin, np.arcsin, "Asin"), + (pt.arccos, np.arccos, "Acos"), + (pt.arctan, np.arctan, "Atan"), + ], +) +def test_trigonometric_functions(pt_op, np_op, onnx_op): + """Test trigonometric functions.""" + pytest.importorskip("onnx") + pytest.importorskip("onnxruntime") + + x = pt.vector("x", dtype="float32") + y = pt_op(x) + + # Use values in appropriate domain + if pt_op in [pt.arcsin, pt.arccos]: + # Domain [-1, 1] + x_val = np.linspace(-0.9, 0.9, 10).astype("float32") + else: + x_val = np.linspace(-3, 3, 10).astype("float32") + + fn, result = compare_onnx_and_py([x], y, [x_val]) + + expected = np_op(x_val) + np.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-6) + + node_types = get_onnx_node_types(fn) + assert onnx_op in node_types, f"Expected '{onnx_op}' node, got {node_types}" + + +# Hyperbolic Functions + + +@pytest.mark.parametrize( + "pt_op,np_op,onnx_op", + [ + (pt.sinh, np.sinh, "Sinh"), + (pt.cosh, np.cosh, "Cosh"), + (pt.tanh, np.tanh, "Tanh"), + (pt.arcsinh, np.arcsinh, "Asinh"), + (pt.arccosh, np.arccosh, "Acosh"), + (pt.arctanh, np.arctanh, "Atanh"), + ], +) +def test_hyperbolic_functions(pt_op, np_op, onnx_op): + """Test hyperbolic functions.""" + pytest.importorskip("onnx") + pytest.importorskip("onnxruntime") + + x = pt.vector("x", dtype="float32") + y = pt_op(x) + + # Use values in appropriate domain + if pt_op == pt.arccosh: + # Domain [1, inf) + x_val = np.linspace(1.1, 3, 10).astype("float32") + elif pt_op == pt.arctanh: + # Domain (-1, 1) + x_val = np.linspace(-0.9, 0.9, 10).astype("float32") + else: + x_val = np.linspace(-2, 2, 10).astype("float32") + + fn, result = compare_onnx_and_py([x], y, [x_val]) + + expected = np_op(x_val) + np.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-6) + + node_types = get_onnx_node_types(fn) + assert onnx_op in node_types + + +# Comparison Operations + + +@pytest.mark.parametrize( + "pt_op,np_op,onnx_op", + [ + (pt.lt, np.less, "Less"), + (pt.gt, np.greater, "Greater"), + (pt.le, np.less_equal, "LessOrEqual"), + (pt.ge, np.greater_equal, "GreaterOrEqual"), + (pt.eq, np.equal, "Equal"), + (pt.neq, np.not_equal, "Not"), # Not + Equal + ], +) +def test_comparison_ops(pt_op, np_op, onnx_op): + """Test comparison operations.""" + pytest.importorskip("onnx") + pytest.importorskip("onnxruntime") + + x = pt.vector("x", dtype="float32") + y = pt.vector("y", dtype="float32") + z = pt_op(x, y) + + x_val = np.array([1, 2, 3, 4, 5], dtype="float32") + y_val = np.array([2, 2, 2, 2, 2], dtype="float32") + + _fn, result = compare_onnx_and_py([x, y], z, [x_val, y_val]) + + expected = np_op(x_val, y_val) + np.testing.assert_array_equal(result, expected) + + # Result should be boolean + assert result.dtype == bool or result.dtype == np.bool_ + + +# Logical Operations + + +@pytest.mark.parametrize( + "pt_op,np_op,onnx_op", + [ + (pt.and_, np.logical_and, "And"), + (pt.or_, np.logical_or, "Or"), + (pt.xor, np.logical_xor, "Xor"), + (pt.invert, np.logical_not, "Not"), + ], +) +def test_logical_ops(pt_op, np_op, onnx_op): + """Test logical operations.""" + pytest.importorskip("onnx") + pytest.importorskip("onnxruntime") + + if pt_op == pt.invert: + # Unary operation + x = pt.vector("x", dtype="bool") + y = pt_op(x) + + x_val = np.array([True, False, True, False, True], dtype=bool) + + fn, result = compare_onnx_and_py([x], y, [x_val]) + + expected = np_op(x_val) + np.testing.assert_array_equal(result, expected) + else: + # Binary operation + x = pt.vector("x", dtype="bool") + y_tensor = pt.vector("y", dtype="bool") + z = pt_op(x, y_tensor) + + x_val = np.array([True, True, False, False], dtype=bool) + y_val = np.array([True, False, True, False], dtype=bool) + + fn, result = compare_onnx_and_py([x, y_tensor], z, [x_val, y_val]) + + expected = np_op(x_val, y_val) + np.testing.assert_array_equal(result, expected) + + node_types = get_onnx_node_types(fn) + assert onnx_op in node_types + + +# Special Math Functions + + +@pytest.mark.parametrize( + "pt_op,onnx_op", + [ + (pt.sigmoid, "Sigmoid"), + (pt.softplus, "Softplus"), + ], +) +def test_sigmoid_softplus(pt_op, onnx_op): + """Test sigmoid and softplus activations.""" + pytest.importorskip("onnx") + pytest.importorskip("onnxruntime") + + x = pt.vector("x", dtype="float32") + y = pt_op(x) + + x_val = np.linspace(-5, 5, 20).astype("float32") + + fn, result = compare_onnx_and_py([x], y, [x_val]) + + # Verify with manual computation + if pt_op == pt.sigmoid: + expected = 1 / (1 + np.exp(-x_val)) + else: # softplus + expected = np.log(1 + np.exp(x_val)) + + np.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-6) + + node_types = get_onnx_node_types(fn) + assert onnx_op in node_types + + +def test_erf(): + """Test error function.""" + pytest.importorskip("onnx") + pytest.importorskip("onnxruntime") + + from scipy import special + + x = pt.vector("x", dtype="float32") + y = pt.erf(x) + + x_val = np.linspace(-3, 3, 20).astype("float32") + + fn, result = compare_onnx_and_py([x], y, [x_val]) + + expected = special.erf(x_val) + np.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-6) + + node_types = get_onnx_node_types(fn) + assert "Erf" in node_types + + +@pytest.mark.parametrize( + "pt_op,np_op", + [ + (pt.log1p, np.log1p), + (pt.expm1, np.expm1), + ], +) +def test_log1p_expm1(pt_op, np_op): + """Test log1p and expm1 functions. + + These may not have direct ONNX ops, but can be composed: + - log1p(x) = log(1 + x) using Add + Log + - expm1(x) = exp(x) - 1 using Exp + Sub + """ + pytest.importorskip("onnx") + pytest.importorskip("onnxruntime") + + x = pt.vector("x", dtype="float32") + y = pt_op(x) + + x_val = np.linspace(-0.5, 2, 20).astype("float32") + + _fn, result = compare_onnx_and_py([x], y, [x_val]) + + expected = np_op(x_val) + np.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-6) + + +def test_clip(): + """Test clip operation (clamp values to range).""" + pytest.importorskip("onnx") + pytest.importorskip("onnxruntime") + + x = pt.vector("x", dtype="float32") + y = pt.clip(x, -1.0, 1.0) + + x_val = np.array([-2, -0.5, 0, 0.5, 2], dtype="float32") + + fn, result = compare_onnx_and_py([x], y, [x_val]) + + expected = np.clip(x_val, -1.0, 1.0) + np.testing.assert_array_equal(result, expected) + + node_types = get_onnx_node_types(fn) + assert "Clip" in node_types, f"Expected 'Clip' node, got {node_types}" diff --git a/tests/link/onnx/test_strategies.py b/tests/link/onnx/test_strategies.py new file mode 100644 index 0000000000..b57f2b0bc1 --- /dev/null +++ b/tests/link/onnx/test_strategies.py @@ -0,0 +1,299 @@ +"""Tests for ONNX strategy registries. + +This module validates the structure and correctness of operation registries +used for property-based testing of the ONNX backend. +""" + +import numpy as np +import pytest +from hypothesis import given, settings +from hypothesis import strategies as st + +import pytensor.tensor as pt + + +# ============================================================================ +# REGISTRY STRUCTURE TESTS +# ============================================================================ + + +def test_elemwise_registry_exists(): + """ + Test that ELEMWISE_OPERATIONS registry exists and is accessible. + + This test verifies: + - Registry is defined in strategies module + - Registry is a dictionary + - Registry is not empty + """ + from tests.link.onnx.strategies import ELEMWISE_OPERATIONS + + assert isinstance(ELEMWISE_OPERATIONS, dict), ( + "ELEMWISE_OPERATIONS should be a dictionary" + ) + assert len(ELEMWISE_OPERATIONS) > 0, "ELEMWISE_OPERATIONS should not be empty" + + +def test_elemwise_registry_completeness(): + """ + Test that all 18 Tier 1 elemwise operations are registered. + + This test verifies: + - All expected Tier 1 operations are present + - No unexpected operations are present (optional) + - Operation names follow naming conventions + + Tier 1 Operations from SCALAR_OP_TO_ONNX (pytensor/link/onnx/dispatch/elemwise.py:10-30): + - Binary arithmetic: Add, Mul, Sub, TrueDiv, IntDiv, Pow (6) + - Unary math: Neg, Abs, Exp, Log, Sqrt (5) + - Rounding: Floor, Ceil, RoundHalfToEven, RoundHalfAwayFromZero (4) + - Min/Max: Maximum, Minimum (2) + - Special: Clip (1) + Total: 18 operations + + Note: Both RoundHalfToEven and RoundHalfAwayFromZero should be in registry as 'round' + and 'round_away' to enable testing both behaviors. + """ + from tests.link.onnx.strategies import ELEMWISE_OPERATIONS + + expected_ops = { + # Binary arithmetic operations (6) + "add", + "mul", + "sub", + "div", + "int_div", + "pow", + # Unary math operations (5) + "neg", + "abs", + "exp", + "log", + "sqrt", + # Rounding operations (4 - two Python operations, both mapped to ONNX "Round") + "floor", + "ceil", + "round", + "round_away", + # Element-wise min/max operations (2) + "maximum", + "minimum", + # Special operations (1) + "clip", + } + + actual_ops = set(ELEMWISE_OPERATIONS.keys()) + missing_ops = expected_ops - actual_ops + + assert len(expected_ops) == 18, ( + f"Expected ops count should be 18 Tier 1 operations, got {len(expected_ops)}" + ) + assert missing_ops == set(), f"Missing operations in registry: {missing_ops}" + # Note: extra operations in actual_ops are OK if testing Tier 4-5 operations + + +@pytest.mark.parametrize( + "op_name", + [ + "add", + "mul", + "sub", + "div", + "int_div", + "pow", + "neg", + "abs", + "exp", + "log", + "sqrt", + "floor", + "ceil", + "round", + "maximum", + "minimum", + "clip", + ], +) +def test_elemwise_registry_entry_structure(op_name): + """ + Test that each registry entry has required fields with correct types. + + This test verifies: + - Entry has 'build_graph' (callable) + - Entry has 'strategy' (hypothesis strategy) + - Entry has 'expected_onnx_ops' (list of strings) + - Entry has 'description' (string) + """ + from tests.link.onnx.strategies import ELEMWISE_OPERATIONS + + entry = ELEMWISE_OPERATIONS[op_name] + + # Check all required fields present + required_fields = {"build_graph", "strategy", "expected_onnx_ops", "description"} + actual_fields = set(entry.keys()) + missing_fields = required_fields - actual_fields + + assert missing_fields == set(), ( + f"{op_name}: Missing required fields: {missing_fields}" + ) + + # Check field types + assert callable(entry["build_graph"]), ( + f"{op_name}: 'build_graph' should be callable" + ) + assert isinstance(entry["expected_onnx_ops"], list), ( + f"{op_name}: 'expected_onnx_ops' should be a list" + ) + assert all(isinstance(op, str) for op in entry["expected_onnx_ops"]), ( + f"{op_name}: 'expected_onnx_ops' should contain strings" + ) + assert isinstance(entry["description"], str), ( + f"{op_name}: 'description' should be a string" + ) + + +# ============================================================================ +# STRATEGY VALIDATION TESTS +# ============================================================================ + + +@given(data=st.data()) +@settings(max_examples=5, deadline=None) +def test_binary_op_strategy_generates_valid_data(data): + """ + Test that binary operation strategies generate valid tensor pairs. + + This test verifies: + - Strategy generates two arrays + - Arrays have float32 dtype + - Arrays have compatible shapes (for broadcasting) + - Arrays contain finite values + """ + from tests.link.onnx.strategies import ELEMWISE_OPERATIONS + + # Test with 'add' as representative binary op + op_config = ELEMWISE_OPERATIONS["add"] + test_inputs = data.draw(op_config["strategy"]) + + assert isinstance(test_inputs, tuple), "Binary op strategy should return tuple" + assert len(test_inputs) >= 2, "Binary op strategy should return at least 2 arrays" + + x_val, y_val = test_inputs[0], test_inputs[1] + + assert x_val.dtype == np.float32, f"Expected float32, got {x_val.dtype}" + assert y_val.dtype == np.float32, f"Expected float32, got {y_val.dtype}" + assert np.all(np.isfinite(x_val)), "Generated data should be finite" + assert np.all(np.isfinite(y_val)), "Generated data should be finite" + + +@given(data=st.data()) +@settings(max_examples=5, deadline=None) +def test_unary_op_strategy_generates_valid_data(data): + """ + Test that unary operation strategies generate valid tensors. + + This test verifies: + - Strategy generates one array (or tuple with one array) + - Array has float32 dtype + - Array contains finite values + """ + from tests.link.onnx.strategies import ELEMWISE_OPERATIONS + + # Test with 'neg' as representative unary op + op_config = ELEMWISE_OPERATIONS["neg"] + test_inputs = data.draw(op_config["strategy"]) + + # Handle both tuple and direct array returns + if isinstance(test_inputs, tuple): + x_val = test_inputs[0] + else: + x_val = test_inputs + + assert x_val.dtype == np.float32, f"Expected float32, got {x_val.dtype}" + assert np.all(np.isfinite(x_val)), "Generated data should be finite" + + +@given(data=st.data()) +@settings(max_examples=10, deadline=None) +def test_log_strategy_generates_positive_values(data): + """ + Test that log strategy generates positive values. + + This test verifies: + - Strategy generates positive values (log requires x > 0) + - Values are not too close to zero (numerical stability) + """ + from tests.link.onnx.strategies import ELEMWISE_OPERATIONS + + op_config = ELEMWISE_OPERATIONS["log"] + test_inputs = data.draw(op_config["strategy"]) + + if isinstance(test_inputs, tuple): + x_val = test_inputs[0] + else: + x_val = test_inputs + + assert np.all(x_val > 0), "Log operation requires positive inputs" + assert np.all(x_val > 1e-6), ( + "Values should not be too close to zero for numerical stability" + ) + + +@given(data=st.data()) +@settings(max_examples=10, deadline=None) +def test_sqrt_strategy_generates_non_negative_values(data): + """ + Test that sqrt strategy generates non-negative values. + + This test verifies: + - Strategy generates non-negative values (sqrt requires x >= 0) + """ + from tests.link.onnx.strategies import ELEMWISE_OPERATIONS + + op_config = ELEMWISE_OPERATIONS["sqrt"] + test_inputs = data.draw(op_config["strategy"]) + + if isinstance(test_inputs, tuple): + x_val = test_inputs[0] + else: + x_val = test_inputs + + assert np.all(x_val >= 0), "Sqrt operation requires non-negative inputs" + + +# ============================================================================ +# BUILD GRAPH VALIDATION TESTS +# ============================================================================ + + +def test_build_graph_returns_valid_structure(): + """ + Test that build_graph functions return valid graph structure. + + This test verifies: + - build_graph returns a tuple + - First element is a list of PyTensor Variables (inputs) + - Second element is a PyTensor Variable (output) + """ + from tests.link.onnx.strategies import ELEMWISE_OPERATIONS + + # Test with 'add' as representative + op_config = ELEMWISE_OPERATIONS["add"] + + # Create dummy inputs + x_val = np.array([1, 2, 3], dtype="float32") + y_val = np.array([4, 5, 6], dtype="float32") + + # Call build_graph + result = op_config["build_graph"](x_val, y_val) + + assert isinstance(result, tuple), "build_graph should return a tuple" + assert len(result) == 2, "build_graph should return (inputs, output)" + + graph_inputs, graph_output = result + + assert isinstance(graph_inputs, list), "First element should be list of inputs" + assert all(isinstance(inp, pt.Variable) for inp in graph_inputs), ( + "All inputs should be PyTensor Variables" + ) + assert isinstance(graph_output, pt.Variable), "Output should be PyTensor Variable" diff --git a/tests/link/onnx/test_subtensor.py b/tests/link/onnx/test_subtensor.py new file mode 100644 index 0000000000..6b2f55f3b8 --- /dev/null +++ b/tests/link/onnx/test_subtensor.py @@ -0,0 +1,482 @@ +"""Tests for ONNX subtensor (slicing) operations. + +Test Strategy: +- Property-based tests provide primary coverage (40+ scenarios) +- Individual property test per operation type (4 operations) +- Manual tests retained for specific patterns and edge cases + +Operations: Subtensor (slicing), AdvancedSubtensor (integer indexing), + set_subtensor, inc_subtensor + +Known Limitations: +- Negative indices NOT supported (limitation documented in subtensor.py:122-127) +- Property tests explicitly exclude negative indices +- Manual tests for negative indices are skipped (will be enabled when supported) +""" + +import numpy as np +import pytest +from hypothesis import assume, given, settings +from hypothesis import strategies as st + +import pytensor.tensor as pt +from tests.link.onnx.strategies import INCSUBTENSOR_OPERATIONS, SUBTENSOR_OPERATIONS +from tests.link.onnx.test_basic import compare_onnx_and_py, get_onnx_node_types + + +# Import ONNX and skip if not available +onnx = pytest.importorskip("onnx") +ort = pytest.importorskip("onnxruntime") + + +# ============================================================================ +# PROPERTY-BASED TESTS (Primary Coverage) +# ============================================================================ + + +@given( + op_name=st.sampled_from(["slice_basic", "slice_multidim", "slice_with_step"]), + data=st.data(), +) +@settings(max_examples=20, deadline=None) # Higher count for slicing edge cases +def test_subtensor_basic_slicing_correctness(op_name, data): + """ + Property test: Basic subtensor slicing operations produce correct results. + + This test verifies: + - Basic slicing (x[2:5]) works correctly + - Multi-dimensional slicing (x[1:3, 2:4]) works correctly + - Slicing with step (x[::2], x[1:8:2]) works correctly + - ONNX output matches Python reference + - Correct ONNX node type (Slice) + + Operations tested: slice_basic, slice_multidim, slice_with_step + Total: 3 patterns x 20 examples = 60 test scenarios + + Note: This test does NOT cover negative indices (not yet supported in ONNX backend) + """ + op_config = SUBTENSOR_OPERATIONS[op_name] + + # Generate test data (tensor with valid size for slicing) + x_val = data.draw(op_config["strategy"]) + + # Build graph + graph_inputs, graph_output = op_config["build_graph"](x_val) + + # Compare ONNX vs PyTensor + fn, result = compare_onnx_and_py(graph_inputs, graph_output, [x_val]) + + # Verify ONNX node type + node_types = get_onnx_node_types(fn) + expected_ops = op_config["expected_onnx_ops"] + assert any(op in node_types for op in expected_ops), ( + f"{op_name}: Expected one of {expected_ops}, got {node_types}" + ) + + # Additional validation: verify result shape is reasonable + assert result.ndim <= x_val.ndim, ( + "Result should not have more dimensions than input" + ) + assert result.size <= x_val.size, "Slice result should not be larger than input" + + +@given(data=st.data()) +@settings(max_examples=10, deadline=None) +def test_advanced_subtensor_indexing_correctness(data): + """ + Property test: Advanced subtensor indexing produces correct results. + + This test verifies: + - Integer array indexing (x[indices]) works correctly + - Selected elements match Python reference + - ONNX output matches PyTensor + - Correct ONNX node type (Gather) + + Note: Uses advanced_index_strategy to generate valid indices + (all indices are non-negative and within bounds) + """ + op_config = SUBTENSOR_OPERATIONS["advanced_index"] + + # Generate test data (tensor and valid integer indices) + test_data = data.draw(op_config["strategy"]) + x_val, indices_val = test_data + + # Verify indices are valid (strategy constraint) + assert np.all(indices_val >= 0), ( + "Indices should be non-negative (negative indices not supported)" + ) + assert np.all(indices_val < x_val.shape[0]), "Indices should be within bounds" + + # Build graph + graph_inputs, graph_output = op_config["build_graph"](x_val, indices_val) + + # Compare ONNX vs PyTensor + fn, result = compare_onnx_and_py(graph_inputs, graph_output, [x_val, indices_val]) + + # Verify ONNX node type + node_types = get_onnx_node_types(fn) + expected_ops = op_config["expected_onnx_ops"] + assert any(op in node_types for op in expected_ops), ( + f"Expected one of {expected_ops}, got {node_types}" + ) + + # Validate result shape + expected_shape = (indices_val.shape[0],) + assert result.shape == expected_shape, ( + f"Expected shape {expected_shape}, got {result.shape}" + ) + + +@given(data=st.data()) +@settings(max_examples=10, deadline=None) +def test_set_subtensor_operation_correctness(data): + """ + Property test: set_subtensor correctly replaces slice with values. + + This test verifies: + - set_subtensor replaces slice with provided values + - Other elements remain unchanged + - ONNX output matches PyTensor + - Correct ONNX node types (ScatterElements/ScatterND) + + Note: Uses set_subtensor_strategy to generate compatible shapes + """ + op_config = INCSUBTENSOR_OPERATIONS["set_subtensor"] + + # Generate test data (tensor and replacement values) + x_val, values_val = data.draw(op_config["strategy"]) + + # Build graph + graph_inputs, graph_output = op_config["build_graph"](x_val, values_val) + + # Compare ONNX vs PyTensor + fn, result = compare_onnx_and_py(graph_inputs, graph_output, [x_val, values_val]) + + # Verify ONNX node types + node_types = get_onnx_node_types(fn) + expected_ops = op_config["expected_onnx_ops"] + assert any(op in node_types for op in expected_ops), ( + f"Expected one of {expected_ops}, got {node_types}" + ) + + # Use Hypothesis assume() to filter edge case where new values equal old + # This avoids false failures when values_val happens to equal x_val[2:5] + assume(not np.array_equal(values_val, x_val[2:5])) + + # Validate that slice was modified + # (This assertion is now guaranteed to be meaningful) + assert not np.array_equal(result[2:5], x_val[2:5]), ( + "Slice should have been modified" + ) + + # Validate that values were set correctly + np.testing.assert_array_equal(result[2:5], values_val) + + # Validate that other elements unchanged + np.testing.assert_array_equal(result[:2], x_val[:2]) + np.testing.assert_array_equal(result[5:], x_val[5:]) + + +@given(data=st.data()) +@settings(max_examples=10, deadline=None) +def test_inc_subtensor_operation_correctness(data): + """ + Property test: inc_subtensor correctly increments slice values. + + This test verifies: + - inc_subtensor adds values to existing slice + - Other elements remain unchanged + - ONNX output matches PyTensor + - Correct ONNX node types (Gather, Add, ScatterElements) + + Note: inc_subtensor is more complex than set_subtensor + (requires gather, add, then scatter) + """ + op_config = INCSUBTENSOR_OPERATIONS["inc_subtensor"] + + # Generate test data (tensor and increment values) + x_val, values_val = data.draw(op_config["strategy"]) + + # Build graph + graph_inputs, graph_output = op_config["build_graph"](x_val, values_val) + + # Compare ONNX vs PyTensor + fn, result = compare_onnx_and_py(graph_inputs, graph_output, [x_val, values_val]) + + # Verify ONNX node types (should include Gather, Add, ScatterElements) + node_types = get_onnx_node_types(fn) + # Note: inc_subtensor requires multiple operations + assert "Gather" in node_types or "Slice" in node_types, ( + "Expected gather/slice operation" + ) + assert "Add" in node_types, "Expected Add operation (for increment)" + assert "ScatterElements" in node_types or "ScatterND" in node_types, ( + "Expected scatter operation" + ) + + # Use Hypothesis assume() to filter edge case where increment values are zero + # This avoids false failures when values_val is all zeros + assume(not np.allclose(values_val, 0)) + + # Validate that slice was modified + # (This assertion is now guaranteed to be meaningful) + assert not np.array_equal(result[2:5], x_val[2:5]), ( + "Slice should have been modified" + ) + + # Validate that values were incremented correctly + expected_slice = x_val[2:5] + values_val + np.testing.assert_allclose(result[2:5], expected_slice, rtol=1e-5) + + # Validate that other elements unchanged + np.testing.assert_array_equal(result[:2], x_val[:2]) + np.testing.assert_array_equal(result[5:], x_val[5:]) + + +# ============================================================================ +# MANUAL EDGE CASE TESTS +# ============================================================================ +# These tests complement the property-based tests above by: +# - Testing specific edge cases and patterns +# - Providing readable examples for documentation +# - Validating 3D operations (more complex than property tests cover) +# ============================================================================ + + +class TestSubtensorBasic: + """Test specific basic slicing patterns. + + Note: Many of these patterns are also covered by property-based tests above, + but are retained for: + - Explicit documentation of supported patterns + - Quick debugging when property tests fail + - Testing specific slice boundaries + """ + + def test_slice_1d_basic(self): + """Test basic 1D slicing: x[2:5]""" + x = pt.vector("x", dtype="float32") + y = x[2:5] + + x_val = np.arange(10, dtype="float32") + + fn, result = compare_onnx_and_py([x], y, [x_val]) + + # Verify correct output + expected = x_val[2:5] + np.testing.assert_array_equal(result, expected) + + # Verify ONNX uses Slice operation + node_types = get_onnx_node_types(fn) + assert "Slice" in node_types, f"Expected 'Slice' in {node_types}" + + def test_slice_1d_from_start(self): + """Test slicing from start: x[:5]""" + x = pt.vector("x", dtype="float32") + y = x[:5] + + x_val = np.arange(10, dtype="float32") + + _fn, result = compare_onnx_and_py([x], y, [x_val]) + expected = x_val[:5] + np.testing.assert_array_equal(result, expected) + + def test_slice_1d_to_end(self): + """Test slicing to end: x[3:]""" + x = pt.vector("x", dtype="float32") + y = x[3:] + + x_val = np.arange(10, dtype="float32") + + _fn, result = compare_onnx_and_py([x], y, [x_val]) + expected = x_val[3:] + np.testing.assert_array_equal(result, expected) + + def test_slice_1d_with_step(self): + """Test slicing with step: x[::2]""" + x = pt.vector("x", dtype="float32") + y = x[::2] + + x_val = np.arange(10, dtype="float32") + + _fn, result = compare_onnx_and_py([x], y, [x_val]) + expected = x_val[::2] + np.testing.assert_array_equal(result, expected) + + def test_slice_1d_with_step_range(self): + """Test slicing with step and range: x[1:8:2]""" + x = pt.vector("x", dtype="float32") + y = x[1:8:2] + + x_val = np.arange(10, dtype="float32") + + _fn, result = compare_onnx_and_py([x], y, [x_val]) + expected = x_val[1:8:2] + np.testing.assert_array_equal(result, expected) + + def test_slice_2d_basic(self): + """Test 2D slicing: x[1:3, 2:4]""" + x = pt.matrix("x", dtype="float32") + y = x[1:3, 2:4] + + x_val = np.arange(20, dtype="float32").reshape(4, 5) + + _fn, result = compare_onnx_and_py([x], y, [x_val]) + expected = x_val[1:3, 2:4] + np.testing.assert_array_equal(result, expected) + + def test_slice_2d_one_axis(self): + """Test 2D slicing on one axis: x[1:3, :]""" + x = pt.matrix("x", dtype="float32") + y = x[1:3, :] + + x_val = np.arange(20, dtype="float32").reshape(4, 5) + + _fn, result = compare_onnx_and_py([x], y, [x_val]) + expected = x_val[1:3, :] + np.testing.assert_array_equal(result, expected) + + def test_slice_3d(self): + """Test 3D slicing: x[0:2, 1:3, 2:4]""" + x = pt.tensor3("x", dtype="float32") + y = x[0:2, 1:3, 2:4] + + x_val = np.arange(60, dtype="float32").reshape(3, 4, 5) + + _fn, result = compare_onnx_and_py([x], y, [x_val]) + expected = x_val[0:2, 1:3, 2:4] + np.testing.assert_array_equal(result, expected) + + +class TestSubtensorNegativeIndices: + """Test slicing with negative indices (when implemented). + + IMPORTANT: These tests are currently skipped because negative indices are NOT + yet supported in the ONNX backend. This is a known limitation documented at: + pytensor/link/onnx/dispatch/subtensor.py:122-127 + + These tests document the expected behavior when the feature is implemented. + Remove @pytest.mark.skip decorators when negative index support is added. + """ + + @pytest.mark.skip(reason="Negative indices not yet implemented") + def test_slice_negative_start(self): + """Test slicing with negative start: x[-3:]""" + x = pt.vector("x", dtype="float32") + y = x[-3:] + + x_val = np.arange(10, dtype="float32") + + _fn, result = compare_onnx_and_py([x], y, [x_val]) + expected = x_val[-3:] + np.testing.assert_array_equal(result, expected) + + @pytest.mark.skip(reason="Negative indices not yet implemented") + def test_slice_negative_end(self): + """Test slicing with negative end: x[:-2]""" + x = pt.vector("x", dtype="float32") + y = x[:-2] + + x_val = np.arange(10, dtype="float32") + + _fn, result = compare_onnx_and_py([x], y, [x_val]) + expected = x_val[:-2] + np.testing.assert_array_equal(result, expected) + + +class TestAdvancedSubtensor: + """Test advanced indexing with integer arrays. + + These tests verify that integer array indexing (fancy indexing) works correctly. + Also covered by test_advanced_subtensor_indexing_correctness property test. + """ + + def test_integer_array_indexing(self): + """Test integer array indexing: x[indices]""" + x = pt.vector("x", dtype="float32") + indices = pt.vector("indices", dtype="int64") + y = x[indices] + + x_val = np.arange(10, dtype="float32") + indices_val = np.array([0, 2, 5], dtype="int64") + + fn, result = compare_onnx_and_py([x, indices], y, [x_val, indices_val]) + expected = x_val[indices_val] + np.testing.assert_array_equal(result, expected) + + # Verify ONNX uses Gather operation + node_types = get_onnx_node_types(fn) + assert "Gather" in node_types, f"Expected 'Gather' in {node_types}" + + def test_integer_array_indexing_2d(self): + """Test integer array indexing on 2D array: x[indices, :]""" + x = pt.matrix("x", dtype="float32") + indices = pt.vector("indices", dtype="int64") + y = x[indices] + + x_val = np.arange(20, dtype="float32").reshape(4, 5) + indices_val = np.array([0, 2], dtype="int64") + + fn, result = compare_onnx_and_py([x, indices], y, [x_val, indices_val]) + expected = x_val[indices_val] + np.testing.assert_array_equal(result, expected) + + # Verify ONNX uses Gather operation + node_types = get_onnx_node_types(fn) + assert "Gather" in node_types, f"Expected 'Gather' in {node_types}" + + +class TestIncSubtensor: + """Test set_subtensor and inc_subtensor operations. + + These tests verify that setting and incrementing subtensor slices works correctly. + They also document the expected ONNX node patterns (ScatterElements for both, + plus Gather and Add for inc_subtensor). + + Also covered by property tests: test_set_subtensor_operation_correctness and + test_inc_subtensor_operation_correctness. + """ + + def test_set_subtensor(self): + """Test set_subtensor: x[2:5] = values""" + x = pt.vector("x", dtype="float32") + values = pt.vector("values", dtype="float32") + y = pt.set_subtensor(x[2:5], values) + + x_val = np.arange(10, dtype="float32") + values_val = np.array([100, 200, 300], dtype="float32") + + fn, result = compare_onnx_and_py([x, values], y, [x_val, values_val]) + + expected = x_val.copy() + expected[2:5] = values_val + np.testing.assert_array_equal(result, expected) + + # Verify ONNX uses ScatterElements operation + node_types = get_onnx_node_types(fn) + assert "ScatterElements" in node_types, ( + f"Expected 'ScatterElements' in {node_types}" + ) + + def test_inc_subtensor(self): + """Test inc_subtensor: x[2:5] += values""" + x = pt.vector("x", dtype="float32") + values = pt.vector("values", dtype="float32") + y = pt.inc_subtensor(x[2:5], values) + + x_val = np.arange(10, dtype="float32") + values_val = np.array([1, 2, 3], dtype="float32") + + fn, result = compare_onnx_and_py([x, values], y, [x_val, values_val]) + + expected = x_val.copy() + expected[2:5] += values_val + np.testing.assert_array_equal(result, expected) + + # Verify ONNX uses Gather, Add, and ScatterElements operations + node_types = get_onnx_node_types(fn) + assert "Gather" in node_types, f"Expected 'Gather' in {node_types}" + assert "Add" in node_types, f"Expected 'Add' in {node_types}" + assert "ScatterElements" in node_types, ( + f"Expected 'ScatterElements' in {node_types}" + ) diff --git a/tests/link/onnx/test_tensor_basic.py b/tests/link/onnx/test_tensor_basic.py new file mode 100644 index 0000000000..76699644fb --- /dev/null +++ b/tests/link/onnx/test_tensor_basic.py @@ -0,0 +1,163 @@ +"""Tests for ONNX tensor basic operations (allocation, etc.).""" + +import numpy as np +import pytest +from hypothesis import given, settings +from hypothesis import strategies as st + +import pytensor.tensor as pt +from tests.link.onnx.strategies import ALLOCATION_OPERATIONS +from tests.link.onnx.test_basic import compare_onnx_and_py, get_onnx_node_types + + +# Import ONNX and skip if not available +onnx = pytest.importorskip("onnx") +ort = pytest.importorskip("onnxruntime") + + +# ============================================================================ +# Property-Based Tests for Allocation Operations +# ============================================================================ + + +@given( + op_name=st.sampled_from(list(ALLOCATION_OPERATIONS.keys())), + data=st.data(), +) +@settings(max_examples=10, deadline=None) +def test_allocation_operations_correctness(op_name, data): + """Property test: All allocation operations produce correct ONNX results. + + Tests: alloc, alloc_empty, make_vector, arange + Total: 4 operations x 10 examples = 40 test scenarios + """ + op_config = ALLOCATION_OPERATIONS[op_name] + + # Generate test data + test_data = data.draw(op_config["strategy"]) + inputs_tuple = test_data if isinstance(test_data, tuple) else (test_data,) + + # Build graph + graph_inputs, graph_output = op_config["build_graph"](*inputs_tuple) + + # Prepare test inputs (many allocation ops have no inputs) + test_inputs = [] + + # Special handling for AllocEmpty (only check shape/dtype) + if op_name == "alloc_empty": + + def assert_shape_dtype(a, b): + assert a.shape == b.shape + assert a.dtype == b.dtype + + fn, _result = compare_onnx_and_py( + graph_inputs, graph_output, test_inputs, assert_fn=assert_shape_dtype + ) + else: + fn, _result = compare_onnx_and_py(graph_inputs, graph_output, test_inputs) + + # Verify ONNX nodes + node_types = get_onnx_node_types(fn) + expected_ops = op_config["expected_onnx_ops"] + assert any(op in node_types for op in expected_ops), ( + f"{op_name}: Expected {expected_ops}, got {node_types}" + ) + + +# ============================================================================ +# Specific Tests for Edge Cases +# ============================================================================ + + +def test_arange_requires_constants(): + """ARange requires constant inputs (ONNX limitation).""" + x = pt.arange(0, 10, 2, dtype="int64") + + fn, result = compare_onnx_and_py([], x, []) + + expected = np.arange(0, 10, 2, dtype="int64") + np.testing.assert_array_equal(result, expected) + assert "Range" in get_onnx_node_types(fn) + + +def test_alloc_constant_shape(): + """Alloc with constant shape.""" + val = 5.0 + x = pt.alloc(val, 3, 4) + + fn, result = compare_onnx_and_py([], x, []) + + expected = np.full((3, 4), val, dtype="float32") + np.testing.assert_allclose(result, expected) + assert "Expand" in get_onnx_node_types(fn) + + +def test_alloc_dynamic_shape(): + """Alloc with dynamic shape from scalar inputs.""" + val = pt.scalar("val", dtype="float32") + s1 = pt.scalar("s1", dtype="int64") + s2 = pt.scalar("s2", dtype="int64") + x = pt.alloc(val, s1, s2) + + val_data = np.array(3.5, dtype="float32") + s1_data = np.array(4, dtype="int64") + s2_data = np.array(5, dtype="int64") + + fn, result = compare_onnx_and_py([val, s1, s2], x, [val_data, s1_data, s2_data]) + + expected = np.full((4, 5), 3.5, dtype="float32") + np.testing.assert_allclose(result, expected) + assert "Expand" in get_onnx_node_types(fn) + + +def test_make_vector_from_scalars(): + """MakeVector creates vector from scalar values.""" + a = 1.0 + b = 2.0 + c = 3.0 + vec = pt.stack([a, b, c]) + + fn, result = compare_onnx_and_py([], vec, []) + + expected = np.array([1.0, 2.0, 3.0], dtype="float32") + np.testing.assert_allclose(result, expected) + + node_types = get_onnx_node_types(fn) + # MakeVector uses Unsqueeze + Concat + assert "Concat" in node_types + + +def test_alloc_empty_shape_dtype(): + """AllocEmpty creates tensor with correct shape and dtype.""" + x = pt.empty((3, 4), dtype="float32") + + fn, result = compare_onnx_and_py( + [], + x, + [], + assert_fn=lambda a, b: (a.shape == b.shape and a.dtype == b.dtype) + or (_ for _ in ()).throw( + AssertionError( + f"Shape/dtype mismatch: {a.shape}/{a.dtype} vs {b.shape}/{b.dtype}" + ) + ), + ) + + assert result.shape == (3, 4) + assert result.dtype == np.float32 + assert "ConstantOfShape" in get_onnx_node_types(fn) + + +def test_arange_with_different_dtypes(): + """ARange works with different dtypes.""" + # int64 + x_int = pt.arange(0, 10, 1, dtype="int64") + _fn_int, result_int = compare_onnx_and_py([], x_int, []) + expected_int = np.arange(0, 10, 1, dtype="int64") + np.testing.assert_array_equal(result_int, expected_int) + + # float32 + x_float = pt.arange(0.0, 5.0, 0.5, dtype="float32") + _fn_float, result_float = compare_onnx_and_py([], x_float, []) + expected_float = np.arange(0.0, 5.0, 0.5, dtype="float32") + np.testing.assert_allclose(result_float, expected_float) diff --git a/uv.lock b/uv.lock new file mode 100644 index 0000000000..2b9f15eddd --- /dev/null +++ b/uv.lock @@ -0,0 +1,1083 @@ +version = 1 +revision = 3 +requires-python = ">=3.11, <3.14" +resolution-markers = [ + "python_full_version >= '3.13'", + "python_full_version == '3.12.*'", + "python_full_version < '3.12'", +] + +[[package]] +name = "alabaster" +version = "0.7.16" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/c9/3e/13dd8e5ed9094e734ac430b5d0eb4f2bb001708a8b7856cbf8e084e001ba/alabaster-0.7.16.tar.gz", hash = "sha256:75a8b99c28a5dad50dd7f8ccdd447a121ddb3892da9e53d1ca5cca3106d58d65", size = 23776, upload-time = "2024-01-10T00:56:10.189Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/32/34/d4e1c02d3bee589efb5dfa17f88ea08bdb3e3eac12bc475462aec52ed223/alabaster-0.7.16-py3-none-any.whl", hash = "sha256:b46733c07dce03ae4e150330b975c75737fa60f0a7c591b6c8bf4928a28e2c92", size = 13511, upload-time = "2024-01-10T00:56:08.388Z" }, +] + +[[package]] +name = "babel" +version = "2.17.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/7d/6b/d52e42361e1aa00709585ecc30b3f9684b3ab62530771402248b1b1d6240/babel-2.17.0.tar.gz", hash = "sha256:0c54cffb19f690cdcc52a3b50bcbf71e07a808d1c80d549f2459b9d2cf0afb9d", size = 9951852, upload-time = "2025-02-01T15:17:41.026Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b7/b8/3fe70c75fe32afc4bb507f75563d39bc5642255d1d94f1f23604725780bf/babel-2.17.0-py3-none-any.whl", hash = "sha256:4d0b53093fdfb4b21c92b5213dba5a1b23885afa8383709427046b21c366e5f2", size = 10182537, upload-time = "2025-02-01T15:17:37.39Z" }, +] + +[[package]] +name = "certifi" +version = "2025.10.5" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/4c/5b/b6ce21586237c77ce67d01dc5507039d444b630dd76611bbca2d8e5dcd91/certifi-2025.10.5.tar.gz", hash = "sha256:47c09d31ccf2acf0be3f701ea53595ee7e0b8fa08801c6624be771df09ae7b43", size = 164519, upload-time = "2025-10-05T04:12:15.808Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e4/37/af0d2ef3967ac0d6113837b44a4f0bfe1328c2b9763bd5b1744520e5cfed/certifi-2025.10.5-py3-none-any.whl", hash = "sha256:0f212c2744a9bb6de0c56639a6f68afe01ecd92d91f14ae897c4fe7bbeeef0de", size = 163286, upload-time = "2025-10-05T04:12:14.03Z" }, +] + +[[package]] +name = "cfgv" +version = "3.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/11/74/539e56497d9bd1d484fd863dd69cbbfa653cd2aa27abfe35653494d85e94/cfgv-3.4.0.tar.gz", hash = "sha256:e52591d4c5f5dead8e0f673fb16db7949d2cfb3f7da4582893288f0ded8fe560", size = 7114, upload-time = "2023-08-12T20:38:17.776Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c5/55/51844dd50c4fc7a33b653bfaba4c2456f06955289ca770a5dbd5fd267374/cfgv-3.4.0-py2.py3-none-any.whl", hash = "sha256:b7265b1f29fd3316bfcd2b330d63d024f2bfd8bcb8b0272f8e19a504856c48f9", size = 7249, upload-time = "2023-08-12T20:38:16.269Z" }, +] + +[[package]] +name = "charset-normalizer" +version = "3.4.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/13/69/33ddede1939fdd074bce5434295f38fae7136463422fe4fd3e0e89b98062/charset_normalizer-3.4.4.tar.gz", hash = "sha256:94537985111c35f28720e43603b8e7b43a6ecfb2ce1d3058bbe955b73404e21a", size = 129418, upload-time = "2025-10-14T04:42:32.879Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ed/27/c6491ff4954e58a10f69ad90aca8a1b6fe9c5d3c6f380907af3c37435b59/charset_normalizer-3.4.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:6e1fcf0720908f200cd21aa4e6750a48ff6ce4afe7ff5a79a90d5ed8a08296f8", size = 206988, upload-time = "2025-10-14T04:40:33.79Z" }, + { url = "https://files.pythonhosted.org/packages/94/59/2e87300fe67ab820b5428580a53cad894272dbb97f38a7a814a2a1ac1011/charset_normalizer-3.4.4-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5f819d5fe9234f9f82d75bdfa9aef3a3d72c4d24a6e57aeaebba32a704553aa0", size = 147324, upload-time = "2025-10-14T04:40:34.961Z" }, + { url = "https://files.pythonhosted.org/packages/07/fb/0cf61dc84b2b088391830f6274cb57c82e4da8bbc2efeac8c025edb88772/charset_normalizer-3.4.4-cp311-cp311-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:a59cb51917aa591b1c4e6a43c132f0cdc3c76dbad6155df4e28ee626cc77a0a3", size = 142742, upload-time = "2025-10-14T04:40:36.105Z" }, + { url = "https://files.pythonhosted.org/packages/62/8b/171935adf2312cd745d290ed93cf16cf0dfe320863ab7cbeeae1dcd6535f/charset_normalizer-3.4.4-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:8ef3c867360f88ac904fd3f5e1f902f13307af9052646963ee08ff4f131adafc", size = 160863, upload-time = "2025-10-14T04:40:37.188Z" }, + { url = "https://files.pythonhosted.org/packages/09/73/ad875b192bda14f2173bfc1bc9a55e009808484a4b256748d931b6948442/charset_normalizer-3.4.4-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:d9e45d7faa48ee908174d8fe84854479ef838fc6a705c9315372eacbc2f02897", size = 157837, upload-time = "2025-10-14T04:40:38.435Z" }, + { url = "https://files.pythonhosted.org/packages/6d/fc/de9cce525b2c5b94b47c70a4b4fb19f871b24995c728e957ee68ab1671ea/charset_normalizer-3.4.4-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:840c25fb618a231545cbab0564a799f101b63b9901f2569faecd6b222ac72381", size = 151550, upload-time = "2025-10-14T04:40:40.053Z" }, + { url = "https://files.pythonhosted.org/packages/55/c2/43edd615fdfba8c6f2dfbd459b25a6b3b551f24ea21981e23fb768503ce1/charset_normalizer-3.4.4-cp311-cp311-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:ca5862d5b3928c4940729dacc329aa9102900382fea192fc5e52eb69d6093815", size = 149162, upload-time = "2025-10-14T04:40:41.163Z" }, + { url = "https://files.pythonhosted.org/packages/03/86/bde4ad8b4d0e9429a4e82c1e8f5c659993a9a863ad62c7df05cf7b678d75/charset_normalizer-3.4.4-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d9c7f57c3d666a53421049053eaacdd14bbd0a528e2186fcb2e672effd053bb0", size = 150019, upload-time = "2025-10-14T04:40:42.276Z" }, + { url = "https://files.pythonhosted.org/packages/1f/86/a151eb2af293a7e7bac3a739b81072585ce36ccfb4493039f49f1d3cae8c/charset_normalizer-3.4.4-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:277e970e750505ed74c832b4bf75dac7476262ee2a013f5574dd49075879e161", size = 143310, upload-time = "2025-10-14T04:40:43.439Z" }, + { url = "https://files.pythonhosted.org/packages/b5/fe/43dae6144a7e07b87478fdfc4dbe9efd5defb0e7ec29f5f58a55aeef7bf7/charset_normalizer-3.4.4-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:31fd66405eaf47bb62e8cd575dc621c56c668f27d46a61d975a249930dd5e2a4", size = 162022, upload-time = "2025-10-14T04:40:44.547Z" }, + { url = "https://files.pythonhosted.org/packages/80/e6/7aab83774f5d2bca81f42ac58d04caf44f0cc2b65fc6db2b3b2e8a05f3b3/charset_normalizer-3.4.4-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:0d3d8f15c07f86e9ff82319b3d9ef6f4bf907608f53fe9d92b28ea9ae3d1fd89", size = 149383, upload-time = "2025-10-14T04:40:46.018Z" }, + { url = "https://files.pythonhosted.org/packages/4f/e8/b289173b4edae05c0dde07f69f8db476a0b511eac556dfe0d6bda3c43384/charset_normalizer-3.4.4-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:9f7fcd74d410a36883701fafa2482a6af2ff5ba96b9a620e9e0721e28ead5569", size = 159098, upload-time = "2025-10-14T04:40:47.081Z" }, + { url = "https://files.pythonhosted.org/packages/d8/df/fe699727754cae3f8478493c7f45f777b17c3ef0600e28abfec8619eb49c/charset_normalizer-3.4.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:ebf3e58c7ec8a8bed6d66a75d7fb37b55e5015b03ceae72a8e7c74495551e224", size = 152991, upload-time = "2025-10-14T04:40:48.246Z" }, + { url = "https://files.pythonhosted.org/packages/1a/86/584869fe4ddb6ffa3bd9f491b87a01568797fb9bd8933f557dba9771beaf/charset_normalizer-3.4.4-cp311-cp311-win32.whl", hash = "sha256:eecbc200c7fd5ddb9a7f16c7decb07b566c29fa2161a16cf67b8d068bd21690a", size = 99456, upload-time = "2025-10-14T04:40:49.376Z" }, + { url = "https://files.pythonhosted.org/packages/65/f6/62fdd5feb60530f50f7e38b4f6a1d5203f4d16ff4f9f0952962c044e919a/charset_normalizer-3.4.4-cp311-cp311-win_amd64.whl", hash = "sha256:5ae497466c7901d54b639cf42d5b8c1b6a4fead55215500d2f486d34db48d016", size = 106978, upload-time = "2025-10-14T04:40:50.844Z" }, + { url = "https://files.pythonhosted.org/packages/7a/9d/0710916e6c82948b3be62d9d398cb4fcf4e97b56d6a6aeccd66c4b2f2bd5/charset_normalizer-3.4.4-cp311-cp311-win_arm64.whl", hash = "sha256:65e2befcd84bc6f37095f5961e68a6f077bf44946771354a28ad434c2cce0ae1", size = 99969, upload-time = "2025-10-14T04:40:52.272Z" }, + { url = "https://files.pythonhosted.org/packages/f3/85/1637cd4af66fa687396e757dec650f28025f2a2f5a5531a3208dc0ec43f2/charset_normalizer-3.4.4-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:0a98e6759f854bd25a58a73fa88833fba3b7c491169f86ce1180c948ab3fd394", size = 208425, upload-time = "2025-10-14T04:40:53.353Z" }, + { url = "https://files.pythonhosted.org/packages/9d/6a/04130023fef2a0d9c62d0bae2649b69f7b7d8d24ea5536feef50551029df/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b5b290ccc2a263e8d185130284f8501e3e36c5e02750fc6b6bdeb2e9e96f1e25", size = 148162, upload-time = "2025-10-14T04:40:54.558Z" }, + { url = "https://files.pythonhosted.org/packages/78/29/62328d79aa60da22c9e0b9a66539feae06ca0f5a4171ac4f7dc285b83688/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:74bb723680f9f7a6234dcf67aea57e708ec1fbdf5699fb91dfd6f511b0a320ef", size = 144558, upload-time = "2025-10-14T04:40:55.677Z" }, + { url = "https://files.pythonhosted.org/packages/86/bb/b32194a4bf15b88403537c2e120b817c61cd4ecffa9b6876e941c3ee38fe/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:f1e34719c6ed0b92f418c7c780480b26b5d9c50349e9a9af7d76bf757530350d", size = 161497, upload-time = "2025-10-14T04:40:57.217Z" }, + { url = "https://files.pythonhosted.org/packages/19/89/a54c82b253d5b9b111dc74aca196ba5ccfcca8242d0fb64146d4d3183ff1/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:2437418e20515acec67d86e12bf70056a33abdacb5cb1655042f6538d6b085a8", size = 159240, upload-time = "2025-10-14T04:40:58.358Z" }, + { url = "https://files.pythonhosted.org/packages/c0/10/d20b513afe03acc89ec33948320a5544d31f21b05368436d580dec4e234d/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:11d694519d7f29d6cd09f6ac70028dba10f92f6cdd059096db198c283794ac86", size = 153471, upload-time = "2025-10-14T04:40:59.468Z" }, + { url = "https://files.pythonhosted.org/packages/61/fa/fbf177b55bdd727010f9c0a3c49eefa1d10f960e5f09d1d887bf93c2e698/charset_normalizer-3.4.4-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:ac1c4a689edcc530fc9d9aa11f5774b9e2f33f9a0c6a57864e90908f5208d30a", size = 150864, upload-time = "2025-10-14T04:41:00.623Z" }, + { url = "https://files.pythonhosted.org/packages/05/12/9fbc6a4d39c0198adeebbde20b619790e9236557ca59fc40e0e3cebe6f40/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:21d142cc6c0ec30d2efee5068ca36c128a30b0f2c53c1c07bd78cb6bc1d3be5f", size = 150647, upload-time = "2025-10-14T04:41:01.754Z" }, + { url = "https://files.pythonhosted.org/packages/ad/1f/6a9a593d52e3e8c5d2b167daf8c6b968808efb57ef4c210acb907c365bc4/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:5dbe56a36425d26d6cfb40ce79c314a2e4dd6211d51d6d2191c00bed34f354cc", size = 145110, upload-time = "2025-10-14T04:41:03.231Z" }, + { url = "https://files.pythonhosted.org/packages/30/42/9a52c609e72471b0fc54386dc63c3781a387bb4fe61c20231a4ebcd58bdd/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:5bfbb1b9acf3334612667b61bd3002196fe2a1eb4dd74d247e0f2a4d50ec9bbf", size = 162839, upload-time = "2025-10-14T04:41:04.715Z" }, + { url = "https://files.pythonhosted.org/packages/c4/5b/c0682bbf9f11597073052628ddd38344a3d673fda35a36773f7d19344b23/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:d055ec1e26e441f6187acf818b73564e6e6282709e9bcb5b63f5b23068356a15", size = 150667, upload-time = "2025-10-14T04:41:05.827Z" }, + { url = "https://files.pythonhosted.org/packages/e4/24/a41afeab6f990cf2daf6cb8c67419b63b48cf518e4f56022230840c9bfb2/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:af2d8c67d8e573d6de5bc30cdb27e9b95e49115cd9baad5ddbd1a6207aaa82a9", size = 160535, upload-time = "2025-10-14T04:41:06.938Z" }, + { url = "https://files.pythonhosted.org/packages/2a/e5/6a4ce77ed243c4a50a1fecca6aaaab419628c818a49434be428fe24c9957/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:780236ac706e66881f3b7f2f32dfe90507a09e67d1d454c762cf642e6e1586e0", size = 154816, upload-time = "2025-10-14T04:41:08.101Z" }, + { url = "https://files.pythonhosted.org/packages/a8/ef/89297262b8092b312d29cdb2517cb1237e51db8ecef2e9af5edbe7b683b1/charset_normalizer-3.4.4-cp312-cp312-win32.whl", hash = "sha256:5833d2c39d8896e4e19b689ffc198f08ea58116bee26dea51e362ecc7cd3ed26", size = 99694, upload-time = "2025-10-14T04:41:09.23Z" }, + { url = "https://files.pythonhosted.org/packages/3d/2d/1e5ed9dd3b3803994c155cd9aacb60c82c331bad84daf75bcb9c91b3295e/charset_normalizer-3.4.4-cp312-cp312-win_amd64.whl", hash = "sha256:a79cfe37875f822425b89a82333404539ae63dbdddf97f84dcbc3d339aae9525", size = 107131, upload-time = "2025-10-14T04:41:10.467Z" }, + { url = "https://files.pythonhosted.org/packages/d0/d9/0ed4c7098a861482a7b6a95603edce4c0d9db2311af23da1fb2b75ec26fc/charset_normalizer-3.4.4-cp312-cp312-win_arm64.whl", hash = "sha256:376bec83a63b8021bb5c8ea75e21c4ccb86e7e45ca4eb81146091b56599b80c3", size = 100390, upload-time = "2025-10-14T04:41:11.915Z" }, + { url = "https://files.pythonhosted.org/packages/97/45/4b3a1239bbacd321068ea6e7ac28875b03ab8bc0aa0966452db17cd36714/charset_normalizer-3.4.4-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:e1f185f86a6f3403aa2420e815904c67b2f9ebc443f045edd0de921108345794", size = 208091, upload-time = "2025-10-14T04:41:13.346Z" }, + { url = "https://files.pythonhosted.org/packages/7d/62/73a6d7450829655a35bb88a88fca7d736f9882a27eacdca2c6d505b57e2e/charset_normalizer-3.4.4-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6b39f987ae8ccdf0d2642338faf2abb1862340facc796048b604ef14919e55ed", size = 147936, upload-time = "2025-10-14T04:41:14.461Z" }, + { url = "https://files.pythonhosted.org/packages/89/c5/adb8c8b3d6625bef6d88b251bbb0d95f8205831b987631ab0c8bb5d937c2/charset_normalizer-3.4.4-cp313-cp313-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:3162d5d8ce1bb98dd51af660f2121c55d0fa541b46dff7bb9b9f86ea1d87de72", size = 144180, upload-time = "2025-10-14T04:41:15.588Z" }, + { url = "https://files.pythonhosted.org/packages/91/ed/9706e4070682d1cc219050b6048bfd293ccf67b3d4f5a4f39207453d4b99/charset_normalizer-3.4.4-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:81d5eb2a312700f4ecaa977a8235b634ce853200e828fbadf3a9c50bab278328", size = 161346, upload-time = "2025-10-14T04:41:16.738Z" }, + { url = "https://files.pythonhosted.org/packages/d5/0d/031f0d95e4972901a2f6f09ef055751805ff541511dc1252ba3ca1f80cf5/charset_normalizer-3.4.4-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:5bd2293095d766545ec1a8f612559f6b40abc0eb18bb2f5d1171872d34036ede", size = 158874, upload-time = "2025-10-14T04:41:17.923Z" }, + { url = "https://files.pythonhosted.org/packages/f5/83/6ab5883f57c9c801ce5e5677242328aa45592be8a00644310a008d04f922/charset_normalizer-3.4.4-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a8a8b89589086a25749f471e6a900d3f662d1d3b6e2e59dcecf787b1cc3a1894", size = 153076, upload-time = "2025-10-14T04:41:19.106Z" }, + { url = "https://files.pythonhosted.org/packages/75/1e/5ff781ddf5260e387d6419959ee89ef13878229732732ee73cdae01800f2/charset_normalizer-3.4.4-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:bc7637e2f80d8530ee4a78e878bce464f70087ce73cf7c1caf142416923b98f1", size = 150601, upload-time = "2025-10-14T04:41:20.245Z" }, + { url = "https://files.pythonhosted.org/packages/d7/57/71be810965493d3510a6ca79b90c19e48696fb1ff964da319334b12677f0/charset_normalizer-3.4.4-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:f8bf04158c6b607d747e93949aa60618b61312fe647a6369f88ce2ff16043490", size = 150376, upload-time = "2025-10-14T04:41:21.398Z" }, + { url = "https://files.pythonhosted.org/packages/e5/d5/c3d057a78c181d007014feb7e9f2e65905a6c4ef182c0ddf0de2924edd65/charset_normalizer-3.4.4-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:554af85e960429cf30784dd47447d5125aaa3b99a6f0683589dbd27e2f45da44", size = 144825, upload-time = "2025-10-14T04:41:22.583Z" }, + { url = "https://files.pythonhosted.org/packages/e6/8c/d0406294828d4976f275ffbe66f00266c4b3136b7506941d87c00cab5272/charset_normalizer-3.4.4-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:74018750915ee7ad843a774364e13a3db91682f26142baddf775342c3f5b1133", size = 162583, upload-time = "2025-10-14T04:41:23.754Z" }, + { url = "https://files.pythonhosted.org/packages/d7/24/e2aa1f18c8f15c4c0e932d9287b8609dd30ad56dbe41d926bd846e22fb8d/charset_normalizer-3.4.4-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:c0463276121fdee9c49b98908b3a89c39be45d86d1dbaa22957e38f6321d4ce3", size = 150366, upload-time = "2025-10-14T04:41:25.27Z" }, + { url = "https://files.pythonhosted.org/packages/e4/5b/1e6160c7739aad1e2df054300cc618b06bf784a7a164b0f238360721ab86/charset_normalizer-3.4.4-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:362d61fd13843997c1c446760ef36f240cf81d3ebf74ac62652aebaf7838561e", size = 160300, upload-time = "2025-10-14T04:41:26.725Z" }, + { url = "https://files.pythonhosted.org/packages/7a/10/f882167cd207fbdd743e55534d5d9620e095089d176d55cb22d5322f2afd/charset_normalizer-3.4.4-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:9a26f18905b8dd5d685d6d07b0cdf98a79f3c7a918906af7cc143ea2e164c8bc", size = 154465, upload-time = "2025-10-14T04:41:28.322Z" }, + { url = "https://files.pythonhosted.org/packages/89/66/c7a9e1b7429be72123441bfdbaf2bc13faab3f90b933f664db506dea5915/charset_normalizer-3.4.4-cp313-cp313-win32.whl", hash = "sha256:9b35f4c90079ff2e2edc5b26c0c77925e5d2d255c42c74fdb70fb49b172726ac", size = 99404, upload-time = "2025-10-14T04:41:29.95Z" }, + { url = "https://files.pythonhosted.org/packages/c4/26/b9924fa27db384bdcd97ab83b4f0a8058d96ad9626ead570674d5e737d90/charset_normalizer-3.4.4-cp313-cp313-win_amd64.whl", hash = "sha256:b435cba5f4f750aa6c0a0d92c541fb79f69a387c91e61f1795227e4ed9cece14", size = 107092, upload-time = "2025-10-14T04:41:31.188Z" }, + { url = "https://files.pythonhosted.org/packages/af/8f/3ed4bfa0c0c72a7ca17f0380cd9e4dd842b09f664e780c13cff1dcf2ef1b/charset_normalizer-3.4.4-cp313-cp313-win_arm64.whl", hash = "sha256:542d2cee80be6f80247095cc36c418f7bddd14f4a6de45af91dfad36d817bba2", size = 100408, upload-time = "2025-10-14T04:41:32.624Z" }, + { url = "https://files.pythonhosted.org/packages/0a/4c/925909008ed5a988ccbb72dcc897407e5d6d3bd72410d69e051fc0c14647/charset_normalizer-3.4.4-py3-none-any.whl", hash = "sha256:7a32c560861a02ff789ad905a2fe94e3f840803362c84fecf1851cb4cf3dc37f", size = 53402, upload-time = "2025-10-14T04:42:31.76Z" }, +] + +[[package]] +name = "colorama" +version = "0.4.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697, upload-time = "2022-10-25T02:36:22.414Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" }, +] + +[[package]] +name = "cons" +version = "0.4.7" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "logical-unification" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ae/20/0eca1dcdbac64a570e60df66119847f94cdd513178d9c222c15101ca1022/cons-0.4.7.tar.gz", hash = "sha256:0a96cd2abd6a9f494816c1272cf5583a960041750c2d7a48eeeccd47ce369dfd", size = 8690, upload-time = "2025-07-11T18:01:31.534Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a5/9f/bffa3362895e5437d9d12e3bbd242f86d91af1d7cd26f6e14ebb6376581b/cons-0.4.7-py3-none-any.whl", hash = "sha256:e38ee12cf703559ea744c94f725bee0e2329f32daf0249b49db1b0437cc6cb94", size = 8603, upload-time = "2025-07-11T18:01:28.706Z" }, +] + +[[package]] +name = "coverage" +version = "7.11.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/1c/38/ee22495420457259d2f3390309505ea98f98a5eed40901cf62196abad006/coverage-7.11.0.tar.gz", hash = "sha256:167bd504ac1ca2af7ff3b81d245dfea0292c5032ebef9d66cc08a7d28c1b8050", size = 811905, upload-time = "2025-10-15T15:15:08.542Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/49/3a/ee1074c15c408ddddddb1db7dd904f6b81bc524e01f5a1c5920e13dbde23/coverage-7.11.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3d58ecaa865c5b9fa56e35efc51d1014d4c0d22838815b9fce57a27dd9576847", size = 215912, upload-time = "2025-10-15T15:12:40.665Z" }, + { url = "https://files.pythonhosted.org/packages/70/c4/9f44bebe5cb15f31608597b037d78799cc5f450044465bcd1ae8cb222fe1/coverage-7.11.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:b679e171f1c104a5668550ada700e3c4937110dbdd153b7ef9055c4f1a1ee3cc", size = 216310, upload-time = "2025-10-15T15:12:42.461Z" }, + { url = "https://files.pythonhosted.org/packages/42/01/5e06077cfef92d8af926bdd86b84fb28bf9bc6ad27343d68be9b501d89f2/coverage-7.11.0-cp311-cp311-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:ca61691ba8c5b6797deb221a0d09d7470364733ea9c69425a640f1f01b7c5bf0", size = 246706, upload-time = "2025-10-15T15:12:44.001Z" }, + { url = "https://files.pythonhosted.org/packages/40/b8/7a3f1f33b35cc4a6c37e759137533119560d06c0cc14753d1a803be0cd4a/coverage-7.11.0-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:aef1747ede4bd8ca9cfc04cc3011516500c6891f1b33a94add3253f6f876b7b7", size = 248634, upload-time = "2025-10-15T15:12:45.768Z" }, + { url = "https://files.pythonhosted.org/packages/7a/41/7f987eb33de386bc4c665ab0bf98d15fcf203369d6aacae74f5dd8ec489a/coverage-7.11.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a1839d08406e4cba2953dcc0ffb312252f14d7c4c96919f70167611f4dee2623", size = 250741, upload-time = "2025-10-15T15:12:47.222Z" }, + { url = "https://files.pythonhosted.org/packages/23/c1/a4e0ca6a4e83069fb8216b49b30a7352061ca0cb38654bd2dc96b7b3b7da/coverage-7.11.0-cp311-cp311-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:e0eb0a2dcc62478eb5b4cbb80b97bdee852d7e280b90e81f11b407d0b81c4287", size = 246837, upload-time = "2025-10-15T15:12:48.904Z" }, + { url = "https://files.pythonhosted.org/packages/5d/03/ced062a17f7c38b4728ff76c3acb40d8465634b20b4833cdb3cc3a74e115/coverage-7.11.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:bc1fbea96343b53f65d5351d8fd3b34fd415a2670d7c300b06d3e14a5af4f552", size = 248429, upload-time = "2025-10-15T15:12:50.73Z" }, + { url = "https://files.pythonhosted.org/packages/97/af/a7c6f194bb8c5a2705ae019036b8fe7f49ea818d638eedb15fdb7bed227c/coverage-7.11.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:214b622259dd0cf435f10241f1333d32caa64dbc27f8790ab693428a141723de", size = 246490, upload-time = "2025-10-15T15:12:52.646Z" }, + { url = "https://files.pythonhosted.org/packages/ab/c3/aab4df02b04a8fde79068c3c41ad7a622b0ef2b12e1ed154da986a727c3f/coverage-7.11.0-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:258d9967520cca899695d4eb7ea38be03f06951d6ca2f21fb48b1235f791e601", size = 246208, upload-time = "2025-10-15T15:12:54.586Z" }, + { url = "https://files.pythonhosted.org/packages/30/d8/e282ec19cd658238d60ed404f99ef2e45eed52e81b866ab1518c0d4163cf/coverage-7.11.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:cf9e6ff4ca908ca15c157c409d608da77a56a09877b97c889b98fb2c32b6465e", size = 247126, upload-time = "2025-10-15T15:12:56.485Z" }, + { url = "https://files.pythonhosted.org/packages/d1/17/a635fa07fac23adb1a5451ec756216768c2767efaed2e4331710342a3399/coverage-7.11.0-cp311-cp311-win32.whl", hash = "sha256:fcc15fc462707b0680cff6242c48625da7f9a16a28a41bb8fd7a4280920e676c", size = 218314, upload-time = "2025-10-15T15:12:58.365Z" }, + { url = "https://files.pythonhosted.org/packages/2a/29/2ac1dfcdd4ab9a70026edc8d715ece9b4be9a1653075c658ee6f271f394d/coverage-7.11.0-cp311-cp311-win_amd64.whl", hash = "sha256:865965bf955d92790f1facd64fe7ff73551bd2c1e7e6b26443934e9701ba30b9", size = 219203, upload-time = "2025-10-15T15:12:59.902Z" }, + { url = "https://files.pythonhosted.org/packages/03/21/5ce8b3a0133179115af4c041abf2ee652395837cb896614beb8ce8ddcfd9/coverage-7.11.0-cp311-cp311-win_arm64.whl", hash = "sha256:5693e57a065760dcbeb292d60cc4d0231a6d4b6b6f6a3191561e1d5e8820b745", size = 217879, upload-time = "2025-10-15T15:13:01.35Z" }, + { url = "https://files.pythonhosted.org/packages/c4/db/86f6906a7c7edc1a52b2c6682d6dd9be775d73c0dfe2b84f8923dfea5784/coverage-7.11.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:9c49e77811cf9d024b95faf86c3f059b11c0c9be0b0d61bc598f453703bd6fd1", size = 216098, upload-time = "2025-10-15T15:13:02.916Z" }, + { url = "https://files.pythonhosted.org/packages/21/54/e7b26157048c7ba555596aad8569ff903d6cd67867d41b75287323678ede/coverage-7.11.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a61e37a403a778e2cda2a6a39abcc895f1d984071942a41074b5c7ee31642007", size = 216331, upload-time = "2025-10-15T15:13:04.403Z" }, + { url = "https://files.pythonhosted.org/packages/b9/19/1ce6bf444f858b83a733171306134a0544eaddf1ca8851ede6540a55b2ad/coverage-7.11.0-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:c79cae102bb3b1801e2ef1511fb50e91ec83a1ce466b2c7c25010d884336de46", size = 247825, upload-time = "2025-10-15T15:13:05.92Z" }, + { url = "https://files.pythonhosted.org/packages/71/0b/d3bcbbc259fcced5fb67c5d78f6e7ee965f49760c14afd931e9e663a83b2/coverage-7.11.0-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:16ce17ceb5d211f320b62df002fa7016b7442ea0fd260c11cec8ce7730954893", size = 250573, upload-time = "2025-10-15T15:13:07.471Z" }, + { url = "https://files.pythonhosted.org/packages/58/8d/b0ff3641a320abb047258d36ed1c21d16be33beed4152628331a1baf3365/coverage-7.11.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:80027673e9d0bd6aef86134b0771845e2da85755cf686e7c7c59566cf5a89115", size = 251706, upload-time = "2025-10-15T15:13:09.4Z" }, + { url = "https://files.pythonhosted.org/packages/59/c8/5a586fe8c7b0458053d9c687f5cff515a74b66c85931f7fe17a1c958b4ac/coverage-7.11.0-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:4d3ffa07a08657306cd2215b0da53761c4d73cb54d9143b9303a6481ec0cd415", size = 248221, upload-time = "2025-10-15T15:13:10.964Z" }, + { url = "https://files.pythonhosted.org/packages/d0/ff/3a25e3132804ba44cfa9a778cdf2b73dbbe63ef4b0945e39602fc896ba52/coverage-7.11.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:a3b6a5f8b2524fd6c1066bc85bfd97e78709bb5e37b5b94911a6506b65f47186", size = 249624, upload-time = "2025-10-15T15:13:12.5Z" }, + { url = "https://files.pythonhosted.org/packages/c5/12/ff10c8ce3895e1b17a73485ea79ebc1896a9e466a9d0f4aef63e0d17b718/coverage-7.11.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:fcc0a4aa589de34bc56e1a80a740ee0f8c47611bdfb28cd1849de60660f3799d", size = 247744, upload-time = "2025-10-15T15:13:14.554Z" }, + { url = "https://files.pythonhosted.org/packages/16/02/d500b91f5471b2975947e0629b8980e5e90786fe316b6d7299852c1d793d/coverage-7.11.0-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:dba82204769d78c3fd31b35c3d5f46e06511936c5019c39f98320e05b08f794d", size = 247325, upload-time = "2025-10-15T15:13:16.438Z" }, + { url = "https://files.pythonhosted.org/packages/77/11/dee0284fbbd9cd64cfce806b827452c6df3f100d9e66188e82dfe771d4af/coverage-7.11.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:81b335f03ba67309a95210caf3eb43bd6fe75a4e22ba653ef97b4696c56c7ec2", size = 249180, upload-time = "2025-10-15T15:13:17.959Z" }, + { url = "https://files.pythonhosted.org/packages/59/1b/cdf1def928f0a150a057cab03286774e73e29c2395f0d30ce3d9e9f8e697/coverage-7.11.0-cp312-cp312-win32.whl", hash = "sha256:037b2d064c2f8cc8716fe4d39cb705779af3fbf1ba318dc96a1af858888c7bb5", size = 218479, upload-time = "2025-10-15T15:13:19.608Z" }, + { url = "https://files.pythonhosted.org/packages/ff/55/e5884d55e031da9c15b94b90a23beccc9d6beee65e9835cd6da0a79e4f3a/coverage-7.11.0-cp312-cp312-win_amd64.whl", hash = "sha256:d66c0104aec3b75e5fd897e7940188ea1892ca1d0235316bf89286d6a22568c0", size = 219290, upload-time = "2025-10-15T15:13:21.593Z" }, + { url = "https://files.pythonhosted.org/packages/23/a8/faa930cfc71c1d16bc78f9a19bb73700464f9c331d9e547bfbc1dbd3a108/coverage-7.11.0-cp312-cp312-win_arm64.whl", hash = "sha256:d91ebeac603812a09cf6a886ba6e464f3bbb367411904ae3790dfe28311b15ad", size = 217924, upload-time = "2025-10-15T15:13:23.39Z" }, + { url = "https://files.pythonhosted.org/packages/60/7f/85e4dfe65e400645464b25c036a26ac226cf3a69d4a50c3934c532491cdd/coverage-7.11.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:cc3f49e65ea6e0d5d9bd60368684fe52a704d46f9e7fc413918f18d046ec40e1", size = 216129, upload-time = "2025-10-15T15:13:25.371Z" }, + { url = "https://files.pythonhosted.org/packages/96/5d/dc5fa98fea3c175caf9d360649cb1aa3715e391ab00dc78c4c66fabd7356/coverage-7.11.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:f39ae2f63f37472c17b4990f794035c9890418b1b8cca75c01193f3c8d3e01be", size = 216380, upload-time = "2025-10-15T15:13:26.976Z" }, + { url = "https://files.pythonhosted.org/packages/b2/f5/3da9cc9596708273385189289c0e4d8197d37a386bdf17619013554b3447/coverage-7.11.0-cp313-cp313-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:7db53b5cdd2917b6eaadd0b1251cf4e7d96f4a8d24e174bdbdf2f65b5ea7994d", size = 247375, upload-time = "2025-10-15T15:13:28.923Z" }, + { url = "https://files.pythonhosted.org/packages/65/6c/f7f59c342359a235559d2bc76b0c73cfc4bac7d61bb0df210965cb1ecffd/coverage-7.11.0-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:10ad04ac3a122048688387828b4537bc9cf60c0bf4869c1e9989c46e45690b82", size = 249978, upload-time = "2025-10-15T15:13:30.525Z" }, + { url = "https://files.pythonhosted.org/packages/e7/8c/042dede2e23525e863bf1ccd2b92689692a148d8b5fd37c37899ba882645/coverage-7.11.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4036cc9c7983a2b1f2556d574d2eb2154ac6ed55114761685657e38782b23f52", size = 251253, upload-time = "2025-10-15T15:13:32.174Z" }, + { url = "https://files.pythonhosted.org/packages/7b/a9/3c58df67bfa809a7bddd786356d9c5283e45d693edb5f3f55d0986dd905a/coverage-7.11.0-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:7ab934dd13b1c5e94b692b1e01bd87e4488cb746e3a50f798cb9464fd128374b", size = 247591, upload-time = "2025-10-15T15:13:34.147Z" }, + { url = "https://files.pythonhosted.org/packages/26/5b/c7f32efd862ee0477a18c41e4761305de6ddd2d49cdeda0c1116227570fd/coverage-7.11.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:59a6e5a265f7cfc05f76e3bb53eca2e0dfe90f05e07e849930fecd6abb8f40b4", size = 249411, upload-time = "2025-10-15T15:13:38.425Z" }, + { url = "https://files.pythonhosted.org/packages/76/b5/78cb4f1e86c1611431c990423ec0768122905b03837e1b4c6a6f388a858b/coverage-7.11.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:df01d6c4c81e15a7c88337b795bb7595a8596e92310266b5072c7e301168efbd", size = 247303, upload-time = "2025-10-15T15:13:40.464Z" }, + { url = "https://files.pythonhosted.org/packages/87/c9/23c753a8641a330f45f221286e707c427e46d0ffd1719b080cedc984ec40/coverage-7.11.0-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:8c934bd088eed6174210942761e38ee81d28c46de0132ebb1801dbe36a390dcc", size = 247157, upload-time = "2025-10-15T15:13:42.087Z" }, + { url = "https://files.pythonhosted.org/packages/c5/42/6e0cc71dc8a464486e944a4fa0d85bdec031cc2969e98ed41532a98336b9/coverage-7.11.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5a03eaf7ec24078ad64a07f02e30060aaf22b91dedf31a6b24d0d98d2bba7f48", size = 248921, upload-time = "2025-10-15T15:13:43.715Z" }, + { url = "https://files.pythonhosted.org/packages/e8/1c/743c2ef665e6858cccb0f84377dfe3a4c25add51e8c7ef19249be92465b6/coverage-7.11.0-cp313-cp313-win32.whl", hash = "sha256:695340f698a5f56f795b2836abe6fb576e7c53d48cd155ad2f80fd24bc63a040", size = 218526, upload-time = "2025-10-15T15:13:45.336Z" }, + { url = "https://files.pythonhosted.org/packages/ff/d5/226daadfd1bf8ddbccefbd3aa3547d7b960fb48e1bdac124e2dd13a2b71a/coverage-7.11.0-cp313-cp313-win_amd64.whl", hash = "sha256:2727d47fce3ee2bac648528e41455d1b0c46395a087a229deac75e9f88ba5a05", size = 219317, upload-time = "2025-10-15T15:13:47.401Z" }, + { url = "https://files.pythonhosted.org/packages/97/54/47db81dcbe571a48a298f206183ba8a7ba79200a37cd0d9f4788fcd2af4a/coverage-7.11.0-cp313-cp313-win_arm64.whl", hash = "sha256:0efa742f431529699712b92ecdf22de8ff198df41e43aeaaadf69973eb93f17a", size = 217948, upload-time = "2025-10-15T15:13:49.096Z" }, + { url = "https://files.pythonhosted.org/packages/e5/8b/cb68425420154e7e2a82fd779a8cc01549b6fa83c2ad3679cd6c088ebd07/coverage-7.11.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:587c38849b853b157706407e9ebdca8fd12f45869edb56defbef2daa5fb0812b", size = 216837, upload-time = "2025-10-15T15:13:51.09Z" }, + { url = "https://files.pythonhosted.org/packages/33/55/9d61b5765a025685e14659c8d07037247de6383c0385757544ffe4606475/coverage-7.11.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:b971bdefdd75096163dd4261c74be813c4508477e39ff7b92191dea19f24cd37", size = 217061, upload-time = "2025-10-15T15:13:52.747Z" }, + { url = "https://files.pythonhosted.org/packages/52/85/292459c9186d70dcec6538f06ea251bc968046922497377bf4a1dc9a71de/coverage-7.11.0-cp313-cp313t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:269bfe913b7d5be12ab13a95f3a76da23cf147be7fa043933320ba5625f0a8de", size = 258398, upload-time = "2025-10-15T15:13:54.45Z" }, + { url = "https://files.pythonhosted.org/packages/1f/e2/46edd73fb8bf51446c41148d81944c54ed224854812b6ca549be25113ee0/coverage-7.11.0-cp313-cp313t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:dadbcce51a10c07b7c72b0ce4a25e4b6dcb0c0372846afb8e5b6307a121eb99f", size = 260574, upload-time = "2025-10-15T15:13:56.145Z" }, + { url = "https://files.pythonhosted.org/packages/07/5e/1df469a19007ff82e2ca8fe509822820a31e251f80ee7344c34f6cd2ec43/coverage-7.11.0-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9ed43fa22c6436f7957df036331f8fe4efa7af132054e1844918866cd228af6c", size = 262797, upload-time = "2025-10-15T15:13:58.635Z" }, + { url = "https://files.pythonhosted.org/packages/f9/50/de216b31a1434b94d9b34a964c09943c6be45069ec704bfc379d8d89a649/coverage-7.11.0-cp313-cp313t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:9516add7256b6713ec08359b7b05aeff8850c98d357784c7205b2e60aa2513fa", size = 257361, upload-time = "2025-10-15T15:14:00.409Z" }, + { url = "https://files.pythonhosted.org/packages/82/1e/3f9f8344a48111e152e0fd495b6fff13cc743e771a6050abf1627a7ba918/coverage-7.11.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:eb92e47c92fcbcdc692f428da67db33337fa213756f7adb6a011f7b5a7a20740", size = 260349, upload-time = "2025-10-15T15:14:02.188Z" }, + { url = "https://files.pythonhosted.org/packages/65/9b/3f52741f9e7d82124272f3070bbe316006a7de1bad1093f88d59bfc6c548/coverage-7.11.0-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:d06f4fc7acf3cabd6d74941d53329e06bab00a8fe10e4df2714f0b134bfc64ef", size = 258114, upload-time = "2025-10-15T15:14:03.907Z" }, + { url = "https://files.pythonhosted.org/packages/0b/8b/918f0e15f0365d50d3986bbd3338ca01178717ac5678301f3f547b6619e6/coverage-7.11.0-cp313-cp313t-musllinux_1_2_riscv64.whl", hash = "sha256:6fbcee1a8f056af07ecd344482f711f563a9eb1c2cad192e87df00338ec3cdb0", size = 256723, upload-time = "2025-10-15T15:14:06.324Z" }, + { url = "https://files.pythonhosted.org/packages/44/9e/7776829f82d3cf630878a7965a7d70cc6ca94f22c7d20ec4944f7148cb46/coverage-7.11.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:dbbf012be5f32533a490709ad597ad8a8ff80c582a95adc8d62af664e532f9ca", size = 259238, upload-time = "2025-10-15T15:14:08.002Z" }, + { url = "https://files.pythonhosted.org/packages/9a/b8/49cf253e1e7a3bedb85199b201862dd7ca4859f75b6cf25ffa7298aa0760/coverage-7.11.0-cp313-cp313t-win32.whl", hash = "sha256:cee6291bb4fed184f1c2b663606a115c743df98a537c969c3c64b49989da96c2", size = 219180, upload-time = "2025-10-15T15:14:09.786Z" }, + { url = "https://files.pythonhosted.org/packages/ac/e1/1a541703826be7ae2125a0fb7f821af5729d56bb71e946e7b933cc7a89a4/coverage-7.11.0-cp313-cp313t-win_amd64.whl", hash = "sha256:a386c1061bf98e7ea4758e4313c0ab5ecf57af341ef0f43a0bf26c2477b5c268", size = 220241, upload-time = "2025-10-15T15:14:11.471Z" }, + { url = "https://files.pythonhosted.org/packages/d5/d1/5ee0e0a08621140fd418ec4020f595b4d52d7eb429ae6a0c6542b4ba6f14/coverage-7.11.0-cp313-cp313t-win_arm64.whl", hash = "sha256:f9ea02ef40bb83823b2b04964459d281688fe173e20643870bb5d2edf68bc836", size = 218510, upload-time = "2025-10-15T15:14:13.46Z" }, + { url = "https://files.pythonhosted.org/packages/5f/04/642c1d8a448ae5ea1369eac8495740a79eb4e581a9fb0cbdce56bbf56da1/coverage-7.11.0-py3-none-any.whl", hash = "sha256:4b7589765348d78fb4e5fb6ea35d07564e387da2fc5efff62e0222971f155f68", size = 207761, upload-time = "2025-10-15T15:15:06.439Z" }, +] + +[package.optional-dependencies] +toml = [ + { name = "tomli", marker = "python_full_version <= '3.11'" }, +] + +[[package]] +name = "distlib" +version = "0.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/96/8e/709914eb2b5749865801041647dc7f4e6d00b549cfe88b65ca192995f07c/distlib-0.4.0.tar.gz", hash = "sha256:feec40075be03a04501a973d81f633735b4b69f98b05450592310c0f401a4e0d", size = 614605, upload-time = "2025-07-17T16:52:00.465Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/33/6b/e0547afaf41bf2c42e52430072fa5658766e3d65bd4b03a563d1b6336f57/distlib-0.4.0-py2.py3-none-any.whl", hash = "sha256:9659f7d87e46584a30b5780e43ac7a2143098441670ff0a49d5f9034c54a6c16", size = 469047, upload-time = "2025-07-17T16:51:58.613Z" }, +] + +[[package]] +name = "docutils" +version = "0.19" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/6b/5c/330ea8d383eb2ce973df34d1239b3b21e91cd8c865d21ff82902d952f91f/docutils-0.19.tar.gz", hash = "sha256:33995a6753c30b7f577febfc2c50411fec6aac7f7ffeb7c4cfe5991072dcf9e6", size = 2056383, upload-time = "2022-07-05T20:17:31.045Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/93/69/e391bd51bc08ed9141ecd899a0ddb61ab6465309f1eb470905c0c8868081/docutils-0.19-py3-none-any.whl", hash = "sha256:5e1de4d849fee02c63b040a4a3fd567f4ab104defd8a5511fbbc24a8a017efbc", size = 570472, upload-time = "2022-07-05T20:17:26.388Z" }, +] + +[[package]] +name = "etuples" +version = "0.3.10" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cons" }, + { name = "multipledispatch" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/42/c0/ba049efa7d216221713cffc303641bd73bbb309ff0e4e2a623f32af2a4ea/etuples-0.3.10.tar.gz", hash = "sha256:26fde81d7e822837146231bfce4d6ba67eab5d7ed55bc58ba7437c2568051167", size = 21493, upload-time = "2025-07-14T18:49:35.654Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/39/19/bf11636df040a9f9c3fd6959aedea5b5cfddd751272732278fb04ee0a78c/etuples-0.3.10-py3-none-any.whl", hash = "sha256:4408c7940ef06af52dbbea0954a8a1817ed5750ce905ff48091ac3cd3aeb720b", size = 12201, upload-time = "2025-07-14T18:49:34.557Z" }, +] + +[[package]] +name = "filelock" +version = "3.20.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/58/46/0028a82567109b5ef6e4d2a1f04a583fb513e6cf9527fcdd09afd817deeb/filelock-3.20.0.tar.gz", hash = "sha256:711e943b4ec6be42e1d4e6690b48dc175c822967466bb31c0c293f34334c13f4", size = 18922, upload-time = "2025-10-08T18:03:50.056Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/76/91/7216b27286936c16f5b4d0c530087e4a54eead683e6b0b73dd0c64844af6/filelock-3.20.0-py3-none-any.whl", hash = "sha256:339b4732ffda5cd79b13f4e2711a31b0365ce445d95d243bb996273d072546a2", size = 16054, upload-time = "2025-10-08T18:03:48.35Z" }, +] + +[[package]] +name = "identify" +version = "2.6.15" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ff/e7/685de97986c916a6d93b3876139e00eef26ad5bbbd61925d670ae8013449/identify-2.6.15.tar.gz", hash = "sha256:e4f4864b96c6557ef2a1e1c951771838f4edc9df3a72ec7118b338801b11c7bf", size = 99311, upload-time = "2025-10-02T17:43:40.631Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0f/1c/e5fd8f973d4f375adb21565739498e2e9a1e54c858a97b9a8ccfdc81da9b/identify-2.6.15-py2.py3-none-any.whl", hash = "sha256:1181ef7608e00704db228516541eb83a88a9f94433a8c80bb9b5bd54b1d81757", size = 99183, upload-time = "2025-10-02T17:43:39.137Z" }, +] + +[[package]] +name = "idna" +version = "3.11" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/6f/6d/0703ccc57f3a7233505399edb88de3cbd678da106337b9fcde432b65ed60/idna-3.11.tar.gz", hash = "sha256:795dafcc9c04ed0c1fb032c2aa73654d8e8c5023a7df64a53f39190ada629902", size = 194582, upload-time = "2025-10-12T14:55:20.501Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0e/61/66938bbb5fc52dbdf84594873d5b51fb1f7c7794e9c0f5bd885f30bc507b/idna-3.11-py3-none-any.whl", hash = "sha256:771a87f49d9defaf64091e6e6fe9c18d4833f140bd19464795bc32d966ca37ea", size = 71008, upload-time = "2025-10-12T14:55:18.883Z" }, +] + +[[package]] +name = "imagesize" +version = "1.4.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a7/84/62473fb57d61e31fef6e36d64a179c8781605429fd927b5dd608c997be31/imagesize-1.4.1.tar.gz", hash = "sha256:69150444affb9cb0d5cc5a92b3676f0b2fb7cd9ae39e947a5e11a36b4497cd4a", size = 1280026, upload-time = "2022-07-01T12:21:05.687Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ff/62/85c4c919272577931d407be5ba5d71c20f0b616d31a0befe0ae45bb79abd/imagesize-1.4.1-py2.py3-none-any.whl", hash = "sha256:0d8d18d08f840c19d0ee7ca1fd82490fdc3729b7ac93f49870406ddde8ef8d8b", size = 8769, upload-time = "2022-07-01T12:21:02.467Z" }, +] + +[[package]] +name = "iniconfig" +version = "2.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/72/34/14ca021ce8e5dfedc35312d08ba8bf51fdd999c576889fc2c24cb97f4f10/iniconfig-2.3.0.tar.gz", hash = "sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730", size = 20503, upload-time = "2025-10-18T21:55:43.219Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" }, +] + +[[package]] +name = "jax" +version = "0.8.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jaxlib" }, + { name = "ml-dtypes" }, + { name = "numpy" }, + { name = "opt-einsum" }, + { name = "scipy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/22/1c/9baf805e6c969a1a7afeb37d359e8a10585e8b2621f103626998b42ae838/jax-0.8.0.tar.gz", hash = "sha256:0ea5a7be7068c25934450dfd87d7d80a18a5d30e0a53454e7aade525b23accd5", size = 2489031, upload-time = "2025-10-15T23:10:11.839Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b3/77/4e6c9a54247810eff8ac8a1af7dc1be0779b52df0d82f3fc8586061914f3/jax-0.8.0-py3-none-any.whl", hash = "sha256:d190158bc019756c6a0f6b3d5fc8783471fb407e6deaff559eaac60dd5ee850a", size = 2900279, upload-time = "2025-10-15T23:10:09.88Z" }, +] + +[[package]] +name = "jaxlib" +version = "0.8.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "ml-dtypes" }, + { name = "numpy" }, + { name = "scipy" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/5d/15/91c4fbd4017bdeaa0800b9aee02cce967b65e1ce79ece93c1b79a92a5a41/jaxlib-0.8.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:bb602a8c24c614cb8ca6eeed3e70a733d9399c6a2f88900a0252623cd67276b5", size = 54952368, upload-time = "2025-10-15T23:10:22.823Z" }, + { url = "https://files.pythonhosted.org/packages/68/ac/5a0469a9611c9e2886bd0315771dc75f582e467f2c814718cf35c5a46e51/jaxlib-0.8.0-cp311-cp311-manylinux_2_27_aarch64.whl", hash = "sha256:41aebddef67a555a6de17427a4e66ce60a528a815847e2dd96dabce579f7acf8", size = 73156932, upload-time = "2025-10-15T23:10:26.573Z" }, + { url = "https://files.pythonhosted.org/packages/bd/a5/eb6ef4bf19bbb8acb878579fd48c37e15d0803f6aded0dd91e77958dae20/jaxlib-0.8.0-cp311-cp311-manylinux_2_27_x86_64.whl", hash = "sha256:ff53e8baf978f6b7c4076215af78f0ba969cac434ed2f72565d87e38c23f00e7", size = 79692924, upload-time = "2025-10-15T23:10:29.816Z" }, + { url = "https://files.pythonhosted.org/packages/7b/4e/ea4540fec3388d9984fce3afbed99a6d9ab14a40a9c4745071e46ff0fa50/jaxlib-0.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:9cd4c7a8acc5b3dee4ad28a5d101264d89754e29553b0cdb92c79f5b460a511b", size = 59300184, upload-time = "2025-10-15T23:10:32.968Z" }, + { url = "https://files.pythonhosted.org/packages/17/3c/939138d7ee36d124d02bf411f8a76dda9606fb4adc3e1452cdc8ce7cb1f7/jaxlib-0.8.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f60aac0f64e9e70a5cef341fe292684518695514c71ad00036774bbed5f7312e", size = 54964234, upload-time = "2025-10-15T23:10:35.969Z" }, + { url = "https://files.pythonhosted.org/packages/4d/58/61e951fb2b0618fdaec6819a3e0f575ccf9dd7003a56598bb21c2a75dfe0/jaxlib-0.8.0-cp312-cp312-manylinux_2_27_aarch64.whl", hash = "sha256:d83ff8cf1b070299639cda4f8427707f69051dc8421e59fbb73305523937570d", size = 73158965, upload-time = "2025-10-15T23:10:39.497Z" }, + { url = "https://files.pythonhosted.org/packages/07/57/3e4abd3e8af698834c261a39247e4a098fef38378b9bd7b44f78b30f52ae/jaxlib-0.8.0-cp312-cp312-manylinux_2_27_x86_64.whl", hash = "sha256:2c8675bf86e391afe4f8d863080be1a024d734dfd3dd137f7aa8e7f22091adcd", size = 79698853, upload-time = "2025-10-15T23:10:43.35Z" }, + { url = "https://files.pythonhosted.org/packages/2a/17/c6d9dc31001a495cb3c52fa69b22a0d8812880cb853f7c0573e2a5edad82/jaxlib-0.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:659d894d93876e3675c2132d13c3d241f204b21172a58f928b96f654f603f6dc", size = 59323262, upload-time = "2025-10-15T23:10:46.607Z" }, + { url = "https://files.pythonhosted.org/packages/f6/76/f11130a3a6318a50662be4ee8c7ab6e61f3f334978653243ebc9d6f5d0bb/jaxlib-0.8.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:5fcf33a5639f8f164a473a9c78a1fa0b2e15ac3fcbecd6d96aa0f88bf25ea6bb", size = 54964169, upload-time = "2025-10-15T23:10:49.524Z" }, + { url = "https://files.pythonhosted.org/packages/24/2b/31ded3e83f3e198edc54519dc72cc829aa4875481ee6e19f123ef474f065/jaxlib-0.8.0-cp313-cp313-manylinux_2_27_aarch64.whl", hash = "sha256:b3eac503b90ffecc68f11fa122133eef2c62c536db28e801e436d7e7a9b67bf8", size = 73160932, upload-time = "2025-10-15T23:10:52.47Z" }, + { url = "https://files.pythonhosted.org/packages/8f/f0/cde1d84c737bdb75712f70d69561120ce91f3f294acf2fba573c0de740b6/jaxlib-0.8.0-cp313-cp313-manylinux_2_27_x86_64.whl", hash = "sha256:66c6f576f54a63ed052f5c469bef4db723f5f050b839ec0c429573011341bd58", size = 79698354, upload-time = "2025-10-15T23:10:55.822Z" }, + { url = "https://files.pythonhosted.org/packages/f1/be/88fa119a05525f7b683588b789c0e8f51292280dfcfbf7d0193bd3f7b651/jaxlib-0.8.0-cp313-cp313-win_amd64.whl", hash = "sha256:72759ebbfb40a717349f174712207d306aa28630359f05cd69b091bd4efa0603", size = 59323012, upload-time = "2025-10-15T23:10:59.475Z" }, + { url = "https://files.pythonhosted.org/packages/88/c9/2eabf3126424625dc0390a5382b8911c494b7dd8e902aa7c9d5607259664/jaxlib-0.8.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:df2781e0fc93fb6f42111b385b90126b9571eafe0e860f033615ff7156b76817", size = 55067941, upload-time = "2025-10-15T23:11:02.235Z" }, + { url = "https://files.pythonhosted.org/packages/72/7e/1d6ef4d730b381c382847e30e39b906d5bc7ba3c13c394c0412aa0a7261e/jaxlib-0.8.0-cp313-cp313t-manylinux_2_27_aarch64.whl", hash = "sha256:7eb3be931de77bfcde27df659ada432719aa1e19a2fa5b835638e7404c74cb63", size = 73278908, upload-time = "2025-10-15T23:11:05.299Z" }, + { url = "https://files.pythonhosted.org/packages/1f/3c/d1d424e5483a8bc5eba631892c58f6c6e738844195c065bc50e6506561c0/jaxlib-0.8.0-cp313-cp313t-manylinux_2_27_x86_64.whl", hash = "sha256:accebe89a36e28306a4db3f68f527a0f87b8a0fd253b3c1556fbd24f16bec22c", size = 79805682, upload-time = "2025-10-15T23:11:08.962Z" }, +] + +[[package]] +name = "jinja2" +version = "3.1.6" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markupsafe" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/df/bf/f7da0350254c0ed7c72f3e33cef02e048281fec7ecec5f032d4aac52226b/jinja2-3.1.6.tar.gz", hash = "sha256:0137fb05990d35f1275a587e9aee6d56da821fc83491a0fb838183be43f66d6d", size = 245115, upload-time = "2025-03-05T20:05:02.478Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/62/a1/3d680cbfd5f4b8f15abc1d571870c5fc3e594bb582bc3b64ea099db13e56/jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67", size = 134899, upload-time = "2025-03-05T20:05:00.369Z" }, +] + +[[package]] +name = "llvmlite" +version = "0.45.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/99/8d/5baf1cef7f9c084fb35a8afbde88074f0d6a727bc63ef764fe0e7543ba40/llvmlite-0.45.1.tar.gz", hash = "sha256:09430bb9d0bb58fc45a45a57c7eae912850bedc095cd0810a57de109c69e1c32", size = 185600, upload-time = "2025-10-01T17:59:52.046Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/04/ad/9bdc87b2eb34642c1cfe6bcb4f5db64c21f91f26b010f263e7467e7536a3/llvmlite-0.45.1-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:60f92868d5d3af30b4239b50e1717cb4e4e54f6ac1c361a27903b318d0f07f42", size = 43043526, upload-time = "2025-10-01T18:03:15.051Z" }, + { url = "https://files.pythonhosted.org/packages/a5/ea/c25c6382f452a943b4082da5e8c1665ce29a62884e2ec80608533e8e82d5/llvmlite-0.45.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:98baab513e19beb210f1ef39066288784839a44cd504e24fff5d17f1b3cf0860", size = 37253118, upload-time = "2025-10-01T18:04:06.783Z" }, + { url = "https://files.pythonhosted.org/packages/fe/af/85fc237de98b181dbbe8647324331238d6c52a3554327ccdc83ced28efba/llvmlite-0.45.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:3adc2355694d6a6fbcc024d59bb756677e7de506037c878022d7b877e7613a36", size = 56288209, upload-time = "2025-10-01T18:01:00.168Z" }, + { url = "https://files.pythonhosted.org/packages/0a/df/3daf95302ff49beff4230065e3178cd40e71294968e8d55baf4a9e560814/llvmlite-0.45.1-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2f3377a6db40f563058c9515dedcc8a3e562d8693a106a28f2ddccf2c8fcf6ca", size = 55140958, upload-time = "2025-10-01T18:02:11.199Z" }, + { url = "https://files.pythonhosted.org/packages/a4/56/4c0d503fe03bac820ecdeb14590cf9a248e120f483bcd5c009f2534f23f0/llvmlite-0.45.1-cp311-cp311-win_amd64.whl", hash = "sha256:f9c272682d91e0d57f2a76c6d9ebdfccc603a01828cdbe3d15273bdca0c3363a", size = 38132232, upload-time = "2025-10-01T18:04:52.181Z" }, + { url = "https://files.pythonhosted.org/packages/e2/7c/82cbd5c656e8991bcc110c69d05913be2229302a92acb96109e166ae31fb/llvmlite-0.45.1-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:28e763aba92fe9c72296911e040231d486447c01d4f90027c8e893d89d49b20e", size = 43043524, upload-time = "2025-10-01T18:03:30.666Z" }, + { url = "https://files.pythonhosted.org/packages/9d/bc/5314005bb2c7ee9f33102c6456c18cc81745d7055155d1218f1624463774/llvmlite-0.45.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1a53f4b74ee9fd30cb3d27d904dadece67a7575198bd80e687ee76474620735f", size = 37253123, upload-time = "2025-10-01T18:04:18.177Z" }, + { url = "https://files.pythonhosted.org/packages/96/76/0f7154952f037cb320b83e1c952ec4a19d5d689cf7d27cb8a26887d7bbc1/llvmlite-0.45.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5b3796b1b1e1c14dcae34285d2f4ea488402fbd2c400ccf7137603ca3800864f", size = 56288211, upload-time = "2025-10-01T18:01:24.079Z" }, + { url = "https://files.pythonhosted.org/packages/00/b1/0b581942be2683ceb6862d558979e87387e14ad65a1e4db0e7dd671fa315/llvmlite-0.45.1-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:779e2f2ceefef0f4368548685f0b4adde34e5f4b457e90391f570a10b348d433", size = 55140958, upload-time = "2025-10-01T18:02:30.482Z" }, + { url = "https://files.pythonhosted.org/packages/33/94/9ba4ebcf4d541a325fd8098ddc073b663af75cc8b065b6059848f7d4dce7/llvmlite-0.45.1-cp312-cp312-win_amd64.whl", hash = "sha256:9e6c9949baf25d9aa9cd7cf0f6d011b9ca660dd17f5ba2b23bdbdb77cc86b116", size = 38132231, upload-time = "2025-10-01T18:05:03.664Z" }, + { url = "https://files.pythonhosted.org/packages/1d/e2/c185bb7e88514d5025f93c6c4092f6120c6cea8fe938974ec9860fb03bbb/llvmlite-0.45.1-cp313-cp313-macosx_10_15_x86_64.whl", hash = "sha256:d9ea9e6f17569a4253515cc01dade70aba536476e3d750b2e18d81d7e670eb15", size = 43043524, upload-time = "2025-10-01T18:03:43.249Z" }, + { url = "https://files.pythonhosted.org/packages/09/b8/b5437b9ecb2064e89ccf67dccae0d02cd38911705112dd0dcbfa9cd9a9de/llvmlite-0.45.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:c9f3cadee1630ce4ac18ea38adebf2a4f57a89bd2740ce83746876797f6e0bfb", size = 37253121, upload-time = "2025-10-01T18:04:30.557Z" }, + { url = "https://files.pythonhosted.org/packages/f7/97/ad1a907c0173a90dd4df7228f24a3ec61058bc1a9ff8a0caec20a0cc622e/llvmlite-0.45.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:57c48bf2e1083eedbc9406fb83c4e6483017879714916fe8be8a72a9672c995a", size = 56288210, upload-time = "2025-10-01T18:01:40.26Z" }, + { url = "https://files.pythonhosted.org/packages/32/d8/c99c8ac7a326e9735401ead3116f7685a7ec652691aeb2615aa732b1fc4a/llvmlite-0.45.1-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3aa3dfceda4219ae39cf18806c60eeb518c1680ff834b8b311bd784160b9ce40", size = 55140957, upload-time = "2025-10-01T18:02:46.244Z" }, + { url = "https://files.pythonhosted.org/packages/09/56/ed35668130e32dbfad2eb37356793b0a95f23494ab5be7d9bf5cb75850ee/llvmlite-0.45.1-cp313-cp313-win_amd64.whl", hash = "sha256:080e6f8d0778a8239cd47686d402cb66eb165e421efa9391366a9b7e5810a38b", size = 38132232, upload-time = "2025-10-01T18:05:14.477Z" }, +] + +[[package]] +name = "logical-unification" +version = "0.4.7" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "multipledispatch" }, + { name = "toolz" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b0/5d/37673e494a4eed550785ad1268df0202e69aa081bcbf7c0aafd0a853b0fc/logical_unification-0.4.7.tar.gz", hash = "sha256:3d73b263a870827b3f52d89c94f3336afd7fcaecf1e0c67fa18e73025399775c", size = 13513, upload-time = "2025-10-20T21:42:24.904Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b1/d0/337b3c49cbe742ab5c118d14730fbc7b14b57d1a130d4f39efaa9ec04226/logical_unification-0.4.7-py3-none-any.whl", hash = "sha256:077f49e32693bc66a418f08c1de540f55b5a20f237ffb80ea85d99bfc6139c3b", size = 13469, upload-time = "2025-10-20T21:42:24.024Z" }, +] + +[[package]] +name = "markupsafe" +version = "3.0.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/7e/99/7690b6d4034fffd95959cbe0c02de8deb3098cc577c67bb6a24fe5d7caa7/markupsafe-3.0.3.tar.gz", hash = "sha256:722695808f4b6457b320fdc131280796bdceb04ab50fe1795cd540799ebe1698", size = 80313, upload-time = "2025-09-27T18:37:40.426Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/08/db/fefacb2136439fc8dd20e797950e749aa1f4997ed584c62cfb8ef7c2be0e/markupsafe-3.0.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:1cc7ea17a6824959616c525620e387f6dd30fec8cb44f649e31712db02123dad", size = 11631, upload-time = "2025-09-27T18:36:18.185Z" }, + { url = "https://files.pythonhosted.org/packages/e1/2e/5898933336b61975ce9dc04decbc0a7f2fee78c30353c5efba7f2d6ff27a/markupsafe-3.0.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4bd4cd07944443f5a265608cc6aab442e4f74dff8088b0dfc8238647b8f6ae9a", size = 12058, upload-time = "2025-09-27T18:36:19.444Z" }, + { url = "https://files.pythonhosted.org/packages/1d/09/adf2df3699d87d1d8184038df46a9c80d78c0148492323f4693df54e17bb/markupsafe-3.0.3-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6b5420a1d9450023228968e7e6a9ce57f65d148ab56d2313fcd589eee96a7a50", size = 24287, upload-time = "2025-09-27T18:36:20.768Z" }, + { url = "https://files.pythonhosted.org/packages/30/ac/0273f6fcb5f42e314c6d8cd99effae6a5354604d461b8d392b5ec9530a54/markupsafe-3.0.3-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0bf2a864d67e76e5c9a34dc26ec616a66b9888e25e7b9460e1c76d3293bd9dbf", size = 22940, upload-time = "2025-09-27T18:36:22.249Z" }, + { url = "https://files.pythonhosted.org/packages/19/ae/31c1be199ef767124c042c6c3e904da327a2f7f0cd63a0337e1eca2967a8/markupsafe-3.0.3-cp311-cp311-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:bc51efed119bc9cfdf792cdeaa4d67e8f6fcccab66ed4bfdd6bde3e59bfcbb2f", size = 21887, upload-time = "2025-09-27T18:36:23.535Z" }, + { url = "https://files.pythonhosted.org/packages/b2/76/7edcab99d5349a4532a459e1fe64f0b0467a3365056ae550d3bcf3f79e1e/markupsafe-3.0.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:068f375c472b3e7acbe2d5318dea141359e6900156b5b2ba06a30b169086b91a", size = 23692, upload-time = "2025-09-27T18:36:24.823Z" }, + { url = "https://files.pythonhosted.org/packages/a4/28/6e74cdd26d7514849143d69f0bf2399f929c37dc2b31e6829fd2045b2765/markupsafe-3.0.3-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:7be7b61bb172e1ed687f1754f8e7484f1c8019780f6f6b0786e76bb01c2ae115", size = 21471, upload-time = "2025-09-27T18:36:25.95Z" }, + { url = "https://files.pythonhosted.org/packages/62/7e/a145f36a5c2945673e590850a6f8014318d5577ed7e5920a4b3448e0865d/markupsafe-3.0.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:f9e130248f4462aaa8e2552d547f36ddadbeaa573879158d721bbd33dfe4743a", size = 22923, upload-time = "2025-09-27T18:36:27.109Z" }, + { url = "https://files.pythonhosted.org/packages/0f/62/d9c46a7f5c9adbeeeda52f5b8d802e1094e9717705a645efc71b0913a0a8/markupsafe-3.0.3-cp311-cp311-win32.whl", hash = "sha256:0db14f5dafddbb6d9208827849fad01f1a2609380add406671a26386cdf15a19", size = 14572, upload-time = "2025-09-27T18:36:28.045Z" }, + { url = "https://files.pythonhosted.org/packages/83/8a/4414c03d3f891739326e1783338e48fb49781cc915b2e0ee052aa490d586/markupsafe-3.0.3-cp311-cp311-win_amd64.whl", hash = "sha256:de8a88e63464af587c950061a5e6a67d3632e36df62b986892331d4620a35c01", size = 15077, upload-time = "2025-09-27T18:36:29.025Z" }, + { url = "https://files.pythonhosted.org/packages/35/73/893072b42e6862f319b5207adc9ae06070f095b358655f077f69a35601f0/markupsafe-3.0.3-cp311-cp311-win_arm64.whl", hash = "sha256:3b562dd9e9ea93f13d53989d23a7e775fdfd1066c33494ff43f5418bc8c58a5c", size = 13876, upload-time = "2025-09-27T18:36:29.954Z" }, + { url = "https://files.pythonhosted.org/packages/5a/72/147da192e38635ada20e0a2e1a51cf8823d2119ce8883f7053879c2199b5/markupsafe-3.0.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:d53197da72cc091b024dd97249dfc7794d6a56530370992a5e1a08983ad9230e", size = 11615, upload-time = "2025-09-27T18:36:30.854Z" }, + { url = "https://files.pythonhosted.org/packages/9a/81/7e4e08678a1f98521201c3079f77db69fb552acd56067661f8c2f534a718/markupsafe-3.0.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1872df69a4de6aead3491198eaf13810b565bdbeec3ae2dc8780f14458ec73ce", size = 12020, upload-time = "2025-09-27T18:36:31.971Z" }, + { url = "https://files.pythonhosted.org/packages/1e/2c/799f4742efc39633a1b54a92eec4082e4f815314869865d876824c257c1e/markupsafe-3.0.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3a7e8ae81ae39e62a41ec302f972ba6ae23a5c5396c8e60113e9066ef893da0d", size = 24332, upload-time = "2025-09-27T18:36:32.813Z" }, + { url = "https://files.pythonhosted.org/packages/3c/2e/8d0c2ab90a8c1d9a24f0399058ab8519a3279d1bd4289511d74e909f060e/markupsafe-3.0.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d6dd0be5b5b189d31db7cda48b91d7e0a9795f31430b7f271219ab30f1d3ac9d", size = 22947, upload-time = "2025-09-27T18:36:33.86Z" }, + { url = "https://files.pythonhosted.org/packages/2c/54/887f3092a85238093a0b2154bd629c89444f395618842e8b0c41783898ea/markupsafe-3.0.3-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:94c6f0bb423f739146aec64595853541634bde58b2135f27f61c1ffd1cd4d16a", size = 21962, upload-time = "2025-09-27T18:36:35.099Z" }, + { url = "https://files.pythonhosted.org/packages/c9/2f/336b8c7b6f4a4d95e91119dc8521402461b74a485558d8f238a68312f11c/markupsafe-3.0.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:be8813b57049a7dc738189df53d69395eba14fb99345e0a5994914a3864c8a4b", size = 23760, upload-time = "2025-09-27T18:36:36.001Z" }, + { url = "https://files.pythonhosted.org/packages/32/43/67935f2b7e4982ffb50a4d169b724d74b62a3964bc1a9a527f5ac4f1ee2b/markupsafe-3.0.3-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:83891d0e9fb81a825d9a6d61e3f07550ca70a076484292a70fde82c4b807286f", size = 21529, upload-time = "2025-09-27T18:36:36.906Z" }, + { url = "https://files.pythonhosted.org/packages/89/e0/4486f11e51bbba8b0c041098859e869e304d1c261e59244baa3d295d47b7/markupsafe-3.0.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:77f0643abe7495da77fb436f50f8dab76dbc6e5fd25d39589a0f1fe6548bfa2b", size = 23015, upload-time = "2025-09-27T18:36:37.868Z" }, + { url = "https://files.pythonhosted.org/packages/2f/e1/78ee7a023dac597a5825441ebd17170785a9dab23de95d2c7508ade94e0e/markupsafe-3.0.3-cp312-cp312-win32.whl", hash = "sha256:d88b440e37a16e651bda4c7c2b930eb586fd15ca7406cb39e211fcff3bf3017d", size = 14540, upload-time = "2025-09-27T18:36:38.761Z" }, + { url = "https://files.pythonhosted.org/packages/aa/5b/bec5aa9bbbb2c946ca2733ef9c4ca91c91b6a24580193e891b5f7dbe8e1e/markupsafe-3.0.3-cp312-cp312-win_amd64.whl", hash = "sha256:26a5784ded40c9e318cfc2bdb30fe164bdb8665ded9cd64d500a34fb42067b1c", size = 15105, upload-time = "2025-09-27T18:36:39.701Z" }, + { url = "https://files.pythonhosted.org/packages/e5/f1/216fc1bbfd74011693a4fd837e7026152e89c4bcf3e77b6692fba9923123/markupsafe-3.0.3-cp312-cp312-win_arm64.whl", hash = "sha256:35add3b638a5d900e807944a078b51922212fb3dedb01633a8defc4b01a3c85f", size = 13906, upload-time = "2025-09-27T18:36:40.689Z" }, + { url = "https://files.pythonhosted.org/packages/38/2f/907b9c7bbba283e68f20259574b13d005c121a0fa4c175f9bed27c4597ff/markupsafe-3.0.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e1cf1972137e83c5d4c136c43ced9ac51d0e124706ee1c8aa8532c1287fa8795", size = 11622, upload-time = "2025-09-27T18:36:41.777Z" }, + { url = "https://files.pythonhosted.org/packages/9c/d9/5f7756922cdd676869eca1c4e3c0cd0df60ed30199ffd775e319089cb3ed/markupsafe-3.0.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:116bb52f642a37c115f517494ea5feb03889e04df47eeff5b130b1808ce7c219", size = 12029, upload-time = "2025-09-27T18:36:43.257Z" }, + { url = "https://files.pythonhosted.org/packages/00/07/575a68c754943058c78f30db02ee03a64b3c638586fba6a6dd56830b30a3/markupsafe-3.0.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:133a43e73a802c5562be9bbcd03d090aa5a1fe899db609c29e8c8d815c5f6de6", size = 24374, upload-time = "2025-09-27T18:36:44.508Z" }, + { url = "https://files.pythonhosted.org/packages/a9/21/9b05698b46f218fc0e118e1f8168395c65c8a2c750ae2bab54fc4bd4e0e8/markupsafe-3.0.3-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ccfcd093f13f0f0b7fdd0f198b90053bf7b2f02a3927a30e63f3ccc9df56b676", size = 22980, upload-time = "2025-09-27T18:36:45.385Z" }, + { url = "https://files.pythonhosted.org/packages/7f/71/544260864f893f18b6827315b988c146b559391e6e7e8f7252839b1b846a/markupsafe-3.0.3-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:509fa21c6deb7a7a273d629cf5ec029bc209d1a51178615ddf718f5918992ab9", size = 21990, upload-time = "2025-09-27T18:36:46.916Z" }, + { url = "https://files.pythonhosted.org/packages/c2/28/b50fc2f74d1ad761af2f5dcce7492648b983d00a65b8c0e0cb457c82ebbe/markupsafe-3.0.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:a4afe79fb3de0b7097d81da19090f4df4f8d3a2b3adaa8764138aac2e44f3af1", size = 23784, upload-time = "2025-09-27T18:36:47.884Z" }, + { url = "https://files.pythonhosted.org/packages/ed/76/104b2aa106a208da8b17a2fb72e033a5a9d7073c68f7e508b94916ed47a9/markupsafe-3.0.3-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:795e7751525cae078558e679d646ae45574b47ed6e7771863fcc079a6171a0fc", size = 21588, upload-time = "2025-09-27T18:36:48.82Z" }, + { url = "https://files.pythonhosted.org/packages/b5/99/16a5eb2d140087ebd97180d95249b00a03aa87e29cc224056274f2e45fd6/markupsafe-3.0.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:8485f406a96febb5140bfeca44a73e3ce5116b2501ac54fe953e488fb1d03b12", size = 23041, upload-time = "2025-09-27T18:36:49.797Z" }, + { url = "https://files.pythonhosted.org/packages/19/bc/e7140ed90c5d61d77cea142eed9f9c303f4c4806f60a1044c13e3f1471d0/markupsafe-3.0.3-cp313-cp313-win32.whl", hash = "sha256:bdd37121970bfd8be76c5fb069c7751683bdf373db1ed6c010162b2a130248ed", size = 14543, upload-time = "2025-09-27T18:36:51.584Z" }, + { url = "https://files.pythonhosted.org/packages/05/73/c4abe620b841b6b791f2edc248f556900667a5a1cf023a6646967ae98335/markupsafe-3.0.3-cp313-cp313-win_amd64.whl", hash = "sha256:9a1abfdc021a164803f4d485104931fb8f8c1efd55bc6b748d2f5774e78b62c5", size = 15113, upload-time = "2025-09-27T18:36:52.537Z" }, + { url = "https://files.pythonhosted.org/packages/f0/3a/fa34a0f7cfef23cf9500d68cb7c32dd64ffd58a12b09225fb03dd37d5b80/markupsafe-3.0.3-cp313-cp313-win_arm64.whl", hash = "sha256:7e68f88e5b8799aa49c85cd116c932a1ac15caaa3f5db09087854d218359e485", size = 13911, upload-time = "2025-09-27T18:36:53.513Z" }, + { url = "https://files.pythonhosted.org/packages/e4/d7/e05cd7efe43a88a17a37b3ae96e79a19e846f3f456fe79c57ca61356ef01/markupsafe-3.0.3-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:218551f6df4868a8d527e3062d0fb968682fe92054e89978594c28e642c43a73", size = 11658, upload-time = "2025-09-27T18:36:54.819Z" }, + { url = "https://files.pythonhosted.org/packages/99/9e/e412117548182ce2148bdeacdda3bb494260c0b0184360fe0d56389b523b/markupsafe-3.0.3-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:3524b778fe5cfb3452a09d31e7b5adefeea8c5be1d43c4f810ba09f2ceb29d37", size = 12066, upload-time = "2025-09-27T18:36:55.714Z" }, + { url = "https://files.pythonhosted.org/packages/bc/e6/fa0ffcda717ef64a5108eaa7b4f5ed28d56122c9a6d70ab8b72f9f715c80/markupsafe-3.0.3-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4e885a3d1efa2eadc93c894a21770e4bc67899e3543680313b09f139e149ab19", size = 25639, upload-time = "2025-09-27T18:36:56.908Z" }, + { url = "https://files.pythonhosted.org/packages/96/ec/2102e881fe9d25fc16cb4b25d5f5cde50970967ffa5dddafdb771237062d/markupsafe-3.0.3-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8709b08f4a89aa7586de0aadc8da56180242ee0ada3999749b183aa23df95025", size = 23569, upload-time = "2025-09-27T18:36:57.913Z" }, + { url = "https://files.pythonhosted.org/packages/4b/30/6f2fce1f1f205fc9323255b216ca8a235b15860c34b6798f810f05828e32/markupsafe-3.0.3-cp313-cp313t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:b8512a91625c9b3da6f127803b166b629725e68af71f8184ae7e7d54686a56d6", size = 23284, upload-time = "2025-09-27T18:36:58.833Z" }, + { url = "https://files.pythonhosted.org/packages/58/47/4a0ccea4ab9f5dcb6f79c0236d954acb382202721e704223a8aafa38b5c8/markupsafe-3.0.3-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:9b79b7a16f7fedff2495d684f2b59b0457c3b493778c9eed31111be64d58279f", size = 24801, upload-time = "2025-09-27T18:36:59.739Z" }, + { url = "https://files.pythonhosted.org/packages/6a/70/3780e9b72180b6fecb83a4814d84c3bf4b4ae4bf0b19c27196104149734c/markupsafe-3.0.3-cp313-cp313t-musllinux_1_2_riscv64.whl", hash = "sha256:12c63dfb4a98206f045aa9563db46507995f7ef6d83b2f68eda65c307c6829eb", size = 22769, upload-time = "2025-09-27T18:37:00.719Z" }, + { url = "https://files.pythonhosted.org/packages/98/c5/c03c7f4125180fc215220c035beac6b9cb684bc7a067c84fc69414d315f5/markupsafe-3.0.3-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:8f71bc33915be5186016f675cd83a1e08523649b0e33efdb898db577ef5bb009", size = 23642, upload-time = "2025-09-27T18:37:01.673Z" }, + { url = "https://files.pythonhosted.org/packages/80/d6/2d1b89f6ca4bff1036499b1e29a1d02d282259f3681540e16563f27ebc23/markupsafe-3.0.3-cp313-cp313t-win32.whl", hash = "sha256:69c0b73548bc525c8cb9a251cddf1931d1db4d2258e9599c28c07ef3580ef354", size = 14612, upload-time = "2025-09-27T18:37:02.639Z" }, + { url = "https://files.pythonhosted.org/packages/2b/98/e48a4bfba0a0ffcf9925fe2d69240bfaa19c6f7507b8cd09c70684a53c1e/markupsafe-3.0.3-cp313-cp313t-win_amd64.whl", hash = "sha256:1b4b79e8ebf6b55351f0d91fe80f893b4743f104bff22e90697db1590e47a218", size = 15200, upload-time = "2025-09-27T18:37:03.582Z" }, + { url = "https://files.pythonhosted.org/packages/0e/72/e3cc540f351f316e9ed0f092757459afbc595824ca724cbc5a5d4263713f/markupsafe-3.0.3-cp313-cp313t-win_arm64.whl", hash = "sha256:ad2cf8aa28b8c020ab2fc8287b0f823d0a7d8630784c31e9ee5edea20f406287", size = 13973, upload-time = "2025-09-27T18:37:04.929Z" }, +] + +[[package]] +name = "minikanren" +version = "1.0.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cons" }, + { name = "etuples" }, + { name = "logical-unification" }, + { name = "multipledispatch" }, + { name = "toolz" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ab/3d/bbab3c19771efbfafc52de98db8ad7cf3c2c444bbbd7241c2b06e9f305bc/minikanren-1.0.5.tar.gz", hash = "sha256:c030e3e9a3fa5f372f84b66966776a8dc63b16b98768b78be0401982b892e00d", size = 21699, upload-time = "2025-06-24T21:38:51.439Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bb/02/5e9ae831946db26f172e03e896fe83b07c5ca643df2b32c1b81557f0e77f/minikanren-1.0.5-py3-none-any.whl", hash = "sha256:22c24f4fdf009a56e30655787af45c90f0704bcc24e8d3e651378675b4bccb21", size = 24072, upload-time = "2025-06-24T21:38:50.113Z" }, +] + +[[package]] +name = "ml-dtypes" +version = "0.5.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/78/a7/aad060393123cfb383956dca68402aff3db1e1caffd5764887ed5153f41b/ml_dtypes-0.5.3.tar.gz", hash = "sha256:95ce33057ba4d05df50b1f3cfefab22e351868a843b3b15a46c65836283670c9", size = 692316, upload-time = "2025-07-29T18:39:19.454Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/af/f1/720cb1409b5d0c05cff9040c0e9fba73fa4c67897d33babf905d5d46a070/ml_dtypes-0.5.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:4a177b882667c69422402df6ed5c3428ce07ac2c1f844d8a1314944651439458", size = 667412, upload-time = "2025-07-29T18:38:25.275Z" }, + { url = "https://files.pythonhosted.org/packages/6a/d5/05861ede5d299f6599f86e6bc1291714e2116d96df003cfe23cc54bcc568/ml_dtypes-0.5.3-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9849ce7267444c0a717c80c6900997de4f36e2815ce34ac560a3edb2d9a64cd2", size = 4964606, upload-time = "2025-07-29T18:38:27.045Z" }, + { url = "https://files.pythonhosted.org/packages/db/dc/72992b68de367741bfab8df3b3fe7c29f982b7279d341aa5bf3e7ef737ea/ml_dtypes-0.5.3-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c3f5ae0309d9f888fd825c2e9d0241102fadaca81d888f26f845bc8c13c1e4ee", size = 4938435, upload-time = "2025-07-29T18:38:29.193Z" }, + { url = "https://files.pythonhosted.org/packages/81/1c/d27a930bca31fb07d975a2d7eaf3404f9388114463b9f15032813c98f893/ml_dtypes-0.5.3-cp311-cp311-win_amd64.whl", hash = "sha256:58e39349d820b5702bb6f94ea0cb2dc8ec62ee81c0267d9622067d8333596a46", size = 206334, upload-time = "2025-07-29T18:38:30.687Z" }, + { url = "https://files.pythonhosted.org/packages/1a/d8/6922499effa616012cb8dc445280f66d100a7ff39b35c864cfca019b3f89/ml_dtypes-0.5.3-cp311-cp311-win_arm64.whl", hash = "sha256:66c2756ae6cfd7f5224e355c893cfd617fa2f747b8bbd8996152cbdebad9a184", size = 157584, upload-time = "2025-07-29T18:38:32.187Z" }, + { url = "https://files.pythonhosted.org/packages/0d/eb/bc07c88a6ab002b4635e44585d80fa0b350603f11a2097c9d1bfacc03357/ml_dtypes-0.5.3-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:156418abeeda48ea4797db6776db3c5bdab9ac7be197c1233771e0880c304057", size = 663864, upload-time = "2025-07-29T18:38:33.777Z" }, + { url = "https://files.pythonhosted.org/packages/cf/89/11af9b0f21b99e6386b6581ab40fb38d03225f9de5f55cf52097047e2826/ml_dtypes-0.5.3-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1db60c154989af253f6c4a34e8a540c2c9dce4d770784d426945e09908fbb177", size = 4951313, upload-time = "2025-07-29T18:38:36.45Z" }, + { url = "https://files.pythonhosted.org/packages/d8/a9/b98b86426c24900b0c754aad006dce2863df7ce0bb2bcc2c02f9cc7e8489/ml_dtypes-0.5.3-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1b255acada256d1fa8c35ed07b5f6d18bc21d1556f842fbc2d5718aea2cd9e55", size = 4928805, upload-time = "2025-07-29T18:38:38.29Z" }, + { url = "https://files.pythonhosted.org/packages/50/c1/85e6be4fc09c6175f36fb05a45917837f30af9a5146a5151cb3a3f0f9e09/ml_dtypes-0.5.3-cp312-cp312-win_amd64.whl", hash = "sha256:da65e5fd3eea434ccb8984c3624bc234ddcc0d9f4c81864af611aaebcc08a50e", size = 208182, upload-time = "2025-07-29T18:38:39.72Z" }, + { url = "https://files.pythonhosted.org/packages/9e/17/cf5326d6867be057f232d0610de1458f70a8ce7b6290e4b4a277ea62b4cd/ml_dtypes-0.5.3-cp312-cp312-win_arm64.whl", hash = "sha256:8bb9cd1ce63096567f5f42851f5843b5a0ea11511e50039a7649619abfb4ba6d", size = 161560, upload-time = "2025-07-29T18:38:41.072Z" }, + { url = "https://files.pythonhosted.org/packages/2d/87/1bcc98a66de7b2455dfb292f271452cac9edc4e870796e0d87033524d790/ml_dtypes-0.5.3-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:5103856a225465371fe119f2fef737402b705b810bd95ad5f348e6e1a6ae21af", size = 663781, upload-time = "2025-07-29T18:38:42.984Z" }, + { url = "https://files.pythonhosted.org/packages/fd/2c/bd2a79ba7c759ee192b5601b675b180a3fd6ccf48ffa27fe1782d280f1a7/ml_dtypes-0.5.3-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4cae435a68861660af81fa3c5af16b70ca11a17275c5b662d9c6f58294e0f113", size = 4956217, upload-time = "2025-07-29T18:38:44.65Z" }, + { url = "https://files.pythonhosted.org/packages/14/f3/091ba84e5395d7fe5b30c081a44dec881cd84b408db1763ee50768b2ab63/ml_dtypes-0.5.3-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6936283b56d74fbec431ca57ce58a90a908fdbd14d4e2d22eea6d72bb208a7b7", size = 4933109, upload-time = "2025-07-29T18:38:46.405Z" }, + { url = "https://files.pythonhosted.org/packages/bc/24/054036dbe32c43295382c90a1363241684c4d6aaa1ecc3df26bd0c8d5053/ml_dtypes-0.5.3-cp313-cp313-win_amd64.whl", hash = "sha256:d0f730a17cf4f343b2c7ad50cee3bd19e969e793d2be6ed911f43086460096e4", size = 208187, upload-time = "2025-07-29T18:38:48.24Z" }, + { url = "https://files.pythonhosted.org/packages/a6/3d/7dc3ec6794a4a9004c765e0c341e32355840b698f73fd2daff46f128afc1/ml_dtypes-0.5.3-cp313-cp313-win_arm64.whl", hash = "sha256:2db74788fc01914a3c7f7da0763427280adfc9cd377e9604b6b64eb8097284bd", size = 161559, upload-time = "2025-07-29T18:38:50.493Z" }, + { url = "https://files.pythonhosted.org/packages/12/91/e6c7a0d67a152b9330445f9f0cf8ae6eee9b83f990b8c57fe74631e42a90/ml_dtypes-0.5.3-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:93c36a08a6d158db44f2eb9ce3258e53f24a9a4a695325a689494f0fdbc71770", size = 689321, upload-time = "2025-07-29T18:38:52.03Z" }, + { url = "https://files.pythonhosted.org/packages/9e/6c/b7b94b84a104a5be1883305b87d4c6bd6ae781504474b4cca067cb2340ec/ml_dtypes-0.5.3-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0e44a3761f64bc009d71ddb6d6c71008ba21b53ab6ee588dadab65e2fa79eafc", size = 5274495, upload-time = "2025-07-29T18:38:53.797Z" }, + { url = "https://files.pythonhosted.org/packages/5b/38/6266604dffb43378055394ea110570cf261a49876fc48f548dfe876f34cc/ml_dtypes-0.5.3-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bdf40d2aaabd3913dec11840f0d0ebb1b93134f99af6a0a4fd88ffe924928ab4", size = 5285422, upload-time = "2025-07-29T18:38:56.603Z" }, +] + +[[package]] +name = "multipledispatch" +version = "1.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/fe/3e/a62c3b824c7dec33c4a1578bcc842e6c30300051033a4e5975ed86cc2536/multipledispatch-1.0.0.tar.gz", hash = "sha256:5c839915465c68206c3e9c473357908216c28383b425361e5d144594bf85a7e0", size = 12385, upload-time = "2023-06-27T16:45:11.074Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/51/c0/00c9809d8b9346eb238a6bbd5f83e846a4ce4503da94a4c08cb7284c325b/multipledispatch-1.0.0-py3-none-any.whl", hash = "sha256:0c53cd8b077546da4e48869f49b13164bebafd0c2a5afceb6bb6a316e7fb46e4", size = 12818, upload-time = "2023-06-27T16:45:09.418Z" }, +] + +[[package]] +name = "nodeenv" +version = "1.9.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/43/16/fc88b08840de0e0a72a2f9d8c6bae36be573e475a6326ae854bcc549fc45/nodeenv-1.9.1.tar.gz", hash = "sha256:6ec12890a2dab7946721edbfbcd91f3319c6ccc9aec47be7c7e6b7011ee6645f", size = 47437, upload-time = "2024-06-04T18:44:11.171Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d2/1d/1b658dbd2b9fa9c4c9f32accbfc0205d532c8c6194dc0f2a4c0428e7128a/nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9", size = 22314, upload-time = "2024-06-04T18:44:08.352Z" }, +] + +[[package]] +name = "numba" +version = "0.62.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "llvmlite" }, + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a3/20/33dbdbfe60e5fd8e3dbfde299d106279a33d9f8308346022316781368591/numba-0.62.1.tar.gz", hash = "sha256:7b774242aa890e34c21200a1fc62e5b5757d5286267e71103257f4e2af0d5161", size = 2749817, upload-time = "2025-09-29T10:46:31.551Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/dd/5f/8b3491dd849474f55e33c16ef55678ace1455c490555337899c35826836c/numba-0.62.1-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:f43e24b057714e480fe44bc6031de499e7cf8150c63eb461192caa6cc8530bc8", size = 2684279, upload-time = "2025-09-29T10:43:37.213Z" }, + { url = "https://files.pythonhosted.org/packages/bf/18/71969149bfeb65a629e652b752b80167fe8a6a6f6e084f1f2060801f7f31/numba-0.62.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:57cbddc53b9ee02830b828a8428757f5c218831ccc96490a314ef569d8342b7b", size = 2687330, upload-time = "2025-09-29T10:43:59.601Z" }, + { url = "https://files.pythonhosted.org/packages/0e/7d/403be3fecae33088027bc8a95dc80a2fda1e3beff3e0e5fc4374ada3afbe/numba-0.62.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:604059730c637c7885386521bb1b0ddcbc91fd56131a6dcc54163d6f1804c872", size = 3739727, upload-time = "2025-09-29T10:42:45.922Z" }, + { url = "https://files.pythonhosted.org/packages/e0/c3/3d910d08b659a6d4c62ab3cd8cd93c4d8b7709f55afa0d79a87413027ff6/numba-0.62.1-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d6c540880170bee817011757dc9049dba5a29db0c09b4d2349295991fe3ee55f", size = 3445490, upload-time = "2025-09-29T10:43:12.692Z" }, + { url = "https://files.pythonhosted.org/packages/5b/82/9d425c2f20d9f0a37f7cb955945a553a00fa06a2b025856c3550227c5543/numba-0.62.1-cp311-cp311-win_amd64.whl", hash = "sha256:03de6d691d6b6e2b76660ba0f38f37b81ece8b2cc524a62f2a0cfae2bfb6f9da", size = 2745550, upload-time = "2025-09-29T10:44:20.571Z" }, + { url = "https://files.pythonhosted.org/packages/5e/fa/30fa6873e9f821c0ae755915a3ca444e6ff8d6a7b6860b669a3d33377ac7/numba-0.62.1-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:1b743b32f8fa5fff22e19c2e906db2f0a340782caf024477b97801b918cf0494", size = 2685346, upload-time = "2025-09-29T10:43:43.677Z" }, + { url = "https://files.pythonhosted.org/packages/a9/d5/504ce8dc46e0dba2790c77e6b878ee65b60fe3e7d6d0006483ef6fde5a97/numba-0.62.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:90fa21b0142bcf08ad8e32a97d25d0b84b1e921bc9423f8dda07d3652860eef6", size = 2688139, upload-time = "2025-09-29T10:44:04.894Z" }, + { url = "https://files.pythonhosted.org/packages/50/5f/6a802741176c93f2ebe97ad90751894c7b0c922b52ba99a4395e79492205/numba-0.62.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6ef84d0ac19f1bf80431347b6f4ce3c39b7ec13f48f233a48c01e2ec06ecbc59", size = 3796453, upload-time = "2025-09-29T10:42:52.771Z" }, + { url = "https://files.pythonhosted.org/packages/7e/df/efd21527d25150c4544eccc9d0b7260a5dec4b7e98b5a581990e05a133c0/numba-0.62.1-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9315cc5e441300e0ca07c828a627d92a6802bcbf27c5487f31ae73783c58da53", size = 3496451, upload-time = "2025-09-29T10:43:19.279Z" }, + { url = "https://files.pythonhosted.org/packages/80/44/79bfdab12a02796bf4f1841630355c82b5a69933b1d50eb15c7fa37dabe8/numba-0.62.1-cp312-cp312-win_amd64.whl", hash = "sha256:44e3aa6228039992f058f5ebfcfd372c83798e9464297bdad8cc79febcf7891e", size = 2745552, upload-time = "2025-09-29T10:44:26.399Z" }, + { url = "https://files.pythonhosted.org/packages/22/76/501ea2c07c089ef1386868f33dff2978f43f51b854e34397b20fc55e0a58/numba-0.62.1-cp313-cp313-macosx_10_15_x86_64.whl", hash = "sha256:b72489ba8411cc9fdcaa2458d8f7677751e94f0109eeb53e5becfdc818c64afb", size = 2685766, upload-time = "2025-09-29T10:43:49.161Z" }, + { url = "https://files.pythonhosted.org/packages/80/68/444986ed95350c0611d5c7b46828411c222ce41a0c76707c36425d27ce29/numba-0.62.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:44a1412095534a26fb5da2717bc755b57da5f3053965128fe3dc286652cc6a92", size = 2688741, upload-time = "2025-09-29T10:44:10.07Z" }, + { url = "https://files.pythonhosted.org/packages/78/7e/bf2e3634993d57f95305c7cee4c9c6cb3c9c78404ee7b49569a0dfecfe33/numba-0.62.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:8c9460b9e936c5bd2f0570e20a0a5909ee6e8b694fd958b210e3bde3a6dba2d7", size = 3804576, upload-time = "2025-09-29T10:42:59.53Z" }, + { url = "https://files.pythonhosted.org/packages/e8/b6/8a1723fff71f63bbb1354bdc60a1513a068acc0f5322f58da6f022d20247/numba-0.62.1-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:728f91a874192df22d74e3fd42c12900b7ce7190b1aad3574c6c61b08313e4c5", size = 3503367, upload-time = "2025-09-29T10:43:26.326Z" }, + { url = "https://files.pythonhosted.org/packages/9c/ec/9d414e7a80d6d1dc4af0e07c6bfe293ce0b04ea4d0ed6c45dad9bd6e72eb/numba-0.62.1-cp313-cp313-win_amd64.whl", hash = "sha256:bbf3f88b461514287df66bc8d0307e949b09f2b6f67da92265094e8fa1282dd8", size = 2745529, upload-time = "2025-09-29T10:44:31.738Z" }, +] + +[[package]] +name = "numpy" +version = "2.3.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b5/f4/098d2270d52b41f1bd7db9fc288aaa0400cb48c2a3e2af6fa365d9720947/numpy-2.3.4.tar.gz", hash = "sha256:a7d018bfedb375a8d979ac758b120ba846a7fe764911a64465fd87b8729f4a6a", size = 20582187, upload-time = "2025-10-15T16:18:11.77Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/60/e7/0e07379944aa8afb49a556a2b54587b828eb41dc9adc56fb7615b678ca53/numpy-2.3.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:e78aecd2800b32e8347ce49316d3eaf04aed849cd5b38e0af39f829a4e59f5eb", size = 21259519, upload-time = "2025-10-15T16:15:19.012Z" }, + { url = "https://files.pythonhosted.org/packages/d0/cb/5a69293561e8819b09e34ed9e873b9a82b5f2ade23dce4c51dc507f6cfe1/numpy-2.3.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7fd09cc5d65bda1e79432859c40978010622112e9194e581e3415a3eccc7f43f", size = 14452796, upload-time = "2025-10-15T16:15:23.094Z" }, + { url = "https://files.pythonhosted.org/packages/e4/04/ff11611200acd602a1e5129e36cfd25bf01ad8e5cf927baf2e90236eb02e/numpy-2.3.4-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:1b219560ae2c1de48ead517d085bc2d05b9433f8e49d0955c82e8cd37bd7bf36", size = 5381639, upload-time = "2025-10-15T16:15:25.572Z" }, + { url = "https://files.pythonhosted.org/packages/ea/77/e95c757a6fe7a48d28a009267408e8aa382630cc1ad1db7451b3bc21dbb4/numpy-2.3.4-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:bafa7d87d4c99752d07815ed7a2c0964f8ab311eb8168f41b910bd01d15b6032", size = 6914296, upload-time = "2025-10-15T16:15:27.079Z" }, + { url = "https://files.pythonhosted.org/packages/a3/d2/137c7b6841c942124eae921279e5c41b1c34bab0e6fc60c7348e69afd165/numpy-2.3.4-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:36dc13af226aeab72b7abad501d370d606326a0029b9f435eacb3b8c94b8a8b7", size = 14591904, upload-time = "2025-10-15T16:15:29.044Z" }, + { url = "https://files.pythonhosted.org/packages/bb/32/67e3b0f07b0aba57a078c4ab777a9e8e6bc62f24fb53a2337f75f9691699/numpy-2.3.4-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a7b2f9a18b5ff9824a6af80de4f37f4ec3c2aab05ef08f51c77a093f5b89adda", size = 16939602, upload-time = "2025-10-15T16:15:31.106Z" }, + { url = "https://files.pythonhosted.org/packages/95/22/9639c30e32c93c4cee3ccdb4b09c2d0fbff4dcd06d36b357da06146530fb/numpy-2.3.4-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:9984bd645a8db6ca15d850ff996856d8762c51a2239225288f08f9050ca240a0", size = 16372661, upload-time = "2025-10-15T16:15:33.546Z" }, + { url = "https://files.pythonhosted.org/packages/12/e9/a685079529be2b0156ae0c11b13d6be647743095bb51d46589e95be88086/numpy-2.3.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:64c5825affc76942973a70acf438a8ab618dbd692b84cd5ec40a0a0509edc09a", size = 18884682, upload-time = "2025-10-15T16:15:36.105Z" }, + { url = "https://files.pythonhosted.org/packages/cf/85/f6f00d019b0cc741e64b4e00ce865a57b6bed945d1bbeb1ccadbc647959b/numpy-2.3.4-cp311-cp311-win32.whl", hash = "sha256:ed759bf7a70342f7817d88376eb7142fab9fef8320d6019ef87fae05a99874e1", size = 6570076, upload-time = "2025-10-15T16:15:38.225Z" }, + { url = "https://files.pythonhosted.org/packages/7d/10/f8850982021cb90e2ec31990291f9e830ce7d94eef432b15066e7cbe0bec/numpy-2.3.4-cp311-cp311-win_amd64.whl", hash = "sha256:faba246fb30ea2a526c2e9645f61612341de1a83fb1e0c5edf4ddda5a9c10996", size = 13089358, upload-time = "2025-10-15T16:15:40.404Z" }, + { url = "https://files.pythonhosted.org/packages/d1/ad/afdd8351385edf0b3445f9e24210a9c3971ef4de8fd85155462fc4321d79/numpy-2.3.4-cp311-cp311-win_arm64.whl", hash = "sha256:4c01835e718bcebe80394fd0ac66c07cbb90147ebbdad3dcecd3f25de2ae7e2c", size = 10462292, upload-time = "2025-10-15T16:15:42.896Z" }, + { url = "https://files.pythonhosted.org/packages/96/7a/02420400b736f84317e759291b8edaeee9dc921f72b045475a9cbdb26b17/numpy-2.3.4-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:ef1b5a3e808bc40827b5fa2c8196151a4c5abe110e1726949d7abddfe5c7ae11", size = 20957727, upload-time = "2025-10-15T16:15:44.9Z" }, + { url = "https://files.pythonhosted.org/packages/18/90/a014805d627aa5750f6f0e878172afb6454552da929144b3c07fcae1bb13/numpy-2.3.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c2f91f496a87235c6aaf6d3f3d89b17dba64996abadccb289f48456cff931ca9", size = 14187262, upload-time = "2025-10-15T16:15:47.761Z" }, + { url = "https://files.pythonhosted.org/packages/c7/e4/0a94b09abe89e500dc748e7515f21a13e30c5c3fe3396e6d4ac108c25fca/numpy-2.3.4-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:f77e5b3d3da652b474cc80a14084927a5e86a5eccf54ca8ca5cbd697bf7f2667", size = 5115992, upload-time = "2025-10-15T16:15:50.144Z" }, + { url = "https://files.pythonhosted.org/packages/88/dd/db77c75b055c6157cbd4f9c92c4458daef0dd9cbe6d8d2fe7f803cb64c37/numpy-2.3.4-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:8ab1c5f5ee40d6e01cbe96de5863e39b215a4d24e7d007cad56c7184fdf4aeef", size = 6648672, upload-time = "2025-10-15T16:15:52.442Z" }, + { url = "https://files.pythonhosted.org/packages/e1/e6/e31b0d713719610e406c0ea3ae0d90760465b086da8783e2fd835ad59027/numpy-2.3.4-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:77b84453f3adcb994ddbd0d1c5d11db2d6bda1a2b7fd5ac5bd4649d6f5dc682e", size = 14284156, upload-time = "2025-10-15T16:15:54.351Z" }, + { url = "https://files.pythonhosted.org/packages/f9/58/30a85127bfee6f108282107caf8e06a1f0cc997cb6b52cdee699276fcce4/numpy-2.3.4-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4121c5beb58a7f9e6dfdee612cb24f4df5cd4db6e8261d7f4d7450a997a65d6a", size = 16641271, upload-time = "2025-10-15T16:15:56.67Z" }, + { url = "https://files.pythonhosted.org/packages/06/f2/2e06a0f2adf23e3ae29283ad96959267938d0efd20a2e25353b70065bfec/numpy-2.3.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:65611ecbb00ac9846efe04db15cbe6186f562f6bb7e5e05f077e53a599225d16", size = 16059531, upload-time = "2025-10-15T16:15:59.412Z" }, + { url = "https://files.pythonhosted.org/packages/b0/e7/b106253c7c0d5dc352b9c8fab91afd76a93950998167fa3e5afe4ef3a18f/numpy-2.3.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:dabc42f9c6577bcc13001b8810d300fe814b4cfbe8a92c873f269484594f9786", size = 18578983, upload-time = "2025-10-15T16:16:01.804Z" }, + { url = "https://files.pythonhosted.org/packages/73/e3/04ecc41e71462276ee867ccbef26a4448638eadecf1bc56772c9ed6d0255/numpy-2.3.4-cp312-cp312-win32.whl", hash = "sha256:a49d797192a8d950ca59ee2d0337a4d804f713bb5c3c50e8db26d49666e351dc", size = 6291380, upload-time = "2025-10-15T16:16:03.938Z" }, + { url = "https://files.pythonhosted.org/packages/3d/a8/566578b10d8d0e9955b1b6cd5db4e9d4592dd0026a941ff7994cedda030a/numpy-2.3.4-cp312-cp312-win_amd64.whl", hash = "sha256:985f1e46358f06c2a09921e8921e2c98168ed4ae12ccd6e5e87a4f1857923f32", size = 12787999, upload-time = "2025-10-15T16:16:05.801Z" }, + { url = "https://files.pythonhosted.org/packages/58/22/9c903a957d0a8071b607f5b1bff0761d6e608b9a965945411f867d515db1/numpy-2.3.4-cp312-cp312-win_arm64.whl", hash = "sha256:4635239814149e06e2cb9db3dd584b2fa64316c96f10656983b8026a82e6e4db", size = 10197412, upload-time = "2025-10-15T16:16:07.854Z" }, + { url = "https://files.pythonhosted.org/packages/57/7e/b72610cc91edf138bc588df5150957a4937221ca6058b825b4725c27be62/numpy-2.3.4-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:c090d4860032b857d94144d1a9976b8e36709e40386db289aaf6672de2a81966", size = 20950335, upload-time = "2025-10-15T16:16:10.304Z" }, + { url = "https://files.pythonhosted.org/packages/3e/46/bdd3370dcea2f95ef14af79dbf81e6927102ddf1cc54adc0024d61252fd9/numpy-2.3.4-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a13fc473b6db0be619e45f11f9e81260f7302f8d180c49a22b6e6120022596b3", size = 14179878, upload-time = "2025-10-15T16:16:12.595Z" }, + { url = "https://files.pythonhosted.org/packages/ac/01/5a67cb785bda60f45415d09c2bc245433f1c68dd82eef9c9002c508b5a65/numpy-2.3.4-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:3634093d0b428e6c32c3a69b78e554f0cd20ee420dcad5a9f3b2a63762ce4197", size = 5108673, upload-time = "2025-10-15T16:16:14.877Z" }, + { url = "https://files.pythonhosted.org/packages/c2/cd/8428e23a9fcebd33988f4cb61208fda832800ca03781f471f3727a820704/numpy-2.3.4-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:043885b4f7e6e232d7df4f51ffdef8c36320ee9d5f227b380ea636722c7ed12e", size = 6641438, upload-time = "2025-10-15T16:16:16.805Z" }, + { url = "https://files.pythonhosted.org/packages/3e/d1/913fe563820f3c6b079f992458f7331278dcd7ba8427e8e745af37ddb44f/numpy-2.3.4-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4ee6a571d1e4f0ea6d5f22d6e5fbd6ed1dc2b18542848e1e7301bd190500c9d7", size = 14281290, upload-time = "2025-10-15T16:16:18.764Z" }, + { url = "https://files.pythonhosted.org/packages/9e/7e/7d306ff7cb143e6d975cfa7eb98a93e73495c4deabb7d1b5ecf09ea0fd69/numpy-2.3.4-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fc8a63918b04b8571789688b2780ab2b4a33ab44bfe8ccea36d3eba51228c953", size = 16636543, upload-time = "2025-10-15T16:16:21.072Z" }, + { url = "https://files.pythonhosted.org/packages/47/6a/8cfc486237e56ccfb0db234945552a557ca266f022d281a2f577b98e955c/numpy-2.3.4-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:40cc556d5abbc54aabe2b1ae287042d7bdb80c08edede19f0c0afb36ae586f37", size = 16056117, upload-time = "2025-10-15T16:16:23.369Z" }, + { url = "https://files.pythonhosted.org/packages/b1/0e/42cb5e69ea901e06ce24bfcc4b5664a56f950a70efdcf221f30d9615f3f3/numpy-2.3.4-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:ecb63014bb7f4ce653f8be7f1df8cbc6093a5a2811211770f6606cc92b5a78fd", size = 18577788, upload-time = "2025-10-15T16:16:27.496Z" }, + { url = "https://files.pythonhosted.org/packages/86/92/41c3d5157d3177559ef0a35da50f0cda7fa071f4ba2306dd36818591a5bc/numpy-2.3.4-cp313-cp313-win32.whl", hash = "sha256:e8370eb6925bb8c1c4264fec52b0384b44f675f191df91cbe0140ec9f0955646", size = 6282620, upload-time = "2025-10-15T16:16:29.811Z" }, + { url = "https://files.pythonhosted.org/packages/09/97/fd421e8bc50766665ad35536c2bb4ef916533ba1fdd053a62d96cc7c8b95/numpy-2.3.4-cp313-cp313-win_amd64.whl", hash = "sha256:56209416e81a7893036eea03abcb91c130643eb14233b2515c90dcac963fe99d", size = 12784672, upload-time = "2025-10-15T16:16:31.589Z" }, + { url = "https://files.pythonhosted.org/packages/ad/df/5474fb2f74970ca8eb978093969b125a84cc3d30e47f82191f981f13a8a0/numpy-2.3.4-cp313-cp313-win_arm64.whl", hash = "sha256:a700a4031bc0fd6936e78a752eefb79092cecad2599ea9c8039c548bc097f9bc", size = 10196702, upload-time = "2025-10-15T16:16:33.902Z" }, + { url = "https://files.pythonhosted.org/packages/11/83/66ac031464ec1767ea3ed48ce40f615eb441072945e98693bec0bcd056cc/numpy-2.3.4-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:86966db35c4040fdca64f0816a1c1dd8dbd027d90fca5a57e00e1ca4cd41b879", size = 21049003, upload-time = "2025-10-15T16:16:36.101Z" }, + { url = "https://files.pythonhosted.org/packages/5f/99/5b14e0e686e61371659a1d5bebd04596b1d72227ce36eed121bb0aeab798/numpy-2.3.4-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:838f045478638b26c375ee96ea89464d38428c69170360b23a1a50fa4baa3562", size = 14302980, upload-time = "2025-10-15T16:16:39.124Z" }, + { url = "https://files.pythonhosted.org/packages/2c/44/e9486649cd087d9fc6920e3fc3ac2aba10838d10804b1e179fb7cbc4e634/numpy-2.3.4-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:d7315ed1dab0286adca467377c8381cd748f3dc92235f22a7dfc42745644a96a", size = 5231472, upload-time = "2025-10-15T16:16:41.168Z" }, + { url = "https://files.pythonhosted.org/packages/3e/51/902b24fa8887e5fe2063fd61b1895a476d0bbf46811ab0c7fdf4bd127345/numpy-2.3.4-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:84f01a4d18b2cc4ade1814a08e5f3c907b079c847051d720fad15ce37aa930b6", size = 6739342, upload-time = "2025-10-15T16:16:43.777Z" }, + { url = "https://files.pythonhosted.org/packages/34/f1/4de9586d05b1962acdcdb1dc4af6646361a643f8c864cef7c852bf509740/numpy-2.3.4-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:817e719a868f0dacde4abdfc5c1910b301877970195db9ab6a5e2c4bd5b121f7", size = 14354338, upload-time = "2025-10-15T16:16:46.081Z" }, + { url = "https://files.pythonhosted.org/packages/1f/06/1c16103b425de7969d5a76bdf5ada0804b476fed05d5f9e17b777f1cbefd/numpy-2.3.4-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:85e071da78d92a214212cacea81c6da557cab307f2c34b5f85b628e94803f9c0", size = 16702392, upload-time = "2025-10-15T16:16:48.455Z" }, + { url = "https://files.pythonhosted.org/packages/34/b2/65f4dc1b89b5322093572b6e55161bb42e3e0487067af73627f795cc9d47/numpy-2.3.4-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:2ec646892819370cf3558f518797f16597b4e4669894a2ba712caccc9da53f1f", size = 16134998, upload-time = "2025-10-15T16:16:51.114Z" }, + { url = "https://files.pythonhosted.org/packages/d4/11/94ec578896cdb973aaf56425d6c7f2aff4186a5c00fac15ff2ec46998b46/numpy-2.3.4-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:035796aaaddfe2f9664b9a9372f089cfc88bd795a67bd1bfe15e6e770934cf64", size = 18651574, upload-time = "2025-10-15T16:16:53.429Z" }, + { url = "https://files.pythonhosted.org/packages/62/b7/7efa763ab33dbccf56dade36938a77345ce8e8192d6b39e470ca25ff3cd0/numpy-2.3.4-cp313-cp313t-win32.whl", hash = "sha256:fea80f4f4cf83b54c3a051f2f727870ee51e22f0248d3114b8e755d160b38cfb", size = 6413135, upload-time = "2025-10-15T16:16:55.992Z" }, + { url = "https://files.pythonhosted.org/packages/43/70/aba4c38e8400abcc2f345e13d972fb36c26409b3e644366db7649015f291/numpy-2.3.4-cp313-cp313t-win_amd64.whl", hash = "sha256:15eea9f306b98e0be91eb344a94c0e630689ef302e10c2ce5f7e11905c704f9c", size = 12928582, upload-time = "2025-10-15T16:16:57.943Z" }, + { url = "https://files.pythonhosted.org/packages/67/63/871fad5f0073fc00fbbdd7232962ea1ac40eeaae2bba66c76214f7954236/numpy-2.3.4-cp313-cp313t-win_arm64.whl", hash = "sha256:b6c231c9c2fadbae4011ca5e7e83e12dc4a5072f1a1d85a0a7b3ed754d145a40", size = 10266691, upload-time = "2025-10-15T16:17:00.048Z" }, + { url = "https://files.pythonhosted.org/packages/b1/b6/64898f51a86ec88ca1257a59c1d7fd077b60082a119affefcdf1dd0df8ca/numpy-2.3.4-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:6e274603039f924c0fe5cb73438fa9246699c78a6df1bd3decef9ae592ae1c05", size = 21131552, upload-time = "2025-10-15T16:17:55.845Z" }, + { url = "https://files.pythonhosted.org/packages/ce/4c/f135dc6ebe2b6a3c77f4e4838fa63d350f85c99462012306ada1bd4bc460/numpy-2.3.4-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:d149aee5c72176d9ddbc6803aef9c0f6d2ceeea7626574fc68518da5476fa346", size = 14377796, upload-time = "2025-10-15T16:17:58.308Z" }, + { url = "https://files.pythonhosted.org/packages/d0/a4/f33f9c23fcc13dd8412fc8614559b5b797e0aba9d8e01dfa8bae10c84004/numpy-2.3.4-pp311-pypy311_pp73-macosx_14_0_arm64.whl", hash = "sha256:6d34ed9db9e6395bb6cd33286035f73a59b058169733a9db9f85e650b88df37e", size = 5306904, upload-time = "2025-10-15T16:18:00.596Z" }, + { url = "https://files.pythonhosted.org/packages/28/af/c44097f25f834360f9fb960fa082863e0bad14a42f36527b2a121abdec56/numpy-2.3.4-pp311-pypy311_pp73-macosx_14_0_x86_64.whl", hash = "sha256:fdebe771ca06bb8d6abce84e51dca9f7921fe6ad34a0c914541b063e9a68928b", size = 6819682, upload-time = "2025-10-15T16:18:02.32Z" }, + { url = "https://files.pythonhosted.org/packages/c5/8c/cd283b54c3c2b77e188f63e23039844f56b23bba1712318288c13fe86baf/numpy-2.3.4-pp311-pypy311_pp73-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:957e92defe6c08211eb77902253b14fe5b480ebc5112bc741fd5e9cd0608f847", size = 14422300, upload-time = "2025-10-15T16:18:04.271Z" }, + { url = "https://files.pythonhosted.org/packages/b0/f0/8404db5098d92446b3e3695cf41c6f0ecb703d701cb0b7566ee2177f2eee/numpy-2.3.4-pp311-pypy311_pp73-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:13b9062e4f5c7ee5c7e5be96f29ba71bc5a37fed3d1d77c37390ae00724d296d", size = 16760806, upload-time = "2025-10-15T16:18:06.668Z" }, + { url = "https://files.pythonhosted.org/packages/95/8e/2844c3959ce9a63acc7c8e50881133d86666f0420bcde695e115ced0920f/numpy-2.3.4-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:81b3a59793523e552c4a96109dde028aa4448ae06ccac5a76ff6532a85558a7f", size = 12973130, upload-time = "2025-10-15T16:18:09.397Z" }, +] + +[[package]] +name = "opt-einsum" +version = "3.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8c/b9/2ac072041e899a52f20cf9510850ff58295003aa75525e58343591b0cbfb/opt_einsum-3.4.0.tar.gz", hash = "sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac", size = 63004, upload-time = "2024-09-26T14:33:24.483Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/23/cd/066e86230ae37ed0be70aae89aabf03ca8d9f39c8aea0dec8029455b5540/opt_einsum-3.4.0-py3-none-any.whl", hash = "sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd", size = 71932, upload-time = "2024-09-26T14:33:23.039Z" }, +] + +[[package]] +name = "packaging" +version = "25.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a1/d4/1fc4078c65507b51b96ca8f8c3ba19e6a61c8253c72794544580a7b6c24d/packaging-25.0.tar.gz", hash = "sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f", size = 165727, upload-time = "2025-04-19T11:48:59.673Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl", hash = "sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484", size = 66469, upload-time = "2025-04-19T11:48:57.875Z" }, +] + +[[package]] +name = "platformdirs" +version = "4.5.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/61/33/9611380c2bdb1225fdef633e2a9610622310fed35ab11dac9620972ee088/platformdirs-4.5.0.tar.gz", hash = "sha256:70ddccdd7c99fc5942e9fc25636a8b34d04c24b335100223152c2803e4063312", size = 21632, upload-time = "2025-10-08T17:44:48.791Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/73/cb/ac7874b3e5d58441674fb70742e6c374b28b0c7cb988d37d991cde47166c/platformdirs-4.5.0-py3-none-any.whl", hash = "sha256:e578a81bb873cbb89a41fcc904c7ef523cc18284b7e3b3ccf06aca1403b7ebd3", size = 18651, upload-time = "2025-10-08T17:44:47.223Z" }, +] + +[[package]] +name = "pluggy" +version = "1.6.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412, upload-time = "2025-05-15T12:30:07.975Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, +] + +[[package]] +name = "pre-commit" +version = "4.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cfgv" }, + { name = "identify" }, + { name = "nodeenv" }, + { name = "pyyaml" }, + { name = "virtualenv" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ff/29/7cf5bbc236333876e4b41f56e06857a87937ce4bf91e117a6991a2dbb02a/pre_commit-4.3.0.tar.gz", hash = "sha256:499fe450cc9d42e9d58e606262795ecb64dd05438943c62b66f6a8673da30b16", size = 193792, upload-time = "2025-08-09T18:56:14.651Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5b/a5/987a405322d78a73b66e39e4a90e4ef156fd7141bf71df987e50717c321b/pre_commit-4.3.0-py2.py3-none-any.whl", hash = "sha256:2b0747ad7e6e967169136edffee14c16e148a778a54e4f967921aa1ebf2308d8", size = 220965, upload-time = "2025-08-09T18:56:13.192Z" }, +] + +[[package]] +name = "py-cpuinfo" +version = "9.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/37/a8/d832f7293ebb21690860d2e01d8115e5ff6f2ae8bbdc953f0eb0fa4bd2c7/py-cpuinfo-9.0.0.tar.gz", hash = "sha256:3cdbbf3fac90dc6f118bfd64384f309edeadd902d7c8fb17f02ffa1fc3f49690", size = 104716, upload-time = "2022-10-25T20:38:06.303Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e0/a9/023730ba63db1e494a271cb018dcd361bd2c917ba7004c3e49d5daf795a2/py_cpuinfo-9.0.0-py3-none-any.whl", hash = "sha256:859625bc251f64e21f077d099d4162689c762b5d6a4c3c97553d56241c9674d5", size = 22335, upload-time = "2022-10-25T20:38:27.636Z" }, +] + +[[package]] +name = "pydot" +version = "4.0.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pyparsing" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/50/35/b17cb89ff865484c6a20ef46bf9d95a5f07328292578de0b295f4a6beec2/pydot-4.0.1.tar.gz", hash = "sha256:c2148f681c4a33e08bf0e26a9e5f8e4099a82e0e2a068098f32ce86577364ad5", size = 162594, upload-time = "2025-06-17T20:09:56.454Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7e/32/a7125fb28c4261a627f999d5fb4afff25b523800faed2c30979949d6facd/pydot-4.0.1-py3-none-any.whl", hash = "sha256:869c0efadd2708c0be1f916eb669f3d664ca684bc57ffb7ecc08e70d5e93fee6", size = 37087, upload-time = "2025-06-17T20:09:55.25Z" }, +] + +[[package]] +name = "pygments" +version = "2.19.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b0/77/a5b8c569bf593b0140bde72ea885a803b82086995367bf2037de0159d924/pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887", size = 4968631, upload-time = "2025-06-21T13:39:12.283Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" }, +] + +[[package]] +name = "pyparsing" +version = "3.2.5" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f2/a5/181488fc2b9d093e3972d2a472855aae8a03f000592dbfce716a512b3359/pyparsing-3.2.5.tar.gz", hash = "sha256:2df8d5b7b2802ef88e8d016a2eb9c7aeaa923529cd251ed0fe4608275d4105b6", size = 1099274, upload-time = "2025-09-21T04:11:06.277Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/10/5e/1aa9a93198c6b64513c9d7752de7422c06402de6600a8767da1524f9570b/pyparsing-3.2.5-py3-none-any.whl", hash = "sha256:e38a4f02064cf41fe6593d328d0512495ad1f3d8a91c4f73fc401b3079a59a5e", size = 113890, upload-time = "2025-09-21T04:11:04.117Z" }, +] + +[[package]] +name = "pytensor" +source = { editable = "." } +dependencies = [ + { name = "cons" }, + { name = "etuples" }, + { name = "filelock" }, + { name = "logical-unification" }, + { name = "minikanren" }, + { name = "numpy" }, + { name = "scipy" }, + { name = "setuptools" }, +] + +[package.optional-dependencies] +complete = [ + { name = "jax" }, + { name = "jaxlib" }, + { name = "llvmlite" }, + { name = "numba" }, +] +development = [ + { name = "coverage" }, + { name = "jax" }, + { name = "jaxlib" }, + { name = "llvmlite" }, + { name = "numba" }, + { name = "pre-commit" }, + { name = "pydot" }, + { name = "pygments" }, + { name = "pytest" }, + { name = "pytest-benchmark" }, + { name = "pytest-cov" }, + { name = "pytest-mock" }, + { name = "pytest-sphinx" }, + { name = "sphinx" }, +] +jax = [ + { name = "jax" }, + { name = "jaxlib" }, +] +numba = [ + { name = "llvmlite" }, + { name = "numba" }, +] +rtd = [ + { name = "pydot" }, + { name = "pygments" }, + { name = "sphinx" }, +] +tests = [ + { name = "coverage" }, + { name = "pre-commit" }, + { name = "pytest" }, + { name = "pytest-benchmark" }, + { name = "pytest-cov" }, + { name = "pytest-mock" }, + { name = "pytest-sphinx" }, +] + +[package.metadata] +requires-dist = [ + { name = "cons" }, + { name = "coverage", marker = "extra == 'tests'", specifier = ">=5.1" }, + { name = "etuples" }, + { name = "filelock", specifier = ">=3.15" }, + { name = "jax", marker = "extra == 'jax'" }, + { name = "jaxlib", marker = "extra == 'jax'" }, + { name = "llvmlite", marker = "extra == 'numba'" }, + { name = "logical-unification" }, + { name = "minikanren" }, + { name = "numba", marker = "extra == 'numba'", specifier = ">=0.57" }, + { name = "numpy", specifier = ">=2.0" }, + { name = "pre-commit", marker = "extra == 'tests'" }, + { name = "pydot", marker = "extra == 'rtd'" }, + { name = "pygments", marker = "extra == 'rtd'" }, + { name = "pytensor", extras = ["complete"], marker = "extra == 'development'" }, + { name = "pytensor", extras = ["jax"], marker = "extra == 'complete'" }, + { name = "pytensor", extras = ["numba"], marker = "extra == 'complete'" }, + { name = "pytensor", extras = ["rtd"], marker = "extra == 'development'" }, + { name = "pytensor", extras = ["tests"], marker = "extra == 'development'" }, + { name = "pytest", marker = "extra == 'tests'" }, + { name = "pytest-benchmark", marker = "extra == 'tests'" }, + { name = "pytest-cov", marker = "extra == 'tests'", specifier = ">=2.6.1" }, + { name = "pytest-mock", marker = "extra == 'tests'" }, + { name = "pytest-sphinx", marker = "extra == 'tests'" }, + { name = "scipy", specifier = ">=1,<2" }, + { name = "setuptools", specifier = ">=59.0.0" }, + { name = "sphinx", marker = "extra == 'rtd'", specifier = ">=5.1.0,<6" }, +] +provides-extras = ["complete", "development", "tests", "rtd", "jax", "numba"] + +[[package]] +name = "pytest" +version = "8.4.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "iniconfig" }, + { name = "packaging" }, + { name = "pluggy" }, + { name = "pygments" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a3/5c/00a0e072241553e1a7496d638deababa67c5058571567b92a7eaa258397c/pytest-8.4.2.tar.gz", hash = "sha256:86c0d0b93306b961d58d62a4db4879f27fe25513d4b969df351abdddb3c30e01", size = 1519618, upload-time = "2025-09-04T14:34:22.711Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a8/a4/20da314d277121d6534b3a980b29035dcd51e6744bd79075a6ce8fa4eb8d/pytest-8.4.2-py3-none-any.whl", hash = "sha256:872f880de3fc3a5bdc88a11b39c9710c3497a547cfa9320bc3c5e62fbf272e79", size = 365750, upload-time = "2025-09-04T14:34:20.226Z" }, +] + +[[package]] +name = "pytest-benchmark" +version = "5.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "py-cpuinfo" }, + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/91/84/84ba011c4b2a44c8fce772be6124821a27cecd0f69b324f24ef4c1172863/pytest_benchmark-5.2.0.tar.gz", hash = "sha256:75731991edf6c807d0699130afbb4ba77d8ce8e3b8314662c340ee8e1db19f43", size = 339143, upload-time = "2025-10-30T18:11:02.264Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b7/c2/57de9aa286a2f6d00c52a7bb4b16dbbfa2a6c80b4a4f0e415c874269a4a6/pytest_benchmark-5.2.0-py3-none-any.whl", hash = "sha256:0631cdf19f6032fc46d6bf9e8d15931d78473228b579a3fd84ca5e2f0e8ee06c", size = 44194, upload-time = "2025-10-30T18:11:00.311Z" }, +] + +[[package]] +name = "pytest-cov" +version = "7.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "coverage", extra = ["toml"] }, + { name = "pluggy" }, + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5e/f7/c933acc76f5208b3b00089573cf6a2bc26dc80a8aece8f52bb7d6b1855ca/pytest_cov-7.0.0.tar.gz", hash = "sha256:33c97eda2e049a0c5298e91f519302a1334c26ac65c1a483d6206fd458361af1", size = 54328, upload-time = "2025-09-09T10:57:02.113Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ee/49/1377b49de7d0c1ce41292161ea0f721913fa8722c19fb9c1e3aa0367eecb/pytest_cov-7.0.0-py3-none-any.whl", hash = "sha256:3b8e9558b16cc1479da72058bdecf8073661c7f57f7d3c5f22a1c23507f2d861", size = 22424, upload-time = "2025-09-09T10:57:00.695Z" }, +] + +[[package]] +name = "pytest-mock" +version = "3.15.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/68/14/eb014d26be205d38ad5ad20d9a80f7d201472e08167f0bb4361e251084a9/pytest_mock-3.15.1.tar.gz", hash = "sha256:1849a238f6f396da19762269de72cb1814ab44416fa73a8686deac10b0d87a0f", size = 34036, upload-time = "2025-09-16T16:37:27.081Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5a/cc/06253936f4a7fa2e0f48dfe6d851d9c56df896a9ab09ac019d70b760619c/pytest_mock-3.15.1-py3-none-any.whl", hash = "sha256:0a25e2eb88fe5168d535041d09a4529a188176ae608a6d249ee65abc0949630d", size = 10095, upload-time = "2025-09-16T16:37:25.734Z" }, +] + +[[package]] +name = "pytest-sphinx" +version = "0.6.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/8f/12/a6e99712955b7057accbe43f7f709cf212e6fc00f570bfdc93574335ba5b/pytest_sphinx-0.6.3.tar.gz", hash = "sha256:3b63c8181b9de6a5e5c9826d1b4dc0c827245bec8e64c9f16f269be08be5ecd5", size = 13690, upload-time = "2024-04-13T19:11:51.905Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/70/db/79570f7eebfa0f24b670d985423f4fa45fee67ef8feb25c6b58cbe2b0bb7/pytest_sphinx-0.6.3-py3-none-any.whl", hash = "sha256:856e760e64dfbfc89e362e187d641140a267b97881d3ef8aeefb72cc8438ac40", size = 10349, upload-time = "2024-04-13T19:11:50.394Z" }, +] + +[[package]] +name = "pyyaml" +version = "6.0.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/05/8e/961c0007c59b8dd7729d542c61a4d537767a59645b82a0b521206e1e25c2/pyyaml-6.0.3.tar.gz", hash = "sha256:d76623373421df22fb4cf8817020cbb7ef15c725b9d5e45f17e189bfc384190f", size = 130960, upload-time = "2025-09-25T21:33:16.546Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6d/16/a95b6757765b7b031c9374925bb718d55e0a9ba8a1b6a12d25962ea44347/pyyaml-6.0.3-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:44edc647873928551a01e7a563d7452ccdebee747728c1080d881d68af7b997e", size = 185826, upload-time = "2025-09-25T21:31:58.655Z" }, + { url = "https://files.pythonhosted.org/packages/16/19/13de8e4377ed53079ee996e1ab0a9c33ec2faf808a4647b7b4c0d46dd239/pyyaml-6.0.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:652cb6edd41e718550aad172851962662ff2681490a8a711af6a4d288dd96824", size = 175577, upload-time = "2025-09-25T21:32:00.088Z" }, + { url = "https://files.pythonhosted.org/packages/0c/62/d2eb46264d4b157dae1275b573017abec435397aa59cbcdab6fc978a8af4/pyyaml-6.0.3-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:10892704fc220243f5305762e276552a0395f7beb4dbf9b14ec8fd43b57f126c", size = 775556, upload-time = "2025-09-25T21:32:01.31Z" }, + { url = "https://files.pythonhosted.org/packages/10/cb/16c3f2cf3266edd25aaa00d6c4350381c8b012ed6f5276675b9eba8d9ff4/pyyaml-6.0.3-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:850774a7879607d3a6f50d36d04f00ee69e7fc816450e5f7e58d7f17f1ae5c00", size = 882114, upload-time = "2025-09-25T21:32:03.376Z" }, + { url = "https://files.pythonhosted.org/packages/71/60/917329f640924b18ff085ab889a11c763e0b573da888e8404ff486657602/pyyaml-6.0.3-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b8bb0864c5a28024fac8a632c443c87c5aa6f215c0b126c449ae1a150412f31d", size = 806638, upload-time = "2025-09-25T21:32:04.553Z" }, + { url = "https://files.pythonhosted.org/packages/dd/6f/529b0f316a9fd167281a6c3826b5583e6192dba792dd55e3203d3f8e655a/pyyaml-6.0.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:1d37d57ad971609cf3c53ba6a7e365e40660e3be0e5175fa9f2365a379d6095a", size = 767463, upload-time = "2025-09-25T21:32:06.152Z" }, + { url = "https://files.pythonhosted.org/packages/f2/6a/b627b4e0c1dd03718543519ffb2f1deea4a1e6d42fbab8021936a4d22589/pyyaml-6.0.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:37503bfbfc9d2c40b344d06b2199cf0e96e97957ab1c1b546fd4f87e53e5d3e4", size = 794986, upload-time = "2025-09-25T21:32:07.367Z" }, + { url = "https://files.pythonhosted.org/packages/45/91/47a6e1c42d9ee337c4839208f30d9f09caa9f720ec7582917b264defc875/pyyaml-6.0.3-cp311-cp311-win32.whl", hash = "sha256:8098f252adfa6c80ab48096053f512f2321f0b998f98150cea9bd23d83e1467b", size = 142543, upload-time = "2025-09-25T21:32:08.95Z" }, + { url = "https://files.pythonhosted.org/packages/da/e3/ea007450a105ae919a72393cb06f122f288ef60bba2dc64b26e2646fa315/pyyaml-6.0.3-cp311-cp311-win_amd64.whl", hash = "sha256:9f3bfb4965eb874431221a3ff3fdcddc7e74e3b07799e0e84ca4a0f867d449bf", size = 158763, upload-time = "2025-09-25T21:32:09.96Z" }, + { url = "https://files.pythonhosted.org/packages/d1/33/422b98d2195232ca1826284a76852ad5a86fe23e31b009c9886b2d0fb8b2/pyyaml-6.0.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:7f047e29dcae44602496db43be01ad42fc6f1cc0d8cd6c83d342306c32270196", size = 182063, upload-time = "2025-09-25T21:32:11.445Z" }, + { url = "https://files.pythonhosted.org/packages/89/a0/6cf41a19a1f2f3feab0e9c0b74134aa2ce6849093d5517a0c550fe37a648/pyyaml-6.0.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:fc09d0aa354569bc501d4e787133afc08552722d3ab34836a80547331bb5d4a0", size = 173973, upload-time = "2025-09-25T21:32:12.492Z" }, + { url = "https://files.pythonhosted.org/packages/ed/23/7a778b6bd0b9a8039df8b1b1d80e2e2ad78aa04171592c8a5c43a56a6af4/pyyaml-6.0.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9149cad251584d5fb4981be1ecde53a1ca46c891a79788c0df828d2f166bda28", size = 775116, upload-time = "2025-09-25T21:32:13.652Z" }, + { url = "https://files.pythonhosted.org/packages/65/30/d7353c338e12baef4ecc1b09e877c1970bd3382789c159b4f89d6a70dc09/pyyaml-6.0.3-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:5fdec68f91a0c6739b380c83b951e2c72ac0197ace422360e6d5a959d8d97b2c", size = 844011, upload-time = "2025-09-25T21:32:15.21Z" }, + { url = "https://files.pythonhosted.org/packages/8b/9d/b3589d3877982d4f2329302ef98a8026e7f4443c765c46cfecc8858c6b4b/pyyaml-6.0.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ba1cc08a7ccde2d2ec775841541641e4548226580ab850948cbfda66a1befcdc", size = 807870, upload-time = "2025-09-25T21:32:16.431Z" }, + { url = "https://files.pythonhosted.org/packages/05/c0/b3be26a015601b822b97d9149ff8cb5ead58c66f981e04fedf4e762f4bd4/pyyaml-6.0.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:8dc52c23056b9ddd46818a57b78404882310fb473d63f17b07d5c40421e47f8e", size = 761089, upload-time = "2025-09-25T21:32:17.56Z" }, + { url = "https://files.pythonhosted.org/packages/be/8e/98435a21d1d4b46590d5459a22d88128103f8da4c2d4cb8f14f2a96504e1/pyyaml-6.0.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:41715c910c881bc081f1e8872880d3c650acf13dfa8214bad49ed4cede7c34ea", size = 790181, upload-time = "2025-09-25T21:32:18.834Z" }, + { url = "https://files.pythonhosted.org/packages/74/93/7baea19427dcfbe1e5a372d81473250b379f04b1bd3c4c5ff825e2327202/pyyaml-6.0.3-cp312-cp312-win32.whl", hash = "sha256:96b533f0e99f6579b3d4d4995707cf36df9100d67e0c8303a0c55b27b5f99bc5", size = 137658, upload-time = "2025-09-25T21:32:20.209Z" }, + { url = "https://files.pythonhosted.org/packages/86/bf/899e81e4cce32febab4fb42bb97dcdf66bc135272882d1987881a4b519e9/pyyaml-6.0.3-cp312-cp312-win_amd64.whl", hash = "sha256:5fcd34e47f6e0b794d17de1b4ff496c00986e1c83f7ab2fb8fcfe9616ff7477b", size = 154003, upload-time = "2025-09-25T21:32:21.167Z" }, + { url = "https://files.pythonhosted.org/packages/1a/08/67bd04656199bbb51dbed1439b7f27601dfb576fb864099c7ef0c3e55531/pyyaml-6.0.3-cp312-cp312-win_arm64.whl", hash = "sha256:64386e5e707d03a7e172c0701abfb7e10f0fb753ee1d773128192742712a98fd", size = 140344, upload-time = "2025-09-25T21:32:22.617Z" }, + { url = "https://files.pythonhosted.org/packages/d1/11/0fd08f8192109f7169db964b5707a2f1e8b745d4e239b784a5a1dd80d1db/pyyaml-6.0.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:8da9669d359f02c0b91ccc01cac4a67f16afec0dac22c2ad09f46bee0697eba8", size = 181669, upload-time = "2025-09-25T21:32:23.673Z" }, + { url = "https://files.pythonhosted.org/packages/b1/16/95309993f1d3748cd644e02e38b75d50cbc0d9561d21f390a76242ce073f/pyyaml-6.0.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:2283a07e2c21a2aa78d9c4442724ec1eb15f5e42a723b99cb3d822d48f5f7ad1", size = 173252, upload-time = "2025-09-25T21:32:25.149Z" }, + { url = "https://files.pythonhosted.org/packages/50/31/b20f376d3f810b9b2371e72ef5adb33879b25edb7a6d072cb7ca0c486398/pyyaml-6.0.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ee2922902c45ae8ccada2c5b501ab86c36525b883eff4255313a253a3160861c", size = 767081, upload-time = "2025-09-25T21:32:26.575Z" }, + { url = "https://files.pythonhosted.org/packages/49/1e/a55ca81e949270d5d4432fbbd19dfea5321eda7c41a849d443dc92fd1ff7/pyyaml-6.0.3-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:a33284e20b78bd4a18c8c2282d549d10bc8408a2a7ff57653c0cf0b9be0afce5", size = 841159, upload-time = "2025-09-25T21:32:27.727Z" }, + { url = "https://files.pythonhosted.org/packages/74/27/e5b8f34d02d9995b80abcef563ea1f8b56d20134d8f4e5e81733b1feceb2/pyyaml-6.0.3-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0f29edc409a6392443abf94b9cf89ce99889a1dd5376d94316ae5145dfedd5d6", size = 801626, upload-time = "2025-09-25T21:32:28.878Z" }, + { url = "https://files.pythonhosted.org/packages/f9/11/ba845c23988798f40e52ba45f34849aa8a1f2d4af4b798588010792ebad6/pyyaml-6.0.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:f7057c9a337546edc7973c0d3ba84ddcdf0daa14533c2065749c9075001090e6", size = 753613, upload-time = "2025-09-25T21:32:30.178Z" }, + { url = "https://files.pythonhosted.org/packages/3d/e0/7966e1a7bfc0a45bf0a7fb6b98ea03fc9b8d84fa7f2229e9659680b69ee3/pyyaml-6.0.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:eda16858a3cab07b80edaf74336ece1f986ba330fdb8ee0d6c0d68fe82bc96be", size = 794115, upload-time = "2025-09-25T21:32:31.353Z" }, + { url = "https://files.pythonhosted.org/packages/de/94/980b50a6531b3019e45ddeada0626d45fa85cbe22300844a7983285bed3b/pyyaml-6.0.3-cp313-cp313-win32.whl", hash = "sha256:d0eae10f8159e8fdad514efdc92d74fd8d682c933a6dd088030f3834bc8e6b26", size = 137427, upload-time = "2025-09-25T21:32:32.58Z" }, + { url = "https://files.pythonhosted.org/packages/97/c9/39d5b874e8b28845e4ec2202b5da735d0199dbe5b8fb85f91398814a9a46/pyyaml-6.0.3-cp313-cp313-win_amd64.whl", hash = "sha256:79005a0d97d5ddabfeeea4cf676af11e647e41d81c9a7722a193022accdb6b7c", size = 154090, upload-time = "2025-09-25T21:32:33.659Z" }, + { url = "https://files.pythonhosted.org/packages/73/e8/2bdf3ca2090f68bb3d75b44da7bbc71843b19c9f2b9cb9b0f4ab7a5a4329/pyyaml-6.0.3-cp313-cp313-win_arm64.whl", hash = "sha256:5498cd1645aa724a7c71c8f378eb29ebe23da2fc0d7a08071d89469bf1d2defb", size = 140246, upload-time = "2025-09-25T21:32:34.663Z" }, +] + +[[package]] +name = "requests" +version = "2.32.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "certifi" }, + { name = "charset-normalizer" }, + { name = "idna" }, + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c9/74/b3ff8e6c8446842c3f5c837e9c3dfcfe2018ea6ecef224c710c85ef728f4/requests-2.32.5.tar.gz", hash = "sha256:dbba0bac56e100853db0ea71b82b4dfd5fe2bf6d3754a8893c3af500cec7d7cf", size = 134517, upload-time = "2025-08-18T20:46:02.573Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738, upload-time = "2025-08-18T20:46:00.542Z" }, +] + +[[package]] +name = "scipy" +version = "1.16.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0a/ca/d8ace4f98322d01abcd52d381134344bf7b431eba7ed8b42bdea5a3c2ac9/scipy-1.16.3.tar.gz", hash = "sha256:01e87659402762f43bd2fee13370553a17ada367d42e7487800bf2916535aecb", size = 30597883, upload-time = "2025-10-28T17:38:54.068Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9b/5f/6f37d7439de1455ce9c5a556b8d1db0979f03a796c030bafdf08d35b7bf9/scipy-1.16.3-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:40be6cf99e68b6c4321e9f8782e7d5ff8265af28ef2cd56e9c9b2638fa08ad97", size = 36630881, upload-time = "2025-10-28T17:31:47.104Z" }, + { url = "https://files.pythonhosted.org/packages/7c/89/d70e9f628749b7e4db2aa4cd89735502ff3f08f7b9b27d2e799485987cd9/scipy-1.16.3-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:8be1ca9170fcb6223cc7c27f4305d680ded114a1567c0bd2bfcbf947d1b17511", size = 28941012, upload-time = "2025-10-28T17:31:53.411Z" }, + { url = "https://files.pythonhosted.org/packages/a8/a8/0e7a9a6872a923505dbdf6bb93451edcac120363131c19013044a1e7cb0c/scipy-1.16.3-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:bea0a62734d20d67608660f69dcda23e7f90fb4ca20974ab80b6ed40df87a005", size = 20931935, upload-time = "2025-10-28T17:31:57.361Z" }, + { url = "https://files.pythonhosted.org/packages/bd/c7/020fb72bd79ad798e4dbe53938543ecb96b3a9ac3fe274b7189e23e27353/scipy-1.16.3-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:2a207a6ce9c24f1951241f4693ede2d393f59c07abc159b2cb2be980820e01fb", size = 23534466, upload-time = "2025-10-28T17:32:01.875Z" }, + { url = "https://files.pythonhosted.org/packages/be/a0/668c4609ce6dbf2f948e167836ccaf897f95fb63fa231c87da7558a374cd/scipy-1.16.3-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:532fb5ad6a87e9e9cd9c959b106b73145a03f04c7d57ea3e6f6bb60b86ab0876", size = 33593618, upload-time = "2025-10-28T17:32:06.902Z" }, + { url = "https://files.pythonhosted.org/packages/ca/6e/8942461cf2636cdae083e3eb72622a7fbbfa5cf559c7d13ab250a5dbdc01/scipy-1.16.3-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:0151a0749efeaaab78711c78422d413c583b8cdd2011a3c1d6c794938ee9fdb2", size = 35899798, upload-time = "2025-10-28T17:32:12.665Z" }, + { url = "https://files.pythonhosted.org/packages/79/e8/d0f33590364cdbd67f28ce79368b373889faa4ee959588beddf6daef9abe/scipy-1.16.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:b7180967113560cca57418a7bc719e30366b47959dd845a93206fbed693c867e", size = 36226154, upload-time = "2025-10-28T17:32:17.961Z" }, + { url = "https://files.pythonhosted.org/packages/39/c1/1903de608c0c924a1749c590064e65810f8046e437aba6be365abc4f7557/scipy-1.16.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:deb3841c925eeddb6afc1e4e4a45e418d19ec7b87c5df177695224078e8ec733", size = 38878540, upload-time = "2025-10-28T17:32:23.907Z" }, + { url = "https://files.pythonhosted.org/packages/f1/d0/22ec7036ba0b0a35bccb7f25ab407382ed34af0b111475eb301c16f8a2e5/scipy-1.16.3-cp311-cp311-win_amd64.whl", hash = "sha256:53c3844d527213631e886621df5695d35e4f6a75f620dca412bcd292f6b87d78", size = 38722107, upload-time = "2025-10-28T17:32:29.921Z" }, + { url = "https://files.pythonhosted.org/packages/7b/60/8a00e5a524bb3bf8898db1650d350f50e6cffb9d7a491c561dc9826c7515/scipy-1.16.3-cp311-cp311-win_arm64.whl", hash = "sha256:9452781bd879b14b6f055b26643703551320aa8d79ae064a71df55c00286a184", size = 25506272, upload-time = "2025-10-28T17:32:34.577Z" }, + { url = "https://files.pythonhosted.org/packages/40/41/5bf55c3f386b1643812f3a5674edf74b26184378ef0f3e7c7a09a7e2ca7f/scipy-1.16.3-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:81fc5827606858cf71446a5e98715ba0e11f0dbc83d71c7409d05486592a45d6", size = 36659043, upload-time = "2025-10-28T17:32:40.285Z" }, + { url = "https://files.pythonhosted.org/packages/1e/0f/65582071948cfc45d43e9870bf7ca5f0e0684e165d7c9ef4e50d783073eb/scipy-1.16.3-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:c97176013d404c7346bf57874eaac5187d969293bf40497140b0a2b2b7482e07", size = 28898986, upload-time = "2025-10-28T17:32:45.325Z" }, + { url = "https://files.pythonhosted.org/packages/96/5e/36bf3f0ac298187d1ceadde9051177d6a4fe4d507e8f59067dc9dd39e650/scipy-1.16.3-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:2b71d93c8a9936046866acebc915e2af2e292b883ed6e2cbe5c34beb094b82d9", size = 20889814, upload-time = "2025-10-28T17:32:49.277Z" }, + { url = "https://files.pythonhosted.org/packages/80/35/178d9d0c35394d5d5211bbff7ac4f2986c5488b59506fef9e1de13ea28d3/scipy-1.16.3-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:3d4a07a8e785d80289dfe66b7c27d8634a773020742ec7187b85ccc4b0e7b686", size = 23565795, upload-time = "2025-10-28T17:32:53.337Z" }, + { url = "https://files.pythonhosted.org/packages/fa/46/d1146ff536d034d02f83c8afc3c4bab2eddb634624d6529a8512f3afc9da/scipy-1.16.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:0553371015692a898e1aa858fed67a3576c34edefa6b7ebdb4e9dde49ce5c203", size = 33349476, upload-time = "2025-10-28T17:32:58.353Z" }, + { url = "https://files.pythonhosted.org/packages/79/2e/415119c9ab3e62249e18c2b082c07aff907a273741b3f8160414b0e9193c/scipy-1.16.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:72d1717fd3b5e6ec747327ce9bda32d5463f472c9dce9f54499e81fbd50245a1", size = 35676692, upload-time = "2025-10-28T17:33:03.88Z" }, + { url = "https://files.pythonhosted.org/packages/27/82/df26e44da78bf8d2aeaf7566082260cfa15955a5a6e96e6a29935b64132f/scipy-1.16.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:1fb2472e72e24d1530debe6ae078db70fb1605350c88a3d14bc401d6306dbffe", size = 36019345, upload-time = "2025-10-28T17:33:09.773Z" }, + { url = "https://files.pythonhosted.org/packages/82/31/006cbb4b648ba379a95c87262c2855cd0d09453e500937f78b30f02fa1cd/scipy-1.16.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c5192722cffe15f9329a3948c4b1db789fbb1f05c97899187dcf009b283aea70", size = 38678975, upload-time = "2025-10-28T17:33:15.809Z" }, + { url = "https://files.pythonhosted.org/packages/c2/7f/acbd28c97e990b421af7d6d6cd416358c9c293fc958b8529e0bd5d2a2a19/scipy-1.16.3-cp312-cp312-win_amd64.whl", hash = "sha256:56edc65510d1331dae01ef9b658d428e33ed48b4f77b1d51caf479a0253f96dc", size = 38555926, upload-time = "2025-10-28T17:33:21.388Z" }, + { url = "https://files.pythonhosted.org/packages/ce/69/c5c7807fd007dad4f48e0a5f2153038dc96e8725d3345b9ee31b2b7bed46/scipy-1.16.3-cp312-cp312-win_arm64.whl", hash = "sha256:a8a26c78ef223d3e30920ef759e25625a0ecdd0d60e5a8818b7513c3e5384cf2", size = 25463014, upload-time = "2025-10-28T17:33:25.975Z" }, + { url = "https://files.pythonhosted.org/packages/72/f1/57e8327ab1508272029e27eeef34f2302ffc156b69e7e233e906c2a5c379/scipy-1.16.3-cp313-cp313-macosx_10_14_x86_64.whl", hash = "sha256:d2ec56337675e61b312179a1ad124f5f570c00f920cc75e1000025451b88241c", size = 36617856, upload-time = "2025-10-28T17:33:31.375Z" }, + { url = "https://files.pythonhosted.org/packages/44/13/7e63cfba8a7452eb756306aa2fd9b37a29a323b672b964b4fdeded9a3f21/scipy-1.16.3-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:16b8bc35a4cc24db80a0ec836a9286d0e31b2503cb2fd7ff7fb0e0374a97081d", size = 28874306, upload-time = "2025-10-28T17:33:36.516Z" }, + { url = "https://files.pythonhosted.org/packages/15/65/3a9400efd0228a176e6ec3454b1fa998fbbb5a8defa1672c3f65706987db/scipy-1.16.3-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:5803c5fadd29de0cf27fa08ccbfe7a9e5d741bf63e4ab1085437266f12460ff9", size = 20865371, upload-time = "2025-10-28T17:33:42.094Z" }, + { url = "https://files.pythonhosted.org/packages/33/d7/eda09adf009a9fb81827194d4dd02d2e4bc752cef16737cc4ef065234031/scipy-1.16.3-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:b81c27fc41954319a943d43b20e07c40bdcd3ff7cf013f4fb86286faefe546c4", size = 23524877, upload-time = "2025-10-28T17:33:48.483Z" }, + { url = "https://files.pythonhosted.org/packages/7d/6b/3f911e1ebc364cb81320223a3422aab7d26c9c7973109a9cd0f27c64c6c0/scipy-1.16.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:0c3b4dd3d9b08dbce0f3440032c52e9e2ab9f96ade2d3943313dfe51a7056959", size = 33342103, upload-time = "2025-10-28T17:33:56.495Z" }, + { url = "https://files.pythonhosted.org/packages/21/f6/4bfb5695d8941e5c570a04d9fcd0d36bce7511b7d78e6e75c8f9791f82d0/scipy-1.16.3-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:7dc1360c06535ea6116a2220f760ae572db9f661aba2d88074fe30ec2aa1ff88", size = 35697297, upload-time = "2025-10-28T17:34:04.722Z" }, + { url = "https://files.pythonhosted.org/packages/04/e1/6496dadbc80d8d896ff72511ecfe2316b50313bfc3ebf07a3f580f08bd8c/scipy-1.16.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:663b8d66a8748051c3ee9c96465fb417509315b99c71550fda2591d7dd634234", size = 36021756, upload-time = "2025-10-28T17:34:13.482Z" }, + { url = "https://files.pythonhosted.org/packages/fe/bd/a8c7799e0136b987bda3e1b23d155bcb31aec68a4a472554df5f0937eef7/scipy-1.16.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:eab43fae33a0c39006a88096cd7b4f4ef545ea0447d250d5ac18202d40b6611d", size = 38696566, upload-time = "2025-10-28T17:34:22.384Z" }, + { url = "https://files.pythonhosted.org/packages/cd/01/1204382461fcbfeb05b6161b594f4007e78b6eba9b375382f79153172b4d/scipy-1.16.3-cp313-cp313-win_amd64.whl", hash = "sha256:062246acacbe9f8210de8e751b16fc37458213f124bef161a5a02c7a39284304", size = 38529877, upload-time = "2025-10-28T17:35:51.076Z" }, + { url = "https://files.pythonhosted.org/packages/7f/14/9d9fbcaa1260a94f4bb5b64ba9213ceb5d03cd88841fe9fd1ffd47a45b73/scipy-1.16.3-cp313-cp313-win_arm64.whl", hash = "sha256:50a3dbf286dbc7d84f176f9a1574c705f277cb6565069f88f60db9eafdbe3ee2", size = 25455366, upload-time = "2025-10-28T17:35:59.014Z" }, + { url = "https://files.pythonhosted.org/packages/e2/a3/9ec205bd49f42d45d77f1730dbad9ccf146244c1647605cf834b3a8c4f36/scipy-1.16.3-cp313-cp313t-macosx_10_14_x86_64.whl", hash = "sha256:fb4b29f4cf8cc5a8d628bc8d8e26d12d7278cd1f219f22698a378c3d67db5e4b", size = 37027931, upload-time = "2025-10-28T17:34:31.451Z" }, + { url = "https://files.pythonhosted.org/packages/25/06/ca9fd1f3a4589cbd825b1447e5db3a8ebb969c1eaf22c8579bd286f51b6d/scipy-1.16.3-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:8d09d72dc92742988b0e7750bddb8060b0c7079606c0d24a8cc8e9c9c11f9079", size = 29400081, upload-time = "2025-10-28T17:34:39.087Z" }, + { url = "https://files.pythonhosted.org/packages/6a/56/933e68210d92657d93fb0e381683bc0e53a965048d7358ff5fbf9e6a1b17/scipy-1.16.3-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:03192a35e661470197556de24e7cb1330d84b35b94ead65c46ad6f16f6b28f2a", size = 21391244, upload-time = "2025-10-28T17:34:45.234Z" }, + { url = "https://files.pythonhosted.org/packages/a8/7e/779845db03dc1418e215726329674b40576879b91814568757ff0014ad65/scipy-1.16.3-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:57d01cb6f85e34f0946b33caa66e892aae072b64b034183f3d87c4025802a119", size = 23929753, upload-time = "2025-10-28T17:34:51.793Z" }, + { url = "https://files.pythonhosted.org/packages/4c/4b/f756cf8161d5365dcdef9e5f460ab226c068211030a175d2fc7f3f41ca64/scipy-1.16.3-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:96491a6a54e995f00a28a3c3badfff58fd093bf26cd5fb34a2188c8c756a3a2c", size = 33496912, upload-time = "2025-10-28T17:34:59.8Z" }, + { url = "https://files.pythonhosted.org/packages/09/b5/222b1e49a58668f23839ca1542a6322bb095ab8d6590d4f71723869a6c2c/scipy-1.16.3-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:cd13e354df9938598af2be05822c323e97132d5e6306b83a3b4ee6724c6e522e", size = 35802371, upload-time = "2025-10-28T17:35:08.173Z" }, + { url = "https://files.pythonhosted.org/packages/c1/8d/5964ef68bb31829bde27611f8c9deeac13764589fe74a75390242b64ca44/scipy-1.16.3-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:63d3cdacb8a824a295191a723ee5e4ea7768ca5ca5f2838532d9f2e2b3ce2135", size = 36190477, upload-time = "2025-10-28T17:35:16.7Z" }, + { url = "https://files.pythonhosted.org/packages/ab/f2/b31d75cb9b5fa4dd39a0a931ee9b33e7f6f36f23be5ef560bf72e0f92f32/scipy-1.16.3-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:e7efa2681ea410b10dde31a52b18b0154d66f2485328830e45fdf183af5aefc6", size = 38796678, upload-time = "2025-10-28T17:35:26.354Z" }, + { url = "https://files.pythonhosted.org/packages/b4/1e/b3723d8ff64ab548c38d87055483714fefe6ee20e0189b62352b5e015bb1/scipy-1.16.3-cp313-cp313t-win_amd64.whl", hash = "sha256:2d1ae2cf0c350e7705168ff2429962a89ad90c2d49d1dd300686d8b2a5af22fc", size = 38640178, upload-time = "2025-10-28T17:35:35.304Z" }, + { url = "https://files.pythonhosted.org/packages/8e/f3/d854ff38789aca9b0cc23008d607ced9de4f7ab14fa1ca4329f86b3758ca/scipy-1.16.3-cp313-cp313t-win_arm64.whl", hash = "sha256:0c623a54f7b79dd88ef56da19bc2873afec9673a48f3b85b18e4d402bdd29a5a", size = 25803246, upload-time = "2025-10-28T17:35:42.155Z" }, +] + +[[package]] +name = "setuptools" +version = "80.9.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/18/5d/3bf57dcd21979b887f014ea83c24ae194cfcd12b9e0fda66b957c69d1fca/setuptools-80.9.0.tar.gz", hash = "sha256:f36b47402ecde768dbfafc46e8e4207b4360c654f1f3bb84475f0a28628fb19c", size = 1319958, upload-time = "2025-05-27T00:56:51.443Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a3/dc/17031897dae0efacfea57dfd3a82fdd2a2aeb58e0ff71b77b87e44edc772/setuptools-80.9.0-py3-none-any.whl", hash = "sha256:062d34222ad13e0cc312a4c02d73f059e86a4acbfbdea8f8f76b28c99f306922", size = 1201486, upload-time = "2025-05-27T00:56:49.664Z" }, +] + +[[package]] +name = "snowballstemmer" +version = "3.0.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/75/a7/9810d872919697c9d01295633f5d574fb416d47e535f258272ca1f01f447/snowballstemmer-3.0.1.tar.gz", hash = "sha256:6d5eeeec8e9f84d4d56b847692bacf79bc2c8e90c7f80ca4444ff8b6f2e52895", size = 105575, upload-time = "2025-05-09T16:34:51.843Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c8/78/3565d011c61f5a43488987ee32b6f3f656e7f107ac2782dd57bdd7d91d9a/snowballstemmer-3.0.1-py3-none-any.whl", hash = "sha256:6cd7b3897da8d6c9ffb968a6781fa6532dce9c3618a4b127d920dab764a19064", size = 103274, upload-time = "2025-05-09T16:34:50.371Z" }, +] + +[[package]] +name = "sphinx" +version = "5.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "alabaster" }, + { name = "babel" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "docutils" }, + { name = "imagesize" }, + { name = "jinja2" }, + { name = "packaging" }, + { name = "pygments" }, + { name = "requests" }, + { name = "snowballstemmer" }, + { name = "sphinxcontrib-applehelp" }, + { name = "sphinxcontrib-devhelp" }, + { name = "sphinxcontrib-htmlhelp" }, + { name = "sphinxcontrib-jsmath" }, + { name = "sphinxcontrib-qthelp" }, + { name = "sphinxcontrib-serializinghtml" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/af/b2/02a43597980903483fe5eb081ee8e0ba2bb62ea43a70499484343795f3bf/Sphinx-5.3.0.tar.gz", hash = "sha256:51026de0a9ff9fc13c05d74913ad66047e104f56a129ff73e174eb5c3ee794b5", size = 6811365, upload-time = "2022-10-16T09:58:25.963Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/67/a7/01dd6fd9653c056258d65032aa09a615b5d7b07dd840845a9f41a8860fbc/sphinx-5.3.0-py3-none-any.whl", hash = "sha256:060ca5c9f7ba57a08a1219e547b269fadf125ae25b06b9fa7f66768efb652d6d", size = 3183160, upload-time = "2022-10-16T09:58:21.63Z" }, +] + +[[package]] +name = "sphinxcontrib-applehelp" +version = "2.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ba/6e/b837e84a1a704953c62ef8776d45c3e8d759876b4a84fe14eba2859106fe/sphinxcontrib_applehelp-2.0.0.tar.gz", hash = "sha256:2f29ef331735ce958efa4734873f084941970894c6090408b079c61b2e1c06d1", size = 20053, upload-time = "2024-07-29T01:09:00.465Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5d/85/9ebeae2f76e9e77b952f4b274c27238156eae7979c5421fba91a28f4970d/sphinxcontrib_applehelp-2.0.0-py3-none-any.whl", hash = "sha256:4cd3f0ec4ac5dd9c17ec65e9ab272c9b867ea77425228e68ecf08d6b28ddbdb5", size = 119300, upload-time = "2024-07-29T01:08:58.99Z" }, +] + +[[package]] +name = "sphinxcontrib-devhelp" +version = "2.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f6/d2/5beee64d3e4e747f316bae86b55943f51e82bb86ecd325883ef65741e7da/sphinxcontrib_devhelp-2.0.0.tar.gz", hash = "sha256:411f5d96d445d1d73bb5d52133377b4248ec79db5c793ce7dbe59e074b4dd1ad", size = 12967, upload-time = "2024-07-29T01:09:23.417Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/35/7a/987e583882f985fe4d7323774889ec58049171828b58c2217e7f79cdf44e/sphinxcontrib_devhelp-2.0.0-py3-none-any.whl", hash = "sha256:aefb8b83854e4b0998877524d1029fd3e6879210422ee3780459e28a1f03a8a2", size = 82530, upload-time = "2024-07-29T01:09:21.945Z" }, +] + +[[package]] +name = "sphinxcontrib-htmlhelp" +version = "2.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/43/93/983afd9aa001e5201eab16b5a444ed5b9b0a7a010541e0ddfbbfd0b2470c/sphinxcontrib_htmlhelp-2.1.0.tar.gz", hash = "sha256:c9e2916ace8aad64cc13a0d233ee22317f2b9025b9cf3295249fa985cc7082e9", size = 22617, upload-time = "2024-07-29T01:09:37.889Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0a/7b/18a8c0bcec9182c05a0b3ec2a776bba4ead82750a55ff798e8d406dae604/sphinxcontrib_htmlhelp-2.1.0-py3-none-any.whl", hash = "sha256:166759820b47002d22914d64a075ce08f4c46818e17cfc9470a9786b759b19f8", size = 98705, upload-time = "2024-07-29T01:09:36.407Z" }, +] + +[[package]] +name = "sphinxcontrib-jsmath" +version = "1.0.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b2/e8/9ed3830aeed71f17c026a07a5097edcf44b692850ef215b161b8ad875729/sphinxcontrib-jsmath-1.0.1.tar.gz", hash = "sha256:a9925e4a4587247ed2191a22df5f6970656cb8ca2bd6284309578f2153e0c4b8", size = 5787, upload-time = "2019-01-21T16:10:16.347Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c2/42/4c8646762ee83602e3fb3fbe774c2fac12f317deb0b5dbeeedd2d3ba4b77/sphinxcontrib_jsmath-1.0.1-py2.py3-none-any.whl", hash = "sha256:2ec2eaebfb78f3f2078e73666b1415417a116cc848b72e5172e596c871103178", size = 5071, upload-time = "2019-01-21T16:10:14.333Z" }, +] + +[[package]] +name = "sphinxcontrib-qthelp" +version = "2.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/68/bc/9104308fc285eb3e0b31b67688235db556cd5b0ef31d96f30e45f2e51cae/sphinxcontrib_qthelp-2.0.0.tar.gz", hash = "sha256:4fe7d0ac8fc171045be623aba3e2a8f613f8682731f9153bb2e40ece16b9bbab", size = 17165, upload-time = "2024-07-29T01:09:56.435Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/27/83/859ecdd180cacc13b1f7e857abf8582a64552ea7a061057a6c716e790fce/sphinxcontrib_qthelp-2.0.0-py3-none-any.whl", hash = "sha256:b18a828cdba941ccd6ee8445dbe72ffa3ef8cbe7505d8cd1fa0d42d3f2d5f3eb", size = 88743, upload-time = "2024-07-29T01:09:54.885Z" }, +] + +[[package]] +name = "sphinxcontrib-serializinghtml" +version = "2.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/3b/44/6716b257b0aa6bfd51a1b31665d1c205fb12cb5ad56de752dfa15657de2f/sphinxcontrib_serializinghtml-2.0.0.tar.gz", hash = "sha256:e9d912827f872c029017a53f0ef2180b327c3f7fd23c87229f7a8e8b70031d4d", size = 16080, upload-time = "2024-07-29T01:10:09.332Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/52/a7/d2782e4e3f77c8450f727ba74a8f12756d5ba823d81b941f1b04da9d033a/sphinxcontrib_serializinghtml-2.0.0-py3-none-any.whl", hash = "sha256:6e2cb0eef194e10c27ec0023bfeb25badbbb5868244cf5bc5bdc04e4464bf331", size = 92072, upload-time = "2024-07-29T01:10:08.203Z" }, +] + +[[package]] +name = "tomli" +version = "2.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/52/ed/3f73f72945444548f33eba9a87fc7a6e969915e7b1acc8260b30e1f76a2f/tomli-2.3.0.tar.gz", hash = "sha256:64be704a875d2a59753d80ee8a533c3fe183e3f06807ff7dc2232938ccb01549", size = 17392, upload-time = "2025-10-08T22:01:47.119Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b3/2e/299f62b401438d5fe1624119c723f5d877acc86a4c2492da405626665f12/tomli-2.3.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:88bd15eb972f3664f5ed4b57c1634a97153b4bac4479dcb6a495f41921eb7f45", size = 153236, upload-time = "2025-10-08T22:01:00.137Z" }, + { url = "https://files.pythonhosted.org/packages/86/7f/d8fffe6a7aefdb61bced88fcb5e280cfd71e08939da5894161bd71bea022/tomli-2.3.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:883b1c0d6398a6a9d29b508c331fa56adbcdff647f6ace4dfca0f50e90dfd0ba", size = 148084, upload-time = "2025-10-08T22:01:01.63Z" }, + { url = "https://files.pythonhosted.org/packages/47/5c/24935fb6a2ee63e86d80e4d3b58b222dafaf438c416752c8b58537c8b89a/tomli-2.3.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d1381caf13ab9f300e30dd8feadb3de072aeb86f1d34a8569453ff32a7dea4bf", size = 234832, upload-time = "2025-10-08T22:01:02.543Z" }, + { url = "https://files.pythonhosted.org/packages/89/da/75dfd804fc11e6612846758a23f13271b76d577e299592b4371a4ca4cd09/tomli-2.3.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a0e285d2649b78c0d9027570d4da3425bdb49830a6156121360b3f8511ea3441", size = 242052, upload-time = "2025-10-08T22:01:03.836Z" }, + { url = "https://files.pythonhosted.org/packages/70/8c/f48ac899f7b3ca7eb13af73bacbc93aec37f9c954df3c08ad96991c8c373/tomli-2.3.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:0a154a9ae14bfcf5d8917a59b51ffd5a3ac1fd149b71b47a3a104ca4edcfa845", size = 239555, upload-time = "2025-10-08T22:01:04.834Z" }, + { url = "https://files.pythonhosted.org/packages/ba/28/72f8afd73f1d0e7829bfc093f4cb98ce0a40ffc0cc997009ee1ed94ba705/tomli-2.3.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:74bf8464ff93e413514fefd2be591c3b0b23231a77f901db1eb30d6f712fc42c", size = 245128, upload-time = "2025-10-08T22:01:05.84Z" }, + { url = "https://files.pythonhosted.org/packages/b6/eb/a7679c8ac85208706d27436e8d421dfa39d4c914dcf5fa8083a9305f58d9/tomli-2.3.0-cp311-cp311-win32.whl", hash = "sha256:00b5f5d95bbfc7d12f91ad8c593a1659b6387b43f054104cda404be6bda62456", size = 96445, upload-time = "2025-10-08T22:01:06.896Z" }, + { url = "https://files.pythonhosted.org/packages/0a/fe/3d3420c4cb1ad9cb462fb52967080575f15898da97e21cb6f1361d505383/tomli-2.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:4dc4ce8483a5d429ab602f111a93a6ab1ed425eae3122032db7e9acf449451be", size = 107165, upload-time = "2025-10-08T22:01:08.107Z" }, + { url = "https://files.pythonhosted.org/packages/ff/b7/40f36368fcabc518bb11c8f06379a0fd631985046c038aca08c6d6a43c6e/tomli-2.3.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:d7d86942e56ded512a594786a5ba0a5e521d02529b3826e7761a05138341a2ac", size = 154891, upload-time = "2025-10-08T22:01:09.082Z" }, + { url = "https://files.pythonhosted.org/packages/f9/3f/d9dd692199e3b3aab2e4e4dd948abd0f790d9ded8cd10cbaae276a898434/tomli-2.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:73ee0b47d4dad1c5e996e3cd33b8a76a50167ae5f96a2607cbe8cc773506ab22", size = 148796, upload-time = "2025-10-08T22:01:10.266Z" }, + { url = "https://files.pythonhosted.org/packages/60/83/59bff4996c2cf9f9387a0f5a3394629c7efa5ef16142076a23a90f1955fa/tomli-2.3.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:792262b94d5d0a466afb5bc63c7daa9d75520110971ee269152083270998316f", size = 242121, upload-time = "2025-10-08T22:01:11.332Z" }, + { url = "https://files.pythonhosted.org/packages/45/e5/7c5119ff39de8693d6baab6c0b6dcb556d192c165596e9fc231ea1052041/tomli-2.3.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4f195fe57ecceac95a66a75ac24d9d5fbc98ef0962e09b2eddec5d39375aae52", size = 250070, upload-time = "2025-10-08T22:01:12.498Z" }, + { url = "https://files.pythonhosted.org/packages/45/12/ad5126d3a278f27e6701abde51d342aa78d06e27ce2bb596a01f7709a5a2/tomli-2.3.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:e31d432427dcbf4d86958c184b9bfd1e96b5b71f8eb17e6d02531f434fd335b8", size = 245859, upload-time = "2025-10-08T22:01:13.551Z" }, + { url = "https://files.pythonhosted.org/packages/fb/a1/4d6865da6a71c603cfe6ad0e6556c73c76548557a8d658f9e3b142df245f/tomli-2.3.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:7b0882799624980785240ab732537fcfc372601015c00f7fc367c55308c186f6", size = 250296, upload-time = "2025-10-08T22:01:14.614Z" }, + { url = "https://files.pythonhosted.org/packages/a0/b7/a7a7042715d55c9ba6e8b196d65d2cb662578b4d8cd17d882d45322b0d78/tomli-2.3.0-cp312-cp312-win32.whl", hash = "sha256:ff72b71b5d10d22ecb084d345fc26f42b5143c5533db5e2eaba7d2d335358876", size = 97124, upload-time = "2025-10-08T22:01:15.629Z" }, + { url = "https://files.pythonhosted.org/packages/06/1e/f22f100db15a68b520664eb3328fb0ae4e90530887928558112c8d1f4515/tomli-2.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:1cb4ed918939151a03f33d4242ccd0aa5f11b3547d0cf30f7c74a408a5b99878", size = 107698, upload-time = "2025-10-08T22:01:16.51Z" }, + { url = "https://files.pythonhosted.org/packages/89/48/06ee6eabe4fdd9ecd48bf488f4ac783844fd777f547b8d1b61c11939974e/tomli-2.3.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:5192f562738228945d7b13d4930baffda67b69425a7f0da96d360b0a3888136b", size = 154819, upload-time = "2025-10-08T22:01:17.964Z" }, + { url = "https://files.pythonhosted.org/packages/f1/01/88793757d54d8937015c75dcdfb673c65471945f6be98e6a0410fba167ed/tomli-2.3.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:be71c93a63d738597996be9528f4abe628d1adf5e6eb11607bc8fe1a510b5dae", size = 148766, upload-time = "2025-10-08T22:01:18.959Z" }, + { url = "https://files.pythonhosted.org/packages/42/17/5e2c956f0144b812e7e107f94f1cc54af734eb17b5191c0bbfb72de5e93e/tomli-2.3.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c4665508bcbac83a31ff8ab08f424b665200c0e1e645d2bd9ab3d3e557b6185b", size = 240771, upload-time = "2025-10-08T22:01:20.106Z" }, + { url = "https://files.pythonhosted.org/packages/d5/f4/0fbd014909748706c01d16824eadb0307115f9562a15cbb012cd9b3512c5/tomli-2.3.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4021923f97266babc6ccab9f5068642a0095faa0a51a246a6a02fccbb3514eaf", size = 248586, upload-time = "2025-10-08T22:01:21.164Z" }, + { url = "https://files.pythonhosted.org/packages/30/77/fed85e114bde5e81ecf9bc5da0cc69f2914b38f4708c80ae67d0c10180c5/tomli-2.3.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:a4ea38c40145a357d513bffad0ed869f13c1773716cf71ccaa83b0fa0cc4e42f", size = 244792, upload-time = "2025-10-08T22:01:22.417Z" }, + { url = "https://files.pythonhosted.org/packages/55/92/afed3d497f7c186dc71e6ee6d4fcb0acfa5f7d0a1a2878f8beae379ae0cc/tomli-2.3.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:ad805ea85eda330dbad64c7ea7a4556259665bdf9d2672f5dccc740eb9d3ca05", size = 248909, upload-time = "2025-10-08T22:01:23.859Z" }, + { url = "https://files.pythonhosted.org/packages/f8/84/ef50c51b5a9472e7265ce1ffc7f24cd4023d289e109f669bdb1553f6a7c2/tomli-2.3.0-cp313-cp313-win32.whl", hash = "sha256:97d5eec30149fd3294270e889b4234023f2c69747e555a27bd708828353ab606", size = 96946, upload-time = "2025-10-08T22:01:24.893Z" }, + { url = "https://files.pythonhosted.org/packages/b2/b7/718cd1da0884f281f95ccfa3a6cc572d30053cba64603f79d431d3c9b61b/tomli-2.3.0-cp313-cp313-win_amd64.whl", hash = "sha256:0c95ca56fbe89e065c6ead5b593ee64b84a26fca063b5d71a1122bf26e533999", size = 107705, upload-time = "2025-10-08T22:01:26.153Z" }, + { url = "https://files.pythonhosted.org/packages/77/b8/0135fadc89e73be292b473cb820b4f5a08197779206b33191e801feeae40/tomli-2.3.0-py3-none-any.whl", hash = "sha256:e95b1af3c5b07d9e643909b5abbec77cd9f1217e6d0bca72b0234736b9fb1f1b", size = 14408, upload-time = "2025-10-08T22:01:46.04Z" }, +] + +[[package]] +name = "toolz" +version = "1.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/11/d6/114b492226588d6ff54579d95847662fc69196bdeec318eb45393b24c192/toolz-1.1.0.tar.gz", hash = "sha256:27a5c770d068c110d9ed9323f24f1543e83b2f300a687b7891c1a6d56b697b5b", size = 52613, upload-time = "2025-10-17T04:03:21.661Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fb/12/5911ae3eeec47800503a238d971e51722ccea5feb8569b735184d5fcdbc0/toolz-1.1.0-py3-none-any.whl", hash = "sha256:15ccc861ac51c53696de0a5d6d4607f99c210739caf987b5d2054f3efed429d8", size = 58093, upload-time = "2025-10-17T04:03:20.435Z" }, +] + +[[package]] +name = "typing-extensions" +version = "4.15.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/72/94/1a15dd82efb362ac84269196e94cf00f187f7ed21c242792a923cdb1c61f/typing_extensions-4.15.0.tar.gz", hash = "sha256:0cea48d173cc12fa28ecabc3b837ea3cf6f38c6d1136f85cbaaf598984861466", size = 109391, upload-time = "2025-08-25T13:49:26.313Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/18/67/36e9267722cc04a6b9f15c7f3441c2363321a3ea07da7ae0c0707beb2a9c/typing_extensions-4.15.0-py3-none-any.whl", hash = "sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548", size = 44614, upload-time = "2025-08-25T13:49:24.86Z" }, +] + +[[package]] +name = "urllib3" +version = "2.5.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/15/22/9ee70a2574a4f4599c47dd506532914ce044817c7752a79b6a51286319bc/urllib3-2.5.0.tar.gz", hash = "sha256:3fc47733c7e419d4bc3f6b3dc2b4f890bb743906a30d56ba4a5bfa4bbff92760", size = 393185, upload-time = "2025-06-18T14:07:41.644Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a7/c2/fe1e52489ae3122415c51f387e221dd0773709bad6c6cdaa599e8a2c5185/urllib3-2.5.0-py3-none-any.whl", hash = "sha256:e6b01673c0fa6a13e374b50871808eb3bf7046c4b125b216f6bf1cc604cff0dc", size = 129795, upload-time = "2025-06-18T14:07:40.39Z" }, +] + +[[package]] +name = "virtualenv" +version = "20.35.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "distlib" }, + { name = "filelock" }, + { name = "platformdirs" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/20/28/e6f1a6f655d620846bd9df527390ecc26b3805a0c5989048c210e22c5ca9/virtualenv-20.35.4.tar.gz", hash = "sha256:643d3914d73d3eeb0c552cbb12d7e82adf0e504dbf86a3182f8771a153a1971c", size = 6028799, upload-time = "2025-10-29T06:57:40.511Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/79/0c/c05523fa3181fdf0c9c52a6ba91a23fbf3246cc095f26f6516f9c60e6771/virtualenv-20.35.4-py3-none-any.whl", hash = "sha256:c21c9cede36c9753eeade68ba7d523529f228a403463376cf821eaae2b650f1b", size = 6005095, upload-time = "2025-10-29T06:57:37.598Z" }, +]