Giter Site home page Giter Site logo

tkh666 / fedrolex Goto Github PK

View Code? Open in Web Editor NEW

This project forked from aiot-mlsys-lab/fedrolex

0.0 0.0 0.0 945 KB

[NeurIPS 2022] "FedRolex: Model-Heterogeneous Federated Learning with Rolling Sub-Model Extraction" by Samiul Alam, Luyang Liu, Ming Yan, and Mi Zhang

License: Apache License 2.0

Python 100.00%

fedrolex's Introduction

FedRolex: Model-Heterogeneous Federated Learning with Rolling Sub-Model Extraction

Code for paper:

FedRolex: Model-Heterogeneous Federated Learning with Rolling Sub-Model Extraction
Samiul Alam, Luyang Liu, Ming Yan, and Mi Zhang.
NeurIPS 2022.

The repository is built upon HeteroFL.

Overview

Most cross-device federated learning studies focus on the model-homogeneous setting where the global server model and local client models are identical. However, such constraint not only excludes low-end clients who would otherwise make unique contributions to model training but also restrains clients from training large models due to on-device resource bottlenecks. We propose FedRolex, a partial training-based approach that enables model-heterogeneous FL and can train a global server model larger than the largest client model.

comparison

The key difference between FedRolex and existing partial training-based methods is how the sub-models are extracted for each client over communication rounds in the federated training process. Specifically, instead of extracting sub-models in either random or static manner, FedRolex proposes a rolling sub-model extraction scheme, where the sub-model is extracted from the global server model using a rolling window that advances in each communication round. Since the window is rolling, sub-models from different parts of the global model are extracted in sequence in different rounds. As a result, all the parameters of the global server model are evenly trained over the local data of client devices.

fedrolex

Video Brief

Click the figure to watch this short video explaining our work.

slideslive_link

Usage

Setup

pip install -r requirements.txt

Training

Train RESNET-18 model on CIFAR-10 dataset.

python main_resnet.py --data_name CIFAR10 \
                      --model_name resnet18 \ 
                      --control_name 1_100_0.1_non-iid-2_dynamic_a1-b1-c1-d1-e1_bn_1_1 \
                      --exp_name roll_test \
                      --algo roll \
                      --g_epoch 3200 \
                      --l_epoch 1 \
                      --lr 2e-4 \
                      --schedule 1200 \
                      --seed 31 \
                      --num_experiments 3 \
                      --devices 0 1 2

data_name: CIFAR10 or CIFAR100
model_name: resnet18 or vgg control_name: 1_{num users}{num participating users}{iid or non-iid-{num classes}}{dynamic or fix} {heterogeneity distribution}{batch norm(bn), {group norm(gn)}}{scalar 1 or 0}_{masked cross entropy, 1 or 0}
exp_name: string value
algo: roll, random or static
g_epoch: num global epochs
l_epoch: num local epochs
lr: learning rate
schedule: lr schedule, space seperated list of integers less than g_epoch
seed: integer number
num_experiments: integer number, will run num_experiments trials with seed incrementing each time
devices: Index of GPUs to use \

To train Transformer model on StackOverflow dataset, use main_transformer.py instead.

python main_transformer.py --data_name Stackoverflow \
                           --model_name transformer \
                           --control_name 1_100_0.1_iid_dynamic_a6-b10-c11-d18-e55_bn_1_1 \
                           --exp_name roll_so_test \
                           --algo roll \
                           --g_epoch 1500 \ 
                           --l_epoch 1 \
                           --lr 2e-4 \
                           --schedule 600 1000 \
                           --seed 31 \
                           --num_experiments 3 \
                           --devices 0 1 2 3 4 5 6 7

To train a data and model homogeneous the command would look like this.

python main_resnet.py --data_name CIFAR10 \
                      --model_name resnet18 \
                      --control_name 1_100_0.1_iid_dynamic_a1_bn_1_1 \ 
                      --exp_name homogeneous_largest_low_heterogeneity \
                      --algo static \
                      --g_epoch 3200 \
                      --l_epoch 1 \
                      --lr 2e-4 \
                      --schedule 800 1200 \ 
                      --seed 31 \
                      --num_experiments 3 \
                      --devices 0 1 2

To reproduce the results of on Table 3 in the paper please run the following commands:

CIFAR-10

python main_resnet.py --data_name CIFAR10 \
                      --model_name resnet18 \
                      --control_name 1_100_0.1_iid_dynamic_a1-b1-c1-d1-e1_bn_1_1 \ 
                      --exp_name homogeneous_largest_low_heterogeneity \
                      --algo static \
                      --g_epoch 3200 \
                      --l_epoch 1 \
                      --lr 2e-4 \
                      --schedule 800 1200 \ 
                      --seed 31 \
                      --num_experiments 5 \
                      --devices 0 1 2

CIFAR-100

python main_resnet.py --data_name CIFAR100 \
                      --model_name resnet18 \
                      --control_name 1_100_0.1_iid_dynamic_a1-b1-c1-d1-e1_bn_1_1 \ 
                      --exp_name homogeneous_largest_low_heterogeneity \
                      --algo static \
                      --g_epoch 2500 \
                      --l_epoch 1 \
                      --lr 2e-4 \
                      --schedule 800 1200 \ 
                      --seed 31 \
                      --num_experiments 5 \
                      --devices 0 1 2

StackOverflow

python main_transformer.py --data_name Stackoverflow \
                           --model_name transformer \
                           --control_name 1_100_0.1_iid_dynamic_a1-b1-c1-d1-e1_bn_1_1 \
                           --exp_name roll_so_test \
                           --algo roll \
                           --g_epoch 1500 \ 
                           --l_epoch 1 \
                           --lr 2e-4 \
                           --schedule 600 1000 \
                           --seed 31 \
                           --num_experiments 5 \
                           --devices 0 1 2 3 4 5 6 7

Note: To get the results based on the real world distribution as in Table 4, use a6-b10-c11-d18-e55 as the distribution.

Citation

If you find this useful for your work, please consider citing:

@InProceedings{alam2022fedrolex,
  title = {FedRolex: Model-Heterogeneous Federated Learning with Rolling Sub-Model Extraction},
  author = {Alam, Samiul and Liu, Luyang and Yan, Ming and Zhang, Mi},
  booktitle = {Conference on Neural Information Processing Systems (NeurIPS)},
  year = {2022}
}

fedrolex's People

Contributors

mi-zhang avatar samiul272 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.