Giter Site home page Giter Site logo

neuralfactortrees's Introduction

Neural Factor Trees for Graph Node Classification

This repository contains code to train and evaluate the graph node classification model described in the "Node Classification in Random Trees" paper by Nuijten and Menkovski. The model can be trained on the SST Dataset.

Features

  • Multi-GPU training
  • Multithread data fetching
  • Tensorboard logging
  • CLI for hyperparameter tuning

Requirements:

  • python
  • numpy
  • pytorch
  • dgl
  • pytorch-lighting
  • (optional) Tensorboard

Please make sure to install pytorch and dgl by hand (pip), since these depend on CUDA versions.

Project Overview

This repository contains code to build and train a Neural Markov Tree. By parameterizing a Gibbs Distribution that factorizes over a Markov Network we are able to estimate the joint probability distribution over vertex labels. The project is structured as follows: train.py contains the training logic, which invokes model/graph_pruning.py, which contains the actual module. Elementary submodules are hidden in model/modules.py such that model/graph_pruning.py only contains the novel graph pruning logic. dataloader.py contains the data fetching scripts and is invoked by multiple threads with pytorch's built-in DataLoader. The entire training loop is invoked by main.py which parses the CLI arguments and trains the model accordingly.

Usage

The model training is invoked by running main.py with command line parameters. For a complete overview of what command line parameters are implemented, run:

python3 main.py -h

An example command would be the following:

python3 main.py --hidden_dim 64 --num_gnn_steps 6 --num_layers 3 --epochs 100 --workers 14

This trains a model using 14 data fetching threads with 64 neurons, 6 message passing steps with a 3 layered MLP as transfer function for 100 epochs.

Viewing logs and performance

Since we use Pytorch Lightning, the metrics we specify are automatically logged and we can upload these to Tensorboard. To do this, run

tensorboard dev upload --logdir tb_logs/GraphPruner --name "GraphPruning"

to upload an experiment to TensorBooard. Here we also have hyperparameter logging under the hyperparameter tab to track the individual settings of the different experiments.

neuralfactortrees's People

Contributors

wouterwln avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.