Giter Site home page Giter Site logo

yueyang1996 / labo Goto Github PK

View Code? Open in Web Editor NEW
66.0 3.0 6.0 16.31 MB

CVPR 2023: Language in a Bottle: Language Model Guided Concept Bottlenecks for Interpretable Image Classification

Home Page: https://arxiv.org/abs/2211.11158

Python 99.56% Shell 0.44%
image-classification interpretability few-shot-learning language-model

labo's Issues

code for multi-gpu training

hi,it seems like the code provided is only for one GPU training, could you please provide the code for multi-gpu training?

conda: fail to install the environment

Hello,
thank you for releasing the code of your paper. I am having troubles to install your environment, just by following the instructions. I do get the following error:

Collecting package metadata (current_repodata.json): - WARNING conda.models.version:get_matcher(556): Using .* with relational operator is superfluous and deprecated and will be removed in a future version of conda. Your spec was 1.7.1.*, but conda is ignoring the .* and treating it as 1.7.1
done
Solving environment: unsuccessful initial attempt using frozen solve. Retrying with flexible solve.
Collecting package metadata (repodata.json): \ WARNING conda.models.version:get_matcher(556): Using .* with relational operator is superfluous and deprecated and will be removed in a future version of conda. Your spec was 1.6.0.*, but conda is ignoring the .* and treating it as 1.6.0
WARNING conda.models.version:get_matcher(556): Using .* with relational operator is superfluous and deprecated and will be removed in a future version of conda. Your spec was 1.9.0.*, but conda is ignoring the .* and treating it as 1.9.0
WARNING conda.models.version:get_matcher(556): Using .* with relational operator is superfluous and deprecated and will be removed in a future version of conda. Your spec was 1.8.0.*, but conda is ignoring the .* and treating it as 1.8.0
done
Solving environment: unsuccessful initial attempt using frozen solve. Retrying with flexible solve.

PackagesNotFoundError: The following packages are not available from current channels:

  - sklearn==0.0=pypi_0
  - argon2-cffi-bindings==21.2.0=pypi_0
  - tensorboard-plugin-wit==1.8.1=pypi_0
  - flask-compress==1.12=pypi_0
  - prompt-toolkit==3.0.30=pypi_0
  - dash-table==5.0.0=pypi_0
  - dash-html-components==2.0.0=pypi_0
  - webencodings==0.5.1=pypi_0
  - terminaltables==3.1.10=pypi_0
  - grpcio==1.46.3=pypi_0
  - openpyxl==3.0.10=pypi_0
  - yarl==1.7.2=pypi_0
  - pydeprecate==0.3.2=pypi_0
  - parse==1.19.0=pypi_0
  - et-xmlfile==1.1.0=pypi_0
  - pyasn1==0.4.8=pypi_0
  - ipyparallel==8.4.1=pypi_0
  - cachetools==5.2.0=pypi_0
  - ipywidgets==7.7.1=pypi_0
  - apricot-select==0.6.1=pypi_0
  - jupyter-client==7.3.4=pypi_0
  - gitpython==3.1.27=pypi_0
  - nltk==3.7=pypi_0
  - mmcv==1.5.2=pypi_0
  - dash-core-components==2.0.0=pypi_0
  - pytorch-lightning==1.6.4=pypi_0
  - docker-pycreds==0.4.0=pypi_0
  - sentencepiece==0.1.96=pypi_0
  - tensorboard==2.9.1=pypi_0
  - matplotlib-inline==0.1.3=pypi_0
  - pyasn1-modules==0.2.8=pypi_0
  - pyyaml==6.0=pypi_0
  - jupyterlab-pygments==0.2.2=pypi_0
  - nbclient==0.6.4=pypi_0
  - plotly==5.9.0=pypi_0
  - jsonschema==4.6.1=pypi_0
  - pathtools==0.1.2=pypi_0
  - tensorboard-data-server==0.6.1=pypi_0
  - braceexpand==0.1.7=pypi_0
  - tokenizers==0.12.1=pypi_0
  - aiohttp==3.8.1=pypi_0
  - pytorch-pfn-extras==0.5.8=pypi_0
  - fastargs==1.2.0=pypi_0
  - sentence-transformers==2.2.2=pypi_0
  - regex==2022.6.2=pypi_0
  - webdataset==0.2.5=pypi_0
  - lesscpy==0.15.0=pypi_0
  - websocket-client==1.3.3=pypi_0
  - tenacity==8.0.1=pypi_0
  - cycler==0.11.0=pypi_0
  - threadpoolctl==3.1.0=pypi_0
  - jupyter-dash==0.4.2=pypi_0
  - soupsieve==2.3.2.post1=pypi_0
  - huggingface-hub==0.8.1=pypi_0
  - imgcat==0.5.0=pypi_0
  - fastjsonschema==2.15.3=pypi_0
  - sentry-sdk==1.5.12=pypi_0
  - anyio==3.6.1=pypi_0
  - pandas-stubs==1.4.3.220704=pypi_0
  - yapf==0.32.0=pypi_0
  - widgetsnbextension==3.6.1=pypi_0
  - notebook==6.4.12=pypi_0
  - setproctitle==1.2.3=pypi_0
  - fonttools==4.33.3=pypi_0
  - pygments==2.12.0=pypi_0
  - click==8.1.3=pypi_0
  - terminado==0.15.0=pypi_0
  - argon2-cffi==21.3.0=pypi_0
  - fsspec==2022.5.0=pypi_0
  - google-auth==2.7.0=pypi_0
  - markdown==3.3.7=pypi_0
  - pytorch==1.11.0=py3.9_cuda11.3_cudnn8.2.0_0
  - openai==0.20.0=pypi_0
  - torchvision==0.12.0=py39_cu113
  - addict==2.4.0=pypi_0
  - seaborn==0.11.2=pypi_0
  - nbformat==5.4.0=pypi_0
  - ansi2html==1.7.0=pypi_0
  - bleach==5.0.1=pypi_0
  - kiwisolver==1.4.3=pypi_0
  - torchmetrics==0.9.1=pypi_0
  - retrying==1.3.3=pypi_0
  - gitdb==4.0.9=pypi_0
  - stack-data==0.3.0=pypi_0
  - dash==2.5.1=pypi_0
  - pycocotools==2.0.4=pypi_0
  - scipy==1.8.1=pypi_0
  - frozenlist==1.3.0=pypi_0
  - clip==1.0=dev_0
  - wikipedia==1.4.0=pypi_0
  - scikit-learn==1.1.1=pypi_0
  - tinycss2==1.1.1=pypi_0
  - promise==2.3=pypi_0
  - jupyterlab-widgets==1.1.1=pypi_0
  - google-auth-oauthlib==0.4.6=pypi_0
  - debugpy==1.6.0=pypi_0
  - jupyterthemes==0.20.0=pypi_0
  - assertpy==1.1=pypi_0
  - ipykernel==6.15.0=pypi_0
  - filelock==3.7.1=pypi_0
  - absl-py==1.1.0=pypi_0
  - tornado==6.1=pypi_0
  - flask==2.1.2=pypi_0
  - pytorch-mutex==1.0=cuda
  - nbconvert==6.5.0=pypi_0
  - tqdm==4.64.0=pypi_0
  - werkzeug==2.1.2=pypi_0
  - requests-oauthlib==1.3.1=pypi_0
  - wandb==0.13.4=pypi_0
  - async-timeout==4.0.2=pypi_0
  - shortuuid==1.0.9=pypi_0
  - nose==1.3.7=pypi_0
  - jupyter-server==1.18.1=pypi_0
  - smmap==5.0.0=pypi_0
  - protobuf==3.19.4=pypi_0
  - ftfy==6.1.1=pypi_0
  - aiosignal==1.2.0=pypi_0
  - ply==3.11=pypi_0
  - matplotlib==3.5.2=pypi_0
  - ffcv==0.0.3=pypi_0
  - joblib==1.1.0=pypi_0
  - brotli==1.0.9=pypi_0
  - pandas==1.4.2=pypi_0
  - rsa==4.8=pypi_0
  - itsdangerous==2.1.2=pypi_0
  - multidict==6.0.2=pypi_0
  - transformers==4.20.1=pypi_0

