Giter Site home page Giter Site logo

jianjieluo / scd-net Goto Github PK

View Code? Open in Web Editor NEW
45.0 2.0 3.0 408 KB

[CVPR23] A cascaded diffusion captioning model with a novel semantic-conditional diffusion process that upgrades conventional diffusion model with additional semantic prior.

Home Page: https://arxiv.org/abs/2212.03099

License: Other

Shell 0.30% Python 99.70%
diffusion-model image-caption

scd-net's Introduction

Semantic-Conditional Diffusion Networks for Image Captioning [CVPR2023]

Introduction

This is the official repository for Semantic-Conditional Diffusion Networks for Image Captioning(SCD-Net). SCD-Net is a cascaded diffusion captioning model with a novel semantic-conditional diffusion process that upgrades conventional diffusion model with additional semantic prior. A novel guided self-critical sequence training strategy is further devised to stabilize and boost the diffusion process.

To our best knowledge, SCD-Net is the first diffusion-based captioning model that achieves better performance than the naive auto-regressive transformer captioning model conditioned on the same visual features(i.e. bottom-up attention region features) in both XE and RL training stages. SCD-Net is also the first diffusion-based captioning model that adopts CIDEr-D optimization successfully via a novel guided self-critical sequence training strategy.

SCD-Net achieves state-of-the-art performance among non-autoregressive/diffusion captioning models and comparable performance aginst the state-of-the-art autoregressive captioning models, which indicates the promising potential of using diffusion models in the challenging image captioning task.

Framework

scdnet

Data Preparation

  1. Download training misc data(Google Drive, BaiduYun, extract code: 6os2) for SCD-Net, and extract it in open_source_dataset/mscoco_dataset/ folder.
  2. Download official Bottom-up features(10 to 100 regions) and preprocess them.
python tools/create_feats.py --infeats karpathy_train_resnet101_faster_rcnn_genome.tsv.0 --outfolder ../open_source_dataset/mscoco_dataset/features/up_down

Training

Since SCD-Net is a cascaded diffusion captioning model, we need to train stage1 and stage2 model sequentially.

# Train stage1 XE
bash configs/image_caption/scdnet/stage1/1_train_xe.sh

# Train stage1 RL
bash configs/image_caption/scdnet/stage1/2_train_rl.sh

# Inference sentences for training images using stage1 xe model in order to train stage2 XE
bash configs/image_caption/scdnet/stage1/3_xe_inf_train.sh

# Inference sentences for training images using stage1 rl model in order to train stage2 RL
bash configs/image_caption/scdnet/stage1/4_rl_inf_train.sh

# Train stage2 XE
bash configs/image_caption/scdnet/stage2/1_train_xe.sh

# Train stage2 RL
bash configs/image_caption/scdnet/stage2/2_train_rl.sh

# Inference sentences for training images using stage2 rl model and update better guided sentences
bash configs/image_caption/scdnet/stage2/3_rl_inf_train.sh
cd tools/update_kd_sents
python compare_merge.py --last_kd {path_to_autoregressive_teacher_pred_ep25.pkl} --new_pred {path_to_stage2_rl_infernece_train} --out {path_to_updated_sentences}

# Train stage2 RL with updated guided sentences
bash configs/image_caption/scdnet/stage2/4_train_rl_update_kd.sh

Citation

If you use this code for your research, please cite:

@article{luo2022semantic,
  title={Semantic-Conditional Diffusion Networks for Image Captioning},
  author={Luo, Jianjie and Li, Yehao and Pan, Yingwei and Yao, Ting and Feng, Jianlin and Chao, Hongyang and Mei, Tao},
  journal={arXiv preprint arXiv:2212.03099},
  year={2022}
}

Acknowledgement

This code used resources from X-Modaler Codebase and bit-diffusion code. We thank the authors for open-sourcing their awesome projects.

License

MIT

scd-net's People

Contributors

jianjieluo avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

scd-net's Issues

mscoco_clip_ret_sents.pkl的参数

您好,我想知道mscoco_clip_ret_sents.pkl 的参数具体代表什么?"r_token_ids" 是保存cross-modal retrieval model 产生的20个句子吗?cross-modal retrieval model 是用的CLIP吗?谢谢

unable to process up-down feature.

image
I run tools/create_feats.py, but it throw out this problem. I turn the mode into read text, but still not work.
Could you please tell me what is wrong?

Hi,I encountered an error during the first stage of training. After running for some time, an error occurred.“zipfile.BadZipFile File is not a zip file”

[08/31 09:51:13 xl.utils.even
zipfile.BadZipFile File is not a zip file.txt
ts]: eta: 22:35:08 iter: 2839 total_loss: 2.535 MSE loss(U): 0.3188 LabelSmoothing(G) loss: 2.217 time: 0.4365 data_time: 0.2023 lr: 4.4375e-05 max_mem: 3045M
Traceback (most recent call last):
File "train_net.py", line 78, in
args=(args,),
File "/home/username/SCD/xmodaler/engine/launch.py", line 83, in launch
daemon=False,
File "/home/username/anaconda3/envs/py37/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 240, in spawn
return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
File "/home/username/anaconda3/envs/py37/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 198, in start_processes
while not context.join():
File "/home/username/anaconda3/envs/py37/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 160, in join
raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException:

