Giter Site home page Giter Site logo

wise-ft's Introduction

Robust fine-tuning of zero-shot models

This repository contains code for the paper Robust fine-tuning of zero-shot models by Mitchell Wortsman*, Gabriel Ilharco*, Jong Wook Kim, Mike Li, Simon Kornblith, Rebecca Roelofs, Raphael Gontijo-Lopes, Hannaneh Hajishirzi, Ali Farhadi, Hongseok Namkoong, Ludwig Schmidt.

TLDR: We fine-tune zero-shot models while preserving or improving OOD accuracy at no extra computational cost during fine-tuning or inference.

Abstract

Large pre-trained models such as CLIP or ALIGN offer consistent accuracy across a range of data distributions when performing zero-shot inference (i.e., without fine-tuning on a specific dataset). Although existing fine-tuning approaches substantially improve accuracy in-distribution, they often reduce out-of-distribution robustness. We address this tension by introducing a simple and effective method for improving robustness: ensembling the weights of the zero-shot and fine-tuned models (WiSE-FT). Compared to standard fine-tuning, WiSE-FT provides large accuracy improvements out-of-distribution, while preserving high in-distribution accuracy. On ImageNet (in-distribution) and five derived distribution shifts, WiSE-FT improves out-of-distribution accuracy by 4 to 6 percentage points (pp) over prior work while increasing in-distribution accuracy by 1.6 pp. WiSE-FT achieves similarly large robustness improvements (2 to 23 pp) on a diverse set of six further distribution shifts, and in-distribution accuracy gains of 0.8 to 3.3 pp compared to standard fine-tuning on seven commonly used transfer learning datasets. These improvements come at no additional computational cost during fine-tuning or inference.

Summary figure

figure1

Code

Overview

WiSE-FT can be implemented in a few lines of code in addition to standard fine-tuning, as shown below. See src/wise_ft.py for more details.

# Load models
zeroshot = ImageClassifier.load(zeroshot_checkpoint)
finetuned = ImageClassifier.load(finetuned_checkpoint)
theta_0 = zeroshot.state_dict()
theta_1 = finetuned.state_dict()

# make sure checkpoints are compatible
assert set(theta_0.keys()) == set(theta_1.keys())

# interpolate between checkpoints with mixing coefficient alpha
theta = {
    key: (1-alpha) * theta_0[key] + alpha * theta_1[key]
    for key in theta_0.keys()
}

# update the model acccording to the new weights
finetuned.load_state_dict(theta)

# evaluate
evaluate(finetuned, args)

Install dependencies

conda env create
conda activate wiseft

Add directory to PYTHONPATH:

cd wise-ft
export PYTHONPATH="$PYTHONPATH:$PWD"

Download data

When necessary, please refer to datasets.md for instructions on how to download datasets.

Run WiSE-FT

Sample command when zeroshot and fine-tuned models are available:

python src/wise_ft.py   \
    --eval-datasets=ImageNet,ImageNetV2,ImageNetR,ImageNetA,ImageNetSketch  \
    --load=models/zeroshot.pt,models/finetuned.pt  \
    --results-db=results.jsonl  \
    --save=models/wiseft  \
    --data-location=~/data \
    --alpha 0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0

Sample command for running WiSE-FT from scratch using ViT-B/32:

python src/wise_ft.py   \
    --train-dataset=ImageNet  \
    --epochs=10  \
    --lr=0.00003  \
    --batch-size=512  \
    --cache-dir=cache  \
    --model=ViT-B/32  \
    --eval-datasets=ImageNet,ImageNetV2,ImageNetR,ImageNetA,ImageNetSketch  \
    --template=openai_imagenet_template  \
    --results-db=results.jsonl  \
    --save=models/wiseft/ViTB32  \
    --data-location=~/data \
    --alpha 0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0

Note: the flag --freeze-encoder controls whether only a linear classifier is fine-tuned, or if all weights are fine-tuned (end-to-end).

Plotting results

Sample command for generating a scatter plot:

python src/scatter_plot.py  \
    --eval-datasets=ImageNetV2,ImageNetR,ImageNetA,ImageNetSketch  \
    --results-db=results.jsonl  \
    --save plots

We show samples of expected behavior below when running the commands above using ViT-B/16 (models can be downloaded here):

