Comments (5)
Looking more closely im seeing things like
tf.summary.histogram('D_Y/true', self.D_Y(y)) tf.summary.histogram('D_Y/fake', self.D_Y(self.G(x))) tf.summary.histogram('D_X/true', self.D_X(x)) tf.summary.histogram('D_X/fake', self.D_X(self.F(y)))
Im pretty sure that evert time you call G() F() D_Y() or D_X() you are constructing a separate piece of the graph. So basically when you run this histogram you have to compute the entire iteration, rather than just using the data you have already generated in the training ops. Pretty sure that means this code is doing at least 2x the work it should be. Am I way off base here?
from cyclegan-tensorflow.
Ok changing these pieces seems to give about a 2x speedup. The code is below. Note that ive change a few things for reading my own custom data. The functions i modified i prefaced with _, so model() becomes _model(). Im not sure why this markdown formatting is weird. It think the various loss functions are calculated correctly but let me know if now
import tensorflow as tf
import ops
import utils
from reader import Reader
from discriminator import Discriminator
from generator import Generator
REAL_LABEL = 0.9
class CycleGAN:
def __init__(self,
x_data_dir,
y_data_dir,
batch_size=1,
image_size=256,
use_lsgan=True,
norm='instance',
lambda1=10.0,
lambda2=10.0,
learning_rate=2e-4,
beta1=0.5,
ngf=64
):
"""
Args:
X_train_file: string, X tfrecords file for training
Y_train_file: string Y tfrecords file for training
batch_size: integer, batch size
image_size: integer, image size
lambda1: integer, weight for forward cycle loss (X->Y->X)
lambda2: integer, weight for backward cycle loss (Y->X->Y)
use_lsgan: boolean
norm: 'instance' or 'batch'
learning_rate: float, initial learning rate for Adam
beta1: float, momentum term of Adam
ngf: number of gen filters in first conv layer
"""
self.lambda1 = lambda1
self.lambda2 = lambda2
self.use_lsgan = use_lsgan
use_sigmoid = not use_lsgan
self.batch_size = batch_size
self.image_size = image_size
self.learning_rate = learning_rate
self.beta1 = beta1
self.x_data_dir = x_data_dir
self.y_data_dir = y_data_dir
self.is_training = tf.placeholder_with_default(True, shape=[], name='is_training')
self.G = Generator('G', self.is_training, ngf=ngf, norm=norm, image_size=image_size)
self.D_Y = Discriminator('D_Y',
self.is_training, norm=norm, use_sigmoid=use_sigmoid)
self.F = Generator('F', self.is_training, norm=norm, image_size=image_size)
self.D_X = Discriminator('D_X',
self.is_training, norm=norm, use_sigmoid=use_sigmoid)
self.fake_x = tf.placeholder(tf.float32,
shape=[batch_size, image_size, image_size, 3])
self.fake_y = tf.placeholder(tf.float32,
shape=[batch_size, image_size, image_size, 3])
def _model(self):
print("Using faster model")
reader1 = Reader(self.x_data_dir, name="X")
reader2 = Reader(self.y_data_dir, name="Y")
self.x = x = reader1.feed()
self.y = y = reader2.feed()
fake_xy = self.G(x)
fake_xyx = self.F(fake_xy)
fake_yx = self.F(y)
fake_yxy = self.G(fake_yx)
disc_y = self.D_Y(y)
disc_xy = self.D_Y(fake_xy)
disc_x = self.D_X(x)
disc_yx = self.D_X(fake_yx)
disc_fake_x = self.D_X(self.fake_x)
disc_fake_y = self.D_Y(self.fake_y)
print("Creating losses")
cycle_loss = self._cycle_consistency_loss(fake_xyx, fake_yxy, x, y)
# X -> Y
G_gan_loss = self._generator_loss(disc_xy, use_lsgan=self.use_lsgan)
G_loss = G_gan_loss + cycle_loss
D_Y_loss = self._discriminator_loss(disc_y, disc_fake_y, use_lsgan=self.use_lsgan)
# Y -> X
F_gan_loss = self._generator_loss(disc_yx, use_lsgan=self.use_lsgan)
F_loss = F_gan_loss + cycle_loss
D_X_loss = self._discriminator_loss(disc_x, disc_fake_x, use_lsgan=self.use_lsgan)
# summary
tf.summary.histogram('D_Y/true', disc_y)
tf.summary.histogram('D_Y/fake', disc_xy)
tf.summary.histogram('D_X/true', disc_x)
tf.summary.histogram('D_X/fake', disc_yx)
tf.summary.scalar('loss/G', G_gan_loss)
tf.summary.scalar('loss/D_Y', D_Y_loss)
tf.summary.scalar('loss/F', F_gan_loss)
tf.summary.scalar('loss/D_X', D_X_loss)
tf.summary.scalar('loss/cycle', cycle_loss)
tf.summary.image('X/generated', utils.batch_convert2int(fake_xy))
tf.summary.image('X/reconstruction', utils.batch_convert2int(fake_xyx))
tf.summary.image('Y/generated', utils.batch_convert2int(fake_yx))
tf.summary.image('Y/reconstruction', utils.batch_convert2int(fake_yxy))
return G_loss, D_Y_loss, F_loss, D_X_loss, fake_xy, fake_yx
def model(self):
reader1 = Reader(self.x_data_dir, name="X")
reader2 = Reader(self.y_data_dir, name="Y")
self.x = x = reader1.feed()
self.y = y = reader2.feed()
print("Creating cycle loss")
cycle_loss = self.cycle_consistency_loss(self.G, self.F, x, y)
# X -> Y
print("Creating fake y")
fake_y = self.G(x)
G_gan_loss = self.generator_loss(self.D_Y, fake_y, use_lsgan=self.use_lsgan)
G_loss = G_gan_loss + cycle_loss
D_Y_loss = self.discriminator_loss(self.D_Y, y, self.fake_y, use_lsgan=self.use_lsgan)
# Y -> X
print("Creating fake x")
fake_x = self.F(y)
F_gan_loss = self.generator_loss(self.D_X, fake_x, use_lsgan=self.use_lsgan)
F_loss = F_gan_loss + cycle_loss
D_X_loss = self.discriminator_loss(self.D_X, x, self.fake_x, use_lsgan=self.use_lsgan)
# summary
tf.summary.histogram('D_Y/true', self.D_Y(y))
tf.summary.histogram('D_Y/fake', self.D_Y(self.G(x)))
tf.summary.histogram('D_X/true', self.D_X(x))
tf.summary.histogram('D_X/fake', self.D_X(self.F(y)))
tf.summary.scalar('loss/G', G_gan_loss)
tf.summary.scalar('loss/D_Y', D_Y_loss)
tf.summary.scalar('loss/F', F_gan_loss)
tf.summary.scalar('loss/D_X', D_X_loss)
tf.summary.scalar('loss/cycle', cycle_loss)
tf.summary.image('X/generated', utils.batch_convert2int(self.G(x)))
tf.summary.image('X/reconstruction', utils.batch_convert2int(self.F(self.G(x))))
tf.summary.image('Y/generated', utils.batch_convert2int(self.F(y)))
tf.summary.image('Y/reconstruction', utils.batch_convert2int(self.G(self.F(y))))
return G_loss, D_Y_loss, F_loss, D_X_loss, fake_y, fake_x
def optimize(self, G_loss, D_Y_loss, F_loss, D_X_loss):
def make_optimizer(loss, variables, name='Adam'):
""" Adam optimizer with learning rate 0.0002 for the first 100k steps (~100 epochs)
and a linearly decaying rate that goes to zero over the next 100k steps
"""
global_step = tf.Variable(0, trainable=False)
starter_learning_rate = self.learning_rate
end_learning_rate = 0.0
start_decay_step = 100000
decay_steps = 100000
beta1 = self.beta1
learning_rate = (
tf.where(
tf.greater_equal(global_step, start_decay_step),
tf.train.polynomial_decay(starter_learning_rate, global_step-start_decay_step,
decay_steps, end_learning_rate,
power=1.0),
starter_learning_rate
)
)
tf.summary.scalar('learning_rate/{}'.format(name), learning_rate)
learning_step = (
tf.train.AdamOptimizer(learning_rate, beta1=beta1, name=name)
.minimize(loss, global_step=global_step, var_list=variables)
)
return learning_step
G_optimizer = make_optimizer(G_loss, self.G.variables, name='Adam_G')
D_Y_optimizer = make_optimizer(D_Y_loss, self.D_Y.variables, name='Adam_D_Y')
F_optimizer = make_optimizer(F_loss, self.F.variables, name='Adam_F')
D_X_optimizer = make_optimizer(D_X_loss, self.D_X.variables, name='Adam_D_X')
with tf.control_dependencies([G_optimizer, D_Y_optimizer, F_optimizer, D_X_optimizer]):
return tf.no_op(name='optimizers')
def discriminator_loss(self, D, y, fake_y, use_lsgan=True):
""" Note: default: D(y).shape == (batch_size,5,5,1),
fake_buffer_size=50, batch_size=1
Args:
G: generator object
D: discriminator object
y: 4D tensor (batch_size, image_size, image_size, 3)
Returns:
loss: scalar
"""
if use_lsgan:
# use mean squared error
error_real = tf.reduce_mean(tf.squared_difference(D(y), REAL_LABEL))
error_fake = tf.reduce_mean(tf.square(D(fake_y)))
else:
# use cross entropy
error_real = -tf.reduce_mean(ops.safe_log(D(y)))
error_fake = -tf.reduce_mean(ops.safe_log(1-D(fake_y)))
loss = (error_real + error_fake) / 2
return loss
def _discriminator_loss(self, disc_real, disc_fake, use_lsgan=True):
""" Note: default: D(y).shape == (batch_size,5,5,1),
fake_buffer_size=50, batch_size=1
Args:
G: generator object
D: discriminator object
y: 4D tensor (batch_size, image_size, image_size, 3)
Returns:
loss: scalar
"""
if use_lsgan:
# use mean squared error
error_real = tf.reduce_mean(tf.squared_difference(disc_real, REAL_LABEL))
error_fake = tf.reduce_mean(tf.square(disc_fake))
else:
# use cross entropy
error_real = -tf.reduce_mean(ops.safe_log(disc_real))
error_fake = -tf.reduce_mean(ops.safe_log(1-disc_fake))
loss = (error_real + error_fake) / 2
return loss
def generator_loss(self, D, fake_y, use_lsgan=True):
""" fool discriminator into believing that G(x) is real
"""
if use_lsgan:
# use mean squared error
loss = tf.reduce_mean(tf.squared_difference(D(fake_y), REAL_LABEL))
else:
# heuristic, non-saturating loss
loss = -tf.reduce_mean(ops.safe_log(D(fake_y))) / 2
return loss
def _generator_loss(self, disc_fake, use_lsgan=True):
""" fool discriminator into believing that G(x) is real
"""
if use_lsgan:
# use mean squared error
loss = tf.reduce_mean(tf.squared_difference(disc_fake, REAL_LABEL))
else:
# heuristic, non-saturating loss
loss = -tf.reduce_mean(ops.safe_log(D(disc_fake))) / 2
return loss
def cycle_consistency_loss(self, G, F, x, y):
""" cycle consistency loss (L1 norm)
"""
forward_loss = tf.reduce_mean(tf.abs(F(G(x))-x))
backward_loss = tf.reduce_mean(tf.abs(G(F(y))-y))
loss = self.lambda1*forward_loss + self.lambda2*backward_loss
return loss
def _cycle_consistency_loss(self, fake_xyx, fake_yxy, x, y):
""" cycle consistency loss (L1 norm)
"""
forward_loss = tf.reduce_mean(tf.abs(fake_xyx-x))
backward_loss = tf.reduce_mean(tf.abs(fake_yxy-y))
loss = self.lambda1*forward_loss + self.lambda2*backward_loss
return loss
from cyclegan-tensorflow.
Also I believe here
fake_Y_pool = ImagePool(FLAGS.pool_size)
fake_X_pool = ImagePool(FLAGS.pool_size)
while not coord.should_stop():
# get previously generated images
fake_y_val, fake_x_val = sess.run([fake_y, fake_x])
# train
_, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary = (
sess.run(
[optimizers, G_loss, D_Y_loss, F_loss, D_X_loss, summary_op],
feed_dict={cycle_gan.fake_y: fake_Y_pool.query(fake_y_val),
cycle_gan.fake_x: fake_X_pool.query(fake_x_val)}
)
)
On every iteration we are generating fake images for the discriminator, and then running all of the losses and gradient updates in a second sess run call. This means that we are running the generator twice for every gradient update. I believe you will get the exact same functionality with
fake_Y_pool = ImagePool(FLAGS.pool_size)
fake_X_pool = ImagePool(FLAGS.pool_size)
# prime previously generated images
fake_y_val, fake_x_val = sess.run([fake_y, fake_x])
while not coord.should_stop():
# train
_, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, fake_y_val, fake_x_val, summary = (
sess.run(
[optimizers, G_loss, D_Y_loss, F_loss, D_X_loss, fake_y, fake_x, summary_op],
feed_dict={cycle_gan.fake_y: fake_Y_pool.query(fake_y_val),
cycle_gan.fake_x: fake_X_pool.query(fake_x_val)}
)
)
Except you won't have to do 2 generations for every gradient update. And now that i think about it the first code might be wrong anyway, as when you generate the fake images with
fake_y_val, fake_x_val = sess.run([fake_y, fake_x]) you are using the last generated images. But in reality for the pool code to work as expected you should be feeding in the current fake images 50% of the time. In the first implementation you are never feeding in the current images.
from cyclegan-tensorflow.
Nice! Let me dig around it.
About image pool, the current images are feeding 50% of the time as in: https://github.com/vanhuyz/CycleGAN-TensorFlow/blob/master/utils.py#L49-L57
from cyclegan-tensorflow.
Right so my point about the pool is in the current code you have 2 calls to sess.run in every loop. That means that in every loop 2 sets (each set has an image X and Y) of images are being pulled from your queue runner. The first call generates fake images X and Y from images A1 and B1, the second call calculates the loss by pulling another set of images A2 and B2. You feed in the previously generated fake images which are either (50% from random replay pool) or (50% images generated from A1 and B1). So 50% of the time it is a random one from the past, and 50% of the time it is the one from 1 timestep back. But it is never from the current timestep. So I dont think thats the functionality you want right? Shouldnt it be 50% from past and 50% current ? In addition you are generating two sets of fake images for every gradient update so actually doing more work than you have to.
while not coord.should_stop():
# get previously generated images
# Makes a call to image queue
fake_y_val, fake_x_val = sess.run([fake_y, fake_x])
# train
# makes a call to image queue
_, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary = (
sess.run(
[optimizers, G_loss, D_Y_loss, F_loss, D_X_loss, summary_op],
feed_dict={cycle_gan.fake_y: fake_Y_pool.query(fake_y_val),
cycle_gan.fake_x: fake_X_pool.query(fake_x_val)}
)
)
from cyclegan-tensorflow.
Related Issues (20)
- Missing the ngf parameter for F generator HOT 2
- CPU 100%, but training never starts HOT 2
- verbosity error HOT 4
- can be used to remove the watermark ? HOT 1
- Nothing in cityscapes downloaded by download_dataset.sh HOT 1
- Error happen trying to export. HOT 1
- Control dependencies in CycleGAN optimizer
- image size, default: 256
- 怎么使用预训练的模型啊
- Observing backward pass while running inference.py with a pb file, which is consuming lot of time
- NO FILE IN THE Checkpoints AFTER Train
- ops.py重的Instance Normalization是不是有个地方搞错了? HOT 2
- Use python train.py --load_model XXXX problem
- can this
- Can this code work with gray scale image or rectangle img?
- How to deal with high contrast background colors?
- How to convert the ckpt model into savedmodel HOT 1
- The training doesn't start
- Training doesn't start HOT 1
- pretrained examples: freezed or not? any data preprocessing?
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from cyclegan-tensorflow.