Giter Site home page Giter Site logo

fedgmm_icml2023's Introduction

Personalized Federated Learning under Mixture of Distributions

This repository is the official implementation of Personalized Federated Learning under Mixture of Distributions.

The recent trend towards Personalized Federated Learning (PFL) has garnered significant attention as it allows for the training of models that are tailored to each client while maintaining data privacy. However, current PFL techniques primarily focus on modeling the conditional distribution heterogeneity (i.e. concept shift), which can result in suboptimal performance when the distribution of input data across clients diverges (i.e. covariate shift). Additionally, these techniques often lack the ability to adapt to unseen data, further limiting their effectiveness in real-world scenarios.

To address these limitations, we propose a novel approach, FedGMM, which utilizes Gaussian mixture models (GMM) to effectively fit the input data distributions across diverse clients. The model parameters are estimated by maximum likelihood estimation utilizing a federated Expectation-Maximization algorithm, which is solved in closed form and does not assume gradient similarity. Furthermore, \ourmethod\ possesses an additional advantage of adapting to new clients with minimal overhead, and it also enables uncertainty quantification. Empirical evaluations on synthetic and benchmark datasets demonstrate the superior performance of our method in both PFL classification and novel sample detection.

Requirements

To install requirements:

pip install -r requirements.txt

Usage

We provide code to simulate federated training of machine learning. The core objects are Aggregator and Client, different federated learning algorithms can be implemented by revising the local update method Client.step() and/or the aggregation protocol defined in Aggregator.mix() and Aggregator.update_client().

In addition to the trivial baseline consisting of training models locally without any collaboration, this repository supports the following federated learning algorithms:

Datasets

We provide five federated benchmark datasets spanning a wide range of machine learning tasks: image classification (CIFAR10 and CIFAR100), handwritten character recognition (EMNIST and FEMNIST), and language modelling (Shakespeare), in addition to a synthetic dataset

Shakespeare dataset (resp. FEMNIST) was naturally partitioned by assigning all lines from the same characters (resp. all images from the same writer) to the same client. We created federated versions of CIFAR10 and EMNIST by distributing samples with the same label across the clients according to a symmetric Dirichlet distribution with parameter 0.4. For CIFAR100, we exploited the availability of "coarse" and "fine" labels, using a two-stage Pachinko allocation method to assign 600 sample to each of the 100 clients.

The following table summarizes the datasets and models

Dataset Task Model
FEMNIST Handwritten character recognition 2-layer CNN + 2-layer FFN
EMNIST Handwritten character recognition 2-layer CNN + 2-layer FFN
CIFAR10 Image classification MobileNet-v2
CIFAR100 Image classification MobileNet-v2
Shakespeare Next character prediction Stacked LSTM
Synthetic dataset Binary classification Linear model

See the README.md files of respective dataset, i.e., data/$DATASET, for instructions on generating data

PCA projection

The PCA projection file for each dataset can be retrieved by uncommenting lines 42-81 in the run_experiment.py file.

Training

Run on one dataset, with a specific choice of federated learning method. Specify the name of the dataset (experiment), the used method, and configure all other hyper-parameters (see all hyper-parameters values in the appendix of the paper)

python3  python run_experiment.py cifar10 FedGMM \
    --n_learners 3 \
    --n_gmm 3
    --n_rounds 200 \
    --bz 128 \
    --lr 0.01 \
    --lr_scheduler multi_step \
    --log_freq 5 \
    --device cuda \
    --optimizer sgd \
    --seed 1234 \
    --logs_root ./logs \
    --verbose 1

The test and training accuracy and loss will be saved in the specified log path.

We provide example scripts to run paper experiments under scripts/ directory.

Evaluation

We give instructions to run experiments on CIFAR-10 dataset as an example (the same holds for the other datasets). You need first to go to ./data/cifar10, follow the instructions in README.md to download and partition the dataset.

All experiments will generate tensorboard log files (logs/cifar10) that you can interact with, using TensorBoard

Average performance of personalized models

Run the following scripts, this will generate tensorboard logs that you can interact with to make plots or get the values presented in Table 2

# run FedAvg
echo "Run FedAvg"
python run_experiment.py cifar10 FedAvg --n_learners 1 --n_rounds 200 --bz 128 --lr 0.01 \
 --lr_scheduler multi_step --log_freq 5 --device cuda --optimizer sgd --seed 1234 --verbose 1

# run FedAvg + local adaption
echo "run FedAvg + local adaption"
python run_experiment.py cifar10 FedAvg --n_learners 1 --locally_tune_clients --n_rounds 201 --bz 128 \
 --lr 0.001 --lr_scheduler multi_step --log_freq 10 --device cuda --optimizer sgd --seed 1234 --verbose 1

# run training using local data only
echo "Run Local"
python run_experiment.py cifar10 local --n_learners 1 --n_rounds 201 --bz 128 --lr 0.03 \
 --lr_scheduler multi_step --log_freq 10 --device cuda --optimizer sgd --seed 1234 --verbose 1

