Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions paddle/platform/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@ nv_test(cuda_test SRCS cuda_test.cu)

cc_library(place SRCS place.cc)
cc_test(place_test SRCS place_test.cc DEPS place glog gflags)

nv_test(device_context_test SRCS device_context_test.cc DEPS dynamic_loader place eigen3 glog gflags)
170 changes: 170 additions & 0 deletions paddle/platform/device_context.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#ifndef PADDLE_ONLY_CPU
#include "paddle/platform/cuda.h"
#define EIGEN_USE_GPU
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is the EIGEN_USE_GPU used?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The EIGEN_USE_GPU is used by eigen library. If we want to use Tensor Expression of eigen in GPU, we have to define this marco.

#endif

#include "paddle/framework/enforce.h"
#include "paddle/platform/dynload/cublas.h"
#include "paddle/platform/dynload/cudnn.h"
#include "paddle/platform/dynload/curand.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The above three lines should also be included between #ifndef PADDLE_ONLY_CPU and #endif.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, logically above three header files should be included between macros.

#include "paddle/platform/place.h"
#include "unsupported/Eigen/CXX11/Tensor"

namespace paddle {
namespace platform {

class DeviceContext {
public:
virtual ~DeviceContext() {}
};

class CpuDeviceContext : public DeviceContext {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cpu => CPU

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Eigen::DefaultDevice eigen_device() {
if (!eigen_device_) {
eigen_device_ = new Eigen::DefaultDevice();
}
Copy link
Contributor

@qingqing01 qingqing01 Jul 5, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where use the Eigen::DefaultDevice in our design? I find the Eigen::DefaultDevice in the directory of eigen/unsupported/Eigen/CXX11/src/Tensor, but I do not find the usage in the Tensor's doc of Eigen.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eigen::DefaultDevice is defined in unsupported/Eigen/CXX11/src/Tensor/TensorDeviceDefault.h.
About the usage of Eigen::DefaultDevice, please refer to (https://github.com/QiJune/RefEigen/blob/master/main.cu)

return *eigen_device_;
}

private:
Eigen::DefaultDevice* eigen_device_{nullptr};
};

#ifndef PADDLE_ONLY_CPU
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here to add PADDLE_ONLY_CPU will bring a problem, the code that calls DeviceGuard or CudaDeviceContext needs to be separated by PADDLE_ONLY_CPU.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code that calls DeviceGuard or CudaDeviceContext must have WITH_GPU set 1.
Yes, this brings a question, how we organize our CPU/GPU codes clearly.
We can use marco, or make fake stub header file.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there are a few things to consider when dealing with GPU and CPU mixed code.

