Giter Site home page Giter Site logo

amnet's Introduction

AMNet: Memorability Estimation with Attention

A PyTorch implementation of our paper AMNet: Memorability Estimation with Attention by Jiri Fajtl, Vasileios Argyriou, Dorothy Monekosso and Paolo Remagnino. This paper will be presented at CVPR 2018.

Installation

The development and evaluation was done on the following configuration:

System configuration

  • Platform : Linux-4.13.0-38-generic-x86_64-with-Ubuntu-16.04-xenial
  • Display driver : NVRM version: NVIDIA UNIX x86_64 Kernel Module 384.111 Tue Dec 19 23:51:45 PST 2017 GCC version: gcc version 5.4.0 (Ubuntu 5.4.0-6ubuntu1~16.04.9)
  • GPU: NVIDIA Titan Xp
  • CUDA: 8.0.61
  • CUDNN: 6.0.21

Python packages

  • Python: 3.5.2
  • PyTorch: 0.2.0_2
  • Torchvision: 0.1.9
  • NumPy: 1.14.2
  • OpenCV: 3.2.0
  • PIL: 1.1.7

There are no explicit version requirements for any of the components and the AMNet is expected to run with other configurations too, however, it is important to note that results obtained with different setup may slightly differ from our publication. We found that with PyTorch version 0.3.0.post4 the average Spearman's rank correlation over the five LaMem splits was RC=0.67641 while with the version 0.2.0_2 RC=0.67666 (published).

Datasets

The AMNet was evaluated on two datasests, LaMem and SUN memorability. The original SUN memorability dataset was processed to the format identical with LaMem. Both datasets can be downloaded bu running the following commands. You will need ~3GB space on your driver (2.7GB for LaMem and 280MB for SUN).

cd datasets
../download.sh lamem_url.txt
../download.sh sun_memorability_url.txt

You can also use wild card '*.txt' to download them.

Trained Models

In order to quikly verify the published results or use the AMNet for your own application you can download
models fully trained on the LaMem and SUN datatests. You can download all by running the following command. You will need ~11GB space on your drive.

cd data
../download.sh *.txt

Alternatively you can download weights for each test case individually.

cd data
../download.sh lamem_weights_urls.txt
../download.sh sun_weights_urls.txt

The models will be stored in the 'data' directory, one for each split.

Model size
lamem_ResNet50FC_lstm3_train_* 822MB
lamem_ResNet50FC_lstm3_noatt_train_* 822MB
lamem_ResNet101FC_lstm3_train_* 1.2GB
sun_ResNet50FC_lstm3_train_* 4GB
sun_ResNet50FC_lstm3_noatt_train_* 4GB

Where ResNet* signifies name of CNN model used for features extraction, noatt stands for 'no visual attention' and lstm3 a LSTM sequence with three steps.

Evaluation

