Skip to content

Commit b2594a9

Browse files
committed
Apply feedback from ML meeting
1 parent 2d2f180 commit b2594a9

File tree

1 file changed

+25
-13
lines changed

1 file changed

+25
-13
lines changed

wit/wasi-nn.wit

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,20 @@ world ml {
1515
import errors;
1616
}
1717

18+
/// Inference is performed on a specific `device`.
19+
interface device {
20+
/// Define where tensors reside and graphs execute.
21+
enum location {
22+
cpu,
23+
gpu,
24+
tpu
25+
}
26+
}
27+
1828
/// All inputs and outputs to an ML inference are represented as `tensor`s.
1929
interface tensor {
30+
use device.{location};
31+
2032
/// The dimensions of a tensor.
2133
///
2234
/// The array length matches the tensor rank and each element in the array describes the size of
@@ -44,8 +56,8 @@ interface tensor {
4456
type tensor-data = list<u8>;
4557

4658
resource tensor {
47-
constructor(dimensions: tensor-dimensions, ty: tensor-type, data: tensor-data,
48-
location: option<execution-target>);
59+
/// Construct a tensor that lives on the host CPU.
60+
constructor(dimensions: tensor-dimensions, ty: tensor-type, data: tensor-data);
4961

5062
// Describe the size of the tensor (e.g., 2x2x2x2 -> [2, 2, 2, 2]). To represent a tensor
5163
// containing a single value, use `[1]` for the tensor dimensions.
@@ -55,7 +67,7 @@ interface tensor {
5567
ty: func() -> tensor-type;
5668

5769
// Describe where the tensor is currently located (e.g., `cpu`, `gpu`, `tpu`).
58-
location: func() -> execution-target;
70+
location: func() -> location;
5971

6072
// Return the tensor data. If the tensor is located on a device other than the CPU, this
6173
// operation may result in an expensive data copy operation.
@@ -74,8 +86,9 @@ interface tensor {
7486
/// framework (e.g., TensorFlow):
7587
interface graph {
7688
use errors.{error};
77-
use tensor.{tensor};
89+
use device.{location};
7890
use inference.{graph-execution-context};
91+
use tensor.{tensor};
7992

8093
/// An execution graph for performing inference (i.e., a model).
8194
resource graph {
@@ -93,21 +106,15 @@ interface graph {
93106
autodetect,
94107
}
95108

96-
/// Define where the graph should be executed.
97-
enum execution-target {
98-
cpu,
99-
gpu,
100-
tpu
101-
}
102-
103109
/// The graph initialization data.
104110
///
105111
/// This gets bundled up into an array of buffers because implementing backends may encode their
106112
/// graph IR in parts (e.g., OpenVINO stores its IR and weights separately).
107113
type graph-builder = list<u8>;
108114

109-
/// Load a `graph` from an opaque sequence of bytes to use for inference.
110-
load: func(builder: list<graph-builder>, encoding: graph-encoding, target: execution-target) -> result<graph, error>;
115+
/// Load a `graph` from an opaque sequence of bytes to use for inference on the specified device
116+
/// `location`.
117+
load: func(builder: list<graph-builder>, encoding: graph-encoding, location: location) -> result<graph, error>;
111118

112119
/// Load a `graph` by name.
113120
///
@@ -128,6 +135,11 @@ interface inference {
128135
/// TODO: this may no longer be necessary in WIT
129136
/// (https://github.com/WebAssembly/wasi-nn/issues/43)
130137
resource graph-execution-context {
138+
/// Load a tensor using the graph context. Unlike the `tensor` constructor, this function
139+
/// will co-locate the tensor data on a specific device using the graph's underlying
140+
/// backend; this may avoid some copies, improving performance.
141+
load-tensor: func(dimensions: tensor-dimensions, ty: tensor-type, data: tensor-data) -> result<tensor, error>;
142+
131143
/// Define the inputs to use for inference.
132144
set-input: func(name: string, tensor: tensor) -> result<_, error>;
133145

0 commit comments

Comments
 (0)