Giter Site home page Giter Site logo

Comments (5)

alihassanijr avatar alihassanijr commented on July 27, 2024 1

Hello and thank you for your interest,

Just to give a bit of background: RPB is in theory a continuous function, at least they way we intended it for NA/DiNA.
Here we're only learning a discrete set of weights because our kernel size is typically fixed.

As for its implementation here, most tokens (non-edge cases) share an identical RPB grid: north, south, east, west -- and positions in between, i.e. northwest.
And of course there's a magnitude: 1 north, 2 north, etc.

As a result, if you look at a visualization of NA, you would see that if we don't consider edge cases, the key-value positions for the rest of the feature map is identical: query is centered, and the neighbors are wrapped around it, hence they share the same RPB.

It becomes different for the edge cases precisely because they are not centered. For instance, the north-west (top-left) most pixel is always attending to an "extended neighborhood", which is explained in the original NAT paper, therefore its relative positional biases with respect to its key-value pair, or neighborhood, would be different compared to non-edge cases where they're always centered.

To clarify further, you can try plotting much larger inputs, in which you would see the RPB difference only in the corners and see an identical RPB index in the middle.
By the way, thank you for taking the time to plot these, I'm sure it'll help other users as well.

I hope this explains the idea, but if that's not the case, please let us know so we can clarify further.

from neighborhood-attention-transformer.

lartpang avatar lartpang commented on July 27, 2024 1

@alihassanijr

oh.... I understand it. Thank you so much for your patient reply.

from neighborhood-attention-transformer.

lartpang avatar lartpang commented on July 27, 2024

@alihassanijr

Thanks for your reply!

About the original question

In my original example, some settings were blocking my understanding. I optimized the code and it is more intuitive now.
But this also leads to another problem, see the discussion in the next section.

import matplotlib.pyplot as plt
import numpy as np
import torch

# specify the height and width of the feature map
height = width = 10

# construct a figure containing height*width subfigures corresponding to different (h,w) pixel
fig, axes = plt.subplots(nrows=height, ncols=width, figsize=(8, 8))
fig.suptitle('All Index Windows of RPE for each position of H-W Plane')

# specify the size of kernel for position bias
kernel_size = 5

# construct a shared relative position bias map
rpb_size = 2 * kernel_size - 1
shared_rpb_bg = np.zeros((rpb_size, rpb_size), dtype=np.uint8)

idx_h = torch.arange(0, kernel_size)
idx_w = torch.arange(0, kernel_size)
# absolute 1D indices in the left-top window of the rpe map (2*kernel_size-1, 2*kernel_size-1)
# other window indices can be obtained by adding a new start index on this `idx_k`
idx_k = ((idx_h.unsqueeze(-1) * rpb_size) + idx_w).reshape(-1)

# construct indices of the window in rpe map for each (h,w) pixel
num_repeat_h = torch.ones(kernel_size, dtype=torch.long)
num_repeat_w = torch.ones(kernel_size, dtype=torch.long)
num_repeat_h[kernel_size // 2] = height - (kernel_size - 1)
num_repeat_w[kernel_size // 2] = width - (kernel_size - 1)
# the base h and w of the four edge regions is different from others
bias_hw = idx_h.repeat_interleave(num_repeat_h).unsqueeze(-1) * rpb_size + idx_w.repeat_interleave(num_repeat_w)
# each (h,w) in the H-W plane corresponds to a window of kernel_size*kernel_size containing indices
bias_idx = (bias_hw.unsqueeze(-1) + idx_k).reshape(-1, kernel_size ** 2) # height*width,kernel_size**2

# traverse all positions to visualize and highlight their own index window in the shared rpe map
for h in range(height):
    for w in range(width):
        new_rpb_bg = shared_rpb_bg.flatten().copy()

        new_start_idx = h * height + w
        new_rpb_bg[bias_idx[new_start_idx]] = 255  # index the specific window in rpb map
        new_rpb_bg = new_rpb_bg.reshape(rpb_size, rpb_size)
        axes[h, w].imshow(new_rpb_bg)
        axes[h, w].set_title(f"Win {(h,w)}")

plt.show()

rpb-k5-h10

About the relative position bias for NAT

Let's consider a simple case, kernel_size=3, and the rpb map is [2*3-1, 2*3-1]=[5,5].

The real indices of rpb map is:

(-2, -2), (-2, -1), (-2, 0), (-2, 1), (-2, 2), 
(-1, -2), (-1, -1), (-1, 0), (-1, 1), (-1, 2), 
(0, -2), (0, -1), (0, 0), (0, 1), (0, 2),
(1, -2), (1, -1), (1, 0), (1, 1), (1, 2),
(2, -2), (2, -1), (2, 0), (2, 1), (2, 2),

In traditional attention, the rpb is simple and it has no edge regions that require special consideration. So it only has one index pattern: (Fixed: In fact, the traditional RPB is more like a special form of NAT with only edge regions. So the implementation here is generic.)

In the convolution-like NAT operation, the rpe become more complicated.

The index window in the non-edge region is consist with the following pattern window:

(-1, -1), (-1, 0), (-1, 1), 
(0, -1), (0, 0), (0, 1), 
(1, -1), (1, 0), (1, 1),

The index window in the edge region is:

# start at the pixel (h=0, w=0), we denote the matrix as $W_{0,0}$
(0, 0), (0, 1), (0, 2);
(1, 0), (1, 1), (1, 2);
(2, 0), (2, 1), (2, 2);

# start at the pixel (h=0, w=1), we denote the matrix as $W_{0,1}$
(0, -1), (0, 0), (0, 1);
(1, -1), (1, 0), (1, 1);
(2, -1), (2, 0), (2, 1);

# start at the pixel (h=1, w=0), we denote the matrix as $W_{1,0}$
(-1, 0), (-1, 1), (-1, 2);
(0, 0), (0, 1), (0, 2);
(1, 0), (1, 1), (1, 2);

# ....

# start at the pixel (h=height-1, w=width-2), we denote the matrix as $W_{height-1,width-2}$
(-2, -1), (-2, 0), (-2, 1),
(-1, -1), (-1, 0), (-1, 1),
(0, -1), (0, 0), (0, 1),

# start at the pixel (h=height-1, w=width-1), we denote the matrix as $W_{height-1,width-1}$
(-2, -2), (-2, -1), (-2, 0), 
(-1, -2), (-1, -1), (-1, 0), 
(0, -2), (0, -1), (0, 0),

In current implementation of the rpb of LegacyNeighborhoodAttention2D, the index pattern does not correspond to the abovementioned real indices of rpb map.

from neighborhood-attention-transformer.

alihassanijr avatar alihassanijr commented on July 27, 2024

Maybe it's the flip?

# Index flip
# Our RPB indexing in the kernel is in a different order, so we flip these indices to ensure weights match.
bias_idx = torch.flip(bias_idx.reshape(-1, self.kernel_size**2), [0])

We have this additional flip to make sure the behavior is identical to the behavior programmed in NATTEN.

from neighborhood-attention-transformer.

alihassanijr avatar alihassanijr commented on July 27, 2024

I'm closing this issue now because we're moving our extension to its own separate repository, and due to inactivity.

Please feel free to reopen it if you still have questions, or open an issue in NATTEN if it's related to that.

from neighborhood-attention-transformer.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.