Skip to content

Commit 3ce54c7

Browse files
committed
Finish Copilot code
1 parent 53adf9a commit 3ce54c7

File tree

8 files changed

+460
-135
lines changed

8 files changed

+460
-135
lines changed

pytensor/link/jax/dispatch/subtensor.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,18 @@
3131
"""
3232

3333

34+
@jax_funcify.register(AdvancedSubtensor1)
35+
def jax_funcify_AdvancedSubtensor1(op, node, **kwargs):
36+
def advanced_subtensor1(x, ilist):
37+
return x[ilist]
38+
39+
return advanced_subtensor1
40+
41+
3442
@jax_funcify.register(Subtensor)
3543
@jax_funcify.register(AdvancedSubtensor)
36-
@jax_funcify.register(AdvancedSubtensor1)
3744
def jax_funcify_Subtensor(op, node, **kwargs):
38-
idx_list = getattr(op, "idx_list", None)
45+
idx_list = op.idx_list
3946

4047
def subtensor(x, *ilists):
4148
indices = indices_from_subtensor(ilists, idx_list)
@@ -47,10 +54,24 @@ def subtensor(x, *ilists):
4754
return subtensor
4855

4956

50-
@jax_funcify.register(IncSubtensor)
5157
@jax_funcify.register(AdvancedIncSubtensor1)
58+
def jax_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
59+
if getattr(op, "set_instead_of_inc", False):
60+
61+
def jax_fn(x, y, ilist):
62+
return x.at[ilist].set(y)
63+
64+
else:
65+
66+
def jax_fn(x, y, ilist):
67+
return x.at[ilist].add(y)
68+
69+
return jax_fn
70+
71+
72+
@jax_funcify.register(IncSubtensor)
5273
def jax_funcify_IncSubtensor(op, node, **kwargs):
53-
idx_list = getattr(op, "idx_list", None)
74+
idx_list = op.idx_list
5475

5576
if getattr(op, "set_instead_of_inc", False):
5677

@@ -77,8 +98,8 @@ def incsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=idx_list):
7798

7899
@jax_funcify.register(AdvancedIncSubtensor)
79100
def jax_funcify_AdvancedIncSubtensor(op, node, **kwargs):
80-
idx_list = getattr(op, "idx_list", None)
81-
101+
idx_list = op.idx_list
102+
82103
if getattr(op, "set_instead_of_inc", False):
83104

84105
def jax_fn(x, indices, y):

pytensor/link/numba/dispatch/subtensor.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
)
2121
from pytensor.link.numba.dispatch.compile_ops import numba_deepcopy
2222
from pytensor.tensor import TensorType
23-
from pytensor.tensor.rewriting.subtensor import is_full_slice
2423
from pytensor.tensor.subtensor import (
2524
AdvancedIncSubtensor,
2625
AdvancedIncSubtensor1,
@@ -247,7 +246,7 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
247246
basic_idxs = []
248247
adv_idxs = []
249248
input_idx = 0
250-
249+
251250
for i, entry in enumerate(op.idx_list):
252251
if isinstance(entry, slice):
253252
# Basic slice index
@@ -256,12 +255,14 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
256255
# Advanced tensor index
257256
if input_idx < len(tensor_inputs):
258257
idx_input = tensor_inputs[input_idx]
259-
adv_idxs.append({
260-
"axis": i,
261-
"dtype": idx_input.type.dtype,
262-
"bcast": idx_input.type.broadcastable,
263-
"ndim": idx_input.type.ndim,
264-
})
258+
adv_idxs.append(
259+
{
260+
"axis": i,
261+
"dtype": idx_input.type.dtype,
262+
"bcast": idx_input.type.broadcastable,
263+
"ndim": idx_input.type.ndim,
264+
}
265+
)
265266
input_idx += 1
266267

267268
# Special implementation for consecutive integer vector indices

pytensor/link/pytorch/dispatch/subtensor.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
Subtensor,
1010
indices_from_subtensor,
1111
)
12-
from pytensor.tensor.type_other import MakeSlice, SliceType
12+
from pytensor.tensor.type_other import MakeSlice
1313

1414

1515
def check_negative_steps(indices):
@@ -63,8 +63,8 @@ def makeslice(start, stop, step):
6363
@pytorch_funcify.register(AdvancedSubtensor1)
6464
@pytorch_funcify.register(AdvancedSubtensor)
6565
def pytorch_funcify_AdvSubtensor(op, node, **kwargs):
66-
idx_list = getattr(op, "idx_list", None)
67-
66+
idx_list = op.idx_list
67+
6868
def advsubtensor(x, *flattened_indices):
6969
indices = indices_from_subtensor(flattened_indices, idx_list)
7070
check_negative_steps(indices)
@@ -105,7 +105,7 @@ def inc_subtensor(x, y, *flattened_indices):
105105
@pytorch_funcify.register(AdvancedIncSubtensor)
106106
@pytorch_funcify.register(AdvancedIncSubtensor1)
107107
def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs):
108-
idx_list = getattr(op, "idx_list", None)
108+
idx_list = op.idx_list
109109
inplace = op.inplace
110110
ignore_duplicates = getattr(op, "ignore_duplicates", False)
111111

@@ -139,7 +139,9 @@ def adv_inc_subtensor_no_duplicates(x, y, *flattened_indices):
139139

140140
else:
141141
# Check if we have slice indexing in idx_list
142-
has_slice_indexing = any(isinstance(entry, slice) for entry in idx_list) if idx_list else False
142+
has_slice_indexing = (
143+
any(isinstance(entry, slice) for entry in idx_list) if idx_list else False
144+
)
143145
if has_slice_indexing:
144146
raise NotImplementedError(
145147
"IncSubtensor with potential duplicates indexes and slice indexing not implemented in PyTorch"

pytensor/tensor/basic.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1818,6 +1818,33 @@ def do_constant_folding(self, fgraph, node):
18181818
return True
18191819

18201820

1821+
@_vectorize_node.register(Alloc)
1822+
def vectorize_alloc(op: Alloc, node: Apply, batch_val, *batch_shapes):
1823+
# batch_shapes are usually not batched (they are scalars for the shape)
1824+
# batch_val is the value being allocated.
1825+
1826+
# If shapes are batched, we fall back (complex case)
1827+
if any(
1828+
b_shp.type.ndim > shp.type.ndim
1829+
for b_shp, shp in zip(batch_shapes, node.inputs[1:], strict=True)
1830+
):
1831+
return vectorize_node_fallback(op, node, batch_val, *batch_shapes)
1832+
1833+
# If value is batched, we need to prepend batch dims to the output shape
1834+
val = node.inputs[0]
1835+
batch_ndim = batch_val.type.ndim - val.type.ndim
1836+
1837+
if batch_ndim == 0:
1838+
return op.make_node(batch_val, *batch_shapes)
1839+
1840+
# We need the size of the batch dimensions
1841+
# batch_val has shape (B1, B2, ..., val_dims...)
1842+
batch_dims = [batch_val.shape[i] for i in range(batch_ndim)]
1843+
1844+
new_shapes = batch_dims + list(batch_shapes)
1845+
return op.make_node(batch_val, *new_shapes)
1846+
1847+
18211848
alloc = Alloc()
18221849
pprint.assign(alloc, printing.FunctionPrinter(["alloc"]))
18231850

pytensor/tensor/rewriting/subtensor.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
in2out,
1515
node_rewriter,
1616
)
17+
from pytensor.graph.type import Type
1718
from pytensor.raise_op import Assert
1819
from pytensor.scalar import Add, ScalarConstant, ScalarType
1920
from pytensor.scalar import constant as scalar_constant
@@ -229,7 +230,7 @@ def local_replace_AdvancedSubtensor(fgraph, node):
229230

230231
indexed_var = node.inputs[0]
231232
tensor_inputs = node.inputs[1:]
232-
233+
233234
# Reconstruct indices from idx_list and tensor inputs
234235
indices = []
235236
input_idx = 0
@@ -267,7 +268,7 @@ def local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1(fgraph, node):
267268
res = node.inputs[0]
268269
val = node.inputs[1]
269270
tensor_inputs = node.inputs[2:]
270-
271+
271272
# Reconstruct indices from idx_list and tensor inputs
272273
indices = []
273274
input_idx = 0
@@ -1112,6 +1113,7 @@ def local_inplace_AdvancedIncSubtensor1(fgraph, node):
11121113
def local_inplace_AdvancedIncSubtensor(fgraph, node):
11131114
if isinstance(node.op, AdvancedIncSubtensor) and not node.op.inplace:
11141115
new_op = type(node.op)(
1116+
node.op.idx_list,
11151117
inplace=True,
11161118
set_instead_of_inc=node.op.set_instead_of_inc,
11171119
ignore_duplicates=node.op.ignore_duplicates,
@@ -1376,6 +1378,7 @@ def local_useless_inc_subtensor_alloc(fgraph, node):
13761378
z_broad[k]
13771379
and not same_shape(xi, y, dim_x=k, dim_y=k)
13781380
and shape_of[y][k] != 1
1381+
and shape_of[xi][k] == 1
13791382
)
13801383
]
13811384

@@ -1778,7 +1781,7 @@ def ravel_multidimensional_bool_idx(fgraph, node):
17781781
else:
17791782
x, y = node.inputs[0], node.inputs[1]
17801783
tensor_inputs = node.inputs[2:]
1781-
1784+
17821785
# Reconstruct indices from idx_list and tensor inputs
17831786
idxs = []
17841787
input_idx = 0
@@ -1829,36 +1832,36 @@ def ravel_multidimensional_bool_idx(fgraph, node):
18291832
# Create new AdvancedSubtensor with updated idx_list
18301833
new_idx_list = list(node.op.idx_list)
18311834
new_tensor_inputs = list(tensor_inputs)
1832-
1835+
18331836
# Update the idx_list and tensor_inputs for the raveled boolean index
18341837
input_idx = 0
18351838
for i, entry in enumerate(node.op.idx_list):
18361839
if isinstance(entry, Type):
18371840
if input_idx == bool_idx_pos:
18381841
new_tensor_inputs[input_idx] = raveled_bool_idx
18391842
input_idx += 1
1840-
1843+
18411844
new_out = AdvancedSubtensor(new_idx_list)(raveled_x, *new_tensor_inputs)
18421845
else:
18431846
# Create new AdvancedIncSubtensor with updated idx_list
18441847
new_idx_list = list(node.op.idx_list)
18451848
new_tensor_inputs = list(tensor_inputs)
1846-
1849+
18471850
# Update the tensor_inputs for the raveled boolean index
18481851
input_idx = 0
18491852
for i, entry in enumerate(node.op.idx_list):
18501853
if isinstance(entry, Type):
18511854
if input_idx == bool_idx_pos:
18521855
new_tensor_inputs[input_idx] = raveled_bool_idx
18531856
input_idx += 1
1854-
1857+
18551858
# The dimensions of y that correspond to the boolean indices
18561859
# must already be raveled in the original graph, so we don't need to do anything to it
18571860
new_out = AdvancedIncSubtensor(
18581861
new_idx_list,
18591862
inplace=node.op.inplace,
18601863
set_instead_of_inc=node.op.set_instead_of_inc,
1861-
ignore_duplicates=node.op.ignore_duplicates
1864+
ignore_duplicates=node.op.ignore_duplicates,
18621865
)(raveled_x, y, *new_tensor_inputs)
18631866
# But we must reshape the output to match the original shape
18641867
new_out = new_out.reshape(x_shape)

0 commit comments

Comments
 (0)