Skip to content

Commit 9ebf41c

Browse files
committed
Update PjRt registry to use Status QOL functions
Replace `XLA_CHECK` macros with `XLA_ASSIGN_OR_RETURN` and `XLA_RETURN_IF_ERROR` for better error handling: - `InitializePjRt()` now returns `StatusOr<T>` instead of throwing on errors - Enhanced error messages with location context - Consistent error handling across all device types
1 parent dfcdd73 commit 9ebf41c

File tree

5 files changed

+49
-38
lines changed

5 files changed

+49
-38
lines changed

torch_xla/csrc/runtime/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ cc_library(
212212
"//torch_xla/csrc:status",
213213
"@torch//:headers",
214214
"@com_google_absl//absl/log:initialize",
215+
"@com_google_absl//absl/log:absl_check",
215216
"@xla//xla/pjrt:pjrt_c_api_client",
216217
"@xla//xla/pjrt:tfrt_cpu_pjrt_client",
217218
"@xla//xla/pjrt/gpu:se_gpu_pjrt_client",

torch_xla/csrc/runtime/ifrt_computation_client.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,8 @@ std::vector<std::string> IfrtComputationClient::IfrtDevicesToString(
123123
IfrtComputationClient::IfrtComputationClient() {
124124
std::string device_type = sys_util::GetEnvString(env::kEnvPjRtDevice, "");
125125
std::unique_ptr<xla::PjRtClient> pjrt_client;
126-
std::tie(pjrt_client, coordinator_) = InitializePjRt(device_type);
126+
std::tie(pjrt_client, coordinator_) =
127+
GetValueOrThrow(InitializePjRt(device_type));
127128

128129
client_ = xla::ifrt::PjRtClient::Create(std::move(pjrt_client));
129130

torch_xla/csrc/runtime/pjrt_computation_client.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,8 @@ std::vector<std::string> PjRtComputationClient::PjRtDevicesToString(
117117

118118
PjRtComputationClient::PjRtComputationClient() {
119119
std::string device_type = sys_util::GetEnvString(env::kEnvPjRtDevice, "");
120-
std::tie(client_, coordinator_) = InitializePjRt(device_type);
120+
std::tie(client_, coordinator_) =
121+
GetValueOrThrow(InitializePjRt(device_type));
121122

122123
// PjRtDevice IDs are not guaranteed to be dense, so we need to track
123124
// a device's global ordinal separately from its device ID. Order the

torch_xla/csrc/runtime/pjrt_registry.cpp

Lines changed: 42 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include <c10/util/Exception.h>
44

5+
#include "absl/log/absl_check.h"
56
#include "absl/log/initialize.h"
67
#include "torch_xla/csrc/runtime/debug_macros.h"
78
#include "torch_xla/csrc/runtime/env_vars.h"
@@ -74,7 +75,8 @@ void RegisterPjRtPlugin(std::string name,
7475
pjrt_plugins_[name] = plugin;
7576
}
7677

77-
std::tuple<std::unique_ptr<xla::PjRtClient>, std::unique_ptr<XlaCoordinator>>
78+
absl::StatusOr<std::tuple<absl_nonnull std::unique_ptr<xla::PjRtClient>,
79+
std::unique_ptr<XlaCoordinator>>>
7880
InitializePjRt(const std::string& device_type) {
7981
std::unique_ptr<xla::PjRtClient> client;
8082
std::unique_ptr<XlaCoordinator> coordinator;
@@ -110,28 +112,34 @@ InitializePjRt(const std::string& device_type) {
110112
<< ", coordinator address=" << master_addr << ":" << port;
111113

112114
// Use the XlaCoordinator as the distributed key-value store.
113-
coordinator = GetValueOrThrow(XlaCoordinator::Create(
114-
global_process_rank, global_world_size, master_addr, port));
115+
XLA_ASSIGN_OR_RETURN(
116+
coordinator,
117+
XlaCoordinator::Create(global_process_rank, global_world_size,
118+
master_addr, port));
115119
std::shared_ptr<xla::DistributedRuntimeClient> distributed_client =
116120
coordinator->GetClient();
117121
kv_store = xla::GetDistributedKeyValueStore(distributed_client,
118122
/*key_prefix=*/"pjrt:");
119123
}
120-
const PJRT_Api* c_api = *pjrt::LoadPjrtPlugin(
121-
absl::AsciiStrToLower(device_type), plugin->library_path());
122-
XLA_CHECK_OK(pjrt::InitializePjrtPlugin(device_type));
124+
XLA_ASSIGN_OR_RETURN(
125+
const PJRT_Api* c_api,
126+
pjrt::LoadPjrtPlugin(absl::AsciiStrToLower(device_type),
127+
plugin->library_path()));
128+
XLA_RETURN_IF_ERROR(pjrt::InitializePjrtPlugin(device_type));
123129
auto create_options = plugin->client_create_options();
124-
client = xla::GetCApiClient(
125-
absl::AsciiStrToUpper(device_type),
126-
{create_options.begin(), create_options.end()}, kv_store)
127-
.value();
130+
XLA_ASSIGN_OR_RETURN(
131+
client,
132+
xla::GetCApiClient(absl::AsciiStrToUpper(device_type),
133+
{create_options.begin(), create_options.end()},
134+
kv_store));
128135
profiler::RegisterProfilerForPlugin(c_api);
129136
}
130137
} else if (device_type == "CPU") {
131138
TF_VLOG(1) << "Initializing PjRt CPU client...";
132139
bool async = sys_util::GetEnvBool(env::kEnvPjrtAsyncCpuClient, true);
133140
int cpu_device_count = sys_util::GetEnvInt(env::kEnvNumCpu, 1);
134-
client = std::move(xla::GetPjRtCpuClient(async, cpu_device_count).value());
141+
XLA_ASSIGN_OR_RETURN(client,
142+
xla::GetPjRtCpuClient(async, cpu_device_count));
135143
} else if (device_type == "TPU") {
136144
TF_VLOG(1) << "Initializing TFRT TPU client...";
137145
// Init the absl logging to avoid the log spam.
@@ -140,15 +148,14 @@ InitializePjRt(const std::string& device_type) {
140148
auto tpu_library_path = sys_util::GetEnvString(
141149
env::kEnvTpuLibraryPath,
142150
sys_util::GetEnvString(env::kEnvInferredTpuLibraryPath, "libtpu.so"));
143-
XLA_CHECK_OK(pjrt::LoadPjrtPlugin("tpu", tpu_library_path).status());
144-
absl::Status tpu_status = pjrt::InitializePjrtPlugin("tpu");
145-
XLA_CHECK_OK(tpu_status);
146-
client = std::move(xla::GetCApiClient("TPU").value());
147-
const PJRT_Api* c_api =
148-
static_cast<xla::PjRtCApiClient*>(client.get())->pjrt_c_api();
151+
XLA_ASSIGN_OR_RETURN(const PJRT_Api* c_api,
152+
pjrt::LoadPjrtPlugin("tpu", tpu_library_path));
153+
XLA_RETURN_IF_ERROR(pjrt::InitializePjrtPlugin("tpu"));
154+
XLA_ASSIGN_OR_RETURN(client, xla::GetCApiClient("TPU"));
149155
profiler::RegisterProfilerForPlugin(c_api);
150156
} else if (device_type == "TPU_LEGACY") {
151-
XLA_ERROR() << "TPU_LEGACY client is no longer available.";
157+
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(
158+
"TPU_LEGACY client is no longer available."));
152159
} else if (device_type == "CUDA") {
153160
TORCH_WARN("The XLA:CUDA device is deprecated in release 2.8. ",
154161
"Future releases might remove XLA:CUDA support entirely. ",
@@ -183,8 +190,10 @@ InitializePjRt(const std::string& device_type) {
183190
runtime::sys_util::GetEnvString("MASTER_ADDR", "localhost");
184191
std::string port = runtime::sys_util::GetEnvString(
185192
"XLA_COORDINATOR_PORT", XlaCoordinator::kDefaultCoordinatorPort);
186-
coordinator = GetValueOrThrow(XlaCoordinator::Create(
187-
global_process_rank, global_world_size, master_addr, port));
193+
XLA_ASSIGN_OR_RETURN(
194+
coordinator,
195+
XlaCoordinator::Create(global_process_rank, global_world_size,
196+
master_addr, port));
188197
std::shared_ptr<xla::DistributedRuntimeClient> distributed_client =
189198
coordinator->GetClient();
190199
kv_store = xla::GetDistributedKeyValueStore(distributed_client,
@@ -199,27 +208,25 @@ InitializePjRt(const std::string& device_type) {
199208
options.platform_name = "gpu";
200209
options.should_stage_host_to_device_transfers = true;
201210
options.kv_store = kv_store;
202-
client = std::move(xla::GetStreamExecutorGpuClient(options).value());
211+
XLA_ASSIGN_OR_RETURN(client, xla::GetStreamExecutorGpuClient(options));
203212
} else if (device_type == "XPU") {
204213
TF_VLOG(1) << "Initializing PjRt XPU client...";
205-
XLA_CHECK_OK(
206-
pjrt::LoadPjrtPlugin(
207-
"xpu", sys_util::GetEnvString(env::kEnvXpuLibraryPath, "libxpu.so"))
208-
.status());
209-
client = std::move(xla::GetCApiClient("XPU").value());
214+
XLA_RETURN_IF_ERROR(pjrt::LoadPjrtPlugin(
215+
"xpu", sys_util::GetEnvString(env::kEnvXpuLibraryPath, "libxpu.so")));
216+
XLA_ASSIGN_OR_RETURN(client, xla::GetCApiClient("XPU"));
210217
} else if (device_type == "NEURON") {
211218
TF_VLOG(1) << "Initializing PjRt NEURON client...";
212-
XLA_CHECK_OK(pjrt::LoadPjrtPlugin("NEURON", sys_util::GetEnvString(
213-
env::kEnvNeuronLibraryPath,
214-
"libneuronpjrt.so"))
215-
.status());
216-
client = std::move(xla::GetCApiClient("NEURON").value());
219+
XLA_RETURN_IF_ERROR(pjrt::LoadPjrtPlugin(
220+
"NEURON", sys_util::GetEnvString(env::kEnvNeuronLibraryPath,
221+
"libneuronpjrt.so")));
222+
XLA_ASSIGN_OR_RETURN(client, xla::GetCApiClient("NEURON"));
223+
} else {
224+
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat(
225+
"Unknown ", env::kEnvPjRtDevice, ": '", device_type, "'")));
217226
}
218227

219-
XLA_CHECK(client) << absl::StrFormat("Unknown %s '%s'", env::kEnvPjRtDevice,
220-
device_type);
221-
222-
return {std::move(client), std::move(coordinator)};
228+
ABSL_CHECK(client);
229+
return std::make_tuple(std::move(client), std::move(coordinator));
223230
}
224231

225232
} // namespace runtime

torch_xla/csrc/runtime/pjrt_registry.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ class PjRtPlugin {
2121
void RegisterPjRtPlugin(std::string name,
2222
std::shared_ptr<const PjRtPlugin> plugin);
2323

24-
std::tuple<std::unique_ptr<xla::PjRtClient>, std::unique_ptr<XlaCoordinator>>
24+
absl::StatusOr<std::tuple<absl_nonnull std::unique_ptr<xla::PjRtClient>,
25+
std::unique_ptr<XlaCoordinator>>>
2526
InitializePjRt(const std::string& device_type);
2627

2728
} // namespace runtime

0 commit comments

Comments
 (0)