Skip to content

Commit 706e13e

Browse files
authored
implement affinegrid cpu kernel (#17777)
1 parent 2c6b31c commit 706e13e

6 files changed

Lines changed: 457 additions & 0 deletions

File tree

docs/OperatorKernels.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ Do not modify directly.*
2525
|||13|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)|
2626
|||[7, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)|
2727
|Affine|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
28+
|AffineGrid|*in* theta:**T1**<br> *in* size:**T2**<br> *out* grid:**T1**|20+|**T1** = tensor(double), tensor(float)<br/> **T2** = tensor(int64)|
2829
|And|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T1**|7+|**T** = tensor(bool)<br/> **T1** = tensor(bool)|
2930
|ArgMax|*in* data:**T**<br> *out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int8), tensor(uint8)|
3031
|||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int8), tensor(uint8)|

onnxruntime/core/providers/cpu/cpu_execution_provider.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -960,6 +960,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, Sh
960960

961961
// Opset 20
962962
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, ConstantOfShape);
963+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, AffineGrid);
964+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, AffineGrid);
963965
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, IsNaN);
964966
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, IsNaN);
965967
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, MLFloat16, IsNaN);
@@ -2399,6 +2401,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
23992401

24002402
// Opset 20
24012403
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, ConstantOfShape)>,
2404+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, AffineGrid)>,
2405+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, AffineGrid)>,
24022406
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, IsNaN)>,
24032407
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, IsNaN)>,
24042408
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, MLFloat16, IsNaN)>,
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "core/providers/cpu/tensor/affine_grid.h"
5+
6+
#include "core/common/common.h"
7+
#include "core/providers/op_kernel_type_control.h"
8+
#include "core/util/math_cpuonly.h"
9+
#include <iostream>
10+
#include "Eigen/src/Core/Map.h"
11+
#include <Eigen/Dense>
12+
#include "core/common/eigen_common_wrapper.h"
13+
14+
namespace onnxruntime {
15+
16+
#define REGISTER_KERNEL_TYPED(T) \
17+
ONNX_CPU_OPERATOR_TYPED_KERNEL( \
18+
AffineGrid, \
19+
20, \
20+
T, \
21+
KernelDefBuilder() \
22+
.TypeConstraint("T1", DataTypeImpl::GetTensorType<T>()) \
23+
.TypeConstraint("T2", DataTypeImpl::GetTensorType<int64_t>()), \
24+
AffineGrid<T>);
25+
26+
REGISTER_KERNEL_TYPED(float)
27+
REGISTER_KERNEL_TYPED(double)
28+
29+
template <typename T>
30+
void generate_base_grid_2d(int64_t H, int64_t W, bool align_corners, Eigen::Matrix<T, Eigen::Dynamic, 2>& base_grid) {
31+
Eigen::VectorXf row_vec = Eigen::VectorXf::LinSpaced(static_cast<Eigen::Index>(W), -1, 1);
32+
if (!align_corners) {
33+
row_vec = row_vec * (W - 1) / W;
34+
}
35+
Eigen::VectorXf col_vec = Eigen::VectorXf::LinSpaced(static_cast<Eigen::Index>(H), -1, 1);
36+
if (!align_corners) {
37+
col_vec = col_vec * (H - 1) / H;
38+
}
39+
40+
base_grid.resize(static_cast<Eigen::Index>(H * W), 2);
41+
for (Eigen::Index j = 0; j < H; j++) {
42+
for (Eigen::Index i = 0; i < W; i++) {
43+
base_grid.row(j * static_cast<Eigen::Index>(W) + i) << row_vec(i), col_vec(j);
44+
}
45+
}
46+
}
47+
48+
template <typename T>
49+
void generate_base_grid_3d(int64_t D, int64_t H, int64_t W, bool align_corners, Eigen::Matrix<T, Eigen::Dynamic, 3>& base_grid) {
50+
Eigen::VectorXf row_vec = Eigen::VectorXf::LinSpaced(static_cast<Eigen::Index>(W), -1, 1);
51+
if (!align_corners) {
52+
row_vec = row_vec * (W - 1) / W;
53+
}
54+
Eigen::VectorXf col_vec = Eigen::VectorXf::LinSpaced(static_cast<Eigen::Index>(H), -1, 1);
55+
if (!align_corners) {
56+
col_vec = col_vec * (H - 1) / H;
57+
}
58+
Eigen::VectorXf slice_vec = Eigen::VectorXf::LinSpaced(static_cast<Eigen::Index>(D), -1, 1);
59+
if (!align_corners) {
60+
slice_vec = slice_vec * (D - 1) / D;
61+
}
62+
63+
base_grid.resize(static_cast<Eigen::Index>(D * H * W), 3);
64+
for (Eigen::Index k = 0; k < D; k++) {
65+
for (Eigen::Index j = 0; j < H; j++) {
66+
for (Eigen::Index i = 0; i < W; i++) {
67+
base_grid.row(k * static_cast<Eigen::Index>(H * W) + j * static_cast<Eigen::Index>(W) + i) << row_vec(i), col_vec(j), slice_vec(k);
68+
}
69+
}
70+
}
71+
}
72+
73+
template <typename T>
74+
void affine_grid_generator_2d(const Tensor* theta, const Eigen::Matrix<T, 2, Eigen::Dynamic>& base_grid_transposed, int64_t batch_num, int64_t H, int64_t W, Tensor* grid) {
75+
const Eigen::StorageOptions option = Eigen::RowMajor;
76+
auto theta_batch_offset = batch_num * 2 * 3;
77+
const T* theta_data = theta->Data<T>() + theta_batch_offset;
78+
const Eigen::Matrix<T, 2, 2, option> theta_R{{theta_data[0], theta_data[1]}, {theta_data[3], theta_data[4]}};
79+
const Eigen::Array<T, 2, 1> theta_T(theta_data[2], theta_data[5]);
80+
81+
auto grid_batch_offset = batch_num * H * W * 2;
82+
T* grid_data = grid->MutableData<T>() + grid_batch_offset;
83+
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, 2, option>> grid_matrix(grid_data, narrow<size_t>(H * W), 2);
84+
grid_matrix = ((theta_R * base_grid_transposed).array().colwise() + theta_T).matrix().transpose();
85+
}
86+
87+
template <typename T>
88+
void affine_grid_generator_3d(const Tensor* theta, const Eigen::Matrix<T, 3, Eigen::Dynamic>& base_grid_transposed, int64_t batch_num, int64_t D, int64_t H, int64_t W, Tensor* grid) {
89+
const Eigen::StorageOptions option = Eigen::RowMajor;
90+
auto theta_batch_offset = batch_num * 3 * 4;
91+
const T* theta_data = theta->Data<T>() + theta_batch_offset;
92+
const Eigen::Matrix<T, 3, 3, option> theta_R{
93+
{theta_data[0], theta_data[1], theta_data[2]},
94+
{theta_data[4], theta_data[5], theta_data[6]},
95+
{theta_data[8], theta_data[9], theta_data[10]}};
96+
const Eigen::Array<T, 3, 1> theta_T(theta_data[3], theta_data[7], theta_data[11]);
97+
98+
auto grid_batch_offset = batch_num * D * H * W * 3;
99+
T* grid_data = grid->MutableData<T>() + grid_batch_offset;
100+
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, 3, option>> grid_matrix(grid_data, narrow<size_t>(D * H * W), 3);
101+
grid_matrix = ((theta_R * base_grid_transposed).array().colwise() + theta_T).matrix().transpose();
102+
}
103+
104+
template <typename T>
105+
Status AffineGrid<T>::Compute(OpKernelContext* context) const {
106+
const Tensor* theta = context->Input<Tensor>(0);
107+
const TensorShape& theta_shape = theta->Shape();
108+
if (theta_shape.NumDimensions() != 3) {
109+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "AffineGrid : Input theta tensor dimension is not 3");
110+
}
111+
112+
const Tensor* size = context->Input<Tensor>(1);
113+
const TensorShape& size_shape = size->Shape();
114+
const int64_t* size_data = size->Data<int64_t>();
115+
116+
if (size_shape.GetDims()[0] == 4 /*&& get_check_2d_grid_sample_consistency(theta_shape, size_shape, N, C, H, W)*/) {
117+
int64_t N = size_data[0], H = size_data[2], W = size_data[3];
118+
119+
TensorShape grid_shape{N, H, W, 2};
120+
auto grid = context->Output(0, grid_shape);
121+
122+
Eigen::Matrix<T, Eigen::Dynamic, 2> base_grid;
123+
generate_base_grid_2d(H, W, align_corners_, base_grid);
124+
Eigen::Matrix<T, 2, Eigen::Dynamic> base_grid_transposed = base_grid.transpose();
125+
126+
std::function<void(ptrdiff_t)> fn = [theta, base_grid_transposed, H, W, grid](ptrdiff_t batch_num) {
127+
affine_grid_generator_2d(theta, base_grid_transposed, batch_num, H, W, grid);
128+
};
129+
130+
concurrency::ThreadPool::TryBatchParallelFor(context->GetOperatorThreadPool(), narrow<size_t>(N), std::move(fn), 0);
131+
} else if (size_shape.GetDims()[0] == 5 /*&& get_check_2d_grid_sample_consistency(theta_shape, size_shape, N, C, H, W)*/) {
132+
int64_t N = size_data[0], D = size_data[2], H = size_data[3], W = size_data[4];
133+
134+
TensorShape grid_shape{N, D, H, W, 3};
135+
auto grid = context->Output(0, grid_shape);
136+
137+
Eigen::Matrix<T, Eigen::Dynamic, 3> base_grid;
138+
generate_base_grid_3d(D, H, W, align_corners_, base_grid);
139+
Eigen::Matrix<T, 3, Eigen::Dynamic> base_grid_transposed = base_grid.transpose();
140+
141+
std::function<void(ptrdiff_t)> fn = [theta, base_grid_transposed, D, H, W, grid](ptrdiff_t batch_num) {
142+
affine_grid_generator_3d(theta, base_grid_transposed, batch_num, D, H, W, grid);
143+
};
144+
145+
concurrency::ThreadPool::TryBatchParallelFor(context->GetOperatorThreadPool(), narrow<size_t>(N), std::move(fn), 0);
146+
} else {
147+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "AffineGrid : Invalidate size - length of size should be 4 or 5.");
148+
}
149+
return Status::OK();
150+
}
151+
} // namespace onnxruntime
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include "core/common/common.h"
7+
#include "core/framework/op_kernel.h"
8+
9+
namespace onnxruntime {
10+
11+
template <typename T>
12+
class AffineGrid final : public OpKernel {
13+
public:
14+
AffineGrid(const OpKernelInfo& info) : OpKernel(info) {
15+
int64_t align_corners = info.GetAttrOrDefault<int64_t>("align_corners", 0);
16+
align_corners_ = (align_corners != 0);
17+
}
18+
19+
Status Compute(OpKernelContext* context) const override;
20+
21+
private:
22+
bool align_corners_;
23+
};
24+
25+
} // namespace onnxruntime

0 commit comments

Comments
 (0)