Skip to content

Commit d6d7cac

Browse files
Move pack/unpack to shape_ops.py and add JoinDims and SplitDims Ops
1 parent 1497963 commit d6d7cac

File tree

5 files changed

+616
-474
lines changed

5 files changed

+616
-474
lines changed

pytensor/tensor/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ def _get_vector_length_Constant(op: Op | Variable, var: Constant) -> int:
142142
specify_broadcastable,
143143
specify_shape,
144144
)
145+
from pytensor.tensor.shape_ops import *
145146

146147
# We import as `_shared` instead of `shared` to avoid confusion between
147148
# `pytensor.shared` and `tensor._shared`.

pytensor/tensor/extra_ops.py

Lines changed: 4 additions & 220 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
import warnings
2-
from collections.abc import Collection, Iterable, Sequence
3-
from itertools import pairwise
2+
from collections.abc import Collection, Iterable
43
from textwrap import dedent
54

65
import numpy as np
7-
from numpy.lib.array_utils import normalize_axis_index, normalize_axis_tuple
6+
from numpy.lib.array_utils import normalize_axis_index
87

98
import pytensor
109
import pytensor.scalar.basic as ps
@@ -26,7 +25,7 @@
2625
from pytensor.scalar import upcast
2726
from pytensor.tensor import TensorLike, as_tensor_variable
2827
from pytensor.tensor import basic as ptb
29-
from pytensor.tensor.basic import alloc, as_tensor, join, second, split
28+
from pytensor.tensor.basic import alloc, join, second
3029
from pytensor.tensor.exceptions import NotScalarConstantError
3130
from pytensor.tensor.math import abs as pt_abs
3231
from pytensor.tensor.math import all as pt_all
@@ -44,7 +43,7 @@
4443
)
4544
from pytensor.tensor.math import max as pt_max
4645
from pytensor.tensor.math import sum as pt_sum
47-
from pytensor.tensor.shape import Shape_i, ShapeValueType
46+
from pytensor.tensor.shape import Shape_i
4847
from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor
4948
from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes
5049
from pytensor.tensor.utils import normalize_reduce_axis
@@ -2012,221 +2011,6 @@ def concat_with_broadcast(tensor_list, axis=0):
20122011
return join(axis, *bcast_tensor_inputs)
20132012

20142013

