Giter Site home page Giter Site logo

berniwal / swin-transformer-pytorch Goto Github PK

View Code? Open in Web Editor NEW
762.0 762.0 123.0 206 KB

Implementation of the Swin Transformer in PyTorch.

Home Page: https://arxiv.org/pdf/2103.14030.pdf

License: MIT License

Python 100.00%
artificial-intelligence attention-model deep-learning machine-learning pytorch transformer-architecture transformer-pytorch

swin-transformer-pytorch's People

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

swin-transformer-pytorch's Issues

Training instance segmentation

Hi,
I am trying to use this model for instance segmentation with no success.
Would you be so kind and guide me on how to do that?
Thank you so much!

Cyclic shift with masking

Hello sir, I'm trying to understand "efficient batch computation" which the authors suggested. Probably because of my short knowledge, it was hard to get how it works. Your implementation really helped me for understanding its mechanism, thanks a lot!

Here's my question, it seems the masked area of q * k / sqrt(d) vanishes during the computation of self-attention. I'm not sure that I understood the code correctly, but is this originally intended in the paper? I'm wondering if each subwindow's self-attention might be computed before reversing.
image

Apology if I misunderstood something, and thanks again!

fail to run the code

Hi, i'm intereted in your code! But when i run the example of it,

Traceback (most recent call last):
File "D:/Code/Pytorch/swin-transformer-pytorch-0.4/example.py", line 16, in
logits = net(dummy_x) # (1,3)
File "D:\Softwares\Anaconda\envs\pytorch_18\lib\site-packages\torch\nn\modules\module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "D:\Code\Pytorch\swin-transformer-pytorch-0.4\swin_transformer_pytorch\swin_transformer.py", line 219, in forward
x = self.stage1(img)
File "D:\Softwares\Anaconda\envs\pytorch_18\lib\site-packages\torch\nn\modules\module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "D:\Code\Pytorch\swin-transformer-pytorch-0.4\swin_transformer_pytorch\swin_transformer.py", line 190, in forward
x = regular_block(x)
File "D:\Softwares\Anaconda\envs\pytorch_18\lib\site-packages\torch\nn\modules\module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "D:\Code\Pytorch\swin-transformer-pytorch-0.4\swin_transformer_pytorch\swin_transformer.py", line 149, in forward
x = self.attention_block(x)
File "D:\Softwares\Anaconda\envs\pytorch_18\lib\site-packages\torch\nn\modules\module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "D:\Code\Pytorch\swin-transformer-pytorch-0.4\swin_transformer_pytorch\swin_transformer.py", line 22, in forward
return self.fn(x, **kwargs) + x
File "D:\Softwares\Anaconda\envs\pytorch_18\lib\site-packages\torch\nn\modules\module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "D:\Code\Pytorch\swin-transformer-pytorch-0.4\swin_transformer_pytorch\swin_transformer.py", line 32, in forward
return self.fn(self.norm(x), **kwargs)
File "D:\Softwares\Anaconda\envs\pytorch_18\lib\site-packages\torch\nn\modules\module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "D:\Code\Pytorch\swin-transformer-pytorch-0.4\swin_transformer_pytorch\swin_transformer.py", line 117, in forward
dots += self.pos_embedding[self.relative_indices[:, :, 0], self.relative_indices[:, :, 1]]
IndexError: tensors used as indices must be long, byte or bool tensors

And when i change the type to long, the code has another error.

Traceback (most recent call last):
File "D:/Code/Pytorch/swin-transformer-pytorch-0.4/example.py", line 16, in
logits = net(dummy_x) # (1,3)
File "D:\Softwares\Anaconda\envs\pytorch_18\lib\site-packages\torch\nn\modules\module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "D:\Code\Pytorch\swin-transformer-pytorch-0.4\swin_transformer_pytorch\swin_transformer.py", line 219, in forward
x = self.stage1(img)
File "D:\Softwares\Anaconda\envs\pytorch_18\lib\site-packages\torch\nn\modules\module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "D:\Code\Pytorch\swin-transformer-pytorch-0.4\swin_transformer_pytorch\swin_transformer.py", line 188, in forward
x = self.patch_partition(x)
File "D:\Softwares\Anaconda\envs\pytorch_18\lib\site-packages\torch\nn\modules\module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "D:\Code\Pytorch\swin-transformer-pytorch-0.4\swin_transformer_pytorch\swin_transformer.py", line 164, in forward
x = self.patch_merge(x).view(b, -1, new_h, new_w).permute(0, 2, 3, 1)
File "D:\Softwares\Anaconda\envs\pytorch_18\lib\site-packages\torch\nn\modules\module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "D:\Softwares\Anaconda\envs\pytorch_18\lib\site-packages\torch\nn\modules\fold.py", line 295, in forward
self.padding, self.stride)
File "D:\Softwares\Anaconda\envs\pytorch_18\lib\site-packages\torch\nn\functional.py", line 4313, in unfold
return torch._C._nn.im2col(input, _pair(kernel_size), _pair(dilation), _pair(padding), _pair(stride))
RuntimeError: "im2col_out_cpu" not implemented for 'Long'

