berniwal / swin-transformer-pytorch Goto Github PK
View Code? Open in Web Editor NEWImplementation of the Swin Transformer in PyTorch.
Home Page: https://arxiv.org/pdf/2103.14030.pdf
License: MIT License
Implementation of the Swin Transformer in PyTorch.
Home Page: https://arxiv.org/pdf/2103.14030.pdf
License: MIT License
Hi,
I am trying to use this model for instance segmentation with no success.
Would you be so kind and guide me on how to do that?
Thank you so much!
I've noticed that defaul setting is patch_norm=False, drop_path_rate=0.2.
Will this be better than patch_norm=True, drop_path_rate=0?
Thank you!
Hello @berniwal ,
I have a question about this:
what's the function of the scale?I can't understand why do this.
Best regards
hello, have you tried to replace the detr's backbone with swin transformer, looking forward to your reply, thanks!
Hello sir, I'm trying to understand "efficient batch computation" which the authors suggested. Probably because of my short knowledge, it was hard to get how it works. Your implementation really helped me for understanding its mechanism, thanks a lot!
Here's my question, it seems the masked area of q * k / sqrt(d) vanishes during the computation of self-attention. I'm not sure that I understood the code correctly, but is this originally intended in the paper? I'm wondering if each subwindow's self-attention might be computed before reversing.
Apology if I misunderstood something, and thanks again!
In this repository, patch merging is implemented with nn.Unfold, but it is expected to behave differently than the official implementation.
Is there something I'm missing out on?
Hi, i'm intereted in your code! But when i run the example of it,
Traceback (most recent call last):
File "D:/Code/Pytorch/swin-transformer-pytorch-0.4/example.py", line 16, in
logits = net(dummy_x) # (1,3)
File "D:\Softwares\Anaconda\envs\pytorch_18\lib\site-packages\torch\nn\modules\module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "D:\Code\Pytorch\swin-transformer-pytorch-0.4\swin_transformer_pytorch\swin_transformer.py", line 219, in forward
x = self.stage1(img)
File "D:\Softwares\Anaconda\envs\pytorch_18\lib\site-packages\torch\nn\modules\module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "D:\Code\Pytorch\swin-transformer-pytorch-0.4\swin_transformer_pytorch\swin_transformer.py", line 190, in forward
x = regular_block(x)
File "D:\Softwares\Anaconda\envs\pytorch_18\lib\site-packages\torch\nn\modules\module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "D:\Code\Pytorch\swin-transformer-pytorch-0.4\swin_transformer_pytorch\swin_transformer.py", line 149, in forward
x = self.attention_block(x)
File "D:\Softwares\Anaconda\envs\pytorch_18\lib\site-packages\torch\nn\modules\module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "D:\Code\Pytorch\swin-transformer-pytorch-0.4\swin_transformer_pytorch\swin_transformer.py", line 22, in forward
return self.fn(x, **kwargs) + x
File "D:\Softwares\Anaconda\envs\pytorch_18\lib\site-packages\torch\nn\modules\module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "D:\Code\Pytorch\swin-transformer-pytorch-0.4\swin_transformer_pytorch\swin_transformer.py", line 32, in forward
return self.fn(self.norm(x), **kwargs)
File "D:\Softwares\Anaconda\envs\pytorch_18\lib\site-packages\torch\nn\modules\module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "D:\Code\Pytorch\swin-transformer-pytorch-0.4\swin_transformer_pytorch\swin_transformer.py", line 117, in forward
dots += self.pos_embedding[self.relative_indices[:, :, 0], self.relative_indices[:, :, 1]]
IndexError: tensors used as indices must be long, byte or bool tensors
And when i change the type to long, the code has another error.
Traceback (most recent call last):
File "D:/Code/Pytorch/swin-transformer-pytorch-0.4/example.py", line 16, in
logits = net(dummy_x) # (1,3)
File "D:\Softwares\Anaconda\envs\pytorch_18\lib\site-packages\torch\nn\modules\module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "D:\Code\Pytorch\swin-transformer-pytorch-0.4\swin_transformer_pytorch\swin_transformer.py", line 219, in forward
x = self.stage1(img)
File "D:\Softwares\Anaconda\envs\pytorch_18\lib\site-packages\torch\nn\modules\module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "D:\Code\Pytorch\swin-transformer-pytorch-0.4\swin_transformer_pytorch\swin_transformer.py", line 188, in forward
x = self.patch_partition(x)
File "D:\Softwares\Anaconda\envs\pytorch_18\lib\site-packages\torch\nn\modules\module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "D:\Code\Pytorch\swin-transformer-pytorch-0.4\swin_transformer_pytorch\swin_transformer.py", line 164, in forward
x = self.patch_merge(x).view(b, -1, new_h, new_w).permute(0, 2, 3, 1)
File "D:\Softwares\Anaconda\envs\pytorch_18\lib\site-packages\torch\nn\modules\module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "D:\Softwares\Anaconda\envs\pytorch_18\lib\site-packages\torch\nn\modules\fold.py", line 295, in forward
self.padding, self.stride)
File "D:\Softwares\Anaconda\envs\pytorch_18\lib\site-packages\torch\nn\functional.py", line 4313, in unfold
return torch._C._nn.im2col(input, _pair(kernel_size), _pair(dilation), _pair(padding), _pair(stride))
RuntimeError: "im2col_out_cpu" not implemented for 'Long'
def create_mask(window_size, displacement, upper_lower, left_right):
mask = torch.zeros(window_size ** 2, window_size ** 2)
it is 49*49 in all tne swin network,why?
Hello,thanks for your code.
Is there pretrained weights available?
THANKS AGAIN.
Hi,
I've been setting up with swin_transformer but having a hard time getting it to actually train.
I figured one immediate issue is the lack of init, so I'm using the truncated init setup from rwightman/pytorch he used in ViT impl since that also uses GELU.
But regardless, I'm not able to get it to learn atm even after testing out a range of lr.
Thus wondering if anyone has found some starting hyperparams and/or init method, to get it up and training?
Are there any training results on imagenet11k and imagenet22k?
Hello, sir. A question popped up again, unfortunately.
I've followed your shifting code, and it seems to have a difference with (my comprehension of) the paper.
I understood the behavior of the original paper's window shifting as a black arrow in the image below (self-attention is calculated with elements inside of bold lines). The left red arrow points to the result of patch-wise rolling and the right red arrow points results of rolling the entire feature map.
In my opinion, self-attention should be computed according to the right-top figure, therefore, boxes of right-bottom should be used (green dot-line separates subwindows) which each region in the right-top figure preserves.
Please let me know if I misunderstood your code or something in the paper. Thanks a lot!
Additionally, this is how I mimicked your code:
import torch
from einops import rearrange
A = torch.Tensor(list(range(1, 17))).view(1, 4, 4)
A_patched = A.view(4, 2, 2).permute(1, 2, 0).view(1, 2, 2, 4)
A_patched_rolled = torch.roll(A_patched, shifts=(-1, -1), dims=(1, 2))
A_rearranged = rearrange(A, 'a (b c) (d e)->a (b d) (c e)', b=2, d=2)
A_rearranged_rolled = torch.roll(A_rearranged, shifts=(-1, -1), dims=(1, 2))
A_rearranged_rolled2 = torch.roll(A_rearranged, shifts=(1, 1), dims=(1, 2))
where A
can be considered as a 4x4 feature map (though element order is not matched with image above), A_patched
is a divided version of A
, and A_patched_rolled
is patch-wise shifted version of A_patched
, following torch.roll(x, shifts=(self.displacement, self.displacement), dims=(1, 2))
in your code. A_rearranged
is rearranged to match the image above.
<---A_patched<---A_patched_rolled
>>> A
tensor([[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]]])
>>> A_patched
tensor([[[[ 1., 5., 9., 13.],
[ 2., 6., 10., 14.]],
[[ 3., 7., 11., 15.],
[ 4., 8., 12., 16.]]]])
>>> A_patched_rolled
tensor([[[[ 4., 8., 12., 16.],
[ 3., 7., 11., 15.]],
[[ 2., 6., 10., 14.],
[ 1., 5., 9., 13.]]]])
>>> A_rearranged
tensor([[[ 1., 2., 5., 6.],
[ 3., 4., 7., 8.],
[ 9., 10., 13., 14.],
[11., 12., 15., 16.]]])
>>> A_rearranged_rolled
tensor([[[ 4., 7., 8., 3.],
[10., 13., 14., 9.],
[12., 15., 16., 11.],
[ 2., 5., 6., 1.]]])
>>> A_rearranged_rolled2
tensor([[[16., 11., 12., 15.],
[ 6., 1., 2., 5.],
[ 8., 3., 4., 7.],
[14., 9., 10., 13.]]])
why the mlp_dim = hidden_dimension *4
i try this swintransformer on deeplabv3 (https://github.com/VainF/DeepLabV3Plus-Pytorch), errors are found:
Exception has occurred: EinopsError
Error while processing rearrange-reduction pattern "b (nw_h w_h) (nw_w w_w) (h d) -> b h (nw_h nw_w) (w_h w_w) d".
Input tensor shape: torch.Size([1, 104, 104, 96]). Additional info: {'h': 3, 'w_h': 7, 'w_w': 7}.
Shape mismatch, can't divide axis of length 104 in chunks of 7
During handling of the above exception, another exception occurred:
File "D:\TangYong\Src\VS\Python\PyTorch\deeplabv3-vainf\network\backbone\swintransformer.py", line 111, in
lambda t: rearrange(t, 'b (nw_h w_h) (nw_w w_w) (h d) -> b h (nw_h nw_w) (w_h w_w) d',
File "D:\TangYong\Src\VS\Python\PyTorch\deeplabv3-vainf\network\backbone\swintransformer.py", line 110, in forward
q, k, v = map(
File "D:\TangYong\Src\VS\Python\PyTorch\deeplabv3-vainf\network\backbone\swintransformer.py", line 32, in forward
return self.fn(self.norm(x), **kwargs)
File "D:\TangYong\Src\VS\Python\PyTorch\deeplabv3-vainf\network\backbone\swintransformer.py", line 22, in forward
return self.fn(x, **kwargs) + x
File "D:\TangYong\Src\VS\Python\PyTorch\deeplabv3-vainf\network\backbone\swintransformer.py", line 149, in forward
thank you for your answer.
Hi @berniwal
Thank you for your great work. I want to pass the image Size argument in the Swin Transformer class because my image size is different. Would it be possible for you to include image Size and SwinTransformer parameters?
Image size is 112x112 in my case.
Regards,
Khawar
Very big thanks for making this implementation!
I just upgraded to the relative pos embedding update from an hour ago and in trying to train get this type error.
---> 32 y_pred = model(images)
33 #print(f" y_pred = {y_pred}")
34 #print(f" y_pred shape = {y_pred.shape}")
~\anaconda3\envs\fastai2\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
725 result = self._slow_forward(*input, **kwargs)
726 else:
--> 727 result = self.forward(*input, **kwargs)
728 for hook in itertools.chain(
729 _global_forward_hooks.values(),
~\cdetr\cdetr_utils\transformer\swin_transformer.py in forward(self, img)
229
230 def forward(self, img):
--> 231 x = self.stage1(img)
232 x = self.stage2(x)
233 x = self.stage3(x)
~\anaconda3\envs\fastai2\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
725 result = self._slow_forward(*input, **kwargs)
726 else:
--> 727 result = self.forward(*input, **kwargs)
728 for hook in itertools.chain(
729 _global_forward_hooks.values(),
~\cdetr\cdetr_utils\transformer\swin_transformer.py in forward(self, x)
189 x = self.patch_partition(x)
190 for regular_block, shifted_block in self.layers:
--> 191 x = regular_block(x)
192 x = shifted_block(x)
193 return x.permute(0, 3, 1, 2)
~\anaconda3\envs\fastai2\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
725 result = self._slow_forward(*input, **kwargs)
726 else:
--> 727 result = self.forward(*input, **kwargs)
728 for hook in itertools.chain(
729 _global_forward_hooks.values(),
~\cdetr\cdetr_utils\transformer\swin_transformer.py in forward(self, x)
148
149 def forward(self, x):
--> 150 x = self.attention_block(x)
151 x = self.mlp_block(x)
152 return x
~\anaconda3\envs\fastai2\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
725 result = self._slow_forward(*input, **kwargs)
726 else:
--> 727 result = self.forward(*input, **kwargs)
728 for hook in itertools.chain(
729 _global_forward_hooks.values(),
~\cdetr\cdetr_utils\transformer\swin_transformer.py in forward(self, x, **kwargs)
21
22 def forward(self, x, **kwargs):
---> 23 return self.fn(x, **kwargs) + x
24
25
~\anaconda3\envs\fastai2\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
725 result = self._slow_forward(*input, **kwargs)
726 else:
--> 727 result = self.forward(*input, **kwargs)
728 for hook in itertools.chain(
729 _global_forward_hooks.values(),
~\cdetr\cdetr_utils\transformer\swin_transformer.py in forward(self, x, **kwargs)
31
32 def forward(self, x, **kwargs):
---> 33 return self.fn(self.norm(x), **kwargs)
34
35
~\anaconda3\envs\fastai2\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
725 result = self._slow_forward(*input, **kwargs)
726 else:
--> 727 result = self.forward(*input, **kwargs)
728 for hook in itertools.chain(
729 _global_forward_hooks.values(),
~\cdetr\cdetr_utils\transformer\swin_transformer.py in forward(self, x)
116
117 if self.relative_pos_embedding:
--> 118 dots += self.pos_embedding[self.relative_indices[:, :, 0], self.relative_indices[:, :, 1]]
119 else:
120 dots += self.pos_embedding
IndexError: tensors used as indices must be long, byte or bool tensors
As described in the title.
Thank you
I'm running an error in your code at line 117
dots += self.pos_embedding[self.relative_indices[:, :, 0], self.relative_indices[:, :, 1]]
IndexError: tensors used as indices must be long, byte or bool tensors
Can this be used for target detection? I didn't make it
Hi, thank you so much for your code. But it encounter some error when implemented on my terminal. here is the error description follows
File "E:/Detection/swin/swin_transformer.py", line 117, in forward dots += self.pos_embedding[self.relative_indices[:, :, 0], self.relative_indices[:, :, 1]] IndexError: tensors used as indices must be long, byte or bool tensors
Your response will be highly apprecitaed.
Could you let me know how to change Image size?
hello,thanks for the work you had done very much and i have a question that how can i apply this code to train a vit model on other dataset,how can i to adjust those parameters?
Thanks for your great work. I do the task of image generation. In my opinion, the current swin-transformer is an encode structure. Is there a corresponding swin-transformer that can be used for decode?
Dear Sir, Thank you very much for your great work. I would like to ask if you have any suggestions on how to set the window size.
For 224x224 input, window size set to 7 is reasonable because it can divide by 7, but for other sizes, such as 768x768 in cityscapes, 7 will undoubtedly report an error since 768 / 32=24 , so it looks like the window setting is very subtle.
The close value is 8, but is the window setting the same as the convolution kernel, where odd numbers work better?
Also, is it possible to set different window sizes at different stages, which seems to be feasible for non-regular image sizes.
Since the window size is a very critical hyperparameter that determines the perceptual field and the amount of computation, would like to request your opinion, thanks!
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.