Skip to content

Commit 7bb919b

Browse files
JakeStevensfacebook-github-bot
authored andcommitted
Restore Conv1d Channel Last Fix (#18290)
Summary: D93658558 accidentally dropped the conv1d check which was added shortly before it, affecting indexing and resulting in incorrect results Differential Revision: D96488193
1 parent b40d6fe commit 7bb919b

File tree

2 files changed

+15
-18
lines changed

2 files changed

+15
-18
lines changed

backends/cadence/generic/operators/op_quantized_conv2d.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -510,14 +510,13 @@ void quantized_conv2d_nhwc(
510510
const int c = static_cast<int>(conv1d ? input.size(2) : input.size(3));
511511
// Depthwise is defined by in_channels == groups; depthwise weights have one
512512
// fewer dim than regular weights because the IC dim (always 1) was squeezed.
513-
const bool is_depthwise =
514-
!conv1d && c == groups && weight.dim() < input.dim();
513+
const bool is_depthwise = c == groups && weight.dim() < input.dim();
515514
int oc, wh, ww, wc;
516515
if (is_depthwise) {
517-
// Depthwise weight is [KH, KW, OC]
518-
wh = static_cast<int>(weight.size(0));
519-
ww = static_cast<int>(weight.size(1));
520-
oc = static_cast<int>(weight.size(2));
516+
// Depthwise weight: conv2d=[KH, KW, OC], conv1d=[K, OC]
517+
wh = static_cast<int>(conv1d ? 1 : weight.size(0));
518+
ww = static_cast<int>(conv1d ? weight.size(0) : weight.size(1));
519+
oc = static_cast<int>(conv1d ? weight.size(1) : weight.size(2));
521520
wc = 1;
522521
} else {
523522
// Regular weight is [OC, WH, WW, WC] or for conv1d [OC, WW, WC]

backends/cadence/hifi/operators/op_quantized_conv2d_nhwc_out.cpp

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -183,17 +183,16 @@ void xa_opt_quantized_conv2d_nhwc(
183183
// Depthwise is defined by in_channels == groups; depthwise weights have one
184184
// fewer dim than regular weights because the IC dim (always 1) was
185185
// squeezed.
186-
bool is_depthwise =
187-
!conv1d && input_channels == groups && weight.dim() < input.dim();
186+
bool is_depthwise = input_channels == groups && weight.dim() < input.dim();
188187
WORD32 kernel_height;
189188
WORD32 kernel_width;
190189
WORD32 kernel_channels;
191190
WORD32 out_channels;
192191
if (is_depthwise) {
193-
// Depthwise weight is [KH, KW, OC]
194-
kernel_height = weight.size(0);
195-
kernel_width = weight.size(1);
196-
out_channels = weight.size(2);
192+
// Depthwise weight: conv2d=[KH, KW, OC], conv1d=[K, OC]
193+
kernel_height = conv1d ? 1 : weight.size(0);
194+
kernel_width = conv1d ? weight.size(0) : weight.size(1);
195+
out_channels = conv1d ? weight.size(1) : weight.size(2);
197196
kernel_channels = 1;
198197
} else {
199198
// Regular weight is [OC, IC, KH, KW] or for conv1d [OC, K, IC]
@@ -384,14 +383,13 @@ void quantized_conv2d_nhwc(
384383
const int c = conv1d ? input.size(2) : input.size(3);
385384
// Depthwise is defined by in_channels == groups; depthwise weights have one
386385
// fewer dim than regular weights because the IC dim (always 1) was squeezed.
387-
const bool is_depthwise =
388-
!conv1d && c == groups && weight.dim() < input.dim();
386+
const bool is_depthwise = c == groups && weight.dim() < input.dim();
389387
int oc, wh, ww, wc;
390388
if (is_depthwise) {
391-
// Depthwise weight is [KH, KW, OC]
392-
wh = weight.size(0);
393-
ww = weight.size(1);
394-
oc = weight.size(2);
389+
// Depthwise weight: conv2d=[KH, KW, OC], conv1d=[K, OC]
390+
wh = conv1d ? 1 : weight.size(0);
391+
ww = conv1d ? weight.size(0) : weight.size(1);
392+
oc = conv1d ? weight.size(1) : weight.size(2);
395393
wc = 1;
396394
} else {
397395
// Regular weight is [OC, WH, WW, WC] or for conv1d [OC, WW, WC]

0 commit comments

Comments
 (0)