Skip to content

Commit cef8c1e

Browse files
committed
Propagate status on OOM crashes and exception.
1 parent bac2768 commit cef8c1e

14 files changed

+35
-37
lines changed

test/cpp/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ ptxla_cc_library(
4040
"//torch_xla/csrc/runtime:runtime",
4141
"//torch_xla/csrc/runtime:debug_macros",
4242
"//torch_xla/csrc/runtime:sys_util",
43+
"//torch_xla/csrc:status",
4344
"//torch_xla/csrc:tensor",
4445
"@com_google_absl//absl/types:span",
4546
"@com_google_googletest//:gtest",

test/cpp/cpp_test_util.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "torch_xla/csrc/runtime/debug_macros.h"
1515
#include "torch_xla/csrc/runtime/runtime.h"
1616
#include "torch_xla/csrc/runtime/sys_util.h"
17+
#include "torch_xla/csrc/status.h"
1718
#include "torch_xla/csrc/tensor_impl.h"
1819
#include "torch_xla/csrc/tensor_util.h"
1920
#include "torch_xla/csrc/torch_util.h"
@@ -301,9 +302,8 @@ std::vector<torch_xla::runtime::ComputationClient::DataPtr> Execute(
301302
std::vector<at::Tensor> Fetch(
302303
absl::Span<const torch_xla::runtime::ComputationClient::DataPtr>
303304
device_data) {
304-
std::vector<xla::Literal> literals =
305-
torch_xla::runtime::GetComputationClientOrDie()->TransferFromDevice(
306-
device_data);
305+
std::vector<xla::Literal> literals = GetValueOrThrow(
306+
runtime::GetComputationClientOrDie()->TransferFromDevice(device_data));
307307
std::vector<at::Tensor> tensors;
308308
for (auto& literal : literals) {
309309
tensors.push_back(MakeTensorFromXlaLiteral(

test/cpp/test_replication.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "torch_xla/csrc/helpers.h"
1111
#include "torch_xla/csrc/runtime/debug_macros.h"
1212
#include "torch_xla/csrc/runtime/runtime.h"
13+
#include "torch_xla/csrc/status.h"
1314
#include "torch_xla/csrc/tensor_util.h"
1415
#include "torch_xla/csrc/thread_pool.h"
1516
#include "torch_xla/csrc/torch_util.h"
@@ -78,9 +79,8 @@ void TestSingleReplication(
7879
counter.Wait();
7980

8081
for (size_t i = 0; i < results.size(); ++i) {
81-
std::vector<xla::Literal> literals =
82-
torch_xla::runtime::GetComputationClientOrDie()->TransferFromDevice(
83-
results[i]);
82+
std::vector<xla::Literal> literals = GetValueOrThrow(
83+
runtime::GetComputationClientOrDie()->TransferFromDevice(results[i]));
8484
ASSERT_EQ(literals.size(), 1);
8585

8686
// The result must be the original tensor value, multiplied by the number of

torch_xla/csrc/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ ptxla_cc_library(
125125
":layout_manager",
126126
":shape_builder",
127127
":shape_helper",
128+
":status",
128129
":version",
129130
"//torch_xla/csrc:hash_util",
130131
"//torch_xla/csrc:thread_pool",

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1229,9 +1229,9 @@ class PyLoweringContext {
12291229
lowering_ctx.GetParametersData();
12301230

12311231
// Fetch this parameter data
1232-
std::vector<xla::Literal> literals =
1232+
std::vector<xla::Literal> literals = GetValueOrThrow(
12331233
runtime::GetComputationClientOrDie()->TransferFromDevice(
1234-
UnwrapXlaData(device_data));
1234+
UnwrapXlaData(device_data)));
12351235

12361236
// Create a mapping from paramater id to the tensor data
12371237
std::unordered_map<int64_t, at::Tensor> results;

torch_xla/csrc/runtime/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ cc_library(
121121
":tf_logging",
122122
":xla_coordinator",
123123
"//torch_xla/csrc:status",
124+
"@com_google_absl//absl/log:absl_check",
124125
"@com_google_absl//absl/strings",
125126
"@com_google_absl//absl/synchronization",
126127
"@com_google_absl//absl/types:span",

torch_xla/csrc/runtime/computation_client.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ class ComputationClient {
317317
// Note: `TransferFromDevice` call will block until the `DataPtrs` are ready
318318
// if they were created by `TransferToDevice` or `Execute*`. Calling this from
319319
// python while holding the GIL can cause deadlocks!
320-
virtual std::vector<xla::Literal> TransferFromDevice(
320+
virtual absl::StatusOr<std::vector<xla::Literal>> TransferFromDevice(
321321
absl::Span<const DataPtr> handles) = 0;
322322

323323
virtual std::uintptr_t UnsafeBufferPointer(const DataPtr handle) = 0;

torch_xla/csrc/runtime/ifrt_computation_client.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -423,8 +423,8 @@ std::shared_ptr<xla::PjRtBuffer> IfrtComputationClient::GetPjRtBuffer(
423423
XLA_ERROR() << __FUNCTION__ << " not implemented";
424424
}
425425

426-
std::vector<xla::Literal> IfrtComputationClient::TransferFromDevice(
427-
absl::Span<const DataPtr> handles) {
426+
absl::StatusOr<std::vector<xla::Literal>>
427+
IfrtComputationClient::TransferFromDevice(absl::Span<const DataPtr> handles) {
428428
metrics::TimedSection timed(TransferFromDeviceMetric());
429429
tsl::profiler::TraceMe activity("IfrtComputationClient::TransferFromDevice",
430430
tsl::profiler::TraceMeLevel::kInfo);
@@ -442,9 +442,9 @@ std::vector<xla::Literal> IfrtComputationClient::TransferFromDevice(
442442
auto& literal = literals.emplace_back(
443443
xla::ShapeUtil::DeviceShapeToHostShape(ifrt_data->shape()));
444444
std::vector<int64_t> byte_strides(literal.shape().dimensions_size());
445-
XLA_CHECK_OK(xla::ShapeUtil::ByteStrides(literal.shape(),
446-
absl::MakeSpan(byte_strides)));
447-
XLA_CHECK_OK(
445+
XLA_RETURN_IF_ERROR(xla::ShapeUtil::ByteStrides(
446+
literal.shape(), absl::MakeSpan(byte_strides)));
447+
XLA_RETURN_IF_ERROR(
448448
replicated_array
449449
->CopyToHostBuffer(literal.untyped_data(), byte_strides,
450450
xla::ifrt::ArrayCopySemantics::kAlwaysCopy)

torch_xla/csrc/runtime/ifrt_computation_client.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class IfrtComputationClient : public ComputationClient {
5353
XLA_ERROR() << __FUNCTION__ << " not implemented";
5454
}
5555

56-
std::vector<xla::Literal> TransferFromDevice(
56+
absl::StatusOr<std::vector<xla::Literal>> TransferFromDevice(
5757
absl::Span<const DataPtr> handles) override;
5858

5959
std::uintptr_t UnsafeBufferPointer(const DataPtr handle) override;

torch_xla/csrc/runtime/ifrt_computation_client_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ TEST(PjRtComputationClientTest, Init) {
7070

7171
// Copy the output from device back to host and assert correctness..
7272
ASSERT_EQ(results.size(), 1);
73-
auto result_literals = client->TransferFromDevice(results);
73+
auto result_literals = GetValueOrThrow(client->TransferFromDevice(results));
7474
ASSERT_THAT(result_literals, ::testing::SizeIs(1));
7575
EXPECT_TRUE(xla::LiteralTestUtil::Equal(
7676
xla::LiteralUtil::CreateR2<float>({{6.0f, 8.0f}, {10.0f, 12.0f}}),

0 commit comments

Comments
 (0)