Skip to content

Commit a5f5090

Browse files
authored
feat: tgpu.resolveWithContext (#1410)
1 parent 78a9fcd commit a5f5090

File tree

6 files changed

+85
-35
lines changed

6 files changed

+85
-35
lines changed

packages/typegpu/src/core/pipeline/computePipeline.ts

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ type TgpuComputePipelinePriors = {
7474

7575
type Memo = {
7676
pipeline: GPUComputePipeline;
77-
bindGroupLayouts: TgpuBindGroupLayout[];
78-
catchall: [number, TgpuBindGroup] | null;
77+
usedBindGroupLayouts: TgpuBindGroupLayout[];
78+
catchall: [number, TgpuBindGroup] | undefined;
7979
};
8080

8181
class TgpuComputePipelineImpl implements TgpuComputePipeline {
@@ -156,9 +156,9 @@ class TgpuComputePipelineImpl implements TgpuComputePipeline {
156156

157157
pass.setPipeline(memo.pipeline);
158158

159-
const missingBindGroups = new Set(memo.bindGroupLayouts);
159+
const missingBindGroups = new Set(memo.usedBindGroupLayouts);
160160

161-
memo.bindGroupLayouts.forEach((layout, idx) => {
161+
memo.usedBindGroupLayouts.forEach((layout, idx) => {
162162
if (memo.catchall && idx === memo.catchall[0]) {
163163
// Catch-all
164164
pass.setBindGroup(idx, branch.unwrap(memo.catchall[1]));
@@ -207,7 +207,7 @@ class ComputePipelineCore {
207207
const device = this.branch.device;
208208

209209
// Resolving code
210-
const { code, bindGroupLayouts, catchall } = resolve(
210+
const { code, usedBindGroupLayouts, catchall } = resolve(
211211
{
212212
'~resolve': (ctx) => {
213213
ctx.withSlots(this._slotBindings, () => {
@@ -224,8 +224,8 @@ class ComputePipelineCore {
224224
},
225225
);
226226

227-
if (catchall !== null) {
228-
bindGroupLayouts[catchall[0]]?.$name(
227+
if (catchall !== undefined) {
228+
usedBindGroupLayouts[catchall[0]]?.$name(
229229
`${getName(this) ?? '<unnamed>'} - Automatic Bind Group & Layout`,
230230
);
231231
}
@@ -235,7 +235,7 @@ class ComputePipelineCore {
235235
label: getName(this) ?? '<unnamed>',
236236
layout: device.createPipelineLayout({
237237
label: `${getName(this) ?? '<unnamed>'} - Pipeline Layout`,
238-
bindGroupLayouts: bindGroupLayouts.map((l) =>
238+
bindGroupLayouts: usedBindGroupLayouts.map((l) =>
239239
this.branch.unwrap(l)
240240
),
241241
}),
@@ -246,7 +246,7 @@ class ComputePipelineCore {
246246
}),
247247
},
248248
}),
249-
bindGroupLayouts,
249+
usedBindGroupLayouts,
250250
catchall,
251251
};
252252
}

packages/typegpu/src/core/pipeline/renderPipeline.ts

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -232,8 +232,8 @@ type TgpuRenderPipelinePriors = {
232232

233233
type Memo = {
234234
pipeline: GPURenderPipeline;
235-
bindGroupLayouts: TgpuBindGroupLayout[];
236-
catchall: [number, TgpuBindGroup] | null;
235+
usedBindGroupLayouts: TgpuBindGroupLayout[];
236+
catchall: [number, TgpuBindGroup] | undefined;
237237
};
238238

239239
class TgpuRenderPipelineImpl implements TgpuRenderPipeline {
@@ -390,9 +390,9 @@ class TgpuRenderPipelineImpl implements TgpuRenderPipeline {
390390

391391
pass.setPipeline(memo.pipeline);
392392

393-
const missingBindGroups = new Set(memo.bindGroupLayouts);
393+
const missingBindGroups = new Set(memo.usedBindGroupLayouts);
394394

395-
memo.bindGroupLayouts.forEach((layout, idx) => {
395+
memo.usedBindGroupLayouts.forEach((layout, idx) => {
396396
if (memo.catchall && idx === memo.catchall[0]) {
397397
// Catch-all
398398
pass.setBindGroup(idx, branch.unwrap(memo.catchall[1]));
@@ -473,7 +473,7 @@ class RenderPipelineCore {
473473
} = this.options;
474474

475475
// Resolving code
476-
const { code, bindGroupLayouts, catchall } = resolve(
476+
const { code, usedBindGroupLayouts, catchall } = resolve(
477477
{
478478
'~resolve': (ctx) => {
479479
ctx.withSlots(slotBindings, () => {
@@ -491,8 +491,8 @@ class RenderPipelineCore {
491491
},
492492
);
493493

494-
if (catchall !== null) {
495-
bindGroupLayouts[catchall[0]]?.$name(
494+
if (catchall !== undefined) {
495+
usedBindGroupLayouts[catchall[0]]?.$name(
496496
`${getName(this) ?? '<unnamed>'} - Automatic Bind Group & Layout`,
497497
);
498498
}
@@ -507,7 +507,7 @@ class RenderPipelineCore {
507507
const descriptor: GPURenderPipelineDescriptor = {
508508
layout: device.createPipelineLayout({
509509
label: `${getName(this) ?? '<unnamed>'} - Pipeline Layout`,
510-
bindGroupLayouts: bindGroupLayouts.map((l) => branch.unwrap(l)),
510+
bindGroupLayouts: usedBindGroupLayouts.map((l) => branch.unwrap(l)),
511511
}),
512512
vertex: {
513513
module,
@@ -538,7 +538,7 @@ class RenderPipelineCore {
538538

539539
this._memo = {
540540
pipeline: device.createRenderPipeline(descriptor),
541-
bindGroupLayouts,
541+
usedBindGroupLayouts,
542542
catchall,
543543
};
544544
}

packages/typegpu/src/core/resolve/tgpuResolve.ts

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import type { JitTranspiler } from '../../jitTranspiler.ts';
22
import { RandomNameRegistry, StrictNameRegistry } from '../../nameRegistry.ts';
3-
import { resolve as resolveImpl } from '../../resolutionCtx.ts';
3+
import {
4+
type ResolutionResult,
5+
resolve as resolveImpl,
6+
} from '../../resolutionCtx.ts';
47
import type { SelfResolvable, Wgsl } from '../../types.ts';
58
import { applyExternals, replaceExternalsInWgsl } from './externals.ts';
69

@@ -31,7 +34,7 @@ export interface TgpuResolveOptions {
3134
* Any dependencies of the externals will also be resolved and included in the output.
3235
* @param options - The options for the resolution.
3336
*
34-
* @returns The resolved code.
37+
* @returns {ResolutionResult}
3538
*
3639
* @example
3740
* ```ts
@@ -40,7 +43,7 @@ export interface TgpuResolveOptions {
4043
* to: d.vec3f,
4144
* });
4245
*
43-
* const resolved = tgpu.resolve({
46+
* const { code, usedBindGroupLayouts, catchall } = tgpu.resolveWithContext({
4447
* template: `
4548
* fn getGradientAngle(gradient: Gradient) -> f32 {
4649
* return atan(gradient.to.y - gradient.from.y, gradient.to.x - gradient.from.x);
@@ -51,7 +54,7 @@ export interface TgpuResolveOptions {
5154
* },
5255
* });
5356
*
54-
* console.log(resolved);
57+
* console.log(code);
5558
* // struct Gradient_0 {
5659
* // from: vec3f,
5760
* // to: vec3f,
@@ -61,7 +64,9 @@ export interface TgpuResolveOptions {
6164
* // }
6265
* ```
6366
*/
64-
export function resolve(options: TgpuResolveOptions): string {
67+
export function resolveWithContext(
68+
options: TgpuResolveOptions,
69+
): ResolutionResult {
6570
const {
6671
externals,
6772
template,
@@ -80,12 +85,49 @@ export function resolve(options: TgpuResolveOptions): string {
8085
toString: () => '<root>',
8186
};
8287

83-
const { code } = resolveImpl(resolutionObj, {
88+
return resolveImpl(resolutionObj, {
8489
names: names === 'strict'
8590
? new StrictNameRegistry()
8691
: new RandomNameRegistry(),
8792
jitTranspiler,
8893
});
94+
}
8995

90-
return code;
96+
/**
97+
* Resolves a template with external values. Each external will get resolved to a code string and replaced in the template.
98+
* Any dependencies of the externals will also be resolved and included in the output.
99+
* @param options - The options for the resolution.
100+
*
101+
* @returns The resolved code.
102+
*
103+
* @example
104+
* ```ts
105+
* const Gradient = d.struct({
106+
* from: d.vec3f,
107+
* to: d.vec3f,
108+
* });
109+
*
110+
* const resolved = tgpu.resolve({
111+
* template: `
112+
* fn getGradientAngle(gradient: Gradient) -> f32 {
113+
* return atan(gradient.to.y - gradient.from.y, gradient.to.x - gradient.from.x);
114+
* }
115+
* `,
116+
* externals: {
117+
* Gradient,
118+
* },
119+
* });
120+
*
121+
* console.log(resolved);
122+
* // struct Gradient_0 {
123+
* // from: vec3f,
124+
* // to: vec3f,
125+
* // }
126+
* // fn getGradientAngle(gradient: Gradient_0) -> f32 {
127+
* // return atan(gradient.to.y - gradient.from.y, gradient.to.x - gradient.from.x);
128+
* // }
129+
* ```
130+
*/
131+
export function resolve(options: TgpuResolveOptions): string {
132+
return resolveWithContext(options).code;
91133
}

packages/typegpu/src/core/root/init.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -510,8 +510,8 @@ class TgpuRootImpl extends WithBindingImpl
510510

511511
pass.setPipeline(memo.pipeline);
512512

513-
const missingBindGroups = new Set(memo.bindGroupLayouts);
514-
memo.bindGroupLayouts.forEach((layout, idx) => {
513+
const missingBindGroups = new Set(memo.usedBindGroupLayouts);
514+
memo.usedBindGroupLayouts.forEach((layout, idx) => {
515515
if (memo.catchall && idx === memo.catchall[0]) {
516516
// Catch-all
517517
pass.setBindGroup(idx, this.unwrap(memo.catchall[1]));

packages/typegpu/src/index.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import { computeFn } from './core/function/tgpuComputeFn.ts';
99
import { fn } from './core/function/tgpuFn.ts';
1010
import { fragmentFn } from './core/function/tgpuFragmentFn.ts';
1111
import { vertexFn } from './core/function/tgpuVertexFn.ts';
12-
import { resolve } from './core/resolve/tgpuResolve.ts';
12+
import { resolve, resolveWithContext } from './core/resolve/tgpuResolve.ts';
1313
import { init, initFromDevice } from './core/root/init.ts';
1414
import { comparisonSampler, sampler } from './core/sampler/sampler.ts';
1515
import { accessor } from './core/slot/accessor.ts';
@@ -27,6 +27,7 @@ export const tgpu = {
2727
initFromDevice,
2828

2929
resolve,
30+
resolveWithContext,
3031

3132
'~unstable': {
3233
fn,

packages/typegpu/src/resolutionCtx.ts

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -641,10 +641,17 @@ export class ResolutionCtxImpl implements ResolutionCtx {
641641
}
642642
}
643643

644+
/**
645+
* The results of a WGSL resolution.
646+
*
647+
* @param code - The resolved code.
648+
* @param usedBindGroupLayouts - List of used `tgpu.bindGroupLayout`s.
649+
* @param catchall - Automatically constructed bind group for buffer usages and buffer shorthands, preceded by its index.
650+
*/
644651
export interface ResolutionResult {
645652
code: string;
646-
bindGroupLayouts: TgpuBindGroupLayout[];
647-
catchall: [number, TgpuBindGroup] | null;
653+
usedBindGroupLayouts: TgpuBindGroupLayout[];
654+
catchall: [number, TgpuBindGroup] | undefined;
648655
}
649656

650657
export function resolve(
@@ -655,7 +662,7 @@ export function resolve(
655662
let code = ctx.resolve(item);
656663

657664
const memoMap = ctx.bindGroupLayoutsToPlaceholderMap;
658-
const bindGroupLayouts: TgpuBindGroupLayout[] = [];
665+
const usedBindGroupLayouts: TgpuBindGroupLayout[] = [];
659666
const takenIndices = new Set<number>(
660667
[...memoMap.keys()]
661668
.map((layout) => layout.index)
@@ -672,7 +679,7 @@ export function resolve(
672679
const createCatchallGroup = () => {
673680
const catchallIdx = automaticIds.next().value;
674681
const catchallLayout = bindGroupLayout(Object.fromEntries(layoutEntries));
675-
bindGroupLayouts[catchallIdx] = catchallLayout;
682+
usedBindGroupLayouts[catchallIdx] = catchallLayout;
676683
code = code.replaceAll(CATCHALL_BIND_GROUP_IDX_MARKER, String(catchallIdx));
677684

678685
return [
@@ -692,17 +699,17 @@ export function resolve(
692699

693700
// Retrieving the catch-all binding index first, because it's inherently
694701
// the least swapped bind group (fixed and cannot be swapped).
695-
const catchall = layoutEntries.length > 0 ? createCatchallGroup() : null;
702+
const catchall = layoutEntries.length > 0 ? createCatchallGroup() : undefined;
696703

697704
for (const [layout, placeholder] of memoMap.entries()) {
698705
const idx = layout.index ?? automaticIds.next().value;
699-
bindGroupLayouts[idx] = layout;
706+
usedBindGroupLayouts[idx] = layout;
700707
code = code.replaceAll(placeholder, String(idx));
701708
}
702709

703710
return {
704711
code,
705-
bindGroupLayouts,
712+
usedBindGroupLayouts,
706713
catchall,
707714
};
708715
}

0 commit comments

Comments
 (0)