Skip to content

Commit 009cd4e

Browse files
Allow cuda custom ops allocate deferred cpu mem (#17893)
Expose a new allocator from cuda stream. The allocator manages deferred cpu memory which only get recycled before stream destruction. --------- Co-authored-by: Randy Shuai <rashuai@microsoft.com>
1 parent 2f57625 commit 009cd4e

9 files changed

Lines changed: 100 additions & 17 deletions

File tree

include/onnxruntime/core/providers/cuda/cuda_context.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ struct CudaContext : public CustomOpContext {
1919
cudaStream_t cuda_stream = {};
2020
cudnnHandle_t cudnn_handle = {};
2121
cublasHandle_t cublas_handle = {};
22+
OrtAllocator* deferred_cpu_allocator = {};
2223

2324
void Init(const OrtKernelContext& kernel_ctx) override {
2425
const auto& ort_api = Ort::GetApi();
@@ -44,6 +45,36 @@ struct CudaContext : public CustomOpContext {
4445
ORT_CXX_API_THROW("failed to fetch cublas handle", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
4546
}
4647
cublas_handle = reinterpret_cast<cublasHandle_t>(resource);
48+
49+
resource = {};
50+
status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_CUDA_RESOUCE_VERSION, CudaResource::deferred_cpu_allocator_t, &resource);
51+
if (status) {
52+
ORT_CXX_API_THROW("failed to fetch deferred cpu allocator", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
53+
}
54+
deferred_cpu_allocator = reinterpret_cast<OrtAllocator*>(resource);
55+
}
56+
57+
void* AllocDeferredCpuMem(size_t size) const {
58+
if (0 == size) {
59+
return {};
60+
}
61+
const auto& ort_api = Ort::GetApi();
62+
void* mem = {};
63+
auto status = ort_api.AllocatorAlloc(deferred_cpu_allocator, size, &mem);
64+
if (status) {
65+
ORT_CXX_API_THROW("failed to allocate deferred cpu memory", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
66+
}
67+
return mem;
68+
}
69+
70+
void FreeDeferredCpuMem(void* mem) const {
71+
if (mem) {
72+
const auto& ort_api = Ort::GetApi();
73+
auto status = ort_api.AllocatorFree(deferred_cpu_allocator, mem);
74+
if (status) {
75+
ORT_CXX_API_THROW("failed to free deferred cpu memory", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
76+
}
77+
}
4778
}
4879
};
4980

include/onnxruntime/core/providers/cuda/cuda_resource.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33

44
#include "core/providers/resource.h"
55

6-
#define ORT_CUDA_RESOUCE_VERSION 1
6+
#define ORT_CUDA_RESOUCE_VERSION 2
77

88
enum CudaResource : int {
99
cuda_stream_t = cuda_resource_offset,
1010
cudnn_handle_t,
11-
cublas_handle_t
11+
cublas_handle_t,
12+
deferred_cpu_allocator_t,
1213
};

onnxruntime/core/providers/cuda/cuda_stream_handle.cc

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,25 @@
77

88
namespace onnxruntime {
99

10+
DeferredCpuAllocator::DeferredCpuAllocator(CudaStream& cuda_stream) : cuda_stream_(cuda_stream) {
11+
OrtAllocator::version = ORT_API_VERSION;
12+
OrtAllocator::Alloc =
13+
[](OrtAllocator* this_, size_t size) {
14+
auto self = reinterpret_cast<DeferredCpuAllocator*>(this_);
15+
return self->cuda_stream_.GetCpuAllocator()->Alloc(size);
16+
};
17+
OrtAllocator::Free =
18+
[](OrtAllocator* this_, void* p) {
19+
auto self = reinterpret_cast<DeferredCpuAllocator*>(this_);
20+
self->cuda_stream_.EnqueDeferredCPUBuffer(p);
21+
};
22+
OrtAllocator::Info =
23+
[](const OrtAllocator* this_) {
24+
auto self = reinterpret_cast<const DeferredCpuAllocator*>(this_);
25+
return &self->cuda_stream_.GetCpuAllocator()->Info();
26+
};
27+
}
28+
1029
struct CudaNotification : public synchronize::Notification {
1130
CudaNotification(Stream& s) : Notification(s) {
1231
CUDA_CALL_THROW(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming));
@@ -46,7 +65,8 @@ CudaStream::CudaStream(cudaStream_t stream,
4665
cublasHandle_t external_cublas_handle) : Stream(stream, device),
4766
own_stream_(own_flag),
4867
cpu_allocator_(cpu_allocator),
49-
release_cpu_buffer_on_cuda_stream_(release_cpu_buffer_on_cuda_stream) {
68+
release_cpu_buffer_on_cuda_stream_(release_cpu_buffer_on_cuda_stream),
69+
deferred_cpu_allocator_(*this) {
5070
if (own_flag) {
5171
CUBLAS_CALL_THROW(cublasCreate(&cublas_handle_));
5272
CUBLAS_CALL_THROW(cublasSetStream(cublas_handle_, stream));
@@ -162,6 +182,9 @@ void* CudaStream::GetResource(int version, int id) const {
162182
case CudaResource::cublas_handle_t:
163183
return reinterpret_cast<void*>(cublas_handle_);
164184
break;
185+
case CudaResource::deferred_cpu_allocator_t:
186+
return const_cast<DeferredCpuAllocator*>(&deferred_cpu_allocator_);
187+
break;
165188
default:
166189
break;
167190
}

onnxruntime/core/providers/cuda/cuda_stream_handle.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,13 @@
99

1010
namespace onnxruntime {
1111

12+
struct CudaStream;
13+
14+
struct DeferredCpuAllocator : public OrtAllocator {
15+
DeferredCpuAllocator(CudaStream&);
16+
CudaStream& cuda_stream_;
17+
};
18+
1219
struct CudaStream : Stream {
1320
CudaStream(cudaStream_t stream,
1421
const OrtDevice& device,
@@ -36,10 +43,13 @@ struct CudaStream : Stream {
3643

3744
void* GetResource(int version, int id) const override;
3845

46+
onnxruntime::IAllocator* GetCpuAllocator() const { return cpu_allocator_.get(); }
47+
3948
private:
4049
std::vector<void*> deferred_cpu_buffers_;
4150
AllocatorPtr cpu_allocator_;
4251
bool release_cpu_buffer_on_cuda_stream_{true};
52+
DeferredCpuAllocator deferred_cpu_allocator_;
4353
};
4454

4555
void RegisterCudaStreamHandles(IStreamCommandHandleRegistry& stream_handle_registry,

onnxruntime/test/testdata/custom_op_library/cuda/cuda_ops.cc

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
// Copyright (c) Microsoft Corporation. All rights reserved.
22
// Licensed under the MIT License.
33

4-
#ifdef USE_CUDA
4+
#if defined(USE_CUDA) && !defined(ENABLE_TRAINING)
55

66
#define ORT_API_MANUAL_INIT
77
#include "onnxruntime_cxx_api.h"
@@ -32,6 +32,9 @@ void KernelOne(const Ort::Custom::CudaContext& cuda_ctx,
3232
CUSTOM_ENFORCE(cuda_ctx.cuda_stream, "failed to fetch cuda stream");
3333
CUSTOM_ENFORCE(cuda_ctx.cudnn_handle, "failed to fetch cudnn handle");
3434
CUSTOM_ENFORCE(cuda_ctx.cublas_handle, "failed to fetch cublas handle");
35+
void* deferred_cpu_mem = cuda_ctx.AllocDeferredCpuMem(sizeof(int32_t));
36+
CUSTOM_ENFORCE(deferred_cpu_mem, "failed to allocate deferred cpu allocator");
37+
cuda_ctx.FreeDeferredCpuMem(deferred_cpu_mem);
3538
auto z_raw = Z.Allocate(input_shape);
3639
cuda_add(Z.NumberOfElement(), z_raw, X.Data(), Y.Data(), cuda_ctx.cuda_stream);
3740
}
@@ -43,8 +46,4 @@ void RegisterOps(Ort::CustomOpDomain& domain) {
4346

4447
} // namespace Cuda
4548

46-
#else
47-
48-
void Cuda::RegisterOps(Ort::CustomOpDomain& domain) {}
49-
5049
#endif

onnxruntime/test/testdata/custom_op_library/cuda/cuda_ops.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,14 @@
55

66
namespace Cuda {
77

8+
#if defined(USE_CUDA) && !defined(ENABLE_TRAINING)
9+
810
void RegisterOps(Ort::CustomOpDomain& domain);
911

10-
}
12+
#else
13+
14+
void RegisterOps(Ort::CustomOpDomain&) {}
15+
16+
#endif
17+
18+
} // namespace Cuda

onnxruntime/test/testdata/custom_op_library/custom_op_library.cc

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
#include "core/framework/ortdevice.h"
1414
#include "core/framework/ortmemoryinfo.h"
1515
#include "cpu/cpu_ops.h"
16+
#include "cuda/cuda_ops.h"
17+
#include "rocm/rocm_ops.h"
1618
#include "onnxruntime_lite_custom_op.h"
1719

1820
static const char* c_OpDomain = "test.customop";
@@ -31,10 +33,15 @@ OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtA
3133
ORT_TRY {
3234
Ort::CustomOpDomain domain{c_OpDomain};
3335
Cpu::RegisterOps(domain);
34-
3536
Ort::CustomOpDomain domain_v2{"v2"};
3637
Cpu::RegisterOps(domain_v2);
3738

39+
Cuda::RegisterOps(domain);
40+
Cuda::RegisterOps(domain_v2);
41+
42+
Rocm::RegisterOps(domain);
43+
Rocm::RegisterOps(domain_v2);
44+
3845
Ort::UnownedSessionOptions session_options(options);
3946
session_options.Add(domain);
4047
session_options.Add(domain_v2);

onnxruntime/test/testdata/custom_op_library/rocm/rocm_ops.cc

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ using namespace Ort::Custom;
1919
throw std::runtime_error(msg); \
2020
}
2121

22-
namespace Cuda {
22+
namespace Rocm {
2323

2424
void KernelOne(const Ort::Custom::RocmContext& rocm_ctx,
2525
const Ort::Custom::Tensor<float>& X,
@@ -38,10 +38,6 @@ void RegisterOps(Ort::CustomOpDomain& domain) {
3838
domain.Add(c_CustomOpOne.get());
3939
}
4040

41-
} // namespace Cuda
42-
43-
#else
44-
45-
void Cuda::RegisterOps(Ort::CustomOpDomain& domain) {}
41+
} // namespace Rocm
4642

4743
#endif

onnxruntime/test/testdata/custom_op_library/rocm/rocm_ops.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,14 @@
55

66
namespace Rocm {
77

8+
#ifdef USE_ROCM
9+
810
void RegisterOps(Ort::CustomOpDomain& domain);
911

10-
}
12+
#else
13+
14+
inline void RegisterOps(Ort::CustomOpDomain&) {}
15+
16+
#endif
17+
18+
} // namespace Rocm

0 commit comments

Comments
 (0)