Giter Site home page Giter Site logo

berenslab / t-simcne Goto Github PK

View Code? Open in Web Editor NEW
112.0 5.0 12.0 16.4 MB

Unsupervised visualization of image datasets using contrastive learning

Home Page: https://t-simcne.readthedocs.io/en/latest/

Python 100.00%
dimensionality-reduction visualization

t-simcne's Introduction

Unsupervised visualization of image datasets using contrastive learning

This is the code for the paper โ€œUnsupervised visualization of image datasets using contrastive learningโ€ (ICLR 2023).

If you use the code, please cite our paper:

@inproceedings{boehm2023unsupervised,
  title={Unsupervised visualization of image datasets using contrastive learning},
  author={B{\"o}hm, Jan Niklas and Berens, Philipp and Kobak, Dmitry},
  booktitle={International Conference on Learning Representations},
  year={2023},
}

We show that it is possible to visualize datasets such as CIFAR-10 and CIFAR-100 in 2D with a contrastive learning objective, while preserving a lot of structure! We call our method t-SimCNE.

arch

Installation

Installation should be as easy as calling:

pip install tsimcne

The package is now available on PyPI. If you want to install it from source, you can do as follows.

git clone https://github.com/berenslab/t-simcne
cd t-simcne
pip install .

Since the project uses a pyproject.toml file, you need to make sure that pip version is at least v22.3.1.

Usage example

The documentation is available at readthedocs. Below is a simple usage example.

import torch
import torchvision
from matplotlib import pyplot as plt
from tsimcne.tsimcne import TSimCNE

# get the cifar dataset (make sure to adapt `data_root` to point to your folder)
data_root = "experiments/cifar/out/cifar10"
dataset_train = torchvision.datasets.CIFAR10(
    root=data_root,
    download=True,
    train=True,
)
dataset_test = torchvision.datasets.CIFAR10(
    root=data_root,
    download=True,
    train=False,
)
dataset_full = torch.utils.data.ConcatDataset([dataset_train, dataset_test])

# create the object (here we run t-SimCNE with fewer epochs
# than in the paper; there we used [1000, 50, 450]).
tsimcne = TSimCNE(total_epochs=[500, 50, 250])

# train on the augmented/contrastive dataloader (this takes the most time)
tsimcne.fit(dataset_full)

# map the original images to 2D
Y = tsimcne.transform(dataset_full)

# get the original labels from the dataset
labels = [lbl for img, lbl in dataset_full]

# plot the data
fig, ax = plt.subplots()
ax.scatter(*Y.T, c=labels)
fig.savefig("tsimcne.png")

CIFAR-10

annotated plot of cifar10

CIFAR-100

label density for cifar100

Reproducibility

For reproducing the results of the paper, please see the iclr2023 branch in this repository.

t-simcne's People

Contributors

dkobak avatar fabioseel avatar jnboehm avatar konstantinwilleke avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

t-simcne's Issues

Bug when running example [uninitialized dataset_transforms]

When running the usage example, an Type Error is raised when mapping the images to the 2D coordinates:
Y = tsimcne.transform(dataset_full)

TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
[...]
File "[...]/t-simcne/tsimcne/imagedistortions.py", line 159, in getitem
item1 = self.transform(orig_item)
^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: 'NoneType' object is not callable

It seems that this is due to line 623 in tsimcne.py in make_dataloader:
data_transform = self.data_transform
self.data_transform is still None here.

Additionally this line prohibits passing a data_transform to tsimcne.transform is not possible at the moment, as it gets overwritten.

Issues when running the usage example

Hi!

I'm interested in t-simcne and tried the usage example in README. However, I encountered an issue when running tsimcne.fit() as follow:

Screen Shot 2023-10-19 at 10 32 20 PM

Do you have any clue how to fix it? I'm using the following package versions:

lightning=2.1.0
python=3.8
torch=2.0.0

Which versions are you using for t-simcne?

Thank you so much in advance!

Default behavior of use_ffcv is leading to an error

When trying to run the example script, I get the following error:

        if self.use_ffcv:
            try:
                import ffcv

                ffcv.transforms.RandomGrayscale
            except ModuleNotFoundError:
                raise ValueError(
                    "`use_ffcv` is not False, but `ffcv` is not installed. "
                    "Install https://github.com/facebookresearch/FFCV-SSL"
                )

the requirements of t-simcne do not contain the ffcv package. When it is not installed, the default behavior of use_ffcv='auto' will lead to the ValueErrr being thrown. So I'd suggest to either include ffcv as a requirement, or set the default of use_ffcv=auto to not using ffcv when it is not installed.

data_transform in transform method of class TSimCNE is not being used.

I'm sorry for the flurry of issues that I'm raising here, but I thought I'd just raise them as I go to replicate the code with my MEI data.

Here, the kwarg data_transform is unused. I understand that at this step, the data_transform are not needed by default, but I can see cases where it might be useful.

data_transform=None,

I'd be happy to make a PR to change L487 to pass self.data_transform_none if data_transform=None else data_transform

And because this function is facing the user, it would benefit from a more extensive docstring.

Required Python version needs to be >=3.10 instead of 3.09

Hi,

I was trying to run the example script in my python3.9 environment.

But I'm getting this error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/tmp/ipykernel_62/1150179161.py in <module>
      2 import torchvision
      3 from matplotlib import pyplot as plt
----> 4 from tsimcne.tsimcne import TSimCNE
      5 
      6 # get the cifar dataset (make sure to adapt `data_root` to point to your folder)

/src/nnfabrik/t-simcne/tsimcne/__init__.py in <module>
----> 1 from .tsimcne import PLtSimCNE, TSimCNE
      2 
      3 __version__ = "0.3.2"

/src/nnfabrik/t-simcne/tsimcne/tsimcne.py in <module>
    129 
    130 
--> 131 class TSimCNE:
    132     """The main entry point for fitting tSimCNE on a dataset.
    133 

/src/nnfabrik/t-simcne/tsimcne/tsimcne.py in TSimCNE()
    356     def fit_transform(
    357         self,
--> 358         X: torch.utils.data.Dataset | str,
    359         data_transform=None,
    360         return_labels: bool = False,

TypeError: unsupported operand type(s) for |: 'type' and 'type'

As far as I know, the str | None syntax is supported for python 3.10 or newer.

Any numerical results to indicate avarage distances between classes

Thanks for the great work. I impletmented t-simcne on veri-776 vehicle datasets to visualise groups by vehicle types. I trained with total_epochs=[500, 50, 250] and generated following nice result which shows the model well seperated 'truck' (TypeID=8) with other vehicles. I wonder if it's possible to output some numerical results to indicate avarage distances between classes? That will be convinient to quantitively show how different of the classes.
tsimcne_VeRi

Is the raw data available?

Hi,

Thanks for sharing this repository! Since this is pretty computationally heavy, even for CIFAR-10, I was wondering if y'all would be open to making available the raw results you have for the datasets you demonstrated in the paper (cifar10,100)?

Thanks!

setup.py missing

Hello,

I'm interesting to trying out t-simcne. However, when I try to install it via pip install -e . following the instructions under Implementation, I encounter the following error:

> git clone [email protected]:berenslab/t-simcne.git
> cd t-simcne
> cd cnexp
> pip install -e .
ERROR: File "setup.py" not found. Directory cannot be installed in editable mode

It seems like there is a setup.py script missing from the cnexp/ folder.

Best,
Jean

Unable to use multiple devices

It appears the number of devices is hardcoded:

trainer = pl.Trainer(max_epochs=n_epochs, devices=1)

trainer = pl.Trainer(devices=1)

Given the contrastive learning task, it would be preferable to utilize more resources for training and allow the devices argument to be passed to your model classes.

I'm unsure to what extent other aspects of the model would require changes, but I believe the learning rate calculation in lr_from_batchsize would need to be updated as well (batch_size * devices).

Multiple values for keyword argument 'out_dim'

It appears a recent commit is causing the following exception to be thrown when initializing the PLtSimCNE model for the final stage:

TypeError: tsimcne.tsimcne.PLtSimCNE() got multiple values for keyword argument 'out_dim'

The train_args dictionary already has out_dim, which is causing this exception.

The following line should be removed:

out_dim=self.out_dim,

I am able to replicate this error by running the example code from the readme (I'm setting total_epochs=[1, 1, 1] to get to stage 3 quicker).

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.