Giter Site home page Giter Site logo

salvaba94 / g2net Goto Github PK

View Code? Open in Web Editor NEW
16.0 1.0 2.0 18.74 MB

Find gravitational wave signals from binary black hole collisions.

License: GNU General Public License v3.0

Python 100.00%
tensorflow deep-learning keras machine-learning g2net black-holes tukey-window bandpass-filter bandpass-filters spectrograms

g2net's Introduction

G2Net - Gravitational Wave Detection

Contributors Forks Stars Issues GPL License

Table of Contents
  1. About the Project
  2. Getting Started
  3. Contributing
  4. Acknowledgments

About the Project

Not only a network of Gravitational Waves, Geophysics and Machine Learning experts, G2Net was also released as a Kaggle Competition. G2Net origin dates back to the discovery of Gravitational Waves (GW) in 2015 (The Sound of Two Black Holes Colliding). The aim of this competition was to detect GW signals from the mergers of binary black holes. Specifically, the participant was expected to create and train a model to analyse synthetic GW time-series data from a network of Earth-based detectors (LIGO Hanford, LIGO Livingston and Virgo). The implementations in this repository skyrocketed the ranking (AUC score on test set) to top 8% under certain settings, not meaning with the above that it cannot be further improved.

Contents

The Model

The model implemented for the competition (see the image below) has been created following an end-to-end philosophy, meaning that even the time-series pre-processing logic is included as part of the model and might be made trainable. To know more details about the building blocks of the model, refer to any of the Colab Guides provided by the project.

G2Net Model

Back to top

Major Files

The major project source code files are listed below in a tree-like fashion:

    G2Net
      └───src
          │   config.py
          │   main.py
          ├───ingest
          │       DatasetGeneratorTF.py
          │       NPYDatasetCreator.py
          │       TFRDatasetCreator.py
          ├───models
          │       ImageBasedModels.py
          ├───preprocess
          │       Augmentation.py
          │       Preprocessing.py
          │       Spectrogram.py
          ├───train
          │       Acceleration.py
          │       Losses.py
          │       Schedulers.py
          └───utilities
                  GeneralUtilities.py
                  PlottingUtilities.py

The most important elements in the project are outlined and described as follows:

  • config.py: Contains a configuration class with the parameters used by the model or the training process and other data ingestion options.
  • main.py: Implements the functionality to train and predict with the model locally in GPU/CPU.
  • Ingest module:
    • NPYDatasetCreator.py: Implements the logic to standardise the full dataset on a multiprocessing fashion.
    • TFRDatasetCreator.py: Implements the logic to standardise, encode, create and decode TensorFlow records.
    • DatasetGeneratorTF.py: Includes a class implementing functionality to create TensorFlow Datasets pipelines from both TensorFlow records and NumPy files.
  • Models module:
    • ImageBasedModels.py: Includes a Keras model based on 2D convolutions preceded by a pre-processing phase culminated with the generation of a spectrogram or similar. The 2D convolutional model is here an EfficientNet v2.
  • Preprocess module:
    • Augmentation.py: Implements several augmentations in the form of Keras layers, including Gaussian noise, spectral masking (TPU-compatible and TPU-incompatible versions) and channel permutation.
    • Preprocessing.py: Implements several preprocessing layers in the form of trainable Keras layers, including time windows (TPU-incompatible Tukey window and generic TPU-compatible window), bandpass filtering and spectral whitening.
    • Spectrogram.py: Includes a TensorFlow version of CQT1992v2 implemented in nnAudio with PyTorch. Being in the form of a Keras layer, it also adds functionality to adapt the output range to that recommended as per stability by 2D convolutional models.
  • Train module:
    • Acceleration.py: Includes the logic to automatically configure the TPU if any.
    • Losses.py: Implements a differentiable loss whose minimisation directly maximises the AUC score.
    • Schedulers.py: Implements a wrapper to make CosineDecayRestarts learning rate scheduler compatible with ReduceLROnPlateau.
  • Utilities module:
    • GeneralUtilities.py: General utilities used all along the project mainly to perform automatic Tensor broadcast and determine mean and standard deviation from a dataset with multiprocessing capabilities.
    • PlottingUtilities.py: Includes all the logic behind the plots.

