77 dmatrix ,
88 dvector ,
99 float_dtypes ,
10+ fscalar ,
1011 integer_dtypes ,
1112 lscalar ,
1213 matrix ,
@@ -31,6 +32,12 @@ def setup_method(self):
3132 self .m_val = self .rng .random ((3 , 2 ))
3233 self .v_val = self .rng .random (4 )
3334
35+ def test_invalid_axis_dtype (self ):
36+ with pytest .raises (
37+ ValueError , match = "Sort axis must have an integer dtype, got float32"
38+ ):
39+ sort (dmatrix (), fscalar ())
40+
3441 def test1 (self ):
3542 a = dmatrix ()
3643 w = sort (a )
@@ -39,7 +46,7 @@ def test1(self):
3946
4047 def test2 (self ):
4148 a = dmatrix ()
42- axis = scalar ()
49+ axis = scalar (dtype = "int64" )
4350 w = sort (a , axis )
4451 f = pytensor .function ([a , axis ], w )
4552 for axis_val in 0 , 1 :
@@ -57,12 +64,12 @@ def test3(self):
5764
5865 def test4 (self ):
5966 a = dmatrix ()
60- axis = scalar ()
67+ axis = scalar (dtype = "int8" )
6168 l = sort (a , axis , "mergesort" )
6269 f = pytensor .function ([a , axis ], l )
6370 for axis_val in 0 , 1 :
64- gv = f (self .m_val , axis_val )
65- gt = np .sort (self .m_val , axis_val )
71+ gv = f (self .m_val , np . array ( axis_val , dtype = "int8" ) )
72+ gt = np .sort (self .m_val , np . array ( axis_val , dtype = "int8" ) )
6673 utt .assert_allclose (gv , gt )
6774
6875 def test5 (self ):
@@ -199,12 +206,12 @@ def test_argsort():
199206
200207 # Example 4
201208 a = dmatrix ()
202- axis = lscalar ( )
209+ axis = scalar ( dtype = "int8" )
203210 l = argsort (a , axis , "mergesort" )
204211 f = pytensor .function ([a , axis ], l )
205212 for axis_val in 0 , 1 :
206- gv = f (m_val , axis_val )
207- gt = np .argsort (m_val , axis_val )
213+ gv = f (m_val , np . array ( axis_val , dtype = "int8" ) )
214+ gt = np .argsort (m_val , np . array ( axis_val , dtype = "int8" ) )
208215 utt .assert_allclose (gv , gt )
209216
210217 # Example 5
@@ -222,6 +229,11 @@ def test_argsort():
222229 gt = np .argsort (m_val , None )
223230 utt .assert_allclose (gv , gt )
224231
232+ with pytest .raises (
233+ ValueError , match = "ArgSort axis must have an integer dtype, got float32"
234+ ):
235+ argsort (dmatrix (), fscalar ())
236+
225237
226238def test_argsort_grad ():
227239 rng = np .random .default_rng (seed = utt .fetch_seed ())
0 commit comments