@@ -15,8 +15,20 @@ world ml {
15
15
import errors ;
16
16
}
17
17
18
+ /// Inference is performed on a specific `device` .
19
+ interface device {
20
+ /// Define where tensors reside and graphs execute.
21
+ enum location {
22
+ cpu ,
23
+ gpu ,
24
+ tpu
25
+ }
26
+ }
27
+
18
28
/// All inputs and outputs to an ML inference are represented as `tensor` s.
19
29
interface tensor {
30
+ use device . {location };
31
+
20
32
/// The dimensions of a tensor.
21
33
///
22
34
/// The array length matches the tensor rank and each element in the array describes the size of
@@ -44,8 +56,8 @@ interface tensor {
44
56
type tensor-data = list <u8 >;
45
57
46
58
resource tensor {
47
- constructor ( dimensions : tensor-dimensions , ty : tensor-type , data : tensor-data ,
48
- location : option < execution-target > );
59
+ /// Construct a tensor that lives on the host CPU.
60
+ constructor ( dimensions : tensor-dimensions , ty : tensor-type , data : tensor-data );
49
61
50
62
// Describe the size of the tensor (e.g., 2x2x2x2 -> [2, 2, 2, 2]). To represent a tensor
51
63
// containing a single value, use `[1]` for the tensor dimensions.
@@ -55,7 +67,7 @@ interface tensor {
55
67
ty : func () -> tensor-type ;
56
68
57
69
// Describe where the tensor is currently located (e.g., `cpu`, `gpu`, `tpu`).
58
- location : func () -> execution-target ;
70
+ location : func () -> location ;
59
71
60
72
// Return the tensor data. If the tensor is located on a device other than the CPU, this
61
73
// operation may result in an expensive data copy operation.
@@ -74,8 +86,9 @@ interface tensor {
74
86
/// framework (e.g., TensorFlow):
75
87
interface graph {
76
88
use errors . {error };
77
- use tensor . { tensor };
89
+ use device . { location };
78
90
use inference . {graph-execution-context };
91
+ use tensor . {tensor };
79
92
80
93
/// An execution graph for performing inference (i.e., a model).
81
94
resource graph {
@@ -93,21 +106,15 @@ interface graph {
93
106
autodetect ,
94
107
}
95
108
96
- /// Define where the graph should be executed.
97
- enum execution-target {
98
- cpu ,
99
- gpu ,
100
- tpu
101
- }
102
-
103
109
/// The graph initialization data.
104
110
///
105
111
/// This gets bundled up into an array of buffers because implementing backends may encode their
106
112
/// graph IR in parts (e.g., OpenVINO stores its IR and weights separately).
107
113
type graph-builder = list <u8 >;
108
114
109
- /// Load a `graph` from an opaque sequence of bytes to use for inference.
110
- load : func (builder : list <graph-builder >, encoding : graph-encoding , target : execution-target ) -> result <graph , error >;
115
+ /// Load a `graph` from an opaque sequence of bytes to use for inference on the specified device
116
+ /// `location` .
117
+ load : func (builder : list <graph-builder >, encoding : graph-encoding , location : location ) -> result <graph , error >;
111
118
112
119
/// Load a `graph` by name.
113
120
///
@@ -128,6 +135,11 @@ interface inference {
128
135
/// TODO: this may no longer be necessary in WIT
129
136
/// (https://github.com/WebAssembly/wasi-nn/issues/43)
130
137
resource graph-execution-context {
138
+ /// Load a tensor using the graph context. Unlike the `tensor` constructor, this function
139
+ /// will co-locate the tensor data on a specific device using the graph's underlying
140
+ /// backend; this may avoid some copies, improving performance.
141
+ load-tensor : func (dimensions : tensor-dimensions , ty : tensor-type , data : tensor-data ) -> result <tensor , error >;
142
+
131
143
/// Define the inputs to use for inference.
132
144
set-input : func (name : string , tensor : tensor ) -> result <_ , error >;
133
145
0 commit comments