Back to top

Dependencies

Among others, the project has been built around the following major Python libraries (check config/g2net.yml for a full list of dependencies with tested versions):

  • (version 2.x)

Back to top

Getting Started

Locally

Installation

In order to make use of the project locally (tested in Windows), one should just follow two steps:

  1. Clone the project:
  git clone https://github.com/salvaba94/G2Net.git
  1. Assuming that Anaconda Prompt is installed, run the following command to install the dependencies:
  conda env create --file g2net.yml

Back to top

Coding

To experiment locally:

  1. First, you'll need to manually download the Competition Data as the code is not going to do it for you to avoid problems with connectivity (while downloading a heavy dataset). Paste the content into the raw_data folder.
  2. The controls of the code are in src/config.py. Make sure that, the first time you run the code, any of GENERATE_TFR or GENERATE_NPY flags are set to True. This will generate standardised datasets in TensorFlow records or NumPy files, respectively.
  3. Set to False these flags and make sure that you are reading the data in the format you generated with the flag FROM_TFR.
  4. You are ready to play with the rest of options!

Back to top

Troubleshooting

If by any chance you experience a NotImplementedError (see below), it is an incompatibility issue between the installed TensorFlow and NumPy library versions. It is related to a change in exception types that makes it to be uncaught.

  NotImplementedError: Cannot convert a symbolic Tensor (gradient_tape/model/bandpass/irfft_2/add:0) to a numpy array. 
  This error may indicate that you're trying to pass a Tensor to a NumPy call, which is not supported.

The origin is in line 867 in tensorflow/python/framework/ops.py. It is solved by replacing

  def __array__(self):
    raise NotImplementedError(
        "Cannot convert a symbolic Tensor ({}) to a numpy array."
        " This error may indicate that you're trying to pass a Tensor to"
        " a NumPy call, which is not supported".format(self.name))

by

  def __array__(self):
    raise TypeError(
        "Cannot convert a symbolic Tensor ({}) to a numpy array."
        " This error may indicate that you're trying to pass a Tensor to"
        " a NumPy call, which is not supported".format(self.name))

Back to top

In Colab

Alternatively, feel free to follow the ad-hoc guides in Colab:

  • (full version)
  • (short version)

Important note: As the notebooks connect with your Google Drive to save trained models, copy them to your Drive and run them from there not from the link. Anyway, Google is going to notify you that the notebooks have been loaded from GitHub and not from your Drive.

Back to top

Contributing

Any contributions are greatly appreciated. If you have suggestions that would make the project any better, fork the repository and create a pull request or simply open an issue. If you decide to follow the first procedure, here is a reminder of the steps:

  1. Fork the project.
  2. Create your branch:
  git checkout -b branchname
  1. Commit your changes:
  git commit -m "Add some amazing feature"
  1. Push to the branch:
  git push origin branchname
  1. Open a pull request.

Back to top

Acknowledgements

Back to top

If you like the project and/or any of this contents results useful to you, don't forget to give it a star! It means a lot to me 😄

g2net's People

Contributors

salvaba94 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

g2net's Issues

Question regarding test prediction submission csv file.

Great work on this repo!

My submissions are no better than 0.5 AUC despite the training & validation going smoothly (~0.86 to 0.88). This seems to imply that the id and prediction are not matching up. However, when I check submission.csv the ids are in the correct order which indicated the predictions are being scambled in some way.

Is there a difference in tensflow/keras flatten or .predict() that is version dependent?
Or is something else going on?

Thanks!
Screenshot 2022-11-29 at 2 09 14 PM

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.