Closed
Description
I'm trying to train RNNs with truncated BPTT with tf.data
(a great API by the way!) but got tripped up by these lines as I've assumed an exhausted iterator would result in a new element being opened directly at the same position in the cycle (in order to pass around RNN states reliably).
Instead what seems to be happening is that my sequences are accidentally shifted in in the subsequent .batch()
call whenever a sequence is done. Could the default be changed so that a new element is consumed directly as long as there are any left, such that consecutive dataset elements can be batched in a more straightforward way for RNN training.
Or could we have a tf.contrib.data.batched_interleave
or similar?