@@ -18,16 +18,18 @@ const validateInputs = (inputs: readonly TensorView[]): void => {
1818 if ( ! inputs || inputs . length !== 1 ) {
1919 throw new Error ( 'Pool ops requires 1 input.' ) ;
2020 }
21- if ( inputs [ 0 ] . dims . length !== 4 ) {
22- throw new Error ( 'Pool ops supports 2-D inputs only for now.' ) ;
21+ if ( inputs [ 0 ] . dims . length !== 4 && inputs [ 0 ] . dims . length !== 3 ) {
22+ throw new Error ( 'Pool ops supports 1-D or 2-D inputs only for now.' ) ;
2323 }
2424} ;
2525
2626const getAdjustedPoolAttributesAndOutputShape = < AttributeType extends AveragePoolAttributes | MaxPoolAttributes > (
2727 input : TensorView , attributes : AttributeType , isGlobalOperator : boolean ) : [ AttributeType , number [ ] ] => {
2828 const isChannelsLast = attributes . format === 'NHWC' ;
29- const inputShapeAsChannelFirst =
30- isChannelsLast ? [ input . dims [ 0 ] , input . dims [ 3 ] , input . dims [ 1 ] , input . dims [ 2 ] ] : input . dims . slice ( ) ;
29+ const inputShapeAsChannelFirst = input . dims . slice ( ) ;
30+ if ( isChannelsLast ) {
31+ inputShapeAsChannelFirst . splice ( 1 , 0 , inputShapeAsChannelFirst . pop ( ) ! ) ; // Move channel to the second position.
32+ }
3133 const hasDilations = Object . hasOwnProperty . call ( attributes , 'dilations' ) ;
3234 const kernelShape = attributes . kernelShape . slice ( ) ;
3335 const strides = attributes . strides . slice ( ) ;
@@ -44,15 +46,9 @@ const getAdjustedPoolAttributesAndOutputShape = <AttributeType extends AveragePo
4446 } else {
4547 Object . assign ( newAttributes , { kernelShape, strides, pads, cacheKey : attributes . cacheKey } ) ;
4648 }
47- return [
48- newAttributes ,
49- isChannelsLast ?
50- [
51- outputShapeAsChannelFirst [ 0 ] , outputShapeAsChannelFirst [ 2 ] , outputShapeAsChannelFirst [ 3 ] ,
52- outputShapeAsChannelFirst [ 1 ]
53- ] :
54- outputShapeAsChannelFirst
55- ] ;
49+ const outputShapeAsChannelLast = outputShapeAsChannelFirst . slice ( ) ;
50+ outputShapeAsChannelLast . push ( outputShapeAsChannelLast . splice ( 1 , 1 ) [ 0 ] ) ;
51+ return [ newAttributes , isChannelsLast ? outputShapeAsChannelLast : outputShapeAsChannelFirst ] ;
5652} ;
5753
5854const generatePoolingCode = < AttributeType extends AveragePoolAttributes | MaxPoolAttributes > (
@@ -76,22 +72,22 @@ const generatePoolingCode = <AttributeType extends AveragePoolAttributes|MaxPool
7672 let codeHEnd = '' ;
7773 if ( pwStart + pwEnd !== 0 ) {
7874 codeW = `
79- for (var i: u32 = 0u; i < ${ kw } u; i++) {
80- xIndices[${ dimIdxW } ] = indices[${ dimIdxW } ] * ${ sw } - ${ pwStart } + i;
81- if (xIndices[${ dimIdxW } ] < 0 || xIndices[${ dimIdxW } ] >= ${ inputDims [ dimIdxW ] } ) {
82- pad++;
83- continue;
84- }
85- let x_val = x[${ x . indicesToOffset ( 'xIndices' ) } ];
86- ${ op1 }
87- }` ;
75+ for (var i: u32 = 0u; i < ${ kw } u; i++) {
76+ xIndices[${ dimIdxW } ] = indices[${ dimIdxW } ] * ${ sw } - ${ pwStart } + i;
77+ if (xIndices[${ dimIdxW } ] < 0 || xIndices[${ dimIdxW } ] >= ${ inputDims [ dimIdxW ] } ) {
78+ pad++;
79+ continue;
80+ }
81+ let x_val = x[${ x . indicesToOffset ( 'xIndices' ) } ];
82+ ${ op1 }
83+ }` ;
8884 } else {
8985 codeW = `
90- for (var i: u32 = 0u; i < ${ kw } u; i++) {
91- xIndices[${ dimIdxW } ] = indices[${ dimIdxW } ] * ${ sw } - ${ pwStart } + i;
92- let x_val = x[${ x . indicesToOffset ( 'xIndices' ) } ];
93- ${ op1 }
94- }` ;
86+ for (var i: u32 = 0u; i < ${ kw } u; i++) {
87+ xIndices[${ dimIdxW } ] = indices[${ dimIdxW } ] * ${ sw } - ${ pwStart } + i;
88+ let x_val = x[${ x . indicesToOffset ( 'xIndices' ) } ];
89+ ${ op1 }
90+ }` ;
9591 }
9692
9793 if ( attributes . kernelShape . length === 2 ) {
0 commit comments