File tree Expand file tree Collapse file tree 5 files changed +61
-1
lines changed
Expand file tree Collapse file tree 5 files changed +61
-1
lines changed Original file line number Diff line number Diff line change @@ -273,6 +273,13 @@ void set_constant_with_place<platform::GPUPlace>(
273273 TensorSetConstantGPU (context, tensor, value));
274274}
275275
276+ template <>
277+ void set_constant_with_place<platform::CudnnPlace>(
278+ const platform::DeviceContext& context, framework::Tensor* tensor,
279+ float value) {
280+ set_constant_with_place<platform::GPUPlace>(context, tensor, value);
281+ }
282+
276283template struct RowwiseAdd <platform::CUDADeviceContext, float >;
277284template struct RowwiseAdd <platform::CUDADeviceContext, double >;
278285template struct ColwiseSum <platform::CUDADeviceContext, float >;
Original file line number Diff line number Diff line change @@ -125,6 +125,22 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; }
125125
126126cudaStream_t CUDADeviceContext::stream () const { return stream_; }
127127
128+ CudnnDeviceContext::CudnnDeviceContext (CudnnPlace place)
129+ : CUDADeviceContext(place), place_(place) {
130+ PADDLE_ENFORCE (dynload::cudnnCreate (&cudnn_handle_));
131+ PADDLE_ENFORCE (dynload::cudnnSetStream (cudnn_handle_, stream ()));
132+ }
133+
134+ CudnnDeviceContext::~CudnnDeviceContext () {
135+ SetDeviceId (place_.device );
136+ Wait ();
137+ PADDLE_ENFORCE (dynload::cudnnDestroy (cudnn_handle_));
138+ }
139+
140+ Place CudnnDeviceContext::GetPlace () const { return CudnnPlace (); }
141+
142+ cudnnHandle_t CudnnDeviceContext::cudnn_handle () const { return cudnn_handle_; }
143+
128144#endif
129145
130146} // namespace platform
Original file line number Diff line number Diff line change @@ -86,6 +86,22 @@ class CUDADeviceContext : public DeviceContext {
8686 cublasHandle_t cublas_handle_;
8787};
8888
89+ class CudnnDeviceContext : public CUDADeviceContext {
90+ public:
91+ explicit CudnnDeviceContext (CudnnPlace place);
92+ virtual ~CudnnDeviceContext ();
93+
94+ /* ! \brief Return place in the device context. */
95+ Place GetPlace () const final ;
96+
97+ /* ! \brief Return cudnn handle in the device context. */
98+ cudnnHandle_t cudnn_handle () const ;
99+
100+ private:
101+ cudnnHandle_t cudnn_handle_;
102+ CudnnPlace place_;
103+ };
104+
89105#endif
90106
91107} // namespace platform
Original file line number Diff line number Diff line change @@ -46,3 +46,19 @@ TEST(Device, CUDADeviceContext) {
4646 delete device_context;
4747 }
4848}
49+
50+ TEST (Device, CudnnDeviceContext) {
51+ using paddle::platform::CudnnDeviceContext;
52+ using paddle::platform::CudnnPlace;
53+ if (paddle::platform::dynload::HasCUDNN ()) {
54+ int count = paddle::platform::GetCUDADeviceCount ();
55+ for (int i = 0 ; i < count; ++i) {
56+ CudnnDeviceContext* device_context =
57+ new CudnnDeviceContext (CudnnPlace (i));
58+ cudnnHandle_t cudnn_handle = device_context->cudnn_handle ();
59+ ASSERT_NE (nullptr , cudnn_handle);
60+ ASSERT_NE (nullptr , device_context->stream ());
61+ delete device_context;
62+ }
63+ }
64+ }
Original file line number Diff line number Diff line change @@ -43,6 +43,11 @@ struct GPUPlace {
4343 int device;
4444};
4545
46+ struct CudnnPlace : public GPUPlace {
47+ CudnnPlace () : GPUPlace() {}
48+ explicit CudnnPlace (int d) : GPUPlace(d) {}
49+ };
50+
4651struct IsGPUPlace : public boost ::static_visitor<bool > {
4752 bool operator ()(const CPUPlace &) const { return false ; }
4853 bool operator ()(const GPUPlace &gpu) const { return true ; }
@@ -52,7 +57,7 @@ struct IsGPUPlace : public boost::static_visitor<bool> {
5257// should be less equal than 2^(NUM_PLACE_TYPE_LIMIT_IN_BIT)
5358#define NUM_PLACE_TYPE_LIMIT_IN_BIT 4
5459
55- typedef boost::variant<GPUPlace, CPUPlace> Place;
60+ typedef boost::variant<CudnnPlace, GPUPlace, CPUPlace> Place;
5661
5762// static check number of place types is less equal than
5863// 2^(NUM_PLACE_TYPE_LIMIT_IN_BIT)
You can’t perform that action at this time.
0 commit comments