Evaluation on the LaMem and SUN datasets was done according to protocols suggested by authors of the datasets. The LaMem was evaluated on 5 and the SUN on 25 train/test splits. Each evaluation was done twice, with the attention enabled and disabled. To run the LaMem evaluation please fisrt download the LaMem dataset Datasets and the trained models [Trained Models](#Trained Models) and then run

python3 main.py --test --dataset lamem --cnn ResNet50FC --test-split 'test_*'

In order to run the evaluation without the attention specifiy --att-off argument

python3 main.py --test --dataset lamem --cnn ResNet50FC --test-split 'test_*' --att-off'

Predicting memorability of images

If you wish to estimate memorability for your images you have two options, process all images in a given directory or create a csv file with a list of images to process. In both cases you need to specify file with the model weights. To predict memorability of all images in directory run this command

python3 main.py --cnn ResNet50FC --model-weights data/lamem_ResNet50FC_lstm3_train_5/weights_35.pkl --eval-images images/high

Memorability of each image will be printed on the stdout. If you want to save the memorabilities to a csv file specify argument --csv-out <filename.txt>

python3 main.py --cnn ResNet50FC --model-weights data/lamem_ResNet50FC_lstm3_train_5/weights_35.pkl --eval-images images/high --csv-out memorabilities.txt

Attention maps for each LSTM step can be stored as a jpg image for each input image by specifying output path --att-maps-out <out_dir>

python3 main.py --cnn ResNet50FC --model-weights data/lamem_ResNet50FC_lstm3_train_5/weights_35.pkl --eval-images images/high --att-maps-out att_maps

Here is an example of the attention map image. It includes the original image and one image for each LSTM step with the attention map shown as a heatmap overlay. img1

Training

To train the AMNet you need to follow these steps

  • select CNN front end for image features extraction. Available models are ResNet18FC, ResNet50FC, ResNet101FC and VGG16FC.
  • select lamem or sun dataset.
  • specify training and validation splits. Note that the SUN memorability dataset doesn't come with validation split, thus the test split need to be used.
  • optionally you can set the batch sizes, gpu id, disable the visual attention and other. Please run main.py --help to see other options.
python3 main.py --train-batch-size 222 --test-batch-size 222 --cnn ResNet50FC --dataset lamem --train-split train_1 --val-split val_1

To see other command line arguments please run

python3 main.py --help

or see main.py. If you want to experiment with other parameters the best place to go is config.py.

Cite

If you use this code or reference our paper in your work please cite this publication.

@inproceedings{fajtl2018amnet,
  title={AMNet: Memorability Estimation with Attention},
  author={Fajtl, Jiri and Argyriou, Vasileios and Monekosso, Dorothy and Remagnino, Paolo},
  booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
  pages={6363--6372},
  year={2018}
}

amnet's People

Contributors

electroncastle avatar ok1zjf avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

amnet's Issues

Missing keys in state_dict error in amnet_model.py

Hello, Thank you for sharing your code. While I've been trying to follow your guideline in window-based environment, I see this error message.

Traceback (most recent call last):
File "C:\Users........\AmNet-master\amnet.py", line 284, in load_checkpoint
self.model.load_weights(cpnt['model'])
File "C:\Users........\AmNet-master\amnet_model.py", line 334, in load_weights
raise KeyError('missing keys in state_dict: "{}"'.format(missing))
KeyError: 'missing keys in state_dict: "{'core_cnn.core_cnn.layer2.2.bn3.num_batches_tracked', 'core_cnn.core_cnn.layer3.3
.bn3.num_batches_tracked', 'core_cnn.core_cnn.layer4.0.bn1.num_batches_tracked', 'core_cnn.core_cnn.layer3.4.bn1.num_batc
hes_tracked', 'core_cnn.core_cnn.layer1.2.bn2.num_batches_tracked', ......,}

RuntimeError: Error(s) in loading state_dict for AMemNetModel:
Unexpected key(s) in state_dict: "e11.weight", "e11.bias", "eh2.weight", "eh2.bias", "eh12.weight", "eh12.bias", "eh3
0.weight", "eh30.bias", "eh31.weight", "eh31.bias", "eh11.weight", "eh11.bias", "eh22.weight", "eh22.bias", "regnet2.weight",
"regnet2.bias".

Could you kindly advise me how to solve this issue? I think the runtime error is raised due to the missing keys in state_dict error. I'm relatively new to this world and your comment will be super helpful! Thank you.

Missing Parameters 'target_scale' and 'target_mean' in config.py?

Hi, really appreciate the great work, I'm trying your code to train on my own dataset. But seems there are some hyper-parameters missing from the config.py?
I got this error during training:
AttributeError: 'HParameters' object has no attribute 'target_scale'
For now I'm just setting them to 1 manually, could you please suggest soma values to use?
Looking forward to your prompt reply!

Problem transferring the weights from your model to ResNet50

I am trying to use the pertained weights you have saved as pkl files.

Here is the code that I used to load the weights and attempt to merge them with a blank ResNet50:

from torchvision.models import resnet50

data_directory = '/content/drive/MyDrive/AMNet-master/data/lamem_ResNet50FC_lstm3_train_1/weights_30.pkl'

model = resnet50()

checkpoint = torch.load(data_directory,map_location='cuda:0')

model.load_state_dict(checkpoint)

This is the error message I got:

Screen Shot 2022-10-10 at 2 39 02 PM

How do you suggest I handle this?

Can this code be implemented on CLASSFICATION problem as well?

Hi and Thank you so much for your AMNet code.

I've been thinking that whether this code can be implemented on classification problem coz it seems that I can't attain a satisfactory result on my classification problem, compared with those regression problems.

To be specific, I've changed the MSE loss to CE loss.

Could you plz give me some advice?

Thx a lot and wish you a good day :)

img.mean and image.std in config.py

Hello and thanks for sharing your code, and the trained models in particular.

Are img.mean and img.std the mean and std of each channel's pixel values over all the images in the dataset? I'm assuming so since img.mean and img.std are 1x3 arrays, and they depend on hps.dataset.name. Is that the case?

I am testing the code on some images that are neither from the sun nor the lamem dataset: would you suggest that I set the img.mean and img.std parameters to the mean and std of each channel across my test images (the mean across the eval_images batch) ?

Many thanks for sharing this amazing model!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.