Skip to content

Commit 3208914

Browse files
authored
Merge pull request #2805 from QiJune/tensor_to_EigenTensor
Add method converting Tensor to Eigen TensorMap
2 parents fb48cb1 + d344f67 commit 3208914

File tree

12 files changed

+273
-38
lines changed

12 files changed

+273
-38
lines changed

paddle/framework/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
# ddim lib
12
cc_library(enforce SRCS enforce.cc DEPS glog)
23
cc_test(enforce_test SRCS enforce_test.cc DEPS enforce)
3-
cc_library(ddim SRCS ddim.cc)
4+
cc_library(ddim SRCS ddim.cc DEPS eigen3)
45
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
56
nv_test(dim_test SRCS dim_test.cu DEPS ddim)
67
cc_library(tensor SRCS tensor.cc DEPS ddim place enforce paddle_memory)

paddle/framework/ddim.h

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,26 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
115
#pragma once
216

317
#include <boost/variant.hpp>
418
#include <initializer_list>
519
#include <stdexcept>
620
#include <vector>
7-
821
#include "paddle/framework/dim.h"
22+
#include "paddle/framework/enforce.h"
23+
#include "unsupported/Eigen/CXX11/Tensor"
924

1025
namespace paddle {
1126
namespace framework {
@@ -104,6 +119,17 @@ int arity(const DDim& ddim);
104119

105120
std::ostream& operator<<(std::ostream&, const DDim&);
106121

122+
template <int NDIMS>
123+
Eigen::DSizes<Eigen::DenseIndex, NDIMS> ToEigenDSizes(const DDim& dims) {
124+
int rank = arity(dims);
125+
PADDLE_ENFORCE(rank == NDIMS, "DDim and NDIMS must be same");
126+
Eigen::DSizes<Eigen::DenseIndex, NDIMS> dsizes;
127+
for (int d = 0; d < rank; d++) {
128+
dsizes[d] = dims[d];
129+
}
130+
return dsizes;
131+
}
132+
107133
} // namespace framework
108134
} // namespace paddle
109135

