-
Notifications
You must be signed in to change notification settings - Fork 51
Add GRU and corresponding operators #83
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…utput to return the whole output sequence in addition to the last output.
…t atop the cell e.g. stacked cells, etc.
index.bs
Outdated
@@ -488,13 +490,238 @@ partial interface NeuralNetworkContext { | |||
</div> | |||
</div> | |||
|
|||
### gru ### {#api-neuralnetworkcontext-gru} | |||
Gated Recurrent Unit (GRU) recurrent network using an update gate and a reset gate to compute the hidden state that rolls into the output across the temporal sequence of the network, as outlined in this [paper](https://arxiv.org/pdf/1406.1078.pdf). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest you make this link a biblio reference (add into <pre class="biblio">
) and point to it instead. This allows the reader to see all the references at https://webmachinelearning.github.io/webnn/#references and makes a clear distinction between normative ([[!foo]]
) and informative references ([[foo]]
).
(I notice this "in this paper" convention is used in batchNormalization prose as well, so that could be aligned similarly as well.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea. Fixed.
Other than the referencing style comment, LGTM. I'll defer to @huningxin for the API arguments and returns definitions. |
@pyu10055 please have a look also. |
@pyu10055 now has the superpower to assign himself as a reviewer for PRs. Also the editors and the chair can assign him. Thank you @pyu10055 for your help in reviewing changes to the API spec. Note to my future self: in order to be able to assign folks as reviewers with the built-in feature, an individual must be invited as a member of the GH org (here |
@pyu10055 I took the liberty to assign you as a reviewer to this PR without your explicit consent. Do you prefer to be added as a reviewer for the WebNN API changes going forward? I'd prefer not to burden you too much with review requests if you have limited bandwidth for reviews currently. One option is to keep you in the review loop in a non-blocking manner i.e. if we don't hear any concerns from you we consider you're OK with the proposed changes. Would that work for you? Thanks for your help! |
index.bs
Outdated
let cellOutput = null; | ||
|
||
for (let slot = 0; slot < slots; ++slot) { | ||
cellHidden.push(nn.squeeze(nn.slice(hidden, [slot, 0, 0], [slot + 1, -1, -1]), [0])); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hidden
is not defined. Is it hiddenState
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super. I did miss renaming that one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
index.bs
Outdated
The behavior of this operation with default argument values can be generically emulated from the usage of other operations as follow. However, user agents typically have a more efficient implementation for it, therefore its usage is encouraged from the performance standpoint. | ||
<pre highlight="js"> | ||
let hiddenState = initialHiddenState; | ||
let slots = (direction == "both" ? 2 : 1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be good to be const slots
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
index.bs
Outdated
### gemm ### {#api-neuralnetworkcontext-gemm} | ||
Calculate the <a href="https://en.wikipedia.org/wiki/Basic_Linear_Algebra_Subprograms#Level_3">general matrixmultiplication of the Basic Liner Algebra Subprograms</a>. The calculation follows the expression `alpha * A * B + beta *C`, where `A`, `B`, and `C` are matrices, and `A` and `B` may optionally be transposed prior to the calculation. | ||
Calculate the [general matrixmultiplication of the Basic Liner Algebra Subprograms](https://en.wikipedia.org/wiki/Basic_Linear_Algebra_Subprograms#Level_3). The calculation follows the expression `alpha * A * B + beta *C`, where `A`, `B`, and `C` are matrices, and `A` and `B` may optionally be transposed prior to the calculation. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is having matrixmultiplication as one word intentional? Two words seems better to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
its a typo from the past. Fixed.
index.bs
Outdated
|
||
partial interface NeuralNetworkContext { | ||
sequence<Operand> gru(Operand input, Operand weight, | ||
Operand recurrentWeight, Operand initialHiddenState, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we make the initialHiddenState optional (with a value 0, if not specified)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point. i'll look into it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please take a look. initialHiddenState
is now optional.
index.bs
Outdated
- *recurrentWeight*: an {{Operand}}. The 3-D recurrent weight tensor of shape [num_directions, 3 * hidden_size, hidden_size]. The ordering of the weight vectors in the second dimension of the tensor shape is specified according to the *layout* argument. | ||
- *initialHiddenState*: an {{Operand}}. The 3-D initial hidden state tensor of shape [num_directions, batch_size, hidden_size]. | ||
- *hiddenSize*: a {{long}} scalar. The value of the third dimension of the cell output tensor shape. It indicates the number of features in the hidden state. | ||
- *steps*: a {{long}} scalar. The number of time steps in the recurrent network. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about varying length sequences? Do we want to support them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gramalingam I did some research and I couldn't find variable length sequence supported in any other popular frameworks but ONNX. I looked at TensorFlow (and its JS sibling), PyTorch, CoreML and MXNet. No other framework supports this but ONNX. Do you remember why ONNX added this? Can we leave it for now for compactness?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It possibly comes from CNTK.
index.bs
Outdated
- *input*: an {{Operand}}. The input tensor. | ||
- *starts*: a sequence of {{long}}. The starting indices of the corresponding axes of the input shape. | ||
- *ends*: a sequence of {{long}}. The ending indices of the corresponding axes of the input shape. | ||
The ending index value of -1 selects all the remaining tensor values from the starting index of the given axis. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we want to allow -1, it may be convenient to allow any negative value, interpreted as counting back from the end, and to allow it for starts as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gramalingam I think we're going to have to choose which semantic we want here, the numpy's one or the TensorFlow's one. ONNX follows numpy's convention, which is what you suggested.
The TensorFlow's convention is (start, length)
where start
is 0-based and length
is 1-based. Here length = -1
means the remaining number of elements to the end of the axis dimension.
The numpy's convention is (start, end)
where both are 0-based indices, and both are allowed negative values. Here end = -1
means end = length - 1
, so a slice of (0, -1)
selects all the elements in an axis except the last one. The drawback of this convention is that to select to the end of the axis dimension without knowing the dimension size (length n), one would need to pass an artificially big number like INT_MAX i.e. (0, INT_MAX)
. That's fine but we now need to agree on what that big number maybe e.g. 32 or 64 bit, signed or unsigned, etc. Numpy has a benefit of a Python's missing argument idiom, so one could simply do (0,)
as absence of end
implies length n.
The (0,-1)
use case of the TensorFlow's convention is quite handy and is used a lot in practice. I'm leaning more toward that convention. Note that PyTorch narrow
operator follows the (start, length)
convention as well but the spec does not explicitly state that it supports passing -1 as a length.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree using -1 to go until the end is convenient. If we treat (start,end)
as being inclusive of end
, we could support both features. But it could have the disadvantage of causing some confusion if people are used to particular interpretations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Either way: if the convention is different from ONNX, it will complicate translating between the two. Sigh! I guess we can't make it work for both Tensorflow and ONNX. A compromise is unavoidable. So, I am okay with the proposal
index.bs
Outdated
Produce a slice of the input tensor. | ||
<script type=idl> | ||
partial interface NeuralNetworkContext { | ||
Operand slice(Operand input, sequence<long> starts, sequence<long> ends); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding a sequence<long> axes may be helpful.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added
} | ||
} | ||
|
||
return (sequence ? [hiddenState, sequence] : [hiddenState]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: may be use "returnSequence ? [hiddenState, sequence] : [hiddenState]"? There will be a difference only in the case where steps=0.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess that does raise another question: the returned value should probably be an empty tensor (of size 0) rather than null.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
steps == 0 should not be allowed. I'll add the clarification text on the argument.
…r slice, adding an optional axes argument.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the update. Looks good to me.
Fill the operator gaps to support noise suppression first-wave models as outlined here.
Adding GRU and GRUcell operators to support GRU recurrent network. Defining a cell operator in addition to the network operator for added customization flexibility e.g. to support stacked cell recurrent network, etc.
#66
Preview | Diff