zimmerrol / keras-utility-layer-collection Goto Github PK
View Code? Open in Web Editor NEWCollection of custom layers and utility functions for Keras which are missing in the main framework.
License: MIT License
Collection of custom layers and utility functions for Keras which are missing in the main framework.
License: MIT License
While trying out MultiHeadAttention I got this error:
TypeError: Failed to convert object of type <class 'tuple'> to Tensor. Contents: (100, Dimension(100)). Consider casting elements to a supported type.
and I fixed it like this:
def build(self, input_shape):
self._validate_input_shape(input_shape)
d_k = self._d_k if self._d_k else input_shape[1][-1]
d_model = self._d_model if self._d_model else input_shape[1][-1]
d_v = self._d_v
if type(d_k) == tf.Dimension:
d_k = d_k.value
if type(d_model) == tf.Dimension:
d_model = d_model.value
self._q_layers = []
self._k_layers = []
self._v_layers = []
self._sdp_layer = ScaledDotProductAttention(return_attention=self._return_attention)
Great project!
Hello,
It seems as if you may have developed this on a version of Keras with a different API, as when I try to use this, I get an error where you try to access the third element of the state tuple on Keras 2.2.0:
File "model.py", line 485, in <module>
model.create_models()
File "model.py", line 262, in create_models
initial_state=[encoder_output])
File "/root/anaconda3/lib/python3.6/site-packages/keras/engine/base_layer.py", line 460, in __call__
output = self.call(inputs, **kwargs)
File "/root/anaconda3/lib/python3.6/site-packages/kulc-0.0.5-py3.6.egg/kulc/attention.py", line 395, in call
input_length=input_shape[1]
File "/root/anaconda3/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py", line 2868, in rnn
outputs, _ = step_function(inputs[0], initial_states + constants)
File "/root/anaconda3/lib/python3.6/site-packages/kulc-0.0.5-py3.6.egg/kulc/attention.py", line 356, in step
total_x_prod = states[3]
IndexError: list index out of range
The problem appears to be here https://github.com/FlashTek/keras-utility-layer-collection/blob/master/kulc/attention.py#L350-L356
If you intended for this to be used with a particular version of Keras, you could use a requirements.txt file to indicate the version of Keras you wanted to use.
Thanks for working on this it seems super useful!
The training model is good for training. Could you give me an example to explain how to predict the training model?
I used:
encoder_input = ks.layers.Input(shape=(90,))
embed = Embedding(input_dim=598, output_dim=512, input_length=90, mask_zero=True)
encoder_inputs = embed(encoder_input)
……
I tried ‘’SequenceAttention‘’ and ‘’AttentionRNNWrapper‘’, then both shows"Layer does not support masking……"
I am try to use
encoder = GRU(embedding_size, return_sequences=True, return_state=True, recurrent_dropout=0.1)
attented_encoder = ExternalAttentionRNNWrapper(encoder, return_attention=True)
but got error in wrapper class
super(ExternalAttentionRNNWrapper, self).init(layer, **kwargs)
File "C:\Users\NITS\Anaconda3\lib\site-packages\tensorflow\python\keras\layers\wrappers.py", line 52, in init
assert isinstance(layer, Layer)
AssertionError
print(type(encoder))
keras.layers.recurrent.LSTM but in Wrapper
assert isinstance(layer, Layer) error
self.layer = layer
# Tracks mapping of Wrapper inputs to inner layer inputs. Useful when
# the inner layer has update ops that depend on its inputs (as opposed
# to the inputs to the Wrapper layer).
self._input_map = {}
super(Wrapper, self).init(**kwargs)
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.