Giter Site home page Giter Site logo

theivaprakasham / self_supervised Goto Github PK

View Code? Open in Web Editor NEW

This project forked from keremturgutlu/self_supervised

1.0 0.0 0.0 129.22 MB

Implementation of popular SOTA self-supervised learning algorithms as Fastai Callbacks.

License: Apache License 2.0

Jupyter Notebook 97.68% Python 2.31% Makefile 0.01%

self_supervised's Introduction

Self Supervised Learning with Fastai

Implementation of popular SOTA self-supervised learning algorithms as Fastai Callbacks.

CI PyPI DOI

Install

pip install self-supervised

Documentation

Please read the documentation here.

To go back to github repo please click here.

Algorithms

Please read the papers or blog posts before getting started with an algorithm, you may also check out documentation page of each algorithm to get a better understanding.

Here are the list of implemented self_supervised.vision algorithms:

Here are the list of implemented self_supervised.multimodal algorithms:

  • CLIP
  • CLIP-MoCo (No paper, own idea)

For vision algorithms all models from timm and fastai can be used as encoders.

For multimodal training currently CLIP supports ViT-B/32 and ViT-L/14, following best architectures from the paper.

Simple Usage

Vision

SimCLR

from self_supervised.vision.simclr import *
dls = get_dls(resize, bs)
# encoder = create_encoder("xresnet34", n_in=3, pretrained=False) # a fastai encoder
encoder = create_encoder("tf_efficientnet_b4_ns", n_in=3, pretrained=False) # a timm encoder
model = create_simclr_model(encoder, hidden_size=2048, projection_size=128)
aug_pipelines = get_simclr_aug_pipelines(size=size)
learn = Learner(dls,model,cbs=[SimCLR(aug_pipelines, temp=0.07)])
learn.fit_flat_cos(100, 1e-2)

MoCo

from self_supervised.vision.moco import *
dls = get_dls(resize, bs)
# encoder = create_encoder("xresnet34", n_in=3, pretrained=False) # a fastai encoder
encoder = create_encoder("tf_efficientnet_b4_ns", n_in=3, pretrained=False) # a timm encoder
model = create_moco_model(encoder, hidden_size=2048, projection_size=128)
aug_pipelines = get_moco_aug_pipelines(size=size)
learn = Learner(dls, model,cbs=[MOCO(aug_pipelines=aug_pipelines, K=128)])
learn.fit_flat_cos(100, 1e-2)

BYOL

from self_supervised.vision.byol import *
dls = get_dls(resize, bs)
# encoder = create_encoder("xresnet34", n_in=3, pretrained=False) # a fastai encoder
encoder = create_encoder("tf_efficientnet_b4_ns", n_in=3, pretrained=False) # a timm encoder
model = create_byol_model(encoder, hidden_size=2048, projection_size=128)
aug_pipelines = get_byol_aug_pipelines(size=size)
learn = Learner(dls, model,cbs=[BYOL(aug_pipelines=aug_pipelines)])
learn.fit_flat_cos(100, 1e-2)

SWAV

from self_supervised.vision.swav import *
dls = get_dls(resize, bs)
encoder = create_encoder("xresnet34", n_in=3, pretrained=False) # a fastai encoder
encoder = create_encoder("tf_efficientnet_b4_ns", n_in=3, pretrained=False) # a timm encoder
model = create_swav_model(encoder, hidden_size=2048, projection_size=128)
aug_pipelines = get_swav_aug_pipelines(num_crops=[2,6],
                                       crop_sizes=[128,96], 
                                       min_scales=[0.25,0.05],
                                       max_scales=[1.0,0.3])
learn = Learner(dls, model, cbs=[SWAV(aug_pipelines=aug_pipelines, crop_assgn_ids=[0,1], K=bs*2**6, queue_start_pct=0.5)])
learn.fit_flat_cos(100, 1e-2)

Barlow Twins

from self_supervised.vision.simclr import *
dls = get_dls(resize, bs)
# encoder = create_encoder("xresnet34", n_in=3, pretrained=False) # a fastai encoder
encoder = create_encoder("tf_efficientnet_b4_ns", n_in=3, pretrained=False) # a timm encoder
model = create_barlow_twins_model(encoder, hidden_size=2048, projection_size=128)
aug_pipelines = get_barlow_twins_aug_pipelines(size=size)
learn = Learner(dls,model,cbs=[BarlowTwins(aug_pipelines, lmb=5e-3)])
learn.fit_flat_cos(100, 1e-2)