Current channels:

  - https://repo.anaconda.com/pkgs/main/linux-64
  - https://repo.anaconda.com/pkgs/main/noarch
  - https://repo.anaconda.com/pkgs/free/linux-64
  - https://repo.anaconda.com/pkgs/free/noarch
  - https://repo.anaconda.com/pkgs/r/linux-64
  - https://repo.anaconda.com/pkgs/r/noarch
  - https://conda.anaconda.org/conda-forge/linux-64
  - https://conda.anaconda.org/conda-forge/noarch
  - https://conda.anaconda.org/anaconda/linux-64
  - https://conda.anaconda.org/anaconda/noarch

To search for alternate channels that may provide the conda package you're
looking for, navigate to

    https://anaconda.org

and use the search bar at the top of the page.

Would you please indicate me how to fix it? Do you have a pip installation txt?

Thanks

code for interpretation

hi, is there any code for interpretation to inference, for example, given a single photo, its classification results and top concepts related? The interpretability of the model is achieved by obtaining a set of concepts for each image. Is there specific implementation code available? Is it the method demonstrated in Figure 7 of the paper, where the weights corresponding to concepts in the fully connected (fc) layer are selected?
Thanks a lot!

GPU for LaBo

hi, I have used the config of cub_base for training , but I kept getting the error of cuda out of memory although I have reduced the batchsize to 4.
I guess the problem may be in the step of precomputing features.
These are the error message:
Traceback (most recent call last):
File "/home2/LaBo/main.py", line 350, in
main(cfg)
File "/home2/LaBo/main.py", line 228, in asso_opt_main
data_module = DotProductDataModule(
File "/home2/lLaBo/data.py", line 378, in init
super().init(*args, **kwargs)
File "/home2/LaBo/data.py", line 157, in init
self.prepare_img_feat(self.splits, self.n_shots, self.clip_model, self.clip_ckpt)
File "/home2/LaBo/data.py", line 317, in prepare_img_feat
img_feat, label = self.compute_img_feat(cls2img, n_shots if mode == 'train' else 'all', clip_model, clip_ckpt)
File "/home2/LaBo/data.py", line 303, in compute_img_feat
img_feat = utils.prepare_img_feat(all_img_paths,
File "/home2/LaBo/utils.py", line 94, in prepare_img_feat
batchify_run(process_img, img_names, res, 2048, use_tqdm=True)
File "/home2/LaBo/utils.py", line 62, in batchify_run
batch_res = process_fn(batch_data)
File "/home2/LaBo/utils.py", line 91, in process_img
img_feat = model.encode_image(img_tensor)
File "/home2/anaconda3/envs/labo/lib/python3.9/site-packages/clip/model.py", line 341, in encode_image
return self.visual(image.type(self.dtype))
File "/home2/anaconda3/envs/labo/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home2/anaconda3/envs/labo/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/home2/anaconda3/envs/labo/lib/python3.9/site-packages/clip/model.py", line 148, in forward
x = self.layer1(x)
File "/home2/anaconda3/envs/labo/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home2/anaconda3/envs/labo/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/home2/anaconda3/envs/labo/lib/python3.9/site-packages/torch/nn/modules/container.py", line 217, in forward
input = module(input)
File "/home2/anaconda3/envs/labo/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home2/anaconda3/envs/labo/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/home2/anaconda3/envs/labo/lib/python3.9/site-packages/clip/model.py", line 48, in forward
out = self.bn3(self.conv3(out))
File "/home2/anaconda3/envs/labo/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home2/anaconda3/envs/labo/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/home2/anaconda3/envs/labo/lib/python3.9/site-packages/torch/nn/modules/batchnorm.py", line 175, in forward
return F.batch_norm(
File "/home2/anaconda3/envs/labo/lib/python3.9/site-packages/torch/nn/functional.py", line 2482, in batch_norm
return torch.batch_norm(
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 3.06 GiB. GPU 1 has a total capacity of 10.91 GiB of which 552.62 MiB is free. Including non-PyTorch memory, this process has 10.36 GiB memory in use. Of the allocated memory 6.56 GiB is allocated by PyTorch, and 3.08 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

could you give me some advice ? Shoud I try a higher GPU ?thanks a lot!

Selected Concepts

Hello,

Thanks for sharing the code of your remarkable work.

I'm writing to ask about the "50 selected concepts for each class" mentioned in your paper. Computing these for each dataset is resource-intensive, and having them directly would be a great help to the community.

Many thanks in advance!

Best,
Yikai

Unexpected CIFAR10 Results

Hi!
First, thanks for your great work!
I followed your instructions and tried running labo_train on CIFAR10. However, I'm running into an issue with the results.
I ran labo_train on the dataset and received the following results:
1 shot: 86.34% on val, 85.39% on test.
2 shot: 90.64% on val, 89.68% on test.
4 shot: 91.04% on val, 91.10% on test.
8 shot: 93.44% on val, 92.9% on test.
16 shot: 94.84% on val, 94.80% on test.

However these results do not match the results in the paper, especially in the fewer shots.
Do you have any idea why this might be happening?

Thanks

The code for generating the initial concepts

First of all, great job on the paper and the overall project !!

I went through the repository but I don't seem to find the code for generating the initial concepts; the ones before the selection mechanism with submodular optimization.
Will this part of the code be publicly shared ?

Thanks in advance.

Query regarding selecting concept in Figure 7

Hi,

In Figure 7 of the paper, you have mentioned that:

The top-3 concepts, ranked by their weights in the linear function, for randomly selected classes, paired with a random image from the class, across 6 datasets.

So, did you select the top-3 concepts for a predicted class's set of concepts from mat, not from dot_product? Am I right here?

cant train with the dataset cifar10 and cifar100

hi, I used the code you provided in issue4 to deal with cifar10 and cifar100, but I still got the error for both dataset training

I got the error for cifar10
Exception has occurred: FileNotFoundError
[Errno 2] No such file or directory: ' /LaBo/datasets/CIFAR10/images/stealth_bomber_s_000554.png'
File " /LaBo/utils.py", line 88, in
img_tensor = torch.cat([preprocess(Image.open('{}'.format(img_name)))
File " /LaBo/utils.py", line 88, in process_img
img_tensor = torch.cat([preprocess(Image.open('{}'.format(img_name)))
File " /LaBo/utils.py", line 62, in batchify_run
batch_res = process_fn(batch_data)
File " /LaBo/utils.py", line 98, in prepare_img_feat
batchify_run(process_img, img_names, res, 1024, use_tqdm=True)
File "/ /LaBo/data.py", line 306, in compute_img_feat
img_feat = utils.prepare_img_feat(all_img_paths,
File " /LaBo/data.py", line 320, in prepare_img_feat
img_feat, label = self.compute_img_feat(cls2img, n_shots if mode == 'train' else 'all', clip_model, clip_ckpt)
File " /LaBo/data.py", line 158, in init
self.prepare_img_feat(self.splits, self.n_shots, self.clip_model, self.clip_ckpt)
File " /LaBo/main.py", line 205, in asso_opt_main
data_module = DataModule(
File " /LaBo/main.py", line 356, in
main(cfg)
FileNotFoundError: [Errno 2] No such file or directory: ' /LaBo/datasets/CIFAR10/images/stealth_bomber_s_000554.png'

and the error for cifar100
[Errno 2] No such file or directory: ' /LaBo/datasets/CIFAR100/images/eating_apple_s_000763.png'

Custom Data train

How is it possible to use this model for a custom dataset ?
I currently have few classes that I wish to use the model to train and I also have the concepts json file with me using GPT-3.

Problem related to high LaBO training time

Hi!

Does LaBo training use max_epochs each time i.e, 15k for CIFAR10 and 1000 for ImageNet? Because, as you mentioned in one of the earlier issues, 10 mins per epoch for ImageNet for 1000 epochs amounts to 7 days, which is quite a lot.

Also, it would be great if you could please tell if the trained checkpoints are available?

Can't train on CIFAR10 because data not in the right format

I am trying to run LaBo on CIFAR10, and I get this error:

Traceback (most recent call last): File "/home/bethge/bkr046/LaBo/main.py", line 350, in <module> main(cfg) File "/home/bethge/bkr046/LaBo/main.py", line 228, in asso_opt_main data_module = DotProductDataModule( File "/home/bethge/bkr046/LaBo/data.py", line 377, in __init__ super().__init__(*args, **kwargs) File "/home/bethge/bkr046/LaBo/data.py", line 157, in __init__ self.prepare_img_feat(self.splits, self.n_shots, self.clip_model, self.clip_ckpt) File "/home/bethge/bkr046/LaBo/data.py", line 316, in prepare_img_feat img_feat, label = self.compute_img_feat(cls2img, n_shots if mode == 'train' else 'all', clip_model, clip_ckpt) File "/home/bethge/bkr046/LaBo/data.py", line 302, in compute_img_feat img_feat = utils.prepare_img_feat(all_img_paths, File "/home/bethge/bkr046/LaBo/utils.py", line 94, in prepare_img_feat batchify_run(process_img, img_names, res, 2048, use_tqdm=True) File "/home/bethge/bkr046/LaBo/utils.py", line 62, in batchify_run batch_res = process_fn(batch_data) File "/home/bethge/bkr046/LaBo/utils.py", line 87, in process_img img_tensor = torch.cat([preprocess(Image.open('{}'.format(img_name)))\ File "/home/bethge/bkr046/LaBo/utils.py", line 87, in <listcomp> img_tensor = torch.cat([preprocess(Image.open('{}'.format(img_name)))\ File "/mnt/qb/work/bethge/bkr046/anaconda3/envs/LaBo/lib/python3.9/site-packages/PIL/Image.py", line 3227, in open fp = builtins.open(filename, "rb") FileNotFoundError: [Errno 2] No such file or directory: 'datasets/CIFAR10/images/stealth_bomber_s_000554.png'

I think this happens because when I download CIFAR10 from the link in DATASET.md, it is in a different format:
batches.meta cifar-10-batches-py cifar-10-python.tar.gz data_batch_1 data_batch_2 data_batch_3 data_batch_4 data_batch_5 readme.html test_batch

How do I convert it into the format that LaBo expects?

Packages are not available from current channels

Hi,

I have tried to install the packages following your provided instructions. But I have got issue saying:
The following packages are not available from current channels

Full issue:

Collecting package metadata (current_repodata.json): done
Solving environment: failed with repodata from current_repodata.json, will retry with next repodata source.
Collecting package metadata (repodata.json): done
Solving environment: failed

PackagesNotFoundError: The following packages are not available from current channels:

  - terminaltables==3.1.10=pypi_0
  - fsspec==2022.5.0=pypi_0
  - nbconvert==6.5.0=pypi_0
  - google-auth==2.7.0=pypi_0
  - grpcio==1.46.3=pypi_0
  - webencodings==0.5.1=pypi_0
  - yarl==1.7.2=pypi_0
  - jsonschema==4.6.1=pypi_0
  - markdown==3.3.7=pypi_0
  - flask-compress==1.12=pypi_0
  - pyasn1-modules==0.2.8=pypi_0
  - debugpy==1.6.0=pypi_0
  - click==8.1.3=pypi_0
  - pyyaml==6.0=pypi_0
  - retrying==1.3.3=pypi_0
  - cycler==0.11.0=pypi_0
  - imgcat==0.5.0=pypi_0
  - argon2-cffi-bindings==21.2.0=pypi_0
  - anyio==3.6.1=pypi_0
  - threadpoolctl==3.1.0=pypi_0
  - tenacity==8.0.1=pypi_0
  - et-xmlfile==1.1.0=pypi_0
  - tokenizers==0.12.1=pypi_0
  - jupyterlab-pygments==0.2.2=pypi_0
  - fastargs==1.2.0=pypi_0
  - pandas-stubs==1.4.3.220704=pypi_0
  - absl-py==1.1.0=pypi_0
  - pygments==2.12.0=pypi_0
  - pycocotools==2.0.4=pypi_0
  - stack-data==0.3.0=pypi_0
  - websocket-client==1.3.3=pypi_0
  - torchmetrics==0.9.1=pypi_0
  - prompt-toolkit==3.0.30=pypi_0
  - kiwisolver==1.4.3=pypi_0
  - dash-table==5.0.0=pypi_0
  - protobuf==3.19.4=pypi_0
  - sentence-transformers==2.2.2=pypi_0
  - regex==2022.6.2=pypi_0
  - jupyter-dash==0.4.2=pypi_0
  - wikipedia==1.4.0=pypi_0
  - ftfy==6.1.1=pypi_0
  - pytorch-mutex==1.0=cuda
  - soupsieve==2.3.2.post1=pypi_0
  - pyasn1==0.4.8=pypi_0
  - google-auth-oauthlib==0.4.6=pypi_0
  - scikit-learn==1.1.1=pypi_0
  - matplotlib==3.5.2=pypi_0
  - addict==2.4.0=pypi_0
  - bleach==5.0.1=pypi_0
  - dash==2.5.1=pypi_0
  - pathtools==0.1.2=pypi_0
  - terminado==0.15.0=pypi_0
  - brotli==1.0.9=pypi_0
  - rsa==4.8=pypi_0
  - nbformat==5.4.0=pypi_0
  - pydeprecate==0.3.2=pypi_0
  - cachetools==5.2.0=pypi_0
  - jupyterthemes==0.20.0=pypi_0
  - lesscpy==0.15.0=pypi_0
  - multidict==6.0.2=pypi_0
  - pandas==1.4.2=pypi_0
  - transformers==4.20.1=pypi_0
  - dash-core-components==2.0.0=pypi_0
  - matplotlib-inline==0.1.3=pypi_0
  - tinycss2==1.1.1=pypi_0
  - docker-pycreds==0.4.0=pypi_0
  - tensorboard-data-server==0.6.1=pypi_0
  - itsdangerous==2.1.2=pypi_0
  - tornado==6.1=pypi_0
  - pytorch-pfn-extras==0.5.8=pypi_0
  - joblib==1.1.0=pypi_0
  - filelock==3.7.1=pypi_0
  - smmap==5.0.0=pypi_0
  - apricot-select==0.6.1=pypi_0
  - plotly==5.9.0=pypi_0
  - requests-oauthlib==1.3.1=pypi_0
  - webdataset==0.2.5=pypi_0
  - widgetsnbextension==3.6.1=pypi_0
  - parse==1.19.0=pypi_0
  - nltk==3.7=pypi_0
  - promise==2.3=pypi_0
  - flask==2.1.2=pypi_0
  - mmcv==1.5.2=pypi_0
  - ipyparallel==8.4.1=pypi_0
  - sentencepiece==0.1.96=pypi_0
  - jupyter-server==1.18.1=pypi_0
  - scipy==1.8.1=pypi_0
  - nbclient==0.6.4=pypi_0
  - tensorboard==2.9.1=pypi_0
  - braceexpand==0.1.7=pypi_0
  - wandb==0.13.4=pypi_0
  - fastjsonschema==2.15.3=pypi_0
  - aiosignal==1.2.0=pypi_0
  - clip==1.0=dev_0
  - pytorch-lightning==1.6.4=pypi_0
  - assertpy==1.1=pypi_0
  - yapf==0.32.0=pypi_0
  - setproctitle==1.2.3=pypi_0
  - torchvision==0.12.0=py39_cu113
  - notebook==6.4.12=pypi_0
  - werkzeug==2.1.2=pypi_0
  - openpyxl==3.0.10=pypi_0
  - ansi2html==1.7.0=pypi_0
  - sklearn==0.0=pypi_0
  - pytorch==1.11.0=py3.9_cuda11.3_cudnn8.2.0_0
  - ipykernel==6.15.0=pypi_0
  - ffcv==0.0.3=pypi_0
  - fonttools==4.33.3=pypi_0
  - dash-html-components==2.0.0=pypi_0
  - seaborn==0.11.2=pypi_0
  - aiohttp==3.8.1=pypi_0
  - huggingface-hub==0.8.1=pypi_0
  - tqdm==4.64.0=pypi_0
  - frozenlist==1.3.0=pypi_0
  - gitdb==4.0.9=pypi_0
  - gitpython==3.1.27=pypi_0
  - argon2-cffi==21.3.0=pypi_0
  - ply==3.11=pypi_0
  - async-timeout==4.0.2=pypi_0
  - jupyterlab-widgets==1.1.1=pypi_0
  - sentry-sdk==1.5.12=pypi_0
  - tensorboard-plugin-wit==1.8.1=pypi_0
  - nose==1.3.7=pypi_0
  - openai==0.20.0=pypi_0
  - shortuuid==1.0.9=pypi_0
  - jupyter-client==7.3.4=pypi_0
  - ipywidgets==7.7.1=pypi_0

Current channels:

  - https://repo.anaconda.com/pkgs/main/linux-64
  - https://repo.anaconda.com/pkgs/main/noarch
  - https://repo.anaconda.com/pkgs/r/linux-64
  - https://repo.anaconda.com/pkgs/r/noarch
  - https://conda.anaconda.org/conda-forge/linux-64
  - https://conda.anaconda.org/conda-forge/noarch

To search for alternate channels that may provide the conda package you're
looking for, navigate to

    https://anaconda.org

and use the search bar at the top of the page.

I also used this command conda create --name <env> --file requirement.txt to create the environment.

Though I am using Ubuntu (64-bit OS) with linux-64 channel in conda and also appended the conda-forge channel to mitigate this issue, the issue has not been solved. Do you have any idea for this is to be solved?

Can you just provide the environment as yml file by exporting it (conda env create -f environment.yml)?

ValueError: X cannot contain negative values or must be entirely negative values

Hi, you work was great, but there was an error in submodular_select() for MixtureSelection training as the follows, could you please give me some advice for this error, thanks again.

(n1) zj@cad:~/tc5/LM$ python3 main.py --cfg cfg/asso_opt/flower/flower_1shot_fac.py --work-dir exp/asso_opt/flower/flower_1shot_fac --func asso_opt_main
/home/zj/n1/lib/python3.6/site-packages/mmcv/init.py:21: UserWarning: On January 1, 2023, MMCV will release v2.0.0, in which it will remove components related to the training process and add a data transformation module. In addition, it will rename the package names mmcv to mmcv-lite and mmcv-full to mmcv. See https://github.com/open-mmlab/mmcv/blob/master/docs/en/compatibility.md for more details.
'On January 1, 2023, MMCV will release v2.0.0, in which it will remove '
use submodular
[10000000.0, 10]
use dot product dataloader
prepare txt feat
100%|███████████████████████████████████████████████████████████████████████████████████████| 195/195 [00:15<00:00, 12.38it/s]
select concept
0%| | 0/102 [00:00<?, ?it/s]191 25
0%| | 0/102 [00:00<?, ?it/s]
Traceback (most recent call last):
File "main.py", line 350, in
main(cfg)
File "main.py", line 249, in asso_opt_main
submodular_weights=cfg.submodular_weights
File "/home/zj/tc5/LM/data.py", line 377, in init
super().init(*args, **kwargs)
File "/home/zj/tc5/LM/data.py", line 178, in init
self.select_concept(self.concept_select_fn, self.img_feat['train'], self.concept_feat, self.n_shots, self.num_concept, self.concept2cls, self.clip_ckpt, self.num_images_per_class, self.submodular_weights)
File "/home/zj/tc5/LM/data.py", line 267, in select_concept
num_concepts, num_images_per_class, submodular_weights)
File "/home/zj/tc5/LM/models/select_concept/select_algo.py", line 163, in submodular_select
selected = selector.fit(augmented_concept_features).ranking
File "/home/zj/n1/lib/python3.6/site-packages/apricot/functions/mixture.py", line 215, in fit
sample_weight=sample_weight, sample_cost=sample_cost)
File "/home/zj/n1/lib/python3.6/site-packages/apricot/functions/base.py", line 202, in fit
raise ValueError("X cannot contain negative values or must be entirely "
ValueError: X cannot contain negative values or must be entirely negative values.

Error in labo_train.sh

I get the following error:

main.py: error: argument --func: expected one argument

Examining the script and the code, the value of $asso_opt_main is undefined in labo_train.sh and the eval() function in main.py is also undefined.

File missing

I ran bash labo_train.sh CIFAR_10 all and got this error:

Traceback (most recent call last):
  File "/home/bethge/bkr046/LaBo/main.py", line 341, in <module>
    cfg = utils.pre_exp(args.cfg, args.work_dir)
  File "/home/bethge/bkr046/LaBo/utils.py", line 23, in pre_exp
    cfg = Config.fromfile(cfg_file)
  File "/mnt/qb/work/bethge/bkr046/anaconda3/envs/LaBo/lib/python3.9/site-packages/mmcv/utils/config.py", line 340, in fromfile
    cfg_dict, cfg_text = Config._file2dict(filename,
  File "/mnt/qb/work/bethge/bkr046/anaconda3/envs/LaBo/lib/python3.9/site-packages/mmcv/utils/config.py", line 183, in _file2dict
    check_file_exist(filename)
  File "/mnt/qb/work/bethge/bkr046/anaconda3/envs/LaBo/lib/python3.9/site-packages/mmcv/utils/path.py", line 23, in check_file_exist
    raise FileNotFoundError(msg_tmpl.format(filename))
FileNotFoundError: file "/home/bethge/bkr046/LaBo/cfg/asso_opt/all/all_CIFAR_10shot_fac.py" does not exist

Looks like this file (cfg/asso_opt/all/all_CIFAR_10shot_fac.py) was not added to GitHub. Could you please add it?

Error in testing

Hi @YueYANG1996 I am using the ckpt file generated in the exp folder to test, but I am getting this error. Please would you help

Traceback (most recent call last):
File "main.py", line 370, in
main(cfg)
File "main.py", line 247, in asso_opt_main
data_module = DotProductDataModule(
File "/DATA/siddharth/LaBo/data.py", line 379, in init
super().init(*args, **kwargs)
File "/DATA/siddharth/LaBo/data.py", line 186, in init
self.gen_init_weight_from_cls_name(self.cls_names, self.concepts_raw[self.select_idx])
File "/DATA/siddharth/LaBo/data.py", line 234, in gen_init_weight_from_cls_name
cls_name_feat = utils.prepare_txt_feat(cls_names, clip_model_name=self.clip_model, ckpt_path=self.clip_ckpt)
File "/DATA/siddharth/LaBo/utils.py", line 141, in prepare_txt_feat
model.load_state_dict(ckpt)
File "/DATA/siddharth/anaconda3/envs/LABO/lib/python3.8/site-packages/torch/nn/modules/module.py", line 2041, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for CLIP:
Missing key(s) in state_dict: "positional_embedding", "text_projection", "logit_scale", "visual.class_embedding", "visual.positional_embedding", "visual.proj", "visual.conv1.weight", "visual.ln_pre.weight", "visual.ln_pre.bias", "visual.transformer.resblocks.0.attn.in_proj_weight", "visual.transformer.resblocks.0.attn.in_proj_bias", "visual.transformer.resblocks.0.attn.out_proj.weight", "visual.transformer.resblocks.0.attn.out_proj.bias", "visual.transformer.resblocks.0.ln_1.weight", "visual.transformer.resblocks.0.ln_1.bias", "visual.transformer.resblocks.0.mlp.c_fc.weight", "visual.transformer.resblocks.0.mlp.c_fc.bias", "visual.transformer.resblocks.0.mlp.c_proj.weight", "visual.transformer.resblocks.0.mlp.c_proj.bias", "visual.transformer.resblocks.0.ln_2.weight", "visual.transformer.resblocks.0.ln_2.bias", "visual.transformer.resblocks.1.attn.in_proj_weight", "visual.transformer.resblocks.1.attn.in_proj_bias", "visual.transformer.resblocks.1.attn.out_proj.weight", "visual.transformer.resblocks.1.attn.out_proj.bias", "visual.transformer.resblocks.1.ln_1.weight", "visual.transformer.resblocks.1.ln_1.bias", "visual.transformer.resblocks.1.mlp.c_fc.weight", "visual.transformer.resblocks.1.mlp.c_fc.bias", "visual.transformer.resblocks.1.mlp.c_proj.weight", "visual.transformer.resblocks.1.mlp.c_proj.bias", "visual.transformer.resblocks.1.ln_2.weight", "visual.transformer.resblocks.1.ln_2.bias", "visual.transformer.resblocks.2.attn.in_proj_weight", "visual.transformer.resblocks.2.attn.in_proj_bias", "visual.transformer.resblocks.2.attn.out_proj.weight", "visual.transformer.resblocks.2.attn.out_proj.bias", "visual.transformer.resblocks.2.ln_1.weight", "visual.transformer.resblocks.2.ln_1.bias", "visual.transformer.resblocks.2.mlp.c_fc.weight", "visual.transformer.resblocks.2.mlp.c_fc.bias", "visual.transformer.resblocks.2.mlp.c_proj.weight", "visual.transformer.resblocks.2.mlp.c_proj.bias", "visual.transformer.resblocks.2.ln_2.weight", "visual.transformer.resblocks.2.ln_2.bias", "visual.transformer.resblocks.3.attn.in_proj_weight", "visual.transformer.resblocks.3.attn.in_proj_bias", "visual.transformer.resblocks.3.attn.out_proj.weight", "visual.transformer.resblocks.3.attn.out_proj.bias", "visual.transformer.resblocks.3.ln_1.weight", "visual.transformer.resblocks.3.ln_1.bias", "visual.transformer.resblocks.3.mlp.c_fc.weight", "visual.transformer.resblocks.3.mlp.c_fc.bias", "visual.transformer.resblocks.3.mlp.c_proj.weight", "visual.transformer.resblocks.3.mlp.c_proj.bias", "visual.transformer.resblocks.3.ln_2.weight", "visual.transformer.resblocks.3.ln_2.bias", "visual.transformer.resblocks.4.attn.in_proj_weight", "visual.transformer.resblocks.4.attn.in_proj_bias", "visual.transformer.resblocks.4.attn.out_proj.weight", "visual.transformer.resblocks.4.attn.out_proj.bias", "visual.transformer.resblocks.4.ln_1.weight", "visual.transformer.resblocks.4.ln_1.bias", "visual.transformer.resblocks.4.mlp.c_fc.weight", "visual.transformer.resblocks.4.mlp.c_fc.bias", "visual.transformer.resblocks.4.mlp.c_proj.weight", "visual.transformer.resblocks.4.mlp.c_proj.bias", "visual.transformer.resblocks.4.ln_2.weight", "visual.transformer.resblocks.4.ln_2.bias", "visual.transformer.resblocks.5.attn.in_proj_weight", "visual.transformer.resblocks.5.attn.in_proj_bias", "visual.transformer.resblocks.5.attn.out_proj.weight", "visual.transformer.resblocks.5.attn.out_proj.bias", "visual.transformer.resblocks.5.ln_1.weight", "visual.transformer.resblocks.5.ln_1.bias", "visual.transformer.resblocks.5.mlp.c_fc.weight", "visual.transformer.resblocks.5.mlp.c_fc.bias", "visual.transformer.resblocks.5.mlp.c_proj.weight", "visual.transformer.resblocks.5.mlp.c_proj.bias", "visual.transformer.resblocks.5.ln_2.weight", "visual.transformer.resblocks.5.ln_2.bias", "visual.transformer.resblocks.6.attn.in_proj_weight", "visual.transformer.resblocks.6.attn.in_proj_bias", "visual.transformer.resblocks.6.attn.out_proj.weight", "visual.transformer.resblocks.6.attn.out_proj.bias", "visual.transformer.resblocks.6.ln_1.weight", "visual.transformer.resblocks.6.ln_1.bias", "visual.transformer.resblocks.6.mlp.c_fc.weight", "visual.transformer.resblocks.6.mlp.c_fc.bias", "visual.transformer.resblocks.6.mlp.c_proj.weight", "visual.transformer.resblocks.6.mlp.c_proj.bias", "visual.transformer.resblocks.6.ln_2.weight", "visual.transformer.resblocks.6.ln_2.bias", "visual.transformer.resblocks.7.attn.in_proj_weight", "visual.transformer.resblocks.7.attn.in_proj_bias", "visual.transformer.resblocks.7.attn.out_proj.weight", "visual.transformer.resblocks.7.attn.out_proj.bias", "visual.transformer.resblocks.7.ln_1.weight", "visual.transformer.resblocks.7.ln_1.bias", "visual.transformer.resblocks.7.mlp.c_fc.weight", "visual.transformer.resblocks.7.mlp.c_fc.bias", "visual.transformer.resblocks.7.mlp.c_proj.weight", "visual.transformer.resblocks.7.mlp.c_proj.bias", "visual.transformer.resblocks.7.ln_2.weight", "visual.transformer.resblocks.7.ln_2.bias", "visual.transformer.resblocks.8.attn.in_proj_weight", "visual.transformer.resblocks.8.attn.in_proj_bias", "visual.transformer.resblocks.8.attn.out_proj.weight", "visual.transformer.resblocks.8.attn.out_proj.bias", "visual.transformer.resblocks.8.ln_1.weight", "visual.transformer.resblocks.8.ln_1.bias", "visual.transformer.resblocks.8.mlp.c_fc.weight", "visual.transformer.resblocks.8.mlp.c_fc.bias", "visual.transformer.resblocks.8.mlp.c_proj.weight", "visual.transformer.resblocks.8.mlp.c_proj.bias", "visual.transformer.resblocks.8.ln_2.weight", "visual.transformer.resblocks.8.ln_2.bias", "visual.transformer.resblocks.9.attn.in_proj_weight", "visual.transformer.resblocks.9.attn.in_proj_bias", "visual.transformer.resblocks.9.attn.out_proj.weight", "visual.transformer.resblocks.9.attn.out_proj.bias", "visual.transformer.resblocks.9.ln_1.weight", "visual.transformer.resblocks.9.ln_1.bias", "visual.transformer.resblocks.9.mlp.c_fc.weight", "visual.transformer.resblocks.9.mlp.c_fc.bias", "visual.transformer.resblocks.9.mlp.c_proj.weight", "visual.transformer.resblocks.9.mlp.c_proj.bias", "visual.transformer.resblocks.9.ln_2.weight", "visual.transformer.resblocks.9.ln_2.bias", "visual.transformer.resblocks.10.attn.in_proj_weight", "visual.transformer.resblocks.10.attn.in_proj_bias", "visual.transformer.resblocks.10.attn.out_proj.weight", "visual.transformer.resblocks.10.attn.out_proj.bias", "visual.transformer.resblocks.10.ln_1.weight", "visual.transformer.resblocks.10.ln_1.bias", "visual.transformer.resblocks.10.mlp.c_fc.weight", "visual.transformer.resblocks.10.mlp.c_fc.bias", "visual.transformer.resblocks.10.mlp.c_proj.weight", "visual.transformer.resblocks.10.mlp.c_proj.bias", "visual.transformer.resblocks.10.ln_2.weight", "visual.transformer.resblocks.10.ln_2.bias", "visual.transformer.resblocks.11.attn.in_proj_weight", "visual.transformer.resblocks.11.attn.in_proj_bias", "visual.transformer.resblocks.11.attn.out_proj.weight", "visual.transformer.resblocks.11.attn.out_proj.bias", "visual.transformer.resblocks.11.ln_1.weight", "visual.transformer.resblocks.11.ln_1.bias", "visual.transformer.resblocks.11.mlp.c_fc.weight", "visual.transformer.resblocks.11.mlp.c_fc.bias", "visual.transformer.resblocks.11.mlp.c_proj.weight", "visual.transformer.resblocks.11.mlp.c_proj.bias", "visual.transformer.resblocks.11.ln_2.weight", "visual.transformer.resblocks.11.ln_2.bias", "visual.transformer.resblocks.12.attn.in_proj_weight", "visual.transformer.resblocks.12.attn.in_proj_bias", "visual.transformer.resblocks.12.attn.out_proj.weight", "visual.transformer.resblocks.12.attn.out_proj.bias", "visual.transformer.resblocks.12.ln_1.weight", "visual.transformer.resblocks.12.ln_1.bias", "visual.transformer.resblocks.12.mlp.c_fc.weight", "visual.transformer.resblocks.12.mlp.c_fc.bias", "visual.transformer.resblocks.12.mlp.c_proj.weight", "visual.transformer.resblocks.12.mlp.c_proj.bias", "visual.transformer.resblocks.12.ln_2.weight", "visual.transformer.resblocks.12.ln_2.bias", "visual.transformer.resblocks.13.attn.in_proj_weight", "visual.transformer.resblocks.13.attn.in_proj_bias", "visual.transformer.resblocks.13.attn.out_proj.weight", "visual.transformer.resblocks.13.attn.out_proj.bias", "visual.transformer.resblocks.13.ln_1.weight", "visual.transformer.resblocks.13.ln_1.bias", "visual.transformer.resblocks.13.mlp.c_fc.weight", "visual.transformer.resblocks.13.mlp.c_fc.bias", "visual.transformer.resblocks.13.mlp.c_proj.weight", "visual.transformer.resblocks.13.mlp.c_proj.bias", "visual.transformer.resblocks.13.ln_2.weight", "visual.transformer.resblocks.13.ln_2.bias", "visual.transformer.resblocks.14.attn.in_proj_weight", "visual.transformer.resblocks.14.attn.in_proj_bias", "visual.transformer.resblocks.14.attn.out_proj.weight", "visual.transformer.resblocks.14.attn.out_proj.bias", "visual.transformer.resblocks.14.ln_1.weight", "visual.transformer.resblocks.14.ln_1.bias", "visual.transformer.resblocks.14.mlp.c_fc.weight", "visual.transformer.resblocks.14.mlp.c_fc.bias", "visual.transformer.resblocks.14.mlp.c_proj.weight", "visual.transformer.resblocks.14.mlp.c_proj.bias", "visual.transformer.resblocks.14.ln_2.weight", "visual.transformer.resblocks.14.ln_2.bias", "visual.transformer.resblocks.15.attn.in_proj_weight", "visual.transformer.resblocks.15.attn.in_proj_bias", "visual.transformer.resblocks.15.attn.out_proj.weight", "visual.transformer.resblocks.15.attn.out_proj.bias", "visual.transformer.resblocks.15.ln_1.weight", "visual.transformer.resblocks.15.ln_1.bias", "visual.transformer.resblocks.15.mlp.c_fc.weight", "visual.transformer.resblocks.15.mlp.c_fc.bias", "visual.transformer.resblocks.15.mlp.c_proj.weight", "visual.transformer.resblocks.15.mlp.c_proj.bias", "visual.transformer.resblocks.15.ln_2.weight", "visual.transformer.resblocks.15.ln_2.bias", "visual.transformer.resblocks.16.attn.in_proj_weight", "visual.transformer.resblocks.16.attn.in_proj_bias", "visual.transformer.resblocks.16.attn.out_proj.weight", "visual.transformer.resblocks.16.attn.out_proj.bias", "visual.transformer.resblocks.16.ln_1.weight", "visual.transformer.resblocks.16.ln_1.bias", "visual.transformer.resblocks.16.mlp.c_fc.weight", "visual.transformer.resblocks.16.mlp.c_fc.bias", "visual.transformer.resblocks.16.mlp.c_proj.weight", "visual.transformer.resblocks.16.mlp.c_proj.bias", "visual.transformer.resblocks.16.ln_2.weight", "visual.transformer.resblocks.16.ln_2.bias", "visual.transformer.resblocks.17.attn.in_proj_weight", "visual.transformer.resblocks.17.attn.in_proj_bias", "visual.transformer.resblocks.17.attn.out_proj.weight", "visual.transformer.resblocks.17.attn.out_proj.bias", "visual.transformer.resblocks.17.ln_1.weight", "visual.transformer.resblocks.17.ln_1.bias", "visual.transformer.resblocks.17.mlp.c_fc.weight", "visual.transformer.resblocks.17.mlp.c_fc.bias", "visual.transformer.resblocks.17.mlp.c_proj.weight", "visual.transformer.resblocks.17.mlp.c_proj.bias", "visual.transformer.resblocks.17.ln_2.weight", "visual.transformer.resblocks.17.ln_2.bias", "visual.transformer.resblocks.18.attn.in_proj_weight", "visual.transformer.resblocks.18.attn.in_proj_bias", "visual.transformer.resblocks.18.attn.out_proj.weight", "visual.transformer.resblocks.18.attn.out_proj.bias", "visual.transformer.resblocks.18.ln_1.weight", "visual.transformer.resblocks.18.ln_1.bias", "visual.transformer.resblocks.18.mlp.c_fc.weight", "visual.transformer.resblocks.18.mlp.c_fc.bias", "visual.transformer.resblocks.18.mlp.c_proj.weight", "visual.transformer.resblocks.18.mlp.c_proj.bias", "visual.transformer.resblocks.18.ln_2.weight", "visual.transformer.resblocks.18.ln_2.bias", "visual.transformer.resblocks.19.attn.in_proj_weight", "visual.transformer.resblocks.19.attn.in_proj_bias", "visual.transformer.resblocks.19.attn.out_proj.weight", "visual.transformer.resblocks.19.attn.out_proj.bias", "visual.transformer.resblocks.19.ln_1.weight", "visual.transformer.resblocks.19.ln_1.bias", "visual.transformer.resblocks.19.mlp.c_fc.weight", "visual.transformer.resblocks.19.mlp.c_fc.bias", "visual.transformer.resblocks.19.mlp.c_proj.weight", "visual.transformer.resblocks.19.mlp.c_proj.bias", "visual.transformer.resblocks.19.ln_2.weight", "visual.transformer.resblocks.19.ln_2.bias", "visual.transformer.resblocks.20.attn.in_proj_weight", "visual.transformer.resblocks.20.attn.in_proj_bias", "visual.transformer.resblocks.20.attn.out_proj.weight", "visual.transformer.resblocks.20.attn.out_proj.bias", "visual.transformer.resblocks.20.ln_1.weight", "visual.transformer.resblocks.20.ln_1.bias", "visual.transformer.resblocks.20.mlp.c_fc.weight", "visual.transformer.resblocks.20.mlp.c_fc.bias", "visual.transformer.resblocks.20.mlp.c_proj.weight", "visual.transformer.resblocks.20.mlp.c_proj.bias", "visual.transformer.resblocks.20.ln_2.weight", "visual.transformer.resblocks.20.ln_2.bias", "visual.transformer.resblocks.21.attn.in_proj_weight", "visual.transformer.resblocks.21.attn.in_proj_bias", "visual.transformer.resblocks.21.attn.out_proj.weight", "visual.transformer.resblocks.21.attn.out_proj.bias", "visual.transformer.resblocks.21.ln_1.weight", "visual.transformer.resblocks.21.ln_1.bias", "visual.transformer.resblocks.21.mlp.c_fc.weight", "visual.transformer.resblocks.21.mlp.c_fc.bias", "visual.transformer.resblocks.21.mlp.c_proj.weight", "visual.transformer.resblocks.21.mlp.c_proj.bias", "visual.transformer.resblocks.21.ln_2.weight", "visual.transformer.resblocks.21.ln_2.bias", "visual.transformer.resblocks.22.attn.in_proj_weight", "visual.transformer.resblocks.22.attn.in_proj_bias", "visual.transformer.resblocks.22.attn.out_proj.weight", "visual.transformer.resblocks.22.attn.out_proj.bias", "visual.transformer.resblocks.22.ln_1.weight", "visual.transformer.resblocks.22.ln_1.bias", "visual.transformer.resblocks.22.mlp.c_fc.weight", "visual.transformer.resblocks.22.mlp.c_fc.bias", "visual.transformer.resblocks.22.mlp.c_proj.weight", "visual.transformer.resblocks.22.mlp.c_proj.bias", "visual.transformer.resblocks.22.ln_2.weight", "visual.transformer.resblocks.22.ln_2.bias", "visual.transformer.resblocks.23.attn.in_proj_weight", "visual.transformer.resblocks.23.attn.in_proj_bias", "visual.transformer.resblocks.23.attn.out_proj.weight", "visual.transformer.resblocks.23.attn.out_proj.bias", "visual.transformer.resblocks.23.ln_1.weight", "visual.transformer.resblocks.23.ln_1.bias", "visual.transformer.resblocks.23.mlp.c_fc.weight", "visual.transformer.resblocks.23.mlp.c_fc.bias", "visual.transformer.resblocks.23.mlp.c_proj.weight", "visual.transformer.resblocks.23.mlp.c_proj.bias", "visual.transformer.resblocks.23.ln_2.weight", "visual.transformer.resblocks.23.ln_2.bias", "visual.ln_post.weight", "visual.ln_post.bias", "transformer.resblocks.0.attn.in_proj_weight", "transformer.resblocks.0.attn.in_proj_bias", "transformer.resblocks.0.attn.out_proj.weight", "transformer.resblocks.0.attn.out_proj.bias", "transformer.resblocks.0.ln_1.weight", "transformer.resblocks.0.ln_1.bias", "transformer.resblocks.0.mlp.c_fc.weight", "transformer.resblocks.0.mlp.c_fc.bias", "transformer.resblocks.0.mlp.c_proj.weight", "transformer.resblocks.0.mlp.c_proj.bias", "transformer.resblocks.0.ln_2.weight", "transformer.resblocks.0.ln_2.bias", "transformer.resblocks.1.attn.in_proj_weight", "transformer.resblocks.1.attn.in_proj_bias", "transformer.resblocks.1.attn.out_proj.weight", "transformer.resblocks.1.attn.out_proj.bias", "transformer.resblocks.1.ln_1.weight", "transformer.resblocks.1.ln_1.bias", "transformer.resblocks.1.mlp.c_fc.weight", "transformer.resblocks.1.mlp.c_fc.bias", "transformer.resblocks.1.mlp.c_proj.weight", "transformer.resblocks.1.mlp.c_proj.bias", "transformer.resblocks.1.ln_2.weight", "transformer.resblocks.1.ln_2.bias", "transformer.resblocks.2.attn.in_proj_weight", "transformer.resblocks.2.attn.in_proj_bias", "transformer.resblocks.2.attn.out_proj.weight", "transformer.resblocks.2.attn.out_proj.bias", "transformer.resblocks.2.ln_1.weight", "transformer.resblocks.2.ln_1.bias", "transformer.resblocks.2.mlp.c_fc.weight", "transformer.resblocks.2.mlp.c_fc.bias", "transformer.resblocks.2.mlp.c_proj.weight", "transformer.resblocks.2.mlp.c_proj.bias", "transformer.resblocks.2.ln_2.weight", "transformer.resblocks.2.ln_2.bias", "transformer.resblocks.3.attn.in_proj_weight", "transformer.resblocks.3.attn.in_proj_bias", "transformer.resblocks.3.attn.out_proj.weight", "transformer.resblocks.3.attn.out_proj.bias", "transformer.resblocks.3.ln_1.weight", "transformer.resblocks.3.ln_1.bias", "transformer.resblocks.3.mlp.c_fc.weight", "transformer.resblocks.3.mlp.c_fc.bias", "transformer.resblocks.3.mlp.c_proj.weight", "transformer.resblocks.3.mlp.c_proj.bias", "transformer.resblocks.3.ln_2.weight", "transformer.resblocks.3.ln_2.bias", "transformer.resblocks.4.attn.in_proj_weight", "transformer.resblocks.4.attn.in_proj_bias", "transformer.resblocks.4.attn.out_proj.weight", "transformer.resblocks.4.attn.out_proj.bias", "transformer.resblocks.4.ln_1.weight", "transformer.resblocks.4.ln_1.bias", "transformer.resblocks.4.mlp.c_fc.weight", "transformer.resblocks.4.mlp.c_fc.bias", "transformer.resblocks.4.mlp.c_proj.weight", "transformer.resblocks.4.mlp.c_proj.bias", "transformer.resblocks.4.ln_2.weight", "transformer.resblocks.4.ln_2.bias", "transformer.resblocks.5.attn.in_proj_weight", "transformer.resblocks.5.attn.in_proj_bias", "transformer.resblocks.5.attn.out_proj.weight", "transformer.resblocks.5.attn.out_proj.bias", "transformer.resblocks.5.ln_1.weight", "transformer.resblocks.5.ln_1.bias", "transformer.resblocks.5.mlp.c_fc.weight", "transformer.resblocks.5.mlp.c_fc.bias", "transformer.resblocks.5.mlp.c_proj.weight", "transformer.resblocks.5.mlp.c_proj.bias", "transformer.resblocks.5.ln_2.weight", "transformer.resblocks.5.ln_2.bias", "transformer.resblocks.6.attn.in_proj_weight", "transformer.resblocks.6.attn.in_proj_bias", "transformer.resblocks.6.attn.out_proj.weight", "transformer.resblocks.6.attn.out_proj.bias", "transformer.resblocks.6.ln_1.weight", "transformer.resblocks.6.ln_1.bias", "transformer.resblocks.6.mlp.c_fc.weight", "transformer.resblocks.6.mlp.c_fc.bias", "transformer.resblocks.6.mlp.c_proj.weight", "transformer.resblocks.6.mlp.c_proj.bias", "transformer.resblocks.6.ln_2.weight", "transformer.resblocks.6.ln_2.bias", "transformer.resblocks.7.attn.in_proj_weight", "transformer.resblocks.7.attn.in_proj_bias", "transformer.resblocks.7.attn.out_proj.weight", "transformer.resblocks.7.attn.out_proj.bias", "transformer.resblocks.7.ln_1.weight", "transformer.resblocks.7.ln_1.bias", "transformer.resblocks.7.mlp.c_fc.weight", "transformer.resblocks.7.mlp.c_fc.bias", "transformer.resblocks.7.mlp.c_proj.weight", "transformer.resblocks.7.mlp.c_proj.bias", "transformer.resblocks.7.ln_2.weight", "transformer.resblocks.7.ln_2.bias", "transformer.resblocks.8.attn.in_proj_weight", "transformer.resblocks.8.attn.in_proj_bias", "transformer.resblocks.8.attn.out_proj.weight", "transformer.resblocks.8.attn.out_proj.bias", "transformer.resblocks.8.ln_1.weight", "transformer.resblocks.8.ln_1.bias", "transformer.resblocks.8.mlp.c_fc.weight", "transformer.resblocks.8.mlp.c_fc.bias", "transformer.resblocks.8.mlp.c_proj.weight", "transformer.resblocks.8.mlp.c_proj.bias", "transformer.resblocks.8.ln_2.weight", "transformer.resblocks.8.ln_2.bias", "transformer.resblocks.9.attn.in_proj_weight", "transformer.resblocks.9.attn.in_proj_bias", "transformer.resblocks.9.attn.out_proj.weight", "transformer.resblocks.9.attn.out_proj.bias", "transformer.resblocks.9.ln_1.weight", "transformer.resblocks.9.ln_1.bias", "transformer.resblocks.9.mlp.c_fc.weight", "transformer.resblocks.9.mlp.c_fc.bias", "transformer.resblocks.9.mlp.c_proj.weight", "transformer.resblocks.9.mlp.c_proj.bias", "transformer.resblocks.9.ln_2.weight", "transformer.resblocks.9.ln_2.bias", "transformer.resblocks.10.attn.in_proj_weight", "transformer.resblocks.10.attn.in_proj_bias", "transformer.resblocks.10.attn.out_proj.weight", "transformer.resblocks.10.attn.out_proj.bias", "transformer.resblocks.10.ln_1.weight", "transformer.resblocks.10.ln_1.bias", "transformer.resblocks.10.mlp.c_fc.weight", "transformer.resblocks.10.mlp.c_fc.bias", "transformer.resblocks.10.mlp.c_proj.weight", "transformer.resblocks.10.mlp.c_proj.bias", "transformer.resblocks.10.ln_2.weight", "transformer.resblocks.10.ln_2.bias", "transformer.resblocks.11.attn.in_proj_weight", "transformer.resblocks.11.attn.in_proj_bias", "transformer.resblocks.11.attn.out_proj.weight", "transformer.resblocks.11.attn.out_proj.bias", "transformer.resblocks.11.ln_1.weight", "transformer.resblocks.11.ln_1.bias", "transformer.resblocks.11.mlp.c_fc.weight", "transformer.resblocks.11.mlp.c_fc.bias", "transformer.resblocks.11.mlp.c_proj.weight", "transformer.resblocks.11.mlp.c_proj.bias", "transformer.resblocks.11.ln_2.weight", "transformer.resblocks.11.ln_2.bias", "token_embedding.weight", "ln_final.weight", "ln_final.bias".

Unexpected key(s) in state_dict: "epoch", "global_step", "pytorch-lightning_version", "state_dict", "loops", "callbacks", "optimizer_states", "lr_schedulers", "hparams_name", "hyper_parameters".

Cifar10 & cifar100 datasets problem

Hi,

Sorry to disturb you. I encountered the following error while trying to run the code on the cifar dataset.

FileNotFoundError: [Errno 2] No such file or directory: ‘~/LaBo/datasets/CIFAR100/images/access_road_s_001073.png

It seems the program is looking for data in PNG format. However, the cifar download link provided in the dataset.md is not in PNG format but consists of three files. Here I am curious about where the cifar dataset used by the program comes from or how it is generated.

Thanks!

TypeError: __new__() missing 1 required positional argument: 'task'

After fixing the error that negative value can't be received, this error comes subsequently:
"""
use asso concept with dot product loader, faster
Traceback (most recent call last):
File "/LaBo-main/main.py", line 350, in
main(cfg)
File "/LaBo-main/main.py", line 270, in asso_opt_main
model = AssoConceptFast(cfg, init_weight=th.load(cfg.init_weight_path) if 'init_weight_path' in cfg else None)
File "/LaBo-main/models/asso_opt/asso_opt.py", line 68, in init
self.train_acc = torchmetrics.Accuracy(num_classes=cfg.num_cls)
TypeError: new() missing 1 required positional argument: 'task'
"""

I didn't modify the whole code file, what's that mean?
I'll be really appreciated if you could help me.
image

the gpu used for the experiment

hi, I have used the config of cub_base for one shot training on a single 3090, but I kept getting the error of cuda out of memory although I have reduced the batchsize to 4, i was wondering what GPU did you use for training and any advice of helping me solve this problem? thanks a lot!

inference_code

hello,
Is there any code that allows a caption to appear when an image is inserted like figure7 in the paper?

the unuse of use_img_norm and use_txt_norm

hi, thanks for your great work! I am a little confusedd about the code.
I noticed in the code that two parameters in the cfg use_img_norm and use_txt_norm are set as False which is unusual in normal CLIP process for classification. Is there any specific reason for not using norm on img and text features? Shouldn't using the norm process a better way for image-concept alignment and better interpretable ability based on CLIP?
THANKS a lot!

Requirement Text Installation Issue

"conda install --file requirement.txt" error: some unrelated packages included in the requirement.txt resulting an installation issue, such as alembic and alsa-lib.
Could you please check the problem?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.