Giter Site home page Giter Site logo

Comments (3)

wbhu avatar wbhu commented on August 25, 2024

Hi @sooyeonshin,

There are bugs in the source code updated 2 weeks ago. And I am training the fixed model right now, you could also have a try of the updated code.

thanks.

from dncnn-tensorflow.

sooyeonshin avatar sooyeonshin commented on August 25, 2024

oh last updated version seem's good.

thank you for answer my question.

and it's just my advice, It will be better, you modify save-load model part.

save routine coded like this
# save the model
if np.mod(iter_num, self.save_every_epoch) == 0:

I think change iter_num to epoch (if you want save model by epock size) or change self.save_every_epoch to number of iteration.

also in your code, exist save model method but not exist load pretrained model before train.

so I try to coding load part like this (add self.ckpt_train at initial part)

   if self.ckpt_train : # if your want load pretrained model, ckpt_train=true 
        model_dir = "%s_%s_%s" % (self.trainset, self.batch_size, self.patch_sioze)
        checkpoint_dir2 = os.path.join(self.ckpt_dir, model_dir)
        full_path = tf.train.latest_checkpoint(checkpoint_dir2)
        self.saver.restore(self.sess, full_path)
        print("[*] loaded check point model = %s" % full_path)
        global_step = int(full_path.split('/')[-1].split('-')[-1])
        start_epoch = global_step // numBatch
        start_step = global_step % numBatch
        print("[*] Model restore finished, current globle step: %d" % global_step)
        print("[*] previous final epoch: %d" % start_epoch)
        print("[*] start step: %d" % start_step)
        iter_num = global_step

        for epoch in xrange(**start_epoch,** self.epoch): #start at pretrained epoch
            for batch_id in xrange(start_step, numBatch):
                batch_images = data[batch_id * self.batch_size:(batch_id + 1) * self.batch_size, :, :, :]
                batch_images = np.array(batch_images / 255.0, dtype=np.float32)     #normalize the data to 0-1
                train_images = add_noise(batch_images, self.sigma, self.sess)
                _, loss, summary = self.sess.run([self.train_step, self.loss, merged], \
                                                 feed_dict={self.X: train_images, self.X_: batch_images})
                print("Epoch: [%2d] [%4d/%4d] time: %4.4f, loss: %.6f" \
                      % (epoch + 1, batch_id + 1, numBatch,
                         time.time() - start_time, loss))
                iter_num += 1
                writer.add_summary(summary, iter_num)
            if np.mod(epoch, self.eval_every_epoch) == 0:
                self.evaluate(epoch, iter_num, test_data)  # test_data value range is 0-255
            # save the model
            if np.mod(epoch, self.save_every_epoch) == 0:
                self.save(iter_num)
        print("[*] Finish training.")

    else : #if you want train at first, ckpt_train is false and this routine is not changed
        for epoch in xrange(self.epoch):
            for batch_id in xrange(numBatch):
                batch_images = data[batch_id * self.batch_size:(batch_id + 1) * self.batch_size, :, :, :]
                batch_images = np.array(batch_images / 255.0, dtype=np.float32)     #normalize the data to 0-1
                train_images = add_noise(batch_images, self.sigma, self.sess)
                _, loss, summary = self.sess.run([self.train_step, self.loss, merged], \
                                                 feed_dict={self.X: train_images, self.X_: batch_images})
                print("Epoch: [%2d] [%4d/%4d] time: %4.4f, loss: %.6f" \
                      % (epoch + 1, batch_id + 1, numBatch,
                         time.time() - start_time, loss))
                iter_num += 1
                writer.add_summary(summary, iter_num)
            if np.mod(epoch, self.eval_every_epoch) == 0:
                self.evaluate(epoch, iter_num, test_data)  # test_data value range is 0-255
            # save the model
            if np.mod(epoch, self.save_every_epoch) == 0:
                self.save(iter_num)
        print("[*] Finish training.")

It may be unnecessary interference.

anyway I got alot of help from your work. thanks alot.

from dncnn-tensorflow.

wbhu avatar wbhu commented on August 25, 2024

Hi @sooyeonshin ,

Thanks for your advice. You could pull a request for the improved code, it will be greatly appreciated.

from dncnn-tensorflow.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.