Giter Site home page Giter Site logo

clkd's Introduction

Cross-Lingual Knowledge Distillation

PyTorch Lightning Config: Hydra Template

πŸ“ŒΒ Β Introduction

This project enables to distill multilingual transformers into language-specific students. Features contain:

  • Adjust any distillation loss
  • Any number of students and languages per student
  • Change the teacher and student architecture
  • Choose between monolingual, bilingual, or multilingual distillation setup
  • Component Sharing across students
  • Initialization of students from teacher layers

πŸš€Β Β Quickstart

Configure your environment first.

# clone project
git clone https://github.com/MinhDucBui/clkd.git
cd clkd

# [OPTIONAL] create conda environment
conda create -n myenv python=3.9
conda activate myenv

# install requirements
pip install -r requirements.txt

# Please make sure to have the right Pytorch Version: Install pytorch according to instructions
# https://pytorch.org/get-started/

Download English & Turkish dataset from cc100 here. Alternatively, to speed up the downloading process, download Urdu-Swahili.

# change to data folder
cd data/cc100

# Download English Data (82GB) and Turkish Data (5.4GB)
wget http://data.statmt.org/cc-100/en.txt.xz
wget http://data.statmt.org/cc-100/tr.txt.xz

# Alternative: Urdu (884MB) and Swahili (332MB)
wget http://data.statmt.org/cc-100/ur.txt.xz
wget http://data.statmt.org/cc-100/sw.txt.xz

# Change back to original folder
cd ..
cd ..

Execute the main script. The default setting uses the same strategy as MonoShot.

# Choose GPU Device
export CUDA_VISIBLE_DEVICES=0

# execute main script
python run.py

# Alternativ: Urdu-Swahili pair
python run.py experiment=monolingual_urdu_swahili

⚑  Your Superpowers

Change Distillation Loss

Hydra allows you to easily overwrite any parameter defined in your config. See students/individuals/loss for all loss functions.

python run.py students/individual/loss=monoalignment

To contruct your own distillation loss, we provide bass losses, that can be used to construct the final loss. Furthermore, we provide all distillation losses used in this thesis here.

Example of constucting the distillation loss from the MLM loss and logit distillation with CE loss with equal weighting.

_target_: src.loss.loss.GeneralLoss
defaults:
  - base_loss@base_loss.mlm: mlm.yaml
  - base_loss@base_loss.softtargets_ce: softtargets_ce.yaml

base_loss:
  softtargets_ce:
    temperature: 4.0

loss_weighting:
  mlm: 0.5
  softtargets_ce: 0.5
Change Student Number and Languages

We constructed some default configs for different scenarios:

# monolingual setting with english-turkish language pair
python train.py experiment=monolingual

# monolingual setting with english-basque language pair
python train.py experiment=monolingual_eu

# monolingual setting with english-turkish language pair
python train.py experiment=monolingual_sw

# monolingual setting with english-turkish language pair
python train.py experiment=monolingual_ur

# bilingual setting with english-turkish language pair
python train.py experiment=monolingual_bilingual

To construct a custom setting, please see the documentation here.

Embedding Sharing across Students
# Share language embeddings only in each student, not across students.
python run.py students.embed_sharing="in_each_model" 

To construct a custom setting, please see the documentation here.

Layer Sharing across Students

Please see the documentation here.

Change Student Architecture
# Use the same architecture as the teacher
python run.py students/individual/model=from_teacher

More architectures can be found here.

Student Initialization

Default uses weights from the teacher.

# Randomly Initialize Embedding Weights
python run.py students.individual.model.weights_from_teacher.embeddings=False
  
# Randomly Initialize Layer Weights
python run.py students.individual.model.weights_from_teacher.transformer_blocks=False

ℹ️  Project Structure

The directory structure of new project looks like this:


β”œβ”€β”€ configs                 <- Hydra configuration files
β”‚   β”œβ”€β”€ callbacks               <- Callbacks configs
β”‚   β”œβ”€β”€ collate_fn              <- Collate functions configs
β”‚   β”œβ”€β”€ datamodule              <- Datamodule configs
β”‚   β”œβ”€β”€ distillation_setup      <- Distillation configs
β”‚   β”œβ”€β”€ evaluation              <- Evaluation configs
β”‚   β”œβ”€β”€ experiment              <- Experiment configs
β”‚   β”œβ”€β”€ hydra                   <- Hydra related configs
β”‚   β”œβ”€β”€ logger                  <- Logger configs
β”‚   β”œβ”€β”€ students                <- Student configs
β”‚   β”œβ”€β”€ teacher                 <- Teacher configs
β”‚   β”œβ”€β”€ trainer                 <- Trainer configs
β”‚   β”‚
β”‚   └── config.yaml             <- Main project configuration file
β”‚
β”œβ”€β”€ data                    <- Project data
β”‚
β”œβ”€β”€ logs                    <- Logs generated by Hydra and PyTorch Lightning loggers
β”‚
β”œβ”€β”€ src
β”‚   β”œβ”€β”€ callbacks               <- Lightning callbacks
β”‚   β”œβ”€β”€ datamodules             <- Lightning datamodules
β”‚   β”œβ”€β”€ distillation            <- Distillation Setup Files
β”‚   β”œβ”€β”€ evaluation              <- Evaluation Files
β”‚   β”œβ”€β”€ los                     <- Loss Files
β”‚   β”œβ”€β”€ models                  <- Lightning models
β”‚   β”œβ”€β”€ utils                   <- Utility scripts
β”‚   β”‚
β”‚   └── train.py                <- Training pipeline
β”‚
β”œβ”€β”€ run.py                  <- Run pipeline with chosen experiment configuration
β”‚
β”œβ”€β”€ .env.example            <- Template of the file for storing private environment variables
β”œβ”€β”€ .gitignore              <- List of files/folders ignored by git
β”œβ”€β”€ .pre-commit-config.yaml <- Configuration of automatic code formatting
β”œβ”€β”€ setup.cfg               <- Configurations of linters and pytest
β”œβ”€β”€ Dockerfile              <- File for building docker container
β”œβ”€β”€ requirements.txt        <- File for installing python dependencies
β”œβ”€β”€ LICENSE
└── README.md

clkd's People

Contributors

ashleve avatar marlenakuc95 avatar minhducbui avatar adizx12 avatar luciennnnnnn avatar hotthoughts avatar zhengyu-yang avatar sirtris avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    πŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. πŸ“ŠπŸ“ˆπŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❀️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.