Skip to content

Commit 0e9b393

Browse files
authored
"derived cudnnDevice context" (#6585)
* "derived cudnnDevice context" * "leave remove cudnn handle from CUDADeviceContext" * "fix math function error"
1 parent 49b8ac8 commit 0e9b393

File tree

5 files changed

+61
-1
lines changed

5 files changed

+61
-1
lines changed

paddle/operators/math/math_function.cu

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,13 @@ void set_constant_with_place<platform::GPUPlace>(
273273
TensorSetConstantGPU(context, tensor, value));
274274
}
275275

276+
template <>
277+
void set_constant_with_place<platform::CudnnPlace>(
278+
const platform::DeviceContext& context, framework::Tensor* tensor,
279+
float value) {
280+
set_constant_with_place<platform::GPUPlace>(context, tensor, value);
281+
}
282+
276283
template struct RowwiseAdd<platform::CUDADeviceContext, float>;
277284
template struct RowwiseAdd<platform::CUDADeviceContext, double>;
278285
template struct ColwiseSum<platform::CUDADeviceContext, float>;

paddle/platform/device_context.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,22 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; }
125125

126126
cudaStream_t CUDADeviceContext::stream() const { return stream_; }
127127

128+
CudnnDeviceContext::CudnnDeviceContext(CudnnPlace place)
129+
: CUDADeviceContext(place), place_(place) {
130+
PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_));
131+
PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, stream()));
132+
}
133+
134+
CudnnDeviceContext::~CudnnDeviceContext() {
135+
SetDeviceId(place_.device);
136+
Wait();
137+
PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_));
138+
}
139+
140+
Place CudnnDeviceContext::GetPlace() const { return CudnnPlace(); }
141+
142+
cudnnHandle_t CudnnDeviceContext::cudnn_handle() const { return cudnn_handle_; }
143+
128144
#endif
129145

130146
} // namespace platform

paddle/platform/device_context.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,22 @@ class CUDADeviceContext : public DeviceContext {
8686
cublasHandle_t cublas_handle_;
8787
};
8888

89+
class CudnnDeviceContext : public CUDADeviceContext {
90+
public:
91+
explicit CudnnDeviceContext(CudnnPlace place);
92+
virtual ~CudnnDeviceContext();
93+
94+
/*! \brief Return place in the device context. */
95+
Place GetPlace() const final;
96+
97+
/*! \brief Return cudnn handle in the device context. */
98+
cudnnHandle_t cudnn_handle() const;
99+
100+
private:
101+
cudnnHandle_t cudnn_handle_;
102+
CudnnPlace place_;
103+
};
104+
89105
#endif
90106

91107
} // namespace platform

paddle/platform/device_context_test.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,19 @@ TEST(Device, CUDADeviceContext) {
4646
delete device_context;
4747
}
4848
}
49+
50+
TEST(Device, CudnnDeviceContext) {
51+
using paddle::platform::CudnnDeviceContext;
52+
using paddle::platform::CudnnPlace;
53+
if (paddle::platform::dynload::HasCUDNN()) {
54+
int count = paddle::platform::GetCUDADeviceCount();
55+
for (int i = 0; i < count; ++i) {
56+
CudnnDeviceContext* device_context =
57+
new CudnnDeviceContext(CudnnPlace(i));
58+
cudnnHandle_t cudnn_handle = device_context->cudnn_handle();
59+
ASSERT_NE(nullptr, cudnn_handle);
60+
ASSERT_NE(nullptr, device_context->stream());
61+
delete device_context;
62+
}
63+
}
64+
}

paddle/platform/place.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@ struct GPUPlace {
4343
int device;
4444
};
4545

46+
struct CudnnPlace : public GPUPlace {
47+
CudnnPlace() : GPUPlace() {}
48+
explicit CudnnPlace(int d) : GPUPlace(d) {}
49+
};
50+
4651
struct IsGPUPlace : public boost::static_visitor<bool> {
4752
bool operator()(const CPUPlace &) const { return false; }
4853
bool operator()(const GPUPlace &gpu) const { return true; }
@@ -52,7 +57,7 @@ struct IsGPUPlace : public boost::static_visitor<bool> {
5257
// should be less equal than 2^(NUM_PLACE_TYPE_LIMIT_IN_BIT)
5358
#define NUM_PLACE_TYPE_LIMIT_IN_BIT 4
5459

55-
typedef boost::variant<GPUPlace, CPUPlace> Place;
60+
typedef boost::variant<CudnnPlace, GPUPlace, CPUPlace> Place;
5661

5762
// static check number of place types is less equal than
5863
// 2^(NUM_PLACE_TYPE_LIMIT_IN_BIT)

0 commit comments

Comments
 (0)