why the createmask function is 49*49?

def create_mask(window_size, displacement, upper_lower, left_right):
mask = torch.zeros(window_size ** 2, window_size ** 2)

it is 49*49 in all tne swin network,why?

Pretrained weights?

Hello,thanks for your code.
Is there pretrained weights available?
THANKS AGAIN.

Training advise with swin_transformer - initialization with GELU, etc.

Hi,
I've been setting up with swin_transformer but having a hard time getting it to actually train.
I figured one immediate issue is the lack of init, so I'm using the truncated init setup from rwightman/pytorch he used in ViT impl since that also uses GELU.
But regardless, I'm not able to get it to learn atm even after testing out a range of lr.

Thus wondering if anyone has found some starting hyperparams and/or init method, to get it up and training?

Shifting attention-calculating windows

Hello, sir. A question popped up again, unfortunately.

I've followed your shifting code, and it seems to have a difference with (my comprehension of) the paper.
I understood the behavior of the original paper's window shifting as a black arrow in the image below (self-attention is calculated with elements inside of bold lines). The left red arrow points to the result of patch-wise rolling and the right red arrow points results of rolling the entire feature map.

In my opinion, self-attention should be computed according to the right-top figure, therefore, boxes of right-bottom should be used (green dot-line separates subwindows) which each region in the right-top figure preserves.

Please let me know if I misunderstood your code or something in the paper. Thanks a lot!

Additionally, this is how I mimicked your code:

import torch
from einops import rearrange
A = torch.Tensor(list(range(1, 17))).view(1, 4, 4)
A_patched = A.view(4, 2, 2).permute(1, 2, 0).view(1, 2, 2, 4)
A_patched_rolled = torch.roll(A_patched, shifts=(-1, -1), dims=(1, 2))
A_rearranged = rearrange(A, 'a (b c) (d e)->a (b d) (c e)', b=2, d=2)
A_rearranged_rolled = torch.roll(A_rearranged, shifts=(-1, -1), dims=(1, 2))
A_rearranged_rolled2 = torch.roll(A_rearranged, shifts=(1, 1), dims=(1, 2))

where A can be considered as a 4x4 feature map (though element order is not matched with image above), A_patched is a divided version of A, and A_patched_rolled is patch-wise shifted version of A_patched, following torch.roll(x, shifts=(self.displacement, self.displacement), dims=(1, 2)) in your code. A_rearranged is rearranged to match the image above.

<---A_patched<---A_patched_rolled

>>> A
tensor([[[ 1.,  2.,  3.,  4.],
         [ 5.,  6.,  7.,  8.],
         [ 9., 10., 11., 12.],
         [13., 14., 15., 16.]]])
>>> A_patched
tensor([[[[ 1.,  5.,  9., 13.],
          [ 2.,  6., 10., 14.]],

         [[ 3.,  7., 11., 15.],
          [ 4.,  8., 12., 16.]]]])
>>> A_patched_rolled
tensor([[[[ 4.,  8., 12., 16.],
          [ 3.,  7., 11., 15.]],

         [[ 2.,  6., 10., 14.],
          [ 1.,  5.,  9., 13.]]]])
>>> A_rearranged
tensor([[[ 1.,  2.,  5.,  6.],
         [ 3.,  4.,  7.,  8.],
         [ 9., 10., 13., 14.],
         [11., 12., 15., 16.]]])
>>> A_rearranged_rolled
tensor([[[ 4.,  7.,  8.,  3.],
         [10., 13., 14.,  9.],
         [12., 15., 16., 11.],
         [ 2.,  5.,  6.,  1.]]])
>>> A_rearranged_rolled2
tensor([[[16., 11., 12., 15.],
         [ 6.,  1.,  2.,  5.],
         [ 8.,  3.,  4.,  7.],
         [14.,  9., 10., 13.]]])

deeplabv3 + swintransformer

