Thank you for this module which is highly relevant to my experiment where I have replicates of the same subject at different sites. I am unfortunately getting the "expected scalar type Double but found Float" when I am trying to the train the model after setting up the data, (the same data was used for scVI and it shows no errors, so I am assuming the data is fine.). After searching through the internet I get that this might be due to a type specification in pytorch. Example suggestion from the internet: "You need to cast your tensors to float32, either with dtype='float32' or calling float() on your input tensors.
I would be highly obliged if you could look into this.
cheers
shobhit
Traceback (most recent call last):
File "", line 1, in
File "/home/agrawals/.local/lib/python3.9/site-packages/mrvi/_model.py", line 157, in train
super().train(**train_kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/scvi/model/base/_training_mixin.py", line 77, in train
return runner()
File "/home/agrawals/.local/lib/python3.9/site-packages/scvi/train/_trainrunner.py", line 82, in call
self.trainer.fit(self.training_plan, self.data_splitter)
File "/home/agrawals/.local/lib/python3.9/site-packages/scvi/train/_trainer.py", line 188, in fit
super().fit(*args, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 696, in fit
self._call_and_handle_interrupt(
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 650, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 735, in _fit_impl
results = self._run(model, ckpt_path=self.ckpt_path)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1166, in _run
results = self._run_stage()
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1252, in _run_stage
return self._run_train()
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1283, in _run_train
self.fit_loop.run()
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
self.advance(*args, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/loops/fit_loop.py", line 271, in advance
self._outputs = self.epoch_loop.run(self._data_fetcher)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
self.advance(*args, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 203, in advance
batch_output = self.batch_loop.run(kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
self.advance(*args, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 87, in advance
outputs = self.optimizer_loop.run(optimizers, kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
self.advance(*args, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 201, in advance
result = self._run_optimization(kwargs, self._optimizers[self.optim_progress.optimizer_position])
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 248, in _run_optimization
self._optimizer_step(optimizer, opt_idx, kwargs.get("batch_idx", 0), closure)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 358, in _optimizer_step
self.trainer._call_lightning_module_hook(
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1550, in _call_lightning_module_hook
output = fn(*args, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/core/module.py", line 1705, in optimizer_step
optimizer.step(closure=optimizer_closure)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/core/optimizer.py", line 168, in step
step_output = self._strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/strategies/strategy.py", line 216, in optimizer_step
return self.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 153, in optimizer_step
return optimizer.step(closure=closure, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/torch/optim/optimizer.py", line 113, in wrapper
return func(*args, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/torch/optim/adam.py", line 118, in step
loss = closure()
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 138, in _wrap_closure
closure_result = closure()
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 146, in call
self._result = self.closure(*args, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 132, in closure
step_output = self._step_fn()
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 407, in _training_step
training_step_output = self.trainer._call_strategy_hook("training_step", *kwargs.values())
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1704, in _call_strategy_hook
output = fn(*args, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/strategies/strategy.py", line 358, in training_step
return self.model.training_step(*args, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/scvi/train/_trainingplans.py", line 351, in training_step
_, _, scvi_loss = self.forward(batch, loss_kwargs=self.loss_kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/scvi/train/_trainingplans.py", line 282, in forward
return self.module(*args, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/scvi/module/base/_decorators.py", line 33, in auto_transfer_args
return fn(self, *args, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/scvi/module/base/_base_module.py", line 276, in forward
return _generic_forward(
File "/home/agrawals/.local/lib/python3.9/site-packages/scvi/module/base/_base_module.py", line 837, in _generic_forward
inference_outputs = module.inference(**inference_inputs, **inference_kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/scvi/module/base/_decorators.py", line 33, in auto_transfer_args
return fn(self, *args, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/mrvi/module.py", line 122, in inference
x_feat = self.x_featurizer(x)
File "/home/agrawals/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/torch/nn/modules/container.py", line 139, in forward
input = module(input)
File "/home/agrawals/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/torch/nn/modules/linear.py", line 114, in forward
return F.linear(input, self.weight, self.bias)
RuntimeError: expected scalar type Double but found Float