Skip to content

Commit c1adda8

Browse files
hawkinspGoogle-ML-Automation
authored andcommitted
If LAPACK kernels are linked statically, don't override them from scipy.
PiperOrigin-RevId: 839343548
1 parent a915a31 commit c1adda8

File tree

6 files changed

+12
-1
lines changed

6 files changed

+12
-1
lines changed

jax/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ py_library_providing_imports_info(
245245
"//jax/_src:xla_metadata",
246246
"//jax/_src:xla_metadata_lib",
247247
"//jax/_src/lib",
248-
] + py_deps("numpy") + py_deps("scipy") + py_deps("opt_einsum") + py_deps("flatbuffers") + jax_extra_deps,
248+
] + py_deps("numpy") + py_deps("opt_einsum") + py_deps("flatbuffers") + jax_extra_deps,
249249
)
250250

251251
pytype_strict_library(

jaxlib/cpu/lapack.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ using ::xla::ffi::DataType;
2727

2828
void GetLapackKernelsFromScipy() {
2929
static absl::once_flag initialized;
30+
if (lapack_kernels_initialized) {
31+
return;
32+
}
3033
// For reasons I'm not entirely sure of, if the import_ call is done inside
3134
// the call_once scope, we sometimes observe deadlocks in the test suite.
3235
// However it probably doesn't do much harm to just import them a second time,

jaxlib/cpu/lapack_kernels.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ XLA_FFI_REGISTER_ENUM_ATTR_DECODING(jax::schur::Sort);
4848

4949
namespace jax {
5050

51+
bool lapack_kernels_initialized = false;
52+
5153
template <typename T>
5254
inline T CastNoOverflow(int64_t value, std::string_view source = __FILE__) {
5355
auto result = MaybeCastNoOverflow<T>(value, source);

jaxlib/cpu/lapack_kernels.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ limitations under the License.
3030

3131
namespace jax {
3232

33+
extern bool lapack_kernels_initialized;
34+
3335
struct MatrixParams {
3436
enum class Side : char { kLeft = 'L', kRight = 'R' };
3537
enum class UpLo : char { kLower = 'L', kUpper = 'U' };

jaxlib/cpu/lapack_kernels_using_lapack.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ static auto init = []() -> int {
173173
AssignKernelFn<TridiagonalSolver<ffi::DataType::C64>>(cgtsv_);
174174
AssignKernelFn<TridiagonalSolver<ffi::DataType::C128>>(zgtsv_);
175175

176+
lapack_kernels_initialized = true;
176177
return 0;
177178
}();
178179

jaxlib/gpu/hybrid.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ namespace nb = nanobind;
2929

3030
void GetLapackKernelsFromScipy() {
3131
static absl::once_flag initialized;
32+
if (lapack_kernels_initialized) {
33+
return;
34+
}
3235
// For reasons I'm not entirely sure of, if the import_ call is done inside
3336
// the call_once scope, we sometimes observe deadlocks in the test suite.
3437
// However it probably doesn't do much harm to just import them a second time,

0 commit comments

Comments
 (0)