Giter Site home page Giter Site logo

dongjunlee / relation-network-tensorflow Goto Github PK

View Code? Open in Web Editor NEW
2.0 3.0 1.0 435 KB

TensorFlow implementation of 'A simple neural network module for relational reasoning' for bAbi task.

Python 93.39% Shell 1.27% Jupyter Notebook 5.35%
tensorflow relational-reasoning babi qa hb-experiment rn-module

relation-network-tensorflow's Introduction

Relation Network hb-research

TensorFlow implementation of A simple neural network module for relational reasoning for bAbi task.

images

  • image: A Simple Neural Network Module for Relational Reasoning Slides by Xiadong Gu

Requirements

Project Structure

init Project by hb-base

.
├── config                  # Config files (.yml, .json) using with hb-config
├── data                    # dataset path
├── notebooks               # Prototyping with numpy or tf.interactivesession
├── relation_network        # relation network architecture graphs (from input to logits)
    ├── __init__.py             # Graph logic
    ├── encoder.py              # Encoder
    └── relation.py             # RN Module
├── data_loader.py          # raw_date -> precossed_data -> generate_batch (using Dataset)
├── hook.py                 # training or test hook feature (eg. print_variables)
├── main.py                 # define experiment_fn
└── model.py                # define EstimatorSpec

Reference : hb-config, Dataset, experiments_fn, EstimatorSpec

Todo

  • model was trained on the joint version of bAbI (all 20 tasks simultaneously), using the full dataset of 10K examples per task. (paper experiments)

Config

Can control all Experimental environment.

example: bAbi_task1.yml

data:
  base_path: 'data/'
  task_path: 'en-10k/'
  task_id: 1
  PAD_ID: 0

model:
  batch_size: 64
  use_pretrained: false             # (true or false)
  embed_dim: 32                     # if use_pretrained: only available 50, 100, 200, 300
  encoder_type: uni                 # uni, bi
  cell_type: lstm                    # lstm, gru, layer_norm_lstm, nas
  num_layers: 1
  num_units: 32
  dropout: 0.5

  g_units:
    - 64
    - 64
    - 64
    - 64
  f_units:
    - 64
    - 128


train:
  learning_rate: 0.00003
  optimizer: 'Adam'                # Adagrad, Adam, Ftrl, Momentum, RMSProp, SGD

  train_steps: 200000
  model_dir: 'logs/bAbi_task1'

  save_checkpoints_steps: 1000
  check_hook_n_iter: 1000
  min_eval_frequency: 1

  print_verbose: False
  debug: False

slack:
  webhook_url: ""                  # after training notify you using slack-webhook
  • debug mode : using tfdbg

Usage

Install requirements.

pip install -r requirements.txt

Then, prepare dataset.

sh scripts/fetch_babi_data.sh

Finally, start train and evaluate model

python main.py --config bAbi_task1 --mode train_and_evaluate

Experiments modes

✅ : Working
◽ : Not tested yet.

  • evaluate : Evaluate on the evaluation data.
  • extend_train_hooks : Extends the hooks for training.
  • reset_export_strategies : Resets the export strategies with the new_export_strategies.
  • run_std_server : Starts a TensorFlow server and joins the serving thread.
  • test : Tests training, evaluating and exporting the estimator for a single step.
  • train : Fit the estimator using the training data.
  • train_and_evaluate : Interleaves training and evaluation.

Tensorboar

tensorboard --logdir logs

  • bAbi_task1

images

Reference

Author

Dongjun Lee ([email protected])

relation-network-tensorflow's People

Contributors

dongjunlee avatar

Stargazers

 avatar  avatar

Watchers

 avatar  avatar  avatar

Forkers

hb-research

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.