Comments (18)
Anyone interested in implementing this?
Will need to make sure it's supported with distributed wrappers. Specifically, the DistributedDataset wrapper.
from pytorch-lightning.
Another option would be to just concat the datasets with torch.utils.data.ConcatDataset
.
It's kind of a quick fix but then I'm pretty sure that then the existing DistributedDataset wrapper should handle it the same as with having one dataset.
from pytorch-lightning.
@sidhanthholalkere yeah, actually that might be a better option.
I like deferring this stuff to PyTorch.
So to use 2 datasets the user would (according to this issue):
class ConcatDataset(torch.utils.data.Dataset):
def __init__(self, *datasets):
self.datasets = datasets
def __getitem__(self, i):
return tuple(d[i] for d in self.datasets)
def __len__(self):
return min(len(d) for d in self.datasets)
train_loader = torch.utils.data.DataLoader(
ConcatDataset(
datasets.ImageFolder(traindir_A),
datasets.ImageFolder(traindir_B)
),
batch_size=args.batch_size, shuffle=True,
num_workers=args.workers, pin_memory=True)
So looks like Lightning has to do nothing here except maybe add a documentation block to help users looking to do this?
Under the Validation step section:
- [Validation with multiple datasets][link to docs with details]
from pytorch-lightning.
My issue with that method is that some samples in the larger dataset will be left out. The only benefit would be that you get to validate a batch from both datasets at the same time, but in my experience, only the full validation loss(on whole validation set) matters(not per batch).
Now if the user wants separate losses for each set, i think the best option would be:
for dataset in datasets:
for batch in dataset:
model.validation_step(batch_a, batch_nb, dataset_index)
If they just want the validation on both sets combined, I still think that using the built in torch.utils.data.ConcatDataset
here is better because it handles some error checking and would essentially do the same thing as option B and C
from pytorch-lightning.
that looks great. @sidhanthholalkere Want to give it a shot and submit a PR?
File that need to be changed:
- trainer
- add support for returning either a single or multiple dataloaders from val_dataloader
- tests
- testModels (there are 2. 1 should return a single val_dataloader, the other should return 2. A test should be written to make sure both dataloaders are called and used correctly.
from pytorch-lightning.
I've started implementing this locally. I originally started by having validate()
take in dataset_index
so validation_step()
could have access to dataset_index
so the user can name the outputs accordingly, ie:
return {'val_loss_{}'.format(dataset_index}: whatever}
Looking back, it feels like this adds unnecessary complexity because the user has to decide what to do with dataset index.
Now, I'm trying a new approach where validate itself can take in the list of val_dataloaders (instead of previously having to enumerate through them and passing dataset_index) and then just append whatever dataset_index to each of the result keys, ie:
output = {key+str(dataset_index) : value for key, value in output.items()} if len(dataloader) > 1 else outputs
With this new method, there's no need to enumerate through val.dataloader when calling validate AND the user doesn't need to handle dataset_index.
What do you think of this new method?
Now for writing tests, how should I write the tests?
Do you want me to just check if model.nb_val_batches
is correct and if model.validate()
works(since that is all i'm changing, at least in the new method)?
Also, for creating the test with two val_dataloaders, should I just use the default get_model()
and override val_dataloader()
to
@ptl.data_loader
def val_dataloader(self):
print('val data loader called')
return [self.__dataloader(train=False) for i in range(2)]
from pytorch-lightning.
Interesting suggestion but I like the version with the dataset index, just pass it after batch_nb. This way it’s intuitive, users don’t have to read docs, and it remains fully flexible.
So, I propose we do something along the lines of:
for dataset in val_datasets:
for batch in dataset:
out = model.validation_step(batch, batch_nb, dataset_i)
It’s also backwards compatible and what a researcher would do if they had to implement it on their own.
from pytorch-lightning.
I've added support for multiple val_dataloaders on my fork.
Is there a specific way you want me to write tests?
Also, unfortunately I won't be able to run most of the other tests locally as I don't have a multiple gpu machine/apex working for now.
from pytorch-lightning.
awesome contribution!
for tests, can you add multiple val dataloaders to exampleModel used in tests? then modify what’s returned to be a accuracy and val loss indexed by the dataset.
add a separate test to trainer where it inits dataloaders, then check that the loaders are correct (we don’t have such test yet).
make sure to also check that all val dataloaders are being wrapped in distributed dataset (there’s a warning for that).
i can run in gpus once you submit.
from pytorch-lightning.
when you say exampleModel, do you mean the template model(lightning_module_template.py)?
Also, how would you want me to init multiple dataloaders, I could make a child of LightningTemplateModel where val_dataloader returns 2 dataloaders instead of one.
For the warning about val_dataloaders being wrapped, should I write another exception in trainer.py similar to the one that checks if ddp is used & if the tng_dataloader is a DistributedSampler.
from pytorch-lightning.
@sidhanthholalkere
-
this is exampleModel https://github.com/williamFalcon/pytorch-lightning/blob/master/pytorch_lightning/testing/lm_test_module.py.
-
return 2 dataloaders from the module above. That way some of the tests use a different model which returns 1 loader and other tests use this model which returns 2.
return [ds1, ds2]
-
just use
import warnings
warnings.warn('something to warn about')
4.in fact, while you're there, could you remove the exception about distsampler and turn it into a warning? that would solve: #81.
from pytorch-lightning.
Ok, I've made the changes, here is a summary of what I've done:
- Multiple
val_dataloader
support intrainer.py
- Added 2
val_dataloader
s forlm_test_module.py
(its just the same one twice - Added an output to validation_step (
if batch_i % 4 == 0
) that has the losses/accuracies indexed by dataset - Warning for if
val_dataloader
s are notDistributedSamplers
andddp
is selected - Test fit a model with multiple
val_dataloader
s and check if the length of the trainersval_dataloader
is 2(not sure about this test)
Let me know if anything should be changed(its on my fork if you want to check) before I submit a PR
Quick nitpick, in lm_test_module.py
and in validation_step()
, the acc and loss are named val_acc
and loss_val
, could I change it so the naming is consistent?
Also, in your test_models.py
, multiple comments say "traning complete" so I can fix that in another PR
from pytorch-lightning.
@sidhanthholalkere awesome! probably easier to make edits and comments on the PR itself!
from pytorch-lightning.
@sidhanthholalkere 4 was taken care off in a different PR by someone yesterday
from pytorch-lightning.
Any update?
I actually need to use two data loaders for validation and take the mean of the logits. I tried to return a list of loaders for val_dataloader
but it does not work:
TypeError: 'DataLoader' object is not subscriptable
.
from pytorch-lightning.
not live yet. it's on @sidhanthholalkere's branch. waiting on a PR to merge
from pytorch-lightning.
Fixing some errors, should be finished soon
from pytorch-lightning.
@lorenzoFabbri @sidhanthholalkere merged! was not super trivial to verify haha
from pytorch-lightning.
Related Issues (20)
- Support IO Type Checkpoints for trainer.fit() in ckpt_path Parameter
- shortcuts for logging weights and biases norms
- Unable to load Checkpoint
- Add param_group name for BaseFinetuningCallback
- LightningCLI: --help argument given after the subcommand fails
- ModelCheckpoint Callback not working/saving unless `save_on_train_epoch_end` is enabled True which considerably slows down training
- 7x slower training speed when switching from lightning 1.0 to 2.0
- Loading a model changes pytorch random state
- Allow passing custom reader/writer in _distributed_checkpoint_save and _distributed_checkpoint_load.
- Training crash when using XLA profiler on XLA accelerator and manual optimization
- Unexpected Behavior: `Fabric.load` operates out-of-place on nested states HOT 1
- ImportError: cannot import name '_TORCHMETRICS_GREATER_EQUAL_1_0_0' from 'pytorch_lightning.utilities.imports'
- KeyError:pytorch_lightning.utilities.argparse_utils
- You really should make the access to optimizers and schedulers more comprehensible and more detailed.
- Model does not update its weights HOT 3
- Switching into training mode in training_step
- Questions about loading a pre-trained model using lightnining CLI for continue training
- using deepspeed in pytorch lightning, a bug occurred : RuntimeError: Function ConvolutionBackward0 returned an invalid gradient at index 1
- NCCL error: Invalid rank requested HOT 1
- Can no longer install versions 1.5.10-1.6.5
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pytorch-lightning.