Giter Site home page Giter Site logo

gazelle93 / transformer-various-positional-encoding Goto Github PK

View Code? Open in Web Editor NEW
20.0 1.0 1.0 28 KB

This project aims to implement the Transformer Encoder blocks using various Positional Encoding methods.

Python 100.00%
nlp gensim natural-language-processing nltk pytorch relative-positional-encoding relative-positional-representation spacy t5 transformer

transformer-various-positional-encoding's Introduction

Overview

  • After the emergence of Attention, the language models leveraging the attention layer show the best performance in various NLP tasks. Attention allows attending to utilize the most relevant parts of the input sequence by leveraging the attention score which is a weighted result of all of the encoded input vectors simultaneously. Therefore, attention layers are able to increase the learning speed through parallelization without the restrictions appearing in such sequential architectures. This project aims to implement the Transformer Encoder blocks using Absolute Positional Encoding, Relative Position representation of Shaw et al. (2018) and Relative Position representation of Raffel et al. (2019).

  • Attention score using Absolute Positional Encoding: $$\alpha_{ij}^{Abs} = \frac{1}{\sqrt{d}}((w_i+p_i)W^{Q,1})(w_j+p_j)W^{K,1})^T$$ where $w_i$ is word embedding, $p_i$ is absolute positional encoding, $W^{Q,1}$ and $W^{K,1}$ is corresponding weight of query and key. The absolute positional encoding method is applied to represent the position representation of tokens in Transformer-based architecture. This representation is summed to the word embeddings at the input level.

  • Attention score using Relative Position representation of Shaw et al. (2018): $$\alpha_{ij}^{Rel} = \frac{1}{\sqrt{d}}((w_i+p_i)W^{Q,l})((w_j+p_j)W^{K,l}+a_{j-i}^l)^T$$ where $a_{j-i}^l$ is a learnable parameter that represents the embedding of the relative position $jāˆ’i$ in layer $l$. In order to represent longer sentences in the generalized auto-regressive pretraining model (XLNet), the relative positional encoding is applied to represent the positional representation in multiple segments. This representation is applied at the self-attention-mechanism level, not at the input level.

  • Attention score using Relative Position representation of Raffel et al. (2019): $$\alpha_{ij}^{T5} = \frac{1}{\sqrt{d}}((w_i+p_i)W^{Q,l})((w_j+p_j)W^{K,l})^T+b_{j-i}$$ where $b_{j-i}$ is a learnable parameter that represents the embedding of the relative position $jāˆ’i$ and this is shared in all layers.

Brief description

  • text_processing.py

Output format

  • output: Tokenized result of a given text. (list)
  • my_onehot.py

Output format

  • output: List of tensor of input tokens. (Tensor)
  • attentions.py

Output format

  • output: List of tensor of attention results. (Tensor)
  • transformers.py

Output format

  • output: model (Transformer Encoder Model), output (Last hidden states of the model (Tensor)), output_list (Hidden states of layer 1 to 12 (Tensor)), attn_score_list (Attention scores of layer 1 to 12 (Tensor))

Prerequisites

  • argparse
  • torch
  • stanza
  • spacy
  • nltk
  • gensim

Parameters

  • nlp_pipeline(str, defaults to "stanza"): NLP preprocessing pipeline.
  • unk_ignore(bool, defaults to True): Ignore unseen word or not.
  • num_heads(int, defaults to 8): The number of heads for multi-head attention.
  • num_layers(int, defaults to 12): The number of transformer encoder blocks.
  • positional_encoding(str, defaults to "abs"): Type of positional encoding. (abs, rel, t5)

References

  • Attention: Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., ... & Polosukhin, I. (2017). Attention is all you need. Advances in neural information processing systems, 30.
  • Relative Postion Representation: Shaw, P., Uszkoreit, J., & Vaswani, A. (2018). Self-attention with relative position representations. arXiv preprint arXiv:1803.02155.
  • Transformer-XL: Dai, Z., Yang, Z., Yang, Y., Carbonell, J., Le, Q. V., & Salakhutdinov, R. (2019). Transformer-xl: Attentive language models beyond a fixed-length context. arXiv preprint arXiv:1901.02860.
  • XLNet: Yang, Z., Dai, Z., Yang, Y., Carbonell, J., Salakhutdinov, R. R., & Le, Q. V. (2019). Xlnet: Generalized autoregressive pretraining for language understanding. Advances in neural information processing systems, 32.
  • T5 Relative Postion Representation: Raffel, C., Shazeer, N., Roberts, A., Lee, K., Narang, S., Matena, M., ... & Liu, P. J. (2019). Exploring the limits of transfer learning with a unified text-to-text transformer. arXiv preprint arXiv:1910.10683.
  • Stanza: Qi, P., Zhang, Y., Zhang, Y., Bolton, J., & Manning, C. D. (2020). Stanza: A Python natural language processing toolkit for many human languages. arXiv preprint arXiv:2003.07082.
  • Spacy: Matthew Honnibal and Ines Montani. 2017. spaCy 2: Natural language understanding with Bloom embeddings, convolutional neural networks and incremental parsing. To appear (2017).
  • NLTK: Bird, Steven, Edward Loper and Ewan Klein (2009). Natural Language Processing with Python. O'Reilly Media Inc.
  • Gensim: Rehurek, R., & Sojka, P. (2010). Software framework for topic modelling with large corpora. In In Proceedings of the LREC 2010 workshop on new challenges for NLP frameworks.

transformer-various-positional-encoding's People

Contributors

gazelle93 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

Forkers

marziehphi

transformer-various-positional-encoding's Issues

Bug in MHA

Hi, I think there is a bug in the MHA implementation.

If we have the output of:
output, attn_score = self.relative_scaled_dot_attn(query, key, value, a_key, a_value, mask)

this is [B * NH, T, HD] (B - batch, NH - num heads, T - timestep, HD - head dim).

Line 189 of attentions.py then does:

output = output.view(self.num_heads, batch_size, -1, self.head_dim)
output = self.reshape_to_concat(batch_size, output)

But in the implementation of reshape_to_concat:

    def reshape_to_concat(self, batch_size, _tensor):
        # before shape: (batch size, number of heads, input length, head dimension)
        # after shape: (batch size, input length, number of heads, head dimension)
        _tensor = _tensor.permute(0, 2, 1, 3)
        return _tensor.contiguous().view(batch_size, -1, self.num_heads * self.head_dim)

Here we're actually passing in [NH, B, T, HD], rather than [N, NH, T, HD].

This then gets permuted tp [NH, T, B, HD] before being reshaped, which is wrong because we obviously want the tensor to be [B, T, NH, HD] here.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    šŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. šŸ“ŠšŸ“ˆšŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ā¤ļø Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.