Skip to content

Commit 0758de4

Browse files
committed
Sort and Argsort: Check axis are integers
1 parent 22cda11 commit 0758de4

File tree

2 files changed

+29
-9
lines changed

2 files changed

+29
-9
lines changed

pytensor/tensor/sort.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,17 @@ def __init__(self, kind: KIND):
4242
def make_node(self, input, axis=-1):
4343
input = as_tensor_variable(input)
4444
axis = as_tensor_variable(axis, ndim=0, dtype=int)
45+
if axis.type.numpy_dtype.kind != "i":
46+
raise ValueError(
47+
f"Sort axis must have an integer dtype, got {axis.type.dtype}"
48+
)
4549
out_type = input.type()
4650
return Apply(self, [input, axis], [out_type])
4751

4852
def perform(self, node, inputs, output_storage):
4953
a, axis = inputs
5054
z = output_storage[0]
51-
z[0] = np.sort(a, int(axis), self.kind)
55+
z[0] = np.sort(a, axis, self.kind)
5256

5357
def infer_shape(self, fgraph, node, inputs_shapes):
5458
assert node.inputs[0].ndim == node.outputs[0].ndim
@@ -163,6 +167,10 @@ def __init__(self, kind: KIND):
163167
def make_node(self, input, axis=-1):
164168
input = as_tensor_variable(input)
165169
axis = as_tensor_variable(axis, ndim=0, dtype=int)
170+
if axis.type.numpy_dtype.kind != "i":
171+
raise ValueError(
172+
f"ArgSort axis must have an integer dtype, got {axis.type.dtype}"
173+
)
166174
return Apply(
167175
self,
168176
[input, axis],
@@ -173,7 +181,7 @@ def perform(self, node, inputs, output_storage):
173181
a, axis = inputs
174182
z = output_storage[0]
175183
z[0] = np.asarray(
176-
np.argsort(a, int(axis), self.kind),
184+
np.argsort(a, axis, self.kind),
177185
dtype=node.outputs[0].dtype,
178186
)
179187

tests/tensor/test_sort.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
dmatrix,
88
dvector,
99
float_dtypes,
10+
fscalar,
1011
integer_dtypes,
1112
lscalar,
1213
matrix,
@@ -31,6 +32,12 @@ def setup_method(self):
3132
self.m_val = self.rng.random((3, 2))
3233
self.v_val = self.rng.random(4)
3334

35+
def test_invalid_axis_dtype(self):
36+
with pytest.raises(
37+
ValueError, match="Sort axis must have an integer dtype, got float32"
38+
):
39+
sort(dmatrix(), fscalar())
40+
3441
def test1(self):
3542
a = dmatrix()
3643
w = sort(a)
@@ -39,7 +46,7 @@ def test1(self):
3946

4047
def test2(self):
4148
a = dmatrix()
42-
axis = scalar()
49+
axis = scalar(dtype="int64")
4350
w = sort(a, axis)
4451
f = pytensor.function([a, axis], w)
4552
for axis_val in 0, 1:
@@ -57,12 +64,12 @@ def test3(self):
5764

5865
def test4(self):
5966
a = dmatrix()
60-
axis = scalar()
67+
axis = scalar(dtype="int8")
6168
l = sort(a, axis, "mergesort")
6269
f = pytensor.function([a, axis], l)
6370
for axis_val in 0, 1:
64-
gv = f(self.m_val, axis_val)
65-
gt = np.sort(self.m_val, axis_val)
71+
gv = f(self.m_val, np.array(axis_val, dtype="int8"))
72+
gt = np.sort(self.m_val, np.array(axis_val, dtype="int8"))
6673
utt.assert_allclose(gv, gt)
6774

6875
def test5(self):
@@ -199,12 +206,12 @@ def test_argsort():
199206

200207
# Example 4
201208
a = dmatrix()
202-
axis = lscalar()
209+
axis = scalar(dtype="int8")
203210
l = argsort(a, axis, "mergesort")
204211
f = pytensor.function([a, axis], l)
205212
for axis_val in 0, 1:
206-
gv = f(m_val, axis_val)
207-
gt = np.argsort(m_val, axis_val)
213+
gv = f(m_val, np.array(axis_val, dtype="int8"))
214+
gt = np.argsort(m_val, np.array(axis_val, dtype="int8"))
208215
utt.assert_allclose(gv, gt)
209216

210217
# Example 5
@@ -222,6 +229,11 @@ def test_argsort():
222229
gt = np.argsort(m_val, None)
223230
utt.assert_allclose(gv, gt)
224231

232+
with pytest.raises(
233+
ValueError, match="ArgSort axis must have an integer dtype, got float32"
234+
):
235+
argsort(dmatrix(), fscalar())
236+
225237

226238
def test_argsort_grad():
227239
rng = np.random.default_rng(seed=utt.fetch_seed())

0 commit comments

Comments
 (0)