Giter Site home page Giter Site logo

secondlevel / cifar10-classification Goto Github PK

View Code? Open in Web Editor NEW
3.0 1.0 0.0 757 KB

Pattern Recognition homework5 in NYCU It is the task to classify CIFAR10 datasets using Vision Transformer. You can get some detailed introduction and experimental results in the link below. https://github.com/secondlevel/Cifar10-Classification/blob/main/310551031_HW5.pdf

Python 100.00%
cifar10 image-classification vision-transformer

cifar10-classification's Introduction

Cifar10-Classification(Pattern Recognition Homework5)

This assignment is to train a model to classification the images of cifar10. All the models in this project were built by pytorch.

In addition, please refer to the following report link for detailed report and description of the experimental results. https://github.com/secondlevel/Cifar10-Classification/

image

Hardware

Operating System: Ubuntu 20.04.3 LTS  

CPU: Intel(R) Core(TM) i7-6700 CPU @ 3.40GHz  

GPU: NVIDIA GeForce GTX TITAN X  

Requirement

In this part, I use anaconda and pip to build the execution environment.

In addition, the following two options can be used to build an execution environment

  • First Option

conda env create -f environment.yml
  • Second Option

conda create --name cifar python=3.8
conda activate cifar
conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch
conda install matplotlib pandas scikit-learn -y
pip install tqdm

Directory Tree

In this homework, you can put the folder on the specified path according to the pattern in the following directory tree for training and testing.

The model weight can be download in the following link, please put it under the checkpoint directory.
The data can be download in the following link, please put it in the under the repository according to the following description. https://drive.google.com/drive/folders/1Boe0EZT1cyV6MxqqTFk1mufsGbThs4BG?usp=sharing

├─ 310551031_HW5.py
├─ environment.yml
├─ history_csv
│  └─ BEST_VIT_CIFAR.csv
├─ checkpoint
│  └─ BEST_VIT_CIFAR.rar
├─ x_train.npy
├─ x_test.npy
├─ y_train.npy
├─ y_test.npy
└─ README.md

Flow Chart

Hyperparameter Setting

image_size = 224
number_worker = 4
batch_size = 64
epochs = 10
lr = 2e-5
optimizer = AdamW
loss function = CrossEntropy

Data preprocess

The Data Preprocess include two parts. The first part is the standardization of pixel value([0, 255] to [0, 1]). The second part is to adjust the image to 224 x 224.

1. Pixel Value Normalization

2. Resize Image to 224x224

Data Loader

In order to avoid the problem of the cuda out of memory, I create the data loader to process the data.

  • Input: Image Array, Label Array, Data Augmentation method.
  • Ouput: DataLoader
class CIFARLoader(data.Dataset):
    def __init__(self, image, label, transform=None):

        self.img_name, self.labels = image, label
        self.transform = transform
        print("> Found %d images..." % (len(self.img_name)))

    def __len__(self):
        """'return the size of dataset"""
        return len(self.img_name)

    def __getitem__(self, index):
        """something you should implement here"""

        self.img = self.img_name[index]
        self.label = self.labels[index]

        if self.transform:
            self.img = self.transform(self.img)

        return self.img, self.label

Model Architecture

In this homework, I used the Vision Transformer pretrained model to classify images.

In addition, I added the linear layer to the Vision Transformer (VIT) [1], all the weight of the VIT is unfreeze.

The Architecture of the classification model is as follows.

class VIT(nn.Module):
    def __init__(self, pretrained=True):
        super(VIT, self).__init__()
        self.model = models.vit_b_32(pretrained=pretrained)
        self.classify = nn.Linear(1000, num_classes)

    def forward(self, x):
        x = self.model(x)
        x = self.classify(x)
        return x

model = VIT()
for name, child in model.named_children():
    for param in child.parameters():
        param.requires_grad = True

Training

You can switch to the training mode with the following instruction, and then you can start training the classification model.

python 310551031_HW5.py --mode train

The best model weight during training will be stored at checkpoint directory, and the training history will in the history_csv directory.

The training accuracy history is as following.

The training Loss history is as following.

Testing

You can switch to the testing mode with the following instruction, and then you can evaluate the classification result.
Best Model Weight name: BEST_VIT_CIFAR.rar (Which is in the checkpoint directory)

python 310551031_HW5.py --mode test

Reference

[1] A. Dosovitskiy et al., “An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale,” arXiv, arXiv:2010.11929, Jun. 2021. doi: 10.48550/arXiv.2010.11929.

cifar10-classification's People

Contributors

secondlevel avatar

Stargazers

 avatar  avatar  avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.