Giter Site home page Giter Site logo

rakitinden / pytorch-recursive-gumbel-max-trick Goto Github PK

View Code? Open in Web Editor NEW
11.0 5.0 0.0 1.41 MB

Leveraging Recursive Gumbel-Max Trick for Approximate Inference in Combinatorial Spaces, NeurIPS 2021

License: Apache License 2.0

Python 20.71% C++ 5.13% Jupyter Notebook 74.16%
pytorch gumbel gumbel-max discrete-variables gradient-estimation score-function-estimator neurips-2021 gumbel-distribution binary-tree arborescence

pytorch-recursive-gumbel-max-trick's Introduction

Leveraging Recursive Gumbel-Max Trick for Approximate Inference in Combinatorial Spaces

This repository contains the PyTorch impementation of main algorithms and how-to-use examples from our paper Leveraging Recursive Gumbel-Max Trick for Approximate Inference in Combinatorial Spaces.

The repository contains the code for 4 structured variables:

  • Arborescence (Edmonds' algorithm)
  • Binary tree (divide-and-conquer algorithm)
  • Perfect matching ('crossing' algorithm)
  • Spanning tree (Kruskal's algorithm).

For each of the variables, the implementation contains:

  • Sampling structured variable and the execution trace
  • Calculating log probability of the execution trace
  • Sampling from the conditional distribution of the exponentials given the execution trace
  • A toy optimization experiment.

In addition, the repository contains the implementation of different gradient estimators.

Requirements

  • Python 3.8
  • PyTorch 1.8

Abstract

Structured latent variables allow incorporating meaningful prior knowledge into deep learning models. However, learning with such variables remains challenging because of their discrete nature. Nowadays, the standard learning approach is to define a latent variable as a perturbed algorithm output and to use a differentiable surrogate for training. In general, the surrogate puts additional constraints on the model and inevitably leads to biased gradients. To alleviate these shortcomings, we extend the Gumbel-Max trick to define distributions over structured domains. We avoid the differentiable surrogates by leveraging the score function estimators for optimization. In particular, we highlight a family of recursive algorithms with a common feature we call stochastic invariant. The feature allows us to construct reliable gradient estimates and control variates without additional constraints on the model. In our experiments, we consider various structured latent variable models and achieve results competitive with relaxation-based counterparts.

Citation

@article{struminsky2021leveraging,
  title={Leveraging Recursive Gumbel-Max Trick for Approximate Inference in Combinatorial Spaces},
  author={Struminsky, Kirill and Gadetsky, Artyom and Rakitin, Denis and Karpushkin, Danil and Vetrov, Dmitry P},
  journal={Advances in Neural Information Processing Systems},
  volume={34},
  year={2021}
}

pytorch-recursive-gumbel-max-trick's People

Contributors

rakitinden avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.