Giter Site home page Giter Site logo

rensortino / radiff Goto Github PK

View Code? Open in Web Editor NEW

This project forked from ska-inaf/radiff

0.0 0.0 0.0 18.7 MB

Official PyTorch implementation of the paper "RADiff: Controllable Diffusion Models for Radio Astronomical Maps Generation" (https://arxiv.org/abs/2307.02392)

License: Apache License 2.0

Python 100.00%

radiff's Introduction

RADiff: Controllable Diffusion Models for Radio Astronomical Map Generation

Renato Sortino, Thomas Cecconello, Andrea DeMarco, Giuseppe Fiameni, Andrea Pilzer, Andrew M. Hopkins, Daniel Magro, Simone Riggi, Eva Sciacca, Adriano Ingallinera, Cristobal Bordiu, Filomena Bufano, Concetto Spampinato

Paper

Table of Contents

Overview

This repository contains the official PyTorch implementation of the paper titled "RADiff: Controllable Diffusion Models for Radio Astronomical Maps Generation". In the paper, we propose an approach based on diffusion models to augment small datasets for training segmentation methods in radio astronomy.


Teaser image

Requirements

  1. Clone the repository:
git clone https://github.com/SKA-INAF/radiff.git
cd radiff
  1. Create a conda environment and install the required dependencies:
conda env create -f environment.yaml
conda activate radiff

Dataset

The data used to run the experiments for this paper is under privacy constraints and we are not allowed to publish it. However, the model can be trained on any collection of radio astronomical images in FITS format if it presents the folder structure described in this section.

Data should be put in the data folder. This folder should contain all the images and annotations, and a text file where each line defines the path of the images, one for train (train.txt) and one for validation (val.txt). The structure of the folder should be as follows:

data
├── train.txt
├── val.txt
├── images
│   ├── img0001.fits
│   ├── img0002.fits
│   ├── img0003.fits
│   └── ...
└── annotations
    ├── mask_img0001.json
    ├── mask_img0001_obj1.fits
    ├── mask_img0002.json
    ├── mask_img0002_obj1.fits
    ├── mask_img0002_obj2.fits
    ├── mask_img0002_obj3.fits
    ├── mask_img0003.json
    ├── mask_img0003_obj1.fits
    ├── mask_img0003_obj2.fits
    ├── ...
    └──

The images folder contains the 128x128 images to be used for training the model while the annotations folder contains information about each image in JSON format (class, bbox coordinates, flux intensity). Additionally, each FITS file contains the segmentation mask of each object. Note that this folder structure is adapted to the DataLoader in this implementation but this can be adapted to another file structure.

Pretrained Models

Model {METRIC1} {METRIC2} Link Comments
Autoencoder (VQ-VAE f4) 5.11 3.29 link

Usage

Inference

The implementation supports two inference modes:

  1. CLI inference, allowing to iterate over a folder of masks and generate an image for each one
  2. Interactive inference, more user-friendly but supporting only one mask at a time

CLI inference

TODO

Interactive Interface

Run the interface using Gradio

python gradio_app.py

Train on custom datasets

Train the Autoencoder

The first step is to train the autoencoder to reconstruct the images and prepare the latent space for the diffusion model. To do so, run the following command and specify the configuration file for its architecture:

python scripts/train_ae.py \ 
--ae-config vae-f4.yaml \
--dataset train.txt \
--run-dir {OUTPUT_DIR} \ 
--run-name {RUN_NAME} \
--on_wandb

Train the diffusion model

Once the autoencoder is capable of reconstructing the images, we can train the diffusion model. To do so, run the following command by specifying, as done earlier, the configuration file:

python scripts/train_ldm.py \ 
--ae-config vae-f4.yaml \
--ae-ckpt weights/autoencoder/vae-f4.pt \
--unet-config unet-cldm-mask.yaml \
--dataset train.txt \
--run-dir {OUTPUT_DIR} \ 
--run-name {RUN_NAME} \
--on_wandb

Additional Experiments

We provide additional experiments to test the performance of the model on different tasks. To run them, use the following commands:

Train DDPM for unconditional image generation

python scripts/train_ddpm.py \ 
--unet-config unet-pixel.yaml \
--dataset train.txt \
--run-dir {OUTPUT_DIR} \ 
--run-name {RUN_NAME} \
--on_wandb

Train DDPM for mask generation

python scripts/train_ddpm.py \ 
--unet-config unet-pixel-masks.yaml \
--dataset train.txt \
--run-dir {OUTPUT_DIR} \ 
--run-name {RUN_NAME} \
--on_wandb

Train LDM for unconditional image generation

python scripts/train_ldm.py \ 
--ae-config vae-f4.yaml \
--ae-ckpt weights/autoencoder/vae-f4.pt \
--unet-config unet-ldm.yaml \
--dataset train.txt \
--run-dir {OUTPUT_DIR} \ 
--run-name {RUN_NAME} \
--on_wandb

Results

TODO

BibTeX

@article{sortino2023radiff,
  title={RADiff: Controllable Diffusion Models for Radio Astronomical Maps Generation},
  author={Sortino, Renato and Cecconello, Thomas and DeMarco, Andrea and Fiameni, Giuseppe and Pilzer, Andrea and Hopkins, Andrew M and Magro, Daniel and Riggi, Simone and Sciacca, Eva and Ingallinera, Adriano and others},
  journal={arXiv preprint arXiv:2307.02392},
  year={2023}
}

radiff's People

Contributors

andrea-pilzer avatar rensortino avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.