Giter Site home page Giter Site logo

ws-choi / conditioned-source-separation-lasaft Goto Github PK

View Code? Open in Web Editor NEW
84.0 6.0 18.0 201.35 MB

A PyTorch implementation of the paper: "LaSAFT: Latent Source Attentive Frequency Transformation for Conditioned Source Separation" (ICASSP 2021)

License: MIT License

Python 0.09% Jupyter Notebook 99.91%
source-separation pytorch pytorch-lightning musdb18

conditioned-source-separation-lasaft's Introduction

LaSAFT: Latent Source Attentive Frequency Transformation for Conditioned Source Separation

Updates

MDX Challenge (Leaderboard A)

model conditioned? vocals drums bass other Song
Demucs++ X 7.968 8.037 8.115 5.193 7.328
KUILAB-MDX-Net X 8.901 7.173 7.232 5.636 7.236
Kazane Team X 7.686 7.018 6.993 4.901 6.649
LASAFT-Net-v2.0 O 7.354 5.996 5.894 4.595 5.960
LaSAFT-Net-v1.2 O 7.275 5.935 5.823 4.557 5.897
Demucs48-HQ X 6.496 6.509 6.470 4.018 5.873
LaSAFT-Net-v1.1 O 6.685 5.272 5.498 4.121 5.394
XUMXPredictor X 6.341 5.807 5.615 3.722 5.372
UMXPredictor X 5.999 5.504 5.357 3.309 5.042

PWC

Check separated samples on this demo page!

An official Pytorch Implementation of the paper "LaSAFT: Latent Source Attentive Frequency Transformation for Conditioned Source Separation" (accepted to ICASSP 2021. (slide))

Demonstration: A Pretrained Model

demo

Interactive Demonstration - Colab Link

  • including how to download and use the pretrained model

Quickstart: How to use Pretrained Models

1. Install LaSAFT.

2. Load a Pretrained Model.

from lasaft.pretrained import PreTrainedLaSAFTNet
model = PreTrainedLaSAFTNet(model_name='lasaft_large_2020')

3. call model.separate_track !

# audio should be an np(numpy) array of an stereo audio track
# with dtype of float32
# shape must be (T, 2)
# python inference_example.py  assets\footprint.mp3

vocals = model.separate_track(audio, 'vocals', overlap_ratio=0.5)
drums = model.separate_track(audio, 'drums', overlap_ratio=0.5)
bass = model.separate_track(audio, 'bass', overlap_ratio=0.5)
other = model.separate_track(audio, 'other', overlap_ratio=0.5)

4. Example code

python inference_example.py  assets\footprint.mp3

Step-by-Step Tutorials

1. Installation

We highly recommend you to install environments using scripts below, even if we uploaded the pip-requirements.txt

conda env create -f lasaft_env_gpu.yaml -n lasaft
conda activate lasaft
pip install -r requirements.txt

2. Dataset: Musdb18

LaSAFT was trained/evaluated on the Musdb18 dataset.

We provide wrapper packages to efficiently load musdb18 tracks as pytorch tensors.

You can also find useful scripts for downloading and preprocessing Musdb18 (or its 7s-samples).

4. Logging (mandatory): wandb

This project uses wandb. Currently, this setting is mandatory.

To use this, you should copy your wandb apy key from wandb

wandb login -> settings -> Danger Zone -> API keys

Then please copy it and paste it to .env (there is a template file ./.env.sample as below.).

wandb_api_key= [YOUR WANDB API KEY] # go wandb.ai/settings and copy your key
data_dir= [Your MUSDBHQ Data PATH] # Your Musdb data directory. must be an absolute path.

5. Training

  • Below is an example to train a U-Net with LaSAFT+GPoCM, whose hyper-parameters are set as default.

    python train.py trainer.gpus=1 dataset.batch_size=6
  • train.py includes training scripts for several models described in the paper [1].

    • It provides several options, including pytorch-lightning parameters
    • model/conditioned_separation: CUNET_TFC_FiLM, CUNET_TFC_FiLM_LaSAFT, CUNET_TFC_FiLM_TDF, CUNET_TFC_GPoCM, CUNET_TFC_GPoCM_LaSAFT, CUNET_TFC_GPoCM_LightSAFT, CUNET_TFC_GPoCM_TDF, default, lasaft_net, lightsaft_net
  • An example of Training/Validation loss (see wandb report for more details)

