Giter Site home page Giter Site logo

pytorch-bertflow's Introduction

Pytorch-bertflow

This is an re-implemented version of BERT-flow using Pytorch framework, which can reproduce the results from the original repo. This code is used to reproduce the results in the TSDAE paper.

Usage

Please refer to the simple example ./example.py

python example.py

Note

  • Please shuffle your training data, which makes a huge difference.
  • The pooling function makes a huge difference in some datasets (especially for the ones used in the paper). To reproduce the results, please use 'first-last-avg'.

Contact

Contact person and main contributor: Kexin Wang, [email protected]

https://www.ukp.tu-darmstadt.de/

https://www.tu-darmstadt.de/

Don't hesitate to send us an e-mail or report an issue, if something is broken (and it shouldn't be) or if you have further questions.

This repository contains experimental software and is published for the sole purpose of giving additional background details on the respective publication.

pytorch-bertflow's People

Contributors

kwang2049 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

pytorch-bertflow's Issues

Reproducing the results of bertflow in the original paper

Hi, thank you for your great work! I am grateful to your pytorch-bertflow framework and I am using it to reproduce original bertflow experiments. But the result(SRCC on STS-B) is always lower than it is reported in the paper. I guess there are some details I ignore when reproducing.
loss: -1.180362 [473600/551204] corrcoef_dev: 0.223951 loss: -1.119098 [480000/551204] corrcoef_dev: 0.223691 loss: -1.132908 [486400/551204] corrcoef_dev: 0.224357 loss: -1.211618 [492800/551204] corrcoef_dev: 0.225161
I choose WNLI as the training set and STS-B as the dev set. As the result shown above, SRCC is about 22, which is quite low.

Finetuning all-MiniLM-L6-v2 ValueError

Hello thank you for your contribution! I am training to fine-tune the all-MiniLM-L6-v2 (https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) on my data but after the first batch, I get a ValueError and a loss of inf.

ValueError: Expected value argument (Tensor of shape (4, 192, 1, 1)) to be within the support (Real()) of the distribution Normal(loc: torch.Size([4, 192, 1, 1]), scale: torch.Size([4, 192, 1, 1])), but found invalid values:
tensor([[[[nan]],

     [[nan]],

.....

Here is my very simple script (I just replaced the data and put the training in a loop). The error that I get is:

import pandas as pd
import numpy as np
from tflow_utils import TransformerGlow, AdamWeightDecayOptimizer
from transformers import AutoTokenizer,AutoModel

model_name_or_path = '/tmp/all-MiniLM-L6-v2'
bertflow = TransformerGlow(model_name_or_path, pooling='mean')  # pooling could be 'mean', 'max', 'cls' or 'first-last-avg' (mean pooling over the first and the last layers)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters= [
    {
        "params": [p for n, p in bertflow.glow.named_parameters()  \
                        if not any(nd in n for nd in no_decay)],  # Note only the parameters within bertflow.glow will be updated and the Transformer will be freezed during training.
        "weight_decay": 0.01,
    },
    {
        "params": [p for n, p in bertflow.glow.named_parameters()  \
                        if any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
]
optimizer = AdamWeightDecayOptimizer(
    params=optimizer_grouped_parameters, 
    lr=1e-5, 
    eps=1e-6,
)
# Important: Remember to shuffle your training data!!! This makes a huge difference!!!

np.random.seed(0)
df = pd.read_csv("data/classification/data_small.csv")
data = df.text.to_list().copy()
np.random.shuffle(data)


bertflow.train()
batch_size = 4
nb_batch = int(np.ceil(len(data) / batch_size))
print(nb_batch)
for batch_id in range(nb_batch):
    batch = data[batch_id*batch_size:(batch_id+1)*batch_size]
    model_inputs = tokenizer(
        batch,
        add_special_tokens=True,
        return_tensors='pt',
        max_length=256,
        padding='longest',
        truncation=True
    )
    z, loss = bertflow(model_inputs['input_ids'], model_inputs['attention_mask'], return_loss=True)  # Here z is the sentence embedding
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    print(batch_id, loss)

Do you have any ideas where this could come from ? I have tried different learning rates but it doesn't solve the problem

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.