Giter Site home page Giter Site logo

Speed benchmarks about cyclegan-tensorflow HOT 5 OPEN

vanhuyz avatar vanhuyz commented on May 25, 2024
Speed benchmarks

from cyclegan-tensorflow.

Comments (5)

lucasgreene avatar lucasgreene commented on May 25, 2024

Looking more closely im seeing things like
tf.summary.histogram('D_Y/true', self.D_Y(y)) tf.summary.histogram('D_Y/fake', self.D_Y(self.G(x))) tf.summary.histogram('D_X/true', self.D_X(x)) tf.summary.histogram('D_X/fake', self.D_X(self.F(y)))

Im pretty sure that evert time you call G() F() D_Y() or D_X() you are constructing a separate piece of the graph. So basically when you run this histogram you have to compute the entire iteration, rather than just using the data you have already generated in the training ops. Pretty sure that means this code is doing at least 2x the work it should be. Am I way off base here?

from cyclegan-tensorflow.

lucasgreene avatar lucasgreene commented on May 25, 2024

Ok changing these pieces seems to give about a 2x speedup. The code is below. Note that ive change a few things for reading my own custom data. The functions i modified i prefaced with _, so model() becomes _model(). Im not sure why this markdown formatting is weird. It think the various loss functions are calculated correctly but let me know if now

import tensorflow as tf
import ops
import utils
from reader import Reader
from discriminator import Discriminator
from generator import Generator

REAL_LABEL = 0.9

