Giter Site home page Giter Site logo

mrvi_archive's Introduction

Multi-resolution Variational Inference

DEPRECIATED: please refer to the current version of the project for up-to-date code

Multi-resolution Variational Inference (MrVI) is a package for analysis of sample-level heterogeneity in multi-site, multi-sample single-cell omics data. Built with scvi-tools.


To install, run:

pip install mrvi

mrvi.MrVI follows the same API used in scvi-tools.

import mrvi
import anndata

adata = anndata.read_h5ad("path/to/adata.h5ad")
# Sample (e.g. donors, perturbations, etc.) should go in sample_key
# Sites, plates, and other factors should go in categorical_nuisance_keys
mrvi.MrVI.setup_anndata(adata, sample_key="donor", categorical_nuisance_keys=["site"])
mrvi_model = mrvi.MrVI(adata)
mrvi_model.train()
# Get z representation
adata.obsm["X_mrvi_z"] = mrvi_model.get_latent_representation(give_z=True)
# Get u representation
adata.obsm["X_mrvi_u"] = mrvi_model.get_latent_representation(give_z=False)
# Cells by n_sample by n_latent
cell_sample_representations = mrvi_model.get_local_sample_representation()
# Cells by n_sample by n_sample
cell_sample_sample_distances = mrvi_model.get_local_sample_representation(return_distances=True)

Citation

