Skip to content

Error Handling: refactor ExecuteComputation and ExecuteReplicated to propagate status. #9445

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: ysiraichi/status-for-oom-errors
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2448,7 +2448,8 @@ def test_isneginf_no_fallback(self):
def test_construct_large_tensor_raises_error(self):
a = torch.rand(1024, 1024, 1024, 1024, 1024, device=torch_xla.device())

with self.assertRaisesRegex(RuntimeError, r"Out of memory allocating \d* bytes"):
with self.assertRaisesRegex(RuntimeError,
r"Out of memory allocating \d* bytes"):
a.cpu()


Expand Down
5 changes: 3 additions & 2 deletions torch_xla/csrc/runtime/computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "absl/status/statusor.h"
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "torch_xla/csrc/device.h"
Expand Down Expand Up @@ -345,7 +346,7 @@ class ComputationClient {
// The passed device must match the common device of the arguments Data.
// If options.explode_tuple is true, the output tuple will be decomposed into
// its single elements.
virtual std::vector<DataPtr> ExecuteComputation(
virtual absl::StatusOr<std::vector<DataPtr>> ExecuteComputation(
const Computation& computation, absl::Span<const DataPtr> arguments,
const std::string& device,
const ExecuteComputationOptions& options =
Expand All @@ -356,7 +357,7 @@ class ComputationClient {
// as `devices`. If options.explode_tuple is true, the output tuples will be
// decomposed into their single elements. Returns a vector of outputs, each
// of which is sharded in the same order as `devices`.
virtual std::vector<DataPtr> ExecuteReplicated(
virtual absl::StatusOr<std::vector<DataPtr>> ExecuteReplicated(
const Computation& computation, absl::Span<const DataPtr> arguments,
absl::Span<const std::string> devices,
const ExecuteReplicatedOptions& options) = 0;
Expand Down
22 changes: 11 additions & 11 deletions torch_xla/csrc/runtime/ifrt_computation_client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <unordered_set>
#include <vector>

#include "absl/log/absl_check.h"
#include "absl/strings/ascii.h"
#include "absl/synchronization/blocking_counter.h"
#include "absl/types/span.h"
Expand Down Expand Up @@ -403,8 +404,8 @@ tsl::RCReference<xla::ifrt::Array> IfrtComputationClient::ReplicateShardedData(
torch_xla::runtime::ComputationClient::ExecuteReplicatedOptions
execute_options;

auto sharded_results = ExecuteReplicated(*computations.front(), {{handle}},
GetLocalDevices(), execute_options);
auto sharded_results = GetValueOrThrow(ExecuteReplicated(
*computations.front(), {{handle}}, GetLocalDevices(), execute_options));
auto replicated_output =
std::dynamic_pointer_cast<IfrtData>(sharded_results[0])
->buffer->FullyReplicatedShard(
Expand Down Expand Up @@ -524,16 +525,16 @@ std::vector<ComputationClient::ComputationPtr> IfrtComputationClient::Compile(
return computations;
}

std::vector<ComputationClient::DataPtr>
absl::StatusOr<std::vector<ComputationClient::DataPtr>>
IfrtComputationClient::ExecuteComputation(
const ComputationClient::Computation& computation,
absl::Span<const ComputationClient::DataPtr> arguments,
const std::string& device, const ExecuteComputationOptions& options) {
// TODO: Implement sharded exec in IFRT
XLA_ERROR() << __FUNCTION__ << " not implemented";
return absl::UnimplementedError("ExecuteComputation not implemented");
}

std::vector<ComputationClient::DataPtr>
absl::StatusOr<std::vector<ComputationClient::DataPtr>>
IfrtComputationClient::ExecuteReplicated(
const ComputationClient::Computation& computation,
const absl::Span<const ComputationClient::DataPtr> arguments,
Expand Down Expand Up @@ -578,11 +579,10 @@ IfrtComputationClient::ExecuteReplicated(
TF_VLOG(5) << "ExecuteReplicated acquiring IFRT device lock for "
<< spmd_device_str << " Done";

xla::ifrt::LoadedExecutable::ExecuteResult result =
ifrt_computation.executable
->Execute(absl::MakeSpan(argument_handles), execute_options,
std::nullopt)
.value();
XLA_ASSIGN_OR_RETURN_WITH_LOCATION(
xla::ifrt::LoadedExecutable::ExecuteResult result,
ifrt_computation.executable->Execute(absl::MakeSpan(argument_handles),
execute_options, std::nullopt));

result.status.OnReady(std::move([timed, op_tracker = std::move(op_tracker)](
absl::Status status) mutable {
Expand All @@ -599,7 +599,7 @@ IfrtComputationClient::ExecuteReplicated(
? *ifrt_computation.output_shardings_
: std::vector(outputs.size(),
xla::HloSharding::Replicate().ToProto());
XLA_CHECK_EQ(output_shardings.size(), outputs.size());
ABSL_CHECK_EQ(output_shardings.size(), outputs.size());

std::vector<ComputationClient::DataPtr> data_handles(outputs.size());
{
Expand Down
4 changes: 2 additions & 2 deletions torch_xla/csrc/runtime/ifrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,12 @@ class IfrtComputationClient : public ComputationClient {
std::vector<ComputationPtr> Compile(
std::vector<CompileInstance> instances) override;

std::vector<DataPtr> ExecuteComputation(
absl::StatusOr<std::vector<DataPtr>> ExecuteComputation(
const Computation& computation, absl::Span<const DataPtr> arguments,
const std::string& device,
const ExecuteComputationOptions& options) override;

std::vector<DataPtr> ExecuteReplicated(
absl::StatusOr<std::vector<DataPtr>> ExecuteReplicated(
const Computation& computation, const absl::Span<const DataPtr> arguments,
absl::Span<const std::string> devices,
const ExecuteReplicatedOptions& options) override;
Expand Down
7 changes: 4 additions & 3 deletions torch_xla/csrc/runtime/ifrt_computation_client_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,10 @@ TEST(PjRtComputationClientTest, Init) {
std::make_shared<LiteralSource>(std::move(literal_y), device)};

// Execute the graph.
std::vector<ComputationClient::DataPtr> results = client->ExecuteReplicated(
*computations[0], client->TransferToDevice(absl::MakeConstSpan(args)),
{device}, options);
std::vector<ComputationClient::DataPtr> results =
GetValueOrThrow(client->ExecuteReplicated(
*computations[0], client->TransferToDevice(absl::MakeConstSpan(args)),
{device}, options));

// Copy the output from device back to host and assert correctness..
ASSERT_EQ(results.size(), 1);
Expand Down
37 changes: 20 additions & 17 deletions torch_xla/csrc/runtime/pjrt_computation_client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -375,8 +375,8 @@ PjRtComputationClient::ReplicateShardedData(
torch_xla::runtime::ComputationClient::ExecuteReplicatedOptions
execute_options;
auto sharded_results =
ExecuteReplicated(*computations.front(), {sharded_data},
GetLocalDevices(), execute_options);
GetValueOrThrow(ExecuteReplicated(*computations.front(), {sharded_data},
GetLocalDevices(), execute_options));
XLA_CHECK(sharded_results.size() > 0)
<< "empty ExecuteReplicated results returned.";
XLA_CHECK(sharded_results.size() == 1)
Expand Down Expand Up @@ -462,8 +462,8 @@ std::vector<ComputationClient::DataPtr> PjRtComputationClient::ReshardData(

torch_xla::runtime::ComputationClient::ExecuteReplicatedOptions
execute_options;
auto resharded_results = ExecuteReplicated(
*computation, handles, GetLocalDevices(), execute_options);
auto resharded_results = GetValueOrThrow(ExecuteReplicated(
*computation, handles, GetLocalDevices(), execute_options));
return resharded_results;
}

Expand Down Expand Up @@ -711,7 +711,7 @@ torch::lazy::hash_t PjRtComputationClient::HashCompilationEnv() {
return comp_env_hash_;
}

std::vector<ComputationClient::DataPtr>
absl::StatusOr<std::vector<ComputationClient::DataPtr>>
PjRtComputationClient::ExecuteComputation(
const ComputationClient::Computation& computation,
absl::Span<const ComputationClient::DataPtr> arguments,
Expand All @@ -731,14 +731,14 @@ PjRtComputationClient::ExecuteComputation(
dynamic_cast<const PjRtComputation&>(computation);

xla::PjRtDevice* pjrt_device = StringToPjRtDevice(device);
XLA_CHECK(pjrt_device->IsAddressable()) << pjrt_device->DebugString();
ABSL_CHECK(pjrt_device->IsAddressable()) << pjrt_device->DebugString();

std::vector<xla::PjRtBuffer*> buffers;
buffers.reserve(arguments.size());
for (auto& argument : arguments) {
const PjRtData* pjrt_data = dynamic_cast<PjRtData*>(argument.get());

XLA_CHECK(pjrt_device == pjrt_data->buffer->device())
ABSL_CHECK(pjrt_device == pjrt_data->buffer->device())
<< "The device currently being used : " << pjrt_device->DebugString()
<< " is different from the device where the buffer resides: "
<< pjrt_data->buffer->device()->DebugString();
Expand All @@ -758,8 +758,9 @@ PjRtComputationClient::ExecuteComputation(
<< " Done";

std::optional<xla::PjRtFuture<>> returned_future;
std::vector<std::unique_ptr<xla::PjRtBuffer>> results =
GetValueOrThrow(pjrt_computation.executable->ExecuteSharded(
XLA_ASSIGN_OR_RETURN_WITH_LOCATION(
std::vector<std::unique_ptr<xla::PjRtBuffer>> results,
pjrt_computation.executable->ExecuteSharded(
buffers, pjrt_device, execute_options, returned_future));

returned_future->OnReady(std::move(
Expand All @@ -784,7 +785,7 @@ PjRtComputationClient::ExecuteComputation(
return datas;
}

std::vector<ComputationClient::DataPtr>
absl::StatusOr<std::vector<ComputationClient::DataPtr>>
PjRtComputationClient::ExecuteReplicated(
const ComputationClient::Computation& computation,
absl::Span<const ComputationClient::DataPtr> arguments,
Expand Down Expand Up @@ -818,15 +819,15 @@ PjRtComputationClient::ExecuteReplicated(
for (int32_t i = start; i < end; ++i) {
auto pjrt_data =
std::dynamic_pointer_cast<PjRtShardedData>(arguments[i]);
XLA_CHECK_EQ(pjrt_data->shards.size(), devices.size())
ABSL_CHECK_EQ(pjrt_data->shards.size(), devices.size())
<< "Expected one shard per device";

for (int32_t d = 0; d < devices.size(); d++) {
std::shared_ptr<PjRtData> shard = pjrt_data->shards[d];

xla::PjRtDevice* pjrt_device = StringToPjRtDevice(devices[d]);
XLA_CHECK_EQ(shard->buffer->device(), pjrt_device);
XLA_CHECK(pjrt_device->IsAddressable())
ABSL_CHECK_EQ(shard->buffer->device(), pjrt_device);
ABSL_CHECK(pjrt_device->IsAddressable())
<< pjrt_device->DebugString();

argument_handles[d][i] = shard->buffer.get();
Expand Down Expand Up @@ -862,8 +863,10 @@ PjRtComputationClient::ExecuteReplicated(
tsl::profiler::TraceMe activity(
"PjRtComputationClient::ExecuteReplicated_execute",
tsl::profiler::TraceMeLevel::kInfo);
results = GetValueOrThrow(pjrt_computation.executable->Execute(
std::move(argument_handles), execute_options, returned_futures));
XLA_ASSIGN_OR_RETURN_WITH_LOCATION(
results,
pjrt_computation.executable->Execute(
std::move(argument_handles), execute_options, returned_futures));

(*returned_futures)[0].OnReady(
std::move([timed, op_tracker = std::move(op_tracker)](
Expand All @@ -886,7 +889,7 @@ PjRtComputationClient::ExecuteReplicated(
const std::vector<xla::Shape>& output_shapes =
result_shape.IsTuple() ? result_shape.tuple_shapes()
: std::vector<xla::Shape>({result_shape});
XLA_CHECK_EQ(output_shapes.size(), num_outputs);
ABSL_CHECK_EQ(output_shapes.size(), num_outputs);

const std::vector<xla::OpSharding>& output_shardings =
pjrt_computation.output_shardings_.has_value() && num_outputs > 0
Expand All @@ -895,7 +898,7 @@ PjRtComputationClient::ExecuteReplicated(
// Without an explicit sharding annotation, the output is implicitly
// replicated, and we mark explicitly replicated here.
std::vector<xla::OpSharding>(num_outputs);
XLA_CHECK_EQ(output_shardings.size(), num_outputs);
ABSL_CHECK_EQ(output_shardings.size(), num_outputs);

absl::BlockingCounter counter(num_outputs);

Expand Down
4 changes: 2 additions & 2 deletions torch_xla/csrc/runtime/pjrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,12 @@ class PjRtComputationClient : public ComputationClient {

ComputationPtr DeserializeComputation(const std::string& serialized) override;

std::vector<DataPtr> ExecuteComputation(
absl::StatusOr<std::vector<DataPtr>> ExecuteComputation(
const Computation& computation, absl::Span<const DataPtr> arguments,
const std::string& device,
const ExecuteComputationOptions& options) override;

std::vector<DataPtr> ExecuteReplicated(
absl::StatusOr<std::vector<DataPtr>> ExecuteReplicated(
const Computation& computation, absl::Span<const DataPtr> arguments,
absl::Span<const std::string> devices,
const ExecuteReplicatedOptions& options) override;
Expand Down
8 changes: 5 additions & 3 deletions torch_xla/csrc/runtime/pjrt_computation_client_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,11 @@ TEST_F(PjRtComputationClientTest, Init) {
std::make_shared<LiteralSource>(std::move(literal_y), device_)};

// Execute the graph.
std::vector<ComputationClient::DataPtr> results = client_->ExecuteComputation(
*computations[0], client_->TransferToDevice(absl::MakeConstSpan(args)),
device_, options);
std::vector<ComputationClient::DataPtr> results =
GetValueOrThrow(client_->ExecuteComputation(
*computations[0],
client_->TransferToDevice(absl::MakeConstSpan(args)), device_,
options));

// Copy the output from device back to host and assert correctness.
ASSERT_EQ(results.size(), 1);
Expand Down
5 changes: 3 additions & 2 deletions torch_xla/csrc/xla_backend_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "torch_xla/csrc/runtime/computation_client.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/runtime.h"
#include "torch_xla/csrc/status.h"

namespace at {
// This function is defined in the codegenerated RegisterDispatchKey.cpp file.
Expand Down Expand Up @@ -161,11 +162,11 @@ class XlaBackendImpl : public torch::lazy::BackendImplInterface {
torch::lazy::ComputationPtr computation,
c10::ArrayRef<torch::lazy::BackendDataPtr> arguments,
const torch::lazy::BackendDevice& device) const override {
std::vector<runtime::ComputationClient::DataPtr> results =
std::vector<runtime::ComputationClient::DataPtr> results = GetValueOrThrow(
runtime::GetComputationClientOrDie()->ExecuteComputation(
*std::dynamic_pointer_cast<runtime::ComputationClient::Computation>(
computation),
UnwrapXlaData(arguments), device.toString());
UnwrapXlaData(arguments), device.toString()));
return WrapXlaData(results);
}

Expand Down
35 changes: 20 additions & 15 deletions torch_xla/csrc/xla_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
#include "torch_xla/csrc/runtime/sys_util.h"
#include "torch_xla/csrc/runtime/xla_util.h"
#include "torch_xla/csrc/shape_helper.h"
#include "torch_xla/csrc/status.h"
#include "torch_xla/csrc/tensor_util.h"
#include "torch_xla/csrc/thread_pool.h"
#include "torch_xla/csrc/torch_util.h"
Expand Down Expand Up @@ -843,10 +844,11 @@ XLAGraphExecutor::ExecuteComputationWithBarrier(
// tensor results. Both sharded and unsharded results should be
// "Assign"ed to the corresponding data placeholders.
std::vector<runtime::ComputationClient::DataPtr> outputs =
runtime::GetComputationClientOrDie()->ExecuteReplicated(
*async->cached_computation->computation,
UnwrapXlaData(async->parameters_data), devices,
execute_options);
GetValueOrThrow(
runtime::GetComputationClientOrDie()->ExecuteReplicated(
*async->cached_computation->computation,
UnwrapXlaData(async->parameters_data), devices,
execute_options));
results = WrapXlaData(outputs);
TF_VLOG(3) << "Executing Dynamo IR sharded graph hash "
<< torch::lazy::HashToString(hash) << " on devices "
Expand Down Expand Up @@ -940,8 +942,8 @@ std::vector<torch::lazy::BackendDataPtr> XLAGraphExecutor::ExecuteStablehlo(
}

std::vector<runtime::ComputationClient::DataPtr> result_data =
runtime::GetComputationClientOrDie()->ExecuteComputation(
*computations[0], UnwrapXlaData(arguments), device.toString());
GetValueOrThrow(runtime::GetComputationClientOrDie()->ExecuteComputation(
*computations[0], UnwrapXlaData(arguments), device.toString()));

return WrapXlaData(result_data);
}
Expand Down Expand Up @@ -1117,10 +1119,11 @@ XLAGraphExecutor::ScheduleSyncTensorsGraph(
// tensor results. Both sharded and unsharded results should be
// "Assign"ed to the corresponding data placeholders.
std::vector<runtime::ComputationClient::DataPtr> outputs =
runtime::GetComputationClientOrDie()->ExecuteReplicated(
*async->cached_computation->computation,
UnwrapXlaData(async->parameters_data), devices,
execute_options);
GetValueOrThrow(
runtime::GetComputationClientOrDie()->ExecuteReplicated(
*async->cached_computation->computation,
UnwrapXlaData(async->parameters_data), devices,
execute_options));
results = WrapXlaData(outputs);
TORCH_LAZY_COUNTER("ExecuteReplicated", 1);
TF_VLOG(3) << "Executing IR graph hash "
Expand All @@ -1132,11 +1135,13 @@ XLAGraphExecutor::ScheduleSyncTensorsGraph(
<< torch::lazy::HashToString(hash) << " on device "
<< async->device << " ...";
std::vector<runtime::ComputationClient::DataPtr> outputs =
runtime::GetComputationClientOrDie()->ExecuteComputation(
*async->cached_computation->computation,
UnwrapXlaData(async->parameters_data), async->device.toString(),
{/*explode_tuple=*/true,
/*eager_mode=*/use_eager_mode});
GetValueOrThrow(
runtime::GetComputationClientOrDie()->ExecuteComputation(
*async->cached_computation->computation,
UnwrapXlaData(async->parameters_data),
async->device.toString(),
{/*explode_tuple=*/true,
/*eager_mode=*/use_eager_mode}));
results = WrapXlaData(outputs);
TORCH_LAZY_COUNTER("ExecuteComputation", 1);
TF_VLOG(3) << "Executing IR graph hash "
Expand Down