Giter Site home page Giter Site logo

self-attention-gan's People

Contributors

capworkshop avatar doctorteeth avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

self-attention-gan's Issues

Attention map implementation

Hi,

I couldn't find the implementation of the attention layer inside the network models. In the SAGAN paper it is mentioned that they have added the self-attention mechanism at different stages and compared them with each other. Would you please let me know where you have considered that?

Bests,
Samaneh

Max pooling and layer dimensions inside attention layer

  • There are a couple of max_pooling2d() layers inside the attention layer sn_non_local_block_sim() which reduce the number of local features by 4 as such downsampled_num = location_num // 4. However, no downsampling step is reported in the original paper.

  • Also, the first two sn_conv1x1() layers, which stand for Wg and Wf in the paper, have equal sizes C/8 x C, but the third one standing for Wh has C/2 x C shape, while should be also C/8 x C. Similarly the last conv layer.

Is there a reason for such discrepancies?

Related #8

Pre-trained model (imagenet)?

Hello and thank you for the repository!
The model training on Imagenet dataset will take a lot of time. Could someone upload a pre-trained model?

Train on unconditional dataset?

How can we train on Unconditional dataset? I have tried on CelebA dataset, with set parameter number_class to 1. But training process not going well, got error after 60K steps. How can I train with unconditional dataset?

capture1

capture

capture3

Possible error in reshape?

attn_g = tf.reshape(attn_g, [batch_size, h, w, num_channels // 2])

Is there a mistake in the reshape operation linked above? Shouldn't it be

attn_g = tf.reshape(attn_g, [batch_size, h // 2, w // 2, num_channels // 2])
attn_g = tf.depth_to_space(attn_g, [batch_size, h, w, num_channels // 8])
attn_g = sn_conv1x1(attn_g, num_channels, update_collection, init, 'sn_conv_attn')

instead of attn_g = tf.reshape(attn_g, [batch_size, h, w, num_channels // 2])?

How to train on only one GPU?

Hello,
Which parameters do I need to change to make this train and evaluate on one GPU?
I am currently getting an OOM Resource Exhausted Error when i try to train on one GTX 1080.
I tried setting num_towers=1 in train_imagenet.py but this did not help.

No use of reuse_vars?

It seems that you create the variable 'reuse_vars' in the build_model_single_gpu function. However, I do not find you use this variable to reuse variables among your multiple GPUs. Could you please check that? Thank you so much!

Residual vs attentional blocks

All generator and discriminator types implemented here are made of either block() or block_no_sn() modules, which either way have internally a residual connection x_0 + x by default. However, in the associated paper residual vs. attentional blocks are compared as if both architectures were exclusive, one or the other. So, does the attentional architecture reported in the paper includes also residual blocks or this implementation does not fully follow the reported architectures?

Thanks.

Attention map visualization

Hi,

I am having trouble understanding the right way to visualize the attention maps. Lets day the attention block is in the last layer and the image has w=128, h=128, that means the attention map as dimensions N=w*h.

if I want to visualize the attention map for the midpoint for example. which part of the attention map should I access?

The only idea I got was to obtain either the row or the column n:

attention_map[:,n] or attention_map[n,:]

Could you explain how to correctly access the attention map for a specific point?

Thanks in advance

HELP REGARDING THE NON LOCAL BLOCK CODE

`def Nonlocalblock(x):
batch_size, height, width, in_channels = x.get_shape().as_list()
print("height",height)
print("width",width)
print("in_channels",in_channels)
#print("out_channels",out_channels)
print( "shape", x.get_shape())

g1 = tf.keras.layers.Conv2D(in_channels, 1, strides=(1, 1), padding = 'same')(x)
g1 = tf.math.multiply(g1,x)
print("g1",g1.shape)
g = tf.keras.layers.Conv2D(in_channels, 1, strides=(1, 1), padding = 'same')(g1)
print("phi",g.shape)  #x, tf.stack( [ -1, nb_maps, nb_feats ] 
hw =  height * width
g_x = tf.reshape(g, [ batch_size, hw, in_channels])
g_x = tf.squeeze(g_x ,axis= 0)        
phi = tf.keras.layers.Conv2D(in_channels, 1, strides=(1, 1), padding='same')(g1)
print("phi",phi.shape)
theta = tf.keras.layers.Conv2D(in_channels, 1, strides=(1, 1), padding='same')(g1)
print("theta",theta.shape)
print("g_x",g_x.shape)                                               #64,16384   
theta_x = tf.reshape(theta, [ batch_size, hw, in_channels])                     #64,16384
print( "theta_x",theta_x.shape)
phi_x = tf.reshape(phi, ([ batch_size, hw, in_channels]))
phi_x1 = tf.squeeze(phi_x ,axis= 0) 
print( "phi_x",phi_x.shape) 
#theta_x1 = tf.transpose(theta_x, [0,2,1])   
#theta_x1 = tf.squeeze(theta_x1 ,axis= 0)                                        #16384,64
#print( "theta_x1",theta_x1.shape) 
print( "theta_x",theta_x.shape)           
f = tf.matmul( theta_x,g_x,transpose_b=True )                                          #64,64
print("f",f.shape)
f = tf.nn.softmax(f, -1)
y = tf.matmul(phi_x1,f )  
print("y",y.shape) 
y1 = tf.nn.softmax(y)
print("y1",y1.shape)                                            #64,16384
y1 = tf.reshape(y1, [ batch_size, height, width,in_channels])
print("y1",y1.shape)
print("in_channels",in_channels)  
print( "y2" , y1.shape )
return y1 

`

I have implemented this non-local attention block as shown in the above code, but the problem is that when I am using it in a network the batch-size is always None, so while using it for multiplication and reshaping is giving me error

Cuda out of memory

When I try to run the code energy = torch.bmm(proj_query, proj_key), the program runs into the RuntimeError: CUDA out of memory. My Graphics card's memory is 12GB and I am looking for a way to reduce the size of intermediate variables.i.e.energy which in my case is 1 x 65536 x 65536. I've already used torch.no_grad() and split the intermediate matrixes into smaller sub-matrix, then use del to release the memory. But it doesn't seem to work, would you please show me some subtle tips to help me with this kind of problem? (My batch size is 1, the input size is 256 x 256)

About learning rate decay

Hello,

Is the GAN trained with a fixed learning rate?

The discriminator LR: 0.0004
The generator LR: 0.0001

Are these learning rates decayed? If so, where may I find the implementation?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.