Skip to content

Commit ae499a4

Browse files
committed
Do not recreate Scalar Ops with custom TransferType for Elemwise inplacing
This helper could arbitrarily override the default output_type from `ScalarOp.make_node` so that the output type matched one of the input types. This can be used to create artificial Op signatures that don't make sense or can't be cleanly implemented in other backends. For instance an Add with signature (int8,int64)->int8. This helper was historically used in: 1. Elemwise inplace rewrite. I assume as a preventive measure. However, regular use should never require changing the ScalarOp signature, as we only try to inplace on inputs that match the output dtype and recreating the same Op with the same input types should always return the same output type. ScalarOp don't have a concept of inplace, only the Elemwise wrapper does, and it shouldn't require recreating/mutating the original Op. 2. SecondOp. Here it makes sense, but a custom static_method works just as well 3. Inplace tests with the inplace variants of `@scalar_elemwise` decorator. These test Classes were removed. It still didn't make sense to test/force Ops to have an artifical signature for the sake of tests. They were removed anyway
1 parent e0a2a86 commit ae499a4

File tree

5 files changed

+40
-73
lines changed

5 files changed

+40
-73
lines changed

pytensor/scalar/basic.py

Lines changed: 5 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1101,30 +1101,6 @@ def same_out_float_only(type) -> tuple[ScalarType]:
11011101
return (type,)
11021102

11031103

1104-
class transfer_type(MetaObject):
1105-
__props__ = ("transfer",)
1106-
1107-
def __init__(self, *transfer):
1108-
assert all(isinstance(x, int | str) or x is None for x in transfer)
1109-
self.transfer = transfer
1110-
1111-
def __str__(self):
1112-
return f"transfer_type{self.transfer}"
1113-
1114-
def __call__(self, *types):
1115-
upcast = upcast_out(*types)
1116-
retval = []
1117-
for i in self.transfer:
1118-
if i is None:
1119-
retval += [upcast]
1120-
elif isinstance(i, str):
1121-
retval += [i]
1122-
else:
1123-
retval += [types[i]]
1124-
return retval
1125-
# return [upcast if i is None else types[i] for i in self.transfer]
1126-
1127-
11281104
class specific_out(MetaObject):
11291105
__props__ = ("spec",)
11301106

@@ -2446,6 +2422,10 @@ def handle_int(v):
24462422

24472423

24482424
class Second(BinaryScalarOp):
2425+
@staticmethod
2426+
def output_types_preference(_first_type, second_type):
2427+
return [second_type]
2428+
24492429
def impl(self, x, y):
24502430
return y
24512431

@@ -2474,7 +2454,7 @@ def grad(self, inputs, gout):
24742454
return DisconnectedType()(), y.zeros_like(dtype=config.floatX)
24752455

24762456

2477-
second = Second(transfer_type(1), name="second")
2457+
second = Second(name="second")
24782458

24792459

24802460
class Identity(UnaryScalarOp):
@@ -2515,18 +2495,6 @@ def clone_float32(self):
25152495
return convert_to_float32
25162496
return self
25172497

2518-
def make_new_inplace(self, output_types_preference=None, name=None):
2519-
"""
2520-
This op.__init__ fct don't have the same parameter as other scalar op.
2521-
This breaks the insert_inplace_optimizer optimization.
2522-
This function is a fix to patch this, by ignoring the
2523-
output_types_preference passed by the optimization, and replacing it
2524-
by the current output type. This should only be triggered when
2525-
both input and output have the same dtype anyway.
2526-
2527-
"""
2528-
return self.__class__(self.o_type, name)
2529-
25302498
def impl(self, input):
25312499
return self.ctor(input)
25322500

@@ -4322,22 +4290,6 @@ def __str__(self):
43224290

43234291
return self._name
43244292

4325-
def make_new_inplace(self, output_types_preference=None, name=None):
4326-
"""
4327-
This op.__init__ fct don't have the same parameter as other scalar op.
4328-
This break the insert_inplace_optimizer optimization.
4329-
This fct allow fix patch this.
4330-
4331-
"""
4332-
d = {k: getattr(self, k) for k in self.init_param}
4333-
out = self.__class__(**d)
4334-
if name:
4335-
out.name = name
4336-
else:
4337-
name = out.name
4338-
super(Composite, out).__init__(output_types_preference, name)
4339-
return out
4340-
43414293
@property
43424294
def fgraph(self):
43434295
if hasattr(self, "_fgraph"):

