Skip to content

Commit f3cfe08

Browse files
[JS/Web] Enabled 1d spacial input to GlobalAveragePool (#17973)
### Description Enable one-dim special input to GlobalAveragePoll input ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Currently only 2D input is supported.
1 parent 780ee18 commit f3cfe08

1 file changed

Lines changed: 23 additions & 27 deletions

File tree

  • js/web/lib/wasm/jsep/webgpu/ops

js/web/lib/wasm/jsep/webgpu/ops/pool.ts

Lines changed: 23 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,18 @@ const validateInputs = (inputs: readonly TensorView[]): void => {
1818
if (!inputs || inputs.length !== 1) {
1919
throw new Error('Pool ops requires 1 input.');
2020
}
21-
if (inputs[0].dims.length !== 4) {
22-
throw new Error('Pool ops supports 2-D inputs only for now.');
21+
if (inputs[0].dims.length !== 4 && inputs[0].dims.length !== 3) {
22+
throw new Error('Pool ops supports 1-D or 2-D inputs only for now.');
2323
}
2424
};
2525

2626
const getAdjustedPoolAttributesAndOutputShape = <AttributeType extends AveragePoolAttributes|MaxPoolAttributes>(
2727
input: TensorView, attributes: AttributeType, isGlobalOperator: boolean): [AttributeType, number[]] => {
2828
const isChannelsLast = attributes.format === 'NHWC';
29-
const inputShapeAsChannelFirst =
30-
isChannelsLast ? [input.dims[0], input.dims[3], input.dims[1], input.dims[2]] : input.dims.slice();
29+
const inputShapeAsChannelFirst = input.dims.slice();
30+
if (isChannelsLast) {
31+
inputShapeAsChannelFirst.splice(1, 0, inputShapeAsChannelFirst.pop()!); // Move channel to the second position.
32+
}
3133
const hasDilations = Object.hasOwnProperty.call(attributes, 'dilations');
3234
const kernelShape = attributes.kernelShape.slice();
3335
const strides = attributes.strides.slice();
@@ -44,15 +46,9 @@ const getAdjustedPoolAttributesAndOutputShape = <AttributeType extends AveragePo
4446
} else {
4547
Object.assign(newAttributes, {kernelShape, strides, pads, cacheKey: attributes.cacheKey});
4648
}
47-
return [
48-
newAttributes,
49-
isChannelsLast ?
50-
[
51-
outputShapeAsChannelFirst[0], outputShapeAsChannelFirst[2], outputShapeAsChannelFirst[3],
52-
outputShapeAsChannelFirst[1]
53-
] :
54-
outputShapeAsChannelFirst
55-
];
49+
const outputShapeAsChannelLast = outputShapeAsChannelFirst.slice();
50+
outputShapeAsChannelLast.push(outputShapeAsChannelLast.splice(1, 1)[0]);
51+
return [newAttributes, isChannelsLast ? outputShapeAsChannelLast : outputShapeAsChannelFirst];
5652
};
5753

5854
const generatePoolingCode = <AttributeType extends AveragePoolAttributes|MaxPoolAttributes>(
@@ -76,22 +72,22 @@ const generatePoolingCode = <AttributeType extends AveragePoolAttributes|MaxPool
7672
let codeHEnd = '';
7773
if (pwStart + pwEnd !== 0) {
7874
codeW = `
79-
for (var i: u32 = 0u; i < ${kw}u; i++) {
80-
xIndices[${dimIdxW}] = indices[${dimIdxW}] * ${sw} - ${pwStart} + i;
81-
if (xIndices[${dimIdxW}] < 0 || xIndices[${dimIdxW}] >= ${inputDims[dimIdxW]}) {
82-
pad++;
83-
continue;
84-
}
85-
let x_val = x[${x.indicesToOffset('xIndices')}];
86-
${op1}
87-
}`;
75+
for (var i: u32 = 0u; i < ${kw}u; i++) {
76+
xIndices[${dimIdxW}] = indices[${dimIdxW}] * ${sw} - ${pwStart} + i;
77+
if (xIndices[${dimIdxW}] < 0 || xIndices[${dimIdxW}] >= ${inputDims[dimIdxW]}) {
78+
pad++;
79+
continue;
80+
}
81+
let x_val = x[${x.indicesToOffset('xIndices')}];
82+
${op1}
83+
}`;
8884
} else {
8985
codeW = `
90-
for (var i: u32 = 0u; i < ${kw}u; i++) {
91-
xIndices[${dimIdxW}] = indices[${dimIdxW}] * ${sw} - ${pwStart} + i;
92-
let x_val = x[${x.indicesToOffset('xIndices')}];
93-
${op1}
94-
}`;
86+
for (var i: u32 = 0u; i < ${kw}u; i++) {
87+
xIndices[${dimIdxW}] = indices[${dimIdxW}] * ${sw} - ${pwStart} + i;
88+
let x_val = x[${x.indicesToOffset('xIndices')}];
89+
${op1}
90+
}`;
9591
}
9692

9793
if (attributes.kernelShape.length === 2) {

0 commit comments

Comments
 (0)