Giter Site home page Giter Site logo

adobe-research / metaaf Goto Github PK

View Code? Open in Web Editor NEW
205.0 8.0 39.0 1017 KB

Control adaptive filters with neural networks.

Home Page: https://jmcasebeer.github.io/projects/metaaf

Python 97.28% Shell 1.42% Jupyter Notebook 1.30%
beamforming acoustic-echo-cancellation adaptive-filtering adaptive-filters dereverberation gsc linear-prediction system-identification blind-equalization weighted-prediction-error

metaaf's Introduction

Meta-AF: Meta-Learning for Adaptive Filters

Jonah Casebeer1*, Nicholas J. Bryan2, and Paris Smaragdis1

1 Department of Computer Science, University of Illinois at Urbana-Champaign
2 Adobe Research, Lead advisor
*Work performed while an intern at Adobe Research

Demo Video

Table of Contents

Abstract

Adaptive filtering algorithms are pervasive throughout signal processing and have had a material impact on a wide variety of domains including audio processing, telecommunications, biomedical sensing, astrophysics and cosmology, seismology, and many more. Adaptive filters typically operate via specialized online, iterative optimization methods such as least-mean squares or recursive least squares and aim to process signals in unknown or nonstationary environments. Such algorithms, however, can be slow and laborious to develop, require domain expertise to create, and necessitate mathematical insight for improvement. In this work, we seek to improve upon hand-derived adaptive filter algorithms and present a comprehensive framework for learning online, adaptive signal processing algorithms or update rules directly from data. To do so, we frame the development of adaptive filters as a meta-learning problem in the context of deep learning and use a form of self-supervision to learn online iterative update rules for adaptive filters. To demonstrate our approach, we focus on audio applications and systematically develop meta-learned adaptive filters for five canonical audio problems including system identification, acoustic echo cancellation, blind equalization, multi-channel dereverberation, and beamforming. We compare our approach against common baselines and/or recent state-of-the-art methods. We show we can learn high-performing adaptive filters that operate in real-time and, in most cases, significantly outperform each method we compare against -- all using a single general-purpose configuration of our approach.

For more details, please see: "Meta-AF: Meta-Learning for Adaptive Filters", Jonah Casebeer, Nicholas J. Bryan, and Paris Smaragdis, arXiv, 2022. Or, our talk:

Lecture Video

If you use ideas or code from this work, please cite our paper:

@article{casebeer2022meta,
  title={Meta-AF: Meta-Learning for Adaptive Filters},
  author={Casebeer, Jonah and Bryan, Nicholas J and Smaragdis, Paris},
  journal={arXiv preprint arXiv:2204.11942},
  year={2022}
}

Demos

For audio demonstrations of the work and metaaf package in action, please check out our demo website. You'll be able to find demos for the five core adaptive filtering problems.

Code

We open source all code for the work via our metaaf python pip package. Our metaaf package has functionality which enables meta-learning optimizers for near-arbitrary adaptive filters for any differentiable objective. metaaf automatically manages online overlap-save and overlap-add for single/multi channel and single/multi frame filters. We also include generic implementations of LMS, RMSProp, NLMS, and RLS for benchmarking purposes. Finally, metaaf includes implementation of generic GRU based optimizers, which are compatible with any filter defined in the metaaf format. Below, you can find example usage, usage for several common adaptive filter tasks (in the adaptive filter zoo), and installation instructions.

The metaaf package is relatively small, being limited to a dozen files which enable much more functionality than we demo here. The core meta-learning code is in core.py, the buffered and online filter implementations are in filter.py, and the RNN based optimizers are in optimizer_gru.py and optimizer_fgru.py. The remaining files hold utilities and generic implementations of baseline optimizers. meta.py contains a class for managing training.

Installation

To install the metaaf python package, you will need a working JAX install. You can set one up by following the official directions here. Below is an example of the commands we use to setup a new conda environment called metaenv in which we install metaaf and any dependencies.

GPU Setup

### GPU
# Install all the cuda and cudnn prerequisites
conda create -yn metaenv python=3.7 &&
conda install -yn metaenv cudatoolkit=11.1.1 -c pytorch -c conda-forge &&
conda install -yn metaenv cudatoolkit-dev=11.1.1 -c pytorch -c conda-forge &&
conda install -yn metaenv cudnn=8.2 -c nvidia -c pytorch -c anaconda -c conda-forge &&
conda install -yn metaenv pytorch cpuonly -c pytorch -y
conda activate metaenv

# Actually install jax
# You may need to change the cuda/cudnn version numbers depending on your machine
pip install jax[cuda11_cudnn82]==0.3.15 -f https://storage.googleapis.com/jax-releases/jax_releases.html  

# Install Haiku
pip install git+https://github.com/deepmind/[email protected]

CPU Setup

### CPU. x86 only
conda create -yn metaenv python=3.7 && 
conda install -yn metaenv pytorch torchvision torchaudio -c pytorch && 
conda install -yn metaenv pytorch cpuonly -c pytorch -y
conda activate metaenv

# Actually install jax
# You may need to change the cuda/cudnn version numbers depending on your machine
pip install --upgrade pip
pip install --upgrade "jax[cpu]"==0.3.15

# Install Haiku
pip install git+https://github.com/deepmind/[email protected]

Finally, with the prerequisites done, you can install metaaf by cloning the repo, moving into the base directory, and running pip install -e ./. This pip install adds the remaining dependencies. To run the demo notebook, you also need to:

# Add the conda env to your jupyter session
conda install -yn metaenv ipykernel 
ipython kernel install --user --name=metaenv

# Install plotting
pip install matplotlib

# Install widgets for a progress bar
pip install ipywidgets

Example Usage

