Giter Site home page Giter Site logo

miccunifi / arniqa Goto Github PK

View Code? Open in Web Editor NEW
72.0 8.0 1.0 17.78 MB

[WACV 2024 Oral] - ARNIQA: Learning Distortion Manifold for Image Quality Assessment

License: Other

Python 96.92% Shell 0.25% C++ 2.83%
biqa blind-image-quality-assessment degradation-model image-degradation image-quality-assessment iqa low-level-vision no-reference-image-quality-assessment nr-iqa resnet self-supervised-learning simclr contrastive-learning deep-learning fr-iqa full-reference-image-quality-assessment full-reference-iqa image-processing image-quality computer-vision

arniqa's Introduction

ARNIQA (WACV 2024 Oral)

Learning Distortion Manifold for Image Quality Assessment

arXiv Generic badge Generic badge Generic badge GitHub Stars

PWC
PWC
PWC

πŸ”₯πŸ”₯πŸ”₯ [2024/06/06] ARNIQA is now included in the IQA-PyTorch GitHub Stars toolbox

This is the official repository of the paper "ARNIQA: Learning Distortion Manifold for Image Quality Assessment".

Note

If you are interested in IQA, take a look at our new dataset with UHD images and our latest work on CLIP-based opinion-unaware NR-IQA

Overview

Abstract

