-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Recurrent Attention API: Cell wrapper base class #8296
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
andhus
wants to merge
23
commits into
keras-team:master
from
andhus:recurrent_attention_api_cell_wrapper_base
Closed
Changes from all commits
Commits
Show all changes
23 commits
Select commit
Hold shift + click to select a range
28795f1
Added support for passing external constants to RNN, which will pass …
andhus c886e84
added base class for attention cell wrapper
andhus 5213a16
Merge branch 'master' of github.com:fchollet/keras into recurrent_att…
andhus e0dfb6a
added MoG1D attention and MultiLayerWrapperMixin
andhus 767df54
added alignment example, debugging
andhus 08f0a04
fixed dimension bug
andhus b5dfc3f
started refactoring constants handling
andhus d2470b8
fixed step_function wrapping, cleaned up multi layer wrapper, removed…
andhus 4bf62f7
fixed state_spec bug
andhus 205d057
added training flag to attention cell
andhus 7265794
removed multi layer wrapper mixin and refctored MoG attention cell ac…
andhus 5fb3c1b
Merge branch 'master' of github.com:fchollet/keras into recurrent_att…
andhus afd6ae4
added error msg
andhus 7480a83
merged master
andhus 0f6219e
merged master, added TODOs
andhus 3b2753b
detailed docs of attention, WIP
andhus 21e007b
complted docs of attention base class and some cleanup
andhus 9ccdc38
removed dependence of distribution module
andhus e48f5cd
added support for multiple heads, added class docs
andhus e5d965b
completed majority of docs, added sigma_epsilon and removed use of ad…
andhus c0c8968
Merge branch 'master' of github.com:fchollet/keras into recurrent_att…
andhus a90f0b6
improved docs of recurrent_attention example
andhus ff1a8b5
Merge branch 'master' into recurrent_attention_api_cell_wrapper_base
farizrahman4u File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
'''Canonical example of using attention for sequence to sequence problems. | ||
|
||
This script demonstrates how to use an RNNAttentionCell to implement attention | ||
mechanisms. In the example, the model only have to learn to filter the attended | ||
input to obtain the target. Basically it has to learn to "parse" the attended | ||
input sequence and output only relevant parts. | ||
|
||
# Explanation of data: | ||
|
||
One sample of input data consists of a sequence of one-hot-vectors separated | ||
by randomly added "extra" zero-vectors: | ||
|
||
0 0 0 0 1 0 0 0 0 0 | ||
1 0 0 1 0 0 0 0 0 1 | ||
0 0 0 0 0 0 1 0 0 0 | ||
0 0 1 0 0 0 0 0 0 0 | ||
^ ^ | ||
| | | ||
| extra zero-vector | ||
one-hot vector | ||
|
||
The goal is to retrieve the one-hot-vector sequence _without_ the extra zeros: | ||
|
||
0 0 0 1 0 0 | ||
1 0 1 0 0 1 | ||
0 0 0 0 1 0 | ||
0 1 0 0 0 0 | ||
|
||
# Summary of the algorithm | ||
|
||
The task is carried out by letting a Mixture Of Gaussian 1D attention mechanism | ||
attend to the input sequence (with the extra zeros) and select what information | ||
should be passed to the wrapped LSTM cell. | ||
|
||
# Attention vs. Encoder-Decoder approach | ||
This is good example where attention mechanisms are suitable. In this case | ||
attention clearly outperforms e.g. encoder-decoder approaches. | ||
TODO add this comparison to the script | ||
TODO add comparison heads=1 vs heads=2 (later converges faster) | ||
''' | ||
|
||
from __future__ import division, print_function | ||
|
||
import random | ||
|
||
import numpy as np | ||
|
||
from keras import Input | ||
from keras.engine import Model | ||
from keras.layers import Dense, TimeDistributed, LSTMCell, RNN | ||
|
||
from keras.layers.attention import MixtureOfGaussian1DAttention | ||
|
||
|
||
def get_training_data(n_samples, | ||
n_labels, | ||
n_timesteps_attended, | ||
n_timesteps_labels): | ||
labels = np.random.randint( | ||
n_labels, | ||
size=(n_samples, n_timesteps_labels) | ||
) | ||
attended_time_idx = range(n_timesteps_attended) | ||
label_time_idx = range(1, n_timesteps_labels + 1) | ||
|
||
labels_one_hot = np.zeros((n_samples, n_timesteps_labels + 1, n_labels)) | ||
attended = np.zeros((n_samples, n_timesteps_attended, n_labels)) | ||
for i in range(n_samples): | ||
labels_one_hot[i][label_time_idx, labels[i]] = 1 | ||
positions = sorted(random.sample(attended_time_idx, n_timesteps_labels)) | ||
attended[i][positions, labels[i]] = 1 | ||
|
||
return labels_one_hot, attended | ||
|
||
|
||
n_samples = 10000 | ||
n_timesteps_labels = 10 | ||
n_timesteps_attended = 30 | ||
n_labels = 4 | ||
|
||
input_labels = Input((n_timesteps_labels, n_labels)) | ||
attended = Input((n_timesteps_attended, n_labels)) | ||
|
||
cell = MixtureOfGaussian1DAttention(LSTMCell(64), components=3, heads=2) | ||
attention_lstm = RNN(cell, return_sequences=True) | ||
|
||
attention_lstm_output = attention_lstm(input_labels, constants=attended) | ||
output_layer = TimeDistributed(Dense(n_labels, activation='softmax')) | ||
output = output_layer(attention_lstm_output) | ||
|
||
model = Model(inputs=[input_labels, attended], outputs=output) | ||
|
||
labels_data, attended_data = get_training_data(n_samples, | ||
n_labels, | ||
n_timesteps_attended, | ||
n_timesteps_labels) | ||
input_labels_data = labels_data[:, :-1, :] | ||
target_labels_data = labels_data[:, 1:, :] | ||
|
||
model.compile(optimizer='Adam', loss='categorical_crossentropy') | ||
model.fit(x=[input_labels_data, attended_data], y=target_labels_data, epochs=5) | ||
output_data = model.predict([input_labels_data, attended_data]) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We wont have input_labels_data during prediction.
You have to:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, that makes the example more complete 👍