Giter Site home page Giter Site logo

liuchang97 / multi-domain-belief-tracking Goto Github PK

View Code? Open in Web Editor NEW

This project forked from osmanio2/multi-domain-belief-tracking

0.0 1.0 0.0 26 KB

The implementation of the model proposed in the Large-Scale Multi-Domain Belief Tracking with Knowledge Sharing paper

Python 100.00%

multi-domain-belief-tracking's Introduction

Large-Scale Multi-Domain Belief Tracking with Knowledge Sharing

The implementation of the model proposed in the Large-Scale Multi-Domain Belief Tracking with Knowledge Sharing paper

Requirements

Python 3+ with pip

Setup

To install the python modules required to train and run the model

pip install -r requirements.txt

Preprocessing

To download and pre-process the multi-domain belief tracking dialogues and pre-trained word embeddings

python preprocess.py

Training

To train the model

python main.py train [--args=value]

Some of these args include:

  • --num_hid: the size of the hidden layers - default is 50 (the paper uses 100)
  • --bidir/--no-bidir: bidirectional vs forward only - default is bidirectional (only valid for a RNNs)
  • --net_type: the type of the feature encoders, gru for GRU, cnn for CNN, lstm for LSTM - default is lstm (See the paper)
  • --batch_size: the batch size - default is 64
  • --dev: device use to train the model (cpu or gpu) - default is gpu
  • --model_url: path to save the model or to resume training - default is models/model-1
  • --graph_url: path to save the tensorboard graph for evaluation metrics (cross-entropy, accuracy ..etc) - default is graphs/graph-1

For example to train the cnn variant to the model with 8 batch size on the GPU

python main.py train --batch_size=8 --net_type=cnn --dev=gpu

Testing

To evaluate the model on the test dataset

python main.py test [--args=value]

This uses the same arguments as above, except --dev is not included as cpu is used for testing. It generates a log file at results/log-1.txt, which includes the original dialogues with the true labels and the model predictions.

Currently there is a bug that shuffles the model predictions across dialogues. So to go around it, make sure to use a batch size of 1, i.e. --batch_size=1.

multi-domain-belief-tracking's People

Contributors

osmanio2 avatar

Watchers

James Cloos avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.