ImageNet-Sketch         ImageNet-A

ImageNet-R         ImageNetV2

ObjectNet

Citing

If you found this repository useful, please consider citing:

@article{wortsman2021robust,
  title={Robust fine-tuning of zero-shot models},
  author={Wortsman, Mitchell and Ilharco, Gabriel and Kim, Jong Wook and Li, Mike and Kornblith, Simon and Roelofs, Rebecca and Gontijo-Lopes, Raphael and Hajishirzi, Hannaneh and Farhadi, Ali and Namkoong, Hongseok and Schmidt, Ludwig},
  journal={arXiv preprint arXiv:2109.01903},
  note={\url{https://arxiv.org/abs/2109.01903}},
  year={2021}
}

wise-ft's People

Contributors

gabrielilharco avatar mitchellnw avatar mmatena avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

wise-ft's Issues

Possibility to save ensemble as a full model?

If I'm understanding the method correctly, its a mix between model A and model B, at some ratio C

Is it possible, instead of having to add code to ensemble the mix of weights in every downstream application, to ensemble and save the model, premixed?

Poor performance on ResNet.

Although good performace obtained by fine tuning ViT model, I found the poor performance on the ResNet models. Thus, How to fine tune the CLIP model by using pre-trained ResNet models? Thanks.

Zero Shot Classification on my own Dataset

Hello,

I am trying to fine tune CLIP on my own dataset for Zero Shot Classification.
My question is - is there a way to load a CSV containing all the file paths and their corresponding labels? OR a Folder which contains all the images in subfolders?

About fine-tuning

Hi, good work here. I am following the steps trying to get the clip fine-tuned. So I downloaded two datasets that were used in your example and simplified the script to like this:

python src/wise_ft.py
--train-dataset=ImageNetR
--epochs=10
--lr=0.00003
--batch-size=32
--cache-dir=cache
--model=ViT-B/32
--eval-datasets=ImageNetR,ImageNetA
--template=openai_imagenet_template
--results-db=results.jsonl
--save=models/wiseft/ViTB32
--data-location=~/data
--alpha 0 0.5 0.9

And then I got the following error. I have checked the code, and I found there is no such method as train_loader. Is that because there are some updates from the code? Or? Can you please give me some hints? Thanks.

Traceback (most recent call last):
File "/Users/happymind/local_dev/wise-ft/src/wise_ft.py", line 104, in
wise_ft(args)
File "/Users/happymind/local_dev/wise-ft/src/wise_ft.py", line 61, in wise_ft
finetuned_checkpoint = finetune(args)
^^^^^^^^^^^^^^
File "/Users/happymind/local_dev/wise-ft/src/models/finetune.py", line 50, in finetune
num_batches = len(dataset.train_loader)
^^^^^^^^^^^^^^^^^^^^
AttributeError: 'ImageNetR' object has no attribute 'train_loader'. Did you mean: 'test_loader'?

Replicating few-shot results

In Table 7 of the paper, there are results showing Wise-FT with a linear classifier and the ViT/B-16 backbone can get 73% accuracy on a 16-shot imagenet dataset. It was mentioned that the learning rate was 10e-5 and it was trained for 10 epochs, but even with this information, I still cannot replicate the result shown in the paper. I was wondering if I could be provided with an exact command, or additional hyperparameters (e.g. batch size, number of warmup steps, etc.) so that this result can be replicated?

Training Parameters

Hello

Can you please tell what do Data(t) and Batch(t) mean when training from scratch using ViT-B/32:

image

zero-shot model

Hi,

I would like to use the WiSE-FT method to other tasks or pretrained models (e.g., bert, gpt). In this context, the so-called zero-shot model is actually the orignial model without fine-tuning, right? and the zero-model parameters actually means the directly-loaded pretrained parameters?

Thank you!

Where to find pre-trained model weights

Hi newbie here, I am trying to fine tune this model of yours which was uploaded to huggingface: https://huggingface.co/laion/CLIP-ViT-L-14-laion2B-s32B-b82K

I want to fine-tune it on my custom dataset.

Looking from the example below, the "checkpoint" to load are of .pt. May I ask where can I find these checkpoints for the pre-trained model specified in the link?

python src/wise_ft.py   \
    --eval-datasets=ImageNet,ImageNetV2,ImageNetR,ImageNetA,ImageNetSketch  \
    --load=models/zeroshot.pt,models/finetuned.pt  \
    --results-db=results.jsonl  \
    --save=models/wiseft  \
    --data-location=~/data \
    --alpha 0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0

Side question: why do I need to pass the finetuned.pt checkpoints for fine tuning? Won't I be missing the fine-tune weights before I start fine-tuning on my custom dataset?

Finetuning configs for more models

Hi, dear authors.
In this code you have provided an example for finetuning ViT-B/32:

python src/wise_ft.py   \
    --train-dataset=ImageNet  \
    --epochs=10  \
    --lr=0.00003  \
    --batch-size=512  \
    --cache-dir=cache  \
    --model=ViT-B/32  \
    --eval-datasets=ImageNet,ImageNetV2,ImageNetR,ImageNetA,ImageNetSketch  \
    --template=openai_imagenet_template  \
    --results-db=results.jsonl  \
    --save=models/wiseft/ViTB32  \
    --data-location=~/data \
    --alpha 0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0

By runing it, I can get the final WISE-FT results at \alpha=0.5 below:

ImageNet Top-1 accuracy: 0.7554
ImageNetR Top-1 accuracy: 0.7145
ImageNetA Top-1 accuracy: 0.3452
ImageNetSketch Top-1 accuracy: 0.4696
  • Is the result correctly aligned with your results? Since I cannot find official results for ViT-B/32 in paper, I just want to ensure that I run the code correctly.
  • What hyper-parameter config for other models, such as ViT-L, ViT-B, etc?

Custom Dataset Class Usage

Hello, I am planning to finetune a classifier for my dataset and have created a class for it:

import os
import PIL
import torch
import numpy as np
import torchvision
from torchvision import transforms

# define class names
classnames = ['real', 'fake']
# Define the labels and their corresponding integer values
label_dict = {name: i for i, name in enumerate(classnames)}

class ImageFolderDataset(torch.utils.data.Dataset):
    def __init__(self, folder_path, transform=None):
        self.folder_path = folder_path
        self.transform = transform

        # Initialize the lists to store the image paths and labels
        self.image_paths = []
        self.labels = []

        # Loop over the subfolders and their contents
        for label_name in classnames:
            label_path = os.path.join(self.folder_path, label_name)
            for filename in os.listdir(label_path):
                # Create the full path to the image file
                image_path = os.path.join(label_path, filename)
                # Add the image path and label to their respective lists
                self.image_paths.append(image_path)
                self.labels.append(label_dict[label_name])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # Load the image from disk
        image_path = self.image_paths[idx]
        image = Image.open(image_path)
        if self.transform is not None:
            image = self.transform(image)

        # Retrieve the label for this image
        label = self.labels[idx]

        return image, label
    
class ForenSynths:
    def __init__(self, preprocess,
                 location=os.path.expanduser('~/ForenSynths/biggan'),
                 batch_size=128,
                 num_workers=16,
                 classnames=None):

        ################# training #################
        self.train_dataset = ImageFolderDataset(root=location, transform=preprocess)

        self.train_dataloader = torch.utils.data.DataLoader(
            self.train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers
        )
        
        ################# testing #################
        self.test_dataset = ImageFolderDataset(root=location, transform=preprocess)

        self.test_dataloader = torch.utils.data.DataLoader(
            self.test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
        )

        self.classnames = classnames
from src.templates.utils import append_proper_article

forensynths_template = [
    lambda c: f"a {c} photo.",
    lambda c: f"this is {c}.",
    lambda c: f"a {c} image is shown.",
    lambda c: f"a {c} image is displayed.",
    lambda c: f"The image presented is a {c} image.",
    lambda c: f"The image presented is {c}.",
    lambda c: f"The depicted image is {c}.",
    lambda c: f"A picture is showcased, which can be described as {c}.",
]

Based on the instructions, I should run the command as this:

python src/wise_ft.py   \
    --train-dataset=ForenSynths\
    --epochs=10  \
    --lr=0.00003  \
    --batch-size=32 \
    --cache-dir=cache  \
    --model=RN50  \
    --eval-datasets=ForenSynths  \
    --classnames= ['real', 'fake'] \
    --template=forensynths_template  \
    --results-db=results.jsonl  \
    --save=models/wiseft/ForenSynths\
    --data-location=~/ForenSynths/biggan\
    --alpha 0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0

How to find out exactly which labels is used to caculate logits?

Hi, there are some questions when running your finetune code that i'm facing
In file wiseft/src/model/finetune.py, in line 83, logits is caculated by: logits = model(inputs)
inputs = batch[input_key].cuda(). I choose fine end to end, so input_key is image
When I print out the shape of logit in logits, it's a 1000 dimenson tensor, so that mean my image is compared with 1000 labels?
Howerver, I got no idea how to find out exactly which label is used to caculate logits
I track back to file wise/src/model/modeling.py, in line 72, so my inputs is run through a image_encoder, the output (logits) is caculated by calling classification_head. Howerver I still don't know which labels are used in this process
def forward(self, inputs):
if self.process_images:
inputs = self.image_encoder(inputs)
outputs = self.classification_head(inputs)
return outputs

ModuleNotFoundError when running wise-ft.py on google colab

Hi, I'm getting ModuleNotFoundError when running wise-ft on google colab. I tried many solution i found in the internet but none of them working.
After cloning your repo, I run this code

!python src/wise_ft.py   \
    --eval-datasets=ImageNet,ImageNetV2,ImageNetR,ImageNetA,ImageNetSketch  \
    --load=models/zeroshot.pt,models/finetuned.pt  \
    --results-db=results.jsonl  \
    --save=models/wiseft  \
    --data-location=~/data \
    --alpha 0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0

And got this error
Traceback (most recent call last): File "/content/wise-ft/src/wise_ft.py", line 7, in <module> from src.models.eval import evaluate ModuleNotFoundError: No module named 'src'

Question about Table. 2 in the paper

image

How can we get the results in above figure. Do we need to design text prompts for each task and use them to init the classification head?

I try to add the classification head with random init weights, but get poor results for WiSE-FT.

OSError (undefined symbol) running wise_ft.py

Hi, I'm getting an OSError when trying to run the interpolation

I want to interpolate ViT-L-14-336px.pt with my fine-tuned.pt model but can't solve this issue, any ideas?

I ran the code below to create the env (no errors or warnings):

conda env create
conda activate wiseft

cd wise-ft
export PYTHONPATH="$PYTHONPATH:$PWD"

And the code to interpolate:

python wise_ft.py 

--load=/home/user/.cache/clip/ViT-L-14-336px.pt,/home/user/model_checkpoint/ft_01_6ep_lr2e6.pt      
--results-db=results.jsonl      
--save=models/wiseft      
--data-location=~/data     
--alpha 0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0

error:

Traceback (most recent call last):
  File "wise_ft.py", line 5, in <module>
    import torch
  File "/home/user/miniconda3/envs/wiseft/lib/python3.6/site-packages/torch/__init__.py", line 189, in <module>
    _load_global_deps()
  File "/home/user/miniconda3/envs/wiseft/lib/python3.6/site-packages/torch/__init__.py", line 142, in _load_global_deps
    ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL)
  File "/home/user/miniconda3/envs/wiseft/lib/python3.6/ctypes/__init__.py", line 348, in __init__
    self._handle = _dlopen(self._name, mode)
OSError: /home/user/miniconda3/envs/wiseft/lib/python3.6/site-packages/torch/lib/../../../../libcublas.so.11: undefined symbol: free_gemm_select, version libcublasLt.so.11

Fine-tune on your own dataset

Hi,

I was wondering where to get started if I want to use this to finetune clip on my own dataset (a dataset of sketch-text pairs)?

Does fine-tune only tweak image encoder?

First of all, thanks for sharing the codebase.
I briefly went through the codes and it seems like you only fine-tune the image encoder part, is that right? If yes, I'm curious have you tried tweaking both image and text encoders?

Baseline curve for effective robustness for WILDS

Hi,

Thanks for the amazing work! I wonder if you could share some information on fitting the baseline curve, i.e., the list of standard ImageNet models' ID vs. OOD performance and its coefficients (w and b) on FMoW and iWildCam. Thanks in advance!

-K

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.