Giter Site home page Giter Site logo

Comments (20)

JiahuiYu avatar JiahuiYu commented on May 21, 2024 31

Your usage is not correct actually. The build graph function should always be called once in all tensorflow-based code, unless you want to reuse the graph. I've modified it for your case. Please use the following code:

    sess_config = tf.ConfigProto()                                                                                                                                                                                                            
    sess_config.gpu_options.allow_growth = True                                                                                                                                                                                               
    sess = tf.Session(config=sess_config)                                                                                                                                                                                                     
                                                                                                                                                                                                                                              
    model = InpaintCAModel()                                                                                                                                                                                                                  
    input_image_ph = tf.placeholder(                                                                                                                                                                                                          
        tf.float32, shape=(1, args.image_height, args.image_width*2, 3))                                                                                                                                                                      
    output = model.build_server_graph(input_image_ph)                                                                                                                                                                                         
    output = (output + 1.) * 127.5                                                                                                                                                                                                            
    output = tf.reverse(output, [-1])                                                                                                                                                                                                         
    output = tf.saturate_cast(output, tf.uint8)                                                                                                                                                                                               
    vars_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)                                                                                                                                                                              
    assign_ops = []                                                                                                                                                                                                                           
    for var in vars_list:                                                                                                                                                                                                                     
        vname = var.name                                                                                                                                                                                                                      
        from_name = vname                                                                                                                                                                                                                     
        var_value = tf.contrib.framework.load_variable(                                                                                                                                                                                       
            args.checkpoint_dir, from_name)                                                                                                                                                                                                   
        assign_ops.append(tf.assign(var, var_value))                                                                                                                                                                                          
    sess.run(assign_ops)                                                                                                                                                                                                                      
    print('Model loaded.')                                                                                                                                                                                                                    
                                                                                                                                                                                                                                              
    with open(args.flist, 'r') as f:                                                                                                                                                                                                          
        lines = f.read().splitlines()                                                                                                                                                                                                         
    t = time.time()                                                                                                                                                                                                                           
    for line in lines:                                                                                                                                                                                                                                                                                                                                                                                                                                     
        image, mask, out = line.split()                                                                                                                                                                                                       
        base = os.path.basename(mask)                                                                                                                                                                                                         
                                                                                                                                                                                                                                              
        image = cv2.imread(image)                                                                                                                                                                                                             
        mask = cv2.imread(mask)                                                                                                                                                                                                               
        image = cv2.resize(image, (args.image_width, args.image_height))                                                                                                                                                                      
        mask = cv2.resize(mask, (args.image_width, args.image_height))                                                                                                                                                                        
        # cv2.imwrite(out, image*(1-mask/255.) + mask)                                                                                                                                                                                        
        # # continue                                                                                                                                                                                                                          
        # image = np.zeros((128, 256, 3))                                                                                                                                                                                                     
        # mask = np.zeros((128, 256, 3))                                                                                                                                                                                                      
                                                                                                                                                                                                                                              
        assert image.shape == mask.shape                                                                                                                                                                                                      
                                                                                                                                                                                                                                              
        h, w, _ = image.shape                                                                                                                                                                                                                 
        grid = 4                                                                                                                                                                                                                              
        image = image[:h//grid*grid, :w//grid*grid, :]                                                                                                                                                                                        
        mask = mask[:h//grid*grid, :w//grid*grid, :]                                                                                                                                                                                          
        print('Shape of image: {}'.format(image.shape))                                                                                                                                                                                       
                                                                                                                                                                                                                                              
        image = np.expand_dims(image, 0)                                                                                                                                                                                                      
        mask = np.expand_dims(mask, 0)                                                                                                                                                                                                        
        input_image = np.concatenate([image, mask], axis=2)                                                                                                                                                                                   
                                                                                                                                                                                                                                              
        # load pretrained model                                                                                                                                                                                                               
        result = sess.run(output, feed_dict={input_image_ph: input_image})                                                                                                                                                                    
        print('Processed: {}'.format(out))                                                                                                                                                                                                    
        cv2.imwrite(out, result[0][:, :, ::-1])                                                                                                                                                                                               
                                                                                                                                                                                                                                              
    print('Time total: {}'.format(time.time() - t)) 

from generative_inpainting.

JiahuiYu avatar JiahuiYu commented on May 21, 2024 2

"We have not found perceptual loss (reconstruction loss on VGG features), style loss (squared Frobenius norm of Gram matrix computed on the VGG features) [21] and total variation (TV) loss bring noticeable improvements for image inpainting in our framework, thus are not used."

You will need to implement VGG16 perceptual loss by yourself.

from generative_inpainting.

TrinhQuocNguyen avatar TrinhQuocNguyen commented on May 21, 2024 1

Oh thank you, I have found the answer: Just set the parameter reuse = tf.AUTO_REUSE
output = model.build_server_graph(input_image, reuse=tf.AUTO_REUSE)
The tensorflow will automatically understand and reuse the graph.

from generative_inpainting.

Bingmang avatar Bingmang commented on May 21, 2024 1

These codes should be added to the master branch 😍 😍 😍

from generative_inpainting.

JeremyCJM avatar JeremyCJM commented on May 21, 2024 1

Your usage is not correct actually. The build graph function should always be called once in all tensorflow-based code, unless you want to reuse the graph. I've modified it for your case. Please use the following code:

    sess_config = tf.ConfigProto()                                                                                                                                                                                                            
    sess_config.gpu_options.allow_growth = True                                                                                                                                                                                               
    sess = tf.Session(config=sess_config)                                                                                                                                                                                                     
                                                                                                                                                                                                                                              
    model = InpaintCAModel()                                                                                                                                                                                                                  
    input_image_ph = tf.placeholder(                                                                                                                                                                                                          
        tf.float32, shape=(1, args.image_height, args.image_width*2, 3))                                                                                                                                                                      
    output = model.build_server_graph(input_image_ph)                                                                                                                                                                                         
    output = (output + 1.) * 127.5                                                                                                                                                                                                            
    output = tf.reverse(output, [-1])                                                                                                                                                                                                         
    output = tf.saturate_cast(output, tf.uint8)                                                                                                                                                                                               
    vars_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)                                                                                                                                                                              
    assign_ops = []                                                                                                                                                                                                                           
    for var in vars_list:                                                                                                                                                                                                                     
        vname = var.name                                                                                                                                                                                                                      
        from_name = vname                                                                                                                                                                                                                     
        var_value = tf.contrib.framework.load_variable(                                                                                                                                                                                       
            args.checkpoint_dir, from_name)                                                                                                                                                                                                   
        assign_ops.append(tf.assign(var, var_value))                                                                                                                                                                                          
    sess.run(assign_ops)                                                                                                                                                                                                                      
    print('Model loaded.')                                                                                                                                                                                                                    
                                                                                                                                                                                                                                              
    with open(args.flist, 'r') as f:                                                                                                                                                                                                          
        lines = f.read().splitlines()                                                                                                                                                                                                         
    t = time.time()                                                                                                                                                                                                                           
    for line in lines:                                                                                                                                                                                                                                                                                                                                                                                                                                     
        image, mask, out = line.split()                                                                                                                                                                                                       
        base = os.path.basename(mask)                                                                                                                                                                                                         
                                                                                                                                                                                                                                              
        image = cv2.imread(image)                                                                                                                                                                                                             
        mask = cv2.imread(mask)                                                                                                                                                                                                               
        image = cv2.resize(image, (args.image_width, args.image_height))                                                                                                                                                                      
        mask = cv2.resize(mask, (args.image_width, args.image_height))                                                                                                                                                                        
        # cv2.imwrite(out, image*(1-mask/255.) + mask)                                                                                                                                                                                        
        # # continue                                                                                                                                                                                                                          
        # image = np.zeros((128, 256, 3))                                                                                                                                                                                                     
        # mask = np.zeros((128, 256, 3))                                                                                                                                                                                                      
                                                                                                                                                                                                                                              
        assert image.shape == mask.shape                                                                                                                                                                                                      
                                                                                                                                                                                                                                              
        h, w, _ = image.shape                                                                                                                                                                                                                 
        grid = 4                                                                                                                                                                                                                              
        image = image[:h//grid*grid, :w//grid*grid, :]                                                                                                                                                                                        
        mask = mask[:h//grid*grid, :w//grid*grid, :]                                                                                                                                                                                          
        print('Shape of image: {}'.format(image.shape))                                                                                                                                                                                       
                                                                                                                                                                                                                                              
        image = np.expand_dims(image, 0)                                                                                                                                                                                                      
        mask = np.expand_dims(mask, 0)                                                                                                                                                                                                        
        input_image = np.concatenate([image, mask], axis=2)                                                                                                                                                                                   
                                                                                                                                                                                                                                              
        # load pretrained model                                                                                                                                                                                                               
        result = sess.run(output, feed_dict={input_image_ph: input_image})                                                                                                                                                                    
        print('Processed: {}'.format(out))                                                                                                                                                                                                    
        cv2.imwrite(out, result[0][:, :, ::-1])                                                                                                                                                                                               
                                                                                                                                                                                                                                              
    print('Time total: {}'.format(time.time() - t)) 

Should be:

    output = model.build_server_graph(FLAGS, input_image_ph)                                                                                                                                                                         

from generative_inpainting.

JiahuiYu avatar JiahuiYu commented on May 21, 2024

It would be even more efficient if you can build graph ONCE with placeholder and feed your images with sess.run. A related issue can be found #8.

from generative_inpainting.

TrinhQuocNguyen avatar TrinhQuocNguyen commented on May 21, 2024

Hello JiahuiYu,
Thank you for your quick response. Did you mean sess.run ?
I'm reading your source code to understand what you have done.

from generative_inpainting.

JiahuiYu avatar JiahuiYu commented on May 21, 2024

Sorry typo.

from generative_inpainting.

TrinhQuocNguyen avatar TrinhQuocNguyen commented on May 21, 2024

Hello JiahuiYu,
Thank you for your response. I'm building the graph.
In inpaint.yml file, at #loss legacy line. I have found that VGG_MOEL_FILE you have configured, I have read your paper, it did not mention transfer learning. So, I wonder whether we can use VGG16 network for transfer learning?
Thank you for your concerns.

from generative_inpainting.

TrinhQuocNguyen avatar TrinhQuocNguyen commented on May 21, 2024

Thank you for your fast response.
I have used your pretrained model to apply transfer learning, it saved me a lot of time on a new training set.
I am reading your paper again, I think it's a great paper.

from generative_inpainting.

TrinhQuocNguyen avatar TrinhQuocNguyen commented on May 21, 2024

Hello Jiahuiyu,
Thank you for your awesome code, I have tried to modify and build the graph, but unfortunately I could not build it.
I have found that you have used build_server_graph function, but I don't understand it much. Could you please add some code you have built the graph and feed image by image into it?
Thank you in advance.

from generative_inpainting.

TrinhQuocNguyen avatar TrinhQuocNguyen commented on May 21, 2024

Here is my code at the moment: use a for loop

# prepare folder path
    input_folder = args.test_dir + "/input"
    mask_folder = args.test_dir + "/mask"
    output_folder = args.test_dir + "/output_" + args.checkpoint_dir.split("/")[1] + "_" +datetime.datetime.now().strftime("%Y%m%d%H%M%S")

    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    # start sess configuration
    sess_config = tf.ConfigProto()
    sess_config.gpu_options.allow_growth = True
    


    dir_files = os.listdir(input_folder)
    dir_files.sort()

    for file_inter in dir_files:
        sess = tf.Session(config=sess_config)
        
        base_file_name = os.path.basename(file_inter)

        image = cv2.imread(input_folder + "/" + base_file_name)
        mask = cv2.imread(mask_folder + "/" + base_file_name)

        assert image.shape == mask.shape

        h, w, _ = image.shape
        grid = 1
        image = image[:h//grid*grid, :w//grid*grid, :]
        mask = mask[:h//grid*grid, :w//grid*grid, :]
        print('Shape of image: {}'.format(image.shape))

        image = np.expand_dims(image, 0)
        mask = np.expand_dims(mask, 0)
        input_image = np.concatenate([image, mask], axis=2)

        input_image = tf.constant(input_image, dtype=tf.float32)
        output = model.build_server_graph(input_image, reuse=tf.AUTO_REUSE)
        output = (output + 1.) * 127.5
        output = tf.reverse(output, [-1])
        output = tf.saturate_cast(output, tf.uint8)
        # load pretrained model
        vars_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
        assign_ops = []
        for var in vars_list:
            vname = var.name
            from_name = vname
            var_value = tf.contrib.framework.load_variable(args.checkpoint_dir, from_name)
            assign_ops.append(tf.assign(var, var_value))
        sess.run(assign_ops)
        print('Model loaded.')
        result = sess.run(output)

        # write to output folder
        cv2.imwrite(output_folder + "/" + base_file_name, result[0][:, :, ::-1])
        sess.close()

from generative_inpainting.

TrinhQuocNguyen avatar TrinhQuocNguyen commented on May 21, 2024

Hi JiahuiYu ,
Thank you very much for your code and your contribution. I am so excited to check it out.
Thank you again 😄 😄 😄 😄 😄 😄 😄 😄 😄 😄

from generative_inpainting.

TrinhQuocNguyen avatar TrinhQuocNguyen commented on May 21, 2024

Hi JiahuiYu ,
wow, it worked. Thank you very much, you have saved me tons of time. 😍 😍 😍

from generative_inpainting.

JiahuiYu avatar JiahuiYu commented on May 21, 2024

No problem. :)

from generative_inpainting.

TianLuluC avatar TianLuluC commented on May 21, 2024

These codes should be added to the master branch 😍 😍 😍

@Bingmang Is the code added to the for loop of test.py? Thank you

from generative_inpainting.

JiahuiYu avatar JiahuiYu commented on May 21, 2024

I have made this thread open so others can have a reference.

from generative_inpainting.

zylxadz avatar zylxadz commented on May 21, 2024

@TrinhQuocNguyen Thank you very much for your discussions about training a new model!
And could you give me more instructions to pre-train a model with transfer learning? Thanks a lot !

from generative_inpainting.

minushuang avatar minushuang commented on May 21, 2024

great!

from generative_inpainting.

arnavmehta7 avatar arnavmehta7 commented on May 21, 2024

Hey I'm trying since days to customize some part, can you explain me how to access model and run model.summary() ???

from generative_inpainting.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.