@article {Boyeau2022.10.04.510898,
	author = {Boyeau, Pierre and Hong, Justin and Gayoso, Adam and Jordan, Michael and Azizi, Elham and Yosef, Nir},
	title = {Deep generative modeling for quantifying sample-level heterogeneity in single-cell omics},
	elocation-id = {2022.10.04.510898},
	year = {2022},
	doi = {10.1101/2022.10.04.510898},
	publisher = {Cold Spring Harbor Laboratory},
	abstract = {Contemporary single-cell omics technologies have enabled complex experimental designs incorporating hundreds of samples accompanied by detailed information on sample-level conditions. Current approaches for analyzing condition-level heterogeneity in these experiments often rely on a simplification of the data such as an aggregation at the cell-type or cell-state-neighborhood level. Here we present MrVI, a deep generative model that provides sample-sample comparisons at a single-cell resolution, permitting the discovery of subtle sample-specific effects across cell populations. Additionally, the output of MrVI can be used to quantify the association between sample-level metadata and cell state variation. We benchmarked MrVI against conventional meta-analysis procedures on two synthetic datasets and one real dataset with a well-controlled experimental structure. This work introduces a novel approach to understanding sample-level heterogeneity while leveraging the full resolution of single-cell sequencing data.Competing Interest StatementN.Y. is an advisor and/or has equity in Cellarity, Celsius Therapeutics, and Rheos Medicine.},
	URL = {https://www.biorxiv.org/content/early/2022/10/06/2022.10.04.510898},
	eprint = {https://www.biorxiv.org/content/early/2022/10/06/2022.10.04.510898.full.pdf},
	journal = {bioRxiv}
}

mrvi_archive's People

Contributors

justjhong avatar pre-commit-ci[bot] avatar adamgayoso avatar martinkim0 avatar pierreboyeau avatar

Stargazers

 avatar Claudio Novella Rausell avatar weipeng avatar  avatar  avatar Samuel Marsh avatar Yu-Hsin Josch Hsieh avatar Alexandra P  avatar Devika Agarwal avatar Paul L. Maurizio avatar Marc Elosua Bayés avatar Ido Nofech-Mozes avatar James Cranley avatar Eljas Roellin avatar Lucas Plagwitz avatar  avatar Shobhit Agrawal avatar  avatar  avatar Emma Dann avatar Mark Keller avatar Jasim K.B. avatar

Watchers

 avatar  avatar  avatar  avatar

mrvi_archive's Issues

Improving scalability for large datasets

Hi! I'm using your package a lot for my PhD project, thanks for publishing it! I believe the most exciting discoveries can be made for large datasets. Unfortunately, the current setup of the package doesn't scale very well in terms of RAM usage. Especially this line of code makes me worried (and ruins my pipeline):

    def compute_distance_matrix_from_representations(
    ...
        pairwise_dists = np.zeros((n_cells, n_donors, n_donors))

For a dataset with ~1 Million cells and 1000 donors, it tries to allocate 8.75 TiB of memory. This makes MrVI barely applicable to the most exciting scenarios!

I wonder if users even need this tensor. Maybe it makes sense to aggregate distances somehow? For example, report one matrix of size n_donors * n_donors with an average distance. It could then be calculated in a significantly more memory-efficient way.

I'd be glad to hear your thoughts on this and see the improvement of your wonderful package.

Import fails with scvi-tools >= 0.20

First of all, thank you for a great tool! Unfortunately, it is impossible to run import it with the newest version of scvi-tool. When I run

import mrvi

I get the following error:

...
----> 5 from scvi.module.base import BaseModuleClass, LossRecorder, auto_move_data
      6 from scvi.nn import one_hot
      7 from torch.distributions import kl_divergence as kl

ImportError: cannot import name 'LossRecorder' from 'scvi.module.base' 

The reason is this line in your code:
https://github.com/YosefLab/mrvi/blob/d1934e4889bbf383e411d2d39558488e1568fb0c/mrvi/_module.py#L5

In scvi-tools v 0.20 LossRecorder was renamed to LossOutput. I suggest checking the version of scvi-tool in the code before importing loss recorder or simply fixing it to <0.20 in requirements

Have a great day!

getting RuntimeError: expected scalar type Double but found Float - when running model.train()

Thank you for this module which is highly relevant to my experiment where I have replicates of the same subject at different sites. I am unfortunately getting the "expected scalar type Double but found Float" when I am trying to the train the model after setting up the data, (the same data was used for scVI and it shows no errors, so I am assuming the data is fine.). After searching through the internet I get that this might be due to a type specification in pytorch. Example suggestion from the internet: "You need to cast your tensors to float32, either with dtype='float32' or calling float() on your input tensors.
I would be highly obliged if you could look into this.
cheers
shobhit

Traceback (most recent call last):
File "", line 1, in
File "/home/agrawals/.local/lib/python3.9/site-packages/mrvi/_model.py", line 157, in train
super().train(**train_kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/scvi/model/base/_training_mixin.py", line 77, in train
return runner()
File "/home/agrawals/.local/lib/python3.9/site-packages/scvi/train/_trainrunner.py", line 82, in call
self.trainer.fit(self.training_plan, self.data_splitter)
File "/home/agrawals/.local/lib/python3.9/site-packages/scvi/train/_trainer.py", line 188, in fit
super().fit(*args, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 696, in fit
self._call_and_handle_interrupt(
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 650, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 735, in _fit_impl
results = self._run(model, ckpt_path=self.ckpt_path)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1166, in _run
results = self._run_stage()
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1252, in _run_stage
return self._run_train()
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1283, in _run_train
self.fit_loop.run()
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
self.advance(*args, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/loops/fit_loop.py", line 271, in advance
self._outputs = self.epoch_loop.run(self._data_fetcher)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
self.advance(*args, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 203, in advance
batch_output = self.batch_loop.run(kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
self.advance(*args, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 87, in advance
outputs = self.optimizer_loop.run(optimizers, kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
self.advance(*args, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 201, in advance
result = self._run_optimization(kwargs, self._optimizers[self.optim_progress.optimizer_position])
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 248, in _run_optimization
self._optimizer_step(optimizer, opt_idx, kwargs.get("batch_idx", 0), closure)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 358, in _optimizer_step
self.trainer._call_lightning_module_hook(
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1550, in _call_lightning_module_hook
output = fn(*args, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/core/module.py", line 1705, in optimizer_step
optimizer.step(closure=optimizer_closure)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/core/optimizer.py", line 168, in step
step_output = self._strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/strategies/strategy.py", line 216, in optimizer_step
return self.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 153, in optimizer_step
return optimizer.step(closure=closure, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/torch/optim/optimizer.py", line 113, in wrapper
return func(*args, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/torch/optim/adam.py", line 118, in step
loss = closure()
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 138, in _wrap_closure
closure_result = closure()
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 146, in call
self._result = self.closure(*args, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 132, in closure
step_output = self._step_fn()
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 407, in _training_step
training_step_output = self.trainer._call_strategy_hook("training_step", *kwargs.values())
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1704, in _call_strategy_hook
output = fn(*args, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/strategies/strategy.py", line 358, in training_step
return self.model.training_step(*args, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/scvi/train/_trainingplans.py", line 351, in training_step
_, _, scvi_loss = self.forward(batch, loss_kwargs=self.loss_kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/scvi/train/_trainingplans.py", line 282, in forward
return self.module(*args, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/scvi/module/base/_decorators.py", line 33, in auto_transfer_args
return fn(self, *args, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/scvi/module/base/_base_module.py", line 276, in forward
return _generic_forward(
File "/home/agrawals/.local/lib/python3.9/site-packages/scvi/module/base/_base_module.py", line 837, in _generic_forward
inference_outputs = module.inference(**inference_inputs, **inference_kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/scvi/module/base/_decorators.py", line 33, in auto_transfer_args
return fn(self, *args, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/mrvi/module.py", line 122, in inference
x_feat = self.x_featurizer(x
)
File "/home/agrawals/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/torch/nn/modules/container.py", line 139, in forward
input = module(input)
File "/home/agrawals/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/torch/nn/modules/linear.py", line 114, in forward
return F.linear(input, self.weight, self.bias)
RuntimeError: expected scalar type Double but found Float

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.