Giter Site home page Giter Site logo

agent-attention's Introduction

Agent Attention

This repo contains the official PyTorch code and pre-trained models for Agent Attention.

Introduction

The attention module is the key component in Transformers. While the global attention mechanism offers robust expressiveness, its excessive computational cost constrains its applicability in various scenarios. In this paper, we propose a novel attention paradigm, Agent Attention, to strike a favorable balance between computational efficiency and representation power. Specifically, the Agent Attention, denoted as a quadruple $(Q, A, K, V)$, introduces an additional set of agent tokens $A$ into the conventional attention module. The agent tokens first act as the agent for the query tokens $Q$ to aggregate information from $K$ and $V$, and then broadcast the information back to $Q$. Given the number of agent tokens can be designed to be much smaller than the number of query tokens, the agent attention is significantly more efficient than the widely adopted Softmax attention, while preserving global context modelling capability. Interestingly, we show that the proposed agent attention is equivalent to a generalized form of linear attention. Therefore, agent attention seamlessly integrates the powerful Softmax attention and the highly efficient linear attention.

Motivation

(a) In Softmax attention, each query aggregates information from all features, incurring quadratic complexity. (b) Leveraging the redundancy between attention weights, agent attention uses a small number of agent tokens to act as the "agent'' for queries, capturing diverse semantic information from all features, and then presenting it to each query.

Method

An illustration of our agent attention and agent attention module. (a) Agent attention uses agent tokens to aggregate global information and distribute it to individual image tokens, resulting in a practical integration of Softmax and linear attention. $\rm{\sigma}(\cdot)$ represents Softmax function. In (b), we depict the information flow of agent attention module. As a showcase, we acquire agent tokens through pooling. Subsequently, agent tokens are utilized to aggregate information from $V$, and $Q$ queries features from the agent features. In addition, agent bias and DWC are adopted to add positional information and maintain feature diversity.

Results

Classification

Please go to the folder agent_transformer for specific document.

  • Comparison of different models on ImageNet-1K.

  • Accuracy-Runtime curve on ImageNet.

  • Increasing resolution to ${256^2, 288^2, 320^2, 352^2, 384^2}$.

Downstream tasks

Please go to the folder detection, segmentation for specific documents.

AgentSD

When applied to Stable Diffusion, our agent attention accelerates generation and substantially enhances image generation quality without any additional training. Please go to the folder agentsd for specific document.

  • Quantitative Results of Stable Diffusion, ToMeSD and our AgentSD.

  • Samples generated by Stable Diffusion, ToMeSD ($r=0.4$) and AgentSD ($r=0.4$).

TODO

  • Classification
  • Segmentation
  • Detection
  • Agent Attention for Stable Diffusion

Acknowledgements

Our code is developed on the top of PVT, Swin Transformer, CSwin Transformer and ToMeSD.

Citation

If you find this repo helpful, please consider citing us.

@article{han2023agent,
  title={Agent Attention: On the Integration of Softmax and Linear Attention},
  author={Han, Dongchen and Ye, Tianzhu and Han, Yizeng and Xia, Zhuofan and Song, Shiji and Huang, Gao},
  journal={arXiv preprint arXiv:2312.08874},
  year={2023}
}

Contact

If you have any questions, please feel free to contact the authors.

Dongchen Han: [email protected]

Tianzhu Ye: [email protected]

agent-attention's People

Contributors

eltociear avatar tian-qing001 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

agent-attention's Issues

Question about the difference between agent attention and anchored stripe attention

Thank you for your outstanding work! I wanted to inquire if you are familiar with the concept of anchored stripe attention discussed in the paper titled "Efficient and Explicit Modelling of Image Hierarchies for Image Restoration." It appears that there are striking similarities between these two attention mechanisms. Could you elucidate the key distinctions between them?
截屏2023-12-27 15 41 21

How to apply AgentAttention to the input of seq_length=320

Hi! @tian-qing001
Thank you for your outstanding contributions. I want to FOLLOW your work. AgentAttention needs to square the input sequence length to get h and w for subsequent agent token.

However, in the process of applying to my own framework, I found that the input sequence length is 320, and there will be a problem after squaring, is there a solution for this?
Thanks in advance!

Applying agent attention to auto-regressive models

Thank you for your inspiring work. I'm eager to apply ideas like agent attention to train prevailing auto-regressive models like GPT. However, using pooling on Q to get the A matrix will cause information leakage when training auto-regressive models using teacher forcing. I haven't found related discussions in your paper. Is there any straightforward extension or variation of agent attention to adapt it to auto-regressive models?

pretrained

Is there a pre-training weight that can be downloaded

