Skip to content

Commit 1ad2bc6

Browse files
authored
pnnx ncnn handles squeeze/unsqueeze the batch index (#6232)
1 parent e207b3b commit 1ad2bc6

File tree

1 file changed

+21
-5
lines changed

1 file changed

+21
-5
lines changed

tools/pnnx/src/pass_ncnn/solve_batch_index.cpp

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -389,14 +389,14 @@ static void solve_batch_index_forward(Operand* operand)
389389
dim += input_rank0;
390390

391391
int batch_index_squeezed = batch_index;
392-
if (dim >= 0 && dim < batch_index)
393-
{
394-
batch_index_squeezed = batch_index - 1;
395-
}
396392
if (dim >= 0 && dim == batch_index)
397393
{
398394
batch_index_squeezed = 233;
399395
}
396+
else if (dim >= 0 && dim < batch_index)
397+
{
398+
batch_index_squeezed = batch_index - 1;
399+
}
400400

401401
Operand* r = op->outputs[0];
402402
if (r->params.find("__batch_index") == r->params.end())
@@ -413,6 +413,12 @@ static void solve_batch_index_forward(Operand* operand)
413413
if (dim < 0)
414414
dim += input_rank0;
415415

416+
if (batch_index == 233)
417+
{
418+
// give up
419+
return;
420+
}
421+
416422
int batch_index_unsqueezed = batch_index;
417423
if (dim >= 0 && dim <= batch_index)
418424
{
@@ -703,6 +709,12 @@ static void solve_batch_index_backward(Operand* operand)
703709
if (dim < 0)
704710
dim += input_rank0;
705711

712+
if (batch_index == 233)
713+
{
714+
// give up
715+
return;
716+
}
717+
706718
int batch_index_unsqueezed = batch_index;
707719
if (dim >= 0 && dim <= batch_index)
708720
{
@@ -725,7 +737,11 @@ static void solve_batch_index_backward(Operand* operand)
725737
dim += input_rank0;
726738

727739
int batch_index_squeezed = batch_index;
728-
if (dim >= 0 && dim <= batch_index)
740+
if (dim >= 0 && dim == batch_index)
741+
{
742+
batch_index_squeezed = 233;
743+
}
744+
else if (dim >= 0 && dim <= batch_index)
729745
{
730746
batch_index_squeezed = batch_index - 1;
731747
}

0 commit comments

Comments
 (0)