i try this swintransformer on deeplabv3 (https://github.com/VainF/DeepLabV3Plus-Pytorch), errors are found:

Exception has occurred: EinopsError
Error while processing rearrange-reduction pattern "b (nw_h w_h) (nw_w w_w) (h d) -> b h (nw_h nw_w) (w_h w_w) d".
Input tensor shape: torch.Size([1, 104, 104, 96]). Additional info: {'h': 3, 'w_h': 7, 'w_w': 7}.
Shape mismatch, can't divide axis of length 104 in chunks of 7

During handling of the above exception, another exception occurred:

File "D:\TangYong\Src\VS\Python\PyTorch\deeplabv3-vainf\network\backbone\swintransformer.py", line 111, in
lambda t: rearrange(t, 'b (nw_h w_h) (nw_w w_w) (h d) -> b h (nw_h nw_w) (w_h w_w) d',
File "D:\TangYong\Src\VS\Python\PyTorch\deeplabv3-vainf\network\backbone\swintransformer.py", line 110, in forward
q, k, v = map(
File "D:\TangYong\Src\VS\Python\PyTorch\deeplabv3-vainf\network\backbone\swintransformer.py", line 32, in forward
return self.fn(self.norm(x), **kwargs)
File "D:\TangYong\Src\VS\Python\PyTorch\deeplabv3-vainf\network\backbone\swintransformer.py", line 22, in forward
return self.fn(x, **kwargs) + x
File "D:\TangYong\Src\VS\Python\PyTorch\deeplabv3-vainf\network\backbone\swintransformer.py", line 149, in forward

thank you for your answer.

Image Size

Hi @berniwal

Thank you for your great work. I want to pass the image Size argument in the Swin Transformer class because my image size is different. Would it be possible for you to include image Size and SwinTransformer parameters?

Image size is 112x112 in my case.

Regards,
Khawar

relative pos embedding errs out with "IndexError: tensors used as indices must be long, byte or bool tensors"

Very big thanks for making this implementation!
I just upgraded to the relative pos embedding update from an hour ago and in trying to train get this type error.

---> 32         y_pred = model(images)
     33         #print(f" y_pred = {y_pred}")
     34         #print(f" y_pred shape = {y_pred.shape}")

~\anaconda3\envs\fastai2\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

~\cdetr\cdetr_utils\transformer\swin_transformer.py in forward(self, img)
    229 
    230     def forward(self, img):
--> 231         x = self.stage1(img)
    232         x = self.stage2(x)
    233         x = self.stage3(x)

~\anaconda3\envs\fastai2\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

~\cdetr\cdetr_utils\transformer\swin_transformer.py in forward(self, x)
    189         x = self.patch_partition(x)
    190         for regular_block, shifted_block in self.layers:
--> 191             x = regular_block(x)
    192             x = shifted_block(x)
    193         return x.permute(0, 3, 1, 2)

~\anaconda3\envs\fastai2\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

~\cdetr\cdetr_utils\transformer\swin_transformer.py in forward(self, x)
    148 
    149     def forward(self, x):
--> 150         x = self.attention_block(x)
    151         x = self.mlp_block(x)
    152         return x

~\anaconda3\envs\fastai2\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

~\cdetr\cdetr_utils\transformer\swin_transformer.py in forward(self, x, **kwargs)
     21 
     22     def forward(self, x, **kwargs):
---> 23         return self.fn(x, **kwargs) + x
     24 
     25 

~\anaconda3\envs\fastai2\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

~\cdetr\cdetr_utils\transformer\swin_transformer.py in forward(self, x, **kwargs)
     31 
     32     def forward(self, x, **kwargs):
---> 33         return self.fn(self.norm(x), **kwargs)
     34 
     35 

~\anaconda3\envs\fastai2\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

~\cdetr\cdetr_utils\transformer\swin_transformer.py in forward(self, x)
    116 
    117         if self.relative_pos_embedding:
--> 118             dots += self.pos_embedding[self.relative_indices[:, :, 0], self.relative_indices[:, :, 1]]
    119         else:
    120             dots += self.pos_embedding

IndexError: tensors used as indices must be long, byte or bool tensors

Runtime error

I'm running an error in your code at line 117
dots += self.pos_embedding[self.relative_indices[:, :, 0], self.relative_indices[:, :, 1]]
IndexError: tensors used as indices must be long, byte or bool tensors

encounter error

Hi, thank you so much for your code. But it encounter some error when implemented on my terminal. here is the error description follows
File "E:/Detection/swin/swin_transformer.py", line 117, in forward dots += self.pos_embedding[self.relative_indices[:, :, 0], self.relative_indices[:, :, 1]] IndexError: tensors used as indices must be long, byte or bool tensors
Your response will be highly apprecitaed.

apply to other dataset

hello,thanks for the work you had done very much and i have a question that how can i apply this code to train a vit model on other dataset,how can i to adjust those parameters?

How to use for generation work

Thanks for your great work. I do the task of image generation. In my opinion, the current swin-transformer is an encode structure. Is there a corresponding swin-transformer that can be used for decode?

about widow-size

Dear Sir, Thank you very much for your great work. I would like to ask if you have any suggestions on how to set the window size.
For 224x224 input, window size set to 7 is reasonable because it can divide by 7, but for other sizes, such as 768x768 in cityscapes, 7 will undoubtedly report an error since 768 / 32=24 , so it looks like the window setting is very subtle.
The close value is 8, but is the window setting the same as the convolution kernel, where odd numbers work better?
Also, is it possible to set different window sizes at different stages, which seems to be feasible for non-regular image sizes.
Since the window size is a very critical hyperparameter that determines the perceptual field and the amount of computation, would like to request your opinion, thanks!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.