pytensor/scalar/loop.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,9 +136,6 @@ def clone(self, name=None, **kwargs):
136136
def fn(self):
137137
raise NotImplementedError
138138

139-
def make_new_inplace(self, output_types_preference=None, name=None):
140-
return self.clone(output_types_preference=output_types_preference, name=name)
141-
142139
def make_node(self, n_steps, *inputs):
143140
assert len(inputs) == self.nin - 1
144141

pytensor/tensor/rewriting/elemwise.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
Mul,
3636
ScalarOp,
3737
get_scalar_type,
38-
transfer_type,
3938
upcast_out,
4039
upgrade_to_float,
4140
)
@@ -287,22 +286,17 @@ def create_inplace_node(self, node, inplace_pattern):
287286
op = node.op
288287
scalar_op = op.scalar_op
289288
inplace_pattern = {i: o for i, [o] in inplace_pattern.items()}
290-
if hasattr(scalar_op, "make_new_inplace"):
291-
new_scalar_op = scalar_op.make_new_inplace(
292-
transfer_type(
293-
*[
294-
inplace_pattern.get(i, o.dtype)
295-
for i, o in enumerate(node.outputs)
296-
]
289+
try:
290+
return type(op)(scalar_op, inplace_pattern).make_node(*node.inputs)
291+
except TypeError:
292+
# Elemwise raises TypeError if we try to inplace an output on an input of a different dtype
293+
if config.optimizer_verbose:
294+
print( # noqa: T201
295+
f"InplaceElemwise failed because the output dtype of {node} changed when rebuilt. "
296+
"Perhaps due to a change in config.floatX or config.cast_policy"
297297
)
298-
)
299-
else:
300-
new_scalar_op = type(scalar_op)(
301-
transfer_type(
302-
*[inplace_pattern.get(i, None) for i in range(len(node.outputs))]
303-
)
304-
)
305-
return type(op)(new_scalar_op, inplace_pattern).make_node(*node.inputs)
298+
# InplaceGraphOptimizer will chug along fine if we return the original node
299+
return node
306300

307301

308302
optdb.register(

tests/tensor/test_basic.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2797,7 +2797,6 @@ def test_infer_shape(self, cast_policy):
27972797
out = arange(start, stop, 1)
27982798
f = function([start, stop], out.shape, mode=mode)
27992799
assert len(f.maker.fgraph.toposort()) == 5
2800-
# 4 [Elemwise{sub,no_inplace}(stop, start), Elemwise{Cast{int64}}(Elemwise{sub,no_inplace}.0), Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)](Elemwise{Cast{int64}}.0, 0), MakeVector(Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)].0)]
28012800
if config.cast_policy == "custom":
28022801
assert out.dtype == "int64"
28032802
elif config.cast_policy == "numpy+floatX":

tests/tensor/test_elemwise.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1200,3 +1200,28 @@ def test_XOR_inplace():
12001200
_ = gn(l, r)
12011201
# test the in-place stuff
12021202
assert np.all(l == np.asarray([0, 1, 1, 0])), l
1203+
1204+
1205+
def test_inplace_dtype_changed():
1206+
with pytensor.config.change_flags(cast_policy="numpy+floatX", floatX="float64"):
1207+
x = pt.vector("x", dtype="float32")
1208+
y = pt.vector("y", dtype="int32")
1209+
with pytensor.config.change_flags(floatX="float32"):
1210+
out = pt.add(x, y)
1211+
1212+
assert out.dtype == "float32"
1213+
with pytensor.config.change_flags(floatX="float32"):
1214+
fn32 = pytensor.function(
1215+
[In(x, mutable=True), In(y, mutable=True)],
1216+
out,
1217+
mode="fast_run",
1218+
)
1219+
assert fn32.maker.fgraph.outputs[0].owner.op.destroy_map == {0: [0]}
1220+
1221+
with pytensor.config.change_flags(floatX="float64"):
1222+
fn64 = pytensor.function(
1223+
[In(x, mutable=True), In(y, mutable=True)],
1224+
out,
1225+
mode="fast_run",
1226+
)
1227+
assert fn64.maker.fgraph.outputs[0].owner.op.destroy_map == {}

0 commit comments

Comments
 (0)