Skip to content

Commit d5f91dc

Browse files
bchetiouiGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Revert deletion of the legacy FFI API.
Reverts 9138c20 PiperOrigin-RevId: 838381988
1 parent 5e6a2c8 commit d5f91dc

File tree

4 files changed

+49
-11
lines changed

4 files changed

+49
-11
lines changed

jax/_src/internal_test_util/export_back_compat_test_data/pallas/mosaic_gpu_add_one.py

Lines changed: 9 additions & 9 deletions
Large diffs are not rendered by default.

jaxlib/mosaic/gpu/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,8 @@ cc_library(
330330
"@xla//xla/backends/gpu:ffi",
331331
"@xla//xla/ffi",
332332
"@xla//xla/ffi:ffi_api",
333+
"@xla//xla/service:custom_call_status",
334+
"@xla//xla/service:custom_call_target_registry",
333335
"@xla//xla/service/gpu/llvm_gpu_backend:nvptx_libdevice_path",
334336
"@xla//xla/service/llvm_ir:llvm_command_line_options",
335337
"@xla//xla/stream_executor/cuda:assemble_compilation_provider",

jaxlib/mosaic/gpu/custom_call.cc

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,8 @@ limitations under the License.
115115
#include "xla/executable_run_options.h"
116116
#include "xla/ffi/ffi.h"
117117
#include "xla/ffi/ffi_api.h"
118+
#include "xla/service/custom_call_status.h"
119+
#include "xla/service/custom_call_target_registry.h"
118120
#include "xla/service/gpu/llvm_gpu_backend/nvptx_libdevice_path.h"
119121
#include "xla/service/llvm_ir/llvm_command_line_options.h"
120122
#include "xla/stream_executor/cuda/assemble_compilation_provider.h"
@@ -639,6 +641,40 @@ absl::StatusOr<CompiledKernel*> CachedCompileAndInit(CacheKey key,
639641
return &cache.kernels.at(key);
640642
}
641643

644+
void MosaicGPUCustomCall(void* stream, void** buffers, char* opaque,
645+
size_t opaque_len, XlaCustomCallStatus* status) {
646+
// Forward-compatible version using the legacy FFI API
647+
if (reinterpret_cast<uintptr_t>(opaque) % alignof(KernelHash)) {
648+
fprintf(stderr, "Misaligned opaque pointer\n");
649+
abort();
650+
}
651+
auto hash = *reinterpret_cast<KernelHash*>(opaque);
652+
CUcontext ctx;
653+
if (cuCtxGetCurrent(&ctx) != CUDA_SUCCESS) {
654+
fprintf(stderr, "Failed to get current CUDA context\n");
655+
abort();
656+
}
657+
CacheKey key(hash, reinterpret_cast<uintptr_t>(ctx));
658+
auto compiled_kernel = CachedCompileAndInit(key, opaque + sizeof(KernelHash));
659+
if (!compiled_kernel.ok()) {
660+
XlaCustomCallStatusSetFailure(status,
661+
compiled_kernel.status().message().data(),
662+
compiled_kernel.status().message().size());
663+
return;
664+
}
665+
auto ctx_kernel_comm = (*compiled_kernel)->GetHostLaunch();
666+
bool is_comm_used = std::get<2>(ctx_kernel_comm);
667+
void* args[4] = {&std::get<0>(ctx_kernel_comm), &stream, &buffers};
668+
if (is_comm_used) {
669+
mosaic::gpu::NvshmemApi::Default().barrier_all_on_stream(
670+
reinterpret_cast<cudaStream_t>(stream));
671+
}
672+
std::get<1>(ctx_kernel_comm)(args);
673+
}
674+
675+
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("mosaic_gpu", &MosaicGPUCustomCall,
676+
"CUDA");
677+
642678
absl::Status MosaicGpuExecute(gpuStream_t stream, ffi::RemainingArgs inputs,
643679
ffi::RemainingRets results,
644680
std::string_view kernel_hash,

tests/pallas/export_back_compat_pallas_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ def test_mosaic_gpu_add_one(self):
8181
def add_one(x_ref, o_ref):
8282
o_ref[...] = x_ref[...] + 1
8383

84-
data = self.load_testdata(mosaic_gpu_add_one.data_2025_11_27)
85-
self.run_one_test(add_one, data)
84+
data = self.load_testdata(mosaic_gpu_add_one.data_2025_04_22)
85+
self.run_one_test(add_one, data, expect_current_custom_calls=["mosaic_gpu_v2"])
8686

8787
def test_mosaic_gpu_kernel_add_one(self):
8888
if not jtu.is_cuda_compute_capability_at_least("9.0"):

0 commit comments

Comments
 (0)