Giter Site home page Giter Site logo

zhangyikaii / model-spider Goto Github PK

View Code? Open in Web Editor NEW
11.0 2.0 1.0 391 KB

The code repository for "Model Spider: Learning to Rank Pre-Trained Models Efficiently"

Home Page: https://arxiv.org/abs/2306.03900

License: MIT License

Shell 0.70% Python 99.30%
large-models learning-to-rank model-selection pre-trained-model transfer-learning training-paradigms learnware model-reuse

model-spider's Introduction

 

Generic badge GitHub Workflow Status (branch)
PyPI PyPI - Downloads
PyTorch - Version Python - Version

Model Spider: Learning to Rank Pre-Trained Models Efficiently (NeurIPS 2023 Spotlight)

📑 [Paper] [Code]

Detailed Introduction

Figuring out which Pre-Trained Model (PTM) from a model zoo fits the target task is essential to take advantage of plentiful model resources. With the availability of numerous heterogeneous PTMs from diverse fields, efficiently selecting the most suitable PTM is challenging due to the time-consuming costs of carrying out forward or backward passes over all PTMs. In this paper, we propose Model Spider, which tokenizes both PTMs and tasks by summarizing their characteristics into vectors to enable efficient PTM selection.
By leveraging the approximated performance of PTMs on a separate set of training tasks, Model Spider learns to construct representation and measure the fitness score between a model-task pair via their representation. The ability to rank relevant PTMs higher than others generalizes to new tasks. With the top-ranked PTM candidates, we further learn to enrich task repr. with their PTM-specific semantics to re-rank the PTMs for better selection. Model Spider balances efficiency and selection ability, making PTM selection like a spider preying on a web.
Model Spider demonstrates promising performance across various model categories, including visual models and Large Language Models (LLMs). In this repository, we have built a comprehensive and user-friendly PyTorch-based model ranking toolbox for evaluating the future generalization performance of models. It aids in selecting the most suitable foundation pre-trained models for achieving optimal performance in real-world tasks after fine-tuning. In this benchmark for selecting/ranking PTMs, we have reproduced relevant model selection methods such as H-Score, LEEP, LogME, NCE, NLEEP, OTCE, PACTran, GBC, and LFC.

  1. We introduce a single-source model zoo, building 10 PTMs on ImageNet across five architecture families, i.e., Inception, ResNet, DenseNet, MobileNet, and MNASNet. These models can be evaluated on 9 downstream datasets using measure like weighted tau, including Aircraft, Caltech101, Cars, CIFAR10, CIFAR100, DTD, Pet, and SUN397 for classification, UTKFace and dSprites for regression.
  2. We construct a multi-source model zoo where 42 heterogeneous PTMs are pre-trained from multiple datasets, with 3 architectures of similar magnitude, i.e., Inception-V3, ResNet-50, and DenseNet-201, pre-trained on 14 datasets, including animals, general and 3D objects, plants, scene-based, remote sensing, and multi-domain recognition. We evaluate the ability to select PTMs on Aircraft, DTD, and Pet datasets.

In this repo, you can figure out:

  • Implementations of Pre-trained Model Selection / Ranking (for unseen data) with an accompanying benchmark evaluation, including H-Score, LEEP, LogME, NCE, NLEEP, OTCE, PACTran, GBC, and LFC.
  • Get started quickly with our method Model Spider, and enjoy its user-friendly inference capabilities.
  • Feel free to customize the application scenarios of Model Spider!

 

Table of Contents

 

Pre-trained Model Ranking Performance

Performance comparisons of 9 baseline approaches and Model Spider on the single-source model zoo with weighted Kendall's tau. We denote the best-performing results in bold.

Method Downstream Target Dataset
Weighted Tau Aircraft Caltech101 Cars CIFAR10 CIFAR100 DTD Pets SUN397 Mean
H-Score 0.328 0.738 0.616 0.797 0.784 0.395 0.610 0.918 0.648
NCE 0.501 0.752 0.771 0.694 0.617 0.403 0.696 0.892 0.666
LEEP 0.244 0.014 0.704 0.601 0.620 -0.111 0.680 0.509 0.408
N-LEEP -0.725 0.599 0.622 0.768 0.776 0.074 0.787 0.730 0.454
LogME 0.540 0.666 0.677 0.802 0.798 0.429 0.628 0.870 0.676
PACTran 0.031 0.200 0.665 0.717 0.620 -0.236 0.616 0.565 0.397
OTCE -0.241 -0.011 -0.157 0.569 0.573 -0.165 0.402 0.218 0.149
LFC 0.279 -0.165 0.243 0.346 0.418 -0.722 0.215 -0.344 0.034
GBC -0.744 -0.055 -0.265 0.758 0.544 -0.102 0.163 0.457 0.095
Moder Spider (Ours) 0.506 0.761 0.785 0.909 1.000 0.695 0.788 0.954 0.800

 

Code Implementation

Quick Start & Reproduce

  • Set up the environment:

    conda create --name opencompass python=3.10 pytorch torchvision pytorch-cuda -c nvidia -c pytorch -y
    conda activate modelspider
    git clone https://github.com/zhangyikaii/Model-Spider.git
    cd Model-Spider
    pip install -r requirements.txt
  • Choose your path xxx/xx to store data & model:

    source ./scripts/modify-path.sh xxx/xx
  • Download the data and pre-trained model spider here to previous path xxx/xx. (Note that the training set for model spider is sampled from EuroSAT, OfficeHome, PACS, SmallNORB, STL10 and VLCS)

  • Unzip c_data.zip to path xxx/xx/data/ and then run:

    bash scripts/test-model-spider.sh xxx/xx/best.pth

    The results will be displayed on the screen.

 

Reproduce for Other Baseline Methods

We provided results of baseline method in the assests/baseline_results.csv file. Ensure the test datasets (Aircraft, Cars, CIFAR10, CIFAR100, DTD, Pet, SUN397) are in xxx/xx/data, and run following command to reproduce them:

bash scripts/reproduce-baseline-methods.sh

 

Contributing

Model Spider is currently in active development, and we warmly welcome any contributions aimed at enhancing capabilities. Whether you have insights to share regarding pre-trained models, data, or innovative ranking methods, we eagerly invite you to join us in making Model Spider even better.

 

Citing Model Spider

@inproceedings{ModelSpiderNeurIPS23,
  author    = {Yi{-}Kai Zhang and
               Ting{-}Ji Huang and
               Yao{-}Xiang Ding and
               De{-}Chuan Zhan and
               Han{-}Jia Ye},
  title     = {Model Spider: Learning to Rank Pre-Trained Models Efficiently},
  booktitle = {Advances in Neural Information Processing Systems 36: Annual Conference
               on Neural Information Processing Systems 2023, NeurIPS 2023, New Orleans,
               LA, USA, December 10 - 16, 2023},
  year      = {2023},
}

model-spider's People

Contributors

tingji2419 avatar zhangyikaii avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

Forkers

tingji2419

model-spider's Issues

Implementation of baselines

Hello, thanks for the nice work.
Is there implementation of baseline methods, such as LFC, in your code?
The mptms module seems to be imported in feature_extractor in the code but I could not find the module itself.

learnware.learnware_model not found.

Dear author,
Thansk for your work!
I try to run the command line you provided. However, it fails to execute successfully (many module not found error in train.py), appearing that you haven't uploaded the complete files. Would you be able to upload the entire project so that we can utilize it?

Can not open pretrained model link

Hello, I think you did a great job. And, I try to follow your work recently. However, I can not open the data and pretrained model download link. Could you please fix this. Thank you very much!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.