Trainable Initial State RNN¶
Treat the initial state(s) of TensorFlow Keras recurrent neural network (RNN) layers as a parameter or parameters to be learned during training, as recommended in, e.g., [1].
Ordinary RNNs use an all-zero initial state by default. Why not let the neural network learn a smarter initial state?
The trainable_initial_state_rnn
package provides a class
TrainableInitialStateRNN
that can wrap any
tf.keras
RNN (or bidirectional RNN) and manage new initial state
variables in addition to the RNN’s weights.
Typical usage looks as follows.
import tensorflow as tf
from trainable_initial_state_rnn import TrainableInitialStateRNN
base_rnn = tf.keras.layers.LSTM(256)
rnn = TrainableInitialStateRNN(base_rnn) # Treats initial state as a variable!
model = tf.keras.Model(...) # Use rnn like any other tf.keras layer in your model
model.compile(...)
history = model.fit(...)