File tree Expand file tree Collapse file tree 6 files changed +12
-1
lines changed
Expand file tree Collapse file tree 6 files changed +12
-1
lines changed Original file line number Diff line number Diff line change @@ -245,7 +245,7 @@ py_library_providing_imports_info(
245245 "//jax/_src:xla_metadata" ,
246246 "//jax/_src:xla_metadata_lib" ,
247247 "//jax/_src/lib" ,
248- ] + py_deps ("numpy" ) + py_deps ("scipy" ) + py_deps ( " opt_einsum" ) + py_deps ("flatbuffers" ) + jax_extra_deps ,
248+ ] + py_deps ("numpy" ) + py_deps ("opt_einsum" ) + py_deps ("flatbuffers" ) + jax_extra_deps ,
249249)
250250
251251pytype_strict_library (
Original file line number Diff line number Diff line change @@ -27,6 +27,9 @@ using ::xla::ffi::DataType;
2727
2828void GetLapackKernelsFromScipy () {
2929 static absl::once_flag initialized;
30+ if (lapack_kernels_initialized) {
31+ return ;
32+ }
3033 // For reasons I'm not entirely sure of, if the import_ call is done inside
3134 // the call_once scope, we sometimes observe deadlocks in the test suite.
3235 // However it probably doesn't do much harm to just import them a second time,
Original file line number Diff line number Diff line change @@ -48,6 +48,8 @@ XLA_FFI_REGISTER_ENUM_ATTR_DECODING(jax::schur::Sort);
4848
4949namespace jax {
5050
51+ bool lapack_kernels_initialized = false ;
52+
5153template <typename T>
5254inline T CastNoOverflow (int64_t value, std::string_view source = __FILE__) {
5355 auto result = MaybeCastNoOverflow<T>(value, source);
Original file line number Diff line number Diff line change @@ -30,6 +30,8 @@ limitations under the License.
3030
3131namespace jax {
3232
33+ extern bool lapack_kernels_initialized;
34+
3335struct MatrixParams {
3436 enum class Side : char { kLeft = ' L' , kRight = ' R' };
3537 enum class UpLo : char { kLower = ' L' , kUpper = ' U' };
Original file line number Diff line number Diff line change @@ -173,6 +173,7 @@ static auto init = []() -> int {
173173 AssignKernelFn<TridiagonalSolver<ffi::DataType::C64>>(cgtsv_);
174174 AssignKernelFn<TridiagonalSolver<ffi::DataType::C128>>(zgtsv_);
175175
176+ lapack_kernels_initialized = true ;
176177 return 0 ;
177178}();
178179
Original file line number Diff line number Diff line change @@ -29,6 +29,9 @@ namespace nb = nanobind;
2929
3030void GetLapackKernelsFromScipy () {
3131 static absl::once_flag initialized;
32+ if (lapack_kernels_initialized) {
33+ return ;
34+ }
3235 // For reasons I'm not entirely sure of, if the import_ call is done inside
3336 // the call_once scope, we sometimes observe deadlocks in the test suite.
3437 // However it probably doesn't do much harm to just import them a second time,
You can’t perform that action at this time.
0 commit comments