@@ -5,16 +5,18 @@ const SIZE = 28;
5
5
6
6
const root = await tgpu . init ( {
7
7
device : {
8
- optionalFeatures : [ 'timestamp-query' ] ,
8
+ optionalFeatures : [ 'timestamp-query' , 'subgroups' as GPUFeatureName ] ,
9
9
} ,
10
10
} ) ;
11
11
const hasTimestampQuery = root . enabledFeatures . has ( 'timestamp-query' ) ;
12
+ const hasSubgroups = root . enabledFeatures . has ( 'subgroups' as GPUFeatureName ) ;
12
13
const device = root . device ;
14
+
13
15
const canvasData = new Array < number > ( SIZE ** 2 ) . fill ( 0 ) ;
14
16
15
17
// Shader code
16
18
17
- const layerShader = `
19
+ const fallbackShader = `
18
20
@binding(0) @group(0) var<storage, read> input: array<f32>;
19
21
@binding(1) @group(0) var<storage, read_write> output: array<f32>;
20
22
@@ -38,8 +40,62 @@ const layerShader = `
38
40
sum = sum + input[j] * weights[weightsOffset + j];
39
41
}
40
42
41
- sum = sum + biases[i];
42
- output[i] = relu(sum);
43
+ let total = sum + biases[i];
44
+ output[i] = relu(total);
45
+ }
46
+ ` ;
47
+
48
+ const subgroupShader = `
49
+ enable subgroups;
50
+
51
+ @binding(0) @group(0) var<storage, read> input: array<f32>;
52
+ @binding(1) @group(0) var<storage, read_write> output: array<f32>;
53
+
54
+ @binding(0) @group(1) var<storage, read> weights: array<f32>;
55
+ @binding(1) @group(1) var<storage, read> biases: array<f32>;
56
+
57
+ fn relu(x: f32) -> f32 {
58
+ return max(0.0, x);
59
+ }
60
+
61
+ // WebGPU guarantees a subgroup size of at least 4
62
+ var<workgroup> subgroupSums: array<f32, 64 / 4>;
63
+
64
+ @compute @workgroup_size(64)
65
+ fn main(
66
+ @builtin(local_invocation_id) lid: vec3u,
67
+ @builtin(workgroup_id) wid: vec3u,
68
+ @builtin(subgroup_invocation_id) sid: u32,
69
+ @builtin(subgroup_size) ssize: u32
70
+ ) {
71
+ let neuronIndex = wid.x;
72
+ let inputSize = arrayLength(&input);
73
+ let weightsOffset = neuronIndex * inputSize;
74
+
75
+ var partial: f32 = 0.0;
76
+ for (var j = lid.x; j < inputSize; j = j + 64) {
77
+ partial = partial + input[j] * weights[weightsOffset + j];
78
+ }
79
+
80
+ let subgroupSum = subgroupAdd(partial);
81
+ let subgroupId = lid.x / ssize;
82
+
83
+ let numSubgroups = 64 / ssize;
84
+
85
+ if (sid == 0u) {
86
+ subgroupSums[subgroupId] = subgroupSum;
87
+ }
88
+
89
+ workgroupBarrier();
90
+
91
+ var total: f32 = 0.0;
92
+ if (lid.x == 0u) {
93
+ for (var i = 0u; i < numSubgroups; i = i + 1u) {
94
+ total = total + subgroupSums[i];
95
+ }
96
+ total = total + biases[neuronIndex];
97
+ output[neuronIndex] = relu(total);
98
+ }
43
99
}
44
100
` ;
45
101
@@ -69,7 +125,7 @@ const pipeline = device.createComputePipeline({
69
125
} ) ,
70
126
compute : {
71
127
module : device . createShaderModule ( {
72
- code : layerShader ,
128
+ code : hasSubgroups ? subgroupShader : fallbackShader ,
73
129
} ) ,
74
130
} ,
75
131
} ) ;
0 commit comments