@@ -115,6 +115,8 @@ limitations under the License.
115115#include " xla/executable_run_options.h"
116116#include " xla/ffi/ffi.h"
117117#include " xla/ffi/ffi_api.h"
118+ #include " xla/service/custom_call_status.h"
119+ #include " xla/service/custom_call_target_registry.h"
118120#include " xla/service/gpu/llvm_gpu_backend/nvptx_libdevice_path.h"
119121#include " xla/service/llvm_ir/llvm_command_line_options.h"
120122#include " xla/stream_executor/cuda/assemble_compilation_provider.h"
@@ -639,6 +641,40 @@ absl::StatusOr<CompiledKernel*> CachedCompileAndInit(CacheKey key,
639641 return &cache.kernels .at (key);
640642}
641643
644+ void MosaicGPUCustomCall (void * stream, void ** buffers, char * opaque,
645+ size_t opaque_len, XlaCustomCallStatus* status) {
646+ // Forward-compatible version using the legacy FFI API
647+ if (reinterpret_cast <uintptr_t >(opaque) % alignof (KernelHash)) {
648+ fprintf (stderr, " Misaligned opaque pointer\n " );
649+ abort ();
650+ }
651+ auto hash = *reinterpret_cast <KernelHash*>(opaque);
652+ CUcontext ctx;
653+ if (cuCtxGetCurrent (&ctx) != CUDA_SUCCESS) {
654+ fprintf (stderr, " Failed to get current CUDA context\n " );
655+ abort ();
656+ }
657+ CacheKey key (hash, reinterpret_cast <uintptr_t >(ctx));
658+ auto compiled_kernel = CachedCompileAndInit (key, opaque + sizeof (KernelHash));
659+ if (!compiled_kernel.ok ()) {
660+ XlaCustomCallStatusSetFailure (status,
661+ compiled_kernel.status ().message ().data (),
662+ compiled_kernel.status ().message ().size ());
663+ return ;
664+ }
665+ auto ctx_kernel_comm = (*compiled_kernel)->GetHostLaunch ();
666+ bool is_comm_used = std::get<2 >(ctx_kernel_comm);
667+ void * args[4 ] = {&std::get<0 >(ctx_kernel_comm), &stream, &buffers};
668+ if (is_comm_used) {
669+ mosaic::gpu::NvshmemApi::Default ().barrier_all_on_stream (
670+ reinterpret_cast <cudaStream_t>(stream));
671+ }
672+ std::get<1 >(ctx_kernel_comm)(args);
673+ }
674+
675+ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM (" mosaic_gpu" , &MosaicGPUCustomCall,
676+ " CUDA" );
677+
642678absl::Status MosaicGpuExecute (gpuStream_t stream, ffi::RemainingArgs inputs,
643679 ffi::RemainingRets results,
644680 std::string_view kernel_hash,
0 commit comments