@@ -8,7 +8,44 @@ import {createView, TensorView} from './tensor-view';
88import { createGpuDataManager , downloadGpuData , GpuDataManager } from './webgpu/gpu-data-manager' ;
99import { RunFunction , WEBGPU_OP_RESOLVE_RULES } from './webgpu/op-resolve-rules' ;
1010import { ProgramManager } from './webgpu/program-manager' ;
11- import { ComputeContext , GpuData , ProgramInfo , ProgramInfoLoader } from './webgpu/types' ;
11+ import { ComputeContext , GpuData , ProgramInfo , ProgramInputTensorInfoDependency } from './webgpu/types' ;
12+
13+ const getProgramInputTensorInfoDependencyKey =
14+ ( inputTensors : readonly TensorView [ ] , inputDependencies : readonly ProgramInputTensorInfoDependency [ ] ) : string => {
15+ if ( inputDependencies . length !== inputTensors . length ) {
16+ throw new Error ( `inputDependencies length ${ inputDependencies . length } is not equal to inputTensors length ${
17+ inputTensors . length } .`) ;
18+ }
19+
20+ const inputInfos : string [ ] = [ ] ;
21+ for ( let i = 0 ; i < inputTensors . length ; ++ i ) {
22+ const type = inputTensors [ i ] . dataType ;
23+ switch ( inputDependencies [ i ] ) {
24+ case 'none' : {
25+ inputInfos . push ( '' ) ;
26+ break ;
27+ }
28+ case 'type' : {
29+ inputInfos . push ( `${ type } ` ) ;
30+ break ;
31+ }
32+ case 'rank' : {
33+ const rank = inputTensors [ i ] . dims . length ;
34+ inputInfos . push ( `${ type } ;${ rank } ` ) ;
35+ break ;
36+ }
37+ case 'dims' : {
38+ const dims = inputTensors [ i ] . dims . join ( ',' ) ;
39+ inputInfos . push ( `${ type } ;${ dims } ` ) ;
40+ break ;
41+ }
42+ default :
43+ throw new Error ( `unsupported input dependency: ${ inputDependencies [ i ] } ` ) ;
44+ }
45+ }
46+
47+ return inputInfos . join ( '|' ) ;
48+ } ;
1249
1350/**
1451 * get a unique key representing the program from the program info, input shapes and types.
@@ -17,18 +54,20 @@ import {ComputeContext, GpuData, ProgramInfo, ProgramInfoLoader} from './webgpu/
1754 * program. if the key is the same, the program shader source should be the same, so we can reuse the program.
1855 *
1956 */
20- const getProgramInfoUniqueKey =
21- ( programInfo : ProgramInfo | ProgramInfoLoader , inputTensors : readonly TensorView [ ] ) : string => {
22- // final key format:
23- // <PROGRAM_NAME>[<PROGRAM_CUSTOM_CACHE_HINT>]:<INPUTS_INFO_0>|<INPUTS_INFO_1>|...
24- const inputInfos = inputTensors . map ( tensor => `${ tensor . dataType } ;${ tensor . dims . join ( ',' ) } ` ) . join ( '|' ) ;
25- let key = programInfo . name ;
26- if ( programInfo . cacheHint ) {
27- key += '[' + programInfo . cacheHint + ']' ;
28- }
29- key += ':' + inputInfos ;
30- return key ;
31- } ;
57+ const getProgramInfoUniqueKey = ( programInfo : ProgramInfo , inputTensors : readonly TensorView [ ] ) : string => {
58+ // final key format:
59+ // <PROGRAM_NAME>[<PROGRAM_CUSTOM_CACHE_HINT>]:<INPUTS_INFO_0>|<INPUTS_INFO_1>|...
60+ let key = programInfo . name ;
61+ if ( programInfo . shaderCache ?. hint ) {
62+ key += '[' + programInfo . shaderCache . hint + ']' ;
63+ }
64+ key += `:${
65+ getProgramInputTensorInfoDependencyKey (
66+ inputTensors ,
67+ programInfo . shaderCache ?. inputDependencies ??
68+ new Array < ProgramInputTensorInfoDependency > ( inputTensors . length ) . fill ( 'dims' ) ) } `;
69+ return key ;
70+ } ;
3271
3372/**
3473 * this class is designed to store status and being used as a singleton for JSEP. It will be passed to jsepInit() as
@@ -208,55 +247,53 @@ export class WebGpuBackend {
208247
209248 /**
210249 * run a WebGPU program.
211- * @param program either a ProgramInfo instance containing metadata including the shader code, or a function that
212- * can be called and return a ProgramInfo instance
213- * @param inputs a TensorView array. each element represents a value already exists in GPU.
250+ * @param program a ProgramInfo instance
251+ * @param inputTensorViews a TensorView array. each element represents a value already exists in GPU.
214252 * @param outputIndices an indices array. each element can be either -1 (temporary data), -2 (persistent data) or an
215253 * index to the kernel's output.
216254 * @param createKernelOutput a callback function that create a value to kernel's output with the given index
217255 * @param createIntermediateOutput a callback function that create a value as a intermediate value, either temporary
218256 * or persistent (owned by the current kernel)
219257 * @returns a TensorView array representing the result.
220258 */
221- run ( program : ProgramInfoLoader | ProgramInfo , inputs : readonly TensorView [ ] , outputIndices : readonly number [ ] ,
259+ run ( program : ProgramInfo , inputTensorViews : readonly TensorView [ ] , outputIndices : readonly number [ ] ,
222260 createKernelOutput : ( index : number , dataType : number , dims : readonly number [ ] ) => TensorView ,
223261 createIntermediateOutput : ( dataType : number , dims : readonly number [ ] ) => TensorView ) : TensorView [ ] {
224- if ( inputs . length !== program . inputTypes . length ) {
262+ if ( inputTensorViews . length !== program . inputTypes . length ) {
225263 throw new Error ( `Input size must be equal to ${ program . inputTypes . length } .` ) ;
226264 }
227265
228266 // create info for inputs
229267 const inputDatas : GpuData [ ] = [ ] ;
230- for ( let i = 0 ; i < inputs . length ; ++ i ) {
231- const gpuData = this . gpuDataManager . get ( inputs [ i ] . data ) ;
268+ for ( let i = 0 ; i < inputTensorViews . length ; ++ i ) {
269+ const gpuData = this . gpuDataManager . get ( inputTensorViews [ i ] . data ) ;
232270 if ( ! gpuData ) {
233- throw new Error ( `no GPU data for input: ${ inputs [ i ] . data } ` ) ;
271+ throw new Error ( `no GPU data for input: ${ inputTensorViews [ i ] . data } ` ) ;
234272 }
235273 inputDatas [ i ] = gpuData ;
236274 }
237275
238- const key = getProgramInfoUniqueKey ( program , inputs ) ;
276+ // get program info
277+ const key = getProgramInfoUniqueKey ( program , inputTensorViews ) ;
239278 let artifact = this . programManager . getArtifact ( key ) ;
240- const programInfo = artifact ?
241- artifact . programInfo :
242- ( typeof ( program as ProgramInfoLoader ) . get === 'function' ? ( program as ProgramInfoLoader ) . get ( ) :
243- ( program as ProgramInfo ) ) ;
279+
280+ const { outputs, dispatchGroup, variables} = program . getRunData ( inputTensorViews ) ;
244281
245282 // check output indices
246- const validatedOutputIndices = outputIndices . length === 0 ? programInfo . outputs . map ( ( _ , i ) => i ) : outputIndices ;
247- if ( validatedOutputIndices . length !== programInfo . outputs . length ) {
248- throw new Error ( `Output size ${ validatedOutputIndices . length } must be equal to ${ programInfo . outputs . length } .` ) ;
283+ const validatedOutputIndices = outputIndices . length === 0 ? outputs . map ( ( _ , i ) => i ) : outputIndices ;
284+ if ( validatedOutputIndices . length !== outputs . length ) {
285+ throw new Error ( `Output size ${ validatedOutputIndices . length } must be equal to ${ outputs . length } .` ) ;
249286 }
250287
251288 // create info for outputs
252289 const outputTensorViews : TensorView [ ] = [ ] ;
253290 const outputDatas : GpuData [ ] = [ ] ;
254- for ( let i = 0 ; i < programInfo . outputs . length ; ++ i ) {
291+ for ( let i = 0 ; i < outputs . length ; ++ i ) {
255292 // value -1 and -2 are used for creating temporary and persistent outputs.
256293 // value -3 is used for placeholder output. So -3, -2, -1 and 0, 1, 2, ... are valid
257294 // output indices. see type definition of ComputeContextInputsOutputsMapping for more details.
258295 if ( ! Number . isInteger ( validatedOutputIndices [ i ] ) || validatedOutputIndices [ i ] < - 3 ||
259- validatedOutputIndices [ i ] >= programInfo . outputs . length ) {
296+ validatedOutputIndices [ i ] >= outputs . length ) {
260297 throw new Error ( `Invalid output index: ${ validatedOutputIndices [ i ] } ` ) ;
261298 }
262299 if ( validatedOutputIndices [ i ] === - 3 ) {
@@ -265,8 +302,8 @@ export class WebGpuBackend {
265302 const isTemporary = validatedOutputIndices [ i ] === - 1 ;
266303 const isPersistent = validatedOutputIndices [ i ] === - 2 ;
267304 const tensorView = ( isTemporary || isPersistent ) ?
268- createIntermediateOutput ( programInfo . outputs [ i ] . dataType , programInfo . outputs [ i ] . dims ) :
269- createKernelOutput ( validatedOutputIndices [ i ] , programInfo . outputs [ i ] . dataType , programInfo . outputs [ i ] . dims ) ;
305+ createIntermediateOutput ( outputs [ i ] . dataType , outputs [ i ] . dims ) :
306+ createKernelOutput ( validatedOutputIndices [ i ] , outputs [ i ] . dataType , outputs [ i ] . dims ) ;
270307 const gpuData = this . gpuDataManager . get ( tensorView . data ) ;
271308 if ( ! gpuData ) {
272309 throw new Error ( `no GPU data for output: ${ tensorView . data } ` ) ;
@@ -286,18 +323,92 @@ export class WebGpuBackend {
286323 outputDatas . push ( gpuData ) ;
287324 }
288325
289- const normalizedDispatchGroup = this . programManager . normalizeDispatchGroupSize ( programInfo . dispatchGroup ( inputs ) ) ;
326+
327+ // load uniforms
328+ // TODO: add cache for uniform (is it necessary?)
329+ //
330+ let uniformBufferBinding : GPUBindingResource | undefined ;
331+ if ( variables ) {
332+ let currentOffset = 0 ;
333+ let preLength = 0 ;
334+ const offsets : number [ ] = [ ] ;
335+ let maxAlignmentOfField = 1 ;
336+ variables . forEach ( v => {
337+ const data = typeof v . data === 'number' ? [ v . data ] : v . data ;
338+ // https://www.w3.org/TR/WGSL/#alignof
339+ let baseAlignment : number ;
340+ switch ( data . length ) {
341+ case 1 :
342+ baseAlignment = 4 ;
343+ break ;
344+ case 2 :
345+ baseAlignment = 8 ;
346+ break ;
347+ case 3 :
348+ baseAlignment = 16 ;
349+ break ;
350+ case 4 :
351+ baseAlignment = 16 ;
352+ break ;
353+ case 5 :
354+ baseAlignment = 16 ;
355+ break ;
356+ case 6 :
357+ baseAlignment = 16 ;
358+ break ;
359+ default :
360+ throw new Error ( `unsupported data length: ${ data . length } ` ) ;
361+ }
362+
363+ if ( preLength === 5 || preLength === 6 ) {
364+ baseAlignment = 16 ;
365+ }
366+ if ( baseAlignment > maxAlignmentOfField ) {
367+ maxAlignmentOfField = baseAlignment ;
368+ }
369+ currentOffset = Math . ceil ( currentOffset / baseAlignment ) * baseAlignment ;
370+ preLength = data . length ;
371+ offsets . push ( currentOffset ) ;
372+ currentOffset += data . length * 4 ;
373+ } ) ;
374+
375+ currentOffset = Math . ceil ( currentOffset / maxAlignmentOfField ) * maxAlignmentOfField ;
376+ const arrayBuffer = new ArrayBuffer ( currentOffset ) ;
377+ variables . forEach ( ( v , i ) => {
378+ const offset = offsets [ i ] ;
379+ const data = typeof v . data === 'number' ? [ v . data ] : v . data ;
380+ if ( v . type === 'int32' ) {
381+ new Int32Array ( arrayBuffer , offset , data . length ) . set ( data ) ;
382+ } else if ( v . type === 'uint32' ) {
383+ new Uint32Array ( arrayBuffer , offset , data . length ) . set ( data ) ;
384+ } else {
385+ new Float32Array ( arrayBuffer , offset , data . length ) . set ( data ) ;
386+ }
387+ } ) ;
388+
389+ const uniformBufferData =
390+ // eslint-disable-next-line no-bitwise
391+ this . gpuDataManager . create ( currentOffset , GPUBufferUsage . COPY_DST | GPUBufferUsage . UNIFORM ) ;
392+ this . device . queue . writeBuffer ( uniformBufferData . buffer , 0 , arrayBuffer , 0 , currentOffset ) ;
393+ this . gpuDataManager . release ( uniformBufferData . id ) ;
394+ uniformBufferBinding = { offset : 0 , size : currentOffset , buffer : uniformBufferData . buffer } ;
395+ }
396+
397+
398+ const normalizedDispatchGroup = this . programManager . normalizeDispatchGroupSize ( dispatchGroup ) ;
290399
291400 if ( ! artifact ) {
292- artifact = this . programManager . build ( programInfo , normalizedDispatchGroup ) ;
401+ artifact = this . programManager . build ( program , normalizedDispatchGroup ) ;
293402 this . programManager . setArtifact ( key , artifact ) ;
294403 }
295404
296405 LOG_DEBUG (
297406 'info' ,
298- ( ) => `[ProgramManager] run "${ programInfo . name } " (key=${ key } ) with ${ normalizedDispatchGroup [ 0 ] } x${
407+ ( ) => `[ProgramManager] run "${ program . name } " (key=${ key } ) with ${ normalizedDispatchGroup [ 0 ] } x${
299408 normalizedDispatchGroup [ 1 ] } x${ normalizedDispatchGroup [ 2 ] } `) ;
300- this . programManager . run ( artifact , inputs , inputDatas , outputDatas , normalizedDispatchGroup ) ;
409+ this . programManager . run (
410+ artifact , inputTensorViews , outputTensorViews , inputDatas , outputDatas , normalizedDispatchGroup ,
411+ uniformBufferBinding ) ;
301412
302413 return outputTensorViews ;
303414 }
0 commit comments