Comments (4)
hey, it's actually simpler than what you are led to believe - its basically just attending in a 3d convolution pattern, where each query token attends to the keys as designated by the kernel size
haha, i'm human, i copied pasted a lot of code from my other repos, as it is largely the same as DALL-E, but accounting for the extra frame dimension
from nuwa-pytorch.
hi, many thanks for such a quick reply! so wouldn't that yield multiple context vectors after 3DNA (for each kernel) after the attention operation? do we simply concatenate the vectors then and move on with our FF layer to get the final vector?
from nuwa-pytorch.
no problem! yea, you'll have multiple contexts (keys) per query vector, but then, you aggregate and sum the values (also from the contexts) based on the attention between the query / key similarity scores - so in the end, the length is preserved to be the same as what the queries were - the output of attention is always the same as the input!
from nuwa-pytorch.
Hmm.. 🤔 I tried writing it out in some code, but am unsure whether the shapes and logic is correct (sorry for newbie mistakes, first time writing pseudocode). I also couldn't figure out how to obtain convolutions from torch layer so went with split instead (not sure how you performed it, but I couldn't seem to get the actual convolutions) 🤗
#Simple Pseudocode for NUWA attempt
import torch
#constants
img_dim = 8
window_size = 2
#inputs
input_img_1 = torch.ones((img_dim, img_dim, 3)) #8x8x3 matrix
input_img_2 = torch.ones((img_dim, img_dim, 3)) #8x8x3 matrix
#windowing function emulates convolutions by breaking down each image in
#'window_size' chunks to compute attention over
window_1 = torch.split(input_img_1, window_size) # multiple matrices of size 2x8x3; 2 from window_size
window_2 = torch.split(input_img_2, window_size) # thus we get 4 such cuboidal matrices
def SelfAttention(input_matrix_1: torch.Tensor, input_matrix_2: torch.Tensor):
'''
input_matrix_1: torch Tensor for the matrix to attend to
input_matrix_2: matrix to attend against
'''
input_matrix_1 = input_matrix_1.view(-1, 1) # 48x1
input_matrix_2 = input_matrix_2.view(-1, 1) # 48x1
input_seq = torch.concat((input_matrix_1, input_matrix_2)) # concatenating to 96x1
# projections to 128 dimensionality, to create the vectors
projection_layer = torch.nn.Linear(1, 128) #in reality we have 3 different matrices for each vector, but that doesn't change things
Q = K = V = projection_layer(input_seq) # 96x128
Atten = torch.nn.functional.softmax(Q @ K.T, dim=-1) @ V #scaling root(d_k) ignored for simplicity
# Attention retains the 96x128 shape
logits = Atten.sum(dim=1) # summing columnwise to obtain the same input shape
return logits
context_vec = [] # empty list to store all our context vectors from each window/chunk
for chunk_1, chunk_2 in zip(window_1, window_2): #iterating over spatially corresponding chunks in images
context_vec.append(SelfAttention(chunk_1, chunk_2)) # computing attention over chunks and storing in list
final_attention_matrix = torch.stack(context_vec).reshape(2, 8, 8, 3) # 2x8x8x3
# Thus we re-obtain our original dimensional attention tensor to pass along to other heads.
# however, we don't use a FF layer to project it for simplicity :)
from nuwa-pytorch.
Related Issues (9)
- cant wait HOT 1
- what's the recommanded hardware for using?
- Why the video does not pass through the encoder?
- Looks like Microsoft made a significant advancement on there initial NUWA
- Colab HOT 3
- Type of dataset for training VQ-GAN
- Question about generated videos?
- Questions about function forward() in NUWA please. HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from nuwa-pytorch.