No-Reference Image Quality Assessment (NR-IQA) aims to develop methods to measure image quality in alignment with human perception without the need for a high-quality reference image. In this work, we propose a self-supervised approach named ARNIQA (leArning distoRtion maNifold for Image Quality Assessment for modeling the image distortion manifold to obtain quality representations in an intrinsic manner. First, we introduce an image degradation model that randomly composes ordered sequences of consecutively applied distortions. In this way, we can synthetically degrade images with a large variety of degradation patterns. Second, we propose to train our model by maximizing the similarity between the representations of patches of different images distorted equally, despite varying content. Therefore, images degraded in the same manner correspond to neighboring positions within the distortion manifold. Finally, we map the image representations to the quality scores with a simple linear regressor, thus without fine-tuning the encoder weights. The experiments show that our approach achieves state-of-the-art performance on several datasets. In addition, ARNIQA demonstrates improved data efficiency, generalization capabilities, and robustness compared to competing methods.

Comparison between our approach and the State of the Art for NR-IQA

Comparison between our approach and the State of the Art for NR-IQA. While the SotA maximizes the similarity between the representations of crops from the same image, we propose to consider crops from different images degraded equally to learn the image distortion manifold. The t-SNE visualization of the embeddings of the KADID dataset shows that, compared to Re-IQA, ARNIQA yields more discernable clusters for different distortions. In the plots, a higher alpha value corresponds to a stronger degradation intensity.

Citation

@inproceedings{agnolucci2024arniqa,
  title={ARNIQA: Learning Distortion Manifold for Image Quality Assessment},
  author={Agnolucci, Lorenzo and Galteri, Leonardo and Bertini, Marco and Del Bimbo, Alberto},
  booktitle={Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision},
  pages={189--198},
  year={2024}
}

Usage

Note

If you want to employ ARNIQA just for inference, you can also use it through the IQA-PyTorch GitHub Stars toolbox

Minimal Working Example

Thanks to torch.hub, you can use our model for inference without the need to clone our repo or install any specific dependencies. By default, ARNIQA computes a quality score in the range [0, 1], where higher is better.

import torch
import torchvision.transforms as transforms
from PIL import Image

# Set the device
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"

# Load the model
model = torch.hub.load(repo_or_dir="miccunifi/ARNIQA", source="github", model="ARNIQA",
                       regressor_dataset="kadid10k")    # You can choose any of the available datasets
model.eval().to(device)

# Define the preprocessing pipeline
preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load the full-scale image
img_path = "<path_to_your_image>"
img = Image.open(img_path).convert("RGB")

# Get the half-scale image
img_ds = transforms.Resize((img.size[1] // 2, img.size[0] // 2))(img)

# Preprocess the images
img = preprocess(img).unsqueeze(0).to(device)
img_ds = preprocess(img_ds).unsqueeze(0).to(device)

# NOTE: here, for simplicity, we compute the quality score of the whole image.
# In the paper, we average the scores of the center and four corners crops of the image.

# Compute the quality score
with torch.no_grad(), torch.cuda.amp.autocast():
    score = model(img, img_ds, return_embedding=False, scale_score=True)

print(f"Image quality score: {score.item()}")

Getting Started

Installation

We recommend using the Anaconda package manager to avoid dependency/reproducibility problems. For Linux systems, you can find a conda installation guide here.

  1. Clone the repository
git clone https://github.com/miccunifi/ARNIQA
  1. Install Python dependencies
conda create -n ARNIQA -y python=3.10
conda activate ARNIQA
cd ARNIQA
chmod +x install_requirements.sh
./install_requirements.sh

Data Preparation

You need to download the datasets and place them under the same directory data_base_path.

  1. [LIVE]: Download the Release 2 folder from here and the annotations from here (corresponding to the realigned subjective quality data)
  2. [CSIQ]: Create a folder containing the source and distorted images from here and the annotations from here.
  3. TID2013
  4. KADID10K
  5. FLIVE
  6. SPAQ

For each dataset, move the splits folder placed under the datasets directory of our repo under the corresponding dataset directory under data_base_path.

At the end, the directory structure should look like this:

β”œβ”€β”€ data_base_path
|
|    β”œβ”€β”€ LIVE
|    |   β”œβ”€β”€ fastfading
|    |   β”œβ”€β”€ gblur
|    |   β”œβ”€β”€ jp2k
|    |   β”œβ”€β”€ jpeg
|    |   β”œβ”€β”€ refimgs
|    |   β”œβ”€β”€ splits
|    |   β”œβ”€β”€ wn
|    |   LIVE.txt
|        
|    β”œβ”€β”€ CSIQ
|    |   β”œβ”€β”€ dst_imgs
|    |   β”œβ”€β”€ src_imgs
|    |   β”œβ”€β”€ splits
|    |   CSIQ.txt
|        
|    β”œβ”€β”€ TID2013
|    |    β”œβ”€β”€ distorted_images
|    |    β”œβ”€β”€ reference_images
|    |    β”œβ”€β”€ splits
|    |    mos_with_names.txt
|        
|    β”œβ”€β”€ KADID10K
|    |    β”œβ”€β”€ images
|    |    β”œβ”€β”€ splits
|    |    dmos.csv
|        
|    β”œβ”€β”€ FLIVE
|    |    β”œβ”€β”€ database
|    |    |    β”œβ”€β”€ blur_dataset
|    |    |    β”œβ”€β”€ EE371R
|    |    |    β”œβ”€β”€ voc_emotic_ava
|    |    β”œβ”€β”€ splits
|    |    labels_image.csv
|        
|    β”œβ”€β”€ SPAQ
|    |    β”œβ”€β”€ Annotations
|    |    β”œβ”€β”€ splits
|    |    β”œβ”€β”€ TestImage

Single Image Inference

To get the quality score of a single image, run the following command:
python single_image_inference.py --img_path assets/01.png --regressor_dataset kadid10k
--img_path                  Path to the image to be evaluated
--regressor_dataset         Dataset used to train the regressor. Options: ["live",
                            "csiq", "tid2013", "kadid10k", "flive", "spaq", "clive", "koniq10k"]

By default, ARNIQA computes a quality score in the range [0, 1], where higher is better.

Training

Before training, you need to download the pristine images belonging to the KADIS700 dataset. Download the .zip file from here and unzip it. At the end, the directory structure should look like this:

β”œβ”€β”€ data_base_path
|
|    β”œβ”€β”€ KADIS700
|    |   β”œβ”€β”€ ref_imgs
|        
|    β”œβ”€β”€ LIVE
|        
|    β”œβ”€β”€ CSIQ
|        
|    β”œβ”€β”€ TID2013
|        
|    β”œβ”€β”€ KADID10K
|        
|    β”œβ”€β”€ FLIVE
|        
|    β”œβ”€β”€ SPAQ

To train our model from scratch, run the following command:

python main.py --config config.yaml
--config <str>       Path to the configuration file

The configuration file must contain all the parameters needed for training and testing. See config.yaml for more details on each parameter. You need a W&B account for online logging.

For the training to be successful, you need to specify the following parameters:

experiment_name: <str>                  # name of the experiment
data_base_path: <str>                   # path to the base directory containing the datasets

logging.wandb.project: <str>            # name of the W&B project
logging.wandb.entity: <str>             # name of the W&B entity

You can overwrite all the parameters contained in the config file from the command line. For example:

python main.py --config config.yaml --experiment_name new_experiment --training.data.max_distortions 7 --validation.datasets live csiq --test.grid_search true

After training, main.py will run the test with the parameters provided in the config file and log the results, both offline and online. The encoder weights and the regressors will be under the experiments directory.

Testing

To manually test a model, run the following command:
python test.py --config config.yaml --eval_type scratch
--config <str>        Path to the configuration file
--eval_type <str>     Whether to test a model trained from scratch or the one pretrained by the authors of the paper.
                      Options: ['scratch', 'arniqa']

If eval_type == scratch, the script will test the encoder related to the experiment_name provided in the config file or from the command line. If eval_type == arniqa, the script will test our pretrained model.

Authors

Acknowledgements

This work was partially supported by the European Commission under European Horizon 2020 Programme, grant number 101004545 - ReInHerit.

LICENSE

Creative Commons License
All material is made available under Creative Commons BY-NC 4.0. You can use, redistribute, and adapt the material for non-commercial purposes, as long as you give appropriate credit by citing our paper and indicate any changes that you've made.

arniqa's People

Contributors

lorenzoagnolucci avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Forkers

brookluo

arniqa's Issues

How to use ARNIQA locally?

Hello Lorenzo!
I'm new to data scientists.
How to use your model locally?
I have a server with limited internet access and it doesn't connect to github.
Thank you!

regressor load with torch

Thanks for your great work

I trained a regressor myself which is similar to #3. I save the regerssor model with torch.save and try to load it by torch.load. But when i eval it shows TypeError: 'Ridge' object is not callable. Do you have any suggestions for solving it?

Minimal example does not work

I've been trying to run the minimal example, however I got following error:

TypeError: img should be Tensor Image. Got <class 'PIL.Image.Image'>

when trying to execute this line:

img = preprocess(img).unsqueeze(0).to(device)

How to solve it?

C:\Users\sham\anaconda3\envs\szj\python.exe E:\code\ARNIQA-main\single_image_inference.py
Traceback (most recent call last):
File "E:\code\ARNIQA-main\single_image_inference.py", line 6, in
from utils.utils_data import center_corners_crop
File "E:\code\ARNIQA-main\utils\utils_data.py", line 6, in
from utils.distortions import *
File "E:\code\ARNIQA-main\utils\distortions.py", line 16, in
dither_cpp = ctypes.CDLL(str(PROJECT_ROOT / "utils" / "dither_extension/dither.so")).dither
File "C:\Users\sham\anaconda3\envs\szj\lib\ctypes_init_.py", line 373, in init
self._handle = _dlopen(self._name, mode)
OSError: [WinError 193] %1 δΈζ˜―ζœ‰ζ•ˆηš„ Win32 应用程序。

Process finished with exit code 1

regressor of torch version

Hi, I tried use the torch nn.Linear to regression the encoder feature to final score, the code is as below.I tried using MSE loss or plcc loss to train the regression layer for 20 epoch, but the loss does not converge. the MSE loss is arround 0.3-0.4, and plcc loss is arround 0.3. Do οΏ½you have any suggestion about how can I train a regressor with torch vision ?

class ARNIQA(nn.Module):
    def __init__(self, encoder_path=None, regressor_path=None):
        super(ARNIQA, self).__init__()
        self.encoder = ResNet(embedding_dim=128, pretrained=False, use_norm=True)

        if encoder_path is not None:
            self.encoder.load_state_dict(torch.load(encoder_path, map_location="cpu"))

        self.regressor = nn.Linear(4096, 1)

    def forward(self, img, img_ds, return_embedding: bool = False):
        f, _ = self.encoder(img)
        f_ds, _ = self.encoder(img_ds)
        f_combined = torch.hstack((f, f_ds))
        score = self.regressor(f_combined)
        if return_embedding:
            return score, f_combined
        else:
            return score

KADIS700K config?

image

Something error while loading the pre-trained dataset of KADIS700k.
I have already downloaded KADIS700k dataset and tried to change the database path for several times.
Is there some other configuration for this?
(By the way, FLIVE is currently not accessible due to the server problem.)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    πŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. πŸ“ŠπŸ“ˆπŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❀️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.