Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
5757aa1
add isfinitev2 op
joey12300 Aug 13, 2020
0255d58
finish all visitor
joey12300 Aug 14, 2020
d47d19d
modify cuda kernel of bothfalse visitor
joey12300 Aug 14, 2020
2f45e46
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into add_isf…
joey12300 Aug 16, 2020
a973644
add isnan, isinf, isfinite api
joey12300 Aug 16, 2020
ecac473
add comment for isfinite, isnan, isinf
joey12300 Aug 16, 2020
c57659e
add alias for isnan, isinf, isfinite
joey12300 Aug 17, 2020
1963a09
add int64 support
joey12300 Aug 17, 2020
cc1f486
fix typo
joey12300 Aug 17, 2020
22b6693
fix isfinite_v2 op can't evaluate tensor with dim larger than 1
joey12300 Aug 17, 2020
c36a80b
add unittest and register int64_t, float16 op kernel
joey12300 Aug 18, 2020
831f14d
add error testcase
joey12300 Aug 18, 2020
848baf6
update comment of test_isfinite_v2_op.py
joey12300 Aug 18, 2020
856ca1e
remove @skipIf
joey12300 Aug 18, 2020
b4d0ded
add specified type of error
joey12300 Aug 18, 2020
4b6ac65
remove alias comment
joey12300 Aug 18, 2020
1255817
data->mutable_data, use GetMaxPhysicalThreadCount to calculate MAX_GR…
joey12300 Aug 18, 2020
59043f9
rolback
joey12300 Aug 18, 2020
aaee7e3
to_variable->to_tensor
joey12300 Aug 18, 2020
423455c
add cpu only test
joey12300 Aug 19, 2020
f47ebe7
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into add_isf…
joey12300 Aug 19, 2020
15777f6
fix isfinite test case failed
joey12300 Aug 19, 2020
73fef26
fix conflict in paddle.tensor.math
joey12300 Aug 20, 2020
6caf966
1. remove useless macro
joey12300 Aug 22, 2020
d812983
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into add_isf…
joey12300 Aug 22, 2020
a84e213
add alias of isnan, isinf, isfinite
joey12300 Aug 22, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 100 additions & 10 deletions paddle/fluid/framework/tensor_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,61 @@ inline void Any(const framework::Tensor& tensor, Predicate predicate,
platform::VisitPlace(place, visitor);
}

template <typename Predicate, typename DevCtx>
struct AllDTypeVisitor {
Predicate predicate_;
const Tensor& tensor_;
const DevCtx& ctx_;
Tensor* out_;

AllDTypeVisitor(Predicate predicate, const Tensor& tensor, const DevCtx& ctx,
Tensor* out)
: predicate_(predicate), tensor_(tensor), ctx_(ctx), out_(out) {}

template <typename T>
void apply() const {
auto t = EigenVector<T>::Flatten(tensor_);
auto o = EigenVector<bool>::Flatten(*out_);
o.device(*ctx_.eigen_device()) = predicate_(t);
}
};

template <typename Predicate, typename DevCtx>
inline void AllImpl(Predicate predicate, const framework::Tensor& tensor,
const DevCtx& ctx, framework::Tensor* out) {
VisitDataType(tensor.type(), AllDTypeVisitor<Predicate, DevCtx>(
predicate, tensor, ctx, out));
}

template <typename Predicate>
class AllOutVisitor : public boost::static_visitor<> {
private:
const framework::Tensor& tensor_;
mutable framework::Tensor* out_;
Predicate predicate_;

public:
AllOutVisitor(const framework::Tensor& tensor, Predicate predicate,
framework::Tensor* out)
: tensor_(tensor), predicate_(predicate), out_(out) {}

template <typename Place>
void operator()(const Place& place) const {
auto* ctx = platform::DeviceContextPool::Instance().GetByPlace(place);
out_->Resize(tensor_.dims());
out_->mutable_data<bool>(place);
AllImpl(predicate_, tensor_, *ctx, out_);
}
};

template <typename Predicate>
inline void All(const framework::Tensor& tensor, Predicate predicate,
framework::Tensor* out) {
AllOutVisitor<Predicate> visitor(tensor, predicate, out);
auto place = tensor.place();
platform::VisitPlace(place, visitor);
}

