Trainable Initial State RNN

Python Version PyPI Package Version Last Commit GitHub Actions Build Status Code Coverage Documentation Status License

Treat the initial state(s) of TensorFlow Keras recurrent neural network (RNN) layers as a parameter or parameters to be learned during training, as recommended in, e.g., [1].

Ordinary RNNs use an all-zero initial state by default. Why not let the neural network learn a smarter initial state?

The trainable_initial_state_rnn package provides a class TrainableInitialStateRNN that can wrap any tf.keras RNN (or bidirectional RNN) and manage new initial state variables in addition to the RNN’s weights.

Typical usage looks as follows.

import tensorflow as tf
from trainable_initial_state_rnn import TrainableInitialStateRNN

base_rnn = tf.keras.layers.LSTM(256)
rnn = TrainableInitialStateRNN(base_rnn)  # Treats initial state as a variable!

model = tf.keras.Model(...)  # Use rnn like any other tf.keras layer in your model
model.compile(...)
history = model.fit(...)

References

[1]Felix A. Gers, Nicol N. Schraudolph, Jürgen Schmidhuber. Learning Precise Timing with LSTM Recurrent Networks. Journal of Machine Learning Research 3 (2002) 115-143. (Link)