Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 18 additions & 17 deletions paddle/fluid/operators/print_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ namespace operators {

#define CLOG std::cout

const std::string kForward = "FORWARD";
const std::string kBackward = "BACKWARD";
const std::string kBoth = "BOTH";
const char kForward[] = "FORWARD";
const char kBackward[] = "BACKWARD";
const char kBoth[] = "BOTH";

struct Formater {
std::string message;
std::string name;
std::vector<int> dims;
std::type_index dtype{typeid(char)};
std::type_index dtype{typeid(const char)};
framework::LoD lod;
int summarize;
void* data{nullptr};
Expand Down Expand Up @@ -62,7 +62,7 @@ struct Formater {
}
}
void PrintDtype() {
if (dtype.hash_code() != typeid(char).hash_code()) {
if (dtype.hash_code() != typeid(const char).hash_code()) {
CLOG << "\tdtype: " << dtype.name() << std::endl;
}
}
Expand All @@ -83,15 +83,15 @@ struct Formater {
void PrintData(size_t size) {
PADDLE_ENFORCE_NOT_NULL(data);
// print float
if (dtype.hash_code() == typeid(float).hash_code()) {
if (dtype.hash_code() == typeid(const float).hash_code()) {
Display<float>(size);
} else if (dtype.hash_code() == typeid(double).hash_code()) {
} else if (dtype.hash_code() == typeid(const double).hash_code()) {
Display<double>(size);
} else if (dtype.hash_code() == typeid(int).hash_code()) {
} else if (dtype.hash_code() == typeid(const int).hash_code()) {
Display<int>(size);
} else if (dtype.hash_code() == typeid(int64_t).hash_code()) {
} else if (dtype.hash_code() == typeid(const int64_t).hash_code()) {
Display<int64_t>(size);
} else if (dtype.hash_code() == typeid(bool).hash_code()) {
} else if (dtype.hash_code() == typeid(const bool).hash_code()) {
Display<bool>(size);
} else {
CLOG << "\tdata: unprintable type: " << dtype.name() << std::endl;
Expand All @@ -100,7 +100,7 @@ struct Formater {

template <typename T>
void Display(size_t size) {
auto* d = (T*)data;
auto* d = reinterpret_cast<T*>(data);
CLOG << "\tdata: ";
if (summarize != -1) {
summarize = std::min(size, (size_t)summarize);
Expand Down Expand Up @@ -135,7 +135,7 @@ class TensorPrintOp : public framework::OperatorBase {
void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
const framework::Variable* in_var_ptr = nullptr;
std::string phase = kForward;
std::string phase(kForward);
std::string printed_var_name = "";

auto& inputs = Inputs();
Expand All @@ -146,7 +146,7 @@ class TensorPrintOp : public framework::OperatorBase {
!Inputs("In@GRAD").empty()) {
in_var_ptr = scope.FindVar(Input("In@GRAD"));
printed_var_name = Inputs("In@GRAD").front();
phase = kBackward;
phase = std::string(kBackward);
} else {
PADDLE_THROW("Unknown phase, should be forward or backward.");
}
Expand All @@ -163,7 +163,7 @@ class TensorPrintOp : public framework::OperatorBase {
out_tensor.set_lod(in_tensor.lod());

std::string print_phase = Attr<std::string>("print_phase");
if (print_phase != phase && print_phase != kBoth) {
if (print_phase != phase && print_phase != std::string(kBoth)) {
return;
}

Expand Down Expand Up @@ -199,7 +199,7 @@ class TensorPrintOp : public framework::OperatorBase {
formater.lod = printed_tensor.lod();
}
formater.summarize = Attr<int>("summarize");
formater.data = (void*)printed_tensor.data<void>();
formater.data = reinterpret_cast<void*>(printed_tensor.data<void>());
formater(printed_tensor.numel());
}

Expand All @@ -223,8 +223,9 @@ class PrintOpProtoAndCheckMaker : public framework::OpProtoAndCheckerMaker {
"print_phase",
"(string, default 'BOTH') Which phase to display including 'FORWARD' "
"'BACKWARD' and 'BOTH'.")
.SetDefault(kBoth)
.InEnum({kForward, kBackward, kBoth});
.SetDefault(std::string(kBoth))
.InEnum({std::string(kForward), std::string(kBackward),
std::string(kBoth)});
AddOutput("Out", "Output tensor with same data as input tensor.");
AddComment(R"DOC(
Creates a print op that will print when a tensor is accessed.
Expand Down