Giter Site home page Giter Site logo

khoidoo / tvlars Goto Github PK

View Code? Open in Web Editor NEW
0.0 1.0 0.0 1.12 MB

TVLARS - A Fast Convergence Optimizer for Large Batch Training

License: MIT License

Python 100.00%
large-batch-optimization layer-wise-relevance-propagation learning-rate learning-rate-scheduling optimization optimization-algorithms warmup

tvlars's Introduction

TVLARS - Pytorch Official Implementation

Abstract

LARS and LAMB have emerged as prominent techniques in Large Batch Learning (LBL), ensuring the stability of AI training. One of the primary challenges in LBL is stabilizing convergence, where the AI agent is usually get trapped in the sharp minimizer. Addressing this challenge, a relatively recent technique, known as warmup, has been employed. However, it's worth noting that warmup lacks a strong theoretical foundation, leaving the door open for further exploration of more efficacious algorithms. In light of this situation, we conducted empirical experiments to analyze the behaviors of the two most popular optimizers in the LARS family: LARS and LAMB, with and without a warm-up strategy. Our analysis gives us a comprehension of the novel LARS, LAMB, and the necessity of a warmup technique in LBL. Building upon these insights, we propose a novel algorithm called Time Varying LARS (TVLARS), which facilitates robust training in the initial phase without the need for warm-up. Experimental evaluation demonstrates that TVLARS achieves competitive results with LARS and LAMB when warm-up is utilized while surpassing their performance without the warm-up technique.

Experiment

Setup

This work can be conducted on any platform: Windows, Ubuntu, Google Colab. In Windows or Ubuntu use the following script to create a virtual environment.

git clone https://github.com/KhoiDOO/tvlars.git
cd path/to/tvlars
python -m venv .env

The Python packages used in this project are listed below. Crucially, parquet and pyarrow are used for writing and saving .parquet file, which is a strongly compressed file for saving the DataFrame. All the packages can be installed by command pip install -r requirements.txt. If parquet does not work with your machine, consider using fastparquet instead.

matplotlib==3.7.1
numpy==1.24.3
pandas==2.0.1
parquet==1.3.1
pyarrow==12.0.0
seaborn==0.12.2
tqdm==4.65.0

Pytorch is the main package for conducting optimization calculations, whose version is 2.0.1.

Available Settings

Using python main.py -h to print out all available settings of this project. The table below show the tag as well as its related description.

TAG OPTIONS DESCRIPTION
-h, --help show this help message and exit
--bs BS batch size
--workers WORKERS Number of processor used in data loader
--epochs EPOCHS Number of epochs used in training
--lr LR initial learning rate
--seed SEED seed for initializing training
--port PORT Multi-GPU Training Port
--wd W weight decay
--ds cifar10, cifar100, tinyimagenet data set name
--model resnet18, resnet34, resnet50, effb0 model used in training
--opt adam, adamw, adagrad, rmsprop, lars, tvlars, lamb optimizer used in training
--sd None, cosine, lars-warm learning rate scheduler used in training
--dv DV [DV ...] list of devices used in training
--lmbda LMBDA delay factor used in TVLARS
--cl_epochs CL_EPOCHS epoch used in Barlow twins feature redundant removal stage
--btlmbda BTLMBDA lambda factor used in Barlow Twins
--projector PROJECTOR dimensions of top Multilayer Perceptron used in Barlow Twins
--lr_classifier LR_CLASSIFIER classifier learning rate used in Barlow Twins
--lr_backbone LR_BACKBONE backbone learning rate used in Barlow Twins
--mode clf, bt experiment mode, clf is for classification, bt is for Barlow Twins experiment

Running

For instance, the experiment of TVLARS with batch size ($\mathcal{B}$) of 512 and various delay factor ($\lambda$) values by the following expressions: Classification Experiment

