Giter Site home page Giter Site logo

retain_graph=True about vqtorch HOT 3 OPEN

DiffDynamo avatar DiffDynamo commented on August 22, 2024
retain_graph=True

from vqtorch.

Comments (3)

minyoungg avatar minyoungg commented on August 22, 2024

could you provide a short description of what you are trying to do and a code snippet to reproduce this error?

note that each forward call in training mode will update the codebook when inplace_optimizer is provided

from vqtorch.

DiffDynamo avatar DiffDynamo commented on August 22, 2024

My goal is to conduct discrete representation learning for one-dimensional time series data, and I have created my own autoencoder for this purpose. I added your quantization layer in the bottleneck layer of my autoencoder to discretize the continuous latent representations.
I am using Python version 3.9 and PyTorch version 1.12.1.
a short code snippet:

from torch.nn import MSELoss
import torch.optim as optim
import torch
import torch.nn as nn
from torch.nn import Conv1d,ConvTranspose1d
from vqtorch.nn import VectorQuant
from torch.utils.data import TensorDataset,DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class ResidualUnit(nn.Module):
    def __init__(self, in_channels, out_channels, dilation,kernel_size=7):
        super().__init__()

        self.layers = nn.Sequential(
            Conv1d(in_channels=in_channels, out_channels=in_channels,
                   kernel_size=kernel_size, stride=1,dilation=dilation,padding=int(dilation*(kernel_size-1)/2)),
            nn.ELU())

        self.conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels,
                              kernel_size=1)

    def forward(self, x):
        out = self.layers(x)
        out = self.conv(out)
        return x + out

class EncoderBlock(nn.Module):
    def __init__(self, in_channels, stride=2):
        super().__init__()

        self.layers = nn.Sequential(
            Conv1d(in_channels=in_channels,out_channels=in_channels,kernel_size=1,padding=0,bias=False),
            nn.ELU(),
            ResidualUnit(in_channels=in_channels,
                         out_channels=in_channels, dilation=1),
            nn.ELU(),
            ResidualUnit(in_channels=in_channels,
                         out_channels=in_channels, dilation=3),
            nn.ELU(),
            ResidualUnit(in_channels=in_channels,
                         out_channels=in_channels, dilation=5),
            nn.ELU(),
            Conv1d(in_channels=in_channels, out_channels=2*in_channels,
                   kernel_size=stride, stride=stride)
        )

    def forward(self, x):
        return self.layers(x)
