Skip to content

Commit e68bb7b

Browse files
qjia7kleiti
authored andcommitted
[js/webgpu] Fix the transpose error when dims > 4D (microsoft#18027)
### Description <!-- Describe your changes. --> Currently, the uniform support has bugs when dims rank is larger than 4. See microsoft#17860 item 1. So this PR only enables shapes uniforms when shape rank is <= 4 for transpose. Otherwise, below compilation errors are thrown: ``` 1 error(s) generated while compiling the shader: :3:50 error: uniform storage requires that array elements are aligned to 16 bytes, but array element of type 'u32' has a stride of 4 bytes. Consider using a vector or struct as the element type instead. struct Uniforms { output_size:u32, a_shape:array<u32, 5>, a_strides:array<u32, 5>, output_shape:array<u32, 5>, output_strides:array<u32, 5> }; ^^^^^^^^^^^^^ :3:7 note: see layout of struct: /* align(4) size(84) */ struct Uniforms { /* offset( 0) align(4) size( 4) */ output_size : u32; /* offset( 4) align(4) size(20) */ a_shape : array<u32, 5>; /* offset(24) align(4) size(20) */ a_strides : array<u32, 5>; /* offset(44) align(4) size(20) */ output_shape : array<u32, 5>; /* offset(64) align(4) size(20) */ output_strides : array<u32, 5>; /* */ }; struct Uniforms { output_size:u32, a_shape:array<u32, 5>, a_strides:array<u32, 5>, output_shape:array<u32, 5>, output_strides:array<u32, 5> }; ^^^^^^ :4:42 note: 'Uniforms' used in address space 'uniform' here @group(0) @binding(2) var<uniform> uniforms: Uniforms; ^^^^^^^^ ```
1 parent 1e0f4fa commit e68bb7b

5 files changed

Lines changed: 59 additions & 25 deletions

File tree

js/web/lib/wasm/jsep/webgpu/ops/common.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -803,3 +803,6 @@ export const getBroadcastDims = (inShape: readonly number[], outShape: readonly
803803
}
804804
return dims;
805805
};
806+
807+
// TODO: remove this limitation once >4D dims are supported by uniform.
808+
export const enableShapesUniforms = (rank: number): boolean => rank <= 4;

