Giter Site home page Giter Site logo

2D attention about keras-attention HOT 6 CLOSED

philipperemy avatar philipperemy commented on June 15, 2024
2D attention

from keras-attention.

Comments (6)

philipperemy avatar philipperemy commented on June 15, 2024

@raghavgurbaxani The 3D block expects a 3D tensor with shape (batch_size, time_steps, input_dim).
You can always set input_dim=1, by using a Reshape layer or Lambda with K.expand_dims(..., axis=-1)
That way you turn your 2D input into a 3D with the last dimension being one.

from keras-attention.

raghavgurbaxani avatar raghavgurbaxani commented on June 15, 2024

Hi @philipperemy

Thank you so much for your help ! I tried tf.expand(dims,axis=-1) and I am able to compile my code successfully - however it doesn't train well.

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
features (InputLayer)           [(None, 16, 1816)]   0                                            
__________________________________________________________________________________________________
lstm (LSTM)                     (None, 2048)         31662080    features[0][0]                   
__________________________________________________________________________________________________
tf_op_layer_ExpandDims (TensorF [(None, 2048, 1)]    0           lstm[0][0]                       
__________________________________________________________________________________________________
attention_score_vec (Dense)     (None, 2048, 1)      1           tf_op_layer_ExpandDims[0][0]     
__________________________________________________________________________________________________
last_hidden_state (Lambda)      (None, 1)            0           tf_op_layer_ExpandDims[0][0]     
__________________________________________________________________________________________________
attention_score (Dot)           (None, 2048)         0           attention_score_vec[0][0]        
                                                                 last_hidden_state[0][0]          
__________________________________________________________________________________________________
attention_weight (Activation)   (None, 2048)         0           attention_score[0][0]            
__________________________________________________________________________________________________
context_vector (Dot)            (None, 1)            0           tf_op_layer_ExpandDims[0][0]     
                                                                 attention_weight[0][0]           
__________________________________________________________________________________________________
attention_output (Concatenate)  (None, 2)            0           context_vector[0][0]             
                                                                 last_hidden_state[0][0]          
__________________________________________________________________________________________________
attention_vector (Dense)        (None, 128)          256         attention_output[0][0]           
__________________________________________________________________________________________________
dense (Dense)                   (None, 1024)         132096      attention_vector[0][0]           
__________________________________________________________________________________________________
leaky_re_lu (LeakyReLU)         (None, 1024)         0           dense[0][0]                      
__________________________________________________________________________________________________
dropout (Dropout)               (None, 1024)         0           leaky_re_lu[0][0]                
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 120)          123000      dropout[0][0]                    
__________________________________________________________________________________________________
feature_weights (InputLayer)    [(None, 120)]        0                                            
__________________________________________________________________________________________________
multiply (Multiply)             (None, 120)          0           dense_1[0][0]                    
                                                                 feature_weights[0][0]            
==================================================================================================
Total params: 31,917,433
Trainable params: 31,917,433
Non-trainable params: 0
__________________________________________________________________________________________________

While training I get the error -

File "/mnt/ext/raghav/conda/envs/lib/python3.7/site-packages/tensorflow/python/keras/saving/save.py", line 113, in save_model model, filepath, overwrite, include_optimizer) File "/mnt/ext/raghav/conda/envs/lib/python3.7/site-packages/tensorflow/python/keras/saving/hdf5_format.py", line 101, in save_model_to_hdf5 default=serialization.get_json_type).encode('utf8') File "/mnt/ext/raghav/conda/envs/lib/python3.7/json/__init__.py", line 238, in dumps **kw).encode(obj) File "/mnt/ext/raghav/conda/envs/lib/python3.7/json/encoder.py", line 199, in encode chunks = self.iterencode(o, _one_shot=True) File "/mnt/ext/raghav/conda/envs/lib/python3.7/json/encoder.py", line 257, in iterencode return _iterencode(o, 0) File "/mnt/ext/raghav/conda/envs/lib/python3.7/site-packages/tensorflow/python/util/serialization.py", line 69, in get_json_type raise TypeError('Not JSON Serializable:', obj) TypeError: ('Not JSON Serializable:', b'\n\nExpandDims\x12\nExpandDims\x1a\x14lstm/strided_slice_7\x1a\x0eExpandDims/dim*\x07\n\x01T\x12\x020\x01*\n\n\x04Tdim\x12\x020\x03')

Any idea why this occurs ?

from keras-attention.

philipperemy avatar philipperemy commented on June 15, 2024

Yes you have to use Lambda(lambda z: K.expand_dims(z, axis=-1)). K.expand_dims(z, axis=-1) is not a layer and that's why keras is complaining. Use it inside a Sequential. Or consider this lambda as any other layer.
With the imports:

from tensorflow.keras.layers import Lambda
import tensorflow.keras.backend as K

from keras-attention.

philipperemy avatar philipperemy commented on June 15, 2024

I'll close this issue for now. Let me know if you need more help.

from keras-attention.

raghavgurbaxani avatar raghavgurbaxani commented on June 15, 2024

@philipperemy thank you so much :) It worked

from keras-attention.

philipperemy avatar philipperemy commented on June 15, 2024

@raghavgurbaxani GREAT!

from keras-attention.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.