Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/framework/net_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ TEST(OpKernel, all) {
net->Run(scope, dev_ctx);
ASSERT_EQ(2, infer_shape_cnt);
ASSERT_EQ(2, run_cnt);
ASSERT_THROW(net->AddOp(op2), std::runtime_error);
ASSERT_THROW(net->AddOp(op2), paddle::platform::EnforceNotMet);
}
TEST(AddBackwardOp, TestGradOp) {
auto net = std::make_shared<PlainNet>();
Expand Down
10 changes: 5 additions & 5 deletions paddle/framework/op_registry_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ TEST(OpRegistry, IllegalAttr) {
bool caught = false;
try {
paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (std::runtime_error& err) {
} catch (paddle::platform::EnforceNotMet err) {
caught = true;
std::string msg = "larger_than check fail";
const char* err_msg = err.what();
Expand Down Expand Up @@ -136,7 +136,7 @@ TEST(OpRegistry, CustomChecker) {
bool caught = false;
try {
paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (std::runtime_error& err) {
} catch (paddle::platform::EnforceNotMet err) {
caught = true;
std::string msg = "Attribute 'test_attr' is required!";
const char* err_msg = err.what();
Expand All @@ -154,7 +154,7 @@ TEST(OpRegistry, CustomChecker) {
caught = false;
try {
paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (std::runtime_error& err) {
} catch (paddle::platform::EnforceNotMet err) {
caught = true;
std::string msg = "'test_attr' must be even!";
const char* err_msg = err.what();
Expand Down Expand Up @@ -192,7 +192,7 @@ TEST(ProtoMaker, DuplicatedAttr) {
pd::OpProto op_proto;
pd::OpAttrChecker op_checker;
auto proto_maker = TestAttrProtoMaker(&op_proto, &op_checker);
ASSERT_THROW(proto_maker.Validate(), std::runtime_error);
ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
}

class TestInOutProtoMaker : public pd::OpProtoAndCheckerMaker {
Expand All @@ -208,5 +208,5 @@ TEST(ProtoMaker, DuplicatedInOut) {
pd::OpProto op_proto;
pd::OpAttrChecker op_checker;
auto proto_maker = TestInOutProtoMaker(&op_proto, &op_checker);
ASSERT_THROW(proto_maker.Validate(), std::runtime_error);
ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
}
4 changes: 2 additions & 2 deletions paddle/framework/tensor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ TEST(Tensor, DataAssert) {
bool caught = false;
try {
src_tensor.data<double>();
} catch (std::runtime_error& err) {
} catch (paddle::platform::EnforceNotMet err) {
caught = true;
std::string msg =
"Tenosr holds no memory. Call Tensor::mutable_data first.";
Expand Down Expand Up @@ -107,7 +107,7 @@ TEST(Tensor, ShareDataWith) {
bool caught = false;
try {
dst_tensor.ShareDataWith<float>(src_tensor);
} catch (std::runtime_error& err) {
} catch (paddle::platform::EnforceNotMet err) {
caught = true;
std::string msg =
"Tenosr holds no memory. Call Tensor::mutable_data first.";
Expand Down
68 changes: 37 additions & 31 deletions paddle/platform/enforce.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,21 @@ limitations under the License. */
namespace paddle {
namespace platform {

struct EnforceNotMet : public std::exception {
std::exception_ptr exp_;
std::string err_str_;

EnforceNotMet(std::exception_ptr e, const char* f, int l) : exp_(e) {
try {
std::rethrow_exception(exp_);
} catch (const std::exception& exp) {
err_str_ = string::Sprintf("%s at [%s:%d]", exp.what(), f, l);
}
}

const char* what() const noexcept { return err_str_.c_str(); }
};

// Because most enforce conditions would evaluate to true, we can use
// __builtin_expect to instruct the C++ compiler to generate code that
// always forces branch prediction of true.
Expand All @@ -52,9 +67,7 @@ template <typename... Args>
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
int stat, const Args&... args) {
if (UNLIKELY(!(stat))) {
throw std::runtime_error(
string::Sprintf(args...) +
string::Sprintf(" at [%s:%s];", __FILE__, __LINE__));
throw std::runtime_error(string::Sprintf(args...));
}
}

Expand All @@ -64,25 +77,17 @@ template <typename... Args>
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
cudaError_t e, const Args&... args) {
if (UNLIKELY(e)) {
// clang-format off
throw thrust::system_error(
e, thrust::cuda_category(),
string::Sprintf(args...) +
string::Sprintf(" at [%s:%s];", __FILE__, __LINE__));
// clang-format on
throw thrust::system_error(e, thrust::cuda_category(),
string::Sprintf(args...));
}
}

template <typename... Args>
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
curandStatus_t stat, const Args&... args) {
if (stat != CURAND_STATUS_SUCCESS) {
// clang-format off
throw thrust::system_error(
cudaErrorLaunchFailure, thrust::cuda_category(),
string::Sprintf(args...) +
string::Sprintf(" at [%s:%s];", __FILE__, __LINE__));
// clang-format on
throw thrust::system_error(cudaErrorLaunchFailure, thrust::cuda_category(),
string::Sprintf(args...));
}
}

Expand All @@ -92,12 +97,8 @@ inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
if (stat == CUDNN_STATUS_SUCCESS) {
return;
} else {
// clang-format off
throw std::runtime_error(
platform::dynload::cudnnGetErrorString(stat) +
string::Sprintf(args...) +
string::Sprintf(" at [%s:%s];", __FILE__, __LINE__));
// clang-format on
throw std::runtime_error(platform::dynload::cudnnGetErrorString(stat) +
string::Sprintf(args...));
}
}

Expand Down Expand Up @@ -126,22 +127,27 @@ inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
} else if (stat == CUBLAS_STATUS_LICENSE_ERROR) {
err = "CUBLAS: license error, ";
}
throw std::runtime_error(err + string::Sprintf(args...) +
string::Sprintf(" at [%s:%s];", __FILE__, __LINE__));
throw std::runtime_error(err + string::Sprintf(args...));
}

#endif // PADDLE_ONLY_CPU

#define PADDLE_THROW(...) \
do { \
throw std::runtime_error( \
string::Sprintf(__VA_ARGS__) + \
string::Sprintf(" at [%s:%s];", __FILE__, __LINE__)); \
#define PADDLE_THROW(...) \
do { \
throw ::paddle::platform::EnforceNotMet( \
std::make_exception_ptr( \
std::runtime_error(string::Sprintf(__VA_ARGS__))), \
__FILE__, __LINE__); \
} while (0)

#define PADDLE_ENFORCE(...) \
do { \
::paddle::platform::throw_on_error(__VA_ARGS__); \
#define PADDLE_ENFORCE(...) \
do { \
try { \
::paddle::platform::throw_on_error(__VA_ARGS__); \
} catch (...) { \
throw ::paddle::platform::EnforceNotMet(std::current_exception(), \
__FILE__, __LINE__); \
} \
} while (0)

} // namespace platform
Expand Down
2 changes: 1 addition & 1 deletion paddle/platform/enforce_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ TEST(ENFORCE, FAILED) {
bool in_catch = false;
try {
PADDLE_ENFORCE(false, "Enforce is not ok %d at all", 123);
} catch (const std::runtime_error& error) {
} catch (paddle::platform::EnforceNotMet error) {
// your error handling code here
in_catch = true;
std::string msg = "Enforce is not ok 123 at all";
Expand Down