diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 03985260241689..bd6ba3499139ba 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -55,5 +55,6 @@ cc_library(paddle_pybind SHARED recurrent_op uniform_random_op gaussian_random_op - fill_zeros_like_op) + fill_zeros_like_op + fill_op) endif(WITH_PYTHON) diff --git a/paddle/framework/pybind.cc b/paddle/framework/pybind.cc index 07b42c83717652..c99dd22dd86dc9 100644 --- a/paddle/framework/pybind.cc +++ b/paddle/framework/pybind.cc @@ -41,6 +41,7 @@ USE_OP(fill_zeros_like); USE_OP_ITSELF(recurrent_op); USE_OP(gaussian_random); USE_OP(uniform_random); +USE_OP(fill); namespace paddle { namespace framework { diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index b8c779f4e5fc7b..579287a9c1360d 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -107,6 +107,8 @@ class Tensor { platform::Place place() const { return holder_->place(); } + bool IsHoldingMemory() const { return holder_ != nullptr; } + private: template inline void check_memory_size() const; diff --git a/paddle/framework/tensor_impl.h b/paddle/framework/tensor_impl.h index 7d7263b899afb7..bf18b695d31a57 100644 --- a/paddle/framework/tensor_impl.h +++ b/paddle/framework/tensor_impl.h @@ -67,12 +67,11 @@ inline T* Tensor::mutable_data(platform::Place place) { } else if (platform::is_gpu_place(place)) { #ifdef PADDLE_ONLY_CPU PADDLE_THROW("'GPUPlace' is not supported in CPU only device."); - } #else holder_.reset(new PlaceholderImpl( boost::get(place), size)); - } #endif + } offset_ = 0; } return reinterpret_cast(reinterpret_cast(holder_->ptr()) + diff --git a/paddle/memory/memcpy.cc b/paddle/memory/memcpy.cc index aaab1142ca18d3..ef0b7a7e288342 100644 --- a/paddle/memory/memcpy.cc +++ b/paddle/memory/memcpy.cc @@ -29,6 +29,16 @@ void Copy(platform::CPUPlace, void* dst, } #ifndef PADDLE_ONLY_CPU + +template <> +void Copy(platform::CPUPlace dst_place, + void* dst, + platform::GPUPlace src_place, + const void* src, size_t num) { + platform::SetDeviceId(src_place.device); + platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToHost); +} + template <> void Copy(platform::CPUPlace dst_place, void* dst, @@ -39,6 +49,15 @@ void Copy(platform::CPUPlace dst_place, platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream); } +template <> +void Copy(platform::GPUPlace dst_place, + void* dst, + platform::CPUPlace src_place, + const void* src, size_t num) { + platform::SetDeviceId(dst_place.device); + platform::GpuMemcpySync(dst, src, num, cudaMemcpyHostToDevice); +} + template <> void Copy(platform::GPUPlace dst_place, void* dst, diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 373611cc0ee952..ae3fb72b4b104d 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -67,3 +67,5 @@ op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc cc_test(recurrent_op_test SRCS recurrent_op_test.cc DEPS recurrent_op gtest mul_op add_op) op_library(uniform_random_op SRCS uniform_random_op.cc uniform_random_op.cu) + +op_library(fill_op SRCS fill_op.cc fill_op.cu) diff --git a/paddle/operators/fill_op.cc b/paddle/operators/fill_op.cc new file mode 100644 index 00000000000000..a1b21af49c6120 --- /dev/null +++ b/paddle/operators/fill_op.cc @@ -0,0 +1,67 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/fill_op.h" + +namespace paddle { +namespace operators { + +template +class FillOp : public framework::OperatorWithKernel { + public: + FillOp(const std::string &type, const VarNameMap &inputs, + const VarNameMap &outputs, const framework::AttributeMap &attrs) + : OperatorWithKernel(type, inputs, outputs, attrs) {} + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + auto &shape = GetAttr>("shape"); + auto dim = framework::make_ddim(shape); + auto numel = framework::product(dim); + PADDLE_ENFORCE_EQ(numel, GetAttr>("data").size(), + "Shape's numel should be as same as data element count"); + ctx.Output("Out")->Resize(dim); + } +}; + +template +class FillOpMaker : public framework::OpProtoAndCheckerMaker { + public: + FillOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : framework::OpProtoAndCheckerMaker(proto, op_checker) { + AddOutput("Out", "Output of Fill Op"); + AddComment("Fill a variable with shape and buffer each time."); + AddAttr("run_once", "Set it once or each time when run") + .SetDefault(false) + .InEnum({true, false}); + AddAttr>("shape", "The shape of fill parameter"); + AddAttr>("data", "The data will be filled"); + } +}; + +template +class FillOpCPUKernel : public FillOpKernelBase { + public: + void Copy(const platform::Place &place, const std::vector &src, + T *dst) const override { + std::copy(src.begin(), src.end(), dst); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(fill, ops::FillOp, ops::FillOpMaker); +REGISTER_OP_CPU_KERNEL(fill, ops::FillOpCPUKernel); diff --git a/paddle/operators/fill_op.cu b/paddle/operators/fill_op.cu new file mode 100644 index 00000000000000..df3fc4e9e472d7 --- /dev/null +++ b/paddle/operators/fill_op.cu @@ -0,0 +1,32 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/memory/memcpy.h" +#include "paddle/operators/fill_op.h" +namespace paddle { +namespace operators { +template +class FillOpGPUKernel : public FillOpKernelBase { + public: + void Copy(const platform::Place &place, const std::vector &src, + T *dst) const override { + auto &gpu_place = boost::get(place); + auto &cpu_place = platform::default_cpu(); + memory::Copy(gpu_place, dst, cpu_place, src.data(), src.size() * sizeof(T)); + } +}; +} // namespace operators +} // namespace paddle + +REGISTER_OP_GPU_KERNEL(fill, paddle::operators::FillOpGPUKernel); diff --git a/paddle/operators/fill_op.h b/paddle/operators/fill_op.h new file mode 100644 index 00000000000000..3c53130c48ec93 --- /dev/null +++ b/paddle/operators/fill_op.h @@ -0,0 +1,42 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include + +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { +template +class FillOpKernelBase : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + using namespace paddle::framework; + auto* tensor = context.Output("Out"); + auto run_once = static_cast(context.op_.GetAttr("run_once")); + if (run_once && tensor->IsHoldingMemory()) { + return; + } + T* dst = tensor->mutable_data(context.GetPlace()); + auto& src = context.op_.GetAttr>("data"); + this->Copy(context.GetPlace(), src, dst); + } + + virtual void Copy(const platform::Place& place, const std::vector& src, + T* dst) const = 0; +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt index 96fad9b42e04a8..8b0b4f31b5a22d 100644 --- a/python/paddle/v2/framework/tests/CMakeLists.txt +++ b/python/paddle/v2/framework/tests/CMakeLists.txt @@ -25,3 +25,4 @@ py_test(test_operator SRCS test_operator.py) # py_test(test_gaussian_random_op SRCS test_gaussian_random_op.py) py_test(test_uniform_random_op SRCS test_uniform_random_op.py) py_test(test_recurrent_op SRCS test_recurrent_op.py) +py_test(test_fill_op SRCS test_fill_op.py) diff --git a/python/paddle/v2/framework/tests/test_fill_op.py b/python/paddle/v2/framework/tests/test_fill_op.py new file mode 100644 index 00000000000000..77c9096124ecd6 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_fill_op.py @@ -0,0 +1,21 @@ +import unittest +from op_test_util import OpTestMeta +import numpy + + +class TestFillOp(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = "fill" + data = [0.1, 0.2, 0.3, 0.4] + + self.attrs = {'data': data, 'shape': [2, 2], 'run_once': True} + self.outputs = { + 'Out': numpy.array( + [[0.1, 0.2], [0.3, 0.4]], dtype=numpy.float32) + } + + +if __name__ == '__main__': + unittest.main()