Giter Site home page Giter Site logo

elephantmipt / bert-distillation Goto Github PK

View Code? Open in Web Editor NEW
74.0 4.0 6.0 228 KB

Distillation of BERT model with catalyst framework

License: MIT License

Shell 0.47% Python 88.02% Jupyter Notebook 10.08% Dockerfile 1.43%
catalyst bert nlp distillation distilbert rubert

bert-distillation's Introduction

Bert Distillation

For more general pipeline please follow compressors library and BERT distillation example. This project is not supported by me anymore.

logo

CodeFactor

codestyle codestyle codestyle

This project is about BERT distillation.

The goal is to distillate any BERT based on any language with convenient high-level API, reproducibility and all new GPU's features.

Features

  • various losses
  • distributed training
  • fp16
  • logging with tensorboard, wandb etc
  • catalyst framework

A Brief Inquiry

Not so far ago Hugging Face team published paper about DistilBERT model. The idea is to transfer knowledge from big student model to smaller student model.

First of all we need a well trained teacher model.

Let's take 6 encoders instead of 12! We should initialize our small model's layers with teachers layers.

Instead of train our model for a long time on masked language model task we can add to our casual loss KL divergence and cosine loss between student and teacher as we know that a teacher is well trained.

As shown in the paper this method leads to small quality decreasing, reduce model size and speed up inference especially on mobile devices.

Usage

In catalyst framework there are two ways to run your experiment: Notebook API and Config API. If you wanna run a quick flexible experiment yo should use Notebook API, but if you want to make product-ready solution you should use Config API.

Notebook API

Let's briefly take a look on a Notebook API. First of all we should do all necessary imports:

from catalyst import dl
from catalyst.contrib.data.nlp import LanguageModelingDataset
from catalyst.contrib.nn.optimizers import RAdam
from catalyst.core import MetricAggregationCallback
import pandas as pd
import torch
from torch.utils.data import DataLoader
from transformers import (
    AutoConfig,
    AutoTokenizer,
    BertForMaskedLM,
    DistilBertForMaskedLM,
)
from transformers.data.data_collator import DataCollatorForLanguageModeling

from src.callbacks import (
    CosineLossCallback,
    KLDivLossCallback,
    MaskedLanguageModelCallback,
    MSELossCallback,
    PerplexityMetricCallbackDistillation,
)
from src.data import MLMDataset
from src.runners import DistilMLMRunner
from src.models import DistilbertStudentModel, BertForMLM

Then we should load our training data, for example:

train_df = pd.read_csv("data/train.csv")
valid_df = pd.read_csv("data/valid.csv")

Next we should initialize our data loaders.

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")  # your teacher's model tokenizer

train_dataset = LanguageModelingDataset(train_df["text"], tokenizer)
valid_dataset = LanguageModelingDataset(valid_df["text"], tokenizer)

collate_fn = DataCollatorForLanguageModeling(tokenizer).collate_batch
train_dataloader = DataLoader(
    train_dataset, collate_fn=collate_fn, batch_size=2
)
valid_dataloader = DataLoader(
    valid_dataset, collate_fn=collate_fn, batch_size=2
)
loaders = {"train": train_dataloader, "valid": valid_dataloader}

The most important thing is to define our models.

teacher = BertForMLM("bert-base-uncased")
student = DistilbertStudentModel(
    teacher_model_name="bert-base-uncased",
    layers=[0, 2, 4, 7, 9, 11],  # which layers will be transfer to student
)
model = torch.nn.ModuleDict({"teacher": teacher, "student": student})

The next thing is callbacks:

callbacks = {
    "masked_lm_loss": MaskedLanguageModelCallback(),  # standard MLM loss
    "mse_loss": MSELossCallback(),  # MSE loss between student and student distributions on masked positions
    "cosine_loss": CosineLossCallback(),  # cosine loss between hidden states
    "kl_div_loss": KLDivLossCallback(),  # KL divergence between student and student distributions on masked positions 
    "loss": MetricAggregationCallback(
        prefix="loss",
        mode="weighted_sum",
        metrics={  # weights for final loss
            "cosine_loss": 1.0,
            "masked_lm_loss": 1.0,
            "kl_div_loss": 1.0,
            "mse_loss": 1.0,
        },
    ),
    "optimizer": dl.OptimizerCallback(),  # optim.step() and loss.backward() is here
    "perplexity": PerplexityMetricCallbackDistillation(),  # perplexity metric
}