The metaaf package provides several important modules to facilitate training. The first is the MetaAFTrainer, a class which manages training. To use the MetaAFTrainer, we need to define a filter architecture, and a dataset. metaaf adopts several conventions to simplify training and automate procedures like buffering. In this notebook, we walk through this process and demonstrate on a toy system-identification task. In this section, we explain that toy-task and the automatic argparse utilities. To see a full-scale example, proceed to the next section, where we describe the Meta-AF Zoo.

First, you need to make a datatset using a regular PyTorch dataset. The dataset must return a dictionary with two keys: "signals" and "metadata". The "signals" are automatically indexed and sliced and should be of size samples by channels.

class SystemIDDataset(Dataset):
    def __init__(self, N=4096, sys_order=32):
        self.N = N
        self.sys_order = sys_order

    def __len__(self):
        return 256

    def __getitem__(self, idx):
        # the system
        w = np.random.normal(size=self.sys_order) / self.sys_order

        # the input
        u = np.random.normal(size=self.N)

        # the output
        d = np.convolve(w, u)[: self.N]

        return {
            "signals": {
                "u": u[:, None], # time X channels
                "d": d[:, None], # time X channels
            },  
            "metadata": {},
        }
train_loader = NumpyLoader(SystemIDDataset(), batch_size=32)
val_loader = NumpyLoader(SystemIDDataset(), batch_size=32)
test_loader = NumpyLoader(SystemIDDataset(), batch_size=32)

Then, you define your filter. We're going to inherit from the metaaf OLS module. When inheriting, you can return either the current result, which will be automatically buffered, or a dictionary. When returning a dictionary it must have a key "out" which will be buffered. All other keys are stacked and returned.

from metaaf.filter import OverlapSave
# the filter inherits from the overlap save modules
class SystemID(OverlapSave, hk.Module):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        # select the analysize window
        self.analysis_window = jnp.ones(self.window_size)

    # Since we use the OLS base class, x and y are stft domain inputs.
    # The filter msut take the same inputs provided in its _fwd function.
    def __ols_call__(self, u, d, metadata):
        # collect a buffer sized anti-aliased filter
        w = self.get_filter(name="w")

        # this is n_frames x n_freq x channels or 1 x F x 1 here
        return (w * u)[0]
    
    @staticmethod
    def add_args(parent_parser):
        return super(SystemID, SystemID).add_args(parent_parser)

    @staticmethod
    def grab_args(kwargs):
        return super(SystemID, SystemID).grab_args(kwargs)

Haiku converts objects to functions. We need to provide a wrapper to do this. The wrapper function MUST take as input the same named values from your dataset.

def _SystemID_fwd(u, d, metadata=None, init_data=None, **kwargs):
    f = SystemID(**kwargs)
    return f(u=u, d=d)

Then, we define an adaptive filter loss. Here, just the MSE. An adaptive filter loss must be written in this form, so that metaaf can automatically take its gradient and pass it around.

def filter_loss(out, data_samples, metadata):
    e =  out - data_samples["d"]
    return jnp.vdot(e, e) / (e.size)

We can construct the meta-train and meta-val losses in a similar fashion.

def meta_train_loss(losses, outputs, data_samples, metadata, outer_index, outer_learnable):
    out = jnp.concatenate(outputs["out"], 0)
    return jnp.log(jnp.mean(jnp.abs(out - data_samples["d"]) ** 2) +  1e-9)

def meta_val_loss(losses, outputs, data_samples, metadata, outer_learnable):
    out = jnp.reshape(
        outputs["out"],
        (outputs["out"].shape[0], -1, outputs["out"].shape[-1]),
    )
    d = data_samples["d"]
    min_len = min(out.shape[1], d.shape[1])
    return jnp.log(jnp.mean(jnp.abs(out[:, :min_len] - d[:, :min_len]) ** 2) +  1e-9)

With everything defined, we can setup the Meta-Trainer and start training.

from metaaf.optimizer_gru import EGRU

# Collect arguments
parser = argparse.ArgumentParser()
parser.add_argument("--name", type=str, default="")
parser = EGRU.add_args(parser)
parser = SystemID.add_args(parser)
parser = MetaAFTrainer.add_args(parser)
kwargs = vars(parser.parse_args())

# Setup trainer
system = MetaAFTrainer(
    _filter_fwd=_SystemID_fwd,
    filter_kwargs=SystemID.grab_args(kwargs),
    filter_loss=filter_loss,
    meta_train_loss=meta_train_loss,
    meta_val_loss=meta_val_loss,
    optimizer_kwargs=EGRU.grab_args(kwargs),
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
)
# Train
key = jax.random.PRNGKey(0)
outer_learned, losses = system.train(
    **MetaAFTrainer.grab_args(kwargs),
    key=key,
)

That is it! For more advanced options check out the zoo, where we demonstrate call backs, customized filters, and more.

Meta-AF Zoo

The Meta-AF Zoo contains implementations for system identification, acoustic echo cancellation, equalization, weighted predection error dereverberation, and a generalized sidelobe canceller beamformer all in the metaaf framework. You can find intructions for how to run, evaluate, and setup those models here. For trained weights, and tuned baselines, please see the tagged release zip file here.

License

All core utility code within the metaaf folder is licensed via the University of Illinois Open Source License. All code within the zoo folder and model weights are licensed via the Adobe Research License. Copyright (c) Adobe Systems Incorporated. All rights reserved.

Related Works

An extension of this work using metaaf here:

"Meta-Learning for Adaptive Filters with Higher-Order Frequency Dependencies", Junkai Wu, Jonah Casebeer, Nicholas J. Bryan, and Paris Smaragdis, IWAENC, 2022.

@article{wu2022metalearning,
  title={Meta-Learning for Adaptive Filters with Higher-Order Frequency Dependencies},
  author={Wu, Junkai and Casebeer, Jonah and Bryan, Nicholas J. and Smaragdis, Paris},    
  booktitle={IEEE International Workshop on Acoustic Signal Enhancement (IWAENC)},
  year={2022},
}

