Giter Site home page Giter Site logo

fedknow's Introduction

FedKNOW

English | 简体中文

Table of contents

1 Introduction

FedKNOW is designed to achieve SOTA performance (accuracy, time, and communication cost etc.) in federated continual learning setting. It currently supports eight different networks of image classification: ResNet, ResNeXt, MobileNet, WideResNet, SENet, ShuffleNet, Inception and DenseNet.

  • ResNet: this model consists of multiple convolutional layers and pooling layers that extract the information in image. Typically, ResNet suffers from gradient vanishing (exploding) and performance degrading when the network is deep. ResNet thus adds BatchNorm to alleviate gradient vanishing (exploding) and adds residual connection to alleviate the performance degrading.
  • Inception: To extract high dimensional features, a typical method is to use deeper network which makes the size of the network too large. To address this issue, Inception Module is proposed to widen the network. This can maintain the performance and reduce the number of parameters. Inception Module firstly leverages 1x1 convolution to aggregate the information, and leverages multi-scale convolutions to obtain multi-scale feature maps, finally concatenate them to produce the final feature map.
  • ResNeXt: ResNeXt combines Inception and ResNet. It first simplifies the Inception Module to make each of its branch have the same structure and then constructs the network as ResNet-style.
  • WideResNet: WideResNet widens the residual connection of ResNet to improve its performance and reduces the number of its parameters. Besides, WideResNet uses Dropout regularization to further improve its generalization.
  • MobileNet: MobileNet is a lightweight convolutional network which widely uses the depthwise separable convolution.
  • SENet: SENet imports channel attention to allow the network focus the more important features. In SENet, a Squeeze & Excitation Module uses the output of a block as input, produces an channel-wise importance vector, and multiplies it into the original output of the block to strengthen the important channels and weaken the unimportant channels.
  • ShuffleNet: ShuffleNet is a lightweight network. It imports the pointwise group convolution and channel shuffle to greatly reduce the computation cost. It replaces the 1x1 convolution of ResBlock with the group convolution and add channel shuffle to it.
  • DenseNet: DenseNet extends ResNet by adding connections between each blocks to aggregate all multi-scale features.

2 How to get started

2.1 Setup

Requirements

  • Edge devices such as Jetson AGX, Jetson TX2, Jetson Xavier NX and Jetson Nano
  • Linux and Windows
  • Python 3.6+
  • PyTorch 1.9+
  • CUDA 10.2+

Preparing the virtual environment

  1. Create a conda environment and activate it.

    conda create -n FedKNOW python=3.7
    conda active FedKNOW
  2. Install PyTorch 1.9+ in the offical website. A NVIDIA graphics card and PyTorch with CUDA are recommended.

image

  1. Clone this repository and install the dependencies.
git clone https://github.com/LINC-BIT/FedKNOW.git
pip install -r requirements.txt

2.2 Usage

  • Single device

    Run FedKNOW or the baselines:

    python single/main_FedKNOW.py --alg fedknow --dataset [dataset] --model [mdoel]
    --num_users [num_users]  --shard_per_user [shard_per_user] --frac [frac] 
    --local_bs [local_bs] --lr [lr] --task [task] --epoch [epoch]  --local_ep 
    [local_ep] --local_local_ep [local_local_ep] --store_rate [store_rate] 
    --select_grad_num [select_grad_num] --gpu [gpu]

    Arguments:

    • alg: the algorithm, e.g. FedKNOW, FedRep, GEM

    • dataset : the dataset, e.g. cifar100, FC100, CORe50, MiniImagenet, Tinyimagenet

    • model: the model, e.g. 6-Layers CNN, ResNet18

    • num_users: the number of clients

    • shard_per_user: the number of classes in each client

    • frac: the percentage of clients participating in training in each epoch

    • local_bs: the batch size in each client

    • lr: the learning rate

    • task: the number of tasks

    • epochs: the number of communications between clients and the server

    • local_ep:the number of epochs in clients

    • local_local_ep:the number of updating the local parameters in clients

    • store_rate: the store rate of model parameters in FedKNOW

    • select_grad_num: the number of choosing the old grad in FedKNOW

    • gpu: GPU id

      More details refer to utils/option.py. The configurations of all algorithms are located in scripts/single.sh.

  • Multiple devices

    1. Limit the network bandwidth to emulate the real long distance transmission:

      sudo wondershaper [adapter] [download rate] [upload rate]
    2. Launch the server:

      python multi/server.py --num_users [num_users] --frac [frac] --ip [ip]
    3. Launch the clients:

      python multi/main_FedKNOW.py --client_id [client_id] --alg [alg]
      --dataset [dataset] --model[mdoel]  --shard_per_user [shard_per_user] 
      --local_bs [local_bs] --lr [lr] --task [task] --epoch [epoch]  --local_ep 
      [local_ep] --local_local_ep [local_local_ep]  --store_rate [store_rate] 
      --select_grad_num [select_grad_num] --gpu [gpu] --ip [ip]

      Arguments:

      • client_id: the id of the client

      • ip: IP address of the server

        The other arguments is the same as the one in single device setting. More details refer to utils/option.py. The configurations of all algorithms are located in scripts/multi.sh.