  1. If not necessary, try not to put the GPU and CPU code in a file. In this way, you do not need to use an extra macro to separate the code. (I think, context.h can only contain cpu context, cuda_context.h can contain gpu context.)
  2. Do not use PADDLE_ONLY_CPU, should be replaced by PADDLE_WITH_CUDA. The default should be the CPU code, and when need to use CUDA code, add PADDLE_WITH_CUDA.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this suggestion is useful.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merge this pr temporarily. And I will consider the design of DeviceContext combining with Operator interface. And I will follow advices of @hedaoyuan later.

class DeviceGuard {
public:
explicit DeviceGuard(GPUPlace new_place) : previous_(GetCurrentDeviceId()) {
if (previous_ != new_place) {
paddle::platform::SetDeviceId(new_place.device);
}
}

~DeviceGuard() { paddle::platform::SetDeviceId(previous_.device); }

private:
GPUPlace previous_;
};

class CudaDeviceContext : public DeviceContext {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cuda => CUDA

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

public:
explicit CudaDeviceContext(const GPUPlace gpu_place) : gpu_place_(gpu_place) {
DeviceGuard guard(gpu_place_);
paddle::platform::throw_on_error(cudaStreamCreate(&stream_),
"cudaStreamCreate failed");
eigen_stream_ = new Eigen::CudaStreamDevice(&stream_);
eigen_device_ = new Eigen::GpuDevice(eigen_stream_);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure we will use the CUDA implementation in Eigen. If not decide to use it, I think the eigen_stream_ and eigen_device_ can be removed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we do not use CUDA implementation in Eigen, then we will write CUDA kernels for every operators. Just like caffe2.
And tensorflow use CUDA implementation in Eigen. @hedaoyuan once mentioned the efficiency of expression template of Eigen in GPU is acceptable.
So, we may have a discussion about this offline.

}

void Wait() {
paddle::platform::throw_on_error(cudaStreamSynchronize(stream_),
"cudaStreamSynchronize failed");
}

cudaStream_t stream() { return stream_; }
Copy link
Collaborator

@reyoung reyoung Jul 6, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lake of const for all methods?

But maybe it is not important because all device context is a mutable pointer passed to Op::Run.


Eigen::GpuDevice eigen_device() { return *eigen_device_; }

cublasHandle_t cublas_handle() {
if (!blas_handle_) {
DeviceGuard guard(gpu_place_);
PADDLE_ENFORCE(paddle::platform::dynload::cublasCreate(&blas_handle_) ==
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tooooooooo long for the namespace.

Maybe we can add using namespace paddle::platform; in this class private section, like

class GPUDeviceContext {
private:
  using namespace paddle::platform;   // only use namespace in this class.
};

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or maybe alias is better, like

using dynload = paddle::platform::dynload;

Copy link
Member Author

@QiJune QiJune Jul 6, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that we cannot add an alias or using namespace inside a class

CUBLAS_STATUS_SUCCESS,
"cublasCreate failed");
PADDLE_ENFORCE(paddle::platform::dynload::cublasSetStream(
blas_handle_, stream_) == CUBLAS_STATUS_SUCCESS,
"cublasSetStream failed");
}
return blas_handle_;
}

cudnnHandle_t cudnn_handle() {
if (!dnn_handle_) {
DeviceGuard guard(gpu_place_);
PADDLE_ENFORCE(paddle::platform::dynload::cudnnCreate(&dnn_handle_) ==
CUDNN_STATUS_SUCCESS,
"cudnnCreate failed");
PADDLE_ENFORCE(paddle::platform::dynload::cudnnSetStream(
dnn_handle_, stream_) == CUDNN_STATUS_SUCCESS,
"cudnnSetStream failed");
}
return dnn_handle_;
}

curandGenerator_t curand_generator() {
if (!rand_generator_) {
DeviceGuard guard(gpu_place_);
PADDLE_ENFORCE(paddle::platform::dynload::curandCreateGenerator(
&rand_generator_, CURAND_RNG_PSEUDO_DEFAULT) ==
CURAND_STATUS_SUCCESS,
"curandCreateGenerator failed");
PADDLE_ENFORCE(
paddle::platform::dynload::curandSetPseudoRandomGeneratorSeed(
rand_generator_, random_seed_) == CURAND_STATUS_SUCCESS,
"curandSetPseudoRandomGeneratorSeed failed");
PADDLE_ENFORCE(paddle::platform::dynload::curandSetStream(
rand_generator_, stream_) == CURAND_STATUS_SUCCESS,
"curandSetStream failed");
}
return rand_generator_;
}

~CudaDeviceContext() {
Wait();
if (blas_handle_) {
PADDLE_ENFORCE(paddle::platform::dynload::cublasDestroy(blas_handle_) ==
CUBLAS_STATUS_SUCCESS,
"cublasDestroy failed");
}

if (dnn_handle_) {
PADDLE_ENFORCE(paddle::platform::dynload::cudnnDestroy(dnn_handle_) ==
CUDNN_STATUS_SUCCESS,
"cudnnDestroy failed");
}

if (rand_generator_) {
PADDLE_ENFORCE(paddle::platform::dynload::curandDestroyGenerator(
rand_generator_) == CURAND_STATUS_SUCCESS,
"curandDestroyGenerator failed");
}

delete eigen_stream_;
delete eigen_device_;

paddle::platform::throw_on_error(cudaStreamDestroy(stream_),
"cudaStreamDestroy failed");
}

private:
GPUPlace gpu_place_;
cudaStream_t stream_;

Eigen::CudaStreamDevice* eigen_stream_;
Eigen::GpuDevice* eigen_device_;

cublasHandle_t blas_handle_{nullptr};

cudnnHandle_t dnn_handle_{nullptr};

int random_seed_;
curandGenerator_t rand_generator_{nullptr};
};
#endif
} // namespace platform
} // namespace paddle
33 changes: 33 additions & 0 deletions paddle/platform/device_context_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/platform/device_context.h"
#include "gtest/gtest.h"

TEST(DeviceContext, CudaDevice) {
int count = paddle::platform::GetDeviceCount();
for (int i = 0; i < count; i++) {
paddle::platform::CudaDeviceContext* device_context =
new paddle::platform::CudaDeviceContext(i);
__attribute__((unused)) Eigen::GpuDevice gpu_device =
device_context->eigen_device();
__attribute__((unused)) cudnnHandle_t cudnn_handle =
device_context->cudnn_handle();
__attribute__((unused)) cublasHandle_t cublas_handle =
device_context->cublas_handle();
__attribute__((unused)) curandGenerator_t curand_handle =
device_context->curand_generator();
delete device_context;
}
}