Giter Site home page Giter Site logo

vaseline555 / federated-learning-in-pytorch Goto Github PK

View Code? Open in Web Editor NEW
377.0 2.0 87.0 158 KB

Handy PyTorch implementation of Federated Learning (for your painless research)

License: MIT License

Python 95.40% Shell 4.60%
federated-learning pytorch fedavg fedprox fedsgd leaf-benchmark deep-learning fedadagrad fedadam fedopt

federated-learning-in-pytorch's People

Contributors

vaseline555 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

federated-learning-in-pytorch's Issues

AssertionError: Invalid device id

So while trying to set up my Cuda GPU and installing pytorch so as to run main.py. I came across this error

  File "main.py", line 57, in <module>
    central_server.setup()
  File "/content/Federated-Averaging-PyTorch/src/server.py", line 85, in setup
    init_net(self.model, **self.init_config)
  File "/content/Federated-Averaging-PyTorch/src/utils.py", line 77, in init_net
    model = nn.DataParallel(model, gpu_ids)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/parallel/data_parallel.py", line 142, in __init__
    _check_balance(self.device_ids)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/parallel/data_parallel.py", line 23, in _check_balance
    dev_props = _get_devices_properties(device_ids)
  File "/usr/local/lib/python3.7/dist-packages/torch/_utils.py", line 487, in _get_devices_properties
    return [_get_device_attr(lambda m: m.get_device_properties(i)) for i in device_ids]
  File "/usr/local/lib/python3.7/dist-packages/torch/_utils.py", line 487, in <listcomp>
    return [_get_device_attr(lambda m: m.get_device_properties(i)) for i in device_ids]
  File "/usr/local/lib/python3.7/dist-packages/torch/_utils.py", line 470, in _get_device_attr
    return get_member(torch.cuda)
  File "/usr/local/lib/python3.7/dist-packages/torch/_utils.py", line 487, in <lambda>
    return [_get_device_attr(lambda m: m.get_device_properties(i)) for i in device_ids]
  File "/usr/local/lib/python3.7/dist-packages/torch/cuda/__init__.py", line 361, in get_device_properties
    raise AssertionError("Invalid device id")
AssertionError: Invalid device id 

How can i fix the device ID issue.

Clients won't download server model under the global evaluation

Hi, Thanks for your efforts in providing such a neat and handy framework.

However, it seems _request with retain_model=False is never called under the global evaluation setting. As a result, the sampled client won't download the global model since client.model is not None.

Please correct me if there is anything I overlooked.

Thanks!

Implemented aggregation is different from the description in McMahan 2016 paper

Hello. Thanks for the nice repo.

I have a question about the implementation of aggregation on the server. According to the McMahan 2016 paper, the server does
$$w_{t+1}\leftarrow \sum_{k\in S_t} \frac{n_k}{m_t} w_{t+1}^k.$$

If I understand it correctly, there is a server-side optimizer in your code and your implementation can be translated into

$$w_{t+1}\leftarrow w_t - \eta(w_t - \sum_{k\in S_t} \frac{n_k}{m_t} w_{t+1}^k).$$

This looks like a momentum update of the server's model. Could you please explain the motivation behind this?

Sent140 Dataset

Thank you for your nice work.

It seems that URL for downloading Sent140 dataset does not work. How can I fix this issue?

Client Multithreading is actually not Working.

Here is a fix provided by ChatGPT:

The way your code is structured right now, it submits one job to the ThreadPoolExecutor and then immediately waits for the result of that job before proceeding to the next iteration of the loop. This is due to the .result() call right after the .submit() call, which blocks the execution until the future is done and returns the result.

If you want to run the tasks in parallel, you should submit all the tasks first, and then wait for the results. Here is how you can adjust your code to do that:

results = []
futures = []

with concurrent.futures.ThreadPoolExecutor(max_workers=min(len(ids), os.cpu_count() - 1)) as workhorse:
    for idx in TqdmToLogger(ids, logger=logger, desc=f'[{self.args.algorithm.upper()}] [Round: {str(self.round).zfill(4)}] ...update clients... ', total=len(ids)):
        future = workhorse.submit(__update_clients, self.clients[idx])
        futures.append(future)

    for future in concurrent.futures.as_completed(futures):
        result = future.result()
        results.append(result)

In this code snippet, all the tasks are first submitted to the executor and their corresponding Future objects are stored in the futures list. The concurrent.futures.as_completed(futures) function is then used to yield the futures as they are completed, and the results are retrieved and stored in the results list.

This way, the tasks will run in parallel (up to the number specified by max_workers), and you can collect the results as they become available.

Type error for `_request()`

Thanks for the well-organized code base. However, when I run the example experiments, it shows that two arguments are missing. Are the default values of them missing?

TypeError: _request() missing 2 required positional arguments: 'participated' and 'retain_model'

Is there any plan to implement DITTO Algorithm?

Hi, thank you for this repository. It has been really helpful.

I wanted to know if you have any plans to include DITTO Algorithm in your method also. As it could be helpful for some people trying to implement it.
DITTO Paper

Again, thanks for your time :)

The accuracy on Cifar10 may be low.

I ran the FedAvg code with the CNN2 given in model.py, with regard to the Cifar10 dataset. I also excluded the model initialization in server.py, and all of the clients (only 10) were set to update and upload their models to the server. However, over about 100 rounds, the accuracy can only raise up to around 70%, and do not go up afterwards. I wonder if there is anything I've missed or mistaken. Could anyone please offer me some advice?

Set a different number of Clients when using LEAF dataset

First of all, amazing work! :)
There is a way to set a different number of Clients when using LEAF dataset? For example if I use FEMNIST fro LEAF the number of clients is 3597. Is possible to modify this for time and resources issue?

Thanks a lot!

undefined variable error

In "utils.py", modules of Torch Vision are imported by from torchvision import datasets, transforms.

However, in the line 106 of the same file, these modules are referred to by torchvision.datasets, but there is no variable named by torchvision.

I fixed it by rewrite the import statement by just import torchvision. Now it works.

Datasets location

Please, can you provide more information about datasets location as exact name for files to download, website, etc? Thank you!

Instructions to run on separate server/client devices?

Hello! I was interested in potentially running this across a network connections, with the server on one device and clients running on other devices.

I looked through the commands directory and the Readme.md, and also read through main.py it seems the existing examples assume that all code is running on the same machine? Is there a method to run server and client separately?

About Dirichlet non-iid Partition

Hello, I am working on non-iid partition method in FL, and have some problems with Dirichlet partition.
In ./src/loaders/split.py, it seems like that :

  1. Do not consider the situation where a client requires number of data in a class that is larger than remaining data in that class

  2. Do not properly update the remaining class_idcs :
    class_indices[required_class] = class_indices[required_class][:required_counts[idx]]
    maybe it should be:
    class_indices[required_class] = class_indices[required_class][required_counts[idx]:]

  3. If the satisfied_counts is less than ideal_counts in the first while loop, then in the second while loop, the client is just given another ideal_counts sample (sampled = np.random.choice(args.num_classes, ideal_counts, p=cat_param)), which may lead to the client having sample way too much than ideal_counts.

Thanks for your code and hope you can help me with these concerns

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.