@@ -11,39 +11,129 @@ distributed under the License is distributed on an "AS IS" BASIS,
1111WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212See the License for the specific language governing permissions and
1313limitations under the License. */
14-
1514#include " paddle/phi/backends/custom/custom_context.h"
1615
16+ #include " paddle/common/exception.h"
1717#include " paddle/phi/backends/device_guard.h"
18+ #include " paddle/phi/backends/device_manager.h"
1819#include " paddle/phi/backends/stream.h"
20+ #include " paddle/phi/common/place.h"
21+ #include " paddle/phi/core/enforce.h"
22+ #include " paddle/phi/core/memory/allocation/allocator_facade.h"
23+ #include " unsupported/Eigen/CXX11/Tensor"
1924
2025namespace phi {
2126
2227struct CustomContext ::Impl {
2328 explicit Impl (const CustomPlace& place) : place_(place) {}
2429
25- ~Impl () {}
30+ ~Impl () {
31+ phi::DeviceGuard guard (place_);
32+ if (owned_) {
33+ DeviceManager::DestroyEigenDevice (place_, eigen_device_);
34+ }
35+ if (stream_owned_ && stream_) {
36+ stream_->Destroy ();
37+ }
38+ }
2639
2740 void Init () {
41+ owned_ = true ;
42+ phi::DeviceGuard guard (place_);
43+ compute_capability_ = DeviceManager::GetComputeCapability (place_);
44+ runtime_version_ = DeviceManager::GetRuntimeVersion (place_);
45+ driver_version_ = DeviceManager::GetDriverVersion (place_);
46+ multi_process_ = DeviceManager::GetMultiProcessors (place_);
47+ max_threads_per_mp_ = DeviceManager::GetMaxThreadsPerMultiProcessor (place_);
48+ max_threads_per_block_ = DeviceManager::GetMaxThreadsPerBlock (place_);
49+ max_grid_dim_size_ = DeviceManager::GetMaxGridDimSize (place_);
50+ eigen_device_ =
51+ reinterpret_cast <Eigen::GpuDevice*>(DeviceManager::InitEigenDevice (
52+ place_, stream_->raw_stream (), allocator_));
53+
54+ stream_.reset (new phi::stream::Stream ());
55+ stream_->Init (place_);
56+ }
57+
58+ void PartialInitWithoutAllocator () {
59+ owned_ = true ;
60+ stream_owned_ = true ;
2861 phi::DeviceGuard guard (place_);
62+ compute_capability_ = DeviceManager::GetComputeCapability (place_);
63+ runtime_version_ = DeviceManager::GetRuntimeVersion (place_);
64+ driver_version_ = DeviceManager::GetDriverVersion (place_);
65+ multi_process_ = DeviceManager::GetMultiProcessors (place_);
66+ max_threads_per_mp_ = DeviceManager::GetMaxThreadsPerMultiProcessor (place_);
67+ max_threads_per_block_ = DeviceManager::GetMaxThreadsPerBlock (place_);
68+ max_grid_dim_size_ = DeviceManager::GetMaxGridDimSize (place_);
69+
2970 stream_.reset (new phi::stream::Stream ());
3071 stream_->Init (place_);
3172 }
3273
74+ void PartialInitWithAllocator () {
75+ owned_ = true ;
76+ stream_owned_ = true ;
77+ phi::DeviceGuard guard (place_);
78+ }
79+
3380 const Place& GetPlace () const { return place_; }
3481
35- void * stream () const {
36- return reinterpret_cast <void * >(stream_->raw_stream ());
82+ phi::stream:: stream_t stream () const {
83+ return reinterpret_cast <phi::stream:: stream_t >(stream_->raw_stream ());
3784 }
3885
3986 std::shared_ptr<phi::stream::Stream> GetStream () const { return stream_; }
4087
4188 void SetStream (std::shared_ptr<phi::stream::Stream> stream) {
89+ stream_owned_ = true ;
4290 stream_ = stream;
4391 }
4492
93+ void SetEigenDevice (Eigen::GpuDevice* device) { eigen_device_ = device; }
94+
95+ void SetEigenDevice (std::function<Eigen::GpuDevice*()>&& creator) {
96+ eigen_device_creator_ = std::move (creator);
97+ }
98+
99+ Eigen::GpuDevice* eigen_device () {
100+ std::call_once (flag_eigen_device_, [&]() {
101+ if (!eigen_device_) {
102+ if (!eigen_device_creator_) {
103+ // use default initial
104+ eigen_device_ = reinterpret_cast <Eigen::GpuDevice*>(
105+ DeviceManager::InitEigenDevice (
106+ place_, stream_->raw_stream (), allocator_));
107+ } else {
108+ eigen_device_ = eigen_device_creator_ ();
109+ }
110+ }
111+ });
112+ PADDLE_ENFORCE_NOT_NULL (
113+ eigen_device_,
114+ common::errors::InvalidArgument (
115+ " The custom eigen_device is nullptr. It must not be null." ));
116+ return eigen_device_;
117+ }
118+
45119 void Wait () const { stream_->Wait (); }
46120
121+ void WaitEvent (phi::event::event_t ev) const {
122+ event::Event event_ (place_, ev);
123+ stream_->WaitEvent (&event_);
124+ }
125+
126+ void RecordEvent (phi::event::event_t ev,
127+ const std::function<void ()>& callback) const {
128+ event::Event event_ (place_, ev);
129+ stream_->RecordEvent (&event_, callback);
130+ }
131+
132+ void RecordEvent (phi::event::event_t ev) const {
133+ event::Event event_ (place_, ev);
134+ stream_->RecordEvent (&event_);
135+ }
136+
47137 phi::ccl::CCLComm xccl_comm () const { return comm_; }
48138
49139 void set_xccl_comm (phi::ccl::CCLComm comm) { comm_ = comm; }
@@ -52,31 +142,87 @@ struct CustomContext::Impl {
52142
53143 std::shared_ptr<phi::stream::Stream> stream_;
54144
145+ Allocator* allocator_{nullptr };
146+
55147 phi::ccl::CCLComm comm_;
148+
149+ bool owned_{false };
150+ bool stream_owned_{false };
151+ int compute_capability_ = 0 ;
152+ int runtime_version_ = 0 ;
153+ int driver_version_ = 0 ;
154+ int multi_process_ = 0 ;
155+ int max_threads_per_mp_ = 0 ;
156+ int max_threads_per_block_ = 0 ;
157+ std::array<unsigned int , 3 > max_grid_dim_size_;
158+
159+ Eigen::GpuDevice* eigen_device_{nullptr };
160+ std::function<Eigen::GpuDevice*()> eigen_device_creator_{nullptr };
161+ std::once_flag flag_eigen_device_;
56162};
57163
58- void CustomContext::Init () { impl_->Init (); }
164+ CustomContext::CustomContext (const CustomPlace& place)
165+ : DeviceContext(), impl_(std::make_unique<Impl>(place)) {
166+ impl_->PartialInitWithoutAllocator ();
167+ }
168+
169+ CustomContext::~CustomContext () { impl_.reset (); }
170+
171+ void CustomContext::Init () {
172+ impl_->allocator_ = const_cast <Allocator*>(&this ->GetAllocator ());
173+ impl_->Init ();
174+ }
175+
176+ void CustomContext::PartialInitWithoutAllocator () {
177+ impl_->PartialInitWithoutAllocator ();
178+ }
179+
180+ void CustomContext::PartialInitWithAllocator () {
181+ impl_->allocator_ = const_cast <Allocator*>(&this ->GetAllocator ()); // NOLINT
182+ impl_->PartialInitWithAllocator ();
183+ }
59184
60185const Place& CustomContext::GetPlace () const { return impl_->GetPlace (); }
61186
62- void * CustomContext::stream () const { return impl_->stream (); }
187+ phi::stream:: stream_t CustomContext::stream () const { return impl_->stream (); }
63188
64189std::shared_ptr<phi::stream::Stream> CustomContext::GetStream () const {
65190 return impl_->GetStream ();
66191}
67192
68193void CustomContext::SetStream (std::shared_ptr<phi::stream::Stream> stream) {
194+ #if !defined(_WIN32)
195+ this ->SetAllocator (paddle::memory::allocation::AllocatorFacade::Instance ()
196+ .GetAllocator (impl_->GetPlace (), stream->raw_stream ())
197+ .get ());
198+ #endif
199+ impl_->allocator_ = const_cast <Allocator*>(&this ->GetAllocator ()); // NOLINT
69200 impl_->SetStream (stream);
70201}
71202
72203void CustomContext::Wait () const { return impl_->Wait (); }
73204
74- CustomContext::CustomContext ( const CustomPlace& place)
75- : DeviceContext(), impl_( std::make_unique<Impl>(place)) {
76- impl_->Init ( );
205+ void CustomContext::RecordEvent (phi::event:: event_t ev,
206+ const std::function< void ()>& callback) const {
207+ impl_->RecordEvent (ev, callback );
77208}
78209
79- CustomContext::~CustomContext () { impl_.reset (); }
210+ void CustomContext::RecordEvent (phi::event::event_t ev) const {
211+ impl_->RecordEvent (ev);
212+ }
213+
214+ Eigen::GpuDevice* CustomContext::eigen_device () const {
215+ return impl_->eigen_device ();
216+ }
217+
218+ void CustomContext::SetEigenDevice (Eigen::GpuDevice* device) {
219+ impl_->SetEigenDevice (device);
220+ }
221+
222+ void CustomContext::SetEigenDevice (
223+ std::function<Eigen::GpuDevice*()>&& creator) {
224+ impl_->SetEigenDevice (std::move (creator));
225+ }
80226
81227phi::ccl::CCLComm CustomContext::xccl_comm () const {
82228 return impl_->xccl_comm ();
@@ -85,4 +231,46 @@ phi::ccl::CCLComm CustomContext::xccl_comm() const {
85231void CustomContext::set_xccl_comm (phi::ccl::CCLComm comm) {
86232 impl_->set_xccl_comm (comm);
87233}
234+
235+ int CustomContext::GetComputeCapability () const {
236+ return impl_->compute_capability_ ;
237+ }
238+
239+ int CustomContext::GetMaxThreadsPerBlock () const {
240+ return impl_->max_threads_per_block_ ;
241+ }
242+
243+ int CustomContext::GetSMCount () const { return impl_->multi_process_ ; }
244+
245+ std::array<unsigned int , 3 > CustomContext::GetCUDAMaxGridDimSize () const {
246+ return impl_->max_grid_dim_size_ ;
247+ }
248+
249+ int CustomContext::GetMaxPhysicalThreadCount () const {
250+ return impl_->multi_process_ * impl_->max_threads_per_mp_ ;
251+ }
252+
253+ void CustomContext::SetComputeCapability (int val) {
254+ impl_->compute_capability_ = val;
255+ }
256+
257+ void CustomContext::SetMaxThreadsPerMultiProcessor (int val) {
258+ impl_->max_threads_per_mp_ = val;
259+ }
260+
261+ void CustomContext::SetMultiProcessors (int val) { impl_->multi_process_ = val; }
262+
263+ void CustomContext::SetMaxThreadsPerBlock (int val) {
264+ impl_->max_threads_per_block_ = val;
265+ }
266+
267+ void CustomContext::SetMaxGridDimSize (const std::array<unsigned int , 3 >& val) {
268+ impl_->max_grid_dim_size_ = val;
269+ }
270+
271+ void CustomContext::SetDriverVersion (int val) { impl_->driver_version_ = val; }
272+
273+ void CustomContext::SetRuntimeVersion (int val) {
274+ impl_->runtime_version_ = val;
275+ }
88276} // namespace phi
0 commit comments