class DecoderBlock(nn.Module):
    def __init__(self, in_channels, stride=2):
        super().__init__()

        self.layers = nn.Sequential(
            ConvTranspose1d(in_channels=in_channels,
                            out_channels=in_channels,
                            kernel_size=stride, stride=stride),
            nn.ELU(),
            ResidualUnit(in_channels=in_channels, out_channels=in_channels,
                         dilation=1),
            nn.ELU(),
            ResidualUnit(in_channels=in_channels, out_channels=in_channels,
                         dilation=3),
            nn.ELU(),
            ResidualUnit(in_channels=in_channels, out_channels=in_channels,
                         dilation=5),
            nn.ELU(),
            Conv1d(in_channels=in_channels,out_channels=in_channels//2,kernel_size=1,padding=0,bias=False)

        )

    def forward(self, x):
        return self.layers(x)

class Encoder(nn.Module):
    def __init__(self, C, D):
        super().__init__()
        self.conv_1 = nn.Sequential(Conv1d(in_channels=1, out_channels=C, kernel_size=3,padding=1),nn.ELU())
        self.conv_2 = nn.Sequential(EncoderBlock(in_channels=C, stride=1),nn.ELU())
        self.conv_3 = nn.Sequential(EncoderBlock(in_channels=2*C, stride=2),nn.ELU())
        self.conv_4 = nn.Sequential(EncoderBlock(in_channels=4*C, stride=2),nn.ELU())
        self.conv_5 = nn.Sequential(EncoderBlock(in_channels=8*C, stride=2),nn.ELU())
        self.conv_6 = nn.Sequential(Conv1d(in_channels=16*C, out_channels=D, kernel_size=3,padding=1))

    def forward(self, x):
        x = self.conv_1(x)
        x = self.conv_2(x)
        x = self.conv_3(x)
        x = self.conv_4(x)
        x = self.conv_5(x)
        x = self.conv_6(x)
        return x
class Decoder(nn.Module):
    def __init__(self, C, D):
        super().__init__()
        self.conv1 = Conv1d(in_channels=C, out_channels=1, kernel_size=1,padding=0,bias=False)
        self.conv_1 = nn.Sequential(Conv1d(in_channels=D, out_channels=16*C, kernel_size=3,padding=1),nn.ELU())
        self.conv_2 = nn.Sequential(DecoderBlock(in_channels=16*C, stride=2),nn.ELU())
        self.conv_3 = nn.Sequential(DecoderBlock(in_channels=8*C, stride=2),nn.ELU())
        self.conv_4 = nn.Sequential(DecoderBlock(in_channels=4*C, stride=2),nn.ELU())
        self.conv_5 = nn.Sequential(DecoderBlock(in_channels=2*C, stride=1),nn.ELU())
        self.conv_6 = nn.Sequential(Conv1d(in_channels=C, out_channels=C, kernel_size=3,padding=1),nn.ELU())

    def forward(self, x):
        x = self.conv_1(x)
        x = self.conv_2(x)
        x = self.conv_3(x)
        x = self.conv_4(x)
        x = self.conv_5(x)
        x = self.conv_6(x)
        x = self.conv1(x)
        return x

class VQencoder(nn.Module):
    def __init__(self, C, D, num_codes=512,embedding_dim=64,optemizer=None):
        super().__init__()

        self.inplace_optimizer = optemizer
        self.encoder = Encoder(C=C, D=D)
        self.vq_layer = VectorQuant(
            feature_size=embedding_dim,     # feature dimension corresponding to the vectors
            num_codes=num_codes,      # number of codebook vectors
            beta=1,           # (default: 0.9) commitment trade-off
            kmeans_init=True,    # (default: False) whether to use kmeans++ init
            norm='l2',           # (default: None) normalization for the input vectors
            cb_norm='l2',        # (default: None) normalization for codebook vectors
            affine_lr=20,      # (default: 0.0) lr scale for affine parameters
            sync_nu=0.2,         # (default: 0.0) codebook synchronization contribution
            replace_freq=20,     # (default: None) frequency to replace dead codes
            inplace_optimizer=self.inplace_optimizer,
            dim=1,              # (default: -1) dimension to be quantized
        )
        self.decoder = Decoder(C=C, D=D)
    def forward(self, x):
        e = self.encoder(x)
        z_q, vq_dict = self.vq_layer(e)
        vq_loss = vq_dict['loss']
        perplexity = vq_dict['perplexity']
        encodings = vq_dict['q']
        out = self.decoder(z_q)
        return out,vq_loss,perplexity,encodings

def train():
    weight_decay = 1e-4
    batch_size = 64
    num_workers = 4
    learning_rate = 1e-3
    loss_function = MSELoss(reduction='mean')
    inplace_optimizer = lambda *args, **kwargs: torch.optim.SGD(*args, **kwargs, lr=10.0, momentum=0.9)
    net = VQencoder(C=4, D=256, num_codes=1024,embedding_dim=256,optemizer=inplace_optimizer)
    net = net.to(device)
    optimizer = optim.AdamW(net.parameters(),lr=learning_rate, weight_decay=weight_decay,betas=(0.9,0.95))
    input_data = torch.randn(10000,1, 2048)
    train_dataset = TensorDataset(input_data)
    train_loader = DataLoader(train_dataset,batch_size=batch_size,num_workers=num_workers)
    for epoch in range(100):
        net.train()
        for batch_idx,data in enumerate(train_loader):
            data = data[0].to(device)
            res,vq_loss,perplexity,encodings = net(data)
            res_loss = loss_function(res,data)
            optimizer.zero_grad()
            res_loss.backward()
            optimizer.step()

if __name__ == '__main__':
    train()

from vqtorch.

SeanNobel avatar SeanNobel commented on August 22, 2024

@minyoungg Hi, thanks for the interesting work and the great library.
I've encountered the same problem as @DiffDynamo and setting retain_graph=True here seems working fine.
Is this probably because the computation graph for z is deleted after the backward pass by inplace optimizer and thus the main optimizer is unable to do straight-through estimation?
Thanks.

from vqtorch.

Related Issues (6)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.