Giter Site home page Giter Site logo

mdzhangst / mutual-information-variational-bounds Goto Github PK

View Code? Open in Web Editor NEW

This project forked from mboudiaf/mutual-information-variational-bounds

0.0 0.0 0.0 358 KB

A Tensorflow implementation Mutual Information estimation methods

Python 100.00%

mutual-information-variational-bounds's Introduction

Yet to do

  • Include more architectures for the critic network
  • More comments in functions
  • Testing code when X and Z are images
  • Include units test

Variational Bounds on Mutual Information

Throughout this repo, I offer a ready-to-use implementation of state-of-the-art variational methods for mutual information estimation in Tensorflow. This includes:

  • MINE estimator [1], based on the Donsker-Varadhan representation of the KL divergence
  • NWJ estimator [2], based on the f-divergence variational representation of the KL divergence
  • NCE estimator [3], based on noise contrastive estimation principle

Getting Started

Defining the estimator

To define an estimator, one needs to provide a few arguments.

    from estimator import Mi_estimator
    my_mi_estimator = Mi_estimator(regu_name = ...,
                                   dim_x = ...,
                                   dim_z = ...,
                                   batch_size= ...,
                                   critic_layers=[256, 256, 256],
                                   critic_lr=1e-4, 
                                   critic_activation='relu',
                                   critic_type='joint',
                                   ema_decay=0.99,
                                   negative_samples=1)

Measuring the MI between static data

For simple MI estimation between two random variables:
    import numpy as np
    dim_x = 10
    dim_z = 15
    x_data = np.random.rand(data_size, dim_x)
    z_data = np.random.rand(data_size, dim_z)
    
    my_mi_estimator = Mi_estimator(regu_name = 'mine',
                                   dim_x = [dim_x],
                                   dim_z = [dim_z],
                                   batch_size= 128,
                                   critic_layers=[256, 256, 256],
                                   critic_lr=1e-4, 
                                   critic_activation='relu',
                                   critic_type='joint',
                                   ema_decay=0.99,
                                   negative_samples=1)
    mi_estimate = my_mi_estimator.fit(x_data, z_data)

Using MI as a regularizer in TensorFlow graph

For use in a Tensorflow graph (as a regulazation term for instance).
    import tensorflow as tf
    dim = 10
    x = tf.placeholder(shape=[batch_size, dim], tf.float32)
    z = tf.placeholder(shape=[batch_size, dim], tf.float32)
    my_mi_estimator = Mi_estimator(regu_name = 'mine',
                                   dim_x = [dim],
                                   dim_z = [dim],
                                   batch_size= 128,
                                   critic_layers=[256, 256, 256],
                                   critic_lr=1e-4, 
                                   critic_activation='relu',
                                   critic_type='joint',
                                   ema_decay=0.99,
                                   negative_samples=1)
    estimator_train_op, estimator_quantities = my_mi_estimator(x, z)
    
    # Define regularized loss
    loss = loss_unregularized + beta * estimator_quantities['mi_for_grads']
    
    # Define main training op
    main_train_op = optimizer.minimize(loss, var_list=...)
  
    ...
    # At every training iteration:
    
    # Perfom main training operation to optimize loss
    sess.run([main_train_op], feed_dict={x: ..., z= ..., ...})
    # Then update estimator 
    sess.run(estimator_train_op, feed_dict={x: ..., z= ...})   

End-to-End Examples

We provide code to showcase the two functionalities we just talked about

Estimation of mutual information for static data

In the case of correlated Gaussian random variables:

We have access to the mutual information:

First go the example directory

cd examples/correlated_gaussians/

Then, run some tests to make sure the code doesn't yield any bug:

python3 run_test.py

Finally, to run the experiments, you can check all the available options in "demo_gaussian.py", and loop over any parameters by modifying the header of run_exp.py When that is done, simply run:

python3 run_exp.py

After training, plots are avaible in /plots/. You should obtain something like:

The plots above present the estimated MI after 10 epochs of training on a 100k samples dataset with batch size 128, for the three estimation methods.

Using MI as a regularization term in a loss

The most interesting use case of these bounds is in the context of mutual information maximization. A typical example is reduction of mode collapse in GANs. In the context of GANs, the mutual information I(Z;X) is used as a proxy for the entropy of the generator H(X), where X represents the output of the generator, and Z the noise vector. The maximization of I(X;Z) results in the maximization of H(X).

loss_regularized = loss_gan - beta * I(X;Z)

Use adaptive gradient clipping !!

A naive differentiation of the previously defined loss_regularized term will probably fail to yield the expected result. A very probable reason for that is that I(X;Z) will most likely have gradients whose scale are much larger than those from the loss_gan. In order to maintain a certain balance in gradients' scale, on must use adaptive gradient clipping proposed in [1]. This simple trick consists in scaling the MI term such that its gradient norm doesn't exceed the one from the original loss. Concretely, one must instead minimize:

scale = min( ||G(loss)||, ||G(I(X;Z))|| ) / ||G(I(X;Z))||
loss_regularized = loss_gan - beta * scale * I(X;Z)

where G(.) defines the gradient operator with respect to specified parameters.

We provide a simple example in 2D referred to as "25 gaussians experiments" where the target distribution is:

The simple GAN will produce, with the provided generator and discriminator architecture distributions like:

While the above plot clearly exposes a mode collapse, one can easily reduce this mode collapse by adding a MI regularization:

To see the code for this example, first go the example directory

cd examples/gan/

Then, run some tests to make sure the code doesn't yield any bug:

python3 run_test.py

Finally, to run the experiments, you can check all the available options in "demo_gaussian.py", and loop over any parameters by modifying the header of run_exp.py. Then simply run:

python3 run_exp.py

Authors

  • Malik Boudiaf

References

[1] Ishmael Belghazi, Sai Rajeswar, Aristide Baratin, R. Devon Hjelm, and Aaron C.Courville. MINE: mutual information neural estimation.CoRR, https://arxiv.org/abs/1801.04062

[2] Sebastian Nowozin, Botond Cseke, Ryota Tomioka, f-GAN: Training Generative Neural Samplers using Variational Divergence Minimization, https://arxiv.org/abs/1606.00709

[3] Aaron van den Oord, Yazhe Li, Oriol Vinyals, Representation Learning with Contrastive Predictive Coding, https://arxiv.org/abs/1807.03748

mutual-information-variational-bounds's People

Contributors

mboudiaf avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.