struct ContainsNANPredicate {
template <typename T>
auto operator()(const T& eigen_vec) const
Expand All @@ -370,6 +425,12 @@ void TensorContainsNAN(const framework::Tensor& tensor,
Any(tensor, predicate, out);
}

void TensorContainsNANV2(const framework::Tensor& tensor,
framework::Tensor* out) {
ContainsNANPredicate predicate;
All(tensor, predicate, out);
}

struct ContainsInfPredicate {
template <typename T>
auto operator()(const T& eigen_vec) const
Expand All @@ -390,6 +451,12 @@ void TensorContainsInf(const framework::Tensor& tensor,
Any(tensor, predicate, out);
}

void TensorContainsInfV2(const framework::Tensor& tensor,
framework::Tensor* out) {
ContainsInfPredicate predicate;
All(tensor, predicate, out);
}

// NOTE(dzhwinter):
// Isfinite need a AllVisitor to loop through all the elements.
// We choose two cuda call instead of one allvisitor. The AllVisitor
Expand All @@ -402,8 +469,8 @@ bool TensorIsfinite(const framework::Tensor& tensor) {

#ifdef PADDLE_WITH_CUDA
template <typename T>
static inline void __global__ BothFalse(const T* cmp, T* out) {
out[0] = (!cmp[0]) && (!out[0]);
static inline void __global__ BothFalse(const T* cmp, T* out, int element_num) {
CUDA_KERNEL_LOOP(i, element_num) { out[i] = (!cmp[i]) && (!out[i]); }
}
#endif

Expand All @@ -421,22 +488,36 @@ struct BothFalseVisitor : public boost::static_visitor<> {
void VisitorImpl(const platform::CUDAPlace& gpu) const {
#ifdef PADDLE_WITH_CUDA
auto* ctx = platform::DeviceContextPool::Instance().GetByPlace(gpu);
BothFalse<bool><<<1, 1, 0, ctx->stream()>>>(in_.data<bool>(),
out_->mutable_data<bool>(gpu));
constexpr int MAX_BLOCK_DIM = 512;
constexpr int MAX_GRID_DIM = 65535;
int element_num = in_.numel();
int block_size = (element_num >= MAX_BLOCK_DIM)
? MAX_BLOCK_DIM
: (1 << static_cast<int>(std::log2(element_num)));
int grid_size = element_num / block_size;
grid_size = (grid_size >= MAX_GRID_DIM) ? MAX_GRID_DIM : grid_size;
BothFalse<bool><<<grid_size, block_size, 0, ctx->stream()>>>(
in_.data<bool>(), out_->mutable_data<bool>(gpu), element_num);
#endif
}

void VisitorImpl(const platform::CPUPlace& cpu) const {
bool lhs = !in_.data<bool>()[0];
bool rhs = !out_->mutable_data<bool>(cpu)[0];
out_->mutable_data<bool>(cpu)[0] = lhs && rhs;
int num = in_.numel();
for (int i = 0; i < num; ++i) {
bool lhs = !in_.data<bool>()[i];
bool rhs = !out_->mutable_data<bool>(cpu)[i];
out_->mutable_data<bool>(cpu)[i] = lhs && rhs;
}
}

void VisitorImpl(
const platform::CUDAPinnedPlace& cpu /* equals to cpu*/) const {
bool lhs = !in_.data<bool>()[0];
bool rhs = !out_->mutable_data<bool>(cpu)[0];
out_->mutable_data<bool>(cpu)[0] = lhs && rhs;
int num = in_.numel();
for (int i = 0; i < num; ++i) {
bool lhs = !in_.data<bool>()[i];
bool rhs = !out_->mutable_data<bool>(cpu)[i];
out_->mutable_data<bool>(cpu)[i] = lhs && rhs;
}
}
};

Expand All @@ -449,6 +530,15 @@ void TensorIsfinite(const framework::Tensor& tensor, framework::Tensor* out) {
platform::VisitPlace(place, visitor);
}

void TensorIsfiniteV2(const framework::Tensor& tensor, framework::Tensor* out) {
framework::Tensor tmp;
TensorContainsInfV2(tensor, &tmp);
TensorContainsNANV2(tensor, out);
BothFalseVisitor visitor(tmp, out);
auto place = tensor.place();
platform::VisitPlace(place, visitor);
}

void TensorToStream(std::ostream& os, const Tensor& tensor,
const platform::DeviceContext& dev_ctx) {
{ // the 1st field, uint32_t version
Expand Down
7 changes: 7 additions & 0 deletions paddle/fluid/framework/tensor_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,13 @@ void TensorFromStream(std::istream& is, Tensor* tensor,
const platform::DeviceContext& dev_ctx,
const size_t& seek, const std::vector<int64_t>& shape);

// store the bool result tensor in out tensor
void TensorContainsNANV2(const framework::Tensor& tensor,
framework::Tensor* out);
void TensorContainsInfV2(const framework::Tensor& tensor,
framework::Tensor* out);
void TensorIsfiniteV2(const framework::Tensor& tensor, framework::Tensor* out);

// convert dlpack's DLTensor to tensor
void TensorFromDLPack(const ::DLTensor& dl_tensor, framework::Tensor* dst);

Expand Down
119 changes: 119 additions & 0 deletions paddle/fluid/operators/isfinite_v2_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/operators/isfinite_v2_op.h"
#include <string>
#include <vector>
#include "paddle/fluid/operators/common_infer_shape_functions.h"
#include "paddle/fluid/platform/float16.h"

namespace plat = paddle::platform;

namespace paddle {
namespace operators {

class OverflowV2Op : public framework::OperatorWithKernel {
public:
OverflowV2Op(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "isfinitev2");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "isfinitev2");
UnaryOpUnchangedInferShape(ctx);
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
int dtype = -1;
auto *x_var = ctx.InputVar("X");
if (x_var->IsType<framework::LoDTensor>()) {
dtype = x_var->Get<framework::LoDTensor>().type();
} else if (x_var->IsType<framework::SelectedRows>()) {
dtype = x_var->Get<framework::SelectedRows>().value().type();
} else {
PADDLE_THROW(plat::errors::InvalidArgument(
"Cannot find the input data type by all input data"));
}
return framework::OpKernelType(framework::proto::VarType::Type(dtype),
ctx.GetPlace());
}
};

class OverflowV2OpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor) The input tensors of overflowv2 operator.");
AddOutput("Out",
"(Tensor) The output tensor of overflowv2 operator. "
"Same size compare to input tensor");
AddComment(string::Sprintf(R"DOC(
Overflow %s operator.

$$Out = %s(X)$$

Check whether each element of X is Inf or Nan, return the bool result of each
element of X as a tensor.

%s
)DOC",
GetName(), GetComments()));
}

protected:
virtual std::string GetName() const = 0;
virtual std::string GetComments() const = 0;
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;

#define REGISTER_V2OP_MAKER(op_type, comment) \
namespace paddle { \
namespace operators { \
class _##op_type##OverflowV2OpMaker \
: public ::paddle::operators::OverflowV2OpMaker { \
protected: \
std::string GetName() const { return #op_type; } \
std::string GetComments() const { return comment; } \
}; \
} \
} \
REGISTER_OPERATOR( \
op_type, ops::OverflowV2Op, ops::_##op_type##OverflowV2OpMaker, \
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, \
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>)

#define REGISTER_OVERFLOW_CPU_KERNEL(op_type, functor) \
REGISTER_OP_CPU_KERNEL( \
op_type, ops::OverflowKernel<paddle::platform::CPUDeviceContext, int, \
ops::functor>, \
ops::OverflowKernel<paddle::platform::CPUDeviceContext, int64_t, \
ops::functor>, \
ops::OverflowKernel<paddle::platform::CPUDeviceContext, float, \
ops::functor>, \
ops::OverflowKernel<paddle::platform::CPUDeviceContext, double, \
ops::functor>, \
ops::OverflowKernel<paddle::platform::CPUDeviceContext, plat::float16, \
ops::functor>);

REGISTER_V2OP_MAKER(isinf_v2, "isinfv2(X)");
REGISTER_V2OP_MAKER(isnan_v2, "isnanv2(X)");
REGISTER_V2OP_MAKER(isfinite_v2, "isfinitev2(X)");
FOR_EACH_KERNEL_V2FUNCTOR(REGISTER_OVERFLOW_CPU_KERNEL);
34 changes: 34 additions & 0 deletions paddle/fluid/operators/isfinite_v2_op.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/operators/isfinite_v2_op.h"
#include "paddle/fluid/platform/float16.h"

namespace ops = paddle::operators;
namespace plat = paddle::platform;

#define REGISTER_OVERFLOW_CUDA_KERNEL(op_type, functor) \
REGISTER_OP_CUDA_KERNEL( \
op_type, ops::OverflowKernel<paddle::platform::CUDADeviceContext, int, \
ops::functor>, \
ops::OverflowKernel<paddle::platform::CUDADeviceContext, int64_t, \
ops::functor>, \
ops::OverflowKernel<paddle::platform::CUDADeviceContext, float, \
ops::functor>, \
ops::OverflowKernel<paddle::platform::CUDADeviceContext, double, \
ops::functor>, \
ops::OverflowKernel<paddle::platform::CUDADeviceContext, plat::float16, \
ops::functor>);

FOR_EACH_KERNEL_V2FUNCTOR(REGISTER_OVERFLOW_CUDA_KERNEL);
52 changes: 52 additions & 0 deletions paddle/fluid/operators/isfinite_v2_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/isfinite_op.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/transform.h"

namespace paddle {
namespace operators {

struct InfinityV2Functor {
void operator()(const framework::Tensor& tensor, framework::Tensor* out) {
framework::TensorContainsInfV2(tensor, out);
}
};

struct NANV2Functor {
void operator()(const framework::Tensor& tensor, framework::Tensor* out) {
framework::TensorContainsNANV2(tensor, out);
}
};

struct IsfiniteV2Functor {
void operator()(const framework::Tensor& tensor, framework::Tensor* out) {
framework::TensorIsfiniteV2(tensor, out);
}
};

} // namespace operators
} // namespace paddle

#define FOR_EACH_KERNEL_V2FUNCTOR(__macro) \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个宏定义建议去掉,节省的代码不多,但增加了代码层级。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

__macro(isinf_v2, InfinityV2Functor); \
__macro(isnan_v2, NANV2Functor); \
__macro(isfinite_v2, IsfiniteV2Functor);
4 changes: 2 additions & 2 deletions paddle/fluid/operators/nll_loss_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ class NLLLossOp : public framework::OperatorWithKernel {
"Input(Weight) should be a 1D tensor."));
PADDLE_ENFORCE_EQ(x_dims[1], w_dims[0],
platform::errors::InvalidArgument(
"Input(Weight) Tensor's size should match"
"to the class numer."));
"Input(Weight) Tensor's size should match "
"to the the total number of classes."));
}
}
if (x_dims.size() == 2) {
Expand Down
Loading