Skip to content

Commit 9b91fe5

Browse files
nihuiBaiyuetribe
andauthored
implement flip layer and pnnx torch.flip conversion (#6233)
Co-authored-by: 佰阅 <[email protected]>
1 parent 1ad2bc6 commit 9b91fe5

20 files changed

+730
-0
lines changed

docs/developer-guide/operators.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
* [Embed](#embed)
3434
* [Exp](#exp)
3535
* [Flatten](#flatten)
36+
* [Flip](#flip)
3637
* [Fold](#fold)
3738
* [GELU](#gelu)
3839
* [GLU](#glu)
@@ -870,6 +871,14 @@ Reshape blob to 1 dimension
870871

871872
* one_blob_only
872873

874+
# Flip
875+
876+
* one_blob_only
877+
878+
| param id | name | type | default | description |
879+
| --------- | ------------- | ----- | --------- | ----------------- |
880+
| 0 | axes | array | [ ] | |
881+
873882
# Fold
874883
```
875884
y = fold(x)

src/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ ncnn_add_layer(Shrink)
170170
ncnn_add_layer(RMSNorm)
171171
ncnn_add_layer(Spectrogram)
172172
ncnn_add_layer(InverseSpectrogram)
173+
ncnn_add_layer(Flip)
173174

174175
if(NCNN_VULKAN)
175176
ncnn_add_shader(${CMAKE_CURRENT_SOURCE_DIR}/convert_ycbcr.comp)

src/layer/flip.cpp

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
// Copyright 2025 Tencent
2+
// SPDX-License-Identifier: BSD-3-Clause
3+
4+
#include "flip.h"
5+
6+
namespace ncnn {
7+
8+
Flip::Flip()
9+
{
10+
one_blob_only = true;
11+
}
12+
13+
int Flip::load_param(const ParamDict& pd)
14+
{
15+
axes = pd.get(0, Mat());
16+
17+
if (axes.w > 4)
18+
{
19+
// only handle up to 4-dim
20+
return -1;
21+
}
22+
23+
return 0;
24+
}
25+
26+
int Flip::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
27+
{
28+
if (axes.empty())
29+
{
30+
top_blob = bottom_blob;
31+
return 0;
32+
}
33+
34+
const int dims = bottom_blob.dims;
35+
const int w = bottom_blob.w;
36+
const int h = bottom_blob.h;
37+
const int d = bottom_blob.d;
38+
const int channels = bottom_blob.c;
39+
40+
int axes_flag[4] = {0};
41+
bool flip_w = false;
42+
bool flip_h = false;
43+
bool flip_d = false;
44+
bool flip_c = false;
45+
{
46+
const int* axes_ptr = axes;
47+
for (int i = 0; i < axes.w; i++)
48+
{
49+
int axis = axes_ptr[i];
50+
// handle negative axis
51+
if (axis < 0)
52+
axis += dims;
53+
axes_flag[axis] = 1;
54+
}
55+
56+
if (dims == 1)
57+
{
58+
flip_w = true;
59+
}
60+
else if (dims == 2)
61+
{
62+
if (axes_flag[0] == 1) flip_h = true;
63+
if (axes_flag[1] == 1) flip_w = true;
64+
}
65+
else if (dims == 3)
66+
{
67+
if (axes_flag[0] == 1) flip_c = true;
68+
if (axes_flag[1] == 1) flip_h = true;
69+
if (axes_flag[2] == 1) flip_w = true;
70+
}
71+
else if (dims == 4)
72+
{
73+
if (axes_flag[0] == 1) flip_c = true;
74+
if (axes_flag[1] == 1) flip_d = true;
75+
if (axes_flag[2] == 1) flip_h = true;
76+
if (axes_flag[3] == 1) flip_w = true;
77+
}
78+
}
79+
80+
top_blob.create_like(bottom_blob, opt.blob_allocator);
81+
if (top_blob.empty())
82+
return -100;
83+
84+
#pragma omp parallel for num_threads(opt.num_threads)
85+
for (int q = 0; q < channels; q++)
86+
{
87+
for (int z = 0; z < d; z++)
88+
{
89+
for (int i = 0; i < h; i++)
90+
{
91+
int q2 = flip_c ? channels - 1 - q : q;
92+
int z2 = flip_d ? d - 1 - z : z;
93+
int i2 = flip_h ? h - 1 - i : i;
94+
95+
const float* ptr = bottom_blob.channel(q2).depth(z2).row(i2);
96+
float* outptr = top_blob.channel(q).depth(z).row(i);
97+
98+
if (flip_w)
99+
{
100+
ptr += w - 1;
101+
for (int j = 0; j < w; j++)
102+
{
103+
*outptr++ = *ptr--;
104+
}
105+
}
106+
else
107+
{
108+
memcpy(outptr, ptr, w * sizeof(float));
109+
}
110+
}
111+
}
112+
}
113+
114+
return 0;
115+
}
116+
117+
} // namespace ncnn

src/layer/flip.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// Copyright 2025 Tencent
2+
// SPDX-License-Identifier: BSD-3-Clause
3+
4+
#ifndef LAYER_FLIP_H
5+
#define LAYER_FLIP_H
6+
7+
#include "layer.h"
8+
9+
namespace ncnn {
10+
11+
class Flip : public Layer
12+
{
13+
public:
14+
Flip();
15+
16+
virtual int load_param(const ParamDict& pd);
17+
18+
virtual int forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const;
19+
20+
public:
21+
Mat axes;
22+
};
23+
24+
} // namespace ncnn
25+
26+
#endif // LAYER_FLIP_H

tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ ncnn_add_layer_test(Embed)
107107
ncnn_add_layer_test(Erf)
108108
ncnn_add_layer_test(ExpandDims)
109109
ncnn_add_layer_test(Flatten)
110+
ncnn_add_layer_test(Flip)
110111
ncnn_add_layer_test(Fold)
111112
ncnn_add_layer_test(GELU)
112113
ncnn_add_layer_test(GLU)

tests/test_flip.cpp

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
// Copyright 2025 Tencent
2+
// SPDX-License-Identifier: BSD-3-Clause
3+
4+
#include "testutil.h"
5+
6+
static std::vector<int> IntArray(int a0)
7+
{
8+
std::vector<int> m(1);
9+
m[0] = a0;
10+
return m;
11+
}
12+
13+
static std::vector<int> IntArray(int a0, int a1)
14+
{
15+
std::vector<int> m(2);
16+
m[0] = a0;
17+
m[1] = a1;
18+
return m;
19+
}
20+
21+
static std::vector<int> IntArray(int a0, int a1, int a2)
22+
{
23+
std::vector<int> m(3);
24+
m[0] = a0;
25+
m[1] = a1;
26+
m[2] = a2;
27+
return m;
28+
}
29+
30+
static std::vector<int> IntArray(int a0, int a1, int a2, int a3)
31+
{
32+
std::vector<int> m(4);
33+
m[0] = a0;
34+
m[1] = a1;
35+
m[2] = a2;
36+
m[3] = a3;
37+
return m;
38+
}
39+
40+
static void print_int_array(const std::vector<int>& a)
41+
{
42+
fprintf(stderr, "[");
43+
for (size_t i = 0; i < a.size(); i++)
44+
{
45+
fprintf(stderr, " %d", a[i]);
46+
}
47+
fprintf(stderr, " ]");
48+
}
49+
50+
static int test_flip(const ncnn::Mat& a, const std::vector<int>& axes_array)
51+
{
52+
ncnn::Mat axes(axes_array.size());
53+
{
54+
int* p = axes;
55+
for (size_t i = 0; i < axes_array.size(); i++)
56+
{
57+
p[i] = axes_array[i];
58+
}
59+
}
60+
61+
ncnn::ParamDict pd;
62+
pd.set(0, axes);
63+
64+
std::vector<ncnn::Mat> weights(0);
65+
66+
int ret = test_layer("Flip", pd, weights, a);
67+
if (ret != 0)
68+
{
69+
fprintf(stderr, "test_flip failed a.dims=%d a=(%d %d %d %d)", a.dims, a.w, a.h, a.d, a.c);
70+
fprintf(stderr, " axes=");
71+
print_int_array(axes_array);
72+
fprintf(stderr, "\n");
73+
}
74+
75+
return ret;
76+
}
77+
78+
static int test_flip_nd(const ncnn::Mat& a)
79+
{
80+
int ret1 = test_flip(a, IntArray(0));
81+
82+
if (a.dims == 1 || ret1 != 0)
83+
return ret1;
84+
85+
int ret2 = 0
86+
|| test_flip(a, IntArray(0))
87+
|| test_flip(a, IntArray(1))
88+
|| test_flip(a, IntArray(0, 1));
89+
90+
if (a.dims == 2 || ret2 != 0)
91+
return ret2;
92+
93+
int ret3 = 0
94+
|| test_flip(a, IntArray(0))
95+
|| test_flip(a, IntArray(1))
96+
|| test_flip(a, IntArray(2))
97+
|| test_flip(a, IntArray(0, 1))
98+
|| test_flip(a, IntArray(0, 2))
99+
|| test_flip(a, IntArray(1, 2))
100+
|| test_flip(a, IntArray(0, 1, 2));
101+
102+
if (a.dims == 3 || ret3 != 0)
103+
return ret3;
104+
105+
int ret4 = 0
106+
|| test_flip(a, IntArray(0))
107+
|| test_flip(a, IntArray(1))
108+
|| test_flip(a, IntArray(2))
109+
|| test_flip(a, IntArray(3))
110+
|| test_flip(a, IntArray(0, 1))
111+
|| test_flip(a, IntArray(0, 2))
112+
|| test_flip(a, IntArray(0, 3))
113+
|| test_flip(a, IntArray(1, 2))
114+
|| test_flip(a, IntArray(1, 3))
115+
|| test_flip(a, IntArray(2, 3))
116+
|| test_flip(a, IntArray(0, 1, 2))
117+
|| test_flip(a, IntArray(0, 1, 3))
118+
|| test_flip(a, IntArray(0, 2, 3))
119+
|| test_flip(a, IntArray(1, 2, 3))
120+
|| test_flip(a, IntArray(0, 1, 2, 3));
121+
122+
return ret4;
123+
}
124+
125+
static int test_flip_0()
126+
{
127+
ncnn::Mat a = RandomMat(5, 6, 7, 24);
128+
ncnn::Mat b = RandomMat(7, 8, 9, 12);
129+
ncnn::Mat c = RandomMat(3, 4, 5, 13);
130+
131+
return 0
132+
|| test_flip_nd(a)
133+
|| test_flip_nd(b)
134+
|| test_flip_nd(c);
135+
}
136+
137+
static int test_flip_1()
138+
{
139+
ncnn::Mat a = RandomMat(5, 7, 24);
140+
ncnn::Mat b = RandomMat(7, 9, 12);
141+
ncnn::Mat c = RandomMat(3, 5, 13);
142+
143+
return 0
144+
|| test_flip_nd(a)
145+
|| test_flip_nd(b)
146+
|| test_flip_nd(c);
147+
}
148+
149+
static int test_flip_2()
150+
{
151+
ncnn::Mat a = RandomMat(15, 24);
152+
ncnn::Mat b = RandomMat(17, 12);
153+
ncnn::Mat c = RandomMat(19, 15);
154+
155+
return 0
156+
|| test_flip_nd(a)
157+
|| test_flip_nd(b)
158+
|| test_flip_nd(c);
159+
}
160+
161+
static int test_flip_3()
162+
{
163+
ncnn::Mat a = RandomMat(128);
164+
ncnn::Mat b = RandomMat(124);
165+
ncnn::Mat c = RandomMat(127);
166+
167+
return 0
168+
|| test_flip_nd(a)
169+
|| test_flip_nd(b)
170+
|| test_flip_nd(c);
171+
}
172+
173+
int main()
174+
{
175+
SRAND(7767517);
176+
177+
return 0
178+
|| test_flip_0()
179+
|| test_flip_1()
180+
|| test_flip_2()
181+
|| test_flip_3();
182+
}

tools/pnnx/src/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -592,6 +592,7 @@ set(pnnx_pass_ncnn_SRCS
592592
pass_ncnn/torch_cumsum.cpp
593593
pass_ncnn/torch_diag.cpp
594594
pass_ncnn/torch_flatten.cpp
595+
pass_ncnn/torch_flip.cpp
595596
pass_ncnn/torch_istft.cpp
596597
pass_ncnn/torch_logsumexp.cpp
597598
pass_ncnn/torch_matmul.cpp

0 commit comments

Comments
 (0)