I just updated to the most recent code and I can no longer train models. It gets about 7% of the way through the first epoch and then crashes. Here's the full error message from the log.
File "/home/peastman/workspace/torchmd-net/scripts/train.py", line 172, in <module>
main()
File "/home/peastman/workspace/torchmd-net/scripts/train.py", line 165, in main
trainer.fit(model, data)
File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 458, in fit
self._run(model)
File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 756, in _run
self.dispatch()
File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 797, in dispatch
self.accelerator.start_training(self)
File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py", line 96, in start_training
self.training_type_plugin.start_training(trainer)
File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 144, in start_training
self._results = trainer.run_stage()
File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 807, in run_stage
return self.run_train()
File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 869, in run_train
self.train_loop.run_training_epoch()
File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 490, in run_training_epoch
batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx)
File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 731, in run_training_batch
self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure)
File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 432, in optimizer_step
using_lbfgs=is_lbfgs,
File "/home/peastman/workspace/torchmd-net/torchmdnet/module.py", line 111, in optimizer_step
super().optimizer_step(*args, **kwargs)
File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/pytorch_lightning/core/lightning.py", line 1403, in optimizer_step
optimizer.step(closure=optimizer_closure)
File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/pytorch_lightning/core/optimizer.py", line 214, in step
self.__optimizer_step(*args, closure=closure, profiler_name=profiler_name, **kwargs)
File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/pytorch_lightning/core/optimizer.py", line 134, in __optimizer_step
trainer.accelerator.optimizer_step(optimizer, self._optimizer_idx, lambda_closure=closure, **kwargs)
File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py", line 329, in optimizer_step
self.run_optimizer_step(optimizer, opt_idx, lambda_closure, **kwargs)
File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py", line 336, in run_optimizer_step
self.training_type_plugin.optimizer_step(optimizer, lambda_closure=lambda_closure, **kwargs)
File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 193, in optimizer_step
optimizer.step(closure=lambda_closure, **kwargs)
File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/torch/optim/optimizer.py", line 89, in wrapper
return func(*args, **kwargs)
File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/torch/optim/adamw.py", line 65, in step
loss = closure()
File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 726, in train_step_and_backward_closure
split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens
File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 814, in training_step_and_backward
result = self.training_step(split_batch, batch_idx, opt_idx, hiddens)
File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 280, in training_step
training_step_output = self.trainer.accelerator.training_step(args)
File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py", line 204, in training_step
return self.training_type_plugin.training_step(*args)
File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/ddp.py", line 319, in training_step
return self.model(*args, **kwargs)
File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 705, in forward
output = self.module(*inputs[0], **kwargs[0])
File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/pytorch_lightning/overrides/base.py", line 46, in forward
output = self.module.training_step(*inputs, **kwargs)
File "/home/peastman/workspace/torchmd-net/torchmdnet/module.py", line 44, in training_step
return self.step(batch, mse_loss, 'train')
File "/home/peastman/workspace/torchmd-net/torchmdnet/module.py", line 60, in step
pred, deriv = self(batch.z, batch.pos, batch.batch)
File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/peastman/workspace/torchmd-net/torchmdnet/module.py", line 41, in forward
return self.model(z, pos, batch=batch)
File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/peastman/workspace/torchmd-net/torchmdnet/models/output_modules.py", line 60, in forward
x, z, pos, batch = self.representation_model(z, pos, batch=batch)
File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/peastman/workspace/torchmd-net/torchmdnet/models/torchmd_gn.py", line 112, in forward
edge_index, edge_weight = self.distance(pos, batch)
File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/peastman/workspace/torchmd-net/torchmdnet/models/utils.py", line 182, in forward
max_num_neighbors=self.max_num_neighbors)
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/torch_cluster/radius.py", line 53, in radius_graph
if batch_x is not None:
assert x.size(0) == batch_x.numel()
batch_size = int(batch_x.max()) + 1
~~~ <--- HERE
deg = x.new_zeros(batch_size, dtype=torch.long)
RuntimeError: CUDA error: device-side assert triggered