Finally, run an experiment!

runner = DistilMLMRunner()
optimizer = RAdam(model.parameters(), lr=5e-5)
runner.train(
    model=model,
    optimizer=optimizer,
    loaders=loaders,
    verbose=True,
    num_epochs=10,  # epochs number
    callbacks=callbacks,
)

Config API

But what about more product-ready solution?

Here is a minimal example for config API. All yo need to do is to write your config.yml file.

model_params:  # defining our models
  _key_value: true
  teacher:
    model: BertForMLM
    model_name: "bert-base-cased"  # hugging face hub model name
  student:
    model: DistilbertStudentModel
    teacher_model_name: "bert-base-cased"

args:
  # where to look for __init__.py file
  expdir: "src"
  # store logs in this subfolder
  baselogdir: "./logs/distilbert"

# common settings for all stages
stages:
  # PyTorch loader params
  data_params:
    batch_size: 2
    num_workers: 0
    path_to_data: "./data"
    train_filename: "train.csv"
    valid_filename: "valid.csv"
    text_field: "text"
    model_name: "bert-base-uncased"
    max_sequence_length: 300
    shuffle: True

  state_params:
    main_metric: &reduced_metric loss
    minimize_metric: True

  # scheduler controls learning rate during training
  scheduler_params:
    scheduler: ReduceLROnPlateau

  # callbacks serve to calculate loss and metric,
  # update model weights, save checkpoint etc.
  callbacks_params:
    loss_aggregator:
      callback: MetricAggregationCallback
      mode: weighted_sum
      metrics:
        cosine_loss: 1.0
        masked_lm_loss: 1.0
        kl_div_loss: 1.0
        mse_loss: 1.0
      prefix: loss
    cosine_loss:
      callback: CosineLossCallback
      prefix: cosine_loss
    masked_lm_loss:
      callback: MaskedLanguageModelCallback
      prefix: masked_lm_loss
    kl_div_loss:
      callback: KLDivLossCallback
      prefix: kl_div_loss
    mse_loss:
      callback: MSELossCallback
      prefix: mse_loss
    perplexity:
      callback: PerplexityMetricCallbackDistillation
    optimizer:
      callback: OptimizerCallback
    scheduler:
      callback: SchedulerCallback
      reduced_metric: *reduced_metric

  # params specific for stage 1 called "train_val"
  train_val:
    state_params:
      num_epochs: 1
    optimizer_params:
      optimizer: RAdam
      lr: 0.00005

And then run it with catalyst:

catalyst-dl run -C config.yml --verbose

We can add distributed training and fp16:

catalyst-dl run -C config.yml --verbose --distributed --fp16

Folders

  1. bin - bash files for running pipelines
  2. configs - just place configs here
  3. docker - project Docker files for pure reproducibility
  4. examples - examples of using this project
  5. requirements - different project python requirements for docker, tests, CI, etc
  6. scripts - data preprocessing scripts, utils, everything like python scripts/.py
  7. src - model, experiment, etc - research

Contribution

Firstly let's discuss feature you want to see in this project. You can use feature request issue template.

After that you can write your code following this simple steps:

  1. Clone repository
  2. run pip install -r requirements/requirements-dev.txt -r requirements/requirements.txt
  3. write some code
  4. run catalyst-make-codestyle
  5. run catalyst-check-codestyle
  6. if exit code is not 0 refactor your code
  7. commit!

bert-distillation's People

Contributors

elephantmipt avatar kostet499 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

bert-distillation's Issues

Key Error :masked_lm_labels

KeyError Traceback (most recent call last)
in
7 verbose=True,
8 num_epochs=10, # epochs number
----> 9 callbacks=callbacks,
10 )

