Skip to content

Commit 3f8d212

Browse files
committed
add subgroup mnist
1 parent 7c38412 commit 3f8d212

File tree

1 file changed

+61
-5
lines changed
  • apps/typegpu-docs/src/content/examples/algorithms/mnist-inference

1 file changed

+61
-5
lines changed

apps/typegpu-docs/src/content/examples/algorithms/mnist-inference/index.ts

Lines changed: 61 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,18 @@ const SIZE = 28;
55

66
const root = await tgpu.init({
77
device: {
8-
optionalFeatures: ['timestamp-query'],
8+
optionalFeatures: ['timestamp-query', 'subgroups' as GPUFeatureName],
99
},
1010
});
1111
const hasTimestampQuery = root.enabledFeatures.has('timestamp-query');
12+
const hasSubgroups = root.enabledFeatures.has('subgroups' as GPUFeatureName);
1213
const device = root.device;
14+
1315
const canvasData = new Array<number>(SIZE ** 2).fill(0);
1416

1517
// Shader code
1618

17-
const layerShader = `
19+
const fallbackShader = `
1820
@binding(0) @group(0) var<storage, read> input: array<f32>;
1921
@binding(1) @group(0) var<storage, read_write> output: array<f32>;
2022
@@ -38,8 +40,62 @@ const layerShader = `
3840
sum = sum + input[j] * weights[weightsOffset + j];
3941
}
4042
41-
sum = sum + biases[i];
42-
output[i] = relu(sum);
43+
let total = sum + biases[i];
44+
output[i] = relu(total);
45+
}
46+
`;
47+
48+
const subgroupShader = `
49+
enable subgroups;
50+
51+
@binding(0) @group(0) var<storage, read> input: array<f32>;
52+
@binding(1) @group(0) var<storage, read_write> output: array<f32>;
53+
54+
@binding(0) @group(1) var<storage, read> weights: array<f32>;
55+
@binding(1) @group(1) var<storage, read> biases: array<f32>;
56+
57+
fn relu(x: f32) -> f32 {
58+
return max(0.0, x);
59+
}
60+
61+
// WebGPU guarantees a subgroup size of at least 4
62+
var<workgroup> subgroupSums: array<f32, 64 / 4>;
63+
64+
@compute @workgroup_size(64)
65+
fn main(
66+
@builtin(local_invocation_id) lid: vec3u,
67+
@builtin(workgroup_id) wid: vec3u,
68+
@builtin(subgroup_invocation_id) sid: u32,
69+
@builtin(subgroup_size) ssize: u32
70+
) {
71+
let neuronIndex = wid.x;
72+
let inputSize = arrayLength(&input);
73+
let weightsOffset = neuronIndex * inputSize;
74+
75+
var partial: f32 = 0.0;
76+
for (var j = lid.x; j < inputSize; j = j + 64) {
77+
partial = partial + input[j] * weights[weightsOffset + j];
78+
}
79+
80+
let subgroupSum = subgroupAdd(partial);
81+
let subgroupId = lid.x / ssize;
82+
83+
let numSubgroups = 64 / ssize;
84+
85+
if (sid == 0u) {
86+
subgroupSums[subgroupId] = subgroupSum;
87+
}
88+
89+
workgroupBarrier();
90+
91+
var total: f32 = 0.0;
92+
if (lid.x == 0u) {
93+
for (var i = 0u; i < numSubgroups; i = i + 1u) {
94+
total = total + subgroupSums[i];
95+
}
96+
total = total + biases[neuronIndex];
97+
output[neuronIndex] = relu(total);
98+
}
4399
}
44100
`;
45101

@@ -69,7 +125,7 @@ const pipeline = device.createComputePipeline({
69125
}),
70126
compute: {
71127
module: device.createShaderModule({
72-
code: layerShader,
128+
code: hasSubgroups ? subgroupShader : fallbackShader,
73129
}),
74130
},
75131
});

0 commit comments

Comments
 (0)