Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# ddim lib
cc_library(ddim SRCS ddim.cc)
cc_library(ddim SRCS ddim.cc DEPS eigen3)
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
nv_test(dim_test SRCS dim_test.cu DEPS ddim)
cc_test(tensor_test SRCS tensor_test.cc DEPS ddim)
Expand Down
28 changes: 27 additions & 1 deletion paddle/framework/ddim.h
Original file line number Diff line number Diff line change
@@ -1,11 +1,26 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <boost/variant.hpp>
#include <initializer_list>
#include <stdexcept>
#include <vector>

#include "paddle/framework/dim.h"
#include "paddle/framework/enforce.h"
#include "unsupported/Eigen/CXX11/Tensor"

namespace paddle {
namespace framework {
Expand Down Expand Up @@ -91,6 +106,17 @@ int arity(const DDim& ddim);

std::ostream& operator<<(std::ostream&, const DDim&);

template <int NDIMS>
Eigen::DSizes<Eigen::DenseIndex, NDIMS> ToEigenDSizes(const DDim& dims) {
int rank = arity(dims);
PADDLE_ENFORCE(rank == NDIMS, "DDim and NDIMS must be same");
Eigen::DSizes<Eigen::DenseIndex, NDIMS> dsizes;
for (int d = 0; d < rank; d++) {
dsizes[d] = dims[d];
}
return dsizes;
}

} // namespace framework
} // namespace paddle