Examples

  • Table 1 in [1]

    • FiLM CUNet

      python train.py model=conditioned_separation/CUNET_TFC_FiLM dataset.batch_size=8 trainer.precision=16 trainer.gpus=1 training.patience=10 training.lr=0.001 logger=wandb
    • FiLM CUNet + TDF

      python train.py model=conditioned_separation/CUNET_TFC_FiLM_TDF dataset.batch_size=8 trainer.precision=16 trainer.gpus=1 training.patience=10 training.lr=0.001 logger=wandb
    • FiLM CUNet + LaSAFT

      python train.py model=conditioned_separation/CUNET_TFC_FiLM_LaSAFT dataset.batch_size=8 trainer.precision=16 trainer.gpus=1 training.patience=10 training.lr=0.001 logger=wandb
    • GPoCM CUNet

      python train.py model=conditioned_separation/CUNET_TFC_GPoCM dataset.batch_size=8 trainer.precision=16 trainer.gpus=1 training.patience=10 training.lr=0.001 logger=wandb
    • GPoCM CUNet + TDF

      python train.py model=conditioned_separation/CUNET_TFC_GPoCM_TDF dataset.batch_size=8 trainer.precision=16 trainer.gpus=1 training.patience=10 training.lr=0.001 logger=wandb
    • GPoCM CUNet + LaSAFT (* proposed model)

      python train.py model=conditioned_separation/CUNET_TFC_GPoCM_LaSAFT dataset.batch_size=8 trainer.precision=16 trainer.gpus=1 training.patience=10 training.lr=0.001 logger=wandb
    • GPoCM CUNet + LightSAFT

      python train.py model=conditioned_separation/CUNET_TFC_GPoCM_LightSAFT dataset.batch_size=8 trainer.precision=16 trainer.gpus=1 training.patience=10 training.lr=0.001 logger=wandb
  • Table 2 in [1] (Multi-GPUs Version)

    • GPoCM CUNet + LaSAFT (* proposed model)
      python train.py model=conditioned_separation/CUNET_TFC_GPoCM_LaSAFT trainer=four_2080tis model.n_blocks=9 model.num_tdfs=6 model.embedding_dim=64 dataset.n_fft=4096 dataset.hop_length=1024 trainer.deterministic=True training.patience=10 training.lr=0.001 training.auto_lr_schedule=True logger=wandb training.run_id=lasaft-2020

tunable hyperparameters

train is powered by Hydra.

== Configuration groups ==
Compose your configuration from those groups (group=option)

dataset: default
eval: default
model/conditioned_separation: CUNET_TFC_FiLM, CUNET_TFC_FiLM_LaSAFT, CUNET_TFC_FiLM_TDF, CUNET_TFC_GPoCM, CUNET_TFC_GPoCM_LaSAFT, CUNET_TFC_GPoCM_LightSAFT, CUNET_TFC_GPoCM_TDF, base, film, gpocm, lasaft_net, lightsaft_net, tfc
trainer: default
training: default
training/train_loss: distortion, dsr, ldsr, ncs, ncs_44100, ndsr, ndsr_44100, nlcs, raw_and_spec, raw_l1, raw_l2, raw_mse, sdr, sdr_like, spec_l1, spec_l2, spec_mse
training/val_loss: distortion, dsr, ldsr, ncs, ncs_44100, ndsr, ndsr_44100, nlcs, raw_and_spec, raw_l1, raw_l2, raw_mse, sdr, sdr_like, spec_l1, spec_l2, spec_mse


== Config ==
Override anything in the config (foo.bar=value)