class CycleGAN:
  def __init__(self,
               x_data_dir,
               y_data_dir,
               batch_size=1,
               image_size=256,
               use_lsgan=True,
               norm='instance',
               lambda1=10.0,
               lambda2=10.0,
               learning_rate=2e-4,
               beta1=0.5,
               ngf=64
              ):
    """
    Args:
      X_train_file: string, X tfrecords file for training
      Y_train_file: string Y tfrecords file for training
      batch_size: integer, batch size
      image_size: integer, image size
      lambda1: integer, weight for forward cycle loss (X->Y->X)
      lambda2: integer, weight for backward cycle loss (Y->X->Y)
      use_lsgan: boolean
      norm: 'instance' or 'batch'
      learning_rate: float, initial learning rate for Adam
      beta1: float, momentum term of Adam
      ngf: number of gen filters in first conv layer
    """
    self.lambda1 = lambda1
    self.lambda2 = lambda2
    self.use_lsgan = use_lsgan
    use_sigmoid = not use_lsgan
    self.batch_size = batch_size
    self.image_size = image_size
    self.learning_rate = learning_rate
    self.beta1 = beta1
    self.x_data_dir = x_data_dir
    self.y_data_dir = y_data_dir

    self.is_training = tf.placeholder_with_default(True, shape=[], name='is_training')

    self.G = Generator('G', self.is_training, ngf=ngf, norm=norm, image_size=image_size)
    self.D_Y = Discriminator('D_Y',
        self.is_training, norm=norm, use_sigmoid=use_sigmoid)
    self.F = Generator('F', self.is_training, norm=norm, image_size=image_size)
    self.D_X = Discriminator('D_X',
        self.is_training, norm=norm, use_sigmoid=use_sigmoid)

    self.fake_x = tf.placeholder(tf.float32,
        shape=[batch_size, image_size, image_size, 3])
    self.fake_y = tf.placeholder(tf.float32,
        shape=[batch_size, image_size, image_size, 3])


  def _model(self):
    print("Using faster model")
    reader1 = Reader(self.x_data_dir, name="X")
    reader2 = Reader(self.y_data_dir, name="Y")
    self.x = x = reader1.feed()
    self.y = y = reader2.feed()

    fake_xy = self.G(x)
    fake_xyx = self.F(fake_xy)
    fake_yx = self.F(y)
    fake_yxy = self.G(fake_yx)

    disc_y = self.D_Y(y)
    disc_xy = self.D_Y(fake_xy)
    disc_x = self.D_X(x)
    disc_yx = self.D_X(fake_yx)

    disc_fake_x = self.D_X(self.fake_x)
    disc_fake_y = self.D_Y(self.fake_y)


    print("Creating losses")
    cycle_loss = self._cycle_consistency_loss(fake_xyx, fake_yxy, x, y)

    # X -> Y
    G_gan_loss = self._generator_loss(disc_xy, use_lsgan=self.use_lsgan)
    G_loss =  G_gan_loss + cycle_loss
    D_Y_loss = self._discriminator_loss(disc_y, disc_fake_y, use_lsgan=self.use_lsgan)

    # Y -> X
    F_gan_loss = self._generator_loss(disc_yx, use_lsgan=self.use_lsgan)
    F_loss = F_gan_loss + cycle_loss
    D_X_loss = self._discriminator_loss(disc_x,  disc_fake_x, use_lsgan=self.use_lsgan)

    # summary
    tf.summary.histogram('D_Y/true', disc_y)
    tf.summary.histogram('D_Y/fake', disc_xy)
    tf.summary.histogram('D_X/true', disc_x)
    tf.summary.histogram('D_X/fake', disc_yx)

    tf.summary.scalar('loss/G', G_gan_loss)
    tf.summary.scalar('loss/D_Y', D_Y_loss)
    tf.summary.scalar('loss/F', F_gan_loss)
    tf.summary.scalar('loss/D_X', D_X_loss)
    tf.summary.scalar('loss/cycle', cycle_loss)

    tf.summary.image('X/generated', utils.batch_convert2int(fake_xy))
    tf.summary.image('X/reconstruction', utils.batch_convert2int(fake_xyx))
    tf.summary.image('Y/generated', utils.batch_convert2int(fake_yx))
    tf.summary.image('Y/reconstruction', utils.batch_convert2int(fake_yxy))

    return G_loss, D_Y_loss, F_loss, D_X_loss, fake_xy, fake_yx

  def model(self):
    reader1 = Reader(self.x_data_dir, name="X")
    reader2 = Reader(self.y_data_dir, name="Y")
    self.x = x = reader1.feed()
    self.y = y = reader2.feed()

    print("Creating cycle loss")
    cycle_loss = self.cycle_consistency_loss(self.G, self.F, x, y)

    # X -> Y
    print("Creating fake y")
    fake_y = self.G(x)
    G_gan_loss = self.generator_loss(self.D_Y, fake_y, use_lsgan=self.use_lsgan)
    G_loss =  G_gan_loss + cycle_loss
    D_Y_loss = self.discriminator_loss(self.D_Y, y, self.fake_y, use_lsgan=self.use_lsgan)

    # Y -> X
    print("Creating fake x")
    fake_x = self.F(y)
    F_gan_loss = self.generator_loss(self.D_X, fake_x, use_lsgan=self.use_lsgan)
    F_loss = F_gan_loss + cycle_loss
    D_X_loss = self.discriminator_loss(self.D_X, x, self.fake_x, use_lsgan=self.use_lsgan)

    # summary
    tf.summary.histogram('D_Y/true', self.D_Y(y))
    tf.summary.histogram('D_Y/fake', self.D_Y(self.G(x)))
    tf.summary.histogram('D_X/true', self.D_X(x))
    tf.summary.histogram('D_X/fake', self.D_X(self.F(y)))

    tf.summary.scalar('loss/G', G_gan_loss)
    tf.summary.scalar('loss/D_Y', D_Y_loss)
    tf.summary.scalar('loss/F', F_gan_loss)
    tf.summary.scalar('loss/D_X', D_X_loss)
    tf.summary.scalar('loss/cycle', cycle_loss)

    tf.summary.image('X/generated', utils.batch_convert2int(self.G(x)))
    tf.summary.image('X/reconstruction', utils.batch_convert2int(self.F(self.G(x))))
    tf.summary.image('Y/generated', utils.batch_convert2int(self.F(y)))
    tf.summary.image('Y/reconstruction', utils.batch_convert2int(self.G(self.F(y))))

    return G_loss, D_Y_loss, F_loss, D_X_loss, fake_y, fake_x

  def optimize(self, G_loss, D_Y_loss, F_loss, D_X_loss):
    def make_optimizer(loss, variables, name='Adam'):
      """ Adam optimizer with learning rate 0.0002 for the first 100k steps (~100 epochs)
          and a linearly decaying rate that goes to zero over the next 100k steps
      """
      global_step = tf.Variable(0, trainable=False)
      starter_learning_rate = self.learning_rate
      end_learning_rate = 0.0
      start_decay_step = 100000
      decay_steps = 100000
      beta1 = self.beta1
      learning_rate = (
          tf.where(
                  tf.greater_equal(global_step, start_decay_step),
                  tf.train.polynomial_decay(starter_learning_rate, global_step-start_decay_step,
                                            decay_steps, end_learning_rate,
                                            power=1.0),
                  starter_learning_rate
          )

      )
      tf.summary.scalar('learning_rate/{}'.format(name), learning_rate)

      learning_step = (
          tf.train.AdamOptimizer(learning_rate, beta1=beta1, name=name)
                  .minimize(loss, global_step=global_step, var_list=variables)
      )
      return learning_step

    G_optimizer = make_optimizer(G_loss, self.G.variables, name='Adam_G')
    D_Y_optimizer = make_optimizer(D_Y_loss, self.D_Y.variables, name='Adam_D_Y')
    F_optimizer =  make_optimizer(F_loss, self.F.variables, name='Adam_F')
    D_X_optimizer = make_optimizer(D_X_loss, self.D_X.variables, name='Adam_D_X')

    with tf.control_dependencies([G_optimizer, D_Y_optimizer, F_optimizer, D_X_optimizer]):
      return tf.no_op(name='optimizers')

  def discriminator_loss(self, D, y, fake_y, use_lsgan=True):
    """ Note: default: D(y).shape == (batch_size,5,5,1),
                       fake_buffer_size=50, batch_size=1
    Args:
      G: generator object
      D: discriminator object
      y: 4D tensor (batch_size, image_size, image_size, 3)
    Returns:
      loss: scalar
    """
    if use_lsgan:
      # use mean squared error
      error_real = tf.reduce_mean(tf.squared_difference(D(y), REAL_LABEL))
      error_fake = tf.reduce_mean(tf.square(D(fake_y)))
    else:
      # use cross entropy
      error_real = -tf.reduce_mean(ops.safe_log(D(y)))
      error_fake = -tf.reduce_mean(ops.safe_log(1-D(fake_y)))
    loss = (error_real + error_fake) / 2
    return loss

  def _discriminator_loss(self, disc_real, disc_fake, use_lsgan=True):
    """ Note: default: D(y).shape == (batch_size,5,5,1),
                       fake_buffer_size=50, batch_size=1
    Args:
      G: generator object
      D: discriminator object
      y: 4D tensor (batch_size, image_size, image_size, 3)
    Returns:
      loss: scalar
    """
    if use_lsgan:
      # use mean squared error
      error_real = tf.reduce_mean(tf.squared_difference(disc_real, REAL_LABEL))
      error_fake = tf.reduce_mean(tf.square(disc_fake))
    else:
      # use cross entropy
      error_real = -tf.reduce_mean(ops.safe_log(disc_real))
      error_fake = -tf.reduce_mean(ops.safe_log(1-disc_fake))
    loss = (error_real + error_fake) / 2
    return loss

  def generator_loss(self, D, fake_y, use_lsgan=True):
    """  fool discriminator into believing that G(x) is real
    """
    if use_lsgan:
      # use mean squared error
      loss = tf.reduce_mean(tf.squared_difference(D(fake_y), REAL_LABEL))
    else:
      # heuristic, non-saturating loss
      loss = -tf.reduce_mean(ops.safe_log(D(fake_y))) / 2
    return loss

  def _generator_loss(self, disc_fake, use_lsgan=True):
    """  fool discriminator into believing that G(x) is real
    """
    if use_lsgan:
      # use mean squared error
      loss = tf.reduce_mean(tf.squared_difference(disc_fake, REAL_LABEL))
    else:
      # heuristic, non-saturating loss
      loss = -tf.reduce_mean(ops.safe_log(D(disc_fake))) / 2
    return loss

  def cycle_consistency_loss(self, G, F, x, y):
    """ cycle consistency loss (L1 norm)
    """
    forward_loss = tf.reduce_mean(tf.abs(F(G(x))-x))
    backward_loss = tf.reduce_mean(tf.abs(G(F(y))-y))
    loss = self.lambda1*forward_loss + self.lambda2*backward_loss
    return loss

  def _cycle_consistency_loss(self, fake_xyx, fake_yxy, x, y):
    """ cycle consistency loss (L1 norm)
    """
    forward_loss = tf.reduce_mean(tf.abs(fake_xyx-x))
    backward_loss = tf.reduce_mean(tf.abs(fake_yxy-y))
    loss = self.lambda1*forward_loss + self.lambda2*backward_loss
    return loss

