@@ -389,14 +389,14 @@ static void solve_batch_index_forward(Operand* operand)
389
389
dim += input_rank0;
390
390
391
391
int batch_index_squeezed = batch_index;
392
- if (dim >= 0 && dim < batch_index)
393
- {
394
- batch_index_squeezed = batch_index - 1 ;
395
- }
396
392
if (dim >= 0 && dim == batch_index)
397
393
{
398
394
batch_index_squeezed = 233 ;
399
395
}
396
+ else if (dim >= 0 && dim < batch_index)
397
+ {
398
+ batch_index_squeezed = batch_index - 1 ;
399
+ }
400
400
401
401
Operand* r = op->outputs [0 ];
402
402
if (r->params .find (" __batch_index" ) == r->params .end ())
@@ -413,6 +413,12 @@ static void solve_batch_index_forward(Operand* operand)
413
413
if (dim < 0 )
414
414
dim += input_rank0;
415
415
416
+ if (batch_index == 233 )
417
+ {
418
+ // give up
419
+ return ;
420
+ }
421
+
416
422
int batch_index_unsqueezed = batch_index;
417
423
if (dim >= 0 && dim <= batch_index)
418
424
{
@@ -703,6 +709,12 @@ static void solve_batch_index_backward(Operand* operand)
703
709
if (dim < 0 )
704
710
dim += input_rank0;
705
711
712
+ if (batch_index == 233 )
713
+ {
714
+ // give up
715
+ return ;
716
+ }
717
+
706
718
int batch_index_unsqueezed = batch_index;
707
719
if (dim >= 0 && dim <= batch_index)
708
720
{
@@ -725,7 +737,11 @@ static void solve_batch_index_backward(Operand* operand)
725
737
dim += input_rank0;
726
738
727
739
int batch_index_squeezed = batch_index;
728
- if (dim >= 0 && dim <= batch_index)
740
+ if (dim >= 0 && dim == batch_index)
741
+ {
742
+ batch_index_squeezed = 233 ;
743
+ }
744
+ else if (dim >= 0 && dim <= batch_index)
729
745
{
730
746
batch_index_squeezed = batch_index - 1 ;
731
747
}
0 commit comments