3D AgentToken

when i use 3d in train ,when run in 40 epoch ,have a mistake cuda out of memory ,the agent num is 343

meaning and method

Thanks for the information.
I have additional question.
What is that meaning of remove and apply patch? also sx, sy, ratio?
agentsd is your model?
agentsd.remove_patch(self.model)
agentsd.apply_patch(model, sx=4, sy=4, ratio=0.4, agent_ratio=0.95)

actually, I would like to apply your agent attention module to ddim from guided diffusion model.

thanks,
jungmin

different image solution for training and testing

Hi! I am using agent_swin model as backbone, but because the image resolution during testing is higher than that during training, there are some errors, so I want to know whether the model can support input with dynamic resolution, and how should I modify it.
Looking forward to your reply.

application in cross transformer

Hello,

I've been exploring your work on Cross Transformer, and I'm intrigued by the potential of integrating Agent Attention into this architecture. Agent Attention, as a method to balance computational efficiency and representation power, seems like it could complement the Cross Transformer's design quite well.

I'm particularly interested in understanding how Agent Attention might be incorporated into the Cross Transformer framework. Specifically, my questions are:

  1. How many agent tokens would be optimal in the context of Cross Transformer, and how should they be initialized?
  2. In the existing architecture, what would be the best way to integrate agent tokens in terms of the attention mechanism – should they replace or complement the current attention queries or keys?
  3. Are there any specific considerations or potential challenges you foresee in adapting Agent Attention to this context?

Any insights or suggestions you could provide would be greatly appreciated. I believe such an integration could offer a promising direction for further research and application.

Thank you for your time and for the impactful work you've shared with the community.

Best regards,

missing_keys: 'layers.3.blocks.0.attn.relative_position_index', 'layers.2.blocks.1.attn_mask'

Hi! I found there are some missing keys when loading pre-training weights, does this matter?
WARNING _IncompatibleKeys(missing_keys=['layers.2.blocks.0.attn.relative_position_index', 'layers.2.blocks.1.attn_mask', 'layers.2.blocks.1.attn.relative_position_index', 'layers.2.blocks.2.attn.relative_position_index', 'layers.2.blocks.3.attn_mask', 'layers.2.blocks.3.attn.relative_position_index', 'layers.2.blocks.4.attn.relative_position_index', 'layers.2.blocks.5.attn_mask', 'layers.2.blocks.5.attn.relative_position_index', 'layers.2.blocks.6.attn.relative_position_index', 'layers.2.blocks.7.attn_mask', 'layers.2.blocks.7.attn.relative_position_index', 'layers.2.blocks.8.attn.relative_position_index', 'layers.2.blocks.9.attn_mask', 'layers.2.blocks.9.attn.relative_position_index', 'layers.2.blocks.10.attn.relative_position_index', 'layers.2.blocks.11.attn_mask', 'layers.2.blocks.11.attn.relative_position_index', 'layers.2.blocks.12.attn.relative_position_index', 'layers.2.blocks.13.attn_mask', 'layers.2.blocks.13.attn.relative_position_index', 'layers.2.blocks.14.attn.relative_position_index', 'layers.2.blocks.15.attn_mask', 'layers.2.blocks.15.attn.relative_position_index', 'layers.2.blocks.16.attn.relative_position_index', 'layers.2.blocks.17.attn_mask', 'layers.2.blocks.17.attn.relative_position_index', 'layers.3.blocks.0.attn.relative_position_index', 'layers.3.blocks.1.attn.relative_position_index'], unexpected_keys=[])

scatter_reduce() argument 'reduce' (position 3 ) must be str, not Tensor