paddle/framework/operator.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,20 @@ limitations under the License. */
1919
namespace paddle {
2020
namespace framework {
2121

22+
template <>
23+
Eigen::DefaultDevice* KernelContext::GetEigenDevice<
24+
platform::CPUPlace, Eigen::DefaultDevice>() const {
25+
return device_context_.get_eigen_device<Eigen::DefaultDevice>();
26+
}
27+
28+
#ifndef PADDLE_ONLY_CPU
29+
template <>
30+
Eigen::GpuDevice*
31+
KernelContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
32+
return device_context_.get_eigen_device<Eigen::GpuDevice>();
33+
}
34+
#endif
35+
2236
const std::string& OperatorBase::Input(const std::string& name) const {
2337
auto it = in_out_idxs_->find(name);
2438
PADDLE_ENFORCE(it != in_out_idxs_->end(), "no key [%s] in in_out_idxs_",

paddle/framework/operator.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,21 @@ limitations under the License. */
3131
namespace paddle {
3232
namespace framework {
3333

34+
template <typename T>
35+
struct EigenDeviceConverter;
36+
37+
template <>
38+
struct EigenDeviceConverter<platform::CPUPlace> {
39+
using EigenDeviceType = Eigen::DefaultDevice;
40+
};
41+
42+
#ifndef PADDLE_ONLY_CPU
43+
template <>
44+
struct EigenDeviceConverter<platform::GPUPlace> {
45+
using EigenDeviceType = Eigen::GpuDevice;
46+
};
47+
#endif
48+
3449
class OperatorBase;
3550
using OperatorPtr = std::shared_ptr<OperatorBase>;
3651
/**
@@ -131,6 +146,13 @@ class KernelContext {
131146
return res;
132147
}
133148

149+
template <typename PlaceType,
150+
typename DeviceType =
151+
typename EigenDeviceConverter<PlaceType>::EigenDeviceType>
152+
DeviceType* GetEigenDevice() const;
153+
154+
platform::Place GetPlace() const { return device_context_.GetPlace(); }
155+
134156
const OperatorBase& op_;
135157
const std::shared_ptr<Scope>& scope_;
136158
const platform::DeviceContext& device_context_;
@@ -144,6 +166,7 @@ class OpKernel {
144166
* device resource such as CUDA stream, cublas handle, etc. from
145167
* KernelContext. User should construct it before run the Operator.
146168
*/
169+
147170
virtual void Compute(const KernelContext& context) const = 0;
148171

149172
virtual ~OpKernel() {}

paddle/framework/tensor.h

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@ limitations under the License. */
2020
#include <typeindex>
2121
#include "paddle/framework/ddim.h"
2222
#include "paddle/framework/enforce.h"
23+
#include "paddle/framework/tensor_types.h"
2324
#include "paddle/memory/memory.h"
2425
#include "paddle/platform/place.h"
26+
#include "unsupported/Eigen/CXX11/Tensor"
2527

2628
namespace paddle {
2729
namespace pybind {
@@ -43,6 +45,13 @@ class Tensor {
4345
reinterpret_cast<uintptr_t>(holder_->ptr()) + offset_);
4446
}
4547

48+
template <typename T>
49+
T* raw_data() const {
50+
CheckDims<T>();
51+
return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
52+
offset_);
53+
}
54+
4655
template <typename T>
4756
T* mutable_data(DDim dims, platform::Place place) {
4857
set_dims(dims);
@@ -77,6 +86,66 @@ class Tensor {
7786
offset_);
7887
}
7988

89+
template <typename T, size_t NDIMS>
90+
typename TTypes<T, NDIMS>::Tensor shaped(DDim new_dims) {
91+
Eigen::array<Eigen::DenseIndex, NDIMS> dims =
92+
paddle::framework::ToEigenDSizes<NDIMS>(new_dims);
93+
return typename TTypes<T, NDIMS>::Tensor(raw_data<T>(), dims);
94+
}
95+
96+
template <typename T, size_t NDIMS>
97+
typename TTypes<T, NDIMS>::Tensor tensor() {
98+
return typename TTypes<T, NDIMS>::Tensor(
99+
raw_data<T>(), paddle::framework::ToEigenDSizes<NDIMS>(dims_));
100+
}
101+
102+
// flat to rank = 1
103+
template <typename T>
104+
typename TTypes<T>::Flat flat() {
105+
return shaped<T, 1>(make_ddim({static_cast<int>(product(dims_))}));
106+
}
107+
108+
// to TensorType Vec
109+
template <typename T>
110+
typename TTypes<T>::Vec vec() {
111+
return tensor<T, 1>();
112+
}
113+
114+
// to TensorType Matrix
115+
template <typename T>
116+
typename TTypes<T>::Matrix matrix() {
117+
return tensor<T, 2>();
118+
}
119+
120+
// const versions of all the methods above.
121+
template <typename T, size_t NDIMS>
122+
typename TTypes<T, NDIMS>::Tensor shaped(DDim new_dims) const {
123+
Eigen::array<Eigen::DenseIndex, NDIMS> dims =
124+
paddle::framework::ToEigenDSizes<NDIMS>(new_dims);
125+
return typename TTypes<T, NDIMS>::Tensor(data<T>(), dims);
126+
}
127+
128+
template <typename T, size_t NDIMS>
129+
typename TTypes<T, NDIMS>::ConstantTensor tensor() const {
130+
return typename TTypes<T, NDIMS>::Tensor(
131+
data<T>(), paddle::framework::ToEigenDSizes<NDIMS>(dims_));
132+
}
133+
134+
template <typename T>
135+
typename TTypes<T>::ConstFlat flat() const {
136+
return shaped<T, 1>(make_ddim({static_cast<int>(product(dims_))}));
137+
}
138+
139+
template <typename T>
140+
typename TTypes<T>::ConstVec vec() const {
141+
return tensor<T, 1>();
142+
}
143+
144+
template <typename T>
145+
typename TTypes<T>::ConstMatrix matrix() const {
146+
return tensor<T, 2>();
147+
}
148+
80149
template <typename T>
81150
void ShareDataFrom(const Tensor& src) {
82151
src.CheckDims<T>();

paddle/framework/tensor_types.h

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include "unsupported/Eigen/CXX11/Tensor"
18+
19+
namespace paddle {
20+
namespace framework {
21+
22+
// Helper to define Tensor types given that the scalar is of type T.
23+
template <typename T, int NDIMS = 1, typename IndexType = Eigen::DenseIndex>
24+
struct TTypes {
25+
// Rank-<NDIMS> tensor of scalar type T.
26+
typedef Eigen::TensorMap<Eigen::Tensor<T, NDIMS, Eigen::RowMajor, IndexType>,
27+
Eigen::Aligned>
28+
Tensor;
29+
typedef Eigen::TensorMap<
30+
Eigen::Tensor<const T, NDIMS, Eigen::RowMajor, IndexType>, Eigen::Aligned>
31+
ConstTensor;
32+
33+
// Scalar tensor (implemented as a rank-0 tensor) of scalar type T.
34+
typedef Eigen::TensorMap<
35+
Eigen::TensorFixedSize<T, Eigen::Sizes<>, Eigen::RowMajor, IndexType>,
36+
Eigen::Aligned>
37+
Scalar;
38+
typedef Eigen::TensorMap<Eigen::TensorFixedSize<const T, Eigen::Sizes<>,
39+
Eigen::RowMajor, IndexType>,
40+
Eigen::Aligned>
41+
ConstScalar;
42+
43+
// Rank-1 tensor (vector) of scalar type T.
44+
typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType>,
45+
Eigen::Aligned>
46+
Flat;
47+
typedef Eigen::TensorMap<
48+
Eigen::Tensor<const T, 1, Eigen::RowMajor, IndexType>, Eigen::Aligned>
49+
ConstFlat;
50+
typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType>,
51+
Eigen::Aligned>
52+
Vec;
53+
typedef Eigen::TensorMap<
54+
Eigen::Tensor<const T, 1, Eigen::RowMajor, IndexType>, Eigen::Aligned>
55+
ConstVec;
56+
57+
// Rank-2 tensor (matrix) of scalar type T.
58+
typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor, IndexType>,
59+
Eigen::Aligned>
60+
Matrix;
61+
typedef Eigen::TensorMap<
62+
Eigen::Tensor<const T, 2, Eigen::RowMajor, IndexType>, Eigen::Aligned>
63+
ConstMatrix;
64+
};
65+
66+
} // namespace framework
67+
} // namespace paddle

paddle/operators/add_op.cc

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,20 @@
11
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
22
3-
Licensed under the Apache License, Version 2.0 (the "License");
4-
you may not use this file except in compliance with the License.
5-
You may obtain a copy of the License at
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
66
7-
http://www.apache.org/licenses/LICENSE-2.0
7+
http://www.apache.org/licenses/LICENSE-2.0
88
9-
Unless required by applicable law or agreed to in writing, software
10-
distributed under the License is distributed on an "AS IS" BASIS,
11-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12-
See the License for the specific language governing permissions and
13-
limitations under the License. */
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
1414

15-
#include <paddle/framework/op_registry.h>
16-
#include <paddle/framework/tensor.h>
17-
#include <paddle/operators/add_op.h>
15+
#include "paddle/operators/add_op.h"
16+
#include "paddle/framework/op_registry.h"
17+
#include "paddle/framework/tensor.h"
1818

1919
namespace paddle {
2020
namespace operators {
@@ -53,5 +53,6 @@ The equation is: Out = X + Y
5353
} // namespace paddle
5454

5555
REGISTER_OP(add_two, paddle::operators::AddOp, paddle::operators::AddOpMaker);
56-
REGISTER_OP_CPU_KERNEL(
57-
add_two, ::paddle::operators::AddKernel<::paddle::platform::CPUPlace>);
56+
typedef paddle::operators::AddKernel<::paddle::platform::CPUPlace, float>
57+
AddKernel_CPU_float;
58+
REGISTER_OP_CPU_KERNEL(add_two, AddKernel_CPU_float);

paddle/operators/add_op.cu

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
#include <paddle/operators/add_op.h>
2-
#include <paddle/framework/op_registry.h>
1+
#include "paddle/operators/add_op.h"
2+
#include "paddle/framework/op_registry.h"
33

4+
typedef paddle::operators::AddKernel<::paddle::platform::GPUPlace, float> AddKernel_GPU_float;
45
REGISTER_OP_GPU_KERNEL(add_two,
5-
paddle::operators::AddKernel<paddle::platform::GPUPlace>);
6+
AddKernel_GPU_float);

paddle/operators/add_op.h

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,36 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
115
#pragma once
2-
#include <glog/logging.h>
3-
#include <paddle/framework/operator.h>
16+
#include "glog/logging.h"
17+
#include "paddle/framework/operator.h"
418

519
namespace paddle {
620
namespace operators {
721

8-
template <typename Place>
22+
template <typename Place, typename T>
923
class AddKernel : public framework::OpKernel {
1024
public:
11-
void Compute(const framework::KernelContext &context) const override {
12-
LOG(INFO) << "Add kernel in " << typeid(Place).name();
25+
void Compute(const framework::KernelContext& context) const override {
26+
auto input0 = context.Input(0)->Get<framework::Tensor>();
27+
auto input1 = context.Input(1)->Get<framework::Tensor>();
28+
auto* output = context.Output(0)->GetMutable<framework::Tensor>();
29+
30+
output->mutable_data<T>(context.GetPlace());
31+
32+
output->flat<T>().device(*(context.GetEigenDevice<Place>())) =
33+
input0.flat<T>() + input1.flat<T>();
1334
}
1435
};
1536

paddle/operators/add_op_test.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
115
#include <gtest/gtest.h>
216
#define private public
317
#include <paddle/framework/op_registry.h>

0 commit comments

Comments
 (0)