Giter Site home page Giter Site logo

cyclip's People

Contributors

goel-shashank avatar hritikbansal avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

cyclip's Issues

proof for ordering of distances

Thanks for your great work and make it public!

As shown in Figure 1 (b), if representations of any two image-text pairs, (I_dog, T_dog) and (I_cat, T_cat) exactly satisfy both forms of cyclic consistency, when we can guarantee that any test image I_test respects the ordering of distances in both image and text spaces (i.e., if d(I_test, I_dog) > d(I_test, I_cat), then d(I_test, T_dog) > d(I_test, T_cat))

Could you please provide a proof or explanation? Thanks!

Possible bug in inmodal/crossmodal loss calculation

Hi,

I hope I'm not understanding this wrong, but I think there may be an issue with the following line in the train.py file:

inmodal_cyclic_loss = (logits_image_per_image - logits_text_per_text).square().mean() / (umodel.logit_scale.exp() * umodel.logit_scale.exp()) * batch_size

I think what this does is multiplies the batch size in the numerator instead of the denominator, and so for a large batch size e.g. 4096, the loss would become very large which might cause an issue (for me the learning did not progress with 4096 batch size).

I believe the correction should be simply pulling the batch size term inside the brackets
inmodal_cyclic_loss = (logits_image_per_image - logits_text_per_text).square().mean() / (umodel.logit_scale.exp() * umodel.logit_scale.exp() * batch_size)
similarly for crossmodal cyclic loss as well

Reproducing zero-shot retrieval experiments with CyCLIP

Dear authors
First of all, thank you for the wonderful project and for sharing codes and checkpoints

While reproducing the zero-shot retrieval experiments from the Table 6 in your main paper, I observed a huge gap between the reported results in the paper and my reproduced results.

For clarity, here are the results of the zero-shot retrieval performance I obtained:
COCO dataset

  • Image-to-Text Retrieval: {'r1': 21.3, 'r5': 45.0, 'r10': 57.1}
  • Text-to-Image Retrieval: {'r1': 15.97, 'r5': 36.57, 'r10': 48.40}

Flickr dataset

  • Image-to-Text Retrieval: {'r1': 41.2, 'r5': 70.0, 'r10': 80.0}
  • Text-to-Image Retrieval: {'r1': 30.42, 'r5': 57.1, 'r10': 68.82}

The results seem to align closely with the reported performance only in the case of Image-to-Text retrieval on Flickr30k.

I used the CyCLIP checkpoint provided via Google Drive and conducted tests using the Karpathy test split of the COCO and Flickr datasets.

I adapted the code from this repo for retrieval experiments. I hope you can take a quick review on the codes below and help me identify any potential issues.

Additionally, it would be immensely helpful if you could share your codes used for the zero-shot retrieval experiments.

Best regards,


Usage: python test_retrieval.py --dataset coco # or flickr, where

test_retrieval.py:

import argparse

import open_clip
import torch

from src.retrieval import get_loader_image, get_loader_text


def compute_retrieval(similarity_scores, txt2img, img2txt):
    # comput text -> image
    t2i_similarity_score = similarity_scores.t()
    t2i_ranks = torch.zeros(t2i_similarity_score.shape[0])

    for index, score in enumerate(t2i_similarity_score):
        inds = torch.argsort(score, descending=True)
        t2i_ranks[index] = torch.where(inds == txt2img[index])[0][0]
        print(
            'Evaluating batch {}/{}, {}'.format(
                index, t2i_similarity_score.shape[0], t2i_ranks[index]
            ),
            end="\r"
        )

    # Compute metrics
    tr1 = 100.0 * len(torch.where(t2i_ranks < 1)[0]) / len(t2i_ranks)
    tr5 = 100.0 * len(torch.where(t2i_ranks < 5)[0]) / len(t2i_ranks)
    tr10 = 100.0 * len(torch.where(t2i_ranks < 10)[0]) / len(t2i_ranks)
    t2i_report_dict = {"r1": tr1, "r5": tr5, "r10": tr10}

    # comput image -> text
    i2t_similarity_score = similarity_scores
    i2t_ranks = torch.zeros(i2t_similarity_score.shape[0])
    for index, score in enumerate(i2t_similarity_score):
        print('Evaluating batch {}/{}'.format(index, i2t_similarity_score.shape[0]), end="\r")
        inds = torch.argsort(score, descending=True)
        # Score
        rank = 1e10
        for i in img2txt[index]:
            tmp = torch.where(inds == i)[0][0]
            if tmp < rank:
                rank = tmp
        i2t_ranks[index] = rank

    # Compute metrics
    ir1 = 100.0 * len(torch.where(i2t_ranks < 1)[0]) / len(i2t_ranks)
    ir5 = 100.0 * len(torch.where(i2t_ranks < 5)[0]) / len(i2t_ranks)
    ir10 = 100.0 * len(torch.where(i2t_ranks < 10)[0]) / len(i2t_ranks)
    i2t_report_dict = {"r1": ir1, "r5": ir5, "r10": ir10}
    return t2i_report_dict, i2t_report_dict


def get_image_feature(model, data_loader):
    image_features = []
    for batch_idx, batch in enumerate(data_loader):
        print('Evaluating batch {}/{}'.format(batch_idx, len(data_loader)), end="\r")
        images, _ = batch
        image_emb = model.encode_image(images.cuda())  # embed with image encoder
        image_features.append(image_emb.detach().cpu())
    image_features = torch.cat(image_features, 0)

    print('Done image feature extract.')
    print(image_features.shape)

    # normalized features
    image_features = image_features / image_features.norm(dim=-1, keepdim=True)
    return image_features


