Giter Site home page Giter Site logo

applicaai / successive-halving-topk Goto Github PK

View Code? Open in Web Editor NEW
16.0 4.0 1.0 557 KB

A fast and highly accurate differentiable Top-k operator from the "Successive Halving Top-k Operator" AAAI'21 paper.

License: MIT License

Python 100.00%
topk differentiable-programming top-k selector operator neural-networks aaai

successive-halving-topk's Introduction

Successive Halving Top-k Operator

This repository contains a demonstrative implementation of Successive Halving Top-k Operator, complementing the Applica.ai publication, accepted at AAAI'2021. See arXiv, pdf, conference site. Cite us as:

@article{pietruszka2020successive,
      title={Successive Halving Top-k Operator},
      volume={35},
      url={https://ojs.aaai.org/index.php/AAAI/article/view/17931},
      number={18},
      journal={Proceedings of the AAAI Conference on Artificial Intelligence},
      author={Pietruszka, Michał and Borchmann, Łukasz and Graliński, Filip},
      year={2021},
      month={May},
      pages={15869-15870}
}

Reproduce

You can reproduce figures from the paper by:

  1. Generating files with performance metrics (csv format) with ./benchmarker/benchmark.py.
  2. Making figures from these csv files with ./plotters/make_figures.py.

See provided csv file in ./benchmark_log_16003623822_cuda:0.csv that will be used by default.
Note: By default, 'cpu' will be used, but 'cuda' version is available in pooler_arena/trainer/benchmark.py.

Example

You may also be interested in using this approach in your code. The simple guide on using it is below and in ./examples/minimal_example.py.

1. Create a topk operator to select k out of n.

from topk_arena.models.successive_halving_topk import TopKOperator, TopKConfig
import torch

# Input your settings
k = 256     # your k
n = 8192    # your n
depth = 32  # depth of the representations(vectors, embeddings etc.)

# Build TopK operator and configure it.
topk = TopKOperator()
cfg = TopKConfig(input_len=n,
                 pooled_len=k,
                 base=20,       # the bigger the better approximation, but can be unstable
                 )
topk.set_config(cfg)

2. Prepare a dataset (here just random in [-1, 1]).

embeddings = torch.rand((1, n, depth)) * 2 - 1
scores = torch.rand((1, n, 1))

3. Select with Successive Halving TopK operator.

out_embs, out_scores = topk(embeddings, scores)
out_scores.unsqueeze_(2)

4. Let's see how good the approximation was.

We will look at the approximation of the top-1 scoring vector.

top1_hard = embeddings[0, scores.argmax(1).squeeze(), :]
top1_soft = out_embs[0, 0, :]
assert top1_hard.shape == top1_soft.shape
cosine_sim = torch.cosine_similarity(top1_hard, top1_soft, dim=0)   # this should be ~1.0
print(f'Approximation quality of Successive Halving TopK for top-1,'
      f' as measured by cosine similarity is {cosine_sim.item()}.')

The expected output should be something like this:

Approximation quality of Successive Halving TopK for top-1,
 as measured by cosine similarity is 0.9996941685676575.

This repository will hopefully solve your problems! :)

Disclaimer: this is not an official Applica.ai product (experimental or otherwise).

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.