Giter Site home page Giter Site logo

iic's Introduction

Invariant Information Clustering for Unsupervised Image Classification and Segmentation

Reproduction of the IIC paper in pytorch.

Results for MNIST

Each row of the below image represents a class

alt text

Basic command usage

train_classifier.py --config config/iic/mnist.yaml --display 10

Train an autoencoder on cifar10, displaying images every 10 batches

train.py --config config/cifar10.yaml --display 10 --batchsize 64  --epochs 200 

Train an autoencoder on cifar10, with batch size 64 and for 200 passes through the training set

Configuration

Configuration flags can be specified in argparse parameters, or in yaml files, or in both.

--config parameter is used to specify a yaml file to load parameters from. The yaml file contents will be added to the argparse namespace object.

Precedence is

  • Arguments from command line
  • Arguments from the config file
  • Default value if specified in config.py

Yaml files can contain nested name-value pairs and they will be flattened

dataset:
  name: celeba
  train_len: 10000
  test_len: 1000

will be flattened to argparse arguments

--dataset_name celeba
--dataset_train_len 10000
--dataset_test_len: 1000

Data package

A data package is an object that contains everything required to load the data for training.

import datasets.package as package

datapack = package.datasets['celeba']

train, test = datapack.make(train_len=10000, test_len=400, data_root='data')

get a training set of length 1000 and a test set of length 400

Example config

for training VGG16 on CIFAR 10 with a custom SGD schedule

batchsize: 128
epochs: 350

dataset:
  name: cifar-10-normed

model:
  name: VGG16
  type: conv
  encoder: [3, 64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M']

optim:
  class: SGD
  lr: 0.1
  momentum: 0.9
  weight_decay: 5e-4

scheduler:
  class: MultiStepLR
  milestones: [150, 250]
  gamma: 0.1

see more example configs in the configs directory of the project

Configuring the optimizers

optim:
  class: SGD
  lr: 0.1
  momentum: 0.9
  weight_decay: 5e-4

scheduler:
  class: MultiStepLR
  milestones: [150, 250]
  gamma: 0.1

and in the code

from iic import config
    import torch.nn as nn

    args = config.config()
    model = nn.Linear(10, 2)
    optim, scheduler = config.get_optim(args, model.parameters())

Layer builder

If you get bored of typing the same NN blocks over and over, you can instead use the layer builder.

It works similar to the Pytorch built-in layer builder, it can build

fully connected: type = 'fc' vgg: type = 'vgg' or resnet: type = 'resnet'

for example, to build vgg blocks...

from iic.models.mnn import make_layers
from iic.models.layerbuilder import LayerMetaData

meta = LayerMetaData(input_shape=(32, 32))

encoder_core, meta = make_layers(['C:3', 64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], type='vgg', meta=meta)

decoder_core, meta = make_layers([512, 512, 'U', 256, 256, 'U', 256, 256, 'U', 128, 'U', 64, 'U', 'C:3'], type='vgg', meta=meta)

M -> Max Pooling U -> Linear Upsample C:3 -> Conv layer with 3 channels

Duplicating this project

git clone --bare https://github.com/DuaneNielsen/deep_learning.git
cd deep_learning.git/
git push --mirror https://github.com/DuaneNielsen/<NEW REPO NAME>.git

iic's People

Contributors

duanenielsen avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

Forkers

zombasy

iic's Issues

train_autoencoder.py

Thanks for your sharing.
May I ask if the train_autoencoder is implemented for unsupervised segmentation?
The official code is hard to reproduce and read.

Support for Segmentation

I was trying to replicate IIC for segmentation when I came across this implementation.
Does it support segmentation as well ?

License

Hi @DuaneNielsen. Could you please add a license to this repository to clarify the terms of use? GitHub repositories without a license are not usable by others by default under copyright laws, so a clear license would be much appreciated! Thanks

Best accuracy parameters

Hi,
I only can get around 30% of accuracy after 200 epochs. Could you please show me your parameters which can get over 50%?

Thanks

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.