def get_text_feature(model, data_loader):
    text_features = []
    for batch_idx, batch in enumerate(data_loader):
        print('Evaluating batch {}/{}'.format(batch_idx, len(data_loader)), end="\r")
        text = batch.squeeze()
        text_emb = model.encode_text(text.cuda())
        text_features.append(text_emb.detach().cpu())

    text_features = torch.cat(text_features, 0)
    print('Done text feature extract.')
    print(text_features.shape)

    # normalized features
    text_features = text_features / text_features.norm(dim=-1, keepdim=True)
    return text_features


def main(args):
    pretrained = "/home/appuser/.cache/torch/hub/CyCLIP/cc3m/CyCLIP.pt"
    model, _, transform = open_clip.create_model_and_transforms(
        "RN50", pretrained=pretrained, device="cuda"
    )
    model = model.eval().cuda()

    if args.dataset == "coco":
        # karpathy split
        ann_file = "/home/appuser/datasets/coco/coco_karpathy_test.json"
        data_root = "/home/appuser/datasets/coco/"
        image_root = "images/val2014"
    else:
        # karpathy split
        ann_file = "/home/appuser/datasets/flickr30k/annotations/flickr30k_test.json"
        data_root = "/home/appuser/datasets/flickr30k/"
        image_root = "images/flickr30k-images"

    text_loader = get_loader_text(ann_file, data_root, image_root, args.batch_size, transform)
    text_features = get_text_feature(model, text_loader)

    image_loader, txt2img, img2txt = get_loader_image(
        ann_file, data_root, image_root, args.batch_size, transform
    )
    image_features = get_image_feature(model, image_loader)

    similarity_scores = image_features.cuda() @ text_features.cuda().t()
    similarity_scores = similarity_scores
    t2i_dict, i2t_dict = compute_retrieval(similarity_scores, txt2img, img2txt)
    print('Image-to-Text retrieval', i2t_dict)
    print('Text-to-Image retrieval', t2i_dict)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="ZeroShot")
    parser.add_argument("--batch-size", default=64, type=int)
    parser.add_argument("--dataset", default="coco", type=str, help='coco or flickr')
    args = parser.parse_args()
    main(args)

src/retrieval.py:

import json
import os

from open_clip import tokenize
from PIL import Image
from torch.utils.data import DataLoader, Dataset


class TextDataset(Dataset):

    def __init__(self, text_data, tokenizer):
        self.tokenizer = tokenizer
        self.caption = text_data

    def __len__(self):
        return len(self.caption)

    def __getitem__(self, index):
        text_data = self.caption[index]
        # optional
        # text_data = 'a photo of ' + text_data
        text_token = self.tokenizer(text_data)
        return text_token


class CaptionsDataset(Dataset):

    def __init__(self, ann_file, transform, data_root, image_root):
        self.ann_file = json.load(open(ann_file, 'r'))
        self.transform = transform
        self.image_root = image_root
        self.caption = []
        self.image = []
        self.txt2img = {}
        self.img2txt = {}

        txt_num = 0
        for num, line in enumerate(self.ann_file):
            image_name = line['image'].split('/')[1]
            image_path = os.path.join(data_root, image_root, image_name)
            self.image.append(image_path)
            self.caption += line['caption']
            for i in range(txt_num, txt_num + len(line['caption'])):
                self.txt2img[i] = num
                if num not in self.img2txt.keys():
                    self.img2txt[num] = [i]
                else:
                    self.img2txt[num].append(i)
            txt_num += len(line['caption'])

    def __len__(self):
        return len(self.image)

    def __getitem__(self, index):
        image_path = os.path.join(self.image[index])
        image = Image.open(image_path).convert('RGB')
        image = self.transform(image)
        return image, index


def get_loader_image(ann_file, data_root, image_root, batch_size, preprocess):
    valid_dataset = CaptionsDataset(ann_file, preprocess, data_root, image_root)
    valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size)
    return valid_dataloader, valid_dataset.txt2img, valid_dataset.img2txt


def get_loader_text(ann_file, data_root, image_root, batch_size, preprocess):
    valid_dataset = CaptionsDataset(ann_file, preprocess, data_root, image_root)
    text_dataset = TextDataset(valid_dataset.caption, tokenize)
    valid_dataloader = DataLoader(text_dataset, batch_size=batch_size, shuffle=False)
    return valid_dataloader

Run Command used for Experiments

Hi,

Thanks for all this cool research and making making your code public!

I was hoping to explore this area a bit and wanted to train a CyCLIP model from scratch on the CC3M dataset like was described in the paper. I was wondering if by chance there was an example run command used to e.g. generate the base CyCLIP model? I know that you describe the setup in the preprint - I just want to make sure I am running the correct command.

Thanks very much!

Pretraining for I-CyCLIP and C-CyCLIP

Thanks for releasing your code and checkpoints!

The Google Drive checkpoints folder contains checkpoints for the I-CyCLIP and C-CyCLIP models, how many examples were these models trained on? My guess would be: CC3M data only ~ 2.6M datapoints, but I can't see an explicit mention in the repo/paper.

Using other model with checkpoint

Hello, in your example you use RN50 with your checkpoint weights. Is it possible to load, for example, ViT14 and use your checkpoints? Or there checkpoints only for RN50?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.