trainer:
  _target_: pytorch_lightning.Trainer
  checkpoint_callback: true
  callbacks: null
  default_root_dir: null
  gradient_clip_val: 0.0
  process_position: 0
  num_nodes: 1
  num_processes: 1
  gpus: null
  auto_select_gpus: false
  tpu_cores: null
  log_gpu_memory: null
  progress_bar_refresh_rate: 1
  overfit_batches: 0.0
  track_grad_norm: -1
  check_val_every_n_epoch: 1
  fast_dev_run: false
  accumulate_grad_batches: 1
  max_epochs: 1
  min_epochs: 1
  max_steps: null
  min_steps: null
  limit_train_batches: 1.0
  limit_val_batches: 1.0
  limit_test_batches: 1.0
  val_check_interval: 1.0
  flush_logs_every_n_steps: 100
  log_every_n_steps: 50
  accelerator: ddp
  sync_batchnorm: false
  precision: 16
  weights_summary: top
  weights_save_path: null
  num_sanity_val_steps: 2
  truncated_bptt_steps: null
  resume_from_checkpoint: null
  profiler: null
  benchmark: false
  deterministic: false
  reload_dataloaders_every_epoch: false
  auto_lr_find: false
  replace_sampler_ddp: true
  terminate_on_nan: false
  auto_scale_batch_size: false
  prepare_data_per_node: true
  amp_backend: native
  amp_level: O2
  move_metrics_to_cpu: false
dataset:
  _target_: lasaft.data.data_provider.DataProvider
  musdb_root: etc/musdb18_dev_wav
  batch_size: 8
  num_workers: 0
  pin_memory: true
  num_frame: 128
  hop_length: 1024
  n_fft: 2048
model:
  spec_type: complex
  spec_est_mode: mapping
  n_blocks: 7
  input_channels: 4
  internal_channels: 24
  first_conv_activation: relu
  last_activation: identity
  t_down_layers: null
  f_down_layers: null
  control_vector_type: embedding
  control_input_dim: 4
  embedding_dim: 32
  condition_to: decoder
  unfreeze_stft_from: -1
  control_n_layer: 4
  control_type: dense
  pocm_type: matmul
  pocm_norm: batch_norm
  _target_: lasaft.source_separation.conditioned.cunet.models.dcun_tfc_gpocm_lasaft.DCUN_TFC_GPoCM_LaSAFT_Framework
  n_internal_layers: 5
  kernel_size_t: 3
  kernel_size_f: 3
  bn_factor: 16
  min_bn_units: 16
  tfc_tdf_bias: false
  tfc_tdf_activation: relu
  num_tdfs: 6
  dk: 32
training:
  train_loss:
    _target_: lasaft.source_separation.conditioned.loss_functions.Conditional_Spectrogram_Loss
    mode: mse
  val_loss:
    _target_: lasaft.source_separation.conditioned.loss_functions.Conditional_RAW_Loss
    mode: l1
  ckpt_root_path: etc/checkpoints
  log: true
  run_id: ${now:%Y-%m-%d}/${now:%H-%M-%S}
  save_weights_only: false
  optimizer: adam
  lr: 0.001
  auto_lr_schedule: true
  save_top_k: 5
  patience: 10
  seed: 2020

5. Evaluation

python eval.py pretrained=lasaft_large_2021 overlap_ratio=0.5

see result here

You can cite this paper as follows:

@INPROCEEDINGS{9413896,
  author={Choi, Woosung and Kim, Minseok and Chung, Jaehwa and Jung, Soonyoung},
  booktitle={ICASSP 2021 - 2021 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, 
  title={Lasaft: Latent Source Attentive Frequency Transformation For Conditioned Source Separation}, 
  year={2021},
  volume={},
  number={},
  pages={171-175},
  doi={10.1109/ICASSP39728.2021.9413896}}

LaSAFT: Latent Source Attentive Frequency Transformation

GPoCM: Gated Point-wise Convolutional Modulation

Reference

[1] Woosung Choi, Minseok Kim, Jaehwa Chung, and Soonyoung Jung, “LaSAFT: Latent Source Attentive Frequency Transformation for Conditioned Source Separation.,” arXiv preprint arXiv:2010.11631 (2020).

Other Links

conditioned-source-separation-lasaft's People

Contributors

alswhdgus10 avatar taeminlee avatar ws-choi avatar yeongseokjeong avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

conditioned-source-separation-lasaft's Issues

Overlapping sliding window for separation

The current separate_track function iterates small and disjoint pieces of the whole track.
Overlapping sliding window like STFT might improve the separation quality.

Enhancement request : progress bar

Just a suggestion. Useful for those who not have local GPU (me for instance), the current process can be very long depending of the audio duration. Thank you very much.

Memory footprint in Google Colab

Thank you so much for this great model ! Wondeful job ! I have just a little question about the memory required for the separation. The model seem use a lot of memory and require to split the audio of a full song (> 1min / 1min30) in Google Colab (free version because no pro version for european users) and resample the audio to from hires (96000Hz) to lowres (44100Hz).

The current jupiter notebook show only process on very short samples (youtube video), I've slightly modify the code to allow using audio from Google Drive but seem to be limited to low resolution / short duration audio file without using splitting/merging audio subprocess. The same limitation of RAM footprint was resolved with Spleeter (Deezer) by a similar method but with some constraints (zero padding to remove in audio) (issue here : deezer/spleeter#391 (comment)).

Is someone already do the job?

how to run in distributed mode

Hi, I have not use pytorch_lightning before. I am dealing with large amount of data now. In you demo, you only show how to run with 1 GPU. Now I would like to run in multiple GPU modes. How can I do it?

I have simply set gpus to 4, but it showed errors.

Thank you very much!

WHY DOES LASAFT ONLY USE A CPU INSTEAD OF A GPU?

Hello

I wonder if LASAFT is able to separate music using GPU? (I'm not referring to model training)

Because I already did everything here, I installed CUDA, cuDNN, TensorFlow-GPU and even so LASAFT insists on using only CPU to separate the songs :(

I await feedback, thank you very much: D

Screenshot_1
Screenshot_2

cuda version of separate_track with minibatch

```separate_track```` does not have an explicit way to choose the device option

  • def separate_track (track, instrument, cuda=False)
  • def separate_track (track, instrument, cuda=False, batch_size=1)

Question about loss=nan.

Hello

I want to know when the run reaches 115 epochs with loss=nan, I checked the checkpoints and the last saved ckpt is at the 79th epoch

Using the example you gave python main.py --problem_name conditioned_separation --mode train --run_id lasaft_net --musdb_root etc/musdb18_dev_wav --gpus 1 --precision 16 --batch_size 6 --num_workers 0 --pin_memory True --save_top_k 3 --save_weights_only True --patience 10 --lr 0.001 --model CUNET_TFC_GPoCM_LaSAFT

I await feedback, thank you very much:
image

How to train the code on local?

Hello, I'm using Ubuntu 20.04 system and want to train this project on my pc.
I have downloaded the MUSDB18-HQ with all the .wav form profiles to my pc and created a conda environment with python=3.8 and torch=1.8.1 as shown in the readme.
In which path should I put my dataset?
In data/musdb_wrapper.py I see
` @staticmethod
def add_data_provider_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)

    parser.add_argument('--musdb_root', type=str, default='etc/musdb18_samples_wav/')

    return parser`

does this mean I need to change the "musdb18hq" foler name(with subfolders "train" and "test") into "musdb18_samples_wav" and put into the folder "etc"? Or should I take all the song folders into the "musdb18_samples_wav" folder without subfolders "train" and "test"?
Thank you.

length difference between input and output signal

I trained my own model and follow the quick start demo "Quickstart: How to use Pretrained Models". However I find the signal lengths of input and output are different. Is this right?

When I trained the model, I tuned the trim length, hop length and window length.

Thank you.

Error when loading the model

First of all, thanks for open-sourcing your work.

I'm trying to run your colab template but on the load pretrained parameters cell, I'm getting this:

RuntimeError                              Traceback (most recent call last)
<ipython-input-5-1c462f4386f0> in <module>()
----> 1 model = model.load_from_checkpoint('pretrained/gpocm_lasaft.ckpt')

2 frames
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
   1050         if len(error_msgs) > 0:
   1051             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
-> 1052                                self.__class__.__name__, "\n\t".join(error_msgs)))
   1053         return _IncompatibleKeys(missing_keys, unexpected_keys)
   1054 

RuntimeError: Error(s) in loading state_dict for DCUN_TFC_GPoCM_LaSAFT_Framework:
	Missing key(s) in state_dict: "spec2spec.first_conv.0.weight", "spec2spec.first_conv.0.bias", "spec2spec.first_conv.1.weight", "spec2spec.first_conv.1.bias", "spec2spec.first_conv.1.running_mean", "spec2spec.first_conv.1.running_var", "spec2spec.encoders.0.keys", "spec2spec.encoders.0.tfc.H.0.0.weight", "spec2spec.encoders.0.tfc.H.0.0.bias", "spec2spec.encoders.0.tfc.H.0.1.weight", "spec2spec.encoders.0.tfc.H.0.1.bias", "spec2spec.encoders.0.tfc.H.0.1.running_mean", "spec2spec.encoders.0.tfc.H.0.1.running_var", "spec2spec.encoders.0.tfc.H.1.0.weight", "spec2spec.encoders.0.tfc.H.1.0.bias", "spec2spec.encoders.0.tfc.H.1.1.weight", "spec2spec.encoders.0.tfc.H.1.1.bias", "spec2spec.encoders.0.tfc.H.1.1.running_mean", "spec2spec.encoders.0.tfc.H.1.1.running_var", "spec2spec.encoders.0.tfc.H.2.0.weight", "spec2spec.encoders.0.tfc.H.2.0.bias", "spec2spec.encoders.0.tfc.H.2.1.weight", "spec2spec.encoders.0.tfc.H.2.1.bias", "spec2spec.encoders.0.tfc.H.2.1.running_mean", "spec2spec.encoders.0.tfc.H.2.1.running_var", "spec2spec.encoders.0.tfc.H.3.0.weight", "spec2spec.encoders.0.tfc.H.3.0.bias", "spec2spec.encoders.0.tfc.H.3.1.weight", "spec2spec.encoders.0.tfc.H.3.1.bias", "spec2spec.encoders.0.tfc.H.3.1.running_mean", "spec2spec.encoders.0.tfc.H.3.1.running_var", "spec2spec.encoders.0.tfc.H.4.0.weight", "spec2spec.encoders.0.tfc.H.4.0.bias", "spec2spec.encoders.0.tfc.H.4.1.weight", "spec2spec.encoders.0.tfc.H.4.1.bias", "spec2spec.encoders.0.tfc.H.4.1.running_mean", "spec2spec.en...
	Unexpected key(s) in state_dict: "conditional_spec2spec.first_conv.0.weight", "conditional_spec2spec.first_conv.0.bias", "conditional_spec2spec.first_conv.1.weight", "conditional_spec2spec.first_conv.1.bias", "conditional_spec2spec.first_conv.1.running_mean", "conditional_spec2spec.first_conv.1.running_var", "conditional_spec2spec.first_conv.1.num_batches_tracked", "conditional_spec2spec.encoders.0.keys", "conditional_spec2spec.encoders.0.tfc.H.0.0.weight", "conditional_spec2spec.encoders.0.tfc.H.0.0.bias", "conditional_spec2spec.encoders.0.tfc.H.0.1.weight", "conditional_spec2spec.encoders.0.tfc.H.0.1.bias", "conditional_spec2spec.encoders.0.tfc.H.0.1.running_mean", "conditional_spec2spec.encoders.0.tfc.H.0.1.running_var", "conditional_spec2spec.encoders.0.tfc.H.0.1.num_batches_tracked", "conditional_spec2spec.encoders.0.tfc.H.1.0.weight", "conditional_spec2spec.encoders.0.tfc.H.1.0.bias", "conditional_spec2spec.encoders.0.tfc.H.1.1.weight", "conditional_spec2spec.encoders.0.tfc.H.1.1.bias", "conditional_spec2spec.encoders.0.tfc.H.1.1.running_mean", "conditional_spec2spec.encoders.0.tfc.H.1.1.running_var", "conditional_spec2spec.encoders.0.tfc.H.1.1.num_batches_tracked", "conditional_spec2spec.encoders.0.tfc.H.2.0.weight", "conditional_spec2spec.encoders.0.tfc.H.2.0.bias", "conditional_spec2spec.encoders.0.tfc.H.2.1.weight", "conditional_spec2spec.encoders.0.tfc.H.2.1.bias", "conditional_spec2spec.encoders.0.tfc.H.2.1.running_mean", "conditional_spec2spec.encoders.0.tfc....

It fails both on google Collab as well as running locally on my machine.

I really appreciate any help you can provide.

Geraldo

CSVLogger does not work

Hi, I would like to log the experiments on local only, so I changed the logger to CSVLogger in the lasaft/trainer.py:

log = args['log']
    if log == 'False':
        args['logger'] = False
    elif log == 'wandb':
        args['logger'] = WandbLogger(project='lasaft_exp', tags=[model_name], offline=False, name=run_id)
        args['logger'].log_hyperparams(model.hparams)
        args['logger'].watch(model, log='all')
    elif log == 'tensorboard':
        raise NotImplementedError
    else:
        args['logger'] = True  # default
        default_log_path = os.path.join(ckpt_path,'lightning_logs')
        args['logger'] = CSVLogger(default_log_path, version='0')
        mkdir_if_not_exists(default_log_path)

I also set progress_bar_refresh_rate = 0, then when I run main.py, it shows the error:

Traceback (most recent call last):
  File "/opt/conda/envs/lasaft/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 644, in run_train
    self.train_loop.run_training_epoch()
  File "/opt/conda/envs/lasaft/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py", line 564, in run_training_epoch
    self.trainer.run_evaluation(on_epoch=True)
  File "/opt/conda/envs/lasaft/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 758, in run_evaluation
    self.evaluation_loop.on_evaluation_end()
  File "/opt/conda/envs/lasaft/lib/python3.8/site-packages/pytorch_lightning/trainer/evaluation_loop.py", line 100, in on_evaluation_end
    self.trainer.call_hook('on_validation_end', *args, **kwargs)
  File "/opt/conda/envs/lasaft/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1101, in call_hook
    trainer_hook(*args, **kwargs)
  File "/opt/conda/envs/lasaft/lib/python3.8/site-packages/pytorch_lightning/trainer/callback_hook.py", line 183, in on_validation_end
    callback.on_validation_end(self, self.lightning_module)
  File "/opt/conda/envs/lasaft/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py", line 212, in on_validation_end
    self.save_checkpoint(trainer, pl_module)
  File "/opt/conda/envs/lasaft/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py", line 259, in save_checkpoint
    self._save_top_k_checkpoints(trainer, pl_module, monitor_candidates)
  File "/opt/conda/envs/lasaft/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py", line 566, in _save_top_k_checkpoints
    self._update_best_and_save(current, epoch, step, trainer, pl_module, metrics)
  File "/opt/conda/envs/lasaft/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py", line 608, in _update_best_and_save
    self._save_model(filepath, trainer, pl_module)
  File "/opt/conda/envs/lasaft/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py", line 335, in _save_model
    self.save_function(filepath, self.save_weights_only)
  File "/opt/conda/envs/lasaft/lib/python3.8/site-packages/pytorch_lightning/trainer/properties.py", line 327, in save_checkpoint
    self.checkpoint_connector.save_checkpoint(filepath, weights_only)
  File "/opt/conda/envs/lasaft/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py", line 408, in save_checkpoint
    atomic_save(checkpoint, filepath)
  File "/opt/conda/envs/lasaft/lib/python3.8/site-packages/pytorch_lightning/utilities/cloud_io.py", line 63, in atomic_save
    torch.save(checkpoint, bytesbuffer)
  File "/opt/conda/envs/lasaft/lib/python3.8/site-packages/torch/serialization.py", line 372, in save
    _save(obj, opened_zipfile, pickle_module, pickle_protocol)
  File "/opt/conda/envs/lasaft/lib/python3.8/site-packages/torch/serialization.py", line 476, in _save
    pickler.dump(obj)
TypeError: cannot pickle '_csv.writer' object

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "main.py", line 60, in <module>
    trainer.train(parser.parse_args(), hp)
  File "/home/feitao/Projects/music_unmix/lasaft/lasaft/source_separation/conditioned/scripts/trainer.py", line 177, in train
    trainer.fit(model, training_dataloader, validation_dataloader)
  File "/opt/conda/envs/lasaft/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 513, in fit
    self.dispatch()
  File "/opt/conda/envs/lasaft/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 553, in dispatch
    self.accelerator.start_training(self)
  File "/opt/conda/envs/lasaft/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 74, in start_training
    self.training_type_plugin.start_training(trainer)
  File "/opt/conda/envs/lasaft/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 111, in start_training
    self._results = trainer.run_train()
  File "/opt/conda/envs/lasaft/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 676, in run_train
    self.train_loop.on_train_end()
  File "/opt/conda/envs/lasaft/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py", line 134, in on_train_end
    self.check_checkpoint_callback(should_update=True, is_last=True)
  File "/opt/conda/envs/lasaft/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py", line 164, in check_checkpoint_callback
    cb.on_validation_end(self.trainer, model)
  File "/opt/conda/envs/lasaft/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py", line 212, in on_validation_end
    self.save_checkpoint(trainer, pl_module)
  File "/opt/conda/envs/lasaft/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py", line 259, in save_checkpoint
    self._save_top_k_checkpoints(trainer, pl_module, monitor_candidates)
  File "/opt/conda/envs/lasaft/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py", line 566, in _save_top_k_checkpoints
    self._update_best_and_save(current, epoch, step, trainer, pl_module, metrics)
  File "/opt/conda/envs/lasaft/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py", line 608, in _update_best_and_save
    self._save_model(filepath, trainer, pl_module)
  File "/opt/conda/envs/lasaft/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py", line 335, in _save_model
    self.save_function(filepath, self.save_weights_only)
  File "/opt/conda/envs/lasaft/lib/python3.8/site-packages/pytorch_lightning/trainer/properties.py", line 327, in save_checkpoint
    self.checkpoint_connector.save_checkpoint(filepath, weights_only)
  File "/opt/conda/envs/lasaft/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py", line 408, in save_checkpoint
    atomic_save(checkpoint, filepath)
  File "/opt/conda/envs/lasaft/lib/python3.8/site-packages/pytorch_lightning/utilities/cloud_io.py", line 63, in atomic_save
    torch.save(checkpoint, bytesbuffer)
  File "/opt/conda/envs/lasaft/lib/python3.8/site-packages/torch/serialization.py", line 372, in save
    _save(obj, opened_zipfile, pickle_module, pickle_protocol)
  File "/opt/conda/envs/lasaft/lib/python3.8/site-packages/torch/serialization.py", line 476, in _save
    pickler.dump(obj)
TypeError: cannot pickle '_csv.writer' object

I found if I don't use any logger, say only use:

log = args['log']
    if log == 'False':
        args['logger'] = False
    elif log == 'wandb':
        args['logger'] = WandbLogger(project='lasaft_exp', tags=[model_name], offline=False, name=run_id)
        args['logger'].log_hyperparams(model.hparams)
        args['logger'].watch(model, log='all')
    elif log == 'tensorboard':
        raise NotImplementedError
    else:
        args['logger'] = True  # default
        #default_log_path = os.path.join(ckpt_path,'lightning_logs')
        #args['logger'] = CSVLogger(default_log_path, version='0')
        #mkdir_if_not_exists(default_log_path)

Then the code works perfectly (except that I cannot track the loss). Can anyone help? Thank you.

Pretrained models

  1. Are lasaft_large_2020.ckpt and lasaft_large_2021.ckpt trained on "train" part of musdb18 or on full musdb18 ("train" and "test")?
  2. What is the difference between these models?

Validation time is too long

I want to know why the evaluation time is too long. Is this due to GPU issues? The GPU I am using is a GTX 3090, I checked volatile GPU-Util and found that it is found to be between 80% and 99%
image

Pruning and Quantization

Hi there!

Thank you very much for open sourcing code and such a great paper, awesome results!
I was wondering have you tried doing any pruning or quantization on the model?

Best,
Rich

PROBLEMS AND QUESTIONS

Hi all right, sorry to bother you boss,

I would like to know how do I train a model using my own dataset? i already trained my own models using vocal-remover from tsurumeso, now i would like to train a new model using the wonderful Lasaft

the architecture of the vocal remove datasets looks like this:

path/to/dataset/
+- instruments/
| +- 01_foo_inst.wav
| +- 02_bar_inst.mp3
| +- ...
+- mixtures/
+- 01_foo_mix.wav
+- 02_bar_mix.mp3
+- ...

1. I wonder if Lasaft also allows you to train a model using only Mixture and Instrumental?

2. I would like to congratulate you because you are on the top 1st Music Demixing (MDX), you are a genius, you deserve the award and much more!

3. Also, during these past months, have you trained any other .ckpt model for Lasaft? If so, could you share with those who appreciate your work so much?

4. Relating a problem... why when I use the Lasaft it makes the accompaniment/stems crash during the song? this problem occurs both in the 2stems version as well as the 4stems version

Sorry for the many questions, but this is the only place I can get in touch with you, your work is beautiful,

I look forward to hearing from you,

Yours truly,
Lucas Rodrigues.

colab error

from lasaft.source_separation.conditioned.cunet.models.dcun_tfc_gpocm_lasaft import DCUN_TFC_GPoCM_LaSAFT_Framework

args = {}

# FFT params
args['n_fft'] = 4096
args['hop_length'] = 1024
args['num_frame'] = 128

# SVS Framework
args['spec_type'] = 'complex'
args['spec_est_mode'] = 'mapping'

# Other Hyperparams
args['optimizer'] = 'adam'
args['lr'] = 0.0001
args['dev_mode'] = False
args['train_loss'] = 'spec_mse'
args['val_loss'] = 'raw_l1'

# DenseNet Hyperparams
args ['n_blocks'] = 9
args ['input_channels'] = 4
args ['internal_channels'] = 24
args ['first_conv_activation'] = 'relu'
args ['last_activation'] = 'identity'
args ['t_down_layers'] = None
args ['f_down_layers'] = None
args ['tif_init_mode'] = None

# TFC_TDF Block's Hyperparams
args['n_internal_layers'] =5
args['kernel_size_t'] = 3
args['kernel_size_f'] = 3
args['tfc_tdf_activation'] = 'relu'
args['bn_factor'] = 16
args['min_bn_units'] = 16
args['tfc_tdf_bias'] = True
args['num_tdfs'] = 6
args['dk'] = 32

args['control_vector_type'] = 'embedding'
args['control_input_dim'] = 4
args['embedding_dim'] = 64
args['condition_to'] = 'decoder'

args['control_n_layer'] = 4
args['control_type'] = 'dense'
args['pocm_type'] = 'matmul'
args['pocm_norm'] = 'batch_norm'

args['auto_lr_schedule'] = False

model = DCUN_TFC_GPoCM_LaSAFT_Framework(**args)
---------------------------------------------------------------------------
ImportError                               Traceback (most recent call last)
<ipython-input-7-d966a8f14f28> in <module>()
----> 1 from lasaft.source_separation.conditioned.cunet.models.dcun_tfc_gpocm_lasaft import DCUN_TFC_GPoCM_LaSAFT_Framework
      2 
      3 args = {}
      4 
      5 # FFT params

10 frames
/usr/local/lib/python3.7/dist-packages/pytorch_lightning/utilities/apply_func.py in <module>()
     23 TORCHTEXT_AVAILABLE = importlib.util.find_spec("torchtext") is not None
     24 if TORCHTEXT_AVAILABLE:
---> 25     from torchtext.data import Batch
     26 else:
     27     Batch = type(None)

ImportError: cannot import name 'Batch' from 'torchtext.data' (/usr/local/lib/python3.7/dist-packages/torchtext/data/__init__.py)

---------------------------------------------------------------------------
NOTE: If your import is failing due to a missing package, you can
manually install dependencies using either !pip or !apt.

To view examples of installing some common dependencies, click the
"Open Examples" button below.
---------------------------------------------------------------------------

evaluation result

Hi, I have already trained using the data set samples you provided, but I can't find where to get the evaluation results. Thank you very much for your help.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.