Skip to content

Commit 46a7f39

Browse files
author
ktlichkid
committed
Implemented has_data_op
1 parent 8b1918f commit 46a7f39

File tree

3 files changed

+147
-0
lines changed

3 files changed

+147
-0
lines changed
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/has_data_op.h"
16+
#include <string>
17+
#include "paddle/fluid/framework/op_registry.h"
18+
19+
#include <iostream>
20+
21+
namespace paddle {
22+
namespace operators {
23+
24+
class HasDataOpMaker : public framework::OpProtoAndCheckerMaker {
25+
public:
26+
HasDataOpMaker(OpProto *proto, OpAttrChecker *op_checker)
27+
: OpProtoAndCheckerMaker(proto, op_checker) {
28+
// inputs and outputs stored in proto
29+
AddInput("X", "(LoDTensor) the LoDTensor to check");
30+
AddOutput("Out", "(LoDTensor) the ouput of has_data_op");
31+
AddComment(R"DOC(
32+
Has Data Operator.
33+
34+
This operator tests whether the input tensor has data or not.
35+
Out is a boolean scalar.
36+
)DOC");
37+
}
38+
};
39+
40+
class HasDataOp : public framework::OperatorWithKernel {
41+
public:
42+
using framework::OperatorWithKernel::OperatorWithKernel;
43+
44+
protected:
45+
void InferShape(framework::InferShapeContext *ctx) const override {
46+
PADDLE_ENFORCE(ctx->HasInput("X"),
47+
"Input(X) of HasDataOp should not be null.");
48+
PADDLE_ENFORCE(ctx->HasOutput("Out"),
49+
"Output(Out) of HasDataOp should not be null.");
50+
std::cout << "Before set\n";
51+
ctx->SetOutputDim("Out", {1});
52+
std::cout << "After set\n";
53+
ctx->ShareLoD("X", "Out");
54+
}
55+
56+
framework::OpKernelType GetExpectedKernelType(
57+
const framework::ExecutionContext &ctx) const override {
58+
framework::OpKernelType kt = framework::OpKernelType(
59+
framework::ToDataType(
60+
ctx.Input<framework::LoDTensor>("X")->type()),
61+
platform::CPUPlace());
62+
return kt;
63+
}
64+
};
65+
66+
} // namespace operators
67+
} // namespace paddle
68+
69+
namespace ops = paddle::operators;
70+
71+
REGISTER_OPERATOR(has_data, ops::HasDataOp, ops::HasDataOpMaker);
72+
REGISTER_OP_CPU_KERNEL(
73+
has_data,
74+
ops::HasDataOpKernel<paddle::platform::CPUDeviceContext, float>,
75+
ops::HasDataOpKernel<paddle::platform::CPUDeviceContext, double>,
76+
ops::HasDataOpKernel<paddle::platform::CPUDeviceContext, int>,
77+
ops::HasDataOpKernel<paddle::platform::CPUDeviceContext, int64_t>);
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include <math.h>
17+
#include <type_traits>
18+
#include "paddle/fluid/framework/op_registry.h"
19+
#include "paddle/fluid/operators/elementwise_op_function.h"
20+
#include "paddle/fluid/platform/transform.h"
21+
22+
namespace paddle {
23+
namespace operators {
24+
25+
template <typename DeviceContext, typename T>
26+
class HasDataOpKernel : public framework::OpKernel<T> {
27+
public:
28+
void Compute(const framework::ExecutionContext& context) const override {
29+
auto* input = context.Input<framework::LoDTensor>("X");
30+
auto* output = context.Output<framework::LoDTensor>("Out");
31+
size_t mem_size = input->memory_size();
32+
33+
auto* output_data =
34+
output->mutable_data<bool>(platform::CPUPlace());
35+
if (mem_size > 0) {
36+
output_data[0] = true;
37+
} else {
38+
output_data[0] = false;
39+
}
40+
}
41+
};
42+
43+
} // namespace operators
44+
} // namespace paddle
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from paddle.fluid.op import Operator
2+
import paddle.fluid.core as core
3+
import unittest
4+
import numpy as np
5+
6+
7+
class BeamSearchOpTester(unittest.TestCase):
8+
def setUp(self):
9+
self.scope = core.Scope()
10+
self.scope.var('X')
11+
self.scope.var('Out')
12+
self.place = core.CUDAPlace(0)
13+
x_data = np.array([])
14+
x_tensor = self.scope.var('X').get_tensor()
15+
x_tensor.set(x_data, self.place)
16+
out_tensor = self.scope.var('Out').get_tensor()
17+
18+
def test_run(self):
19+
op = Operator('has_data', X='X', Out='Out')
20+
op.run(self.scope, self.place)
21+
out_tensor = self.scope.find_var('Out').get_tensor()
22+
print 'output: ', np.array(out_tensor)
23+
24+
25+
if __name__ == '__main__':
26+
unittest.main()

0 commit comments

Comments
 (0)