2015-
def join_dims(x: Variable, axis: Sequence[int] | int | None = None) -> Variable:
2016-
if axis is None:
2017-
axis = range(x.ndim).tolist()
2018-
if not isinstance(axis, (list, tuple)):
2019-
axis = [axis]
2020-
2021-
if not axis:
2022-
# The user passed an empty list/tuple, so we return the input as is
2023-
return x
2024-
2025-
axis = normalize_axis_tuple(axis, x.ndim)
2026-
2027-
if len(axis) > 1 and np.diff(axis).max() > 1:
2028-
raise ValueError(
2029-
f"join_dims axis must be consecutive, got normalized axis: {axis}"
2030-
)
2031-
2032-
x_shape = tuple(x.shape)
2033-
2034-
return x.reshape((*x_shape[: min(axis)], -1, *x_shape[max(axis) + 1 :]))
2035-
2036-
2037-
def split_dims(x, shape: ShapeValueType, axis: int | None = None) -> Variable:
2038-
if axis is None:
2039-
if x.ndim != 1:
2040-
raise ValueError(
2041-
"split_dims can only be called with axis=None for 1d inputs"
2042-
)
2043-
axis = 0
2044-
2045-
if isinstance(shape, int):
2046-
shape = [shape]
2047-
2048-
shape = list(shape)
2049-
new_shape = list(x.shape)
2050-
2051-
# If we get an empty shape, there is potentially a dummy dimension at the requested axis. This happens for example
2052-
# when splitting a packed tensor that had its dims expanded before packing (e.g. when packing shapes (3, ) and
2053-
# (3, 3) to (3, 4)
2054-
if not shape:
2055-
return x.squeeze(axis=axis)
2056-
2057-
new_shape[axis] = shape.pop(-1)
2058-
for s in shape[::-1]:
2059-
new_shape.insert(axis, s)
2060-
2061-
return x.reshape(tuple(new_shape))
2062-
2063-
2064-
def _analyze_axes_list(axes) -> tuple[int, int, int]:
2065-
"""
2066-
Analyze the provided axes list to determine how many axes are before and after the interval to be raveled, as
2067-
well as the minimum and maximum number of axes that the inputs can have.
2068-
2069-
The rules are:
2070-
- Axes must be strictly increasing in both the positive and negative parts of the list.
2071-
- Negative axes must come after positive axes.
2072-
- There can be at most one "hole" in the axes list, which can be either an implicit hole on an endpoint
2073-
(e.g. [0, 1]) or an explicit hole in the middle (e.g. [0, 2] or [1, -1]).
2074-
2075-
Returns
2076-
-------
2077-
n_axes_before: int
2078-
The number of axes before the interval to be raveled.
2079-
n_axes_after: int
2080-
The number of axes after the interval to be raveled.
2081-
min_axes: int
2082-
The minimum number of axes that the inputs must have.
2083-
"""
2084-
if axes is None:
2085-
return 0, 0, 0
2086-
2087-
if isinstance(axes, int):
2088-
axes = [axes]
2089-
elif not isinstance(axes, Iterable):
2090-
raise TypeError("axes must be an int, an iterable of ints, or None")
2091-
2092-
axes = list(axes)
2093-
2094-
if len(axes) == 0:
2095-
raise ValueError("axes=[] is ambiguous; use None to ravel all")
2096-
2097-
if len(set(axes)) != len(axes):
2098-
raise ValueError("axes must have no duplicates")
2099-
2100-
first_negative_idx = next((i for i, a in enumerate(axes) if a < 0), len(axes))
2101-
positive_axes = list(axes[:first_negative_idx])
2102-
negative_axes = list(axes[first_negative_idx:])
2103-
2104-
if not all(a < 0 for a in negative_axes):
2105-
raise ValueError("Negative axes must come after positive")
2106-
2107-
def strictly_increasing(s):
2108-
return all(b > a for a, b in pairwise(s))
2109-
2110-
if positive_axes and not strictly_increasing(positive_axes):
2111-
raise ValueError("Axes must be strictly increasing in the positive part")
2112-
if negative_axes and not strictly_increasing(negative_axes):
2113-
raise ValueError("Axes must be strictly increasing in the negative part")
2114-
2115-
def find_gaps(s):
2116-
"""Return positions where b - a > 1."""
2117-
return [i for i, (a, b) in enumerate(pairwise(s)) if b - a > 1]
2118-
2119-
pos_gaps = find_gaps(positive_axes)
2120-
neg_gaps = find_gaps(negative_axes)
2121-
2122-
if pos_gaps:
2123-
raise ValueError("Positive axes must be contiguous")
2124-
if neg_gaps:
2125-
raise ValueError("Negative axes must be contiguous")
2126-
2127-
if positive_axes and positive_axes[0] != 0:
2128-
raise ValueError(
2129-
"If positive axes are provided, the first positive axis must be 0 to avoid ambiguity. To ravel indices "
2130-
"starting from the front, use negative axes only."
2131-
)
2132-
2133-
if negative_axes and negative_axes[-1] != -1:
2134-
raise ValueError(
2135-
"If negative axes are provided, the last negative axis must be -1 to avoid ambiguity. To ravel indices "
2136-
"up to the end, use positive axes only."
2137-
)
2138-
2139-
positive_only = positive_axes and not negative_axes
2140-
negative_only = negative_axes and not positive_axes
2141-
mixed_case = positive_axes and negative_axes
2142-
2143-
if positive_only:
2144-
n_before = len(positive_axes)
2145-
n_after = 0
2146-
min_axes = n_before
2147-
2148-
return n_before, n_after, min_axes
2149-
2150-
if negative_only:
2151-
n_before = 0
2152-
n_after = len(negative_axes)
2153-
min_axes = n_after
2154-
2155-
return n_before, n_after, min_axes
2156-
2157-
if mixed_case:
2158-
n_before = len(positive_axes)
2159-
n_after = len(negative_axes)
2160-
min_axes = n_before + n_after
2161-
2162-
return n_before, n_after, min_axes
2163-
2164-
2165-
def pack(*tensors: TensorLike, axes: Sequence[int] | int | None = None):
2166-
n_before, n_after, min_axes = _analyze_axes_list(axes)
2167-
2168-
if all([n_before == 0, n_after == 0, min_axes == 0]):
2169-
# Special case -- we're raveling everything
2170-
packed_shapes = [tensor.shape for tensor in tensors]
2171-
reshaped_tensors = [tensor.ravel() for tensor in tensors]
2172-
2173-
return join(0, *reshaped_tensors), packed_shapes
2174-
2175-
reshaped_tensors: list[TensorLike] = []
2176-
packed_shapes: list[ShapeValueType] = []
2177-
2178-
for i, tensor in enumerate(tensors):
2179-
n_dim = tensor.ndim
2180-
2181-
if n_dim < min_axes:
2182-
raise ValueError(
2183-
f"Input {i} (zero indexed) to pack has {n_dim} dimensions, "
2184-
f"but axes={axes} assumes at least {min_axes} dimension{'s' if min_axes != 1 else ''}."
2185-
)
2186-
2187-
shapes = [
2188-
shape if shape is not None else symbolic_shape
2189-
for shape, symbolic_shape in zip(tensor.type.shape, tensor.shape)
2190-
]
2191-
axis_after_packed_axes = n_dim - n_after
2192-
packed_shapes.append(as_tensor(shapes[n_before:axis_after_packed_axes]))
2193-
2194-
new_shape = (*shapes[:n_before], -1, *shapes[axis_after_packed_axes:])
2195-
2196-
reshaped_tensors.append(tensor.reshape(new_shape))
2197-
2198-
return join(n_before, *reshaped_tensors), packed_shapes
2199-
2200-
2201-
def unpack(packed_input, axes, packed_shapes):
2202-
if axes is None:
2203-
if packed_input.ndim != 1:
2204-
raise ValueError(
2205-
"unpack can only be called with keep_axis=None for 1d inputs"
2206-
)
2207-
split_axis = 0
2208-
else:
2209-
axes = normalize_axis_tuple(axes, ndim=packed_input.ndim)
2210-
try:
2211-
[split_axis] = (i for i in range(packed_input.ndim) if i not in axes)
2212-
except ValueError as err:
2213-
raise ValueError(
2214-
"Unpack must have exactly one more dimension that implied by axes"
2215-
) from err
2216-
2217-
split_inputs = split(
2218-
packed_input,
2219-
splits_size=[prod(shape).astype(int) for shape in packed_shapes],
2220-
n_splits=len(packed_shapes),
2221-
axis=split_axis,
2222-
)
2223-
2224-
return [
2225-
split_dims(inp, shape, split_axis)
2226-
for inp, shape in zip(split_inputs, packed_shapes, strict=True)
2227-
]
2228-
2229-
22302014
__all__ = [
22312015
"bartlett",
22322016
"bincount",

0 commit comments

Comments
 (0)