from cyclegan-tensorflow.

lucasgreene avatar lucasgreene commented on May 25, 2024

Also I believe here

      fake_Y_pool = ImagePool(FLAGS.pool_size)
      fake_X_pool = ImagePool(FLAGS.pool_size)

      while not coord.should_stop():
        # get previously generated images
        fake_y_val, fake_x_val = sess.run([fake_y, fake_x])

        # train
        _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary = (
              sess.run(
                  [optimizers, G_loss, D_Y_loss, F_loss, D_X_loss, summary_op],
                  feed_dict={cycle_gan.fake_y: fake_Y_pool.query(fake_y_val),
                             cycle_gan.fake_x: fake_X_pool.query(fake_x_val)}
              )
        )

On every iteration we are generating fake images for the discriminator, and then running all of the losses and gradient updates in a second sess run call. This means that we are running the generator twice for every gradient update. I believe you will get the exact same functionality with

      fake_Y_pool = ImagePool(FLAGS.pool_size)
      fake_X_pool = ImagePool(FLAGS.pool_size)

      # prime previously generated images
      fake_y_val, fake_x_val = sess.run([fake_y, fake_x])

      while not coord.should_stop():
        # train
        _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, fake_y_val, fake_x_val, summary = (
              sess.run(
                  [optimizers, G_loss, D_Y_loss, F_loss, D_X_loss, fake_y, fake_x, summary_op],
                  feed_dict={cycle_gan.fake_y: fake_Y_pool.query(fake_y_val),
                             cycle_gan.fake_x: fake_X_pool.query(fake_x_val)}
              )
        )