# run Clustered FL
echo "Run Clustered FL"
python run_experiment.py cifar10 clustered --n_learners 1 --n_rounds 201 --bz 128 --lr 0.003 \
 --lr_scheduler multi_step --log_freq 10 --device cuda --optimizer sgd --seed 1234 --verbose 1

# run FedProx
echo "Run FedProx"
python run_experiment.py cifar10 FedProx --n_learners 1 --n_rounds 201 --bz 128 --lr 0.01 --mu 1.0\
 --lr_scheduler multi_step --log_freq 10 --device cuda --optimizer prox_sgd --seed 1234 --verbose 1

# Run pFedME
echo "Run "
python run_experiment.py cifar10 pFedMe --n_learners 1 --n_rounds 201 --bz 128 --lr 0.001 --mu 1.0 \
 --lr_scheduler multi_step --log_freq 10 --device cuda --optimizer prox_sgd --seed 1234 --verbose 1

# run FedEM
echo "Run FedEM"
python run_experiment.py cifar10 FedEM --n_learners 3 --n_rounds 201 --bz 128 --lr 0.03 \
 --lr_scheduler multi_step --log_freq 10 --device cuda --optimizer sgd --seed 1234 --verbose 1
 
# run FedGMM
echo "Run FedGMM"
python run_experiment.py cifar10 FedGMM --n_learners 3 --n_gmm 3 --n_rounds 201 --bz 128 --lr 0.03 \
 --lr_scheduler multi_step --log_freq 10 --device cuda --optimizer sgd --seed 1234 --verbose 1

Reference

