Skip to content

Commit b5b0dcf

Browse files
authored
[Fix] Support onnxruntime-1.13 (open-mmlab#1407)
* support onnxruntime-1.13 * fix lint
1 parent 4dd4d48 commit b5b0dcf

File tree

7 files changed

+21
-15
lines changed

7 files changed

+21
-15
lines changed

csrc/mmdeploy/backend_ops/onnxruntime/grid_sample/grid_sample.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ namespace mmdeploy {
1313
#define MAX(a, b) (((a) < (b)) ? (b) : (a))
1414
#define CLIP_COORDINATES(in, out, clip_limit) out = MIN((clip_limit - 1), MAX(in, 0))
1515

16-
GridSampleKernel::GridSampleKernel(OrtApi api, const OrtKernelInfo *info)
17-
: api_(api), ort_(api_), info_(info) {
16+
GridSampleKernel::GridSampleKernel(const OrtApi &api, const OrtKernelInfo *info)
17+
: ort_(api), info_(info) {
1818
align_corners_ = ort_.KernelInfoGetAttribute<int64_t>(info, "align_corners");
1919
interpolation_mode_ = ort_.KernelInfoGetAttribute<int64_t>(info, "interpolation_mode");
2020
padding_mode_ = ort_.KernelInfoGetAttribute<int64_t>(info, "padding_mode");

csrc/mmdeploy/backend_ops/onnxruntime/grid_sample/grid_sample.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,11 @@
77
namespace mmdeploy {
88

99
struct GridSampleKernel {
10-
GridSampleKernel(OrtApi api, const OrtKernelInfo *info);
10+
GridSampleKernel(const OrtApi &api, const OrtKernelInfo *info);
1111

1212
void Compute(OrtKernelContext *context);
1313

1414
protected:
15-
OrtApi api_;
1615
Ort::CustomOpApi ort_;
1716
const OrtKernelInfo *info_;
1817
Ort::AllocatorWithDefaultOptions allocator_;
@@ -23,7 +22,7 @@ struct GridSampleKernel {
2322
};
2423

2524
struct GridSampleOp : Ort::CustomOpBase<GridSampleOp, GridSampleKernel> {
26-
void *CreateKernel(OrtApi api, const OrtKernelInfo *info) const {
25+
void *CreateKernel(const OrtApi &api, const OrtKernelInfo *info) const {
2726
return new GridSampleKernel(api, info);
2827
};
2928

csrc/mmdeploy/backend_ops/onnxruntime/modulated_deform_conv/modulated_deform_conv.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,9 @@ void deformable_conv2d_ref_fp32(const float *src, const float *offset, const flo
109109
}
110110
}
111111

112-
MMCVModulatedDeformConvKernel::MMCVModulatedDeformConvKernel(OrtApi api, const OrtKernelInfo *info)
113-
: api_(api), ort_(api_), info_(info) {
112+
MMCVModulatedDeformConvKernel::MMCVModulatedDeformConvKernel(const OrtApi &api,
113+
const OrtKernelInfo *info)
114+
: ort_(api), info_(info) {
114115
std::vector<int64_t> stride = ort_.KernelInfoGetAttribute<std::vector<int64_t>>(info, "stride");
115116
stride_height_ = stride[0];
116117
stride_width_ = stride[1];

csrc/mmdeploy/backend_ops/onnxruntime/modulated_deform_conv/modulated_deform_conv.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,11 @@
77
namespace mmdeploy {
88

99
struct MMCVModulatedDeformConvKernel {
10-
MMCVModulatedDeformConvKernel(OrtApi api, const OrtKernelInfo *info);
10+
MMCVModulatedDeformConvKernel(const OrtApi &api, const OrtKernelInfo *info);
1111

1212
void Compute(OrtKernelContext *context);
1313

1414
protected:
15-
OrtApi api_;
1615
Ort::CustomOpApi ort_;
1716
const OrtKernelInfo *info_;
1817
Ort::AllocatorWithDefaultOptions allocator_;
@@ -29,7 +28,7 @@ struct MMCVModulatedDeformConvKernel {
2928

3029
struct MMCVModulatedDeformConvOp
3130
: Ort::CustomOpBase<MMCVModulatedDeformConvOp, MMCVModulatedDeformConvKernel> {
32-
void *CreateKernel(OrtApi api, const OrtKernelInfo *info) const {
31+
void *CreateKernel(const OrtApi &api, const OrtKernelInfo *info) const {
3332
return new MMCVModulatedDeformConvKernel(api, info);
3433
}
3534

csrc/mmdeploy/backend_ops/onnxruntime/nms_rotated/nms_rotated.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -261,8 +261,8 @@ float rotated_boxes_intersection(const RotatedBox& box1, const RotatedBox& box2)
261261
return polygon_area(orderedPts, num_convex);
262262
}
263263

264-
NMSRotatedKernel::NMSRotatedKernel(OrtApi api, const OrtKernelInfo* info)
265-
: api_(api), ort_(api_), info_(info) {
264+
NMSRotatedKernel::NMSRotatedKernel(const OrtApi& api, const OrtKernelInfo* info)
265+
: ort_(api), info_(info) {
266266
iou_threshold_ = ort_.KernelInfoGetAttribute<float>(info, "iou_threshold");
267267
score_threshold_ = ort_.KernelInfoGetAttribute<float>(info, "score_threshold");
268268

csrc/mmdeploy/backend_ops/onnxruntime/nms_rotated/nms_rotated.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,11 @@
1212

1313
namespace mmdeploy {
1414
struct NMSRotatedKernel {
15-
NMSRotatedKernel(OrtApi api, const OrtKernelInfo* info);
15+
NMSRotatedKernel(const OrtApi& api, const OrtKernelInfo* info);
1616

1717
void Compute(OrtKernelContext* context);
1818

1919
private:
20-
OrtApi api_;
2120
Ort::CustomOpApi ort_;
2221
const OrtKernelInfo* info_;
2322
Ort::AllocatorWithDefaultOptions allocator_;
@@ -26,7 +25,7 @@ struct NMSRotatedKernel {
2625
};
2726

2827
struct NMSRotatedOp : Ort::CustomOpBase<NMSRotatedOp, NMSRotatedKernel> {
29-
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
28+
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
3029
return new NMSRotatedKernel(api, info);
3130
}
3231
const char* GetName() const { return "NMSRotated"; }

csrc/mmdeploy/net/ort/ort_net.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,11 @@ Result<void> OrtNet::Init(const Value& args) {
7474
};
7575

7676
for (int i = 0; i < n_inputs; ++i) {
77+
#if ORT_API_VERSION >= 13
78+
auto input_name = session_.GetInputNameAllocated(i, allocator).release();
79+
#else
7780
auto input_name = session_.GetInputName(i, allocator);
81+
#endif
7882
auto type_info = session_.GetInputTypeInfo(i);
7983
auto shape = to_shape(type_info);
8084
MMDEPLOY_DEBUG("input {}, shape = {}", i, shape);
@@ -88,7 +92,11 @@ Result<void> OrtNet::Init(const Value& args) {
8892
auto n_outputs = session_.GetOutputCount();
8993

9094
for (int i = 0; i < n_outputs; ++i) {
95+
#if ORT_API_VERSION >= 13
96+
auto output_name = session_.GetOutputNameAllocated(i, allocator).release();
97+
#else
9198
auto output_name = session_.GetOutputName(i, allocator);
99+
#endif
92100
auto type_info = session_.GetOutputTypeInfo(i);
93101
auto shape = to_shape(type_info);
94102
MMDEPLOY_DEBUG("output {}, shape = {}", i, shape);

0 commit comments

Comments
 (0)