Skip to content

Commit a2c38b5

Browse files
committed
Address reviews.
1 parent a1b2d82 commit a2c38b5

File tree

8 files changed

+107
-53
lines changed

8 files changed

+107
-53
lines changed

test/cpp/test_status.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,26 +30,26 @@ TEST(StatusTest, MaybeWithNewMessageNonEmptyNewMessage) {
3030
EXPECT_EQ(result.message(), new_err_string);
3131
}
3232

33-
TEST(StatusTest, ConsumeAndMaybeThrowWithOkStatus) {
33+
TEST(StatusTest, MaybeThrowWithOkStatus) {
3434
absl::Status ok_status = absl::OkStatus();
35-
EXPECT_NO_THROW(ConsumeAndMaybeThrow(ok_status));
35+
EXPECT_NO_THROW(MaybeThrow(ok_status));
3636
}
3737

38-
TEST(StatusTest, ConsumeAndMaybeThrowWithErrorStatus) {
38+
TEST(StatusTest, MaybeThrowWithErrorStatus) {
3939
absl::Status error_status = absl::InvalidArgumentError("Test error");
40-
EXPECT_THROW(ConsumeAndMaybeThrow(error_status), std::runtime_error);
40+
EXPECT_THROW(MaybeThrow(error_status), std::runtime_error);
4141
}
4242

43-
TEST(StatusTest, ConsumeAndMaybeThrowWithOkStatusOr) {
43+
TEST(StatusTest, GetValueOrThrowWithOkStatusOr) {
4444
int value = 42;
4545
absl::StatusOr<int> status_or = value;
46-
int result = ConsumeAndMaybeThrow(std::move(status_or));
46+
int result = GetValueOrThrow(std::move(status_or));
4747
EXPECT_EQ(result, value);
4848
}
4949

50-
TEST(StatusTest, ConsumeAndMaybeThrowWithErrorStatusOr) {
50+
TEST(StatusTest, GetValueOrThrowWithErrorStatusOr) {
5151
absl::StatusOr<int> status_or = absl::InvalidArgumentError("Test error");
52-
EXPECT_THROW(ConsumeAndMaybeThrow(std::move(status_or)), std::runtime_error);
52+
EXPECT_THROW(GetValueOrThrow(std::move(status_or)), std::runtime_error);
5353
}
5454

5555
TEST(StatusTest, MacroReturnIfError) {

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#include "pybind11/pytypes.h"
3737
#include "pybind11/stl.h"
3838
#include "pybind11/stl_bind.h"
39+
#include "status.h"
3940
#include "torch_xla/csrc/XLANativeFunctions.h"
4041
#include "torch_xla/csrc/aten_autograd_ops.h"
4142
#include "torch_xla/csrc/aten_fallback.h"
@@ -1683,7 +1684,7 @@ void InitXlaModuleBindings(py::module m) {
16831684
})
16841685
.def("_init_computation_client",
16851686
[]() {
1686-
ConsumeAndMaybeThrow(runtime::GetComputationClient());
1687+
GetValueOrThrow(runtime::GetComputationClient());
16871688
})
16881689
.def("_xla_get_device_hw_type",
16891690
[](const at::Tensor& tensor) {

torch_xla/csrc/runtime/ifrt_computation_client.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ void IfrtComputationClient::InitializeCoordinator(int global_rank,
164164
std::string port) {
165165
XLA_CHECK(coordinator_ == nullptr)
166166
<< "Can only initialize the XlaCoordinator once.";
167-
coordinator_ = ConsumeAndMaybeThrow(
167+
coordinator_ = GetValueOrThrow(
168168
XlaCoordinator::Create(global_rank, world_size, master_addr, port));
169169
}
170170

torch_xla/csrc/runtime/pjrt_computation_client.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ void PjRtComputationClient::InitializeCoordinator(int global_rank,
155155
std::string port) {
156156
XLA_CHECK(coordinator_ == nullptr)
157157
<< "Can only initialize the XlaCoordinator once.";
158-
coordinator_ = ConsumeAndMaybeThrow(
158+
coordinator_ = GetValueOrThrow(
159159
XlaCoordinator::Create(global_rank, world_size, master_addr, port));
160160
}
161161

torch_xla/csrc/runtime/pjrt_registry.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ InitializePjRt(const std::string& device_type) {
110110
<< ", coordinator address=" << master_addr << ":" << port;
111111

112112
// Use the XlaCoordinator as the distributed key-value store.
113-
coordinator = ConsumeAndMaybeThrow(XlaCoordinator::Create(
113+
coordinator = GetValueOrThrow(XlaCoordinator::Create(
114114
global_process_rank, global_world_size, master_addr, port));
115115
std::shared_ptr<xla::DistributedRuntimeClient> distributed_client =
116116
coordinator->GetClient();
@@ -183,7 +183,7 @@ InitializePjRt(const std::string& device_type) {
183183
runtime::sys_util::GetEnvString("MASTER_ADDR", "localhost");
184184
std::string port = runtime::sys_util::GetEnvString(
185185
"XLA_COORDINATOR_PORT", XlaCoordinator::kDefaultCoordinatorPort);
186-
coordinator = ConsumeAndMaybeThrow(XlaCoordinator::Create(
186+
coordinator = GetValueOrThrow(XlaCoordinator::Create(
187187
global_process_rank, global_world_size, master_addr, port));
188188
std::shared_ptr<xla::DistributedRuntimeClient> distributed_client =
189189
coordinator->GetClient();

torch_xla/csrc/runtime/xla_coordinator.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ class XlaCoordinator {
1717
// Private struct for making the constructor private, but still callable
1818
// as: `std::make_unique<XlaCoordinator>(PrivateUse())`.
1919
struct PrivateUse {
20-
// Constructor needs to be explicit for allowing only instanciation
21-
// within a private context.
20+
// Constructor needs to be explicit for disallowing implicit construction
21+
// from `{}`.
2222
explicit PrivateUse() = default;
2323
};
2424

torch_xla/csrc/status.cpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@
66

77
namespace torch_xla {
88

9-
static bool ShowCppErrorContext() {
9+
// Returns whether we should show C++ error context.
10+
//
11+
// More specifically, whether the `XLA_SHOW_CPP_ERROR_CONTEXT` environment
12+
// variable is set or not.
13+
static bool ShouldShowCppErrorContext() {
1014
static const bool show_cpp_error_context = runtime::sys_util::GetEnvBool(
1115
runtime::env::kEnvShowCppErrorContext, false);
1216
return show_cpp_error_context;
@@ -23,7 +27,7 @@ absl::Status MaybeWithLocation(const absl::Status& status, const char* file,
2327
ABSL_CHECK(!status.ok());
2428

2529
// Return the same status if we don't need to add the C++ source location.
26-
if (!ShowCppErrorContext()) {
30+
if (!ShouldShowCppErrorContext()) {
2731
return status;
2832
}
2933

@@ -32,8 +36,6 @@ absl::Status MaybeWithLocation(const absl::Status& status, const char* file,
3236
absl::StrCat(status.message(), LocationStrWithSpace(file, line)));
3337
}
3438

35-
const absl::Status& GetStatus(const absl::Status& status) { return status; }
36-
3739
absl::Status MaybeWithNewMessage(const absl::Status& status, const char* file,
3840
const int32_t line,
3941
const std::string_view new_message) {
@@ -42,7 +44,7 @@ absl::Status MaybeWithNewMessage(const absl::Status& status, const char* file,
4244
// Return the same status if:
4345
// 1. we don't need to add the C++ source location.
4446
// 2. there's no new message to replace the old one.
45-
if (!ShowCppErrorContext() && new_message.empty()) {
47+
if (!ShouldShowCppErrorContext() && new_message.empty()) {
4648
return status;
4749
}
4850

@@ -54,15 +56,15 @@ absl::Status MaybeWithNewMessage(const absl::Status& status, const char* file,
5456
// context to give a better error message to the user.
5557
std::string_view message = new_message.empty() ? old_message : new_message;
5658

57-
// If `kEnvShowCppErrorContext` is set, show the context of this error.
59+
// If `XLA_SHOW_CPP_ERROR_CONTEXT` is set, show the context of this error.
5860
// In other words, show:
5961
// - The error location
6062
// - The old messages that were replaced by `new_message`.
6163
//
6264
// This should give more context for developers. Showing the older error
6365
// messages alongside their debug information.
6466
std::string context;
65-
if (ShowCppErrorContext()) {
67+
if (ShouldShowCppErrorContext()) {
6668
context = LocationStrWithSpace(file, line);
6769
if (!new_message.empty()) {
6870
context = absl::StrCat(context, "\nFrom Error: ", old_message);
@@ -72,7 +74,7 @@ absl::Status MaybeWithNewMessage(const absl::Status& status, const char* file,
7274
return absl::Status(status.code(), absl::StrCat(message, context));
7375
}
7476

75-
void ConsumeAndMaybeThrow(const absl::Status& status) {
77+
void MaybeThrow(const absl::Status& status) {
7678
if (!status.ok()) {
7779
throw std::runtime_error(std::string(status.message()));
7880
}

torch_xla/csrc/status.h

Lines changed: 81 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -14,54 +14,95 @@
1414

1515
namespace torch_xla {
1616

17-
// Creates a new Status instance, appending the current location (e.g. file and
18-
// line information) to the status message.
17+
// If `XLA_SHOW_CPP_ERROR_CONTEXT` is set, creates a new Status instance,
18+
// appending the current location (e.g. file and line information) to the
19+
// status message.
20+
//
21+
// This should be used whenever we are returning new error status.
22+
//
23+
// Example:
24+
//
25+
// XLA_ERROR_WITH_LOCATION(
26+
// absl::InvalidArgumentError("Error message.")
27+
// );
28+
//
29+
// If `XLA_SHOW_CPP_ERROR_CONTEXT` is set, the error shown will be:
30+
//
31+
// Error message. (at <cpp-source-file>:<line>)
1932
//
20-
// This should be used whenever we are returning new error status, instead of
21-
// propagating. Then, if `XLA_SHOW_CPP_ERROR_CONTEXT` environment variable is
22-
// set, the location information will be shown.
2333
#define XLA_ERROR_WITH_LOCATION(status) \
2434
::torch_xla::MaybeWithLocation(status, __FILE__, __LINE__)
2535

26-
#define XLA_CONCAT(a, b) XLA_CONCAT_IMPL(a, b)
27-
#define XLA_CONCAT_IMPL(a, b) a##b
36+
#define XLA_CONCAT_(a, b) XLA_CONCAT_IMPL_(a, b)
37+
#define XLA_CONCAT_IMPL_(a, b) a##b
2838

2939
// Unique identifier for the status variable for the current line.
30-
#define XLA_STATUS_VAR XLA_CONCAT(status__, __LINE__)
40+
#define XLA_STATUS_VAR_ XLA_CONCAT_(status_, __LINE__)
3141

3242
// Provides a flexible way to handle error checking with optional message
3343
// modification. It evaluates `expr`, checks if it's OK, and either:
3444
// 1. Returns early with an error status (potentially modified by the provided
3545
// additional messages)
3646
// 2. Proceeds with the given `then` block if successful
37-
#define XLA_RETURN_IF_ERROR_IMPL(expr, var, then, ...) \
47+
#define XLA_RETURN_IF_ERROR_IMPL_(expr, var, then, ...) \
3848
auto var = (expr); \
3949
if (!var.ok()) { \
4050
return ::torch_xla::MaybeWithNewMessage( \
4151
::torch_xla::GetStatus(var), __FILE__, __LINE__, ##__VA_ARGS__); \
4252
} \
43-
then;
53+
then
4454

4555
// Propagates `rexpr`, in case it's a non-ok status.
46-
#define XLA_RETURN_IF_ERROR(rexpr, ...) \
47-
do { \
48-
XLA_RETURN_IF_ERROR_IMPL(rexpr, XLA_STATUS_VAR, {}, ##__VA_ARGS__) \
56+
//
57+
// Example:
58+
//
59+
// XLA_RETURN_IF_ERROR(
60+
// FnThatReturnsStatus(),
61+
// "New error message."
62+
// );
63+
//
64+
// If the function call results in an ok status, execution continues. Otherwise,
65+
// we early return a non-ok status. Then, if `XLA_SHOW_CPP_ERROR_CONTEXT` is
66+
// set, the error shown will be:
67+
//
68+
// New error message. (at <cpp-source-file>:<line>)
69+
// Previous error message. (at <cpp-source-file>:<line>)
70+
// ...
71+
//
72+
#define XLA_RETURN_IF_ERROR(rexpr, ...) \
73+
do { \
74+
XLA_RETURN_IF_ERROR_IMPL_(rexpr, XLA_STATUS_VAR, {}, ##__VA_ARGS__) \
4975
} while (false)
5076

5177
// Propagates `rexpr`, in case it's a non-ok status. Otherwise, assign
5278
// its result to `lhs`.
5379
//
5480
// Note 1: `lhs` might be a variable declarate, e.g:
5581
//
56-
// XLA_ASSIGN_OR_RETURN(int value, FnThatReturnsStatus(), ...);
57-
//
5882
// Note 2: this macro will be replaced by multiple statements that live on
5983
// the scope it was called (see XLA_RETURN_IF_ERROR_IMPL).
6084
//
61-
#define XLA_ASSIGN_OR_RETURN(lhs, rexpr, ...) \
62-
XLA_RETURN_IF_ERROR_IMPL(rexpr, XLA_STATUS_VAR, \
63-
lhs = std::move(XLA_STATUS_VAR).value(); \
64-
, ##__VA_ARGS__)
85+
// Example:
86+
//
87+
// XLA_ASSIGN_OR_RETURN(
88+
// int result,
89+
// FnThatReturnsStatus(),
90+
// "New error message."
91+
// );
92+
//
93+
// If the function call results in an ok status, execution continues with
94+
// `result` set to `ret.value()`, where `ret` is the returned value of the
95+
// function. Otherwise, we early return a non-ok status. Then, if
96+
// `XLA_SHOW_CPP_ERROR_CONTEXT` is set, the error shown will be:
97+
//
98+
// New error message. (at <cpp-source-file>:<line>)
99+
// Previous error message. (at <cpp-source-file>:<line>)
100+
// ...
101+
//
102+
#define XLA_ASSIGN_OR_RETURN(lhs, rexpr, ...) \
103+
XLA_RETURN_IF_ERROR_IMPL_(rexpr, XLA_STATUS_VAR, \
104+
lhs = std::move(XLA_STATUS_VAR).value(), \
105+
##__VA_ARGS__)
65106

66107
// Maybe shows location information in the status message.
67108
//
@@ -73,8 +114,13 @@ namespace torch_xla {
73114
absl::Status MaybeWithLocation(const absl::Status& status, const char* file,
74115
int32_t line);
75116

76-
const absl::Status& GetStatus(const absl::Status& status);
117+
// Returns an `absl::Status` from an `absl::Status`.
118+
// In this case, this function is a no-op. It simply returns the argument.
119+
inline const absl::Status& GetStatus(const absl::Status& status) {
120+
return status;
121+
}
77122

123+
// Returns an `absl::Status` from an `absl::StatusOr<T>`.
78124
template <class T>
79125
const absl::Status& GetStatus(const absl::StatusOr<T>& status) {
80126
return status.status();
@@ -97,27 +143,32 @@ absl::Status MaybeWithNewMessage(const absl::Status& status, const char* file,
97143
int32_t line,
98144
std::string_view new_message = "");
99145

100-
// Consumes the `status` and maybe throws an exception if `status` has
101-
// a non-ok code.
146+
// Maybe throws an exception if `status` has a non-ok code.
102147
//
103148
// Ideally, this function should be used only used in the project's
104149
// boundary, e.g. when we need to throw an exception for the user to see.
105-
void ConsumeAndMaybeThrow(const absl::Status& status);
150+
void MaybeThrow(const absl::Status& status);
106151

107-
// Consumes the `status`, either returning the value it holds (for
108-
// ok status), or throwing an exception.
152+
// Either returns the value `status` holds, if it's an ok-status, or throw an
153+
// exception from its error status.
109154
template <class T>
110-
T ConsumeAndMaybeThrow(absl::StatusOr<T>&& status) {
111-
ConsumeAndMaybeThrow(status.status());
112-
return std::move(status).value();
155+
T& GetValueOrThrow(absl::StatusOr<T>& status) {
156+
MaybeThrow(status.status());
157+
return status.value();
113158
}
114159

115160
template <class T>
116-
T ConsumeAndMaybeThrow(const absl::StatusOr<T>& status) {
117-
ConsumeAndMaybeThrow(status.status());
161+
const T& GetValueOrThrow(const absl::StatusOr<T>& status) {
162+
MaybeThrow(status.status());
118163
return status.value();
119164
}
120165

166+
template <class T>
167+
T GetValueOrThrow(absl::StatusOr<T>&& status) {
168+
MaybeThrow(status.status());
169+
return std::move(status).value();
170+
}
171+
121172
} // namespace torch_xla
122173

123174
#endif // XLA_TORCH_XLA_CSRC_STATUS_H_

0 commit comments

Comments
 (0)