Giter Site home page Giter Site logo

change size of image about gan-tutorial HOT 1 OPEN

yangyangii avatar yangyangii commented on August 17, 2024
change size of image

from gan-tutorial.

Comments (1)

renan-siqueira avatar renan-siqueira commented on August 17, 2024

Hi gunahn,

I didn't develop this code but I work with GAN development as well. From what I understand in the code you need to change the following parts:


Generator:

class Generator(nn.Module):
    """
        Convolutional Generator
    """
    def __init__(self, out_channel=3, n_filters=128, n_noise=512):
        super(Generator, self).__init__()
        self.fc = nn.Linear(n_noise, 1024*4*4)
        self.G = nn.Sequential(
            ResidualBlock(1024, 512),
            nn.Upsample(scale_factor=2, mode='bilinear'), # (N, 512, 8, 8)
            ResidualBlock(512, 256),
            nn.Upsample(scale_factor=2, mode='bilinear'), # (N, 256, 16, 16)
            ResidualBlock(256, 128),
            nn.Upsample(scale_factor=2, mode='bilinear'), # (N, 128, 32, 32)
            ResidualBlock(128, 64),
            nn.Upsample(scale_factor=2, mode='bilinear'), # (N, 64, 64, 64)
            ResidualBlock(64, 64),
            nn.Upsample(scale_factor=2, mode='bilinear'), # (N, 64, 128, 128)
            ResidualBlock(64, 32),
            nn.Upsample(scale_factor=2, mode='bilinear'), # (N, 32, 256, 256)
            ResidualBlock(32, 16),
            nn.Upsample(scale_factor=2, mode='bilinear'), # (N, 16, 512, 512)
            ResidualBlock(16, 16),
            nn.Conv2d(16, out_channel, 3, padding=1) # (N, 3, 512, 512)
        )
        
    def forward(self, z):
        B = z.size(0)
        h = self.fc(z)
        h = h.view(B, 1024, 4, 4)
        x = self.G(h)
        return x

Discriminator:

class Discriminator(nn.Module):
    """
        Convolutional Discriminator
    """
    def __init__(self, in_channel=3):
        super(Discriminator, self).__init__()
        self.D = nn.Sequential(
            nn.Conv2d(in_channel, 32, 3, padding=1), # (N, 32, 512, 512)
            ResidualBlock(32, 64),
            nn.AvgPool2d(3, 2, padding=1), # (N, 64, 256, 256)
            ResidualBlock(64, 128),
            nn.AvgPool2d(3, 2, padding=1), # (N, 128, 128, 128)
            ResidualBlock(128, 256),
            nn.AvgPool2d(3, 2, padding=1), # (N, 256, 64, 64)
            ResidualBlock(256, 512),
            nn.AvgPool2d(3, 2, padding=1), # (N, 512, 32, 32)
            ResidualBlock(512, 1024),
            nn.AvgPool2d(3, 2, padding=1), # (N, 1024, 16, 16)
            ResidualBlock(1024, 2048),
            nn.AvgPool2d(3, 2, padding=1) # (N, 2048, 8, 8)
        )
        self.fc = nn.Linear(2048*8*8, 1) # (N, 1)
        
    def forward(self, x):
        B = x.size(0)
        h = self.D(x)
        h = h.view(B, -1)
        y = self.fc(h)
        return y

Other adaptations:

IMAGE_DIM = (512, 512, 3)

def tensor2img(tensor):
    img = (np.transpose(tensor.detach().cpu().numpy(), [1,2,0])+1)/2.
    return img

def get_sample_image(G, n_noise=512, n_samples=64):
    """
        save sample 100 images
    """
    n_rows = int(np.sqrt(n_samples))
    z = (torch.rand(size=[n_samples, n_noise])*2-1).to(DEVICE) # U[-1, 1]
    x_fake = G(z)
    x_fake = torch.cat([torch.cat([x_fake[n_rows*j+i] for i in range(n_rows)], dim=2) for j in range(n_rows)], dim=1)
    result = tensor2img(x_fake)
    return result

I haven't tested it, but if you have any errors, share here and I can try to help.

from gan-tutorial.

Related Issues (3)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.