@article{zhangICMLFedGMM,
  title={Personalized Federated Learning under Mixture of Distributions},
  author={Wu, Yue and Zhang, Shuaicheng and Yu, Wenchao and Liu, Yanchi and Gu, Quanquan and Zhou, Dawei and Chen, Haifeng and Cheng, Wei},
  booktitle={International Conference on Machine Learning (ICML'23)},
  year={2023}
}

fedgmm_icml2023's People

Contributors

zshuai8 avatar

Stargazers

 avatar lbjcelsius avatar  avatar  avatar Zihong Lin avatar  avatar  avatar tdye24 avatar  avatar  avatar Wu minzhe avatar Yongxin Guo avatar  avatar Wang Bomin avatar  avatar

Watchers

 avatar Kostas Georgiou avatar

fedgmm_icml2023's Issues

n_clusters

If we set n_clusters=3 (n_component=3) when generating non-iid clients, then theoretically how to set n_learner and n_gmm to get best results? is n_learner=3 and n_gmm=3 right?
Please help to understand relations between the n_clusters and the n_learner, n_gmm?

RuntimeError: Given groups=1, weight of size [32, 3, 3, 3], expected input[128, 32, 32, 3] to have 3 channels, but got 32 channels instead

Dear author, I uncomment lines 42-81 in the run_experiment.py file to get PCA projection file. I modify all_data_tensor = all_data_tensor.view(-1,3,32,32) since I run experiment on cifar10.

# all_data_tensor = []
# for cur_data in train_iterators:
#     all_data_tensor.append(cur_data.dataset.data)
# all_data_tensor = torch.cat(all_data_tensor, dim=0)
#
# model = models.resnet18(pretrained=True)
#
# del model.fc
# all_data_tensor = all_data_tensor.view(-1,1,32,32)
# x = all_data_tensor
# if all_data_tensor.shape[1] == 1:
#     x = all_data_tensor.repeat(1, 3, 1, 1)
# x = model.conv1(x.float())
# x = model.bn1(x)
# x = model.relu(x)
# x = model.maxpool(x)
#
# x = model.layer1(x)
# x = model.layer2(x)
# x = model.layer3(x)
# x = model.layer4(x)
#
# # Extract the feature maps produced by the encoder
# encoder_output = x.squeeze()
# U, S, V = torch.svd(encoder_output)
# global PCA_V
# PCA_V = V
# print(PCA_V.size())
# with open(f"data/cifar10/all_data/PCA.pkl" , 'wb') as f:
#     pickle.dump(PCA_V, f)
# raise
#
#
# encoder_output = encoder_output.view(encoder_output.size(0), -1)
# pca_transformer = PCA(n_components=emb_size)
# # Fit the PCA transformer to your data
#
# X_pca = pca_transformer.fit_transform(encoder_output.detach().numpy())
# # Convert the resulting principal components to a PyTorch tensor
# projected = torch.from_numpy(X_pca).float().cuda()

However, I have new error as follows:

==> Clients initialization..
===> Building data iterators..
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 456.40it/s]
===> Initializing clients..
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:44<00:00, 1.41s/it]
==> Test Clients initialization..
===> Building data iterators..
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 465.28it/s]
===> Initializing clients..
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:08<00:00, 1.11s/it]
Traceback (most recent call last):
File "run_experiment.py", line 244, in
run_experiment(args)
File "run_experiment.py", line 185, in run_experiment
get_aggregator(
File "/mnt/traffic1/data/jxt/Personalized_Federated_Learning/FedGMM/utils/utils.py", line 662, in get_aggregator
return ACGCentralizedAggregator(
File "/mnt/traffic1/data/jxt/Personalized_Federated_Learning/FedGMM/aggregator.py", line 412, in init
super().init(clients,
File "/mnt/traffic1/data/jxt/Personalized_Federated_Learning/FedGMM/aggregator.py", line 123, in init
self.write_logs()
File "/mnt/traffic1/data/jxt/Personalized_Federated_Learning/FedGMM/aggregator.py", line 581, in write_logs
self.update_test_clients()
File "/mnt/traffic1/data/jxt/Personalized_Federated_Learning/FedGMM/aggregator.py", line 434, in update_test_clients
client.update_sample_weights()
File "/mnt/traffic1/data/jxt/Personalized_Federated_Learning/FedGMM/client.py", line 202, in update_sample_weights
self.samples_weights = self.learners_ensemble.calc_samples_weights(self.val_iterator)
File "/mnt/traffic1/data/jxt/Personalized_Federated_Learning/FedGMM/learners/learners_ensemble.py", line 805, in calc_samples_weights
all_losses = self.gather_losses(iterator).T # n * m2
File "/mnt/traffic1/data/jxt/Personalized_Federated_Learning/FedGMM/learners/learners_ensemble.py", line 872, in gather_losses
all_losses[learner_id] = learner.gather_losses(iterator)
File "/mnt/traffic1/data/jxt/Personalized_Federated_Learning/FedGMM/learners/learner.py", line 248, in gather_losses
y_pred = self.model(x)
File "/mnt/traffic/home/jxt/miniconda3/envs/FedGMM32/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/mnt/traffic/home/jxt/miniconda3/envs/FedGMM32/lib/python3.8/site-packages/torchvision/models/mobilenetv2.py", line 198, in forward
return self._forward_impl(x)
File "/mnt/traffic/home/jxt/miniconda3/envs/FedGMM32/lib/python3.8/site-packages/torchvision/models/mobilenetv2.py", line 191, in _forward_impl
x = self.features(x)
File "/mnt/traffic/home/jxt/miniconda3/envs/FedGMM32/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/mnt/traffic/home/jxt/miniconda3/envs/FedGMM32/lib/python3.8/site-packages/torch/nn/modules/container.py", line 119, in forward
input = module(input)
File "/mnt/traffic/home/jxt/miniconda3/envs/FedGMM32/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/mnt/traffic/home/jxt/miniconda3/envs/FedGMM32/lib/python3.8/site-packages/torch/nn/modules/container.py", line 119, in forward
input = module(input)
File "/mnt/traffic/home/jxt/miniconda3/envs/FedGMM32/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/mnt/traffic/home/jxt/miniconda3/envs/FedGMM32/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 399, in forward
return self._conv_forward(input, self.weight, self.bias)
File "/mnt/traffic/home/jxt/miniconda3/envs/FedGMM32/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 395, in _conv_forward
return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Given groups=1, weight of size [32, 3, 3, 3], expected input[128, 32, 32, 3] to have 3 channels, but got 32 channels instead

RuntimeError: The size of tensor a (32) must match the size of tensor b (3) at non-singleton dimension 1

Run FedEM
==> Clients initialization..
===> Building data iterators..
Traceback (most recent call last):
File "run_experiment.py", line 244, in
run_experiment(args)
File "run_experiment.py", line 142, in run_experiment
clients = init_clients(args_,
File "run_experiment.py", line 31, in init_clients
get_loaders(
File "/mnt/traffic1/data/jxt/Personalized_Federated_Learning/FedGMM/utils/utils.py", line 375, in get_loaders
inputs, targets = get_cifar10(dist_shift, dp)
File "/mnt/traffic1/data/jxt/Personalized_Federated_Learning/FedGMM/datasets.py", line 1170, in get_cifar10
cifar10_data = d_norm(cifar10_data)
File "/mnt/traffic/home/jxt/miniconda3/envs/FedGMM/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in call_impl
result = self.forward(*input, **kwargs)
File "/mnt/traffic/home/jxt/miniconda3/envs/FedGMM/lib/python3.8/site-packages/torchvision/transforms/transforms.py", line 221, in forward
return F.normalize(tensor, self.mean, self.std, self.inplace)
File "/mnt/traffic/home/jxt/miniconda3/envs/FedGMM/lib/python3.8/site-packages/torchvision/transforms/functional.py", line 336, in normalize
tensor.sub
(mean).div_(std)
RuntimeError: The size of tensor a (32) must match the size of tensor b (3) at non-singleton dimension 1

I train on cifar10 dataset and have this problem.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.