|
7 | 7 |
|
8 | 8 | namespace onnxruntime { |
9 | 9 |
|
| 10 | +DeferredCpuAllocator::DeferredCpuAllocator(CudaStream& cuda_stream) : cuda_stream_(cuda_stream) { |
| 11 | + OrtAllocator::version = ORT_API_VERSION; |
| 12 | + OrtAllocator::Alloc = |
| 13 | + [](OrtAllocator* this_, size_t size) { |
| 14 | + auto self = reinterpret_cast<DeferredCpuAllocator*>(this_); |
| 15 | + return self->cuda_stream_.GetCpuAllocator()->Alloc(size); |
| 16 | + }; |
| 17 | + OrtAllocator::Free = |
| 18 | + [](OrtAllocator* this_, void* p) { |
| 19 | + auto self = reinterpret_cast<DeferredCpuAllocator*>(this_); |
| 20 | + self->cuda_stream_.EnqueDeferredCPUBuffer(p); |
| 21 | + }; |
| 22 | + OrtAllocator::Info = |
| 23 | + [](const OrtAllocator* this_) { |
| 24 | + auto self = reinterpret_cast<const DeferredCpuAllocator*>(this_); |
| 25 | + return &self->cuda_stream_.GetCpuAllocator()->Info(); |
| 26 | + }; |
| 27 | +} |
| 28 | + |
10 | 29 | struct CudaNotification : public synchronize::Notification { |
11 | 30 | CudaNotification(Stream& s) : Notification(s) { |
12 | 31 | CUDA_CALL_THROW(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming)); |
@@ -46,7 +65,8 @@ CudaStream::CudaStream(cudaStream_t stream, |
46 | 65 | cublasHandle_t external_cublas_handle) : Stream(stream, device), |
47 | 66 | own_stream_(own_flag), |
48 | 67 | cpu_allocator_(cpu_allocator), |
49 | | - release_cpu_buffer_on_cuda_stream_(release_cpu_buffer_on_cuda_stream) { |
| 68 | + release_cpu_buffer_on_cuda_stream_(release_cpu_buffer_on_cuda_stream), |
| 69 | + deferred_cpu_allocator_(*this) { |
50 | 70 | if (own_flag) { |
51 | 71 | CUBLAS_CALL_THROW(cublasCreate(&cublas_handle_)); |
52 | 72 | CUBLAS_CALL_THROW(cublasSetStream(cublas_handle_, stream)); |
@@ -162,6 +182,9 @@ void* CudaStream::GetResource(int version, int id) const { |
162 | 182 | case CudaResource::cublas_handle_t: |
163 | 183 | return reinterpret_cast<void*>(cublas_handle_); |
164 | 184 | break; |
| 185 | + case CudaResource::deferred_cpu_allocator_t: |
| 186 | + return const_cast<DeferredCpuAllocator*>(&deferred_cpu_allocator_); |
| 187 | + break; |
165 | 188 | default: |
166 | 189 | break; |
167 | 190 | } |
|
0 commit comments