Skip to content

impr: Utilize subgroups in MNIST Inference when possible #986

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@
<div class="bar">8</div>
<div class="bar">9</div>
</div>

<div class="info">
<div>Subgroups: <span id="subgroups-status">-</span></div>
<div>Inference: <span id="inference-time">-</span></div>
</div>
</div>
</div>

Expand Down Expand Up @@ -49,9 +54,6 @@
.predictions-label {
margin-bottom: 0.5rem;
font-size: 1.25rem;
@media (max-width: 1024px) {
font-size: 1rem;
}
}

.bars-container {
Expand All @@ -60,10 +62,66 @@
width: 100%;
justify-content: flex-start;
row-gap: 0.5rem;
@media (max-width: 1024px) {
width: calc(100% + 8rem);
}

.info {
width: 100%;
padding: 0.75rem;
background: #f8f9fa;
border-radius: 0.5rem;
font-size: 0.875rem;
display: flex;
flex-direction: column;
gap: 0.25rem;
}

@media (max-width: 1024px) {
.predictions-container {
width: 100%;
gap: 0.5rem;
}

.predictions-label {
font-size: 1rem;
}

.predictions-container .bars-container,
.predictions-container .info {
display: inline-block;
vertical-align: top;
}

.bars-container {
width: calc(100% - 9rem);
row-gap: 0.1rem;
}

.info {
width: 8rem;
padding: 0.5rem;
font-size: 0.65rem;
margin-left: 1rem;
}
}

.info div {
display: flex;
justify-content: space-between;
color: #64748b;
}

.info span {
font-family: 'Monaco', monospace;
font-weight: 600;
color: #1e293b;
}

.info .enabled {
color: #16a34a;
}

.info .disabled {
color: #dc2626;
}

.bar {
Expand All @@ -74,7 +132,10 @@
font-size: 1rem;
background: linear-gradient(to right, transparent, #e6e6f2);
border-radius: 9999px;
@media (max-width: 1024px) {
}

@media (max-width: 1024px) {
.bar {
height: 1rem;
font-size: 0.75rem;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,40 @@ const SIZE = 28;

const root = await tgpu.init({
device: {
optionalFeatures: ['timestamp-query'],
optionalFeatures: ['timestamp-query', 'subgroups' as GPUFeatureName],
},
});
const hasTimestampQuery = root.enabledFeatures.has('timestamp-query');
const hasSubgroups = root.enabledFeatures.has('subgroups' as GPUFeatureName);
let useSubgroups = hasSubgroups;
const device = root.device;

const canvasData = new Array<number>(SIZE ** 2).fill(0);

// Shader code
const ReadonlyFloats = {
storage: (n: number) => d.arrayOf(d.f32, n),
access: 'readonly',
} as const;

const layerShader = `
@binding(0) @group(0) var<storage, read> input: array<f32>;
@binding(1) @group(0) var<storage, read_write> output: array<f32>;
const MutableFloats = {
storage: (n: number) => d.arrayOf(d.f32, n),
access: 'mutable',
} as const;

@binding(0) @group(1) var<storage, read> weights: array<f32>;
@binding(1) @group(1) var<storage, read> biases: array<f32>;
const ioLayout = tgpu.bindGroupLayout({
input: ReadonlyFloats,
output: MutableFloats,
}).$idx(0);

const weightsBiasesLayout = tgpu.bindGroupLayout({
weights: ReadonlyFloats,
biases: ReadonlyFloats,
}).$idx(1);

// Shader code

const fallbackShader = tgpu.resolve({
template: `
fn relu(x: f32) -> f32 {
return max(0.0, x);
}
Expand All @@ -38,38 +56,78 @@ const layerShader = `
sum = sum + input[j] * weights[weightsOffset + j];
}

sum = sum + biases[i];
output[i] = relu(sum);
let total = sum + biases[i];
output[i] = relu(total);
}
`;
`,
externals: {
...weightsBiasesLayout.bound,
...ioLayout.bound,
},
});

const ReadonlyFloats = {
storage: (n: number) => d.arrayOf(d.f32, n),
access: 'readonly',
} as const;
const subgroupShader = `enable subgroups;
${
tgpu.resolve({
template: `
fn relu(x: f32) -> f32 {
return max(0.0, x);
}

const MutableFloats = {
storage: (n: number) => d.arrayOf(d.f32, n),
access: 'mutable',
} as const;
// WebGPU guarantees a subgroup size of at least 4
var<workgroup> subgroupSums: array<f32, 64 / 4>;

@compute @workgroup_size(64)
fn main(
@builtin(local_invocation_id) lid: vec3u,
@builtin(workgroup_id) wid: vec3u,
@builtin(subgroup_invocation_id) sid: u32,
@builtin(subgroup_size) ssize: u32
) {
let neuronIndex = wid.x;
let inputSize = arrayLength(&input);
let weightsOffset = neuronIndex * inputSize;

var partial: f32 = 0.0;
for (var j = lid.x; j < inputSize; j = j + 64) {
partial = partial + input[j] * weights[weightsOffset + j];
}

const ioLayout = tgpu.bindGroupLayout({
input: ReadonlyFloats,
output: MutableFloats,
});
let subgroupSum = subgroupAdd(partial);
let subgroupId = lid.x / ssize;

const weightsBiasesLayout = tgpu.bindGroupLayout({
weights: ReadonlyFloats,
biases: ReadonlyFloats,
});
let numSubgroups = 64 / ssize;

if (sid == 0u) {
subgroupSums[subgroupId] = subgroupSum;
}

workgroupBarrier();

const pipeline = device.createComputePipeline({
var total: f32 = 0.0;
if (lid.x == 0u) {
for (var i = 0u; i < numSubgroups; i = i + 1u) {
total = total + subgroupSums[i];
}
total = total + biases[neuronIndex];
output[neuronIndex] = relu(total);
}
}
`,
externals: {
...weightsBiasesLayout.bound,
...ioLayout.bound,
},
})
}`;

let pipeline = device.createComputePipeline({
layout: device.createPipelineLayout({
bindGroupLayouts: [root.unwrap(ioLayout), root.unwrap(weightsBiasesLayout)],
}),
compute: {
module: device.createShaderModule({
code: layerShader,
code: useSubgroups ? subgroupShader : fallbackShader,
}),
},
});
Expand Down Expand Up @@ -181,9 +239,12 @@ function createNetwork(layers: [LayerData, LayerData][]): Network {
if (querySet?.available) {
querySet.resolve();
const results = await querySet.read();
console.log(
`Inference took ${Number(results[1] - results[0]) / 1_000_000} ms`,
);
const inferenceTimeMs = Number(results[1] - results[0]) / 1_000_000;
console.log(`Inference took ${inferenceTimeMs} ms`);

inferenceTimeEl.textContent = `${inferenceTimeMs.toFixed(2)} ms`;
} else {
inferenceTimeEl.textContent = 'N/A';
}

// Read the output
Expand All @@ -198,6 +259,22 @@ function createNetwork(layers: [LayerData, LayerData][]): Network {
};
}

const recreatePipeline = () => {
pipeline = device.createComputePipeline({
layout: device.createPipelineLayout({
bindGroupLayouts: [
root.unwrap(ioLayout),
root.unwrap(weightsBiasesLayout),
],
}),
compute: {
module: device.createShaderModule({
code: useSubgroups ? subgroupShader : fallbackShader,
}),
},
});
};

const network = createNetwork(await downloadLayers());

// #region Downloading weights & biases
Expand Down Expand Up @@ -272,6 +349,12 @@ const canvas = document.querySelector('canvas') as HTMLCanvasElement;
const context = canvas.getContext('2d') as CanvasRenderingContext2D;

const bars = Array.from(document.querySelectorAll('.bar')) as HTMLDivElement[];
const subgroupsEl = document.getElementById(
'subgroups-status',
) as HTMLSpanElement;
const inferenceTimeEl = document.getElementById(
'inference-time',
) as HTMLSpanElement;

const uiState = {
isDrawing: false,
Expand Down Expand Up @@ -329,6 +412,20 @@ function run() {
}

document.querySelector('.loading')?.classList.add('loaded');

function updateSubgroupsStatus() {
const text = !hasSubgroups
? 'Not Supported'
: useSubgroups
? 'Enabled'
: 'Disabled';
const cls = !hasSubgroups || !useSubgroups ? 'disabled' : 'enabled';
subgroupsEl.textContent = text;
subgroupsEl.className = cls;
}

updateSubgroupsStatus();

run();

canvas.addEventListener('mousedown', () => {
Expand Down Expand Up @@ -455,6 +552,14 @@ export const controls = {
Reset: {
onButtonClick: resetDrawing,
},
'Use Subgroups': {
initial: hasSubgroups,
onToggleChange: (value: boolean) => {
useSubgroups = value;
recreatePipeline();
updateSubgroupsStatus();
},
},
};

// #endregion
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"title": "MNIST Inference",
"category": "algorithms",
"tags": ["ai", "compute", "inference", "timestamp query"]
"tags": ["ai", "compute", "inference", "timestamp query", "subgroups"]
}