Giter Site home page Giter Site logo

hhcho / muscat Goto Github PK

View Code? Open in Web Editor NEW
1.0 2.0 0.0 988 KB

Federated system for privacy-preserving pandemic forecasting; a solution for U.S. PETs Prize Challenge

License: MIT License

Shell 1.30% Python 46.58% Dockerfile 1.42% Go 13.19% Jupyter Notebook 37.51%

muscat's Introduction

MusCAT

build status

This repository provides software for Team MusCAT's solution to the U.S. PETs Prize Challenge (Pandemic Forecasting). Team MusCAT won first place for the white paper (Phase 1) and second place in the final stage (Phase 2) of the Challenge.

Problem Setting

Predictive models that can assess individuals' daily infection risks can be a useful tool for pandemic response. Training these models at scale can be challenging due to difficulties in sharing personally identifying information across different data collection sites. The Challenge task was to develop a federated learning system that can jointly leverage distributed private datasets to train and apply risk prediction models with rigorous privacy protection. See here for the official Challenge page introducing the problem.

Our Technical Approach

We introduce MusCAT, a multi-scale federated system for privacy-preserving pandemic risk prediction. We leverage the key insight that predictive information can be divided into components operating at different scales of the problem, including individual contacts, shared locations, and population-level risks. These components are individually learned using a combination of privacy-enhancing technologies to best optimize the tradeoff between privacy and model accuracy. Based on the Challenge dataset, we show that our solution enables improved risk prediction with formal privacy guarantees, while maintaining practical runtimes even with many federation units.

Graphical illustration of MusCAT

Our white paper describing the solution is available here.

Software Components and Methodology

Our solution is implemented in Python and Go.

  • Centralized solution uses Python:

    • solution_centralized.py represents the entrypoint to the solution. It defines the main functions required by the framework (fit() and predict()) and implements MusCAT's general workflow, similar to the one described in Section 3.4 (Privacy-Preserving Federated System for Individual Risk Prediction) of our manuscript.

    • muscat_model.py constructs the MusCAT model and defines each step of MusCAT's workflow (called by solution_centralized.py). See Sections 3.4 and 5.1 (Centralized performance) for discussions of the workflow and its benchmarks.

  • Federated solution uses both Python and Go. The latter is needed for cryptographic operations, and uses a custom fork of Lattigo library for lattice-based homomorphic encryption, as discussed in Section 5.2 (Federated Performance → Implementation Details).

    • solution_federated.py represents the entrypoint to the solution. It defines the main functions required by the framework (e.g., fit(), configure_fit(), ...) and implements MusCAT's general federated workflow.

      • fit() in class TrainClient implements the core of our model training, executed by the clients, with the computation of global statistics (W0-W3 in Section 3.4) and the Poisson regression (W4).
      • aggregate_fit() in class TrainStrategy defines the operations of the server, i.e., securely aggregating encrypted information for the collaboration among the clients, as described in Section 5 (Experimental Results).
      • fit() and evaluate() in class TestClient implement the clients' part of the inference (W6)
      • configure_fit() and aggregate_fit() define the server functions for the same operations. See Section 5.
    • muscat_model.py constructs the MusCAT model and defines each step of MusCAT's federated workflow (called by solution_federated.py), as described in Section 3.4.

    • muscat_privacy.py contains static parameters and functions specific for Differential Privacy (DP). See Sections 3.4, 4 (Privacy Analysis → DP Training), and 5.2 (Federated Performance → Privacy) for a discussion of DP, its implementation and performance.

    • dpmean.py provides multivariate_mean_iterative() that implements CoinPress algorithm for private mean estimation (called by solution_federated.py), as described in section 5.2 (Federated Performance → Privacy).

    • muscat_workflow.py contains static parameters for the secure and plaintext training and testing workflows. It notably defines the training parameters and the order of the rounds to train a model. See Section 3.4 on the workflow.

    • mhe_routines.go represents the Go entrypoint that parses command-line arguments passed to it from Python, and executes a computation corresponding to its step in the Python workflow. This takes the form:

      muscat <command> <arg1> [<arg2> ...]

      where <command> designates a step in the workflow, and <arg1> [<arg2> ...] represents various arguments, which specify either path(s) to the data directory(s), or numeric parameters. It currently enables the setup of the cryptographic parameters and the execution of the Collective Aggregation and Decryption (used during MusCAT's workflow for secure aggregation of the clients' local results by the server), as discussed in Sections 3.4 and 5.

    • mhe/crypto.go contains cryptographic utilities for Multiparty Homomorphic Encryption (MHE, e.g., vectors encryption and decryption), along with some functions to handle disk I/O (e.g., to save and read cryptographic parameters and keys), which is needed for passing data from/to Python. See Section 5.2.1 (Efficiency & Scalability → MHE Operations) on the use of these cryptographic primitives.

    • mhe/protocols.go provides high-level functions that implement disk-assisted client-server communication protocol. See Section 5.2.1 (Efficiency & Scalability) for this protocol implementation.

    • mhe/utilities.go contains auxiliary utilities, including functions to (de)serialize data vectors and matrices from/to disk, in order to pass them from/to Python. See Section 5.2.1 for relevant discussions.

    • go.mod and go.sum configure third-party Go dependencies.

Usage

  1. Make sure you're working on a machine with sufficient memory - 64GB or more is recommended, but it will depend on the overall data size. On macOS, you may need to increase virtual machine memory in Docker Desktop settings.

  2. Prepare a dataset. Sample data can be downloaded from https://net.science/files/resources/datasets/PET_Prize_PandemicForecasting/ (e.g. va_synthetic_population_and_outbreak_ver_1_9.zip).

    After downloading and unpacking a dataset, prepare it according to pandemic-partitioning-example.ipynb notebook.

  3. Install Docker and run the following command:

    docker run --rm -it --pull always \
      -v "$(pwd)/data/pandemic":/code_execution/data:ro \
      -v "$(pwd)/submission":/code_execution/submission \
      ghcr.io/hhcho/muscat centralized # or federated
  4. The tool output will be stored under submission/ folder:

    predictions            cpu_metrics.csv.gz      metrics.json
    centralized-test.log   memory_metrics.csv.gz   scoring_payload
    centralized-train.log  process_metrics.log.gz  state
    log.txt                system_metrics.sar.gz
    

    Here, predictions/<submission_type>/predictions.csv provides results from a sucessful run:

    pid,score
    ...
    195155,0.83926135
    195156,0.8401405
    195157,0.8403996
    195158,0.0
    ...
    

    Similarly to predictions_format.csv, each row represents a person (with a numeric ID) and their risk score of getting infected during the test period, with a higher score corresponding to higher confidence that they become infected.

    If a run fails, you can use log.txt to troubleshoot it.

    *metrics* files contain various internal performance metrics.

    scoring_payload and state store internal state from a run.

Development

When making changes to the code, you can rebuild the Docker image locally using

docker build --platform linux/amd64 -t ghcr.io/hhcho/muscat .

Questions about software

muscat's People

Contributors

dinvlad avatar froelich avatar hhcho avatar

Stargazers

 avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.