|
1 | 1 | import warnings |
2 | | -from collections.abc import Collection, Iterable, Sequence |
3 | | -from itertools import pairwise |
| 2 | +from collections.abc import Collection, Iterable |
4 | 3 | from textwrap import dedent |
5 | 4 |
|
6 | 5 | import numpy as np |
7 | | -from numpy.lib.array_utils import normalize_axis_index, normalize_axis_tuple |
| 6 | +from numpy.lib.array_utils import normalize_axis_index |
8 | 7 |
|
9 | 8 | import pytensor |
10 | 9 | import pytensor.scalar.basic as ps |
|
26 | 25 | from pytensor.scalar import upcast |
27 | 26 | from pytensor.tensor import TensorLike, as_tensor_variable |
28 | 27 | from pytensor.tensor import basic as ptb |
29 | | -from pytensor.tensor.basic import alloc, as_tensor, join, second, split |
| 28 | +from pytensor.tensor.basic import alloc, join, second |
30 | 29 | from pytensor.tensor.exceptions import NotScalarConstantError |
31 | 30 | from pytensor.tensor.math import abs as pt_abs |
32 | 31 | from pytensor.tensor.math import all as pt_all |
|
44 | 43 | ) |
45 | 44 | from pytensor.tensor.math import max as pt_max |
46 | 45 | from pytensor.tensor.math import sum as pt_sum |
47 | | -from pytensor.tensor.shape import Shape_i, ShapeValueType |
| 46 | +from pytensor.tensor.shape import Shape_i |
48 | 47 | from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor |
49 | 48 | from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes |
50 | 49 | from pytensor.tensor.utils import normalize_reduce_axis |
@@ -2012,221 +2011,6 @@ def concat_with_broadcast(tensor_list, axis=0): |
2012 | 2011 | return join(axis, *bcast_tensor_inputs) |
2013 | 2012 |
|
2014 | 2013 |
|
2015 | | -def join_dims(x: Variable, axis: Sequence[int] | int | None = None) -> Variable: |
2016 | | - if axis is None: |
2017 | | - axis = range(x.ndim).tolist() |
2018 | | - if not isinstance(axis, (list, tuple)): |
2019 | | - axis = [axis] |
2020 | | - |
2021 | | - if not axis: |
2022 | | - # The user passed an empty list/tuple, so we return the input as is |
2023 | | - return x |
2024 | | - |
2025 | | - axis = normalize_axis_tuple(axis, x.ndim) |
2026 | | - |
2027 | | - if len(axis) > 1 and np.diff(axis).max() > 1: |
2028 | | - raise ValueError( |
2029 | | - f"join_dims axis must be consecutive, got normalized axis: {axis}" |
2030 | | - ) |
2031 | | - |
2032 | | - x_shape = tuple(x.shape) |
2033 | | - |
2034 | | - return x.reshape((*x_shape[: min(axis)], -1, *x_shape[max(axis) + 1 :])) |
2035 | | - |
2036 | | - |
2037 | | -def split_dims(x, shape: ShapeValueType, axis: int | None = None) -> Variable: |
2038 | | - if axis is None: |
2039 | | - if x.ndim != 1: |
2040 | | - raise ValueError( |
2041 | | - "split_dims can only be called with axis=None for 1d inputs" |
2042 | | - ) |
2043 | | - axis = 0 |
2044 | | - |
2045 | | - if isinstance(shape, int): |
2046 | | - shape = [shape] |
2047 | | - |
2048 | | - shape = list(shape) |
2049 | | - new_shape = list(x.shape) |
2050 | | - |
2051 | | - # If we get an empty shape, there is potentially a dummy dimension at the requested axis. This happens for example |
2052 | | - # when splitting a packed tensor that had its dims expanded before packing (e.g. when packing shapes (3, ) and |
2053 | | - # (3, 3) to (3, 4) |
2054 | | - if not shape: |
2055 | | - return x.squeeze(axis=axis) |
2056 | | - |
2057 | | - new_shape[axis] = shape.pop(-1) |
2058 | | - for s in shape[::-1]: |
2059 | | - new_shape.insert(axis, s) |
2060 | | - |
2061 | | - return x.reshape(tuple(new_shape)) |
2062 | | - |
2063 | | - |
2064 | | -def _analyze_axes_list(axes) -> tuple[int, int, int]: |
2065 | | - """ |
2066 | | - Analyze the provided axes list to determine how many axes are before and after the interval to be raveled, as |
2067 | | - well as the minimum and maximum number of axes that the inputs can have. |
2068 | | -
|
2069 | | - The rules are: |
2070 | | - - Axes must be strictly increasing in both the positive and negative parts of the list. |
2071 | | - - Negative axes must come after positive axes. |
2072 | | - - There can be at most one "hole" in the axes list, which can be either an implicit hole on an endpoint |
2073 | | - (e.g. [0, 1]) or an explicit hole in the middle (e.g. [0, 2] or [1, -1]). |
2074 | | -
|
2075 | | - Returns |
2076 | | - ------- |
2077 | | - n_axes_before: int |
2078 | | - The number of axes before the interval to be raveled. |
2079 | | - n_axes_after: int |
2080 | | - The number of axes after the interval to be raveled. |
2081 | | - min_axes: int |
2082 | | - The minimum number of axes that the inputs must have. |
2083 | | - """ |
2084 | | - if axes is None: |
2085 | | - return 0, 0, 0 |
2086 | | - |
2087 | | - if isinstance(axes, int): |
2088 | | - axes = [axes] |
2089 | | - elif not isinstance(axes, Iterable): |
2090 | | - raise TypeError("axes must be an int, an iterable of ints, or None") |
2091 | | - |
2092 | | - axes = list(axes) |
2093 | | - |
2094 | | - if len(axes) == 0: |
2095 | | - raise ValueError("axes=[] is ambiguous; use None to ravel all") |
2096 | | - |
2097 | | - if len(set(axes)) != len(axes): |
2098 | | - raise ValueError("axes must have no duplicates") |
2099 | | - |
2100 | | - first_negative_idx = next((i for i, a in enumerate(axes) if a < 0), len(axes)) |
2101 | | - positive_axes = list(axes[:first_negative_idx]) |
2102 | | - negative_axes = list(axes[first_negative_idx:]) |
2103 | | - |
2104 | | - if not all(a < 0 for a in negative_axes): |
2105 | | - raise ValueError("Negative axes must come after positive") |
2106 | | - |
2107 | | - def strictly_increasing(s): |
2108 | | - return all(b > a for a, b in pairwise(s)) |
2109 | | - |
2110 | | - if positive_axes and not strictly_increasing(positive_axes): |
2111 | | - raise ValueError("Axes must be strictly increasing in the positive part") |
2112 | | - if negative_axes and not strictly_increasing(negative_axes): |
2113 | | - raise ValueError("Axes must be strictly increasing in the negative part") |
2114 | | - |
2115 | | - def find_gaps(s): |
2116 | | - """Return positions where b - a > 1.""" |
2117 | | - return [i for i, (a, b) in enumerate(pairwise(s)) if b - a > 1] |
2118 | | - |
2119 | | - pos_gaps = find_gaps(positive_axes) |
2120 | | - neg_gaps = find_gaps(negative_axes) |
2121 | | - |
2122 | | - if pos_gaps: |
2123 | | - raise ValueError("Positive axes must be contiguous") |
2124 | | - if neg_gaps: |
2125 | | - raise ValueError("Negative axes must be contiguous") |
2126 | | - |
2127 | | - if positive_axes and positive_axes[0] != 0: |
2128 | | - raise ValueError( |
2129 | | - "If positive axes are provided, the first positive axis must be 0 to avoid ambiguity. To ravel indices " |
2130 | | - "starting from the front, use negative axes only." |
2131 | | - ) |
2132 | | - |
2133 | | - if negative_axes and negative_axes[-1] != -1: |
2134 | | - raise ValueError( |
2135 | | - "If negative axes are provided, the last negative axis must be -1 to avoid ambiguity. To ravel indices " |
2136 | | - "up to the end, use positive axes only." |
2137 | | - ) |
2138 | | - |
2139 | | - positive_only = positive_axes and not negative_axes |
2140 | | - negative_only = negative_axes and not positive_axes |
2141 | | - mixed_case = positive_axes and negative_axes |
2142 | | - |
2143 | | - if positive_only: |
2144 | | - n_before = len(positive_axes) |
2145 | | - n_after = 0 |
2146 | | - min_axes = n_before |
2147 | | - |
2148 | | - return n_before, n_after, min_axes |
2149 | | - |
2150 | | - if negative_only: |
2151 | | - n_before = 0 |
2152 | | - n_after = len(negative_axes) |
2153 | | - min_axes = n_after |
2154 | | - |
2155 | | - return n_before, n_after, min_axes |
2156 | | - |
2157 | | - if mixed_case: |
2158 | | - n_before = len(positive_axes) |
2159 | | - n_after = len(negative_axes) |
2160 | | - min_axes = n_before + n_after |
2161 | | - |
2162 | | - return n_before, n_after, min_axes |
2163 | | - |
2164 | | - |
2165 | | -def pack(*tensors: TensorLike, axes: Sequence[int] | int | None = None): |
2166 | | - n_before, n_after, min_axes = _analyze_axes_list(axes) |
2167 | | - |
2168 | | - if all([n_before == 0, n_after == 0, min_axes == 0]): |
2169 | | - # Special case -- we're raveling everything |
2170 | | - packed_shapes = [tensor.shape for tensor in tensors] |
2171 | | - reshaped_tensors = [tensor.ravel() for tensor in tensors] |
2172 | | - |
2173 | | - return join(0, *reshaped_tensors), packed_shapes |
2174 | | - |
2175 | | - reshaped_tensors: list[TensorLike] = [] |
2176 | | - packed_shapes: list[ShapeValueType] = [] |
2177 | | - |
2178 | | - for i, tensor in enumerate(tensors): |
2179 | | - n_dim = tensor.ndim |
2180 | | - |
2181 | | - if n_dim < min_axes: |
2182 | | - raise ValueError( |
2183 | | - f"Input {i} (zero indexed) to pack has {n_dim} dimensions, " |
2184 | | - f"but axes={axes} assumes at least {min_axes} dimension{'s' if min_axes != 1 else ''}." |
2185 | | - ) |
2186 | | - |
2187 | | - shapes = [ |
2188 | | - shape if shape is not None else symbolic_shape |
2189 | | - for shape, symbolic_shape in zip(tensor.type.shape, tensor.shape) |
2190 | | - ] |
2191 | | - axis_after_packed_axes = n_dim - n_after |
2192 | | - packed_shapes.append(as_tensor(shapes[n_before:axis_after_packed_axes])) |
2193 | | - |
2194 | | - new_shape = (*shapes[:n_before], -1, *shapes[axis_after_packed_axes:]) |
2195 | | - |
2196 | | - reshaped_tensors.append(tensor.reshape(new_shape)) |
2197 | | - |
2198 | | - return join(n_before, *reshaped_tensors), packed_shapes |
2199 | | - |
2200 | | - |
2201 | | -def unpack(packed_input, axes, packed_shapes): |
2202 | | - if axes is None: |
2203 | | - if packed_input.ndim != 1: |
2204 | | - raise ValueError( |
2205 | | - "unpack can only be called with keep_axis=None for 1d inputs" |
2206 | | - ) |
2207 | | - split_axis = 0 |
2208 | | - else: |
2209 | | - axes = normalize_axis_tuple(axes, ndim=packed_input.ndim) |
2210 | | - try: |
2211 | | - [split_axis] = (i for i in range(packed_input.ndim) if i not in axes) |
2212 | | - except ValueError as err: |
2213 | | - raise ValueError( |
2214 | | - "Unpack must have exactly one more dimension that implied by axes" |
2215 | | - ) from err |
2216 | | - |
2217 | | - split_inputs = split( |
2218 | | - packed_input, |
2219 | | - splits_size=[prod(shape).astype(int) for shape in packed_shapes], |
2220 | | - n_splits=len(packed_shapes), |
2221 | | - axis=split_axis, |
2222 | | - ) |
2223 | | - |
2224 | | - return [ |
2225 | | - split_dims(inp, shape, split_axis) |
2226 | | - for inp, shape in zip(split_inputs, packed_shapes, strict=True) |
2227 | | - ] |
2228 | | - |
2229 | | - |
2230 | 2014 | __all__ = [ |
2231 | 2015 | "bartlett", |
2232 | 2016 | "bincount", |
|
0 commit comments