Expand Down
14 changes: 14 additions & 0 deletions paddle/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,20 @@ limitations under the License. */
namespace paddle {
namespace framework {

template <>
Eigen::DefaultDevice* OpKernel::KernelContext::get_eigen_device<
platform::CPUPlace, Eigen::DefaultDevice>() const {
return device_context_.get_eigen_device<Eigen::DefaultDevice>();
}

#ifndef PADDLE_ONLY_CPU
template <>
Eigen::GpuDevice* OpKernel::KernelContext::get_eigen_device<
platform::GPUPlace, Eigen::GpuDevice>() const {
return device_context_.get_eigen_device<Eigen::GpuDevice>();
}
#endif

std::string OperatorBase::DebugString() const {
std::stringstream ss;
ss << "=================\n";
Expand Down
20 changes: 20 additions & 0 deletions paddle/framework/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,21 @@ limitations under the License. */
namespace paddle {
namespace framework {

template <typename T>
struct EigenDeviceConverter;

template <>
struct EigenDeviceConverter<platform::CPUPlace> {
using EigenDeviceType = Eigen::DefaultDevice;
};

#ifndef PADDLE_ONLY_CPU
template <>
struct EigenDeviceConverter<platform::GPUPlace> {
using EigenDeviceType = Eigen::GpuDevice;
};
#endif

class OperatorBase;
using OperatorPtr = std::shared_ptr<OperatorBase>;
/**
Expand Down Expand Up @@ -91,6 +106,11 @@ class OpKernel {
return scope_->GetVariable(op_.outputs_[index]);
}

template <typename PlaceType,
typename DeviceType =
typename EigenDeviceConverter<PlaceType>::EigenDeviceType>
DeviceType* get_eigen_device() const;

const OperatorBase& op_;
const ScopePtr& scope_;
const platform::DeviceContext& device_context_;
Expand Down
99 changes: 92 additions & 7 deletions paddle/framework/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ limitations under the License. */
#include <memory>
#include "paddle/framework/ddim.h"
#include "paddle/framework/enforce.h"
#include "paddle/framework/tensor_types.h"
#include "paddle/memory/memory.h"
#include "paddle/platform/place.h"
#include "unsupported/Eigen/CXX11/Tensor"

namespace paddle {
namespace framework {
Expand All @@ -38,6 +40,13 @@ class Tensor {
reinterpret_cast<uintptr_t>(holder_->ptr()) + offset_);
}

template <typename T>
T* raw_data() const {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we always use raw_data in OpKernel. The mutable_data is only used for allocation. Maybe we could change mutable_data to allocation

CheckDims<T>();
return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
offset_);
}

template <typename T>
T* mutable_data(DDim dims, paddle::platform::Place place) {
set_dims(dims);
Expand All @@ -53,13 +62,88 @@ class Tensor {
!(holder_->place() ==
place) /* some versions of boost::variant don't have operator!= */
|| holder_->size() < numel_ * sizeof(T) + offset_) {
holder_.reset(new PlaceholderImpl<T>(place, numel_ * sizeof(T)));
#ifdef __CUDACC__
switch (place.which()) {
case 0:
holder_.reset(new PlaceholderImpl<T, platform::GPUPlace>(
boost::get<platform::GPUPlace>(place), numel_ * sizeof(T)));
break;

case 1:
holder_.reset(new PlaceholderImpl<T, platform::CPUPlace>(
boost::get<platform::CPUPlace>(place), numel_ * sizeof(T)));
break;
}
#else
holder_.reset(new PlaceholderImpl<T, platform::CPUPlace>(
boost::get<platform::CPUPlace>(place), numel_ * sizeof(T)));
#endif
offset_ = 0;
}
return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
offset_);
}

template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::Tensor shaped(DDim new_dims) {
Eigen::array<Eigen::DenseIndex, NDIMS> dims =
paddle::framework::ToEigenDSizes<NDIMS>(new_dims);
return typename TTypes<T, NDIMS>::Tensor(raw_data<T>(), dims);
}

template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::Tensor tensor() {
return typename TTypes<T, NDIMS>::Tensor(
raw_data<T>(), paddle::framework::ToEigenDSizes<NDIMS>(dims_));
}

// flat to rank = 1
template <typename T>
typename TTypes<T>::Flat flat() {
return shaped<T, 1>(make_ddim({static_cast<int>(numel_)}));
}

// to TensorType Vec
template <typename T>
typename TTypes<T>::Vec vec() {
return tensor<T, 1>();
}

// to TensorType Matrix
template <typename T>
typename TTypes<T>::Matrix matrix() {
return tensor<T, 2>();
}

// const versions of all the methods above.
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::Tensor shaped(DDim new_dims) const {
Eigen::array<Eigen::DenseIndex, NDIMS> dims =
paddle::framework::ToEigenDSizes<NDIMS>(new_dims);
return typename TTypes<T, NDIMS>::Tensor(data<T>(), dims);
}

template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::ConstantTensor tensor() const {
return typename TTypes<T, NDIMS>::Tensor(
data<T>(), paddle::framework::ToEigenDSizes<NDIMS>(dims_));
}

template <typename T>
typename TTypes<T>::ConstFlat flat() const {
return shaped<T, 1>(make_ddim({static_cast<int>(numel_)}));
}

template <typename T>
typename TTypes<T>::ConstVec vec() const {
return tensor<T, 1>();
}

template <typename T>
typename TTypes<T>::ConstMatrix matrix() const {
return tensor<T, 2>();
}

template <typename T>
void ShareDataFrom(const Tensor& src) {
src.CheckDims<T>();
Expand Down Expand Up @@ -123,32 +207,33 @@ class Tensor {
virtual size_t size() const = 0;
};

template <typename T>
template <typename T, typename PlaceType>
struct PlaceholderImpl : public Placeholder {
private:
template <typename PType>
class Deleter {
public:
Deleter(platform::Place place) : place_(place) {}
Deleter(PType place) : place_(place) {}
void operator()(T* ptr) {
paddle::memory::Free(place_, static_cast<void*>(ptr));
}

private:
paddle::platform::Place place_;
PType place_;
};

public:
PlaceholderImpl(paddle::platform::Place place, size_t size)
PlaceholderImpl(PlaceType place, size_t size)
: ptr_(static_cast<T*>(paddle::memory::Alloc(place, size)),
Deleter(place)),
Deleter<PlaceType>(place)),
place_(place),
size_(size) {}

virtual void* ptr() const { return static_cast<void*>(ptr_.get()); }
virtual size_t size() const { return size_; }
virtual paddle::platform::Place place() const { return place_; }

std::unique_ptr<T, Deleter> ptr_;
std::unique_ptr<T, Deleter<PlaceType>> ptr_;
paddle::platform::Place place_; // record the place of ptr_.
size_t size_; // size of the memory block.
};
Expand Down
67 changes: 67 additions & 0 deletions paddle/framework/tensor_types.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "unsupported/Eigen/CXX11/Tensor"

namespace paddle {
namespace framework {

// Helper to define Tensor types given that the scalar is of type T.
template <typename T, int NDIMS = 1, typename IndexType = Eigen::DenseIndex>
struct TTypes {
// Rank-<NDIMS> tensor of scalar type T.
typedef Eigen::TensorMap<Eigen::Tensor<T, NDIMS, Eigen::RowMajor, IndexType>,
Eigen::Aligned>
Tensor;
typedef Eigen::TensorMap<
Eigen::Tensor<const T, NDIMS, Eigen::RowMajor, IndexType>, Eigen::Aligned>
ConstTensor;

// Scalar tensor (implemented as a rank-0 tensor) of scalar type T.
typedef Eigen::TensorMap<
Eigen::TensorFixedSize<T, Eigen::Sizes<>, Eigen::RowMajor, IndexType>,
Eigen::Aligned>
Scalar;
typedef Eigen::TensorMap<Eigen::TensorFixedSize<const T, Eigen::Sizes<>,
Eigen::RowMajor, IndexType>,
Eigen::Aligned>
ConstScalar;

// Rank-1 tensor (vector) of scalar type T.
typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType>,
Eigen::Aligned>
Flat;
typedef Eigen::TensorMap<
Eigen::Tensor<const T, 1, Eigen::RowMajor, IndexType>, Eigen::Aligned>
ConstFlat;
typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType>,
Eigen::Aligned>
Vec;
typedef Eigen::TensorMap<
Eigen::Tensor<const T, 1, Eigen::RowMajor, IndexType>, Eigen::Aligned>
ConstVec;

// Rank-2 tensor (matrix) of scalar type T.
typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor, IndexType>,
Eigen::Aligned>
Matrix;
typedef Eigen::TensorMap<
Eigen::Tensor<const T, 2, Eigen::RowMajor, IndexType>, Eigen::Aligned>
ConstMatrix;
};

} // namespace framework
} // namespace paddle
21 changes: 13 additions & 8 deletions paddle/function/RowConvOpGpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ __global__ void KeRowConv(real* y, const real* x, const real* w,
for (int i = tidy; i < context; i += blky) {
sw[i][tidx] = gidx + tidx < width ? w[i*width + gidx + tidx] : 0.0;
}

__syncthreads();

for (int i = 0; i < numSeq; ++i) {
Expand Down Expand Up @@ -144,12 +144,15 @@ __global__ void KeRowConvBwWeight(real* dw, const real* x, const real* dy,
int yoff = start + j;

// transpose
sh_x[tidx][tidy] = (xoff < width && yoff < end) ? x[yoff * width + xoff] : 0.0;
sh_dy[tidx][tidy + context - 1] = (xoff < width && yoff < end) ? dy[yoff * width + xoff] : 0.0;
sh_x[tidx][tidy] = (xoff < width && yoff < end) ?
x[yoff * width + xoff] : 0.0;
sh_dy[tidx][tidy + context - 1] = (xoff < width && yoff < end) ?
dy[yoff * width + xoff] : 0.0;
__syncthreads();
if (tidy < (context - 1)) {
yoff = yoff - context + 1;
sh_dy[tidx][tidy] = (xoff < width && yoff >= start) ? dy[yoff * width + xoff] : 0.0;
sh_dy[tidx][tidy] = (xoff < width && yoff >= start) ?
dy[yoff * width + xoff] : 0.0;
}
__syncthreads();

Expand Down Expand Up @@ -199,11 +202,13 @@ __global__ void KeRowConvBwWeight2(real* dw, const real* x, const real* dy,
int yoff = start + j;

// transpose
sh_x[tidx][tidy] = (xoff < width && yoff < end) ? x[yoff * width + xoff] : 0.0;
sh_x[tidx][tidy] = (xoff < width && yoff < end) ?
x[yoff * width + xoff] : 0.0;
__syncthreads();

for (int t = 0; t < context; t++) {
sh_dy[tidx][tidy] = (xoff < width && (yoff - t) >= start && yoff - t < end) ? dy[(yoff - t) * width + xoff] : 0.0;
sh_dy[tidx][tidy] = (xoff < width && (yoff - t) >= start &&
yoff - t < end) ? dy[(yoff - t) * width + xoff] : 0.0;
__syncthreads();

real val = sh_x[tidy][tidx] * sh_dy[tidy][tidx];
Expand Down Expand Up @@ -239,7 +244,7 @@ __global__ void KeRowConvBwData(real* dx, const real* w, const real* dy,
for (int i = tidy; i < context; i += blky) {
sw[i][tidx] = gidx + tidx < width ? w[i*width + gidx + tidx] : 0.0;
}

__syncthreads();

for (int i = 0; i < numSeq; ++i) {
Expand Down Expand Up @@ -312,7 +317,7 @@ void RowConvGrad<DEVICE_TYPE_GPU>(const GpuMatrix& outG,
dim3 dimBlock(32, 32);
dim3 dimGrid(DIVUP(width, dimBlock.x), 1);
real* dw = filterG.getData();
if (contextLength <= 32) {
if (contextLength <= 32) {
KeRowConvBwWeight<32, 32, 32>
<<<dimGrid, dimBlock, 0, STREAM_DEFAULT>>>
(dw, x, dy, starts, height, width, numSeq, contextLength);
Expand Down
Loading