python main.py --bs 512 --epochs 100 --lr 1.0 --port 7046 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 1e-06 --dv 0 1 2 3
python main.py --bs 512 --epochs 100 --lr 1.0 --port 3675 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 1e-05 --dv 0 1 2 3
python main.py --bs 512 --epochs 100 --lr 1.0 --port 6162 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 0.0001 --dv 0 1 2 3
python main.py --bs 512 --epochs 100 --lr 1.0 --port 3930 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 0.001 --dv 0 1 2 3
python main.py --bs 512 --epochs 100 --lr 1.0 --port 7644 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 0.005 --dv 0 1 2 3
python main.py --bs 512 --epochs 100 --lr 1.0 --port 5794 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 0.01 --dv 0 1 2 3
python main.py --bs 512 --epochs 100 --lr 2.0 --port 3976 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 1e-06 --dv 0 1 2 3
python main.py --bs 512 --epochs 100 --lr 2.0 --port 5895 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 1e-05 --dv 0 1 2 3
python main.py --bs 512 --epochs 100 --lr 2.0 --port 5014 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 0.0001 --dv 0 1 2 3
python main.py --bs 512 --epochs 100 --lr 2.0 --port 6423 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 0.001 --dv 0 1 2 3
python main.py --bs 512 --epochs 100 --lr 2.0 --port 5228 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 0.005 --dv 0 1 2 3
python main.py --bs 512 --epochs 100 --lr 2.0 --port 6169 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 0.01 --dv 0 1 2 3
python main.py --bs 512 --epochs 100 --lr 3.0 --port 5466 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 1e-06 --dv 0 1 2 3
python main.py --bs 512 --epochs 100 --lr 3.0 --port 7422 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 1e-05 --dv 0 1 2 3
python main.py --bs 512 --epochs 100 --lr 3.0 --port 6373 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 0.0001 --dv 0 1 2 3
python main.py --bs 512 --epochs 100 --lr 3.0 --port 6592 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 0.001 --dv 0 1 2 3
python main.py --bs 512 --epochs 100 --lr 3.0 --port 4802 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 0.005 --dv 0 1 2 3
python main.py --bs 512 --epochs 100 --lr 3.0 --port 7327 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 0.01 --dv 0 1 2 3

Barlow Twins Experiment

python main.py --bs 512 --epochs 100 --cl_epochs 1000 --lr 1.0 --port 7186 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 1e-06 --dv 0 1 2 3 --mode bt
python main.py --bs 512 --epochs 100 --cl_epochs 1000 --lr 1.0 --port 4111 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 1e-05 --dv 0 1 2 3 --mode bt
python main.py --bs 512 --epochs 100 --cl_epochs 1000 --lr 1.0 --port 4356 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 0.0001 --dv 0 1 2 3 --mode bt
python main.py --bs 512 --epochs 100 --cl_epochs 1000 --lr 1.0 --port 7782 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 0.001 --dv 0 1 2 3 --mode bt
python main.py --bs 512 --epochs 100 --cl_epochs 1000 --lr 1.0 --port 4353 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 0.005 --dv 0 1 2 3 --mode bt
python main.py --bs 512 --epochs 100 --cl_epochs 1000 --lr 1.0 --port 6524 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 0.01 --dv 0 1 2 3 --mode bt
python main.py --bs 512 --epochs 100 --cl_epochs 1000 --lr 2.0 --port 3979 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 1e-06 --dv 0 1 2 3 --mode bt
python main.py --bs 512 --epochs 100 --cl_epochs 1000 --lr 2.0 --port 4969 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 1e-05 --dv 0 1 2 3 --mode bt
python main.py --bs 512 --epochs 100 --cl_epochs 1000 --lr 2.0 --port 3517 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 0.0001 --dv 0 1 2 3 --mode bt
python main.py --bs 512 --epochs 100 --cl_epochs 1000 --lr 2.0 --port 7895 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 0.001 --dv 0 1 2 3 --mode bt
python main.py --bs 512 --epochs 100 --cl_epochs 1000 --lr 2.0 --port 4434 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 0.005 --dv 0 1 2 3 --mode bt
python main.py --bs 512 --epochs 100 --cl_epochs 1000 --lr 2.0 --port 7770 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 0.01 --dv 0 1 2 3 --mode bt
python main.py --bs 512 --epochs 100 --cl_epochs 1000 --lr 3.0 --port 5348 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 1e-06 --dv 0 1 2 3 --mode bt
python main.py --bs 512 --epochs 100 --cl_epochs 1000 --lr 3.0 --port 4362 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 1e-05 --dv 0 1 2 3 --mode bt
python main.py --bs 512 --epochs 100 --cl_epochs 1000 --lr 3.0 --port 6193 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 0.0001 --dv 0 1 2 3 --mode bt
python main.py --bs 512 --epochs 100 --cl_epochs 1000 --lr 3.0 --port 6442 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 0.001 --dv 0 1 2 3 --mode bt
python main.py --bs 512 --epochs 100 --cl_epochs 1000 --lr 3.0 --port 7169 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 0.005 --dv 0 1 2 3 --mode bt
python main.py --bs 512 --epochs 100 --cl_epochs 1000 --lr 3.0 --port 7954 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 0.01 --dv 0 1 2 3 --mode bt

Citation

@misc{do2023revisiting,
      title={Revisiting LARS for Large Batch Training Generalization of Neural Networks}, 
      author={Khoi Do and Duong Nguyen and Hoa Nguyen and Long Tran-Thanh and Quoc-Viet Pham},
      year={2023},
      eprint={2309.14053},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

tvlars's People

Contributors

khoidoo avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.