python scripts/txt2img.py --prompt "a photograph of an astronaut riding a horse" --plms --precision full
stable-diffusion-v1-4
Error:
File "scripts/txt2img.py", line 353, in
main()
File "scripts/txt2img.py", line 303, in main
samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
File "/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/code/stable-diffusion-main/ldm/models/diffusion/plms.py", line 97, in sample
samples, intermediates = self.plms_sampling(conditioning, size,
File "/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/code/stable-diffusion-main/ldm/models/diffusion/plms.py", line 159, in plms_sampling
outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
File "/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/code/stable-diffusion-main/ldm/models/diffusion/plms.py", line 225, in p_sample_plms
e_t = get_model_output(x, t)
File "/code/stable-diffusion-main/ldm/models/diffusion/plms.py", line 192, in get_model_output
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
File "/code/stable-diffusion-main/ldm/models/diffusion/ddpm.py", line 987, in apply_model
x_recon = self.model(x_noisy, t, **cond)
x_recon = self.model(x_noisy, t, **cond)
File "/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/code/stable-diffusion-main/ldm/models/diffusion/ddpm.py", line 1410, in forward
out = self.diffusion_model(x, t, context=cc)
File "/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1128, in _call_impl
result = forward_call(*input, **kwargs)
File "/code/stable-diffusion-main/ldm/modules/diffusionmodules/openaimodel.py", line 731, in forward
h = module(h, emb, context)
File "/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/code/stable-diffusion-main/ldm/modules/diffusionmodules/openaimodel.py", line 85, in forward
x = layer(x, context)
File "/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/code/stable-diffusion-main/ldm/modules/attention.py", line 258, in forward
x = block(x, context=context)
File "/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/code/stable-diffusion-main/ldm/modules/attention.py", line 209, in forward
return checkpoint(self.forward, (x, context), self.parameters(), self.checkpoint)
File "/code/stable-diffusion-main/ldm/modules/diffusionmodules/util.py", line 114, in checkpoint
return CheckpointFunction.apply(func, len(inputs), *args)
File "/code/stable-diffusion-main/ldm/modules/diffusionmodules/util.py", line 127, in forward
output_tensors = ctx.run_function(*ctx.input_tensors)
File "/code/stable-diffusion-main/ldm/models/diffusion/agentsd/patch.py", line 66, in forward
feature, agent = m_a(y)
File "/code/stable-diffusion-main/ldm/models/diffusion/agentsd/merge.py", line 118, in merge
dst
= dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src
, reduce=mode)
TypeError: scatter_reduce(): argument 'reduce' (position 3) must be str, not Tensor

Black Images Generated

When I try to run agentsd with stable diffusion v2.1, it generates black images. Sample below:

image

I have added agentsd folder to the root and following lines of code

  import agentsd
  if i == 0:
      agentsd.remove_patch(self.model)
      agentsd.apply_patch(self.model, sx=4, sy=4, ratio=0.4, agent_ratio=0.95, attn_precision="fp32")
  elif i == 20:
      agentsd.remove_patch(self.model)
      agentsd.apply_patch(self.model, sx=2, sy=2, ratio=0.4, agent_ratio=0.5, attn_precision="fp32")

to the ldm/models/diffusion/ddim.py file after L152 as per the instructions. Without this, it works fine.

偏移的相关问题

作者你好:
今天有幸读了这篇文章,感觉非常的潜力。同时我目前在进行图像融合方面的研究,在此有几个问题希望作者可以解惑。
1.agent_tokens = self.pool(q[:, 1:, :].reshape(b, h, w, c).permute(0, 3, 1, 2)).reshape(b, c, -1).permute(0, 2, 1)在此行代码中,对Q做了切片处理,但是切片之后reshape的话是不是就没办法变成(b, h, w, c)了呢。
2.在Agent-feature加了bias,但是position_bias = torch.cat([self.ac_bias.repeat(b, 1, 1, 1), position_bias], dim=-1)这里的偏移是cat在一起的,这样的话不就多出一个维度了。
3.如果不加偏移的话,请问结果会有很大的出入吗
4.bias是依据什么来划分的呢。
期待您的回复

Training on a single GPU

你好,我在本机debug代码时,出现并行的问题local_rank和world_size,是参数设置有错误吗?

Snipaste_2023-12-20_16-07-49

attn_type in agent-swin

Hello author, first of all, thank you very much for your work.
Secondly, I am confused that I found you set attn_type in Agent SwinTransformer function to BBBB in the Agent-swin code and config. Whether this will be code that does not use agent attention in actual running

attention mask

您好,感谢您提供的思路,我想请问下,如何在agent attention 中加attention mask来屏蔽部分输入?

How to decide the window size in agent_swin

Thank you for your excellent work! I find that the original window size in Swin-T is 7, whereas in agent_swin, it is 56. I am curious about your design choices regarding the window size and stage attention types in agent-swin-T/S/B. Are there any guiding principles behind these decisions?

为什么使用A代理K呢?

您好,在阅读完您的论文之后,有个疑惑想要请教一下:在使用A代替Q与K和V计算得到Va之后,为什么还要使用A来代替K而不是用Va来代替K呢?A与Va并不完全相同,那么使用Q和A计算得到的注意力分数可以作用到Va上面吗?

Inquiry About Integrating Agent Attention into xformers Library

Dear Dr. Han and Dr. Ye,

I have been greatly impressed by your work on the Agent Attention model, as detailed in your recent publication and the associated GitHub repository. The method of integrating Softmax with linear attention mechanisms to enhance computational efficiency while maintaining robust expressiveness is particularly compelling.

