Comments (5)
Hello and thank you for your interest,
Just to give a bit of background: RPB is in theory a continuous function, at least they way we intended it for NA/DiNA.
Here we're only learning a discrete set of weights because our kernel size is typically fixed.
As for its implementation here, most tokens (non-edge cases) share an identical RPB grid: north, south, east, west -- and positions in between, i.e. northwest.
And of course there's a magnitude: 1 north, 2 north, etc.
As a result, if you look at a visualization of NA, you would see that if we don't consider edge cases, the key-value positions for the rest of the feature map is identical: query is centered, and the neighbors are wrapped around it, hence they share the same RPB.
It becomes different for the edge cases precisely because they are not centered. For instance, the north-west (top-left) most pixel is always attending to an "extended neighborhood", which is explained in the original NAT paper, therefore its relative positional biases with respect to its key-value pair, or neighborhood, would be different compared to non-edge cases where they're always centered.
To clarify further, you can try plotting much larger inputs, in which you would see the RPB difference only in the corners and see an identical RPB index in the middle.
By the way, thank you for taking the time to plot these, I'm sure it'll help other users as well.
I hope this explains the idea, but if that's not the case, please let us know so we can clarify further.
from neighborhood-attention-transformer.
oh.... I understand it. Thank you so much for your patient reply.
from neighborhood-attention-transformer.
Thanks for your reply!
About the original question
In my original example, some settings were blocking my understanding. I optimized the code and it is more intuitive now.
But this also leads to another problem, see the discussion in the next section.
import matplotlib.pyplot as plt
import numpy as np
import torch
# specify the height and width of the feature map
height = width = 10
# construct a figure containing height*width subfigures corresponding to different (h,w) pixel
fig, axes = plt.subplots(nrows=height, ncols=width, figsize=(8, 8))
fig.suptitle('All Index Windows of RPE for each position of H-W Plane')
# specify the size of kernel for position bias
kernel_size = 5
# construct a shared relative position bias map
rpb_size = 2 * kernel_size - 1
shared_rpb_bg = np.zeros((rpb_size, rpb_size), dtype=np.uint8)
idx_h = torch.arange(0, kernel_size)
idx_w = torch.arange(0, kernel_size)
# absolute 1D indices in the left-top window of the rpe map (2*kernel_size-1, 2*kernel_size-1)
# other window indices can be obtained by adding a new start index on this `idx_k`
idx_k = ((idx_h.unsqueeze(-1) * rpb_size) + idx_w).reshape(-1)
# construct indices of the window in rpe map for each (h,w) pixel
num_repeat_h = torch.ones(kernel_size, dtype=torch.long)
num_repeat_w = torch.ones(kernel_size, dtype=torch.long)
num_repeat_h[kernel_size // 2] = height - (kernel_size - 1)
num_repeat_w[kernel_size // 2] = width - (kernel_size - 1)
# the base h and w of the four edge regions is different from others
bias_hw = idx_h.repeat_interleave(num_repeat_h).unsqueeze(-1) * rpb_size + idx_w.repeat_interleave(num_repeat_w)
# each (h,w) in the H-W plane corresponds to a window of kernel_size*kernel_size containing indices
bias_idx = (bias_hw.unsqueeze(-1) + idx_k).reshape(-1, kernel_size ** 2) # height*width,kernel_size**2
# traverse all positions to visualize and highlight their own index window in the shared rpe map
for h in range(height):
for w in range(width):
new_rpb_bg = shared_rpb_bg.flatten().copy()
new_start_idx = h * height + w
new_rpb_bg[bias_idx[new_start_idx]] = 255 # index the specific window in rpb map
new_rpb_bg = new_rpb_bg.reshape(rpb_size, rpb_size)
axes[h, w].imshow(new_rpb_bg)
axes[h, w].set_title(f"Win {(h,w)}")
plt.show()
About the relative position bias for NAT
Let's consider a simple case, kernel_size=3, and the rpb map is [2*3-1, 2*3-1]=[5,5]
.
The real indices of rpb map is:
(-2, -2), (-2, -1), (-2, 0), (-2, 1), (-2, 2),
(-1, -2), (-1, -1), (-1, 0), (-1, 1), (-1, 2),
(0, -2), (0, -1), (0, 0), (0, 1), (0, 2),
(1, -2), (1, -1), (1, 0), (1, 1), (1, 2),
(2, -2), (2, -1), (2, 0), (2, 1), (2, 2),
In traditional attention, the rpb is simple and it has no edge regions that require special consideration. So it only has one index pattern: (Fixed: In fact, the traditional RPB is more like a special form of NAT with only edge regions. So the implementation here is generic.)
In the convolution-like NAT operation, the rpe become more complicated.
The index window in the non-edge region is consist with the following pattern window:
(-1, -1), (-1, 0), (-1, 1),
(0, -1), (0, 0), (0, 1),
(1, -1), (1, 0), (1, 1),
The index window in the edge region is:
# start at the pixel (h=0, w=0), we denote the matrix as $W_{0,0}$
(0, 0), (0, 1), (0, 2);
(1, 0), (1, 1), (1, 2);
(2, 0), (2, 1), (2, 2);
# start at the pixel (h=0, w=1), we denote the matrix as $W_{0,1}$
(0, -1), (0, 0), (0, 1);
(1, -1), (1, 0), (1, 1);
(2, -1), (2, 0), (2, 1);
# start at the pixel (h=1, w=0), we denote the matrix as $W_{1,0}$
(-1, 0), (-1, 1), (-1, 2);
(0, 0), (0, 1), (0, 2);
(1, 0), (1, 1), (1, 2);
# ....
# start at the pixel (h=height-1, w=width-2), we denote the matrix as $W_{height-1,width-2}$
(-2, -1), (-2, 0), (-2, 1),
(-1, -1), (-1, 0), (-1, 1),
(0, -1), (0, 0), (0, 1),
# start at the pixel (h=height-1, w=width-1), we denote the matrix as $W_{height-1,width-1}$
(-2, -2), (-2, -1), (-2, 0),
(-1, -2), (-1, -1), (-1, 0),
(0, -2), (0, -1), (0, 0),
In current implementation of the rpb of LegacyNeighborhoodAttention2D
, the index pattern does not correspond to the abovementioned real indices of rpb map.
from neighborhood-attention-transformer.
Maybe it's the flip?
Neighborhood-Attention-Transformer/natten/nattentorch2d.py
Lines 54 to 56 in 1437787
We have this additional flip to make sure the behavior is identical to the behavior programmed in NATTEN.
from neighborhood-attention-transformer.
I'm closing this issue now because we're moving our extension to its own separate repository, and due to inactivity.
Please feel free to reopen it if you still have questions, or open an issue in NATTEN if it's related to that.
from neighborhood-attention-transformer.
Related Issues (20)
- Can you release your training log of NAT? I mean, the summary.csv in output folder. HOT 3
- ONNX HOT 2
- How to visualize the attention map? HOT 3
- Welcome update to OpenMMLab 2.0 HOT 1
- Is it possible to do upsampling using NAT ? HOT 2
- Where is natten.py
- May I ask whether the code of coco instance segmentation mask2former is dinat or NAT? HOT 1
- some problem during train HOT 9
- Is DiNAT code is runnable? HOT 2
- Is dectect model available? HOT 2
- freeze_at be set to 2 to freeze the pretrained weight downloaded from the official website? HOT 2
- About the receptive field of image pixel HOT 4
- NAT Tiny performance on ImageNet 1k HOT 7
- training from scratch with different size for height and width HOT 3
- Cannot repeat the results of Mask2Former+DiNAT-Large on ADE20K HOT 12
- mmdetection on COCO2017 not converge HOT 1
- How to calculate the number of params? HOT 1
- For 3D segmentation HOT 2
- instance segmentation mask2former + dinat HOT 1
- Some comparisons against Deformable Attention HOT 4
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from neighborhood-attention-transformer.