Skip to content

Commit c2905b9

Browse files
mwhittakerGoogle-ML-Automation
authored andcommitted
Change PjRt to use new copy of coordination service.
PiperOrigin-RevId: 839422254
1 parent 0625fdc commit c2905b9

File tree

2 files changed

+27
-1
lines changed

2 files changed

+27
-1
lines changed

jaxlib/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,7 @@ nanobind_pywrap_extension(
397397
"@xla//xla/pjrt/distributed:key_value_store_interface",
398398
"@xla//xla/pjrt/distributed:protocol_proto_cc",
399399
"@xla//xla/pjrt/distributed:service",
400+
"@xla//xla/pjrt/distributed/preemption:preemption_sync_manager",
400401
"@xla//xla/pjrt/plugin/xla_cpu:cpu_client_options",
401402
"@xla//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client",
402403
"@xla//xla/python:logging",

jaxlib/jax.cc

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ limitations under the License.
116116
#include "xla/hlo/builder/lib/approx_topk_shape.h"
117117
#include "xla/pjrt/c_api_client/pjrt_c_api_client.h"
118118
#include "xla/pjrt/distributed/key_value_store_interface.h"
119+
#include "xla/pjrt/distributed/preemption/preemption_sync_manager.h"
119120
#include "xla/pjrt/exceptions.h"
120121
#include "xla/pjrt/pjrt_api.h"
121122
#include "xla/pjrt/pjrt_client.h"
@@ -580,6 +581,30 @@ NB_MODULE(_jax, m) {
580581
aux::RegisterTransferServerTypes(m);
581582
#endif // defined(__linux__)
582583

584+
#if JAX_IFRT_VERSION_NUMBER >= 38
585+
nb::class_<xla::PreemptionSyncManager> preemption_sync_manager(
586+
m, "PreemptionSyncManager");
587+
preemption_sync_manager
588+
.def(
589+
"initialize",
590+
[](xla::PreemptionSyncManager& manager,
591+
xla::DistributedRuntimeClient* client) {
592+
xla::CoordinationServiceAgent* agent =
593+
xla::ValueOrThrow(client->GetCoordinationServiceAgent());
594+
xla::ThrowIfError(manager.Initialize(agent));
595+
},
596+
nb::arg("distributed_client"))
597+
.def("reached_sync_point",
598+
[](xla::PreemptionSyncManager& manager, int step_counter) {
599+
return manager.ReachedSyncPoint(step_counter);
600+
})
601+
.def("shutdown", [](xla::PreemptionSyncManager& manager) {
602+
nb::gil_scoped_release gil_release;
603+
manager.Shutdown();
604+
});
605+
m.def("create_preemption_sync_manager",
606+
[]() { return xla::CreatePreemptionSyncManager(); });
607+
#else
583608
nb::class_<tsl::PreemptionSyncManager> preemption_sync_manager(
584609
m, "PreemptionSyncManager");
585610
preemption_sync_manager
@@ -602,6 +627,7 @@ NB_MODULE(_jax, m) {
602627
});
603628
m.def("create_preemption_sync_manager",
604629
[]() { return tsl::CreatePreemptionSyncManager(); });
630+
#endif
605631

606632
nb::class_<xla::DistributedRuntimeService> distributed_runtime_service(
607633
m, "DistributedRuntimeService");
@@ -898,7 +924,6 @@ NB_MODULE(_jax, m) {
898924
nb::class_<xla::ifrt::TransferServerInterfaceFactory>(
899925
m, "TransferServerInterfaceFactory");
900926

901-
902927
m.def("is_asan", IsAsan);
903928
m.def("is_msan", IsMsan);
904929
m.def("is_tsan", IsTsan);

0 commit comments

Comments
 (0)