/opt/conda/lib/python3.6/site-packages/catalyst/dl/runner/runner.py in train(self, model, criterion, optimizer, scheduler, datasets, loaders, callbacks, logdir, resume, num_epochs, valid_loader, main_metric, minimize_metric, verbose, stage_kwargs, checkpoint_data, fp16, distributed, check, timeit, load_best_on_end, initial_seed, state_kwargs)
151 )
152 self.experiment = experiment
--> 153 utils.distributed_cmd_run(self.run_experiment, distributed)
154
155 def infer(

/opt/conda/lib/python3.6/site-packages/catalyst/utils/scripts.py in distributed_cmd_run(worker_fn, distributed, *args, **kwargs)
128 or world_size <= 1
129 ):
--> 130 worker_fn(*args, **kwargs)
131 elif local_rank is not None:
132 torch.cuda.set_device(int(local_rank))

/opt/conda/lib/python3.6/site-packages/catalyst/core/runner.py in run_experiment(self, experiment)
932 if _exception_handler_check(getattr(self, "callbacks", None)):
933 self.exception = ex
--> 934 self._run_event("on_exception")
935 else:
936 raise ex

/opt/conda/lib/python3.6/site-packages/catalyst/core/runner.py in _run_event(self, event)
744 """
745 for callback in self.callbacks.values():
--> 746 getattr(callback, event)(self)
747
748 def _batch2device(

/opt/conda/lib/python3.6/site-packages/catalyst/core/callbacks/exception.py in on_exception(self, runner)
20
21 if runner.need_exception_reraise:
---> 22 raise exception

/opt/conda/lib/python3.6/site-packages/catalyst/core/runner.py in run_experiment(self, experiment)
920 try:
921 for stage in self.experiment.stages:
--> 922 self._run_stage(stage)
923 except (Exception, KeyboardInterrupt) as ex:
924 from catalyst.core.callbacks.exception import ExceptionCallback

/opt/conda/lib/python3.6/site-packages/catalyst/core/runner.py in _run_stage(self, stage)
896 )
897 self._run_event("on_epoch_start")
--> 898 self._run_epoch(stage=stage, epoch=self.epoch)
899 self._run_event("on_epoch_end")
900

/opt/conda/lib/python3.6/site-packages/catalyst/core/runner.py in _run_epoch(self, stage, epoch)
875 self._run_event("on_loader_start")
876 with torch.set_grad_enabled(self.is_train_loader):
--> 877 self._run_loader(loader)
878 self._run_event("on_loader_end")
879

/opt/conda/lib/python3.6/site-packages/catalyst/core/runner.py in _run_loader(self, loader)
816 self.global_batch_step += 1
817 self.loader_batch_step = i + 1
--> 818 self._run_batch(batch)
819 if self.need_early_stop:
820 self.need_early_stop = False

/opt/conda/lib/python3.6/site-packages/catalyst/core/runner.py in _run_batch(self, batch)
796 self._run_event("on_batch_start")
797 self._handle_batch(batch=batch)
--> 798 self._run_event("on_batch_end")
799
800 def _run_loader(self, loader: DataLoader) -> None:

/opt/conda/lib/python3.6/site-packages/catalyst/core/runner.py in _run_event(self, event)
744 """
745 for callback in self.callbacks.values():
--> 746 getattr(callback, event)(self)
747
748 def _batch2device(

/opt/conda/lib/python3.6/site-packages/catalyst/core/callbacks/metrics.py in on_batch_end(self, runner)
84 def on_batch_end(self, runner: IRunner) -> None:
85 """Computes the metric and add it to batch metrics."""
---> 86 metric = self._compute_metric(runner) * self.multiplier
87 runner.batch_metrics[self.prefix] = metric
88

/opt/conda/lib/python3.6/site-packages/catalyst/core/callbacks/metrics.py in _compute_metric_value(self, runner)
70 def _compute_metric_value(self, runner: IRunner):
71 output = self._get_output(runner.output, self.output_key)
---> 72 input = self._get_input(runner.input, self.input_key)
73
74 metric = self.metric_fn(output, input, **self.metrics_kwargs)

/opt/conda/lib/python3.6/site-packages/catalyst/utils/dict.py in get_key_str(dictionary, key)
16 value
17 """
---> 18 return dictionary[key]
19
20

KeyError: 'masked_lm_labels'

While running the notebook API, I encountered this error. Please help me resolve this error. All I'm trying is text classification.
PS: I could not find a solution on SO.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.