Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pytensor/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def _get_vector_length_Constant(op: Op | Variable, var: Constant) -> int:
specify_broadcastable,
specify_shape,
)
from pytensor.tensor.shape_ops import *

# We import as `_shared` instead of `shared` to avoid confusion between
# `pytensor.shared` and `tensor._shared`.
Expand Down
1 change: 1 addition & 0 deletions pytensor/tensor/rewriting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pytensor.tensor.rewriting.math
import pytensor.tensor.rewriting.numba
import pytensor.tensor.rewriting.ofg
import pytensor.tensor.rewriting.reshape_ops
import pytensor.tensor.rewriting.shape
import pytensor.tensor.rewriting.special
import pytensor.tensor.rewriting.subtensor
Expand Down
42 changes: 42 additions & 0 deletions pytensor/tensor/rewriting/reshape_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from pytensor.graph import node_rewriter
from pytensor.tensor.rewriting.basic import register_canonicalize
from pytensor.tensor.shape_ops import JoinDims, SplitDims


@register_canonicalize
@node_rewriter([SplitDims])
def local_split_dims_to_reshape(fgraph, node):
"""
Canonicalize SplitDims Ops to Reshape Ops for further graph reasoning (and dispatch to other backends).
"""

x, shape = node.inputs
axis = node.op.axis

output_shape = [
*x.shape[:axis],
*shape,
*x.shape[axis + 1 :],
]

return [x.reshape(output_shape)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copy stack trace in the rewrites



@register_canonicalize
@node_rewriter([JoinDims])
def local_join_dims_to_reshape(fgraph, node):
"""
Canonicalize JoinDims Ops to Reshape Ops for further graph reasoning (and dispatch to other backends).
"""

(x,) = node.inputs
start_axis = node.op.start_axis
n_axes = node.op.n_axes

output_shape = [
*x.shape[:start_axis],
-1,
*x.shape[start_axis + n_axes :],
]

return [x.reshape(output_shape)]
Loading
Loading