Except you won't have to do 2 generations for every gradient update. And now that i think about it the first code might be wrong anyway, as when you generate the fake images with
fake_y_val, fake_x_val = sess.run([fake_y, fake_x]) you are using the last generated images. But in reality for the pool code to work as expected you should be feeding in the current fake images 50% of the time. In the first implementation you are never feeding in the current images.

from cyclegan-tensorflow.

vanhuyz avatar vanhuyz commented on May 25, 2024

Nice! Let me dig around it.

About image pool, the current images are feeding 50% of the time as in: https://github.com/vanhuyz/CycleGAN-TensorFlow/blob/master/utils.py#L49-L57

from cyclegan-tensorflow.

lucasgreene avatar lucasgreene commented on May 25, 2024

Right so my point about the pool is in the current code you have 2 calls to sess.run in every loop. That means that in every loop 2 sets (each set has an image X and Y) of images are being pulled from your queue runner. The first call generates fake images X and Y from images A1 and B1, the second call calculates the loss by pulling another set of images A2 and B2. You feed in the previously generated fake images which are either (50% from random replay pool) or (50% images generated from A1 and B1). So 50% of the time it is a random one from the past, and 50% of the time it is the one from 1 timestep back. But it is never from the current timestep. So I dont think thats the functionality you want right? Shouldnt it be 50% from past and 50% current ? In addition you are generating two sets of fake images for every gradient update so actually doing more work than you have to.

     while not coord.should_stop():
        # get previously generated images
        # Makes a call to image queue
        fake_y_val, fake_x_val = sess.run([fake_y, fake_x])

        # train
        # makes a call to image queue
        _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary = (
              sess.run(
                  [optimizers, G_loss, D_Y_loss, F_loss, D_X_loss, summary_op],
                  feed_dict={cycle_gan.fake_y: fake_Y_pool.query(fake_y_val),
                             cycle_gan.fake_x: fake_X_pool.query(fake_x_val)}
              )
        )

from cyclegan-tensorflow.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.