Given that the xformers library is a platform for optimizing and enhancing the efficiency of Transformers, I am curious to know if there are any plans to integrate the Agent Attention mechanism into xformers. Such an integration could potentially make your innovative approach more accessible and practical for a broader audience, enabling developers and researchers to utilize Agent Attention in real-world applications more readily.

Could you please share any information regarding plans to migrate Agent Attention code to xformers or similar libraries, or if there are any ongoing projects aimed at such integration?

Thank you for your time and consideration.

Best regards,

xczhou

Some others approach to design agent_tokens?

Hi, thank you for your excellent work! I would like to migrate this work to the Transformer model that solves the related combinatorial optimization problem ex: the traveling salesman problem. Since this type of problem does not involve operations such as pooling, DWC, Bias, etc., I have the following questions:

  1. Are there other design methods for Agent_tokens? As mentioned in the article, "set to a set of learnable parameters" can be used? But I don't quite understand how this approach should be designed in code.
  2. Is it similar to a plug-and-play module to integrate Agent Attention into other Transformers? But it seems to be somewhat difficult due to different tensor dimensions.

I would be extremely grateful for any advice you could provide, and thank you so much for sharing such great work!

IndexError: tuple index out of range

/home/z/anaconda3/envs/agent_detection/bin/python3.7 /home/z/zky/Agent-Attention-master/downstream/detection/tools/debug_train.py
/home/z/zky/Agent-Attention-master/downstream/detection/mmdet/utils/setup_env.py:39: UserWarning: Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
f'Setting OMP_NUM_THREADS environment variable for each process '
/home/z/zky/Agent-Attention-master/downstream/detection/mmdet/utils/setup_env.py:49: UserWarning: Setting MKL_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
f'Setting MKL_NUM_THREADS environment variable for each process '
2023-12-31 21:14:42,606 - mmdet - INFO - Environment info:

