Giter Site home page Giter Site logo

fed_cvae's Introduction

FedCVAE

Description

This repository include implementations for a variety of established Federated Learning methods as well as two novel methods, FedCVAE-KD and FedCVAE-Ens. Please see the accompanying paper "Data-Free One-Shot Federated Learning Under Very High Statistical Heterogeneity" (ICLR 2023) for further details. If you found our implementation helpflul, consider citing our work:

@inproceedings{
  heinbaugh2023datafree,
  title     = {Data-Free One-Shot Federated Learning Under Very High Statistical Heterogeneity},
  author    = {Clare Elizabeth Heinbaugh and Emilio Luz-Ricca and Huajie Shao},
  booktitle = {The Eleventh International Conference on Learning Representations },
  year      = {2023},
  url       = {https://openreview.net/forum?id=_hb4vM3jspB}
}

Prerequisites

  1. Python 3.9.x+
  2. pip

Set up

Install Python 3.9 and pip. We recommend using the package pyenv, which is described in this article. Create and enter a new virtual environment and run:

pip3 install -r requirements.txt

This will install the necessary dependencies.

Algorithms

The datasets available for benchmarking are MNIST, FashionMNIST, and SVHN. All examples below use MNIST.

Change to --dataset fashion to use FashionMNIST. Chage to --dataset svhn to use SVHN.

Unachievable Ideal

Run the following from command line.

python3 main.py --algorithm central --dataset mnist --sample_ratio 0.1 --glob_epochs 5 --should_log 1 

Because we are not training this model in a distributed manner, global epochs just refers to the number of epochs for our centralized model.

FedAvg

Run the following from command line.

python3 main.py --algorithm fedavg --dataset mnist --num_users 10 --alpha 0.1 --sample_ratio 0.1 --glob_epochs 5 --local_epochs 1 --should_log 1 --use_adam 1

One-shot ensembled FL

Run the following from command line.

python3 main.py --algorithm oneshot --dataset mnist --num_users 5 --alpha 1.0 --sample_ratio 0.1 --local_epochs 5 --should_log 1 --one_shot_sampling random --user_data_split 0.9 --K 3 --use_adam 1

--one_shot_sampling can take on the following values:

  • random (sample a random subset of K users to ensemble)
  • validation (split each user's data into training and validation and choose the K best scoring user models on the validation set)
  • data (choose the K users with the most data)
  • all (ensemble all user models)

You can also adjust model specific parameters --K to adjust the number of sampled users for ensembling and --user_data_split to adjust the user train/validation split. Note that you need to choose a K <= number of users.

By default, one-shot ensembled FL only trains for 1 global epoch.

FedCVAE-Ens

Run the following from command line.

python3 main.py --algorithm fedcvaeens --dataset mnist --num_users 5 --alpha 1.0 --sample_ratio 0.1 --local_epochs 5 --should_log 1 --z_dim 50 --beta 1.0 --classifier_num_train_samples 1000 --classifier_epochs 5 --uniform_range "(-1.0, 1.0)" --use_adam 1       

You can adjust model specific parameters --z_dim to change the latent vector dimension and --beta to change the weight of the KL divergence loss. Modify --classifier_num_train_samples to change the number of generated samples to train the server classifier and --classifier_epochs to adjust the server classifier train time. Modify --uniform_range to change the uniform range that the decoder uses to draw samples.

By default, FedCVAE-Ens only trains for 1 global epoch.

FedCVAE-KD

Run the following from command line.

python3 main.py --algorithm fedcvaekd --dataset mnist --num_users 5 --alpha 1.0 --sample_ratio 0.1 --glob_epochs 5 --local_epochs 5 --should_log 1 --z_dim 50 --beta 1.0 --classifier_num_train_samples 1000 --classifier_epochs 5 --decoder_num_train_samples 1000 --decoder_epochs 5 --uniform_range "(-1.0, 1.0)" --use_adam 1  

You can adjust model specific parameters --z_dim to change the latent vector dimension and --beta to change the weight of the KL divergence loss. Modify --classifier_num_train_samples to change the number of generated samples to train the server classifier and --classifier_epochs to adjust the server classifier train time. Modify --decoder_num_train_samples to change the number of generated samples to train the server decoder and --decoder_epochs to adjust the server decoder train time. Modify --uniform_range to change the uniform range that the decoder uses to draw samples.

Experiments

  1. --should_weight_exp: Turn on (1) or off (0) weighting when averaging models.
  2. --should_initialize_models_same: Turn on (1) or off (0) initializing all user models with the same weights.
  3. --should_avg_exp: Turn on (1) or off (0) averaging all user decoders for the server decoder (FedCVAE-KD-specific).
  4. --should_fine_tune_exp: Turn on (1) or off (0) fine-tuning the server decoder (FedCVAE-KD-specific).
  5. --heterogeneous_models_exp: Choose whether to use heterogeneous models or not. Pass in a string containing which versions of the CVAE to use. Passing in a string of length 1 yields homogeneous models. Version 0 is the standard CVAE, version 1 is a smaller alternate, and version 2 is ResNet-based. Ex. "012"
  6. --transform_exp: Choose whether to apply transforms for FedCVAE-KD with SVHN (1) or not (0).

Logging

  1. Enable logging by adding the command line argument --should_log 1 to python3 main.py.
  2. Run tensorboard --logdir=runs and navigate to http://localhost:6006/.

Format

  1. Run black . from the repo root.
  2. Run isort . also from the repo root.

fed_cvae's People

Contributors

ceh-2000 avatar emiliolr avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

Forkers

realrui

fed_cvae's Issues

Implement FashionMNIST

We need more datasets than just MNIST.

  • FashionMNIST is pulled from Pytorch.
  • FashionMNIST is selectable as a command line argument.

`FedVAE`: PMF calculation doesn't always sum to one

When using python3 main.py --algorithm fedvae --num_users 2 --alpha 0.1 --sample_ratio 0.25 --glob_epochs 2 --local_epochs 3 --should_log 1 --z_dim 50 --beta 1.0 --classifier_num_train_samples 1000 --classifier_epochs 5 --decoder_num_train_samples 1000 --decoder_epochs 5, the label distributions begin to not sum to one.

We need to ensure that label distributions always sum to one. This may be a python precision issue.

Hyperparameter tuning for FedVAE

Follow the protocol detailed in blue in this doc.

  • Look up how other FL papers tune - should we re-tune for every level of alpha (heterogeneity)? Should we split out a validation dataset to tune?

Describe the findings and paste in Tensorboard plots into this document. Consider writing a shell script to automate experimentation.

Implement "Unachievable Ideal" (Centralized Model)

Train the global classifier with the training data selected according to the sampling ratio.

  • Add an if-statement in the Data class to not separate the data according to number of users (just hand back a single dataset of all available training data according to the sampling ratio).
  • Add an if-statement in main to only train the global classifier. Another file should probably exist to train this model.
  • Log this to tensorboard as central_model_sampling_ratio=x.x_number_of_epochs=xxx.

Tweak knowledge distillation procedure for FedVAE

Currently, FedVAE generates an even number of samples from all users (teacher decoders).

  • It may, however, be beneficial to sample according to the number of training data points seen by each user to ensure that users who saw substantially less data (more likely to happen with lower alpha) have less of an effect during fine-tuning.
  • Also, it may help to perform weighted parameters averaging for the server decoder's initialization.

Implement `FedVAE`

Implement our algorithm as shown in the pipeline below:

Screen Shot 2022-07-19 at 2 06 02 PM

  • Test the kaiming weight initialization script
  • Create file structure for VAE (decoder, encoder, view, linear_predict, etc.)

Random seeds check

For final experiments, we'd like to show the stability of FedVAE. To do this, we should run the model several times with different weight initializations but with the same dataset split--the random seed shouldn't affect how the dataset is distributed.

  • Separate the random seeds used for dataset and model.
  • Check that changing the model seed changes performance, but leaves the dataset split the same - can check by inspecting non-IID dataset visualization... should be identical!!

Make learning rate a passable parameter

We want learning rate to be passed in with command line arguments.

  • All hard-coded instances of learning rate are passed in via command line. This includes:
    • Local (user) learning rate, which is used by all algorithms.
    • Global decoder aggregation learning rate for FedVAE.

Try changing local optimizers to SGD

Currently, we use Adam as the local optimizer but this is divergent with the standard in the literature. (This is largely because Adam introduces additional hyperparameters, complicating the tuning process.) Try switching the local optimizer to SGD with no momentum and re-tuning learning rates.

Implement the Basic One-Shot Algorithm

Implement the one-shot FL algorithm described in Guha et al. (2019)--just the ensembling version that doesn't require auxiliary data.

  • Add a command argument to specify the sampling method for users. Add methods in the extended server class that allow for sampling via validation, amount of data, random sampling, and all users.
  • Implement an extended server class that overwrites the create_users method in the base class to split data subsets into additional training/validation subsets (if necessary for the sampling scheme).
  • The train method should allow all users to train and then should have selected users upload their models.
  • The evaluation step for this algorithm (the "global model") should just be based on the ensembled predictions (majority vote over classes or average logits) of selected/uploaded user models.

Implement Sampling a Fraction of Users For Extended Communication

Standard FL algorithms classically only involve a fraction of users during each communication round.

  • Add fraction of users to sample as a command line arg in main.py and integrate this into server.py by adding a base method that (uniformly) randomly samples users according to this fraction.
  • Every model that extends server.py should be able to use this fraction of selected users, although one-shot methods should set this to 1.0 by default since they involve all users.

Make the unachievable ideal appear as a line in tensorboard

For every experiment, we want the hard-coded converged ideal value to appear as a line across the top of our plot.

  • Add the hard-coded value to main.py.
  • Log a new run with the writer that just logs the same hard-coded as many times as args.glob_epochs.
  • End this run and start a new one.

Note: This is a purely for easier visualization, not a true experiment.

Re-tune `FedVAE` hyperparameters

Wait for issues #41, #43, #44.

Re-tune FedVAE hyperparameters for just one epoch, since the previous best hyperparameters are from the few-shot setting.

  • Maybe rename this algorithm and refactor (call this OneFedVAE?) - it's one-shot and has one decoder

`FedVAE`: implement augmented classifier training scheme

Currently, in server_fed_vae.py we only sample latent variables from a tight uniform distribution to obtain high-quality samples for classifier training. It may help to either:

  1. Sample a small portion of zs from a multivariate normal, or
  2. Sample from a wider uniform distribution (either for all samples or for a small portion),

to obtain a wider variety of intra-class variation for classifier training. It's likely that increasing the number of samples used for classifier training will also be necessary.

A similar approach may help for the knowledge-distillation fine-tuning for the server VAE as well.

One-shot FedVAE

We want to show that few-shot federated learning is a better setting for our model than one-shot. Thus, create a new algorithm that implements FedVAE as a one-shot algorithm.

See this paper for reference.

  • Add a new command line argument to select for onefedvae
  • Update the run name that gets saved to Tensorboard with onefedvae parameters.
  • Extend ServerFedVAE to ServerOneFedVAE and overwrite the server classifier training to sample a new dataset from collected decoders (same as ServerFedVAE for decoder knowledge distillation) and then train the classifier on this dataset.
  • Because this is a one-shot model, make sure we ignore the added global epochs parameter and only run for one epoch (se oneshot algorithm as reference).
  • Update the README.md with instructions to run.
  • For verification, local epochs should be very high so that user models can converge.

`FedVAE`: possibly re-initialize classifier each global round

Currently, the classifier is consistently trained over all global epochs. However, our pipeline schematic indicates that it should only be trained after all communication is done. Add the ability to re-initialize the server classifier's weights each round (essentially training from scratch each round) to see if this makes a difference at all.

  • Motivation: initial samples from the aggregated decoder may not be very high quality, which could direct classifier's weights towards a bad part of loss space.

Local computation experiments

Make it possible to easily test local computation amounts without re-starting the run. After each local training, communicate upwards to the server and log test results, then don't communicate downwards, but run another local epoch and repeat.

Clean up classifier model

We should use a more standard classifier architecture, like the following from McMahan et al. (2017): "a CNN with two 5x5 convolution layers (the first with 32 channels, the second with 64, each followed with 2x2 max pooling), a fully connected layer with 512 units and ReLu activation, and a final softmax output layer (1,663,370 total parameters)."

Re-tune `OneFedVAE`

Wait for #41, #43, #44.

Re-tune OneFedVAE hyperparameters since the previous best hyperparameters are from previous hyperparameter tuning runs.

  • Rename this algorithm (OneFedVAE is too confusing - maybe FedVAE-Ens) and refactor the code

Implement the application dataset

Wait until after meeting with Jay on 8/16.

Implement the chest x-ray dataset used in this paper.

  • Check that our VAE architecture is powerful enough to capture this data - train a single centralized CVAE and check samples! If it isn't good enough, settle on alternative architecture.

there is no condition in encoder?

Hi, I have a small question. As far as I understand about CVAE, the encoder requires labels as inputs. Why does the conditional encoder in the code only take images as input?

Add the ability to do a weighted average of user weights

In the original FL paper (McMahan et al. (2017)), they average weights proportionally to number of samples each user has in its local dataset. Currently, we do an unweighted average of user weights.

  • Make a weighted average based on number of data samples possible in the average_weights function of utils.py (see Algorithm 1 of McMahan et al. (2017) for details)

Clean up printing and logging for all algorithms

After all algorithms are implemented: clean up the hyperparameters that are printed/logged to tensorboard in main.py. As an example, for the unachievable ideal (centralized model), alpha and number of local epochs should not be printed/logged.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.