cyclip's People
Forkers
gstoica27 aditya-grover anminhhung addicted-to-coding d-roberts sarahesl pipiku915 dyishiou whuhxb jam1ezhang nguyenvuthientrang wujianp abrar-fahimcyclip's Issues
classes.py file for datasets
Do you intend to release the prompt sets used for evaluating datasets as well?
Creating embedding for own image and text
is there an inference script I can use to get embeddings for my own text and image?
proof for ordering of distances
Thanks for your great work and make it public!
As shown in Figure 1 (b), if representations of any two image-text pairs, (I_dog, T_dog) and (I_cat, T_cat) exactly satisfy both forms of cyclic consistency, when we can guarantee that any test image I_test respects the ordering of distances in both image and text spaces (i.e., if d(I_test, I_dog) > d(I_test, I_cat), then d(I_test, T_dog) > d(I_test, T_cat))
Could you please provide a proof or explanation? Thanks!
Possible bug in inmodal/crossmodal loss calculation
Hi,
I hope I'm not understanding this wrong, but I think there may be an issue with the following line in the train.py file:
inmodal_cyclic_loss = (logits_image_per_image - logits_text_per_text).square().mean() / (umodel.logit_scale.exp() * umodel.logit_scale.exp()) * batch_size
I think what this does is multiplies the batch size in the numerator instead of the denominator, and so for a large batch size e.g. 4096, the loss would become very large which might cause an issue (for me the learning did not progress with 4096 batch size).
I believe the correction should be simply pulling the batch size term inside the brackets
inmodal_cyclic_loss = (logits_image_per_image - logits_text_per_text).square().mean() / (umodel.logit_scale.exp() * umodel.logit_scale.exp() * batch_size)
similarly for crossmodal cyclic loss as well
Any pretrained checkpoints?
Reproducing zero-shot retrieval experiments with CyCLIP
Dear authors
First of all, thank you for the wonderful project and for sharing codes and checkpoints
While reproducing the zero-shot retrieval experiments from the Table 6 in your main paper, I observed a huge gap between the reported results in the paper and my reproduced results.
For clarity, here are the results of the zero-shot retrieval performance I obtained:
COCO dataset
- Image-to-Text Retrieval: {'r1': 21.3, 'r5': 45.0, 'r10': 57.1}
- Text-to-Image Retrieval: {'r1': 15.97, 'r5': 36.57, 'r10': 48.40}
Flickr dataset
- Image-to-Text Retrieval: {'r1': 41.2, 'r5': 70.0, 'r10': 80.0}
- Text-to-Image Retrieval: {'r1': 30.42, 'r5': 57.1, 'r10': 68.82}
The results seem to align closely with the reported performance only in the case of Image-to-Text retrieval on Flickr30k.
I used the CyCLIP checkpoint provided via Google Drive and conducted tests using the Karpathy test split of the COCO and Flickr datasets.
I adapted the code from this repo for retrieval experiments. I hope you can take a quick review on the codes below and help me identify any potential issues.
Additionally, it would be immensely helpful if you could share your codes used for the zero-shot retrieval experiments.
Best regards,
Usage: python test_retrieval.py --dataset coco # or flickr
, where
test_retrieval.py:
import argparse
import open_clip
import torch
from src.retrieval import get_loader_image, get_loader_text
def compute_retrieval(similarity_scores, txt2img, img2txt):
# comput text -> image
t2i_similarity_score = similarity_scores.t()
t2i_ranks = torch.zeros(t2i_similarity_score.shape[0])
for index, score in enumerate(t2i_similarity_score):
inds = torch.argsort(score, descending=True)
t2i_ranks[index] = torch.where(inds == txt2img[index])[0][0]
print(
'Evaluating batch {}/{}, {}'.format(
index, t2i_similarity_score.shape[0], t2i_ranks[index]
),
end="\r"
)
# Compute metrics
tr1 = 100.0 * len(torch.where(t2i_ranks < 1)[0]) / len(t2i_ranks)
tr5 = 100.0 * len(torch.where(t2i_ranks < 5)[0]) / len(t2i_ranks)
tr10 = 100.0 * len(torch.where(t2i_ranks < 10)[0]) / len(t2i_ranks)
t2i_report_dict = {"r1": tr1, "r5": tr5, "r10": tr10}
# comput image -> text
i2t_similarity_score = similarity_scores
i2t_ranks = torch.zeros(i2t_similarity_score.shape[0])
for index, score in enumerate(i2t_similarity_score):
print('Evaluating batch {}/{}'.format(index, i2t_similarity_score.shape[0]), end="\r")
inds = torch.argsort(score, descending=True)
# Score
rank = 1e10
for i in img2txt[index]:
tmp = torch.where(inds == i)[0][0]
if tmp < rank:
rank = tmp
i2t_ranks[index] = rank
# Compute metrics
ir1 = 100.0 * len(torch.where(i2t_ranks < 1)[0]) / len(i2t_ranks)
ir5 = 100.0 * len(torch.where(i2t_ranks < 5)[0]) / len(i2t_ranks)
ir10 = 100.0 * len(torch.where(i2t_ranks < 10)[0]) / len(i2t_ranks)
i2t_report_dict = {"r1": ir1, "r5": ir5, "r10": ir10}
return t2i_report_dict, i2t_report_dict
def get_image_feature(model, data_loader):
image_features = []
for batch_idx, batch in enumerate(data_loader):
print('Evaluating batch {}/{}'.format(batch_idx, len(data_loader)), end="\r")
images, _ = batch
image_emb = model.encode_image(images.cuda()) # embed with image encoder
image_features.append(image_emb.detach().cpu())
image_features = torch.cat(image_features, 0)
print('Done image feature extract.')
print(image_features.shape)
# normalized features
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
return image_features
def get_text_feature(model, data_loader):
text_features = []
for batch_idx, batch in enumerate(data_loader):
print('Evaluating batch {}/{}'.format(batch_idx, len(data_loader)), end="\r")
text = batch.squeeze()
text_emb = model.encode_text(text.cuda())
text_features.append(text_emb.detach().cpu())
text_features = torch.cat(text_features, 0)
print('Done text feature extract.')
print(text_features.shape)
# normalized features
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
return text_features
def main(args):
pretrained = "/home/appuser/.cache/torch/hub/CyCLIP/cc3m/CyCLIP.pt"
model, _, transform = open_clip.create_model_and_transforms(
"RN50", pretrained=pretrained, device="cuda"
)
model = model.eval().cuda()
if args.dataset == "coco":
# karpathy split
ann_file = "/home/appuser/datasets/coco/coco_karpathy_test.json"
data_root = "/home/appuser/datasets/coco/"
image_root = "images/val2014"
else:
# karpathy split
ann_file = "/home/appuser/datasets/flickr30k/annotations/flickr30k_test.json"
data_root = "/home/appuser/datasets/flickr30k/"
image_root = "images/flickr30k-images"
text_loader = get_loader_text(ann_file, data_root, image_root, args.batch_size, transform)
text_features = get_text_feature(model, text_loader)
image_loader, txt2img, img2txt = get_loader_image(
ann_file, data_root, image_root, args.batch_size, transform
)
image_features = get_image_feature(model, image_loader)
similarity_scores = image_features.cuda() @ text_features.cuda().t()
similarity_scores = similarity_scores
t2i_dict, i2t_dict = compute_retrieval(similarity_scores, txt2img, img2txt)
print('Image-to-Text retrieval', i2t_dict)
print('Text-to-Image retrieval', t2i_dict)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="ZeroShot")
parser.add_argument("--batch-size", default=64, type=int)
parser.add_argument("--dataset", default="coco", type=str, help='coco or flickr')
args = parser.parse_args()
main(args)
src/retrieval.py:
import json
import os
from open_clip import tokenize
from PIL import Image
from torch.utils.data import DataLoader, Dataset
class TextDataset(Dataset):
def __init__(self, text_data, tokenizer):
self.tokenizer = tokenizer
self.caption = text_data
def __len__(self):
return len(self.caption)
def __getitem__(self, index):
text_data = self.caption[index]
# optional
# text_data = 'a photo of ' + text_data
text_token = self.tokenizer(text_data)
return text_token
class CaptionsDataset(Dataset):
def __init__(self, ann_file, transform, data_root, image_root):
self.ann_file = json.load(open(ann_file, 'r'))
self.transform = transform
self.image_root = image_root
self.caption = []
self.image = []
self.txt2img = {}
self.img2txt = {}
txt_num = 0
for num, line in enumerate(self.ann_file):
image_name = line['image'].split('/')[1]
image_path = os.path.join(data_root, image_root, image_name)
self.image.append(image_path)
self.caption += line['caption']
for i in range(txt_num, txt_num + len(line['caption'])):
self.txt2img[i] = num
if num not in self.img2txt.keys():
self.img2txt[num] = [i]
else:
self.img2txt[num].append(i)
txt_num += len(line['caption'])
def __len__(self):
return len(self.image)
def __getitem__(self, index):
image_path = os.path.join(self.image[index])
image = Image.open(image_path).convert('RGB')
image = self.transform(image)
return image, index
def get_loader_image(ann_file, data_root, image_root, batch_size, preprocess):
valid_dataset = CaptionsDataset(ann_file, preprocess, data_root, image_root)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size)
return valid_dataloader, valid_dataset.txt2img, valid_dataset.img2txt
def get_loader_text(ann_file, data_root, image_root, batch_size, preprocess):
valid_dataset = CaptionsDataset(ann_file, preprocess, data_root, image_root)
text_dataset = TextDataset(valid_dataset.caption, tokenize)
valid_dataloader = DataLoader(text_dataset, batch_size=batch_size, shuffle=False)
return valid_dataloader
Run Command used for Experiments
Hi,
Thanks for all this cool research and making making your code public!
I was hoping to explore this area a bit and wanted to train a CyCLIP model from scratch on the CC3M dataset like was described in the paper. I was wondering if by chance there was an example run command used to e.g. generate the base CyCLIP model? I know that you describe the setup in the preprint - I just want to make sure I am running the correct command.
Thanks very much!
Pretraining for I-CyCLIP and C-CyCLIP
Thanks for releasing your code and checkpoints!
The Google Drive checkpoints folder contains checkpoints for the I-CyCLIP and C-CyCLIP models, how many examples were these models trained on? My guess would be: CC3M data only ~ 2.6M datapoints, but I can't see an explicit mention in the repo/paper.
Using other model with checkpoint
Hello, in your example you use RN50 with your checkpoint weights. Is it possible to load, for example, ViT14 and use your checkpoints? Or there checkpoints only for RN50?
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. ๐๐๐
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google โค๏ธ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.