Giter Site home page Giter Site logo

peterbjorgensen / msgnet Goto Github PK

View Code? Open in Web Editor NEW
12.0 2.0 3.0 35 KB

Tensorflow implementation of message passing neural networks for molecules and materials

License: MIT License

Python 100.00%
machine-learning materials-science molecules tensorflow graphs

msgnet's Introduction

Msgnet

Tensorflow implementation of message passing neural networks for molecules and materials. The framework implements the SchNet model and its extension with edge update network NMP-EDGE as well as the model used in Materials property prediction using symmetry-labeled graphs as atomic-position independent descriptors.

Currently the implementation does not enable training with forces, but this might be implemented in the future. For a more full-fledged implementation of the SchNet model, see schnetpack.

The main difference between msgnet and schnetpack is that msgnet follows a message passing architecture and can therefore be more flexible in some cases, e.g. it can be used to train on graphs rather than on structures with full spatial information.

Install

Install the dependency

Set the datadir variable in src/msgnet/defaults.py to a preferred path in which the datasets will be saved.

Then run python setup.py install or python setup.py install --user to install the module.

Install datasets

Run the script src/scripts/get_qm9.py to download the QM9 dataset

Run the script src/scripts/get_matproj.py MATPROJ_API_KEY to download the materials project dataset. You need to create a user and obtain an API key from Materials Project.

Run the script python2 src/scripts/get_oqmd.py to convert the OQMD database into an ASE database. You need to manually download and install the OQMD database on your machine to run this script. The OQMD API is only compatible with Python 2, so after running the script you must manually move the oqmd12.db to the datadir set in src/msgnet/defaults.py.

Running the model

To train the model used in the NMP-EDGE paper:

python runner.py --cutoff const 100 --readout sumscalar --num_passes 3 --update_edges --node_embedding_size 64 --dataset qm9 --edge_idx 0 --edge_expand 0.0,0.1,15.0 --learning_rate 5e-4 --target U0

To train the model on OQMD structures using the voronoi graph with symmetry labels: python runner.py --fold 0 --cutoff voronoi 0.2 --readout avgscalar --num_passes 3 --node_embedding_size 256 --dataset oqmd12 --learning_rate 0.0001 --edge_idx 5 6 7 8 9 10 11 12 13 --update_edges

After the model is done training get the test set results by running python predict_with_model.py --modelpath logs/path/to/model/model.ckpt-STEP.meta --output modeloutput.txt --split test

Future Development

The model is implemented such that it avoids any padding/masking. This is achieved by reshaping the variable length inputs into the first dimension of the tensors, which is usually the batch dimension. However, this means we can't use the conventional Tensorflow methods for handling datasets as streams. If the framework is still used in the future I am planning to convert it into a tensorflow keras model when the RaggedTensor implementation is fully supported.

msgnet's People

Contributors

peterbjorgensen avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

msgnet's Issues

ASE databases scalar properties

Is there a reason that in get_matproj.py the scalar properties like delta_e and band_gap are saved as key_value_pairs and in get_qm9.py the 17 properties including tag and index are saved in the "data" dictionary, while they could also be stored as key_value_pairs, if I understand the ase.db documentation correctly? This seems to make the input pipeline a bit more complicated.

To clarify, the edge update functionality is only in `msgnet`, not in `schnetpack`, correct?

From README.md:

Tensorflow implementation of message passing neural networks for molecules and materials.
The framework implements the SchNet model and its extension with edge update network NMP-EDGE as well as the model used in Materials property prediction using symmetry-labeled graphs as atomic-position independent descriptors.

Currently the implementation does not enable training with forces, but this might be implemented in the future.
For a more full-fledged implementation of the SchNet model, see schnetpack.

The main difference between msgnet and schnetpack is that msgnet follows a message passing architecture and can therefore be more flexible in some cases, e.g. it can be used to train on graphs rather than on structures with full spatial information.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.