sys.platform: linux
Python: 3.7.13 (default, Mar 29 2022, 02:18:16) [GCC 7.5.0]
CUDA available: True
GPU 0,1,2,3: NVIDIA TITAN X (Pascal)
CUDA_HOME: /usr/local/cuda
NVCC: Cuda compilation tools, release 10.1, V10.1.24
GCC: gcc (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
PyTorch: 1.12.1
PyTorch compiling details: PyTorch built with:

  • GCC 9.3
  • C++ Version: 201402
  • Intel(R) oneAPI Math Kernel Library Version 2021.4-Product Build 20210904 for Intel(R) 64 architecture applications
  • Intel(R) MKL-DNN v2.6.0 (Git Hash 52b5f107dd9cf10910aaa19cb47f3abf9b349815)
  • OpenMP 201511 (a.k.a. OpenMP 4.5)
  • LAPACK is enabled (usually provided by MKL)
  • NNPACK is enabled
  • CPU capability usage: AVX2
  • CUDA Runtime 11.3
  • NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_61,code=sm_61;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_37,code=compute_37
  • CuDNN 8.3.2 (built against CUDA 11.5)
  • Magma 2.5.2
  • Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=11.3, CUDNN_VERSION=8.3.2, CXX_COMPILER=/opt/rh/devtoolset-9/root/usr/bin/c++, CXX_FLAGS= -fabi-version=11 -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -DEDGE_PROFILER_USE_KINETO -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.12.1, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=OFF, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF,

TorchVision: 0.13.1
OpenCV: 4.6.0
MMCV: 1.6.1
MMCV Compiler: GCC 9.3
MMCV CUDA Compiler: 11.3
MMDetection: 2.25.2+daeda61

2023-12-31 21:14:43,557 - mmdet - INFO - Distributed training: False
2023-12-31 21:14:44,477 - mmdet - INFO - Config:
model = dict(
type='RetinaNet',
backbone=dict(
type='AgentPVT',
img_size=224,
patch_size=4,
in_chans=3,
num_classes=6,
embed_dims=[64, 128, 320, 512],
num_heads=[1, 2, 5, 8],
mlp_ratios=[8, 8, 4, 4],
qkv_bias=True,
qk_scale=None,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.2,
depths=[3, 4, 18, 3],
sr_ratios=[8, 4, 2, 1],
agent_sr_ratios='1111',
num_stages=4,
agent_num=[9, 16, 49, 49],
downstream_agent_shapes=[(12, 12), (16, 16), (28, 28), (28, 28)],
kernel_size=3,
attn_type='AAAA',
scale=-0.5,
init_cfg=dict(type='Pretrained', checkpoint=None)),
neck=dict(
type='FPN',
in_channels=[64, 128, 320, 512],
out_channels=256,
start_level=1,
add_extra_convs='on_input',
num_outs=5),
bbox_head=dict(
type='RetinaHead',
num_classes=6,
in_channels=256,
stacked_convs=4,
feat_channels=256,
anchor_generator=dict(
type='AnchorGenerator',
octave_base_scale=4,
scales_per_octave=3,
ratios=[0.5, 1.0, 2.0],
strides=[8, 16, 32, 64, 128]),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[0.0, 0.0, 0.0, 0.0],
target_stds=[1.0, 1.0, 1.0, 1.0]),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
train_cfg=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.4,
min_pos_iou=0,
ignore_iof_thr=-1),
allowed_border=-1,
pos_weight=-1,
debug=False),
test_cfg=dict(
nms_pre=1000,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.5),
max_per_img=100))
dataset_type = 'CocoDataset'
data_root = '/home/z/zky/Cosistentteacher/ConsistentTeacher-main/data/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(
type='Normalize',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(
type='Normalize',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
])
]
data = dict(
train_dataloader=dict(
samples_per_gpu=2, workers_per_gpu=10, pin_memory=True),
train=dict(
type='CocoDataset',
ann_file=
'/home/z/zky/Cosistentteacher/ConsistentTeacher-main/data/annotations/instances_train2017.json',
img_prefix=
'/home/z/zky/Cosistentteacher/ConsistentTeacher-main/data/train2017/',
pipeline=[
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(
type='Normalize',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
]),
val=dict(
type='CocoDataset',
ann_file=
'/home/z/zky/Cosistentteacher/ConsistentTeacher-main/data/annotations/instances_val2017.json',
img_prefix=
'/home/z/zky/Cosistentteacher/ConsistentTeacher-main/data/val2017/',
pipeline=[
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(
type='Normalize',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
])
]),
test=dict(
type='CocoDataset',
ann_file=
'/home/z/zky/Cosistentteacher/ConsistentTeacher-main/data/annotations/instances_val2017.json',
img_prefix=
'/home/z/zky/Cosistentteacher/ConsistentTeacher-main/data/val2017/',
pipeline=[
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(
type='Normalize',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
])
]))
evaluation = dict(interval=1, metric='bbox')
optimizer = dict(type='AdamW', lr=0.0001, weight_decay=0.0001)
optimizer_config = dict(grad_clip=None)
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=0.001,
step=[8, 11])
runner = dict(type='EpochBasedRunner', max_epochs=12)
checkpoint_config = dict(interval=1)
log_config = dict(interval=50, hooks=[dict(type='TextLoggerHook')])
custom_hooks = [dict(type='NumClassCheckHook')]
dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None
resume_from = None
workflow = [('train', 1)]
opencv_num_threads = 0
mp_start_method = 'fork'
auto_scale_lr = dict(enable=False, base_batch_size=16)
pretrained = None
lr = 0.0001
work_dir = './work_dirs/agent_pvt_m_rtn_1x_12-16-28-28'
auto_resume = False
gpu_ids = [0]