-- Process 1 terminated with the following error:
Traceback (most recent call last):
File "/home/username/anaconda3/envs/py37/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
fn(i, *args)
File "/home/username/SCD/xmodaler/engine/launch.py", line 129, in _distributed_worker
main_func(*args)
File "/home/username/SCD/train_net.py", line 66, in main
return trainer.train()
File "/home/username/SCD/xmodaler/engine/defaults.py", line 411, in train
super().train(self.start_iter, self.max_iter)
File "/home/username/SCD/xmodaler/engine/train_loop.py", line 151, in train
self.run_step()
File "/home/username/SCD/xmodaler/engine/bit_diffusion_trainer.py", line 37, in run_step
data = next(self._train_data_loader_iter)
File "/home/username/anaconda3/envs/py37/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 628, in next
data = self._next_data()
File "/home/username/anaconda3/envs/py37/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1313, in _next_data
return self._process_data(data)
File "/home/username/anaconda3/envs/py37/lib/python3.7/site-packages/torch/utils/data/dataloader.p
zipfile.BadZipFile File is not a zip file.txt
y", line 1359, in _process_data
data.reraise()
File "/home/username/anaconda3/envs/py37/lib/python3.7/site-packages/torch/_utils.py", line 543, in reraise
raise exception
zipfile.BadZipFile: Caught BadZipFile in DataLoader worker process 2.
Original Traceback (most recent call last):
File "/home/username/anaconda3/envs/py37/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 302, in _worker_loop
data = fetcher.fetch(index)
File "/home/username/anaconda3/envs/py37/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 58, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/username/anaconda3/envs/py37/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 58, in
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/username/SCD/xmodaler/datasets/common.py", line 42, in getitem
data = self._map_func(self._dataset[cur_idx])
File "/home/username/SCD/xmodaler/datasets/images/mscoco_diffusion.py", line 133, in call
ret = super().call(dataset_dict)
File "/home/username/SCD/xmodaler/datasets/images/mscoco_diffusion.py", line 50, in call
ret = super().call(dataset_dict)
File "/home/username/SCD/xmodaler/datasets/images/mscoco.py", line 120, in call
content = read_np(feat_path)
File "/home/username/SCD/xmodaler/functional/func_io.py", line 22, in read_np
content = np.load(path)
File "/home/username/anaconda3/envs/py37/lib/python3.7/site-packages/numpy/lib/npyio.py", line 433, in load
pickle_kwargs=pickle_kwargs)
File "/home/username/anaconda3/envs/py37/lib/python3.7/site-packages/numpy/lib/npyio.py", line 189, in init
_zip = zipfile_factory(fid)
File "/home/username/anaconda3/envs/py37/lib/python3.7/site-packages/numpy/lib/npyio.py", line 112, in zipfile_factory
return zipfile.ZipFile(file, *args, **kwargs)
File "/home/username/anaconda3/envs/py37/lib/python3.7/zipfile.py", line 1258, in init
self._RealGetContents()
File "/home/username/anaconda3/envs/py37/lib/python3.7/zipfile.py", line 1325, in _RealGetContents
raise BadZipFile("File is not a zip file")
zipfile.BadZipFile: File is not a zip file

Question about diversity

Hello! I have a question: since the diffusion model generates sentences starting from random noise, the generated sentences should reflect diversity. Have you conducted any experience about diversity?

Question about Generating Repeated Tokens while Inferencing

Hi @jianjieluo ,

Thanks for the amazing work! I am working on transferring the architecture of SCD-Net upon medical radiology reports generation. But I figure out that the trained model tends to generate repeated tokens during testing. Have you ever encountered this kind of situation before?

The training losses and visualized results are shown as below.
image
image
Note that the grey curve in the first figure is the training loss on a small medical dataset (only about 5,000 images). The yellow curve is the one that produces repeated tokens.

Could you please help me out with this? Thanks you in advance for replying from your busy schedule.

Best,

How long does it take to finish the first stage of training?

Thanks for your excellent work and codes!
I tried to retrain the model on my own dataset (174,350 samples for training), and it took 3-4 days to finish merely 5 epochs... Is that normal? Could you tell me how long it took to finish the first stage of training in your study?

Indicates that the.npz file cannot be found

Hello, when I was training 1_train_xe.sh, there was no error 446093.npz, how can I solve this problem? Previous address path... /open_source_dataset/mscoco_dataset/features/up_down The data can be queried.
微信图片_20231129154842

how to use the cross-modal retrieval modal in this task?

I didn't find the process that search the semantically relevant sentence from training sentence pool by using an off-the-shelf cross-modal retrieval modal, which is mentioned in the paper. Could you please show me how to do this process in the code?

how to generate long sentence

I try to use this code to generate long text. However, I found that model tend to generate short sentences, I wonder why?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.