An early version of this work:

"Auto-DSP: Learning to Optimize Acoustic Echo Cancellers", Jonah Casebeer, Nicholas J. Bryan, and Paris Smaragdis, WASPAA, 2021.

@inproceedings{casebeer2021auto,
  title={Auto-DSP: Learning to Optimize Acoustic Echo Cancellers},
  author={Casebeer, Jonah and Bryan, Nicholas J. and Smaragdis, Paris},
  booktitle={2021 IEEE Workshop on Applications of Signal Processing to Audio and Acoustics (WASPAA)},
  pages={291--295},
  year={2021},
  organization={IEEE}
}

metaaf's People

Contributors

jmcasebeer avatar njb avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

metaaf's Issues

Could not get the same AEC results shown on the demo page with the provided pretrained models

Excellent work! Thanks for sharing the code base and pretrained models.

I would like to try the AEC performace of Meta-AF using your pretrained models. To make sure that I use them correctly, I downloaded the wav files of the first double-talk sample on your demo website and ran AEC with the pretrained model v0.1.0_models/aec/aec_16_dt_c/2022_04_10_15_57_12/epoch_230.pkl. However, I can only get much worse AEC result than you provided on the demo website. Could you please help me out? The test code I used:

import os
from aec_eval import get_system_ckpt
import numpy as np
import librosa
import soundfile as sf

ckpt_dir = "v0.1.0_models/aec/"
name = "aec_16_dt_c"
date = "2022_04_10_15_57_12"
epoch = 230

ckpt_loc = os.path.join(ckpt_dir, name, date)

system, kwargs, outer_learnable = get_system_ckpt(
    ckpt_loc,
    epoch,
    model_type="egru",
    system_len=None,
)
fit_infer = system.make_fit_infer(outer_learnable=outer_learnable)
fs = 16000

out_dir = "metaAF_res"
os.makedirs(out_dir, exist_ok=True)

u, _ = librosa.load("u.mp3", sr=fs)
d, _ = librosa.load("d.mp3", sr=fs)
s, _ = librosa.load("s.mp3", sr=fs)
e = d - s

d_input = {"u": u[None, :, None], "d": d[None, :, None],
           "s": s[None, :, None], "e": e[None, :, None]
           }
pred = system.infer({"signals": d_input, "metadata": {}}, fit_infer=fit_infer)[
    0
]
pred = np.array(pred[0, :, 0])

sf.write(os.path.join(out_dir, f"_out.wav"), pred, fs)

Looking forward to hearing from you, thanks!
Best,
Fei

with retrieved shape (4, 32) does not match shape=[5, 32] dtype=dtype('complex64')

