Skip to content

In pjrt runtime client, raise a Python exception if XLA compilation fails. #9138

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 9 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions torch_xla/csrc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ ptxla_cc_library(
":shape_builder",
":shape_helper",
":version",
"//torch_xla/csrc:thread_pool",
"//torch_xla/csrc/runtime",
"//torch_xla/csrc/runtime:stablehlo_helper",
"//torch_xla/csrc/runtime:xla_util",
Expand Down
10 changes: 2 additions & 8 deletions torch_xla/csrc/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,10 @@ cc_library(
":env_vars",
":operation_manager",
":pjrt_registry",
":profiler",
":stablehlo_helper",
":tensor_source",
":tf_logging",
":xla_coordinator",
"//torch_xla/csrc:thread_pool",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:span",
Expand Down Expand Up @@ -485,15 +483,10 @@ ptxla_cc_test(
deps = [
":computation_client",
":pjrt_computation_client",
":operation_manager",
":tensor_source",
"@com_google_absl//absl/status",
"@xla//xla/tsl/lib/core:status_test_util",
"@tsl//tsl/platform:env",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:logging",
"@tsl//tsl/platform:test",
"@tsl//tsl/platform:test_main",
"@tsl//tsl/platform:statusor",
"@xla//xla:literal",
"@xla//xla:literal_util",
"@xla//xla:shape_util",
Expand All @@ -502,6 +495,7 @@ ptxla_cc_test(
"@xla//xla/tests:literal_test_util",
"@xla//xla/tools:hlo_module_loader",
],
timeout = "short",
)

# ptxla_cc_test(
Expand Down
33 changes: 18 additions & 15 deletions torch_xla/csrc/runtime/pjrt_computation_client.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
#include "torch_xla/csrc/runtime/pjrt_computation_client.h"

#include <algorithm>
#include <future>
#include <unordered_set>
#include <stdexcept>
#include <vector>

#include "absl/status/status.h"
Expand All @@ -13,34 +12,30 @@
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/env_hash.h"
#include "torch_xla/csrc/runtime/env_vars.h"
#include "torch_xla/csrc/runtime/operation_manager.h"
#include "torch_xla/csrc/runtime/pjrt_registry.h"
#include "torch_xla/csrc/runtime/profiler.h"
#include "torch_xla/csrc/runtime/stablehlo_helper.h"
#include "torch_xla/csrc/runtime/tensor_source.h"
#include "torch_xla/csrc/runtime/tf_logging.h"
#include "torch_xla/csrc/runtime/util.h"
#include "torch_xla/csrc/runtime/xla_coordinator.h"
#include "torch_xla/csrc/thread_pool.h"
#include "tsl/profiler/lib/traceme.h"
#include "xla/hlo/builder/xla_builder.h"
#include "xla/hlo/builder/xla_computation.h"
#include "xla/layout_util.h"
#include "xla/literal.h"
#include "xla/pjrt/c/pjrt_c_api_gpu_extension.h"
#include "xla/pjrt/c/pjrt_c_api_wrapper_impl.h"
#include "xla/pjrt/pjrt_api.h"
#include "xla/pjrt/pjrt_c_api_client.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/pjrt_executable.h"
#include "xla/protobuf_util.h"
#include "xla/service/custom_call_target_registry.h"
#include "xla/shape.h"

using xla::internal::XlaBuilderFriend;

namespace torch_xla {
namespace runtime {

using xla::internal::XlaBuilderFriend;

namespace {

// Builds a map from the device's global ordinal to its index in the `devices`
Expand Down Expand Up @@ -625,21 +620,29 @@ std::vector<ComputationClient::ComputationPtr> PjRtComputationClient::Compile(
device_assignment);
}

std::unique_ptr<xla::PjRtLoadedExecutable> executable;
// Compile the computation to an executible. For better user experience, if
// the XLA compiler fails for any reason, we raise a Python exception.
std::function<absl::StatusOr<std::unique_ptr<xla::PjRtLoadedExecutable>>()>
compile;
if (runtime::sys_util::GetEnvBool("XLA_STABLEHLO_COMPILE", false)) {
// Convert HLO to StableHLO for PjRt client compilation.
mlir::MLIRContext context;
mlir::ModuleOp mlir_module =
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
ConvertHloToStableHlo(instance.computation.mutable_proto(), &mlir_module);
executable =
client_->CompileAndLoad(mlir_module, compile_options).value();
compile = [mlir_module, &compile_options, this] {
return client_->CompileAndLoad(mlir_module, compile_options);
};
StableHloCompileCounter()->AddValue(1);
} else {
executable =
client_->CompileAndLoad(instance.computation, compile_options)
.value();
compile = [&] {
return client_->CompileAndLoad(instance.computation, compile_options);
};
}
std::unique_ptr<xla::PjRtLoadedExecutable> executable =
util::RaisePythonValueErrorOnFailure([&] {
return fake_xla_compile_ ? fake_xla_compile_() : compile();
});

auto memory_stats_status_or = executable->GetCompiledMemoryStats();
if (memory_stats_status_or.ok()) {
Expand Down
14 changes: 13 additions & 1 deletion torch_xla/csrc/runtime/pjrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ namespace runtime {
class PjRtComputationClient : public ComputationClient {
public:
PjRtComputationClient();
~PjRtComputationClient();
~PjRtComputationClient() override;

DataPtr CreateDataPlaceholder(
std::string device, xla::Shape shape,
Expand Down Expand Up @@ -162,6 +162,14 @@ class PjRtComputationClient : public ComputationClient {
const std::function<void()>& callback) override;

private:
friend class PjRtComputationClientTest;

// If `function` is not nullptr, makes the client call it instead of the real
// XLA compiler when compiling. Used for injecting fault for testing.
void FakeXlaCompileForTesting(std::function<absl::Status()> function) {
fake_xla_compile_ = std::move(function);
}

std::unique_ptr<xla::PjRtClient> client_;
std::unique_ptr<XlaCoordinator> coordinator_;
// global_ordinals_ tracks a map from PjRtDeviceId to the device's
Expand All @@ -174,6 +182,10 @@ class PjRtComputationClient : public ComputationClient {
tsl::Env::Default(), "pjrt", std::thread::hardware_concurrency());
torch::lazy::hash_t comp_env_hash_;

// If not nullptr, invoke this instead of the actual XLA compilation. Used
// only for testing.
std::function<absl::Status()> fake_xla_compile_ = nullptr;

xla::PjRtDevice* StringToPjRtDevice(const std::string& device);

struct PjRtData : public Data {
Expand Down
100 changes: 75 additions & 25 deletions torch_xla/csrc/runtime/pjrt_computation_client_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,48 @@

#include <gtest/gtest.h>

#include <functional>
#include <memory>
#include <stdexcept>
#include <string>
#include <vector>

#include "absl/status/status.h"
#include "torch_xla/csrc/runtime/computation_client.h"
#include "torch_xla/csrc/runtime/pjrt_computation_client.h"
#include "torch_xla/csrc/runtime/tensor_source.h"
#include "tsl/platform/env.h"
#include "tsl/platform/logging.h"
#include "tsl/platform/statusor.h"
#include "tsl/platform/test.h"
#include "xla/hlo/builder/xla_builder.h"
#include "xla/hlo/builder/xla_computation.h"
#include "xla/literal.h"
#include "xla/literal_util.h"
#include "xla/tests/literal_test_util.h"
#include "xla/tsl/lib/core/status_test_util.h"

namespace torch_xla {
namespace runtime {

absl::StatusOr<xla::XlaComputation> MakeComputation() {
xla::Shape input_shape =
class PjRtComputationClientTest : public ::testing::Test {
protected:
PjRtComputationClientTest() {
// Get a CPU client.
tsl::setenv("PJRT_DEVICE", "CPU", true);
client_ = std::make_unique<PjRtComputationClient>();
device_ = client_->GetDefaultDevice();
}

static void FakeXlaCompileForTesting(
PjRtComputationClient* client,
std::function<absl::Status()> fake_compile) {
client->FakeXlaCompileForTesting(std::move(fake_compile));
}

std::unique_ptr<PjRtComputationClient> client_;
std::string device_;
};

// Returns a computation to compute x + y where x and y are both F32[2,2]
// arrays.
absl::StatusOr<xla::XlaComputation> MakeAddComputation() {
const xla::Shape input_shape =
xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {2, 2});
xla::XlaBuilder builder("AddComputation");
xla::XlaOp x = xla::Parameter(&builder, 0, input_shape, "x");
Expand All @@ -34,19 +52,51 @@ absl::StatusOr<xla::XlaComputation> MakeComputation() {
return builder.Build();
}

TEST(PjRtComputationClientTest, Init) {
// Get a CPU client.
tsl::setenv("PJRT_DEVICE", "CPU", true);
auto client = std::make_unique<PjRtComputationClient>();
std::string device = client->GetDefaultDevice();
TEST_F(PjRtComputationClientTest, ThrowsExpectedExceptionWhenCompileFails) {
// Compose a computation to add two matrices.
xla::Shape out_shape(xla::F32, {2, 2},
/*dynamic_dimensions=*/{});
std::vector<ComputationClient::CompileInstance> instances;
instances.push_back(ComputationClient::CompileInstance(
std::move(MakeAddComputation().value()), device_,
client_->GetCompilationDevices(device_, client_->GetLocalDevices()),
&out_shape));

// Force XLA to fail with the given error when invoked by Compile() below.
FakeXlaCompileForTesting(
client_.get(), [] { return absl::InvalidArgumentError("invalid arg"); });

// Compiling the graph should fail, which should throw instead of crashing.
EXPECT_THROW(client_->Compile(std::move(instances)), std::invalid_argument);
}

TEST_F(PjRtComputationClientTest, ThrowsExpectedExceptionWhenCompileThrows) {
// Compose a computation to add two matrices.
xla::Shape out_shape(xla::F32, {2, 2},
/*dynamic_dimensions=*/{});
std::vector<ComputationClient::CompileInstance> instances;
instances.push_back(ComputationClient::CompileInstance(
std::move(MakeAddComputation().value()), device_,
client_->GetCompilationDevices(device_, client_->GetLocalDevices()),
&out_shape));

// Force XLA to throw with the given error when invoked by Compile() below.
FakeXlaCompileForTesting(client_.get(), []() -> absl::Status {
throw absl::BadStatusOrAccess(absl::InvalidArgumentError("invalid arg"));
});

// Compiling the graph should fail, which should throw instead of crashing.
EXPECT_THROW(client_->Compile(std::move(instances)), std::invalid_argument);
}

// Compose a computation.
auto shape = xla::ShapeUtil::MakeShape(xla::F32, {2, 2});
TEST_F(PjRtComputationClientTest, Init) {
// Compose a computation to add two 2x2 matrices.
auto out_shape = xla::ShapeUtil::MakeShape(xla::F32, {2, 2});
std::vector<ComputationClient::CompileInstance> instances;
instances.push_back(ComputationClient::CompileInstance(
std::move(MakeComputation().value()), device,
client->GetCompilationDevices(device, client->GetLocalDevices()),
&shape));
std::move(MakeAddComputation().value()), device_,
client_->GetCompilationDevices(device_, client_->GetLocalDevices()),
&out_shape));

// Prepare inputs.
xla::Literal literal_x =
Expand All @@ -56,22 +106,22 @@ TEST(PjRtComputationClientTest, Init) {

// Compile the graph.
std::vector<ComputationClient::ComputationPtr> computations =
client->Compile(std::move(instances));
client_->Compile(std::move(instances));

// Copy inputs to device.
ComputationClient::ExecuteComputationOptions options{};
std::vector<std::shared_ptr<const TensorSource>> args = {
std::make_shared<LiteralSource>(std::move(literal_x), device),
std::make_shared<LiteralSource>(std::move(literal_y), device)};
std::make_shared<LiteralSource>(std::move(literal_x), device_),
std::make_shared<LiteralSource>(std::move(literal_y), device_)};

// Execute the graph.
std::vector<ComputationClient::DataPtr> results = client->ExecuteComputation(
*computations[0], client->TransferToDevice(absl::MakeConstSpan(args)),
device, options);
std::vector<ComputationClient::DataPtr> results = client_->ExecuteComputation(
*computations[0], client_->TransferToDevice(absl::MakeConstSpan(args)),
device_, options);

// Copy the output from device back to host and assert correctness..
// Copy the output from device back to host and assert correctness.
ASSERT_EQ(results.size(), 1);
auto result_literals = client->TransferFromDevice(results);
auto result_literals = client_->TransferFromDevice(results);
ASSERT_THAT(result_literals, ::testing::SizeIs(1));
EXPECT_TRUE(xla::LiteralTestUtil::Equal(
xla::LiteralUtil::CreateR2<float>({{6.0f, 8.0f}, {10.0f, 12.0f}}),
Expand Down
49 changes: 49 additions & 0 deletions torch_xla/csrc/runtime/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
#include <memory>
#include <numeric>
#include <set>
#include <stdexcept>
#include <string>
#include <type_traits>
#include <vector>

#include "absl/status/statusor.h"
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "torch_xla/csrc/runtime/types.h"
Expand Down Expand Up @@ -128,6 +130,53 @@ T Multiply(const S& input) {
std::multiplies<T>());
}

namespace internal {

// ExtractStatusOrValue<U>::type is T if U is absl::StatusOr<T>, and is
// undefined otherwise.
template <typename U>
struct ExtractStatusOrValue;
template <typename T>
struct ExtractStatusOrValue<absl::StatusOr<T>> {
using type = T;
};

} // namespace internal

// RaisePythonValueErrorOnFailure(func) requires `func` to be a functor that
// takes no argument and returns an absl::StatusOr<T>. It's a wrapper of
// `func()` that translates any failure in `func()` to a Python ValueError
// exception. In particular:
//
// - if `func()` returns an error, throws an std::invalid_argument,
// which is translated to a Python ValueError exception;
// (https://pybind11.readthedocs.io/en/stable/advanced/exceptions.html).
// - if `func()` throws any exception, rethrows it as an
// std::invalid_argument so that we get a Python ValueError;
// - if `func()` successfully returns a value of type T, returns the value;
// - however, if `func()` crashes (e.g. due to a CHECK), we cannot
// catch it; therefore we should ensure that `func()` never
// crashes (and fix any crash as a bug).
template <typename Func>
typename internal::ExtractStatusOrValue<decltype(std::declval<Func>()())>::type
RaisePythonValueErrorOnFailure(const Func& func) {
decltype(std::declval<Func>()()) result;
try {
result = func();
} catch (const std::exception& e) {
throw std::invalid_argument(e.what());
} catch (...) {
throw std::invalid_argument(
"Function threw an unknown exception. Please file a bug at "
"https://github.com/pytorch/xla/issues with details on how to "
"reproduce the error.");
}
if (result.ok()) {
return *std::move(result);
}
throw std::invalid_argument(std::string(result.status().message()));
}

} // namespace util
} // namespace runtime
} // namespace torch_xla
Expand Down
Loading