2023-12-31 21:14:44,477 - mmdet - INFO - Set random seed to 2100000934, deterministic: False
Traceback (most recent call last):
File "/home/z/anaconda3/envs/agent_detection/lib/python3.7/site-packages/mmcv/utils/registry.py", line 69, in build_from_cfg
return obj_cls(**args)
File "/home/z/zky/Agent-Attention-master/downstream/detection/mmdet/models/backbones/agent_pvt.py", line 300, in init
for j in range(depths[i])])
File "/home/z/zky/Agent-Attention-master/downstream/detection/mmdet/models/backbones/agent_pvt.py", line 300, in
for j in range(depths[i])])
File "/home/z/zky/Agent-Attention-master/downstream/detection/mmdet/models/backbones/agent_pvt.py", line 217, in init
agent_num=agent_num, downstream_agent_shape=downstream_agent_shape, kernel_size=kernel_size, scale=scale)
File "/home/z/zky/Agent-Attention-master/downstream/detection/mmdet/models/backbones/agent_pvt.py", line 129, in init
print('Agent Attention sr{} v{} n{} k{} scale{} reso{}'.format(sr_ratio, agent_num, kernel_size, scale, window_size))
IndexError: tuple index out of range

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "/home/z/anaconda3/envs/agent_detection/lib/python3.7/site-packages/mmcv/utils/registry.py", line 69, in build_from_cfg
return obj_cls(**args)
File "/home/z/zky/Agent-Attention-master/downstream/detection/mmdet/models/detectors/retinanet.py", line 19, in init
test_cfg, pretrained, init_cfg)
File "/home/z/zky/Agent-Attention-master/downstream/detection/mmdet/models/detectors/single_stage.py", line 32, in init
self.backbone = build_backbone(backbone)
File "/home/z/zky/Agent-Attention-master/downstream/detection/mmdet/models/builder.py", line 20, in build_backbone
return BACKBONES.build(cfg)
File "/home/z/anaconda3/envs/agent_detection/lib/python3.7/site-packages/mmcv/utils/registry.py", line 237, in build
return self.build_func(*args, **kwargs, registry=self)
File "/home/z/anaconda3/envs/agent_detection/lib/python3.7/site-packages/mmcv/cnn/builder.py", line 27, in build_model_from_cfg
return build_from_cfg(cfg, registry, default_args)
File "/home/z/anaconda3/envs/agent_detection/lib/python3.7/site-packages/mmcv/utils/registry.py", line 72, in build_from_cfg
raise type(e)(f'{obj_cls.name}: {e}')
IndexError: AgentPVT: tuple index out of range

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "/home/z/zky/Agent-Attention-master/downstream/detection/tools/debug_train.py", line 244, in
main()
File "/home/z/zky/Agent-Attention-master/downstream/detection/tools/debug_train.py", line 215, in main
test_cfg=cfg.get('test_cfg'))
File "/home/z/zky/Agent-Attention-master/downstream/detection/mmdet/models/builder.py", line 59, in build_detector
cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
File "/home/z/anaconda3/envs/agent_detection/lib/python3.7/site-packages/mmcv/utils/registry.py", line 237, in build
return self.build_func(*args, **kwargs, registry=self)
File "/home/z/anaconda3/envs/agent_detection/lib/python3.7/site-packages/mmcv/cnn/builder.py", line 27, in build_model_from_cfg
return build_from_cfg(cfg, registry, default_args)
File "/home/z/anaconda3/envs/agent_detection/lib/python3.7/site-packages/mmcv/utils/registry.py", line 72, in build_from_cfg
raise type(e)(f'{obj_cls.name}: {e}')
IndexError: RetinaNet: AgentPVT: tuple index out of range

Process finished with exit code 1

Connection with Guided diffusion

Hi, May I ask you for your code part to connect guided diffusion? Because I would like to match your code to guided diffusion, not stable diffusion. I think it will be possible, right? Can you give me some help to match q, k, v, agent to guided diffusion?

thanks,

Agent_num

Hello, I encountered several questions while reading your code.

1.What does 'agent_num mean'? I did not find a clear definition in the paper.

2.When defining bias, why is the dimension of bias related to agent_num? If I modify agent_num, it will result in dimension mismatch. However, I noticed that your paper includes comparisons with different agent_num.

I hope to receive your response.

A small bug in agent_pvt.py

作者您好,非常感谢您能分享如此有意思的工作。
我在复现您工作时在agent_pvt.py中发现了个小问题:如果self.sr_ratio>1,则在134行后,放缩过后的k和v的维度应该是q的self.sr_ratio平方分之1,而在144-146行中,qkv采用了相同的reshape维度,这里kv的reshape操作可能会有问题。我注意到您使用sr_ratio=sr_ratios[i] if attn_type[i] == 'B' else int(agent_sr_ratios[i])在agentattn下强制sr_ratio=1,那么当sr_ratio不等1时该如何处理呢?还是说agentattn不支持kv的放缩?
感谢您的回复

Agent-attention

Hi, I would like to use your model into DDPM or DDIM method. Is it possible?
Could you tell me which code or file is really important to add to them?

KeyError: 'stages.0.blocks.0.attn.w_msa.relative_position_bias_table'

感谢您出色的工作,我还有一个问题,打扰您一下:我将swin transformer的注意力换成了您的AgentAttention:

class AgentAttention(nn.Module):
    r""" Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.

    Args:
        dim (int): Number of input channels.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    """

    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.,
                 shift_size=0, agent_num=49, **kwargs):

        super().__init__()
        self.dim = dim
        self.window_size = window_size  # Wh, Ww
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.softmax = nn.Softmax(dim=-1)
        self.shift_size = shift_size

        self.agent_num = agent_num
        self.dwc = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=(3, 3), padding=1, groups=dim)
        self.an_bias = nn.Parameter(torch.zeros(num_heads, agent_num, 7, 7))
        self.na_bias = nn.Parameter(torch.zeros(num_heads, agent_num, 7, 7))
        self.ah_bias = nn.Parameter(torch.zeros(1, num_heads, agent_num, window_size[0], 1))
        self.aw_bias = nn.Parameter(torch.zeros(1, num_heads, agent_num, 1, window_size[1]))
        self.ha_bias = nn.Parameter(torch.zeros(1, num_heads, window_size[0], 1, agent_num))
        self.wa_bias = nn.Parameter(torch.zeros(1, num_heads, 1, window_size[1], agent_num))
        trunc_normal_(self.an_bias, std=.02)
        trunc_normal_(self.na_bias, std=.02)
        trunc_normal_(self.ah_bias, std=.02)
        trunc_normal_(self.aw_bias, std=.02)
        trunc_normal_(self.ha_bias, std=.02)
        trunc_normal_(self.wa_bias, std=.02)
        pool_size = int(agent_num ** 0.5)
        self.pool = nn.AdaptiveAvgPool2d(output_size=(pool_size, pool_size))

    def forward(self, x, mask=None):
        """
        Args:
            x: input features with shape of (num_windows*B, N, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        b, n, c = x.shape
        h = int(n ** 0.5)
        w = int(n ** 0.5)
        num_heads = self.num_heads
        head_dim = c // num_heads
        qkv = self.qkv(x).reshape(b, n, 3, c).permute(2, 0, 1, 3)
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)
        # q, k, v: b, n, c

        agent_tokens = self.pool(q.reshape(b, h, w, c).permute(0, 3, 1, 2)).reshape(b, c, -1).permute(0, 2, 1)
        q = q.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3)
        k = k.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3)
        v = v.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3)
        agent_tokens = agent_tokens.reshape(b, self.agent_num, num_heads, head_dim).permute(0, 2, 1, 3)

        position_bias1 = nn.functional.interpolate(self.an_bias, size=self.window_size, mode='bilinear')
        position_bias1 = position_bias1.reshape(1, num_heads, self.agent_num, -1).repeat(b, 1, 1, 1)
        position_bias2 = (self.ah_bias + self.aw_bias).reshape(1, num_heads, self.agent_num, -1).repeat(b, 1, 1, 1)
        position_bias = position_bias1 + position_bias2
        agent_attn = self.softmax((agent_tokens * self.scale) @ k.transpose(-2, -1) + position_bias)
        agent_attn = self.attn_drop(agent_attn)
        agent_v = agent_attn @ v

        agent_bias1 = nn.functional.interpolate(self.na_bias, size=self.window_size, mode='bilinear')
        agent_bias1 = agent_bias1.reshape(1, num_heads, self.agent_num, -1).permute(0, 1, 3, 2).repeat(b, 1, 1, 1)
        agent_bias2 = (self.ha_bias + self.wa_bias).reshape(1, num_heads, -1, self.agent_num).repeat(b, 1, 1, 1)
        agent_bias = agent_bias1 + agent_bias2
        q_attn = self.softmax((q * self.scale) @ agent_tokens.transpose(-2, -1) + agent_bias)
        q_attn = self.attn_drop(q_attn)
        x = q_attn @ agent_v

        x = x.transpose(1, 2).reshape(b, n, c)
        v = v.transpose(1, 2).reshape(b, h, w, c).permute(0, 3, 1, 2)
        x = x + self.dwc(v).permute(0, 2, 3, 1).reshape(b, n, c)

        x = self.proj(x)
        x = self.proj_drop(x)
        return x

    def extra_repr(self) -> str:
        return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'

我使用了预训练权重:/home/class1/work/modify/G/checkpoints/swin_tiny_patch4_window7_224.pth
https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth
数据集使用了coco格式的数据,替换成您的AgentAttention后,发生了如下错误:

python-BaseException
Traceback (most recent call last):
  File "/home/class1/.pycharm_helpers/pydev/pydevd.py", line 1491, in _exec
    pydev_imports.execfile(file, globals, locals)  # execute the script
  File "/home/class1/.pycharm_helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/home/class1/work/modify/G/tools/train.py", line 276, in <module>
    main()
  File "/home/class1/work/modify/G/tools/train.py", line 239, in main
    model.init_weights()
  File "/home/class1/.conda/envs/mm100/lib/python3.7/site-packages/mmcv/runner/base_module.py", line 117, in init_weights
    m.init_weights()
  File "/home/class1/work/modify/G/mmdet/models/backbones/swin_test.py", line 1296, in init_weights
    table_current = self.state_dict()[table_key]
KeyError: 'stages.0.blocks.0.attn.w_msa.relative_position_bias_table'

您知道如何解决吗?谢谢您!

ModuleNotFoundError: No module named 'mmcv._ext'

/home/algorithms/Agent-Attention/downstream/mmcv/mmcv/cnn/bricks/transformer.py:33: UserWarning: Fail to import MultiScaleDeformableAttention from mmcv.ops.multi_scale_deform_attn, You should install mmcv-full if you need this module.
warnings.warn('Fail to import MultiScaleDeformableAttention from '
Traceback (most recent call last):
File "tools/test.py", line 17, in
from mmseg.apis import multi_gpu_test, single_gpu_test
File "/home/algorithms/Agent-Attention/downstream/segmentation/mmseg/apis/init.py", line 2, in
from .inference import inference_segmentor, init_segmentor, show_result_pyplot
File "/home/algorithms/Agent-Attention/downstream/segmentation/mmseg/apis/inference.py", line 9, in
from mmseg.models import build_segmentor
File "/home/algorithms/Agent-Attention/downstream/segmentation/mmseg/models/init.py", line 2, in
from .backbones import * # noqa: F401,F403
File "/home/algorithms/Agent-Attention/downstream/segmentation/mmseg/models/backbones/init.py", line 7, in
from .fast_scnn import FastSCNN
File "/home/algorithms/Agent-Attention/downstream/segmentation/mmseg/models/backbones/fast_scnn.py", line 7, in
from mmseg.models.decode_heads.psp_head import PPM
File "/home/algorithms/Agent-Attention/downstream/segmentation/mmseg/models/decode_heads/init.py", line 2, in
from .ann_head import ANNHead
File "/home/algorithms/Agent-Attention/downstream/segmentation/mmseg/models/decode_heads/ann_head.py", line 8, in
from .decode_head import BaseDecodeHead
File "/home/algorithms/Agent-Attention/downstream/segmentation/mmseg/models/decode_heads/decode_head.py", line 12, in
from ..losses import accuracy
File "/home/algorithms/Agent-Attention/downstream/segmentation/mmseg/models/losses/init.py", line 6, in
from .focal_loss import FocalLoss
File "/home/algorithms/Agent-Attention/downstream/segmentation/mmseg/models/losses/focal_loss.py", line 6, in
from mmcv.ops import sigmoid_focal_loss as _sigmoid_focal_loss
File "/home/algorithms/Agent-Attention/downstream/mmcv/mmcv/ops/init.py", line 2, in
from .active_rotated_filter import active_rotated_filter
File "/home/algorithms/Agent-Attention/downstream/mmcv/mmcv/ops/active_rotated_filter.py", line 10, in
ext_module = ext_loader.load_ext(
File "/home/algorithms/Agent-Attention/downstream/mmcv/mmcv/utils/ext_loader.py", line 13, in load_ext
ext = importlib.import_module('mmcv.' + name)
File "/root/miniconda3/envs/agent_segmentation/lib/python3.8/importlib/init.py", line 127, in import_module
return _bootstrap._gcd_import(name[level:], package, level)
ModuleNotFoundError: No module named 'mmcv._ext'

forward() got an unexpected keyword argument 'encoder_hidden_states'

hello, I greatly appreciate your awesome work.

It seems that self.attn1 in https://github.com/LeapLabTHU/Agent-Attention/blob/master/agentsd/patch.py#L220 is replaced as AgentAttention, whose forward function accepts forward(self, x, agent=None, context=None, mask=None) as parameters. However, in https://github.com/LeapLabTHU/Agent-Attention/blob/master/agentsd/patch.py#L220, encoder_hidden_states and attention_mask are passed to self.attn1, which causes the problem forward() got an unexpected keyword argument 'encoder_hidden_states'.

do you have any solution? thanks a lot

How is it used for the backbone of the Siamese tracking network?

May I ask how this can be applied to a Siamese tracking network where the input images are of different sizes and serve as a backbone for weight sharing? I noticed that the agent_num and window parameters are related to the input image size, how can I set them to apply to different input image sizes at the same time?

Agent bias

the total model's input size is (1,4,128,128,128),1 represent batchsize, 4 represent channel ,128 represent h,w,d respectively,i read your Appendix A about Agent Bias,you say Each position offset is composed of three parameters, ssuch as B1 = (B′1c + B′1r + B′1b),include column bias B1c ∈ Rn×1×w, row bias B1r ∈ Rn×h×1 and block bias B1b ∈ Rn×h0×w0,but now have d this dimension,how can i modify the agent bias , should i add a new parameter ? like B1c B1r B1b and B1d,can you give me some addvice ,waiting for your reply

Out Of Memory

Hello, I must commend your work; it's truly impressive. However, I've encountered an issue when running AgentSwin under the same batch size configurations that I successfully use with the standard Swin Transformer. Specifically, I'm experiencing out-of-memory errors with AgentSwin that do not arise with Swin Transformer. I wanted to reach out and inquire if this is expected behavior or if there might be some adjustments or optimizations I could consider to alleviate this issue?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.