44import numpy as np
55
66import pytensor .link .numba .dispatch .basic as numba_basic
7+ from pytensor import config
78from pytensor .link .numba .dispatch .basic import (
89 get_numba_type ,
9- int_to_float_fn ,
1010 register_funcify_default_op_cache_key ,
1111)
1212from pytensor .tensor .nlinalg import (
@@ -26,65 +26,88 @@ def numba_funcify_SVD(op, node, **kwargs):
2626 compute_uv = op .compute_uv
2727 out_dtype = np .dtype (node .outputs [0 ].dtype )
2828
29- inputs_cast = int_to_float_fn (node .inputs , out_dtype )
29+ discrete_input = node .inputs [0 ].type .numpy_dtype .kind in "ibu"
30+ if discrete_input and config .compiler_verbose :
31+ print ("SVD requires casting discrete input to float" ) # noqa: T201
3032
3133 if not compute_uv :
3234
3335 @numba_basic .numba_njit
3436 def svd (x ):
35- _ , ret , _ = np .linalg .svd (inputs_cast (x ), full_matrices )
37+ if discrete_input :
38+ x = x .astype (out_dtype )
39+ _ , ret , _ = np .linalg .svd (x , full_matrices )
3640 return ret
3741
3842 else :
3943
4044 @numba_basic .numba_njit
4145 def svd (x ):
42- return np .linalg .svd (inputs_cast (x ), full_matrices )
46+ if discrete_input :
47+ x = x .astype (out_dtype )
48+ return np .linalg .svd (x , full_matrices )
4349
44- return svd
50+ cache_version = 1
51+ return svd , cache_version
4552
4653
4754@register_funcify_default_op_cache_key (Det )
4855def numba_funcify_Det (op , node , ** kwargs ):
4956 out_dtype = node .outputs [0 ].type .numpy_dtype
50- inputs_cast = int_to_float_fn (node .inputs , out_dtype )
57+ discrete_input = node .inputs [0 ].type .numpy_dtype .kind in "ibu"
58+ if discrete_input and config .compiler_verbose :
59+ print ("Det requires casting discrete input to float" ) # noqa: T201
5160
5261 @numba_basic .numba_njit
5362 def det (x ):
54- return np .array (np .linalg .det (inputs_cast (x ))).astype (out_dtype )
63+ if discrete_input :
64+ x = x .astype (out_dtype )
65+ return np .array (np .linalg .det (x ), dtype = out_dtype )
5566
56- return det
67+ cache_version = 1
68+ return det , cache_version
5769
5870
5971@register_funcify_default_op_cache_key (SLogDet )
6072def numba_funcify_SLogDet (op , node , ** kwargs ):
61- out_dtype_1 = node .outputs [0 ].type .numpy_dtype
62- out_dtype_2 = node .outputs [1 ].type .numpy_dtype
73+ out_dtype_sign = node .outputs [0 ].type .numpy_dtype
74+ out_dtype_det = node .outputs [1 ].type .numpy_dtype
6375
64- inputs_cast = int_to_float_fn (node .inputs , out_dtype_1 )
76+ discrete_input = node .inputs [0 ].type .numpy_dtype .kind in "ibu"
77+ if discrete_input and config .compiler_verbose :
78+ print ("SLogDet requires casting discrete input to float" ) # noqa: T201
6579
6680 @numba_basic .numba_njit
6781 def slogdet (x ):
68- sign , det = np .linalg .slogdet (inputs_cast (x ))
82+ if discrete_input :
83+ x = x .astype (out_dtype_det )
84+ sign , det = np .linalg .slogdet (x )
6985 return (
70- np .array (sign ). astype ( out_dtype_1 ),
71- np .array (det ). astype ( out_dtype_2 ),
86+ np .array (sign , dtype = out_dtype_sign ),
87+ np .array (det , dtype = out_dtype_det ),
7288 )
7389
74- return slogdet
90+ cache_version = 1
91+ return slogdet , cache_version
7592
7693
7794@register_funcify_default_op_cache_key (Eig )
7895def numba_funcify_Eig (op , node , ** kwargs ):
7996 w_dtype = node .outputs [0 ].type .numpy_dtype
80- inputs_cast = int_to_float_fn (node .inputs , w_dtype )
97+ non_complex_input = node .inputs [0 ].type .numpy_dtype .kind != "c"
98+ if non_complex_input and config .compiler_verbose :
99+ print ("Eig requires casting input to complex" ) # noqa: T201
81100
82101 @numba_basic .numba_njit
83102 def eig (x ):
84- w , v = np .linalg .eig (inputs_cast (x ))
103+ if non_complex_input :
104+ # Even floats are better cast to complex, otherwise numba may raise
105+ # ValueError: eig() argument must not cause a domain change.
106+ x = x .astype (w_dtype )
107+ w , v = np .linalg .eig (x )
85108 return w .astype (w_dtype ), v .astype (w_dtype )
86109
87- cache_version = 1
110+ cache_version = 2
88111 return eig , cache_version
89112
90113
@@ -125,22 +148,32 @@ def eigh(x):
125148@register_funcify_default_op_cache_key (MatrixInverse )
126149def numba_funcify_MatrixInverse (op , node , ** kwargs ):
127150 out_dtype = node .outputs [0 ].type .numpy_dtype
128- inputs_cast = int_to_float_fn (node .inputs , out_dtype )
151+ discrete_input = node .inputs [0 ].type .numpy_dtype .kind in "ibu"
152+ if discrete_input and config .compiler_verbose :
153+ print ("MatrixInverse requires casting discrete input to float" ) # noqa: T201
129154
130155 @numba_basic .numba_njit
131156 def matrix_inverse (x ):
132- return np .linalg .inv (inputs_cast (x )).astype (out_dtype )
157+ if discrete_input :
158+ x = x .astype (out_dtype )
159+ return np .linalg .inv (x )
133160
134- return matrix_inverse
161+ cache_version = 1
162+ return matrix_inverse , cache_version
135163
136164
137165@register_funcify_default_op_cache_key (MatrixPinv )
138166def numba_funcify_MatrixPinv (op , node , ** kwargs ):
139167 out_dtype = node .outputs [0 ].type .numpy_dtype
140- inputs_cast = int_to_float_fn (node .inputs , out_dtype )
168+ discrete_input = node .inputs [0 ].type .numpy_dtype .kind in "ibu"
169+ if discrete_input and config .compiler_verbose :
170+ print ("MatrixPinv requires casting discrete input to float" ) # noqa: T201
141171
142172 @numba_basic .numba_njit
143- def matrixpinv (x ):
144- return np .linalg .pinv (inputs_cast (x )).astype (out_dtype )
173+ def matrix_pinv (x ):
174+ if discrete_input :
175+ x = x .astype (out_dtype )
176+ return np .linalg .pinv (x )
145177
146- return matrixpinv
178+ cache_version = 1
179+ return matrix_pinv , cache_version
0 commit comments