Hello, when i run code from second closed issue where is used pre-trained models and aec i get this error can you help me with this ?
Exception has occurred: ValueError (note: full exception trace is shown but execution is paused at: _run_module_as_main)
'ElementWiseGRU//linear/w' with retrieved shape (4, 32) does not match shape=[5, 32] dtype=dtype('complex64')
File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\site-packages\haiku_src\base.py", line 685, in get_parameter
raise ValueError(
File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\site-packages\haiku_src\basic.py", line 179, in call
w = hk.get_parameter("w", [input_size, output_size], dtype, init=w_init)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\site-packages\haiku_src\module.py", line 299, in run_interceptors
return bound_method(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\contextlib.py", line 81, in inner
return func(*args, **kwds)
^^^^^^^^^^^^^^^^^^^
File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\site-packages\haiku_src\module.py", line 458, in wrapped
out = f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^
File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\site-packages\haiku_src\basic.py", line 126, in call
out = layer(out, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\site-packages\haiku_src\module.py", line 299, in run_interceptors
return bound_method(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\contextlib.py", line 81, in inner
return func(*args, **kwds)
^^^^^^^^^^^^^^^^^^^
File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\site-packages\haiku_src\module.py", line 458, in wrapped
out = f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^
File "C:\Users\pc\Desktop\AI-Beamformers\meta-af\MetaAF\metaaf\optimizer_gru.py", line 71, in preprocess_flatten
return self.in_lin(input_stack_flat)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\site-packages\haiku_src\module.py", line 299, in run_interceptors
return bound_method(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\contextlib.py", line 81, in inner
return func(*args, **kwds)
^^^^^^^^^^^^^^^^^^^
File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\contextlib.py", line 81, in inner
return func(*args, **kwds)
^^^^^^^^^^^^^^^^^^^
File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\site-packages\haiku_src\module.py", line 458, in wrapped
out = f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^
File "C:\Users\pc\Desktop\AI-Beamformers\meta-af\MetaAF\metaaf\optimizer_gru.py", line 80, in call
rnn_in = self.preprocess_flatten(x, extra_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\site-packages\haiku_src\module.py", line 299, in run_interceptors
return bound_method(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\contextlib.py", line 81, in inner
return func(*args, **kwds)
^^^^^^^^^^^^^^^^^^^
File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\site-packages\haiku_src\module.py", line 458, in wrapped
out = f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^
File "C:\Users\pc\Desktop\AI-Beamformers\meta-af\MetaAF\metaaf\optimizer_gru.py", line 122, in _fwd
return optimizer(x, h, extra_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\site-packages\haiku_src\transform.py", line 456, in apply_fn
out = f(*args, **kwargs)
File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\site-packages\haiku_src\transform.py", line 183, in apply_fn
out, state = f.apply(params, None, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\pc\Desktop\AI-Beamformers\meta-af\MetaAF\metaaf\optimizer_gru.py", line 212, in update
update, state = optimizer.apply(
^^^^^^^^^^^^^^^^
File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\site-packages\jax\example_libraries\optimizers.py", line 199, in tree_update
new_states = map(partial(update, i), grad_flat, states)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\pc\Desktop\AI-Beamformers\meta-af\MetaAF\metaaf\core.py", line 462, in online_step
opt_s = opt_update(0, filter_features, opt_s)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\pc\Desktop\AI-Beamformers\meta-af\MetaAF\metaaf\core.py", line 549, in fit_single
cur_out, loss, batch_state = batch_step(
^^^^^^^^^^^
File "C:\Users\pc\Desktop\AI-Beamformers\meta-af\MetaAF\metaaf\meta.py", line 825, in infer
out, aux = fit_infer(
^^^^^^^^^^
File "C:\Users\pc\Desktop\AI-Beamformers\meta-af\MetaAF\zoo\aec\start.py", line 38, in
pred = system.infer({"signals": d_input, "metadata": {}}, fit_infer=fit_infer)[0]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\runpy.py", line 88, in _run_code
exec(code, run_globals)
File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\runpy.py", line 198, in _run_module_as_main (Current frame)
return _run_code(code, main_globals, None,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: 'ElementWiseGRU//linear/w' with retrieved shape (4, 32) does not match shape=[5, 32] dtype=dtype('complex64')?

AEC Data Setup Issues

I have downloaded the ''AEC-Challenge-main'' dataset and "RIRS_NOISES" dataset and I have set the paths in the config.
I would like to try the AEC performance of Meta-AF using your pre-trained models. When I run this code which you tell in the tutorial:
(!python /content/MetaAF/zoo/aec/aec.py --n_frames 1 --window_size 2048 --hop_size 1024 --n_in_chan 1 --n_out_chan 1 --is_real --n_devices 1 --batch_size 64 --total_epochs 1000 --val_period 10 --reduce_lr_patience 1 --early_stop_patience 4 --name meta_aec_demo --unroll 16 --extra_signals ude --random_roll --outer_loss log_self_mse --double_talk --dataset nonlinear)
I will see 3 choices :
wandb: (1) Create a W&B account
wandb: (2) Use an existing W&B account
wandb: (3) Don't visualize my results
If I select 1 and after that register for a W&B account when I enter the wandb backend code. I received from the (wandb.errors.CommError: Permission denied, ask the project owner to grant you access)
when I select 3 I will receive (RuntimeError: Error opening '/content/AEC-Challenge main/datasets/synthetic/farend_speech/farend_speech_fileid_0.wav': File contains data in an unknown format.)
Could you please help me to solve this error?

Originally posted by @Alirezanezamdoost in #4 (comment)

some module l can't find ,help me

1.Traceback (most recent call last):
File "/home/easymoney/kk/MetaAF/metaaf/core.py", line 6, in
from metaaf.optimizer_utils import FeatureContainer
ModuleNotFoundError: No module named 'metaaf'
2.ModuleNotFoundError Traceback (most recent call last)
/home/easymoney/kk/MetaAF/examples/sysid_demo.ipynb Cell 1 line 1
----> 1](vscode-notebook-cell://wsl%2Bubuntu-22.04/home/easymoney/kk/MetaAF/examples/sysid_demo.ipynb#W0sdnNjb2RlLXJlbW90ZQ%3D%3D?line=0'%3E1%3C/a%3E) get_ipython().run_line_magic('load_ext', 'lab_black')
2](vscode-notebook-cell://wsl%2Bubuntu-22.04/home/easymoney/kk/MetaAF/examples/sysid_demo.ipynb#W0sdnNjb2RlLXJlbW90ZQ%3D%3D?line=1'%3E2%3C/a%3E) get_ipython().run_line_magic('load_ext', 'autoreload')
3](vscode-notebook-cell://wsl%2Bubuntu-22.04/home/easymoney/kk/MetaAF/examples/sysid_demo.ipynb#W0sdnNjb2RlLXJlbW90ZQ%3D%3D?line=2'%3E3%3C/a%3E) get_ipython().run_line_magic('autoreload', '2')

File ~/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:2432, in InteractiveShell.run_line_magic(self, magic_name, line, _stack_depth)
2430 kwargs['local_ns'] = self.get_local_scope(stack_depth)
2431 with self.builtin_trap:
-> 2432 result = fn(*args, **kwargs)
2434 # The code below prevents the output from being displayed
2435 # when using magics with decorator @output_can_be_silenced
2436 # when the last Python token in the expression is a ';'.
2437 if getattr(fn, magic.MAGIC_OUTPUT_CAN_BE_SILENCED, False):

File ~/.local/lib/python3.10/site-packages/IPython/core/magics/extension.py:33, in ExtensionMagics.load_ext(self, module_str)
31 if not module_str:
32 raise UsageError('Missing module name.')
---> 33 res = self.shell.extension_manager.load_extension(module_str)
35 if res == 'already loaded':
36 print("The %s extension is already loaded. To reload it, use:" % module_str)

File ~/.local/lib/python3.10/site-packages/IPython/core/extensions.py:76, in ExtensionManager.load_extension(self, module_str)
69 """Load an IPython extension by its module name.
...
File :1027, in find_and_load(name, import)

File :1004, in find_and_load_unlocked(name, import)

ModuleNotFoundError: No module named 'lab_black

Why not use nearend speech s to be the target ?

Thanks for sharing your excellent work!
Could you please explain that why not use the clean or noisy nearend speech to be the network target ? And why use the mic signal d be
the target can cover double talk scenes ? Thank you.

Compatibility with x64 CPU

Hello, I have AMD graphics card and i can't install JAX because it is not available for windows on AMD GPUs, is this program compatible with x64 based CPUs or just x86? I have installed JAX for CPU and i just want to know if it is compatible?

I've got a few ideas.

I have an idea that change your specialised focus on linear FD adaptive filters to non-linear adaptive filters and use for continuous signals, I don't know if it's possible, I hope you can answer my confusion.If so, what should I do to improve.thanks

You must pass a non-None PRNGKey to init and/or apply if you make use of random numbers when running hoaec_eval.py

I wish to run hoaec_eval.py trough the r_es_ckpts_run.sh_ bash script to try out the algorithms and I get the "You must pass a non-None PRNGKey to init and/or apply if you make use of random numbers when running hoaec_eval.py" error when running. The code is practically the same as the repository, I only changed the config.py constant values to match my folders. Here's the complete Traceback. Hope you can help!

Storing AEC outputs...
/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/torch/utils/data/dataloader.py:561: UserWarning: This DataLoader will create 10 worker processes in total. Our suggested max number of worker in current system is 8 (`cpuset` is not taken into account), which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(_create_warning_msg(
  0%|                                                                                                                                                                                                                                    | 0/32 [00:00<?, ?it/s]/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/metaaf/complex_gru.py:114: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
  return jax.tree_map(broadcast, nest)
  0%|                                                                                                                                                                                                                                    | 0/32 [00:11<?, ?it/s]
Traceback (most recent call last):
  File "/Users/agus/work/eye-predict/audioEngineering/External-repos/MetaAFPackage/hoaec_eval.py", line 225, in <module>
    preds = system.infer(data)[0]
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/metaaf/meta.py", line 825, in infer
    out, aux = fit_infer(
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/metaaf/core.py", line 549, in fit_single
    cur_out, loss, batch_state = batch_step(
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/_src/api.py", line 1564, in vmap_f
    out_flat = batching.batch(
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/linear_util.py", line 168, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/_src/api.py", line 526, in cache_miss
    out_flat = xla.xla_call(
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/core.py", line 1919, in bind
    return call_bind(self, fun, *args, **params)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/core.py", line 1935, in call_bind
    outs = top_trace.process_call(primitive, fun_, tracers, params)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/interpreters/batching.py", line 233, in process_call
    vals_out = call_primitive.bind(f_, *vals, **params)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/core.py", line 1919, in bind
    return call_bind(self, fun, *args, **params)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/core.py", line 1935, in call_bind
    outs = top_trace.process_call(primitive, fun_, tracers, params)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/core.py", line 687, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/_src/dispatch.py", line 199, in _xla_call_impl
    compiled_fun = xla_callable(fun, device, backend, name, donated_invars,
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/linear_util.py", line 295, in memoized_fun
    ans = call(fun, *args)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/_src/dispatch.py", line 248, in _xla_callable_uncached
    return lower_xla_callable(fun, device, backend, name, donated_invars, False,
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/_src/profiler.py", line 294, in wrapper
    return func(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/_src/dispatch.py", line 293, in lower_xla_callable
    jaxpr, out_type, consts = pe.trace_to_jaxpr_final2(
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/_src/profiler.py", line 294, in wrapper
    return func(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 2167, in trace_to_jaxpr_final2
    jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 2117, in trace_to_subjaxpr_dynamic2
    ans = fun.call_wrapped(*in_tracers_)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/linear_util.py", line 168, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/metaaf/core.py", line 462, in online_step
    opt_s = opt_update(0, filter_features, opt_s)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/example_libraries/optimizers.py", line 196, in tree_update
    new_states = map(partial(update, i), grad_flat, states)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/_src/util.py", line 47, in safe_map
    return list(map(f, *args))
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/metaaf/optimizer_hogru.py", line 471, in update
    update, state = optimizer.apply(
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/transform.py", line 128, in apply_fn
    out, state = f.apply(params, {}, *args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/transform.py", line 357, in apply_fn
    out = f(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/metaaf/optimizer_hogru.py", line 234, in _fwd
    return optimizer(x, h, extra_inputs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/module.py", line 434, in wrapped
    out = f(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/module.py", line 273, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/metaaf/optimizer_hogru.py", line 183, in __call__
    rnn_in = self.preprocess_flatten(x, extra_inputs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/module.py", line 434, in wrapped
    out = f(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/module.py", line 273, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/metaaf/optimizer_hogru.py", line 149, in preprocess_flatten
    input_stack_flat = self.in_coupling_conv(
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/module.py", line 434, in wrapped
    out = f(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/module.py", line 273, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/basic.py", line 123, in __call__
    out = layer(out, *args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/module.py", line 434, in wrapped
    out = f(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/module.py", line 273, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/conv.py", line 200, in __call__
    w = hk.get_parameter("w", w_shape, inputs.dtype, init=w_init)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/base.py", line 448, in wrapped
    return wrapped._current(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/base.py", line 515, in get_parameter
    param = init(shape, dtype)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/metaaf/complex_utils.py", line 12, in complex_variance_scaling
    real = hk.initializers.VarianceScaling()(shape, dtype=jnp.float32)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/initializers.py", line 215, in __call__
    return TruncatedNormal(stddev=stddev)(shape, dtype)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/initializers.py", line 114, in __call__
    unscaled = jax.random.truncated_normal(hk.next_rng_key(), -2., 2., shape,
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/base.py", line 448, in wrapped
    return wrapped._current(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/base.py", line 965, in next_rng_key
    return next_rng_key_internal()
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/base.py", line 1003, in next_rng_key_internal
    rng_seq = rng_seq_or_fail()
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/base.py", line 923, in rng_seq_or_fail
    raise ValueError("You must pass a non-None PRNGKey to init and/or apply "
jax._src.traceback_util.UnfilteredStackTrace: ValueError: You must pass a non-None PRNGKey to init and/or apply if you make use of random numbers.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/agus/work/eye-predict/audioEngineering/External-repos/MetaAFPackage/hoaec_eval.py", line 225, in <module>
    preds = system.infer(data)[0]
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/metaaf/meta.py", line 825, in infer
    out, aux = fit_infer(
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/metaaf/core.py", line 549, in fit_single
    cur_out, loss, batch_state = batch_step(
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/metaaf/core.py", line 462, in online_step
    opt_s = opt_update(0, filter_features, opt_s)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/example_libraries/optimizers.py", line 196, in tree_update
    new_states = map(partial(update, i), grad_flat, states)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/metaaf/optimizer_hogru.py", line 471, in update
    update, state = optimizer.apply(
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/transform.py", line 128, in apply_fn
    out, state = f.apply(params, {}, *args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/transform.py", line 357, in apply_fn
    out = f(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/metaaf/optimizer_hogru.py", line 234, in _fwd
    return optimizer(x, h, extra_inputs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/module.py", line 434, in wrapped
    out = f(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/module.py", line 273, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/metaaf/optimizer_hogru.py", line 183, in __call__
    rnn_in = self.preprocess_flatten(x, extra_inputs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/module.py", line 434, in wrapped
    out = f(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/module.py", line 273, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/metaaf/optimizer_hogru.py", line 149, in preprocess_flatten
    input_stack_flat = self.in_coupling_conv(
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/module.py", line 434, in wrapped
    out = f(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/module.py", line 273, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/basic.py", line 123, in __call__
    out = layer(out, *args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/module.py", line 434, in wrapped
    out = f(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/module.py", line 273, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/conv.py", line 200, in __call__
    w = hk.get_parameter("w", w_shape, inputs.dtype, init=w_init)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/base.py", line 448, in wrapped
    return wrapped._current(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/base.py", line 515, in get_parameter
    param = init(shape, dtype)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/metaaf/complex_utils.py", line 12, in complex_variance_scaling
    real = hk.initializers.VarianceScaling()(shape, dtype=jnp.float32)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/initializers.py", line 215, in __call__
    return TruncatedNormal(stddev=stddev)(shape, dtype)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/initializers.py", line 114, in __call__
    unscaled = jax.random.truncated_normal(hk.next_rng_key(), -2., 2., shape,
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/base.py", line 448, in wrapped
    return wrapped._current(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/base.py", line 965, in next_rng_key
    return next_rng_key_internal()
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/base.py", line 1003, in next_rng_key_internal
    rng_seq = rng_seq_or_fail()
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/base.py", line 923, in rng_seq_or_fail
    raise ValueError("You must pass a non-None PRNGKey to init and/or apply "
ValueError: You must pass a non-None PRNGKey to init and/or apply if you make use of random numbers.

Inference question

Thank you for sharing your work.

Is there an easier way to run inference on specific files than hacking aec_eval.py? E.g. in case you have a very long mic, reference pair that you'd like to process with one of your pre-trained AEC models.

when running system.infer in AEC task, shows ValueError: 'TimeChanCoupledGRU ...'

According to the fig.4 in the Meta-AF: Meta-Learning for Adaptive Filters paper, I created a function that takes 2 lists, u and d, as input and I want it to be able to get the prediction from the pretrained model.
But when I run it, it shows the error ValueError: 'TimeChanCoupledGRU/~/linear/w' with retrieved shape (20, 32) does not match shape=[16, 32] dtype=dtype('complex64'). Do you know why?
Also, the prediction means y in the paper, right?

I copied the aec_eval.py to aec_get_output.py and created couple functions. But mostly the same.

 .
 .
 . 

def get_output(system, fit_infer, data_dict, out_dir, eval_kwargs, fs=16000):
    """
    given lists of d and u in data_dict,
    return the system prediction
    """
    u = get_u(data_dict) # [0.1, 0.2, .... , 0.1]  # list of 5000 elements
    d = get_d(data_dict) # [0.1, 0.2, .... , 0.1]  # list of 5000 elements
    print(f'u_len :{len(u)}, d_len: {len(d)}')
    e = [0]
    s = [0]
    max_len = len(u)
    u = np.pad(u, (0, max(0, max_len - len(u))), "wrap")
    d = np.pad(d, (0, max(0, max_len - len(d))), "wrap")
    e = np.pad(e, (0, max(0, max_len - len(e))), "wrap")
    s = np.pad(s, (0, max(0, max_len - len(s))), "wrap")
    u_new = u[:, None]
    d_new = d[:, None]
    e_new = e[:, None]
    s_new = s[:, None]
    print(f'Shapes :u {u_new.shape}, d {d_new.shape}, e {e_new.shape}, s {s_new.shape}')

    d_input = {"u": u_new[None], "d": d_new[None], "s": s_new[None], "e": e_new[None]}
    pred = system.infer({"signals": d_input, "metadata": {}}, fit_infer=fit_infer)[0]
    print(pred)

if __name__ == "__main__":

    # get checkpoint description from user
    parser = argparse.ArgumentParser()
    parser.add_argument("--name", type=str, default="")
    parser.add_argument("--date", type=str, default="")
    parser.add_argument("--epoch", type=int, default=0)
    parser.add_argument("--ckpt_dir", type=str, default="./meta_ckpts")

    # get evaluation conditions from user
    parser.add_argument("--universal", action="store_true", default=False)
    parser.add_argument("--system_len", type=int, default=None)

    # these will only get set if universal is false
    parser.add_argument("--true_rir_len", type=int, default=None)

    # decide what to save
    parser.add_argument("--out_dir", type=str, default="./meta_outputs")
    parser.add_argument("--save_outputs", action="store_true")
    parser.add_argument("--save_metrics", action="store_true")

    eval_kwargs = vars(parser.parse_args())
    pprint.pprint(eval_kwargs)

    # # build the checkpoint path
    ckpt_loc = os.path.join(
        eval_kwargs["ckpt_dir"], eval_kwargs["name"], eval_kwargs["date"]
    )
    epoch = int(eval_kwargs["epoch"])
    print(f'checkpoint location : {ckpt_loc}')
    

    # # load the checkpoint and kwargs file
    system, kwargs, outer_learnable = get_system_ckpt(
        ckpt_loc,
        epoch,
        system_len=eval_kwargs["system_len"],
    )
    fit_infer = system.make_fit_infer(outer_learnable=outer_learnable)

    # # build the outputs path
    out_dir = os.path.join(
        eval_kwargs["out_dir"],
        eval_kwargs["name"],
        eval_kwargs["date"],
        f"epoch_{epoch}",
    )
    if eval_kwargs["save_outputs"] or eval_kwargs["save_metrics"]:
        os.makedirs(out_dir, exist_ok=True)
    print(f'output dir: {out_dir}')

    # # name the filter and rir lengths
    true_rir_len = (
        "DEFAULT"
        if eval_kwargs["true_rir_len"] is None
        else eval_kwargs["true_rir_len"]
    )

    print(f'true RIR length : {true_rir_len}')
    system_len = (
        "DEFAULT" if eval_kwargs["system_len"] is None else eval_kwargs["system_len"]
    )
    print(f'system len : {system_len}')
    data_dict = get_data_dict('/home/burro/Downloads/Dataset.csv')
    predict = get_output(system, fit_infer, data_dict, out_dir, eval_kwargs)

error log:

(metaenv) burro@hostname:~/personal/MetaAF1.0.0/MetaAF-1.0.0/zoo/aec$ python aec_get_output.py --name meta_aec_16_combo_rl_4_1024_512 --date 2022_08_29_01_10_31 --epoch 110 --ckpt_dir ~/personal/MetaAF1.0.0/MetaAF-1.0.0/1.0.0-model/v1.0.0_models/aec
{'ckpt_dir': '/home/burro/personal/MetaAF1.0.0/MetaAF-1.0.0/1.0.0-model/v1.0.0_models/aec',
 'date': '2022_08_29_01_10_31',
 'epoch': 110,
 'name': 'meta_aec_16_combo_rl_4_1024_512',
 'out_dir': './meta_outputs',
 'save_metrics': False,
 'save_outputs': False,
 'system_len': None,
 'true_rir_len': None,
 'universal': False}
checkpoint location : /home/burro/personal/MetaAF1.0.0/MetaAF-1.0.0/1.0.0-model/v1.0.0_models/aec/meta_aec_16_combo_rl_4_1024_512/2022_08_29_01_10_31
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
output dir: ./meta_outputs/meta_aec_16_combo_rl_4_1024_512/2022_08_29_01_10_31/epoch_110
true RIR length : DEFAULT
system len : DEFAULT
u_len :5000, d_len: 5000
Shapes :(5000, 1), (5000, 1), (5000, 1), (5000, 1)
Traceback (most recent call last):
  File "aec_get_output.py", line 354, in <module>
    predict = get_output(system, fit_infer, data_dict, out_dir, eval_kwargs)
  File "aec_get_output.py", line 224, in get_output
    pred = system.infer({"signals": d_input, "metadata": {}}, fit_infer=fit_infer)[0]
  File "/home/burro/personal/MetaAF/metaaf/meta.py", line 804, in infer
    filter_s, filter_p, preprocess_s, postprocess_s, batch, key
  File "/home/burro/personal/MetaAF/metaaf/core.py", line 534, in fit_single
    batch_state, batch_hop, jnp.array(subkeys)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/_src/api.py", line 1686, in vmap_f
    ).call_wrapped(*args_flat)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/linear_util.py", line 167, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/_src/api.py", line 626, in cache_miss
    top_trace.process_call(primitive, fun_, tracers, params))
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/interpreters/batching.py", line 377, in process_call
    vals_out = call_primitive.bind(f_, *segment_lens, *vals, **params)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/core.py", line 2019, in bind
    outs = top_trace.process_call(self, fun_, tracers, params)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/core.py", line 715, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/_src/dispatch.py", line 250, in _xla_call_impl
    keep_unused=keep_unused)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/_src/dispatch.py", line 237, in _xla_call_impl_lazy
    *arg_specs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/linear_util.py", line 303, in memoized_fun
    ans = call(fun, *args)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/_src/dispatch.py", line 360, in _xla_callable_uncached
    keep_unused, *arg_specs).compile().unsafe_call
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/_src/dispatch.py", line 446, in lower_xla_callable
    fun, pe.debug_info_final(fun, "jit"))
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 2077, in trace_to_jaxpr_final2
    jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 2027, in trace_to_subjaxpr_dynamic2
    ans = fun.call_wrapped(*in_tracers_)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/linear_util.py", line 167, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/burro/personal/MetaAF/metaaf/core.py", line 445, in online_step
    opt_s = opt_update(0, filter_features, opt_s)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/example_libraries/optimizers.py", line 196, in tree_update
    new_states = map(partial(update, i), grad_flat, states)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/_src/util.py", line 78, in safe_map
    return list(map(f, *args))
  File "/home/burro/personal/MetaAF/metaaf/optimizer_fgru.py", line 183, in update
    **optimizer_kwargs,
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/transform.py", line 184, in apply_fn
    out, state = f.apply(params, None, *args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/transform.py", line 451, in apply_fn
    out = f(*args, **kwargs)
  File "/home/burro/personal/MetaAF/metaaf/optimizer_fgru.py", line 124, in _timechancoupled_gru_fwd
    return optimizer(x, h, extra_inputs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/module.py", line 465, in wrapped
    out = f(*args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/contextlib.py", line 74, in inner
    return func(*args, **kwds)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/module.py", line 306, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/home/burro/personal/MetaAF/metaaf/optimizer_fgru.py", line 83, in __call__
    rnn_in = self.preprocess_flatten(x, extra_inputs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/module.py", line 465, in wrapped
    out = f(*args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/contextlib.py", line 74, in inner
    return func(*args, **kwds)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/contextlib.py", line 74, in inner
    return func(*args, **kwds)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/module.py", line 306, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/home/burro/personal/MetaAF/metaaf/optimizer_fgru.py", line 67, in preprocess_flatten
    return self.in_lin(input_stack_flat)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/module.py", line 465, in wrapped
    out = f(*args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/contextlib.py", line 74, in inner
    return func(*args, **kwds)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/module.py", line 306, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/basic.py", line 125, in __call__
    out = layer(out, *args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/module.py", line 465, in wrapped
    out = f(*args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/contextlib.py", line 74, in inner
    return func(*args, **kwds)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/module.py", line 306, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/basic.py", line 178, in __call__
    w = hk.get_parameter("w", [input_size, output_size], dtype, init=w_init)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/base.py", line 603, in wrapped
    return wrapped._current(*args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/base.py", line 709, in get_parameter
    f"{fq_name!r} with retrieved shape {param.shape!r} does not match "
jax._src.traceback_util.UnfilteredStackTrace: ValueError: 'TimeChanCoupledGRU/~/linear/w' with retrieved shape (20, 32) does not match shape=[16, 32] dtype=dtype('complex64')

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "aec_get_output.py", line 354, in <module>
    predict = get_output(system, fit_infer, data_dict, out_dir, eval_kwargs)
  File "aec_get_output.py", line 224, in get_output
    pred = system.infer({"signals": d_input, "metadata": {}}, fit_infer=fit_infer)[0]
  File "/home/burro/personal/MetaAF/metaaf/meta.py", line 804, in infer
    filter_s, filter_p, preprocess_s, postprocess_s, batch, key
  File "/home/burro/personal/MetaAF/metaaf/core.py", line 534, in fit_single
    batch_state, batch_hop, jnp.array(subkeys)
  File "/home/burro/personal/MetaAF/metaaf/core.py", line 445, in online_step
    opt_s = opt_update(0, filter_features, opt_s)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/example_libraries/optimizers.py", line 196, in tree_update
    new_states = map(partial(update, i), grad_flat, states)
  File "/home/burro/personal/MetaAF/metaaf/optimizer_fgru.py", line 183, in update
    **optimizer_kwargs,
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/transform.py", line 184, in apply_fn
    out, state = f.apply(params, None, *args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/transform.py", line 451, in apply_fn
    out = f(*args, **kwargs)
  File "/home/burro/personal/MetaAF/metaaf/optimizer_fgru.py", line 124, in _timechancoupled_gru_fwd
    return optimizer(x, h, extra_inputs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/module.py", line 465, in wrapped
    out = f(*args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/contextlib.py", line 74, in inner
    return func(*args, **kwds)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/module.py", line 306, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/home/burro/personal/MetaAF/metaaf/optimizer_fgru.py", line 83, in __call__
    rnn_in = self.preprocess_flatten(x, extra_inputs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/module.py", line 465, in wrapped
    out = f(*args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/contextlib.py", line 74, in inner
    return func(*args, **kwds)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/contextlib.py", line 74, in inner
    return func(*args, **kwds)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/module.py", line 306, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/home/burro/personal/MetaAF/metaaf/optimizer_fgru.py", line 67, in preprocess_flatten
    return self.in_lin(input_stack_flat)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/module.py", line 465, in wrapped
    out = f(*args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/contextlib.py", line 74, in inner
    return func(*args, **kwds)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/module.py", line 306, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/basic.py", line 125, in __call__
    out = layer(out, *args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/module.py", line 465, in wrapped
    out = f(*args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/contextlib.py", line 74, in inner
    return func(*args, **kwds)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/module.py", line 306, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/basic.py", line 178, in __call__
    w = hk.get_parameter("w", [input_size, output_size], dtype, init=w_init)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/base.py", line 603, in wrapped
    return wrapped._current(*args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/base.py", line 709, in get_parameter
    f"{fq_name!r} with retrieved shape {param.shape!r} does not match "
ValueError: 'TimeChanCoupledGRU/~/linear/w' with retrieved shape (20, 32) does not match shape=[16, 32] dtype=dtype('complex64')

rir_idx index out of range

git tag : v1.0.0
model : 1.0.0 checkpoint
dataset: AEC-challenge and RIRS US mirror
content of the zoo/__config__.py file:

AEC_DATA_DIR = "/home/username/personal/dataset/AEC-Challenge/"
RIR_DATA_DIR = "/home/username/personal/rirs_noises/RIRS_NOISES/"
.
.

Cannot run aec_eval.py using pretrained model on the datasets

(metaenv) username@hostname:~/personal/MetaAF1.0.0/MetaAF-1.0.0/zoo/aec$ python aec_eval.py --name meta_aec_16_combo_rl_4_1024_512  --date 2022_08_29_01_10_31  --epoch 110 --ckpt_dir ~/personal/MetaAF1.0.0/MetaAF-1.0.0/1.0.0-model/v1.0.0_models/aec
{'ckpt_dir': '/home/username/personal/MetaAF1.0.0/MetaAF-1.0.0/1.0.0-model/v1.0.0_models/aec',
 'date': '2022_08_29_01_10_31',
 'epoch': 110,
 'name': 'meta_aec_16_combo_rl_4_1024_512',
 'out_dir': './meta_outputs',
 'save_metrics': False,
 'save_outputs': False,
 'system_len': None,
 'true_rir_len': None,
 'universal': False}
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
 --- Testing Dataset ---
 --- DT_False_SC_False_NL_False_TR_DEFAULT_SL_DEFAULT ---
  0%|                                                                       | 0/500 [00:00<?, ?it/s]0, 0
  0%|                                                                       | 0/500 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "aec_eval.py", line 305, in <module>
    system, fit_infer, test_dataset, out_dir, test_dataset_names[i], eval_kwargs
  File "aec_eval.py", line 151, in run_eval
    data = test_dataset[i]
  File "/home/username/personal/MetaAF/zoo/aec/aec.py", line 285, in __getitem__
    data_dict = self.load_from_idx(idx)
  File "/home/username/personal/MetaAF/zoo/aec/aec.py", line 233, in load_from_idx
    w, _ = sf.read(self.rirs[rir_idx])
IndexError: list index out of range

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.