js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ const convTranspose2d =
232232
// STEP.1: transpose weight
233233
const transposedWeight = (context.kernelCustomData.wT as TensorView | undefined) ??
234234
context.compute(
235-
createTransposeProgramInfo(inputs[1].dataType, inputs[1].dims.length, weightTransposePerm),
235+
createTransposeProgramInfo(inputs[1], weightTransposePerm),
236236
{inputs: [1], outputs: [attributes.wIsConst ? -2 : -1]})[0];
237237
if (attributes.wIsConst && !context.kernelCustomData.wT) {
238238
context.kernelCustomData.wT = transposedWeight;

js/web/lib/wasm/jsep/webgpu/ops/conv.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut
168168
if (isChannelsLast) {
169169
const transposedWeight = (context.kernelCustomData.wT as TensorView | undefined) ??
170170
context.compute(
171-
createTransposeProgramInfo(inputs[1].dataType, inputs[1].dims.length, weightTransposeAttribute),
171+
createTransposeProgramInfo(inputs[1], weightTransposeAttribute),
172172
{inputs: [1], outputs: [attributes.wIsConst ? -2 : -1]})[0];
173173
if (attributes.wIsConst && !context.kernelCustomData.wT) {
174174
context.kernelCustomData.wT = transposedWeight;
@@ -208,7 +208,7 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut
208208
// STEP.1: transpose weight
209209
const transposedWeight = (context.kernelCustomData.wT as TensorView | undefined) ??
210210
context.compute(
211-
createTransposeProgramInfo(inputs[1].dataType, inputs[1].dims.length, weightTransposeAttribute),
211+
createTransposeProgramInfo(inputs[1], weightTransposeAttribute),
212212
{inputs: [1], outputs: [attributes.wIsConst ? -2 : -1]})[0];
213213
if (attributes.wIsConst && !context.kernelCustomData.wT) {
214214
context.kernelCustomData.wT = transposedWeight;

js/web/lib/wasm/jsep/webgpu/ops/transpose.ts

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import {ShapeUtil} from '../../util';
66
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
77
import {ComputeContext, ProgramInfo} from '../types';
88

9-
import {createTensorShapeVariables, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common';
9+
import {createTensorShapeVariables, enableShapesUniforms, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common';
1010

1111
export interface TransposeAttributes extends AttributeWithCacheKey {
1212
readonly perm: number[];
@@ -35,13 +35,18 @@ const permFunctionBody = (perm: number[], rank: number, input: IndicesHelper, ou
3535
return reverseFunc.join('\n');
3636
};
3737

38-
export const createTransposeProgramInfo =
39-
(inputDataType: number, inputRank: number, permAttr: number[]): ProgramInfo => {
40-
const perm = getAdjustedPerm(inputRank, permAttr);
41-
const output = outputVariable('output', inputDataType, (permAttr && permAttr.length) || inputRank);
42-
const input = inputVariable('a', inputDataType, inputRank);
38+
export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: number[]): ProgramInfo => {
39+
const inputDataType = inputTensor.dataType;
40+
const inputRank = inputTensor.dims.length;
41+
const perm = getAdjustedPerm(inputRank, permAttr);
42+
const useShapesUniforms = enableShapesUniforms(inputRank);
43+
const outputShape = getOutputShape(inputTensor.dims, perm);
44+
const outShapeOrRank = useShapesUniforms ? outputShape.length : outputShape;
45+
const inShapeOrRank = useShapesUniforms ? inputRank : inputTensor.dims;
46+
const output = outputVariable('output', inputDataType, outShapeOrRank);
47+
const input = inputVariable('a', inputDataType, inShapeOrRank);
4348

44-
const getShaderSource = (shaderHelper: ShaderHelper) => `
49+
const getShaderSource = (shaderHelper: ShaderHelper) => `
4550
${shaderHelper.registerUniform('output_size', 'u32').declareVariables(input, output)}
4651
4752
${permFunctionBody(perm, inputRank, input, output)}
@@ -54,30 +59,32 @@ export const createTransposeProgramInfo =
5459
5560
${output.setByOffset('global_idx', input.getByIndices('aIndices'))}
5661
}`;
62+
return {
63+
name: 'Transpose',
64+
shaderCache: {hint: `${permAttr}`, inputDependencies: useShapesUniforms ? ['rank'] : ['dims']},
65+
getRunData: (inputs) => {
66+
const outputSize = ShapeUtil.size(outputShape);
5767
return {
58-
name: 'Transpose',
59-
shaderCache: {hint: `${permAttr}`, inputDependencies: ['rank']},
60-
getRunData: (inputs) => {
61-
const outputShape = getOutputShape(inputs[0].dims, perm);
62-
const outputSize = ShapeUtil.size(outputShape);
63-
return {
64-
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
65-
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
66-
programUniforms: [
68+
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
69+
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
70+
programUniforms: useShapesUniforms ?
71+
[
6772
{type: 'uint32', data: outputSize},
6873
...createTensorShapeVariables(inputs[0].dims),
6974
...createTensorShapeVariables(outputShape),
75+
] :
76+
[
77+
{type: 'uint32', data: outputSize},
7078
],
71-
};
72-
},
73-
getShaderSource,
7479
};
75-
};
80+
},
81+
getShaderSource,
82+
};
83+
};
7684

7785
export const transpose = (context: ComputeContext, attributes: TransposeAttributes): void => {
7886
validateInputs(context.inputs);
79-
context.compute(
80-
createTransposeProgramInfo(context.inputs[0].dataType, context.inputs[0].dims.length, attributes.perm));
87+
context.compute(createTransposeProgramInfo(context.inputs[0], attributes.perm));
8188
};
8289

8390
export const parseTransposeAttributes = (attributes: Record<string, unknown>): TransposeAttributes =>

js/web/test/data/ops/transpose.jsonc

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,5 +166,29 @@
166166
]
167167
}
168168
]
169+
},
170+
{
171+
"name": "Transpose 5D - perms:[4, 3, 1, 0, 2]",
172+
"operator": "Transpose",
173+
"attributes": [{ "name": "perm", "data": [4, 3, 1, 0, 2], "type": "ints" }],
174+
"cases": [
175+
{
176+
"name": "T[3, 1, 2, 1, 4]",
177+
"inputs": [
178+
{
179+
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24],
180+
"dims": [3, 1, 2, 1, 4],
181+
"type": "float32"
182+
}
183+
],
184+
"outputs": [
185+
{
186+
"data": [1, 5, 9, 13, 17, 21, 2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23, 4, 8, 12, 16, 20, 24],
187+
"dims": [4, 1, 1, 3, 2],
188+
"type": "float32"
189+
}
190+
]
191+
}
192+
]
169193
}
170194
]

0 commit comments

Comments
 (0)