DINO

from self_supervised.models.vision_transformer import *
from self_supervised.vision.dino import *
dls = get_dls(resize, bs)

deits16 = MultiCropWrapper(deit_small(patch_size=16, drop_path_rate=0.1))
dino_head = DINOHead(deits16.encoder.embed_dim, 2**16, norm_last_layer=True)
student_model = nn.Sequential(deits16,dino_head)

deits16 = MultiCropWrapper(deit_small(patch_size=16))
dino_head = DINOHead(deits16.encoder.embed_dim, 2**16, norm_last_layer=True)
teacher_model = nn.Sequential(deits16,dino_head)

dino_model = DINOModel(student_model, teacher_model)
aug_pipelines = get_dino_aug_pipelines(num_crops=[2,6],
                                       crop_sizes=[128,96], 
                                       min_scales=[0.25,0.05],
                                       max_scales=[1.0,0.3])
 learn = Learner(dls,model,cbs=[DINO(aug_pipelines=aug_pipelines)])
learn.fit_flat_cos(100, 1e-2)

SupCon

from self_supervised.vision.supcon import *
dls = get_dls(resize, bs)
# encoder = create_encoder("xresnet34", n_in=3, pretrained=False) # a fastai encoder
encoder = create_encoder("tf_efficientnet_b4_ns", n_in=3, pretrained=False) # a timm encoder
model = create_supcon_model(encoder, hidden_size=2048, projection_size=128)
aug_pipelines = get_supcon_aug_pipelines(size=size)
learn = Learner(dls,model,cbs=[SupCon(aug_pipelines, unsup_class_id, unsup_method=UnsupMethod.All,                                               reg_lambda=1.0, temp=0.07)])
learn.fit_flat_cos(100, 1e-2)

Multimodal

CLIP

from self_supervised.multimodal.clip import *
dls = get_dls(...)
clip_tokenizer = ClipTokenizer()
vitb32_config_dict = vitb32_config(224, clip_tokenizer.context_length, clip_tokenizer.vocab_size)
clip_model = CLIP(**vitb32_config_dict, checkpoint=False, checkpoint_nchunks=0)
learner = Learner(dls, clip_model, loss_func=noop, cbs=[CLIPTrainer()])
learn.fit_flat_cos(100, 1e-2)

CLIP-MoCo

from self_supervised.multimodal.clip_moco import *
dls = get_dls(...)
clip_tokenizer = ClipTokenizer()
vitb32_config_dict = vitb32_config(224, clip_tokenizer.context_length, clip_tokenizer.vocab_size)
clip_model = CLIPMOCO(K=4096,m=0.999, **vitb32_config_dict, checkpoint=False, checkpoint_nchunks=0)
learner = Learner(dls, clip_model, loss_func=noop, cbs=[CLIPMOCOTrainer()])
learn.fit_flat_cos(100, 1e-2)

ImageWang Benchmarks

All of the algorithms implemented in this library have been evaluated in ImageWang Leaderboard.

In overall superiority of the algorithms are as follows SwAV > MoCo > BYOL > SimCLR in most of the benchmarks. For details you may inspect the history of ImageWang Leaderboard through github.

BarlowTwins is still under testing on ImageWang.

It should be noted that during these experiments no hyperparameter selection/tuning was made beyond using learn.lr_find() or making sanity checks over data augmentations by visualizing batches. So, there is still space for improvement and overall rankings of the alogrithms may change based on your setup. Yet, the overall rankings are on par with the papers.

Contributing

Contributions and or requests for new self-supervised algorithms are welcome. This repo will try to keep itself up-to-date with recent SOTA self-supervised algorithms.

Before raising a PR please create a new branch with name <self-supervised-algorithm>. You may refer to previous notebooks before implementing your Callback.

Please refer to sections Developers Guide, Abbreviations Guide, and Style Guide from https://docs.fast.ai/dev-setup and note that same rules apply for this library.

self_supervised's People

Contributors

keremturgutlu avatar muellerzr avatar gsganden avatar jimmiemunyi avatar dependabot[bot] avatar

Stargazers

YvanKOB avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.