Skip to content

Commit 38b0ebf

Browse files
committed
Make ExecuteComputation and ExecuteReplicated return StatusOr<T>
Key changes: - Updated base `ComputationClient` interface to return `absl::StatusOr<std::vector<DataPtr>>` - Modified IFRT and PjRt implementations to use proper error propagation - Replaced raw `.value()` calls with `XLA_ASSIGN_OR_RETURN_WITH_LOCATION` macros - Updated all call sites to use `GetValueOrThrow` for exception-based error handling
1 parent 048459c commit 38b0ebf

9 files changed

+57
-50
lines changed

torch_xla/csrc/runtime/computation_client.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <vector>
1717

1818
#include "absl/container/flat_hash_map.h"
19+
#include "absl/status/statusor.h"
1920
#include "absl/types/optional.h"
2021
#include "absl/types/span.h"
2122
#include "torch_xla/csrc/device.h"
@@ -345,7 +346,7 @@ class ComputationClient {
345346
// The passed device must match the common device of the arguments Data.
346347
// If options.explode_tuple is true, the output tuple will be decomposed into
347348
// its single elements.
348-
virtual std::vector<DataPtr> ExecuteComputation(
349+
virtual absl::StatusOr<std::vector<DataPtr>> ExecuteComputation(
349350
const Computation& computation, absl::Span<const DataPtr> arguments,
350351
const std::string& device,
351352
const ExecuteComputationOptions& options =
@@ -356,7 +357,7 @@ class ComputationClient {
356357
// as `devices`. If options.explode_tuple is true, the output tuples will be
357358
// decomposed into their single elements. Returns a vector of outputs, each
358359
// of which is sharded in the same order as `devices`.
359-
virtual std::vector<DataPtr> ExecuteReplicated(
360+
virtual absl::StatusOr<std::vector<DataPtr>> ExecuteReplicated(
360361
const Computation& computation, absl::Span<const DataPtr> arguments,
361362
absl::Span<const std::string> devices,
362363
const ExecuteReplicatedOptions& options) = 0;

torch_xla/csrc/runtime/ifrt_computation_client.cpp

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -403,8 +403,8 @@ tsl::RCReference<xla::ifrt::Array> IfrtComputationClient::ReplicateShardedData(
403403
torch_xla::runtime::ComputationClient::ExecuteReplicatedOptions
404404
execute_options;
405405

406-
auto sharded_results = ExecuteReplicated(*computations.front(), {{handle}},
407-
GetLocalDevices(), execute_options);
406+
auto sharded_results = GetValueOrThrow(ExecuteReplicated(
407+
*computations.front(), {{handle}}, GetLocalDevices(), execute_options));
408408
auto replicated_output =
409409
std::dynamic_pointer_cast<IfrtData>(sharded_results[0])
410410
->buffer->FullyReplicatedShard(
@@ -524,16 +524,16 @@ std::vector<ComputationClient::ComputationPtr> IfrtComputationClient::Compile(
524524
return computations;
525525
}
526526

527-
std::vector<ComputationClient::DataPtr>
527+
absl::StatusOr<std::vector<ComputationClient::DataPtr>>
528528
IfrtComputationClient::ExecuteComputation(
529529
const ComputationClient::Computation& computation,
530530
absl::Span<const ComputationClient::DataPtr> arguments,
531531
const std::string& device, const ExecuteComputationOptions& options) {
532532
// TODO: Implement sharded exec in IFRT
533-
XLA_ERROR() << __FUNCTION__ << " not implemented";
533+
return absl::UnimplementedError("ExecuteComputation not implemented");
534534
}
535535

536-
std::vector<ComputationClient::DataPtr>
536+
absl::StatusOr<std::vector<ComputationClient::DataPtr>>
537537
IfrtComputationClient::ExecuteReplicated(
538538
const ComputationClient::Computation& computation,
539539
const absl::Span<const ComputationClient::DataPtr> arguments,
@@ -578,11 +578,10 @@ IfrtComputationClient::ExecuteReplicated(
578578
TF_VLOG(5) << "ExecuteReplicated acquiring IFRT device lock for "
579579
<< spmd_device_str << " Done";
580580

581-
xla::ifrt::LoadedExecutable::ExecuteResult result =
582-
ifrt_computation.executable
583-
->Execute(absl::MakeSpan(argument_handles), execute_options,
584-
std::nullopt)
585-
.value();
581+
XLA_ASSIGN_OR_RETURN_WITH_LOCATION(
582+
xla::ifrt::LoadedExecutable::ExecuteResult result,
583+
ifrt_computation.executable->Execute(absl::MakeSpan(argument_handles),
584+
execute_options, std::nullopt));
586585

587586
result.status.OnReady(std::move([timed, op_tracker = std::move(op_tracker)](
588587
absl::Status status) mutable {
@@ -599,7 +598,7 @@ IfrtComputationClient::ExecuteReplicated(
599598
? *ifrt_computation.output_shardings_
600599
: std::vector(outputs.size(),
601600
xla::HloSharding::Replicate().ToProto());
602-
XLA_CHECK_EQ(output_shardings.size(), outputs.size());
601+
ABSL_CHECK_EQ(output_shardings.size(), outputs.size());
603602

604603
std::vector<ComputationClient::DataPtr> data_handles(outputs.size());
605604
{

torch_xla/csrc/runtime/ifrt_computation_client.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,12 +69,12 @@ class IfrtComputationClient : public ComputationClient {
6969
std::vector<ComputationPtr> Compile(
7070
std::vector<CompileInstance> instances) override;
7171

72-
std::vector<DataPtr> ExecuteComputation(
72+
absl::StatusOr<std::vector<DataPtr>> ExecuteComputation(
7373
const Computation& computation, absl::Span<const DataPtr> arguments,
7474
const std::string& device,
7575
const ExecuteComputationOptions& options) override;
7676

77-
std::vector<DataPtr> ExecuteReplicated(
77+
absl::StatusOr<std::vector<DataPtr>> ExecuteReplicated(
7878
const Computation& computation, const absl::Span<const DataPtr> arguments,
7979
absl::Span<const std::string> devices,
8080
const ExecuteReplicatedOptions& options) override;

torch_xla/csrc/runtime/ifrt_computation_client_test.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,10 @@ TEST(PjRtComputationClientTest, Init) {
6464
std::make_shared<LiteralSource>(std::move(literal_y), device)};
6565

6666
// Execute the graph.
67-
std::vector<ComputationClient::DataPtr> results = client->ExecuteReplicated(
68-
*computations[0], client->TransferToDevice(absl::MakeConstSpan(args)),
69-
{device}, options);
67+
std::vector<ComputationClient::DataPtr> results = GetValueOrThrow(
68+
client->ExecuteReplicated(
69+
*computations[0], client->TransferToDevice(absl::MakeConstSpan(args)),
70+
{device}, options));
7071

7172
// Copy the output from device back to host and assert correctness..
7273
ASSERT_EQ(results.size(), 1);

torch_xla/csrc/runtime/pjrt_computation_client.cpp

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -375,8 +375,8 @@ PjRtComputationClient::ReplicateShardedData(
375375
torch_xla::runtime::ComputationClient::ExecuteReplicatedOptions
376376
execute_options;
377377
auto sharded_results =
378-
ExecuteReplicated(*computations.front(), {sharded_data},
379-
GetLocalDevices(), execute_options);
378+
GetValueOrThrow(ExecuteReplicated(*computations.front(), {sharded_data},
379+
GetLocalDevices(), execute_options));
380380
XLA_CHECK(sharded_results.size() > 0)
381381
<< "empty ExecuteReplicated results returned.";
382382
XLA_CHECK(sharded_results.size() == 1)
@@ -462,8 +462,8 @@ std::vector<ComputationClient::DataPtr> PjRtComputationClient::ReshardData(
462462

463463
torch_xla::runtime::ComputationClient::ExecuteReplicatedOptions
464464
execute_options;
465-
auto resharded_results = ExecuteReplicated(
466-
*computation, handles, GetLocalDevices(), execute_options);
465+
auto resharded_results = GetValueOrThrow(ExecuteReplicated(
466+
*computation, handles, GetLocalDevices(), execute_options));
467467
return resharded_results;
468468
}
469469

@@ -711,7 +711,7 @@ torch::lazy::hash_t PjRtComputationClient::HashCompilationEnv() {
711711
return comp_env_hash_;
712712
}
713713

714-
std::vector<ComputationClient::DataPtr>
714+
absl::StatusOr<std::vector<ComputationClient::DataPtr>>
715715
PjRtComputationClient::ExecuteComputation(
716716
const ComputationClient::Computation& computation,
717717
absl::Span<const ComputationClient::DataPtr> arguments,
@@ -731,14 +731,14 @@ PjRtComputationClient::ExecuteComputation(
731731
dynamic_cast<const PjRtComputation&>(computation);
732732

733733
xla::PjRtDevice* pjrt_device = StringToPjRtDevice(device);
734-
XLA_CHECK(pjrt_device->IsAddressable()) << pjrt_device->DebugString();
734+
ABSL_CHECK(pjrt_device->IsAddressable()) << pjrt_device->DebugString();
735735

736736
std::vector<xla::PjRtBuffer*> buffers;
737737
buffers.reserve(arguments.size());
738738
for (auto& argument : arguments) {
739739
const PjRtData* pjrt_data = dynamic_cast<PjRtData*>(argument.get());
740740

741-
XLA_CHECK(pjrt_device == pjrt_data->buffer->device())
741+
ABSL_CHECK(pjrt_device == pjrt_data->buffer->device())
742742
<< "The device currently being used : " << pjrt_device->DebugString()
743743
<< " is different from the device where the buffer resides: "
744744
<< pjrt_data->buffer->device()->DebugString();
@@ -758,8 +758,9 @@ PjRtComputationClient::ExecuteComputation(
758758
<< " Done";
759759

760760
std::optional<xla::PjRtFuture<>> returned_future;
761-
std::vector<std::unique_ptr<xla::PjRtBuffer>> results =
762-
GetValueOrThrow(pjrt_computation.executable->ExecuteSharded(
761+
XLA_ASSIGN_OR_RETURN_WITH_LOCATION(
762+
std::vector<std::unique_ptr<xla::PjRtBuffer>> results,
763+
pjrt_computation.executable->ExecuteSharded(
763764
buffers, pjrt_device, execute_options, returned_future));
764765

765766
returned_future->OnReady(std::move(
@@ -784,7 +785,7 @@ PjRtComputationClient::ExecuteComputation(
784785
return datas;
785786
}
786787

787-
std::vector<ComputationClient::DataPtr>
788+
absl::StatusOr<std::vector<ComputationClient::DataPtr>>
788789
PjRtComputationClient::ExecuteReplicated(
789790
const ComputationClient::Computation& computation,
790791
absl::Span<const ComputationClient::DataPtr> arguments,
@@ -818,15 +819,15 @@ PjRtComputationClient::ExecuteReplicated(
818819
for (int32_t i = start; i < end; ++i) {
819820
auto pjrt_data =
820821
std::dynamic_pointer_cast<PjRtShardedData>(arguments[i]);
821-
XLA_CHECK_EQ(pjrt_data->shards.size(), devices.size())
822+
ABSL_CHECK_EQ(pjrt_data->shards.size(), devices.size())
822823
<< "Expected one shard per device";
823824

824825
for (int32_t d = 0; d < devices.size(); d++) {
825826
std::shared_ptr<PjRtData> shard = pjrt_data->shards[d];
826827

827828
xla::PjRtDevice* pjrt_device = StringToPjRtDevice(devices[d]);
828-
XLA_CHECK_EQ(shard->buffer->device(), pjrt_device);
829-
XLA_CHECK(pjrt_device->IsAddressable())
829+
ABSL_CHECK_EQ(shard->buffer->device(), pjrt_device);
830+
ABSL_CHECK(pjrt_device->IsAddressable())
830831
<< pjrt_device->DebugString();
831832

832833
argument_handles[d][i] = shard->buffer.get();
@@ -862,8 +863,10 @@ PjRtComputationClient::ExecuteReplicated(
862863
tsl::profiler::TraceMe activity(
863864
"PjRtComputationClient::ExecuteReplicated_execute",
864865
tsl::profiler::TraceMeLevel::kInfo);
865-
results = GetValueOrThrow(pjrt_computation.executable->Execute(
866-
std::move(argument_handles), execute_options, returned_futures));
866+
XLA_ASSIGN_OR_RETURN_WITH_LOCATION(
867+
results,
868+
pjrt_computation.executable->Execute(
869+
std::move(argument_handles), execute_options, returned_futures));
867870

868871
(*returned_futures)[0].OnReady(
869872
std::move([timed, op_tracker = std::move(op_tracker)](
@@ -886,7 +889,7 @@ PjRtComputationClient::ExecuteReplicated(
886889
const std::vector<xla::Shape>& output_shapes =
887890
result_shape.IsTuple() ? result_shape.tuple_shapes()
888891
: std::vector<xla::Shape>({result_shape});
889-
XLA_CHECK_EQ(output_shapes.size(), num_outputs);
892+
ABSL_CHECK_EQ(output_shapes.size(), num_outputs);
890893

891894
const std::vector<xla::OpSharding>& output_shardings =
892895
pjrt_computation.output_shardings_.has_value() && num_outputs > 0
@@ -895,7 +898,7 @@ PjRtComputationClient::ExecuteReplicated(
895898
// Without an explicit sharding annotation, the output is implicitly
896899
// replicated, and we mark explicitly replicated here.
897900
std::vector<xla::OpSharding>(num_outputs);
898-
XLA_CHECK_EQ(output_shardings.size(), num_outputs);
901+
ABSL_CHECK_EQ(output_shardings.size(), num_outputs);
899902

900903
absl::BlockingCounter counter(num_outputs);
901904

torch_xla/csrc/runtime/pjrt_computation_client.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,12 @@ class PjRtComputationClient : public ComputationClient {
7676

7777
ComputationPtr DeserializeComputation(const std::string& serialized) override;
7878

79-
std::vector<DataPtr> ExecuteComputation(
79+
absl::StatusOr<std::vector<DataPtr>> ExecuteComputation(
8080
const Computation& computation, absl::Span<const DataPtr> arguments,
8181
const std::string& device,
8282
const ExecuteComputationOptions& options) override;
8383

84-
std::vector<DataPtr> ExecuteReplicated(
84+
absl::StatusOr<std::vector<DataPtr>> ExecuteReplicated(
8585
const Computation& computation, absl::Span<const DataPtr> arguments,
8686
absl::Span<const std::string> devices,
8787
const ExecuteReplicatedOptions& options) override;

torch_xla/csrc/runtime/pjrt_computation_client_test.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,10 @@ TEST_F(PjRtComputationClientTest, Init) {
115115
std::make_shared<LiteralSource>(std::move(literal_y), device_)};
116116

117117
// Execute the graph.
118-
std::vector<ComputationClient::DataPtr> results = client_->ExecuteComputation(
119-
*computations[0], client_->TransferToDevice(absl::MakeConstSpan(args)),
120-
device_, options);
118+
std::vector<ComputationClient::DataPtr> results = GetValueOrThrow(
119+
client_->ExecuteComputation(
120+
*computations[0], client_->TransferToDevice(absl::MakeConstSpan(args)),
121+
device_, options));
121122

122123
// Copy the output from device back to host and assert correctness.
123124
ASSERT_EQ(results.size(), 1);

torch_xla/csrc/xla_backend_impl.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "torch_xla/csrc/runtime/computation_client.h"
1111
#include "torch_xla/csrc/runtime/debug_macros.h"
1212
#include "torch_xla/csrc/runtime/runtime.h"
13+
#include "torch_xla/csrc/status.h"
1314

1415
namespace at {
1516
// This function is defined in the codegenerated RegisterDispatchKey.cpp file.
@@ -162,10 +163,10 @@ class XlaBackendImpl : public torch::lazy::BackendImplInterface {
162163
c10::ArrayRef<torch::lazy::BackendDataPtr> arguments,
163164
const torch::lazy::BackendDevice& device) const override {
164165
std::vector<runtime::ComputationClient::DataPtr> results =
165-
runtime::GetComputationClientOrDie()->ExecuteComputation(
166+
GetValueOrThrow(runtime::GetComputationClientOrDie()->ExecuteComputation(
166167
*std::dynamic_pointer_cast<runtime::ComputationClient::Computation>(
167168
computation),
168-
UnwrapXlaData(arguments), device.toString());
169+
UnwrapXlaData(arguments), device.toString()));
169170
return WrapXlaData(results);
170171
}
171172

torch_xla/csrc/xla_graph_executor.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
#include "torch_xla/csrc/runtime/stablehlo_helper.h"
5252
#include "torch_xla/csrc/runtime/sys_util.h"
5353
#include "torch_xla/csrc/runtime/xla_util.h"
54+
#include "torch_xla/csrc/status.h"
5455
#include "torch_xla/csrc/shape_helper.h"
5556
#include "torch_xla/csrc/tensor_util.h"
5657
#include "torch_xla/csrc/thread_pool.h"
@@ -843,10 +844,10 @@ XLAGraphExecutor::ExecuteComputationWithBarrier(
843844
// tensor results. Both sharded and unsharded results should be
844845
// "Assign"ed to the corresponding data placeholders.
845846
std::vector<runtime::ComputationClient::DataPtr> outputs =
846-
runtime::GetComputationClientOrDie()->ExecuteReplicated(
847+
GetValueOrThrow(runtime::GetComputationClientOrDie()->ExecuteReplicated(
847848
*async->cached_computation->computation,
848849
UnwrapXlaData(async->parameters_data), devices,
849-
execute_options);
850+
execute_options));
850851
results = WrapXlaData(outputs);
851852
TF_VLOG(3) << "Executing Dynamo IR sharded graph hash "
852853
<< torch::lazy::HashToString(hash) << " on devices "
@@ -940,8 +941,8 @@ std::vector<torch::lazy::BackendDataPtr> XLAGraphExecutor::ExecuteStablehlo(
940941
}
941942

942943
std::vector<runtime::ComputationClient::DataPtr> result_data =
943-
runtime::GetComputationClientOrDie()->ExecuteComputation(
944-
*computations[0], UnwrapXlaData(arguments), device.toString());
944+
GetValueOrThrow(runtime::GetComputationClientOrDie()->ExecuteComputation(
945+
*computations[0], UnwrapXlaData(arguments), device.toString()));
945946

946947
return WrapXlaData(result_data);
947948
}
@@ -1117,10 +1118,10 @@ XLAGraphExecutor::ScheduleSyncTensorsGraph(
11171118
// tensor results. Both sharded and unsharded results should be
11181119
// "Assign"ed to the corresponding data placeholders.
11191120
std::vector<runtime::ComputationClient::DataPtr> outputs =
1120-
runtime::GetComputationClientOrDie()->ExecuteReplicated(
1121+
GetValueOrThrow(runtime::GetComputationClientOrDie()->ExecuteReplicated(
11211122
*async->cached_computation->computation,
11221123
UnwrapXlaData(async->parameters_data), devices,
1123-
execute_options);
1124+
execute_options));
11241125
results = WrapXlaData(outputs);
11251126
TORCH_LAZY_COUNTER("ExecuteReplicated", 1);
11261127
TF_VLOG(3) << "Executing IR graph hash "
@@ -1132,11 +1133,11 @@ XLAGraphExecutor::ScheduleSyncTensorsGraph(
11321133
<< torch::lazy::HashToString(hash) << " on device "
11331134
<< async->device << " ...";
11341135
std::vector<runtime::ComputationClient::DataPtr> outputs =
1135-
runtime::GetComputationClientOrDie()->ExecuteComputation(
1136+
GetValueOrThrow(runtime::GetComputationClientOrDie()->ExecuteComputation(
11361137
*async->cached_computation->computation,
11371138
UnwrapXlaData(async->parameters_data), async->device.toString(),
11381139
{/*explode_tuple=*/true,
1139-
/*eager_mode=*/use_eager_mode});
1140+
/*eager_mode=*/use_eager_mode}));
11401141
results = WrapXlaData(outputs);
11411142
TORCH_LAZY_COUNTER("ExecuteComputation", 1);
11421143
TF_VLOG(3) << "Executing IR graph hash "

0 commit comments

Comments
 (0)