@@ -79,7 +79,11 @@ static inline Result<void> trt_try(bool code, const char* msg = nullptr, Status
79
79
80
80
#define TRT_TRY (...) OUTCOME_TRY(trt_try(__VA_ARGS__))
81
81
82
- TRTNet::~TRTNet () = default ;
82
+ TRTNet::~TRTNet () {
83
+ CudaDeviceGuard guard (device_);
84
+ context_.reset ();
85
+ engine_.reset ();
86
+ }
83
87
84
88
static Result<DataType> MapDataType (nvinfer1::DataType dtype) {
85
89
switch (dtype) {
@@ -106,6 +110,7 @@ Result<void> TRTNet::Init(const Value& args) {
106
110
MMDEPLOY_ERROR (" TRTNet: device must be a GPU!" );
107
111
return Status (eNotSupported);
108
112
}
113
+ CudaDeviceGuard guard (device_);
109
114
stream_ = context[" stream" ].get <Stream>();
110
115
111
116
event_ = Event (device_);
@@ -156,13 +161,10 @@ Result<void> TRTNet::Init(const Value& args) {
156
161
return success ();
157
162
}
158
163
159
- Result<void > TRTNet::Deinit () {
160
- context_.reset ();
161
- engine_.reset ();
162
- return success ();
163
- }
164
+ Result<void > TRTNet::Deinit () { return success (); }
164
165
165
166
Result<void > TRTNet::Reshape (Span<TensorShape> input_shapes) {
167
+ CudaDeviceGuard guard (device_);
166
168
using namespace trt_detail ;
167
169
if (input_shapes.size () != input_tensors_.size ()) {
168
170
return Status (eInvalidArgument);
@@ -190,6 +192,7 @@ Result<Span<Tensor>> TRTNet::GetInputTensors() { return input_tensors_; }
190
192
Result<Span<Tensor>> TRTNet::GetOutputTensors () { return output_tensors_; }
191
193
192
194
Result<void > TRTNet::Forward () {
195
+ CudaDeviceGuard guard (device_);
193
196
using namespace trt_detail ;
194
197
std::vector<void *> bindings (engine_->getNbBindings ());
195
198
0 commit comments