Skip to content

Commit 12a1302

Browse files
authored
add unified device guard (#1855)
1 parent bcb93ea commit 12a1302

File tree

11 files changed

+93
-10
lines changed

11 files changed

+93
-10
lines changed

csrc/mmdeploy/core/device.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@
77
#include <functional>
88
#include <memory>
99
#include <optional>
10+
#include <ostream>
1011
#include <string>
1112
#include <vector>
1213

1314
#include "mmdeploy/core/macro.h"
1415
#include "mmdeploy/core/mpl/type_traits.h"
1516
#include "mmdeploy/core/status_code.h"
17+
#include "mmdeploy/core/utils/formatter.h"
1618

1719
namespace mmdeploy {
1820

@@ -97,6 +99,11 @@ class Device {
9799
return PlatformId(platform_id_);
98100
}
99101

102+
friend std::ostream& operator<<(std::ostream& os, const Device& device) {
103+
os << "(" << device.platform_id_ << ", " << device.device_id_ << ")";
104+
return os;
105+
}
106+
100107
private:
101108
int platform_id_{0};
102109
int device_id_{0};
@@ -112,6 +119,9 @@ class MMDEPLOY_API Platform {
112119
// throws if not found
113120
explicit Platform(int platform_id);
114121

122+
// bind device with the current thread
123+
Result<void> Bind(Device device, Device* prev);
124+
115125
// -1 if invalid
116126
int GetPlatformId() const;
117127

@@ -135,6 +145,27 @@ class MMDEPLOY_API Platform {
135145

136146
MMDEPLOY_API const char* GetPlatformName(PlatformId id);
137147

148+
class DeviceGuard {
149+
public:
150+
explicit DeviceGuard(Device device) : platform_(device.platform_id()) {
151+
auto r = platform_.Bind(device, &prev_);
152+
if (!r) {
153+
MMDEPLOY_ERROR("failed to bind device {}: {}", device, r.error().message().c_str());
154+
}
155+
}
156+
157+
~DeviceGuard() {
158+
auto r = platform_.Bind(prev_, nullptr);
159+
if (!r) {
160+
MMDEPLOY_ERROR("failed to unbind device {}: {}", prev_, r.error().message().c_str());
161+
}
162+
}
163+
164+
private:
165+
Platform platform_;
166+
Device prev_;
167+
};
168+
138169
class MMDEPLOY_API Stream {
139170
public:
140171
Stream() = default;

csrc/mmdeploy/core/device_impl.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ Platform::Platform(int platform_id) {
5454
}
5555
}
5656

57+
Result<void> Platform::Bind(Device device, Device* prev) { return impl_->BindDevice(device, prev); }
58+
5759
const char* GetPlatformName(PlatformId id) {
5860
if (auto impl = gPlatformRegistry().GetPlatformImpl(id); impl) {
5961
return impl->GetPlatformName();

csrc/mmdeploy/core/device_impl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class PlatformImpl {
2727

2828
virtual void SetPlatformId(int id) { platform_id_ = id; }
2929

30-
virtual Result<void> SetDevice(Device device) { return success(); };
30+
virtual Result<void> BindDevice(Device device, Device* prev) = 0;
3131

3232
virtual shared_ptr<BufferImpl> CreateBuffer(Device device) = 0;
3333

csrc/mmdeploy/device/cpu/cpu_device.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,14 @@ class CpuHostMemory : public NonCopyable {
7070
////////////////////////////////////////////////////////////////////////////////
7171
/// CpuPlatformImpl
7272

73+
Result<void> CpuPlatformImpl::BindDevice(Device device, Device* prev) {
74+
// do nothing
75+
if (prev) {
76+
*prev = device;
77+
}
78+
return success();
79+
}
80+
7381
shared_ptr<BufferImpl> CpuPlatformImpl::CreateBuffer(Device device) {
7482
return std::make_shared<CpuBufferImpl>(device);
7583
}

csrc/mmdeploy/device/cpu/cpu_device.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ class CpuPlatformImpl : public PlatformImpl {
1717

1818
const char* GetPlatformName() const noexcept override;
1919

20+
Result<void> BindDevice(Device device, Device* prev) override;
21+
2022
shared_ptr<BufferImpl> CreateBuffer(Device device) override;
2123

2224
shared_ptr<StreamImpl> CreateStream(Device device) override;

csrc/mmdeploy/device/cuda/cuda_device.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,32 @@ shared_ptr<EventImpl> CudaPlatformImpl::CreateEvent(Device device) {
127127
return std::make_shared<CudaEventImpl>(device);
128128
}
129129

130+
Result<void> CudaPlatformImpl::BindDevice(Device device, Device* prev) {
131+
if (device.platform_id() != platform_id_) {
132+
return Status(eInvalidArgument);
133+
}
134+
// skip null device
135+
if (device.device_id() == -1) {
136+
return success();
137+
}
138+
int prev_device_id = -1;
139+
if (prev) {
140+
CUcontext ctx{};
141+
cuCtxGetCurrent(&ctx);
142+
if (ctx) {
143+
cudaGetDevice(&prev_device_id);
144+
*prev = Device(platform_id_, prev_device_id);
145+
} else {
146+
// cuda is not initialized return a null device as previous
147+
*prev = Device(platform_id_, -1);
148+
}
149+
}
150+
if (device.device_id() != prev_device_id) {
151+
cudaSetDevice(device.device_id());
152+
}
153+
return success();
154+
}
155+
130156
bool CudaPlatformImpl::CheckCopyDevice(const Device& src, const Device& dst, const Device& st) {
131157
return st.is_device() && (src.is_host() || src == st) && (dst.is_host() || dst == st);
132158
}

csrc/mmdeploy/device/cuda/cuda_device.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ class CudaPlatformImpl : public PlatformImpl {
2828

2929
const char* GetPlatformName() const noexcept override { return "cuda"; }
3030

31+
Result<void> BindDevice(Device device, Device* prev) override;
32+
3133
shared_ptr<BufferImpl> CreateBuffer(Device device) override;
3234

3335
shared_ptr<StreamImpl> CreateStream(Device device) override;
@@ -178,7 +180,9 @@ class CudaDeviceGuard {
178180
if (ctx) {
179181
cudaGetDevice(&prev_device_id_);
180182
}
181-
if (prev_device_id_ != device_id_) cudaSetDevice(device_id_);
183+
if (prev_device_id_ != device_id_) {
184+
cudaSetDevice(device_id_);
185+
}
182186
}
183187
~CudaDeviceGuard() {
184188
if (prev_device_id_ >= 0 && prev_device_id_ != device_id_) {

csrc/mmdeploy/net/ort/ort_net.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ Result<void> OrtNet::Init(const Value& args) {
4040
auto& context = args["context"];
4141
device_ = context["device"].get<Device>();
4242
stream_ = context["stream"].get<Stream>();
43-
43+
DeviceGuard guard(device_);
4444
auto name = args["name"].get<std::string>();
4545
auto model = context["model"].get<Model>();
4646

@@ -150,6 +150,7 @@ static Result<Tensor> AsTensor(Ort::Value& value, const Device& device) {
150150
}
151151

152152
Result<void> OrtNet::Forward() {
153+
DeviceGuard guard(device_);
153154
try {
154155
OUTCOME_TRY(stream_.Wait());
155156
Ort::IoBinding binding(session_);
@@ -186,6 +187,11 @@ Result<void> OrtNet::Forward() {
186187
return success();
187188
}
188189

190+
OrtNet::~OrtNet() {
191+
DeviceGuard guard(device_);
192+
session_ = Ort::Session{nullptr};
193+
}
194+
189195
static std::unique_ptr<Net> Create(const Value& args) {
190196
try {
191197
auto p = std::make_unique<OrtNet>();

csrc/mmdeploy/net/ort/ort_net.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ namespace mmdeploy::framework {
1111

1212
class OrtNet : public Net {
1313
public:
14-
~OrtNet() override = default;
14+
~OrtNet() override;
1515
Result<void> Init(const Value& cfg) override;
1616
Result<void> Deinit() override;
1717
Result<Span<Tensor>> GetInputTensors() override;

csrc/mmdeploy/net/trt/trt_net.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,11 @@ static inline Result<void> trt_try(bool code, const char* msg = nullptr, Status
7979

8080
#define TRT_TRY(...) OUTCOME_TRY(trt_try(__VA_ARGS__))
8181

82-
TRTNet::~TRTNet() = default;
82+
TRTNet::~TRTNet() {
83+
CudaDeviceGuard guard(device_);
84+
context_.reset();
85+
engine_.reset();
86+
}
8387

8488
static Result<DataType> MapDataType(nvinfer1::DataType dtype) {
8589
switch (dtype) {
@@ -106,6 +110,7 @@ Result<void> TRTNet::Init(const Value& args) {
106110
MMDEPLOY_ERROR("TRTNet: device must be a GPU!");
107111
return Status(eNotSupported);
108112
}
113+
CudaDeviceGuard guard(device_);
109114
stream_ = context["stream"].get<Stream>();
110115

111116
event_ = Event(device_);
@@ -156,13 +161,10 @@ Result<void> TRTNet::Init(const Value& args) {
156161
return success();
157162
}
158163

159-
Result<void> TRTNet::Deinit() {
160-
context_.reset();
161-
engine_.reset();
162-
return success();
163-
}
164+
Result<void> TRTNet::Deinit() { return success(); }
164165

165166
Result<void> TRTNet::Reshape(Span<TensorShape> input_shapes) {
167+
CudaDeviceGuard guard(device_);
166168
using namespace trt_detail;
167169
if (input_shapes.size() != input_tensors_.size()) {
168170
return Status(eInvalidArgument);
@@ -190,6 +192,7 @@ Result<Span<Tensor>> TRTNet::GetInputTensors() { return input_tensors_; }
190192
Result<Span<Tensor>> TRTNet::GetOutputTensors() { return output_tensors_; }
191193

192194
Result<void> TRTNet::Forward() {
195+
CudaDeviceGuard guard(device_);
193196
using namespace trt_detail;
194197
std::vector<void*> bindings(engine_->getNbBindings());
195198

0 commit comments

Comments
 (0)