yoseflab / mrvi_archive Goto Github PK
View Code? Open in Web Editor NEWMulti-resolution Variational Inference
License: BSD 3-Clause "New" or "Revised" License
Multi-resolution Variational Inference
License: BSD 3-Clause "New" or "Revised" License
Hi! I'm using your package a lot for my PhD project, thanks for publishing it! I believe the most exciting discoveries can be made for large datasets. Unfortunately, the current setup of the package doesn't scale very well in terms of RAM usage. Especially this line of code makes me worried (and ruins my pipeline):
def compute_distance_matrix_from_representations(
...
pairwise_dists = np.zeros((n_cells, n_donors, n_donors))
For a dataset with ~1 Million cells and 1000 donors, it tries to allocate 8.75 TiB of memory. This makes MrVI barely applicable to the most exciting scenarios!
I wonder if users even need this tensor. Maybe it makes sense to aggregate distances somehow? For example, report one matrix of size n_donors * n_donors
with an average distance. It could then be calculated in a significantly more memory-efficient way.
I'd be glad to hear your thoughts on this and see the improvement of your wonderful package.
Thank you for this module which is highly relevant to my experiment where I have replicates of the same subject at different sites. I am unfortunately getting the "expected scalar type Double but found Float" when I am trying to the train the model after setting up the data, (the same data was used for scVI and it shows no errors, so I am assuming the data is fine.). After searching through the internet I get that this might be due to a type specification in pytorch. Example suggestion from the internet: "You need to cast your tensors to float32, either with dtype='float32' or calling float() on your input tensors.
I would be highly obliged if you could look into this.
cheers
shobhit
Traceback (most recent call last):
File "", line 1, in
File "/home/agrawals/.local/lib/python3.9/site-packages/mrvi/_model.py", line 157, in train
super().train(**train_kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/scvi/model/base/_training_mixin.py", line 77, in train
return runner()
File "/home/agrawals/.local/lib/python3.9/site-packages/scvi/train/_trainrunner.py", line 82, in call
self.trainer.fit(self.training_plan, self.data_splitter)
File "/home/agrawals/.local/lib/python3.9/site-packages/scvi/train/_trainer.py", line 188, in fit
super().fit(*args, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 696, in fit
self._call_and_handle_interrupt(
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 650, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 735, in _fit_impl
results = self._run(model, ckpt_path=self.ckpt_path)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1166, in _run
results = self._run_stage()
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1252, in _run_stage
return self._run_train()
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1283, in _run_train
self.fit_loop.run()
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
self.advance(*args, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/loops/fit_loop.py", line 271, in advance
self._outputs = self.epoch_loop.run(self._data_fetcher)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
self.advance(*args, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 203, in advance
batch_output = self.batch_loop.run(kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
self.advance(*args, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 87, in advance
outputs = self.optimizer_loop.run(optimizers, kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
self.advance(*args, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 201, in advance
result = self._run_optimization(kwargs, self._optimizers[self.optim_progress.optimizer_position])
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 248, in _run_optimization
self._optimizer_step(optimizer, opt_idx, kwargs.get("batch_idx", 0), closure)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 358, in _optimizer_step
self.trainer._call_lightning_module_hook(
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1550, in _call_lightning_module_hook
output = fn(*args, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/core/module.py", line 1705, in optimizer_step
optimizer.step(closure=optimizer_closure)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/core/optimizer.py", line 168, in step
step_output = self._strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/strategies/strategy.py", line 216, in optimizer_step
return self.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 153, in optimizer_step
return optimizer.step(closure=closure, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/torch/optim/optimizer.py", line 113, in wrapper
return func(*args, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/torch/optim/adam.py", line 118, in step
loss = closure()
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 138, in _wrap_closure
closure_result = closure()
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 146, in call
self._result = self.closure(*args, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 132, in closure
step_output = self._step_fn()
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 407, in _training_step
training_step_output = self.trainer._call_strategy_hook("training_step", *kwargs.values())
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1704, in _call_strategy_hook
output = fn(*args, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/strategies/strategy.py", line 358, in training_step
return self.model.training_step(*args, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/scvi/train/_trainingplans.py", line 351, in training_step
_, _, scvi_loss = self.forward(batch, loss_kwargs=self.loss_kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/scvi/train/_trainingplans.py", line 282, in forward
return self.module(*args, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/scvi/module/base/_decorators.py", line 33, in auto_transfer_args
return fn(self, *args, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/scvi/module/base/_base_module.py", line 276, in forward
return _generic_forward(
File "/home/agrawals/.local/lib/python3.9/site-packages/scvi/module/base/_base_module.py", line 837, in _generic_forward
inference_outputs = module.inference(**inference_inputs, **inference_kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/scvi/module/base/_decorators.py", line 33, in auto_transfer_args
return fn(self, *args, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/mrvi/module.py", line 122, in inference
x_feat = self.x_featurizer(x)
File "/home/agrawals/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/torch/nn/modules/container.py", line 139, in forward
input = module(input)
File "/home/agrawals/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/torch/nn/modules/linear.py", line 114, in forward
return F.linear(input, self.weight, self.bias)
RuntimeError: expected scalar type Double but found Float
First of all, thank you for a great tool! Unfortunately, it is impossible to run import it with the newest version of scvi-tool. When I run
import mrvi
I get the following error:
...
----> 5 from scvi.module.base import BaseModuleClass, LossRecorder, auto_move_data
6 from scvi.nn import one_hot
7 from torch.distributions import kl_divergence as kl
ImportError: cannot import name 'LossRecorder' from 'scvi.module.base'
The reason is this line in your code:
https://github.com/YosefLab/mrvi/blob/d1934e4889bbf383e411d2d39558488e1568fb0c/mrvi/_module.py#L5
In scvi-tools v 0.20 LossRecorder
was renamed to LossOutput
. I suggest checking the version of scvi-tool in the code before importing loss recorder or simply fixing it to <0.20 in requirements
Have a great day!
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.