Skip to content

Commit 7c5f282

Browse files
authored
To Support CUDA kernel CustomDevice (PaddlePaddle#72604)
1 parent 95b3ef2 commit 7c5f282

25 files changed

+731
-59
lines changed

paddle/common/backend_header.h

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#if defined(PADDLE_WITH_CUDA)
18+
#include <cuda.h>
19+
#endif
20+
21+
#if defined(__CUDACC__) && CUDA_VERSION >= 11000
22+
#define PADDLE_CUDA_BF16
23+
#include <cuda_bf16.h>
24+
#endif
25+
26+
#ifndef PADDLE_WITH_HIP
27+
#if !defined(_WIN32)
28+
#define PADDLE_ALIGN(x) __attribute__((aligned(x)))
29+
#else
30+
#define PADDLE_ALIGN(x) __declspec(align(x))
31+
#endif
32+
#else
33+
#define PADDLE_ALIGN(x)
34+
#endif

paddle/fluid/custom_engine/custom_device_load.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
#include <glog/logging.h>
1616

17-
#
1817
#include "paddle/fluid/custom_engine/custom_device_load.h"
1918
namespace paddle {
2019

paddle/phi/backends/custom/custom_context.cc

Lines changed: 198 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,39 +11,129 @@ distributed under the License is distributed on an "AS IS" BASIS,
1111
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
14-
1514
#include "paddle/phi/backends/custom/custom_context.h"
1615

16+
#include "paddle/common/exception.h"
1717
#include "paddle/phi/backends/device_guard.h"
18+
#include "paddle/phi/backends/device_manager.h"
1819
#include "paddle/phi/backends/stream.h"
20+
#include "paddle/phi/common/place.h"
21+
#include "paddle/phi/core/enforce.h"
22+
#include "paddle/phi/core/memory/allocation/allocator_facade.h"
23+
#include "unsupported/Eigen/CXX11/Tensor"
1924

2025
namespace phi {
2126

2227
struct CustomContext::Impl {
2328
explicit Impl(const CustomPlace& place) : place_(place) {}
2429

25-
~Impl() {}
30+
~Impl() {
31+
phi::DeviceGuard guard(place_);
32+
if (owned_) {
33+
DeviceManager::DestroyEigenDevice(place_, eigen_device_);
34+
}
35+
if (stream_owned_ && stream_) {
36+
stream_->Destroy();
37+
}
38+
}
2639

2740
void Init() {
41+
owned_ = true;
42+
phi::DeviceGuard guard(place_);
43+
compute_capability_ = DeviceManager::GetComputeCapability(place_);
44+
runtime_version_ = DeviceManager::GetRuntimeVersion(place_);
45+
driver_version_ = DeviceManager::GetDriverVersion(place_);
46+
multi_process_ = DeviceManager::GetMultiProcessors(place_);
47+
max_threads_per_mp_ = DeviceManager::GetMaxThreadsPerMultiProcessor(place_);
48+
max_threads_per_block_ = DeviceManager::GetMaxThreadsPerBlock(place_);
49+
max_grid_dim_size_ = DeviceManager::GetMaxGridDimSize(place_);
50+
eigen_device_ =
51+
reinterpret_cast<Eigen::GpuDevice*>(DeviceManager::InitEigenDevice(
52+
place_, stream_->raw_stream(), allocator_));
53+
54+
stream_.reset(new phi::stream::Stream());
55+
stream_->Init(place_);
56+
}
57+
58+
void PartialInitWithoutAllocator() {
59+
owned_ = true;
60+
stream_owned_ = true;
2861
phi::DeviceGuard guard(place_);
62+
compute_capability_ = DeviceManager::GetComputeCapability(place_);
63+
runtime_version_ = DeviceManager::GetRuntimeVersion(place_);
64+
driver_version_ = DeviceManager::GetDriverVersion(place_);
65+
multi_process_ = DeviceManager::GetMultiProcessors(place_);
66+
max_threads_per_mp_ = DeviceManager::GetMaxThreadsPerMultiProcessor(place_);
67+
max_threads_per_block_ = DeviceManager::GetMaxThreadsPerBlock(place_);
68+
max_grid_dim_size_ = DeviceManager::GetMaxGridDimSize(place_);
69+
2970
stream_.reset(new phi::stream::Stream());
3071
stream_->Init(place_);
3172
}
3273

74+
void PartialInitWithAllocator() {
75+
owned_ = true;
76+
stream_owned_ = true;
77+
phi::DeviceGuard guard(place_);
78+
}
79+
3380
const Place& GetPlace() const { return place_; }
3481

35-
void* stream() const {
36-
return reinterpret_cast<void*>(stream_->raw_stream());
82+
phi::stream::stream_t stream() const {
83+
return reinterpret_cast<phi::stream::stream_t>(stream_->raw_stream());
3784
}
3885

3986
std::shared_ptr<phi::stream::Stream> GetStream() const { return stream_; }
4087

4188
void SetStream(std::shared_ptr<phi::stream::Stream> stream) {
89+
stream_owned_ = true;
4290
stream_ = stream;
4391
}
4492

93+
void SetEigenDevice(Eigen::GpuDevice* device) { eigen_device_ = device; }
94+
95+
void SetEigenDevice(std::function<Eigen::GpuDevice*()>&& creator) {
96+
eigen_device_creator_ = std::move(creator);
97+
}
98+
99+
Eigen::GpuDevice* eigen_device() {
100+
std::call_once(flag_eigen_device_, [&]() {
101+
if (!eigen_device_) {
102+
if (!eigen_device_creator_) {
103+
// use default initial
104+
eigen_device_ = reinterpret_cast<Eigen::GpuDevice*>(
105+
DeviceManager::InitEigenDevice(
106+
place_, stream_->raw_stream(), allocator_));
107+
} else {
108+
eigen_device_ = eigen_device_creator_();
109+
}
110+
}
111+
});
112+
PADDLE_ENFORCE_NOT_NULL(
113+
eigen_device_,
114+
common::errors::InvalidArgument(
115+
"The custom eigen_device is nullptr. It must not be null."));
116+
return eigen_device_;
117+
}
118+
45119
void Wait() const { stream_->Wait(); }
46120

121+
void WaitEvent(phi::event::event_t ev) const {
122+
event::Event event_(place_, ev);
123+
stream_->WaitEvent(&event_);
124+
}
125+
126+
void RecordEvent(phi::event::event_t ev,
127+
const std::function<void()>& callback) const {
128+
event::Event event_(place_, ev);
129+
stream_->RecordEvent(&event_, callback);
130+
}
131+
132+
void RecordEvent(phi::event::event_t ev) const {
133+
event::Event event_(place_, ev);
134+
stream_->RecordEvent(&event_);
135+
}
136+
47137
phi::ccl::CCLComm xccl_comm() const { return comm_; }
48138

49139
void set_xccl_comm(phi::ccl::CCLComm comm) { comm_ = comm; }
@@ -52,31 +142,87 @@ struct CustomContext::Impl {
52142

53143
std::shared_ptr<phi::stream::Stream> stream_;
54144

145+
Allocator* allocator_{nullptr};
146+
55147
phi::ccl::CCLComm comm_;
148+
149+
bool owned_{false};
150+
bool stream_owned_{false};
151+
int compute_capability_ = 0;
152+
int runtime_version_ = 0;
153+
int driver_version_ = 0;
154+
int multi_process_ = 0;
155+
int max_threads_per_mp_ = 0;
156+
int max_threads_per_block_ = 0;
157+
std::array<unsigned int, 3> max_grid_dim_size_;
158+
159+
Eigen::GpuDevice* eigen_device_{nullptr};
160+
std::function<Eigen::GpuDevice*()> eigen_device_creator_{nullptr};
161+
std::once_flag flag_eigen_device_;
56162
};
57163

58-
void CustomContext::Init() { impl_->Init(); }
164+
CustomContext::CustomContext(const CustomPlace& place)
165+
: DeviceContext(), impl_(std::make_unique<Impl>(place)) {
166+
impl_->PartialInitWithoutAllocator();
167+
}
168+
169+
CustomContext::~CustomContext() { impl_.reset(); }
170+
171+
void CustomContext::Init() {
172+
impl_->allocator_ = const_cast<Allocator*>(&this->GetAllocator());
173+
impl_->Init();
174+
}
175+
176+
void CustomContext::PartialInitWithoutAllocator() {
177+
impl_->PartialInitWithoutAllocator();
178+
}
179+
180+
void CustomContext::PartialInitWithAllocator() {
181+
impl_->allocator_ = const_cast<Allocator*>(&this->GetAllocator()); // NOLINT
182+
impl_->PartialInitWithAllocator();
183+
}
59184

60185
const Place& CustomContext::GetPlace() const { return impl_->GetPlace(); }
61186

62-
void* CustomContext::stream() const { return impl_->stream(); }
187+
phi::stream::stream_t CustomContext::stream() const { return impl_->stream(); }
63188

64189
std::shared_ptr<phi::stream::Stream> CustomContext::GetStream() const {
65190
return impl_->GetStream();
66191
}
67192

68193
void CustomContext::SetStream(std::shared_ptr<phi::stream::Stream> stream) {
194+
#if !defined(_WIN32)
195+
this->SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance()
196+
.GetAllocator(impl_->GetPlace(), stream->raw_stream())
197+
.get());
198+
#endif
199+
impl_->allocator_ = const_cast<Allocator*>(&this->GetAllocator()); // NOLINT
69200
impl_->SetStream(stream);
70201
}
71202

72203
void CustomContext::Wait() const { return impl_->Wait(); }
73204

74-
CustomContext::CustomContext(const CustomPlace& place)
75-
: DeviceContext(), impl_(std::make_unique<Impl>(place)) {
76-
impl_->Init();
205+
void CustomContext::RecordEvent(phi::event::event_t ev,
206+
const std::function<void()>& callback) const {
207+
impl_->RecordEvent(ev, callback);
77208
}
78209

79-
CustomContext::~CustomContext() { impl_.reset(); }
210+
void CustomContext::RecordEvent(phi::event::event_t ev) const {
211+
impl_->RecordEvent(ev);
212+
}
213+
214+
Eigen::GpuDevice* CustomContext::eigen_device() const {
215+
return impl_->eigen_device();
216+
}
217+
218+
void CustomContext::SetEigenDevice(Eigen::GpuDevice* device) {
219+
impl_->SetEigenDevice(device);
220+
}
221+
222+
void CustomContext::SetEigenDevice(
223+
std::function<Eigen::GpuDevice*()>&& creator) {
224+
impl_->SetEigenDevice(std::move(creator));
225+
}
80226

81227
phi::ccl::CCLComm CustomContext::xccl_comm() const {
82228
return impl_->xccl_comm();
@@ -85,4 +231,46 @@ phi::ccl::CCLComm CustomContext::xccl_comm() const {
85231
void CustomContext::set_xccl_comm(phi::ccl::CCLComm comm) {
86232
impl_->set_xccl_comm(comm);
87233
}
234+
235+
int CustomContext::GetComputeCapability() const {
236+
return impl_->compute_capability_;
237+
}
238+
239+
int CustomContext::GetMaxThreadsPerBlock() const {
240+
return impl_->max_threads_per_block_;
241+
}
242+
243+
int CustomContext::GetSMCount() const { return impl_->multi_process_; }
244+
245+
std::array<unsigned int, 3> CustomContext::GetCUDAMaxGridDimSize() const {
246+
return impl_->max_grid_dim_size_;
247+
}
248+
249+
int CustomContext::GetMaxPhysicalThreadCount() const {
250+
return impl_->multi_process_ * impl_->max_threads_per_mp_;
251+
}
252+
253+
void CustomContext::SetComputeCapability(int val) {
254+
impl_->compute_capability_ = val;
255+
}
256+
257+
void CustomContext::SetMaxThreadsPerMultiProcessor(int val) {
258+
impl_->max_threads_per_mp_ = val;
259+
}
260+
261+
void CustomContext::SetMultiProcessors(int val) { impl_->multi_process_ = val; }
262+
263+
void CustomContext::SetMaxThreadsPerBlock(int val) {
264+
impl_->max_threads_per_block_ = val;
265+
}
266+
267+
void CustomContext::SetMaxGridDimSize(const std::array<unsigned int, 3>& val) {
268+
impl_->max_grid_dim_size_ = val;
269+
}
270+
271+
void CustomContext::SetDriverVersion(int val) { impl_->driver_version_ = val; }
272+
273+
void CustomContext::SetRuntimeVersion(int val) {
274+
impl_->runtime_version_ = val;
275+
}
88276
} // namespace phi

0 commit comments

Comments
 (0)