3 Supported models in image classification

Model Name Data Script
       ☑               6 layer_CNN (NeurIPS'2020)                          Cifar100             
                 FC100              
                 CORe50              
             Demo              
       ☑               ResNet (CVPR'2016)                      MiniImageNet              
             TinyImageNet              
             Demo              
       ☑               MobileNetV2 (CVPR'2018)                      MiniImageNet                         Demo           
       ☑               ResNeXt (CVPR'2017)                      MiniImageNet                         Demo              
       ☑               InceptionV3(CVPR'2016)                      MiniImageNet                         Demo              
       ☑               WideResNet (BMVC'2016)                      MiniImageNet                         Demo             
       ☑               ShuffleNetV2 (ECCV'2018)                      MiniImageNet                         Demo             
       ☑               DenseNet(CVPR'2017)                      MiniImageNet                         Demo             
       ☑               SENet (CVPR'2018)                      MiniImageNet                         Demo             

4 Experiemts

4.1 Under different workloads (model and dataset)

  1. Run

    Launch the server:

    python multi/server.py --epochs=150 --num_users=20 --frac=0.4 --ip=127.0.0.1:8000

    Launch the clients:

    • 6-layer CNN on Cifar100

      for ((i=0;i<20;i++));
      do
          python multi/main_FedKNOW.py --client_id=$i --model=6_layerCNN --dataset=cifar100 --num_classes=100 --task=10 --alg=FedKNOW --lr=0.001 --optim=Adam --lr_decay=1e-4 --ip=127.0.0.1:8000
      done
    • 6-layer CNN on FC100

      for ((i=0;i<20;i++));
      do
          python multi/main_FedKNOW.py --client_id=$i --model=6_layerCNN --dataset=FC100 --num_classes=100 --task=10 --alg=FedKNOW --lr=0.001 --optim=Adam --lr_decay=1e-4 --ip=127.0.0.1:8000
      done
    • 6-layer CNN on CORe50

      for ((i=0;i<20;i++));
      do
          python multi/main_FedKNOW.py --client_id=$i --model=6_layerCNN --dataset=CORe50 --num_classes=550 --task=11 --alg=FedKNOW --lr=0.001 --optim=Adam --lr_decay=1e-4 --ip=127.0.0.1:8000
      done
    • ResNet18 on MiniImageNet

      for ((i=0;i<20;i++));
      do
          python multi/main_FedKNOW.py --client_id=$i --model=ResNet --dataset=MiniImageNet --num_classes=100 --task=10 --alg=FedKNOW --lr=0.0008 --optim=SGD --lr_decay=1e-5 --ip=127.0.0.1:8000
      done
    • ResNet18 on TiniImageNet

      for ((i=0;i<20;i++));
      do
          python multi/main_FedKNOW.py --client_id=$i --model=ResNet --dataset=TinyImageNet --num_classes=200 --task=20 --alg=FedKNOW --lr=0.0008 --optim=SGD --lr_decay=1e-5 --ip=127.0.0.1:8000
      done

    Note: Keep the IP address of the server and clients the same. --ip=127.0.0.1:8000 represents testing locally. If there're multiple edge devices, you should do --ip=<IP of the server>.

  2. Result

    • The accuracy trend overtime time under different workloads(X-axis represents the time and Y-axis represents the inference accuracy)

4.2 Under different network bandwidths

  1. Run

    Limit the network bandwidth of the server:

    # The maximal download rate and upload rate are 1000KB/s. 
    # In practice this is not so precise so you can adjust it.
    sudo wondershaper [adapter] 1000 1000 

    Check the network state of the server

    sudo nload -m

    Launch the server:

    python multi/server.py --epochs=150 --num_users=20 --frac=0.4 --ip=127.0.0.1:8000

    Launch the clients:

    • 6-layer CNN on Cifar100

      for ((i=0;i<20;i++));
      do
          python multi/main_FedKNOW.py --client_id=$i --model=6_layerCNN --dataset=cifar100 --num_classes=100 --task=10 --alg=FedKNOW --lr=0.001 --optim=Adam --lr_decay=1e-4 --ip=127.0.0.1:8000
      done
    • 6-layer CNN on FC100

      for ((i=0;i<20;i++));
      do
          python multi/main_FedKNOW.py --client_id=$i --model=6_layerCNN --dataset=FC100 --num_classes=100 --task=10 --alg=FedKNOW --lr=0.001 --optim=Adam --lr_decay=1e-4 --ip=127.0.0.1:8000
      done
    • 6-layer CNN on CORe50

      for ((i=0;i<20;i++));
      do
          python multi/main_FedKNOW.py --client_id=$i --model=6_layerCNN --dataset=CORe50 --num_classes=550 --task=11 --alg=FedKNOW --lr=0.001 --optim=Adam --lr_decay=1e-4 --ip=127.0.0.1:8000
      done
    • ResNet18 on MiniImageNet

      for ((i=0;i<20;i++));
      do
          python multi/main_FedKNOW.py --client_id=$i --model=ResNet18 --dataset=MiniImageNet --num_classes=100 --task=10 --alg=FedKNOW --lr=0.0008 --optim=SGD --lr_decay=1e-5 --ip=127.0.0.1:8000
      done
    • ResNet18 on TiniImageNet

      for ((i=0;i<20;i++));
      do
          python multi/main_FedKNOW.py --client_id=$i --model=ResNet18 --dataset=TinyImageNet --num_classes=200 --task=20 --alg=FedKNOW --lr=0.0008 --optim=SGD --lr_decay=1e-5 --ip=127.0.0.1:8000
      done
  2. Result

    • The communication time under different workloads and maximal network bandwidth 1MB/s (X-axis represents the dataset and Y-axis represents the communication time)

    • The communication time under different network bandwidths (X-axis represents the network bandwidth and Y-axis represents the communication time)

4.3 Large scale

  1. Run

    # 50 clients
    python single/main_FedKNOW.py --epochs=150 --num_users=50 --frac=0.4 --model=ResNet18 --dataset=MiniImageNet --num_classes=100 --task=10 --alg=FedKNOW --lr=0.0008 --optim=SGD --lr_decay=1e-5 
    # 100 clients
    python single/main_FedKNOW.py --epochs=150 --num_users=100 --frac=0.4 --model=ResNet18 --dataset=MiniImageNet --num_classes=100 --task=10 --alg=FedKNOW --lr=0.0008 --optim=SGD --lr_decay=1e-5 
  2. Result

    • The accuracy under 50 clients and 100 clients (X-axis represents the task and Y-axis represents the accuracy)

    • The average forgetting rate under 50 clients and 100 clients (X-axis represents the task and Y-axis represents the average forgetting rate)

4.4 Long task sequence

  1. Run

    # dataset = MiniImageNet + TinyImageNet + cifar100 + FC100, task = 80 ,per_task_class = 5
    python single/main_FedKNOW.py --epochs=150 --num_users=20 --frac=0.4 --model=ResNet18 --dataset=All --num_classes=400 --task=80 --alg=FedKNOW --lr=0.0008 --optim=SGD --lr_decay=1e-5 
  2. Result

    • The average accuracy under 80 tasks (X-axis represents the task and Y-axis represents the accuracy)

    • The average forgetting rate under 80 tasks (X-axis represents the task and Y-axis represents the average forgetting rate)

    • The time under 80 tasks (X-axis represents the task and Y-axis represents the time on current task)

4.5 Under different parameter settings

  1. Run

    # store_rate = 0.05
    python single/main_FedKNOW.py --epochs=150 --num_users=20 --frac=0.4 --model=ResNet18 --dataset=MiniImageNet --num_classes=100 --task=10 --alg=FedKNOW --lr=0.0008 --optim=SGD --lr_decay=1e-5 --store_rate=0.05
    # store_rate = 0.1
    python single/main_FedKNOW.py --epochs=150 --num_users=20 --frac=0.4 --model=ResNet18 --dataset=MiniImageNet --num_classes=100 --task=10 --alg=FedKNOW --lr=0.0008 --optim=SGD --lr_decay=1e-5 --store_rate=0.1
    # store_rate = 0.2
    python single/main_FedKNOW.py --epochs=150 --num_users=20 --frac=0.4 --model=ResNet18 --dataset=MiniImageNet --num_classes=100 --task=10 --alg=FedKNOW --lr=0.0008 --optim=SGD --lr_decay=1e-5 --store_rate=0.2
  2. Result

    • The accuracy under different parameter storage ratios (X-axis represents the task and Y-axis represents the accuracy)

    • The time under different parameter storage ratios (X-axis represents the task and Y-axis represents the time on current task)

4.6 Applicability on different networks

  1. Run

    # WideResNet50
    python single/main_FedKNOW.py --epochs=150 --num_users=20 --frac=0.4 --model=WideResNet --dataset=MiniImageNet --num_classes=100 --task=10 --alg=FedKNOW --lr=0.0008 --optim=SGD --lr_decay=1e-5
    # ResNeXt50
    python single/main_FedKNOW.py --epochs=150 --num_users=20 --frac=0.4 --model=ResNeXt --dataset=MiniImageNet --num_classes=100 --task=10 --alg=FedKNOW --lr=0.0008 --optim=SGD --lr_decay=1e-5
    # ResNet152
    python single/main_FedKNOW.py --epochs=150 --num_users=20 --frac=0.4 --model=ResNet152 --dataset=MiniImageNet --num_classes=100 --task=10 --alg=FedKNOW --lr=0.0008 --optim=SGD --lr_decay=1e-5
    # SENet18
    python single/main_FedKNOW.py --epochs=150 --num_users=20 --frac=0.4 --model=SENet --dataset=MiniImageNet --num_classes=100 --task=10 --alg=FedKNOW --lr=0.0008 --optim=SGD --lr_decay=1e-5
    # MobileNetV2
    python single/main_FedKNOW.py --epochs=150 --num_users=20 --frac=0.4 --model=MobileNet --dataset=MiniImageNet --num_classes=100 --task=10 --alg=FedKNOW --lr=0.001 --optim=Adam --lr_decay=1e-5
    # ShuffleNetV2
    python single/main_FedKNOW.py --epochs=150 --num_users=20 --frac=0.4 --model=ShuffleNet --dataset=MiniImageNet --num_classes=100 --task=10 --alg=FedKNOW --lr=0.0005 --optim=Adam --lr_decay=1e-5
    # InceptionV3
    python single/main_FedKNOW.py --epochs=150 --num_users=20 --frac=0.4 --model=Inception --dataset=MiniImageNet --num_classes=100 --task=10 --alg=FedKNOW --lr=0.0005 --optim=Adam --lr_decay=1e-5
    # DenseNet
    python single/main_FedKNOW.py --epochs=150 --num_users=20 --frac=0.4 --model=DenseNet --dataset=MiniImageNet --num_classes=100 --task=10 --alg=FedKNOW --lr=0.001 --optim=Adam --lr_decay=1e-5
  2. Result

    • The accuracy on different networks(X-axis represents the task and Y-axis represents the accuracy)

fedknow's People

Contributors

linc-bit avatar crawler995 avatar luopanyaxin avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.