@@ -116,6 +116,7 @@ limitations under the License.
116116#include " xla/hlo/builder/lib/approx_topk_shape.h"
117117#include " xla/pjrt/c_api_client/pjrt_c_api_client.h"
118118#include " xla/pjrt/distributed/key_value_store_interface.h"
119+ #include " xla/pjrt/distributed/preemption/preemption_sync_manager.h"
119120#include " xla/pjrt/exceptions.h"
120121#include " xla/pjrt/pjrt_api.h"
121122#include " xla/pjrt/pjrt_client.h"
@@ -580,6 +581,30 @@ NB_MODULE(_jax, m) {
580581 aux::RegisterTransferServerTypes (m);
581582#endif // defined(__linux__)
582583
584+ #if JAX_IFRT_VERSION_NUMBER >= 38
585+ nb::class_<xla::PreemptionSyncManager> preemption_sync_manager (
586+ m, " PreemptionSyncManager" );
587+ preemption_sync_manager
588+ .def (
589+ " initialize" ,
590+ [](xla::PreemptionSyncManager& manager,
591+ xla::DistributedRuntimeClient* client) {
592+ xla::CoordinationServiceAgent* agent =
593+ xla::ValueOrThrow (client->GetCoordinationServiceAgent ());
594+ xla::ThrowIfError (manager.Initialize (agent));
595+ },
596+ nb::arg (" distributed_client" ))
597+ .def (" reached_sync_point" ,
598+ [](xla::PreemptionSyncManager& manager, int step_counter) {
599+ return manager.ReachedSyncPoint (step_counter);
600+ })
601+ .def (" shutdown" , [](xla::PreemptionSyncManager& manager) {
602+ nb::gil_scoped_release gil_release;
603+ manager.Shutdown ();
604+ });
605+ m.def (" create_preemption_sync_manager" ,
606+ []() { return xla::CreatePreemptionSyncManager (); });
607+ #else
583608 nb::class_<tsl::PreemptionSyncManager> preemption_sync_manager (
584609 m, " PreemptionSyncManager" );
585610 preemption_sync_manager
@@ -602,6 +627,7 @@ NB_MODULE(_jax, m) {
602627 });
603628 m.def (" create_preemption_sync_manager" ,
604629 []() { return tsl::CreatePreemptionSyncManager (); });
630+ #endif
605631
606632 nb::class_<xla::DistributedRuntimeService> distributed_runtime_service (
607633 m, " DistributedRuntimeService" );
@@ -898,7 +924,6 @@ NB_MODULE(_jax, m) {
898924 nb::class_<xla::ifrt::TransferServerInterfaceFactory>(
899925 m, " TransferServerInterfaceFactory" );
900926
901-
902927 m.def (" is_asan" , IsAsan);
903928 m.def (" is_msan" , IsMsan);
904929 m.def (" is_tsan" , IsTsan);
0 commit comments