vaseline555 / federated-learning-in-pytorch Goto Github PK
View Code? Open in Web Editor NEWHandy PyTorch implementation of Federated Learning (for your painless research)
License: MIT License
Handy PyTorch implementation of Federated Learning (for your painless research)
License: MIT License
So while trying to set up my Cuda GPU and installing pytorch so as to run main.py. I came across this error
File "main.py", line 57, in <module>
central_server.setup()
File "/content/Federated-Averaging-PyTorch/src/server.py", line 85, in setup
init_net(self.model, **self.init_config)
File "/content/Federated-Averaging-PyTorch/src/utils.py", line 77, in init_net
model = nn.DataParallel(model, gpu_ids)
File "/usr/local/lib/python3.7/dist-packages/torch/nn/parallel/data_parallel.py", line 142, in __init__
_check_balance(self.device_ids)
File "/usr/local/lib/python3.7/dist-packages/torch/nn/parallel/data_parallel.py", line 23, in _check_balance
dev_props = _get_devices_properties(device_ids)
File "/usr/local/lib/python3.7/dist-packages/torch/_utils.py", line 487, in _get_devices_properties
return [_get_device_attr(lambda m: m.get_device_properties(i)) for i in device_ids]
File "/usr/local/lib/python3.7/dist-packages/torch/_utils.py", line 487, in <listcomp>
return [_get_device_attr(lambda m: m.get_device_properties(i)) for i in device_ids]
File "/usr/local/lib/python3.7/dist-packages/torch/_utils.py", line 470, in _get_device_attr
return get_member(torch.cuda)
File "/usr/local/lib/python3.7/dist-packages/torch/_utils.py", line 487, in <lambda>
return [_get_device_attr(lambda m: m.get_device_properties(i)) for i in device_ids]
File "/usr/local/lib/python3.7/dist-packages/torch/cuda/__init__.py", line 361, in get_device_properties
raise AssertionError("Invalid device id")
AssertionError: Invalid device id
How can i fix the device ID issue.
Hi, Thanks for your efforts in providing such a neat and handy framework.
However, it seems _request
with retain_model=False
is never called under the global evaluation setting. As a result, the sampled client won't download the global model since client.model
is not None.
Please correct me if there is anything I overlooked.
Thanks!
Hello. Thanks for the nice repo.
I have a question about the implementation of aggregation on the server. According to the McMahan 2016 paper, the server does
If I understand it correctly, there is a server-side optimizer in your code and your implementation can be translated into
This looks like a momentum update of the server's model. Could you please explain the motivation behind this?
Thank you for your nice work.
It seems that URL for downloading Sent140 dataset does not work. How can I fix this issue?
Fantastic work!
I'm a freshman in the FL field.
When I use your project, it raises an error about 'AttributeError: 'FedavgServer' object has no attribute 'lr_scheduler''
Hoping to get your help!
Many thanks.
Kindest regards,
Jinfeng
Here is a fix provided by ChatGPT:
The way your code is structured right now, it submits one job to the ThreadPoolExecutor
and then immediately waits for the result of that job before proceeding to the next iteration of the loop. This is due to the .result()
call right after the .submit()
call, which blocks the execution until the future is done and returns the result.
If you want to run the tasks in parallel, you should submit all the tasks first, and then wait for the results. Here is how you can adjust your code to do that:
results = []
futures = []
with concurrent.futures.ThreadPoolExecutor(max_workers=min(len(ids), os.cpu_count() - 1)) as workhorse:
for idx in TqdmToLogger(ids, logger=logger, desc=f'[{self.args.algorithm.upper()}] [Round: {str(self.round).zfill(4)}] ...update clients... ', total=len(ids)):
future = workhorse.submit(__update_clients, self.clients[idx])
futures.append(future)
for future in concurrent.futures.as_completed(futures):
result = future.result()
results.append(result)
In this code snippet, all the tasks are first submitted to the executor and their corresponding Future objects are stored in the futures
list. The concurrent.futures.as_completed(futures)
function is then used to yield the futures as they are completed, and the results are retrieved and stored in the results
list.
This way, the tasks will run in parallel (up to the number specified by max_workers
), and you can collect the results as they become available.
Good evening sir.
please I come across this error
"RecursionError: maximum recursion depth exceeded while calling a Python object"
whenever i run main.py even after installing all the dependencies
please is there anything I can do to make the code work?
Thanks for the well-organized code base. However, when I run the example experiments, it shows that two arguments are missing. Are the default values of them missing?
TypeError: _request() missing 2 required positional arguments: 'participated' and 'retain_model'
Hi, thank you for this repository. It has been really helpful.
I wanted to know if you have any plans to include DITTO Algorithm in your method also. As it could be helpful for some people trying to implement it.
DITTO Paper
Again, thanks for your time :)
Looking forward to the introduction, clear code on the CIFAR dataset! Thanks!
I ran the FedAvg code with the CNN2 given in model.py, with regard to the Cifar10 dataset. I also excluded the model initialization in server.py, and all of the clients (only 10) were set to update and upload their models to the server. However, over about 100 rounds, the accuracy can only raise up to around 70%, and do not go up afterwards. I wonder if there is anything I've missed or mistaken. Could anyone please offer me some advice?
First of all, amazing work! :)
There is a way to set a different number of Clients when using LEAF dataset? For example if I use FEMNIST fro LEAF the number of clients is 3597. Is possible to modify this for time and resources issue?
Thanks a lot!
In "utils.py", modules of Torch Vision are imported by from torchvision import datasets, transforms
.
However, in the line 106 of the same file, these modules are referred to by torchvision.datasets
, but there is no variable named by torchvision
.
I fixed it by rewrite the import statement by just import torchvision
. Now it works.
Please, can you provide more information about datasets location as exact name for files to download, website, etc? Thank you!
Hello! I was interested in potentially running this across a network connections, with the server on one device and clients running on other devices.
I looked through the commands
directory and the Readme.md, and also read through main.py it seems the existing examples assume that all code is running on the same machine? Is there a method to run server and client separately?
Hello, I am working on non-iid partition method in FL, and have some problems with Dirichlet partition.
In ./src/loaders/split.py, it seems like that :
Do not consider the situation where a client requires number of data in a class that is larger than remaining data in that class
Do not properly update the remaining class_idcs :
class_indices[required_class] = class_indices[required_class][:required_counts[idx]]
maybe it should be:
class_indices[required_class] = class_indices[required_class][required_counts[idx]:]
If the satisfied_counts is less than ideal_counts in the first while loop, then in the second while loop, the client is just given another ideal_counts sample (sampled = np.random.choice(args.num_classes, ideal_counts, p=cat_param)), which may lead to the client having sample way too much than ideal_counts.
Thanks for your code and hope you can help me with these concerns
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.