2
2
3
3
#include < c10/util/Exception.h>
4
4
5
+ #include " absl/log/absl_check.h"
5
6
#include " absl/log/initialize.h"
6
7
#include " torch_xla/csrc/runtime/debug_macros.h"
7
8
#include " torch_xla/csrc/runtime/env_vars.h"
@@ -74,7 +75,8 @@ void RegisterPjRtPlugin(std::string name,
74
75
pjrt_plugins_[name] = plugin;
75
76
}
76
77
77
- std::tuple<std::unique_ptr<xla::PjRtClient>, std::unique_ptr<XlaCoordinator>>
78
+ absl::StatusOr<std::tuple<absl_nonnull std::unique_ptr<xla::PjRtClient>,
79
+ std::unique_ptr<XlaCoordinator>>>
78
80
InitializePjRt (const std::string& device_type) {
79
81
std::unique_ptr<xla::PjRtClient> client;
80
82
std::unique_ptr<XlaCoordinator> coordinator;
@@ -110,28 +112,34 @@ InitializePjRt(const std::string& device_type) {
110
112
<< " , coordinator address=" << master_addr << " :" << port;
111
113
112
114
// Use the XlaCoordinator as the distributed key-value store.
113
- coordinator = GetValueOrThrow (XlaCoordinator::Create (
114
- global_process_rank, global_world_size, master_addr, port));
115
+ XLA_ASSIGN_OR_RETURN (
116
+ coordinator,
117
+ XlaCoordinator::Create (global_process_rank, global_world_size,
118
+ master_addr, port));
115
119
std::shared_ptr<xla::DistributedRuntimeClient> distributed_client =
116
120
coordinator->GetClient ();
117
121
kv_store = xla::GetDistributedKeyValueStore (distributed_client,
118
122
/* key_prefix=*/ " pjrt:" );
119
123
}
120
- const PJRT_Api* c_api = *pjrt::LoadPjrtPlugin (
121
- absl::AsciiStrToLower (device_type), plugin->library_path ());
122
- XLA_CHECK_OK (pjrt::InitializePjrtPlugin (device_type));
124
+ XLA_ASSIGN_OR_RETURN (
125
+ const PJRT_Api* c_api,
126
+ pjrt::LoadPjrtPlugin (absl::AsciiStrToLower (device_type),
127
+ plugin->library_path ()));
128
+ XLA_RETURN_IF_ERROR (pjrt::InitializePjrtPlugin (device_type));
123
129
auto create_options = plugin->client_create_options ();
124
- client = xla::GetCApiClient (
125
- absl::AsciiStrToUpper (device_type),
126
- {create_options.begin (), create_options.end ()}, kv_store)
127
- .value ();
130
+ XLA_ASSIGN_OR_RETURN (
131
+ client,
132
+ xla::GetCApiClient (absl::AsciiStrToUpper (device_type),
133
+ {create_options.begin (), create_options.end ()},
134
+ kv_store));
128
135
profiler::RegisterProfilerForPlugin (c_api);
129
136
}
130
137
} else if (device_type == " CPU" ) {
131
138
TF_VLOG (1 ) << " Initializing PjRt CPU client..." ;
132
139
bool async = sys_util::GetEnvBool (env::kEnvPjrtAsyncCpuClient , true );
133
140
int cpu_device_count = sys_util::GetEnvInt (env::kEnvNumCpu , 1 );
134
- client = std::move (xla::GetPjRtCpuClient (async, cpu_device_count).value ());
141
+ XLA_ASSIGN_OR_RETURN (client,
142
+ xla::GetPjRtCpuClient (async, cpu_device_count));
135
143
} else if (device_type == " TPU" ) {
136
144
TF_VLOG (1 ) << " Initializing TFRT TPU client..." ;
137
145
// Init the absl logging to avoid the log spam.
@@ -140,15 +148,14 @@ InitializePjRt(const std::string& device_type) {
140
148
auto tpu_library_path = sys_util::GetEnvString (
141
149
env::kEnvTpuLibraryPath ,
142
150
sys_util::GetEnvString (env::kEnvInferredTpuLibraryPath , " libtpu.so" ));
143
- XLA_CHECK_OK (pjrt::LoadPjrtPlugin (" tpu" , tpu_library_path).status ());
144
- absl::Status tpu_status = pjrt::InitializePjrtPlugin (" tpu" );
145
- XLA_CHECK_OK (tpu_status);
146
- client = std::move (xla::GetCApiClient (" TPU" ).value ());
147
- const PJRT_Api* c_api =
148
- static_cast <xla::PjRtCApiClient*>(client.get ())->pjrt_c_api ();
151
+ XLA_ASSIGN_OR_RETURN (const PJRT_Api* c_api,
152
+ pjrt::LoadPjrtPlugin (" tpu" , tpu_library_path));
153
+ XLA_RETURN_IF_ERROR (pjrt::InitializePjrtPlugin (" tpu" ));
154
+ XLA_ASSIGN_OR_RETURN (client, xla::GetCApiClient (" TPU" ));
149
155
profiler::RegisterProfilerForPlugin (c_api);
150
156
} else if (device_type == " TPU_LEGACY" ) {
151
- XLA_ERROR () << " TPU_LEGACY client is no longer available." ;
157
+ return XLA_ERROR_WITH_LOCATION (absl::InvalidArgumentError (
158
+ " TPU_LEGACY client is no longer available." ));
152
159
} else if (device_type == " CUDA" ) {
153
160
TORCH_WARN (" The XLA:CUDA device is deprecated in release 2.8. " ,
154
161
" Future releases might remove XLA:CUDA support entirely. " ,
@@ -183,8 +190,10 @@ InitializePjRt(const std::string& device_type) {
183
190
runtime::sys_util::GetEnvString (" MASTER_ADDR" , " localhost" );
184
191
std::string port = runtime::sys_util::GetEnvString (
185
192
" XLA_COORDINATOR_PORT" , XlaCoordinator::kDefaultCoordinatorPort );
186
- coordinator = GetValueOrThrow (XlaCoordinator::Create (
187
- global_process_rank, global_world_size, master_addr, port));
193
+ XLA_ASSIGN_OR_RETURN (
194
+ coordinator,
195
+ XlaCoordinator::Create (global_process_rank, global_world_size,
196
+ master_addr, port));
188
197
std::shared_ptr<xla::DistributedRuntimeClient> distributed_client =
189
198
coordinator->GetClient ();
190
199
kv_store = xla::GetDistributedKeyValueStore (distributed_client,
@@ -199,27 +208,25 @@ InitializePjRt(const std::string& device_type) {
199
208
options.platform_name = " gpu" ;
200
209
options.should_stage_host_to_device_transfers = true ;
201
210
options.kv_store = kv_store;
202
- client = std::move ( xla::GetStreamExecutorGpuClient (options). value ( ));
211
+ XLA_ASSIGN_OR_RETURN ( client, xla::GetStreamExecutorGpuClient (options));
203
212
} else if (device_type == " XPU" ) {
204
213
TF_VLOG (1 ) << " Initializing PjRt XPU client..." ;
205
- XLA_CHECK_OK (
206
- pjrt::LoadPjrtPlugin (
207
- " xpu" , sys_util::GetEnvString (env::kEnvXpuLibraryPath , " libxpu.so" ))
208
- .status ());
209
- client = std::move (xla::GetCApiClient (" XPU" ).value ());
214
+ XLA_RETURN_IF_ERROR (pjrt::LoadPjrtPlugin (
215
+ " xpu" , sys_util::GetEnvString (env::kEnvXpuLibraryPath , " libxpu.so" )));
216
+ XLA_ASSIGN_OR_RETURN (client, xla::GetCApiClient (" XPU" ));
210
217
} else if (device_type == " NEURON" ) {
211
218
TF_VLOG (1 ) << " Initializing PjRt NEURON client..." ;
212
- XLA_CHECK_OK (pjrt::LoadPjrtPlugin (" NEURON" , sys_util::GetEnvString (
213
- env::kEnvNeuronLibraryPath ,
214
- " libneuronpjrt.so" ))
215
- .status ());
216
- client = std::move (xla::GetCApiClient (" NEURON" ).value ());
219
+ XLA_RETURN_IF_ERROR (pjrt::LoadPjrtPlugin (
220
+ " NEURON" , sys_util::GetEnvString (env::kEnvNeuronLibraryPath ,
221
+ " libneuronpjrt.so" )));
222
+ XLA_ASSIGN_OR_RETURN (client, xla::GetCApiClient (" NEURON" ));
223
+ } else {
224
+ return XLA_ERROR_WITH_LOCATION (absl::InvalidArgumentError (absl::StrCat (
225
+ " Unknown " , env::kEnvPjRtDevice , " : '" , device_type, " '" )));
217
226
}
218
227
219
- XLA_CHECK (client) << absl::StrFormat (" Unknown %s '%s'" , env::kEnvPjRtDevice ,
220
- device_type);
221
-
222
- return {std::move (client), std::move (coordinator)};
228
+ ABSL_CHECK (client);
229
+ return std::make_tuple (std::move (client), std::move (coordinator));
223
230
}
224
231
225
232
} // namespace runtime
0 commit comments