The weights need to be registered as trainable weights for Keras.
import tensorflow_hub as hub
from keras import backend as K
from keras.engine import Layer
class ElmoEmbeddingLayer(Layer):
def __init__(self, **kwargs):
self.dimensions = 1024
super(ElmoEmbeddingLayer, self).__init__(trainable=True, **kwargs)
def build(self, input_shape):
self.elmo = hub.Module('https://tfhub.dev/google/elmo/2', trainable=self.embed_trainable,
name="{}_module".format(self.name))
self.trainable_weights += K.tf.trainable_variables(scope="^{}_module/.*".format(self.name))
super(ElmoEmbeddingLayer, self).build(input_shape)
def call(self, x, mask=None):
lengths = K.cast(K.argmax(K.cast(K.equal(x, '--PAD--'), 'uint8')), 'int32')
result = self.elmo(inputs=dict(tokens=x, sequence_len=lengths),
as_dict=True,
signature='tokens',
)['elmo']
return result
def compute_mask(self, inputs, mask=None):
return K.not_equal(inputs, '--PAD--')
def compute_output_shape(self, input_shape):
return input_shape + (self.dimensions,)