Giter Site home page Giter Site logo

adobe-research / metaaf Goto Github PK

View Code? Open in Web Editor NEW
207.0 8.0 39.0 1017 KB

Control adaptive filters with neural networks.

Home Page: https://jmcasebeer.github.io/projects/metaaf

Python 97.28% Shell 1.42% Jupyter Notebook 1.30%
beamforming acoustic-echo-cancellation adaptive-filtering adaptive-filters dereverberation gsc linear-prediction system-identification blind-equalization weighted-prediction-error dsp jax signal-processing digital-signal-processing

metaaf's Issues

Could not get the same AEC results shown on the demo page with the provided pretrained models

Excellent work! Thanks for sharing the code base and pretrained models.

I would like to try the AEC performace of Meta-AF using your pretrained models. To make sure that I use them correctly, I downloaded the wav files of the first double-talk sample on your demo website and ran AEC with the pretrained model v0.1.0_models/aec/aec_16_dt_c/2022_04_10_15_57_12/epoch_230.pkl. However, I can only get much worse AEC result than you provided on the demo website. Could you please help me out? The test code I used:

import os
from aec_eval import get_system_ckpt
import numpy as np
import librosa
import soundfile as sf

ckpt_dir = "v0.1.0_models/aec/"
name = "aec_16_dt_c"
date = "2022_04_10_15_57_12"
epoch = 230

ckpt_loc = os.path.join(ckpt_dir, name, date)

system, kwargs, outer_learnable = get_system_ckpt(
    ckpt_loc,
    epoch,
    model_type="egru",
    system_len=None,
)
fit_infer = system.make_fit_infer(outer_learnable=outer_learnable)
fs = 16000

out_dir = "metaAF_res"
os.makedirs(out_dir, exist_ok=True)

u, _ = librosa.load("u.mp3", sr=fs)
d, _ = librosa.load("d.mp3", sr=fs)
s, _ = librosa.load("s.mp3", sr=fs)
e = d - s

d_input = {"u": u[None, :, None], "d": d[None, :, None],
           "s": s[None, :, None], "e": e[None, :, None]
           }
pred = system.infer({"signals": d_input, "metadata": {}}, fit_infer=fit_infer)[
    0
]
pred = np.array(pred[0, :, 0])

sf.write(os.path.join(out_dir, f"_out.wav"), pred, fs)

Looking forward to hearing from you, thanks!
Best,
Fei

The replicate results don't match the demo.

Hello! Thanks for sharing the pre-trained models and demos.
I would like to replicate the demo results using a pretrained model. I used the data from the first row of the double-talk and converted the mp3 to wav format (single channel, 16000Hz, 16bit) for convenience. Based on the speech titles downloaded from the demo page, I selected the same pkl file to process the original speech. However, there is a significant difference between the spectrograms from the demo page and those generated using the pre-trained model. I've checked every steps and can't find the reason. Could you help me understand why?

1715424116578 1715424077930

model tag: v1.0.1
This code i used is below:

import os
from aec_eval import get_system_ckpt
import numpy as np
import librosa
import soundfile as sf

ckpt_dir = "v1.0.1_models/aec/"
name = "meta_aec_16_combo_rl_4_1024_512_r2"
date = "2022_10_19_23_43_22"
epoch = 110

ckpt_loc = os.path.join(ckpt_dir, name, date)

system, kwargs, outer_learnable = get_system_ckpt(
    ckpt_loc,
    epoch,
)
fit_infer = system.make_fit_infer(outer_learnable=outer_learnable)
fs = 16000

out_dir = "metaAF_output"
os.makedirs(out_dir, exist_ok=True)

u, _ = librosa.load("u.wav", sr=fs)
d, _ = librosa.load("d.wav", sr=fs)
s, _ = librosa.load("s.wav", sr=fs)
e = d - s

d_input = {"u": u[None, :, None], "d": d[None, :, None],
           "s": s[None, :, None], "e": e[None, :, None]
           }
pred = system.infer({"signals": d_input, "metadata": {}}, fit_infer=fit_infer)[0]
pred = np.array(pred[0, :, 0])

sf.write(os.path.join(out_dir, f"_out.wav"), pred, fs)

Looking forward to hearing from you, thanks!

some module l can't find ,help me

1.Traceback (most recent call last):
File "/home/easymoney/kk/MetaAF/metaaf/core.py", line 6, in
from metaaf.optimizer_utils import FeatureContainer
ModuleNotFoundError: No module named 'metaaf'
2.ModuleNotFoundError Traceback (most recent call last)
/home/easymoney/kk/MetaAF/examples/sysid_demo.ipynb Cell 1 line 1
----> 1](vscode-notebook-cell://wsl%2Bubuntu-22.04/home/easymoney/kk/MetaAF/examples/sysid_demo.ipynb#W0sdnNjb2RlLXJlbW90ZQ%3D%3D?line=0'%3E1%3C/a%3E) get_ipython().run_line_magic('load_ext', 'lab_black')
2](vscode-notebook-cell://wsl%2Bubuntu-22.04/home/easymoney/kk/MetaAF/examples/sysid_demo.ipynb#W0sdnNjb2RlLXJlbW90ZQ%3D%3D?line=1'%3E2%3C/a%3E) get_ipython().run_line_magic('load_ext', 'autoreload')
3](vscode-notebook-cell://wsl%2Bubuntu-22.04/home/easymoney/kk/MetaAF/examples/sysid_demo.ipynb#W0sdnNjb2RlLXJlbW90ZQ%3D%3D?line=2'%3E3%3C/a%3E) get_ipython().run_line_magic('autoreload', '2')

File ~/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:2432, in InteractiveShell.run_line_magic(self, magic_name, line, _stack_depth)
2430 kwargs['local_ns'] = self.get_local_scope(stack_depth)
2431 with self.builtin_trap:
-> 2432 result = fn(*args, **kwargs)
2434 # The code below prevents the output from being displayed
2435 # when using magics with decorator @output_can_be_silenced
2436 # when the last Python token in the expression is a ';'.
2437 if getattr(fn, magic.MAGIC_OUTPUT_CAN_BE_SILENCED, False):

File ~/.local/lib/python3.10/site-packages/IPython/core/magics/extension.py:33, in ExtensionMagics.load_ext(self, module_str)
31 if not module_str:
32 raise UsageError('Missing module name.')
---> 33 res = self.shell.extension_manager.load_extension(module_str)
35 if res == 'already loaded':
36 print("The %s extension is already loaded. To reload it, use:" % module_str)

File ~/.local/lib/python3.10/site-packages/IPython/core/extensions.py:76, in ExtensionManager.load_extension(self, module_str)
69 """Load an IPython extension by its module name.
...
File :1027, in find_and_load(name, import)

File :1004, in find_and_load_unlocked(name, import)

ModuleNotFoundError: No module named 'lab_black

You must pass a non-None PRNGKey to init and/or apply if you make use of random numbers when running hoaec_eval.py

I wish to run hoaec_eval.py trough the r_es_ckpts_run.sh_ bash script to try out the algorithms and I get the "You must pass a non-None PRNGKey to init and/or apply if you make use of random numbers when running hoaec_eval.py" error when running. The code is practically the same as the repository, I only changed the config.py constant values to match my folders. Here's the complete Traceback. Hope you can help!

Storing AEC outputs...
/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/torch/utils/data/dataloader.py:561: UserWarning: This DataLoader will create 10 worker processes in total. Our suggested max number of worker in current system is 8 (`cpuset` is not taken into account), which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(_create_warning_msg(
  0%|                                                                                                                                                                                                                                    | 0/32 [00:00<?, ?it/s]/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/metaaf/complex_gru.py:114: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
  return jax.tree_map(broadcast, nest)
  0%|                                                                                                                                                                                                                                    | 0/32 [00:11<?, ?it/s]
Traceback (most recent call last):
  File "/Users/agus/work/eye-predict/audioEngineering/External-repos/MetaAFPackage/hoaec_eval.py", line 225, in <module>
    preds = system.infer(data)[0]
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/metaaf/meta.py", line 825, in infer
    out, aux = fit_infer(
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/metaaf/core.py", line 549, in fit_single
    cur_out, loss, batch_state = batch_step(
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/_src/api.py", line 1564, in vmap_f
    out_flat = batching.batch(
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/linear_util.py", line 168, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/_src/api.py", line 526, in cache_miss
    out_flat = xla.xla_call(
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/core.py", line 1919, in bind
    return call_bind(self, fun, *args, **params)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/core.py", line 1935, in call_bind
    outs = top_trace.process_call(primitive, fun_, tracers, params)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/interpreters/batching.py", line 233, in process_call
    vals_out = call_primitive.bind(f_, *vals, **params)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/core.py", line 1919, in bind
    return call_bind(self, fun, *args, **params)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/core.py", line 1935, in call_bind
    outs = top_trace.process_call(primitive, fun_, tracers, params)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/core.py", line 687, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/_src/dispatch.py", line 199, in _xla_call_impl
    compiled_fun = xla_callable(fun, device, backend, name, donated_invars,
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/linear_util.py", line 295, in memoized_fun
    ans = call(fun, *args)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/_src/dispatch.py", line 248, in _xla_callable_uncached
    return lower_xla_callable(fun, device, backend, name, donated_invars, False,
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/_src/profiler.py", line 294, in wrapper
    return func(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/_src/dispatch.py", line 293, in lower_xla_callable
    jaxpr, out_type, consts = pe.trace_to_jaxpr_final2(
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/_src/profiler.py", line 294, in wrapper
    return func(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 2167, in trace_to_jaxpr_final2
    jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 2117, in trace_to_subjaxpr_dynamic2
    ans = fun.call_wrapped(*in_tracers_)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/linear_util.py", line 168, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/metaaf/core.py", line 462, in online_step
    opt_s = opt_update(0, filter_features, opt_s)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/example_libraries/optimizers.py", line 196, in tree_update
    new_states = map(partial(update, i), grad_flat, states)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/_src/util.py", line 47, in safe_map
    return list(map(f, *args))
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/metaaf/optimizer_hogru.py", line 471, in update
    update, state = optimizer.apply(
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/transform.py", line 128, in apply_fn
    out, state = f.apply(params, {}, *args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/transform.py", line 357, in apply_fn
    out = f(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/metaaf/optimizer_hogru.py", line 234, in _fwd
    return optimizer(x, h, extra_inputs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/module.py", line 434, in wrapped
    out = f(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/module.py", line 273, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/metaaf/optimizer_hogru.py", line 183, in __call__
    rnn_in = self.preprocess_flatten(x, extra_inputs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/module.py", line 434, in wrapped
    out = f(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/module.py", line 273, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/metaaf/optimizer_hogru.py", line 149, in preprocess_flatten
    input_stack_flat = self.in_coupling_conv(
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/module.py", line 434, in wrapped
    out = f(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/module.py", line 273, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/basic.py", line 123, in __call__
    out = layer(out, *args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/module.py", line 434, in wrapped
    out = f(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/module.py", line 273, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/conv.py", line 200, in __call__
    w = hk.get_parameter("w", w_shape, inputs.dtype, init=w_init)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/base.py", line 448, in wrapped
    return wrapped._current(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/base.py", line 515, in get_parameter
    param = init(shape, dtype)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/metaaf/complex_utils.py", line 12, in complex_variance_scaling
    real = hk.initializers.VarianceScaling()(shape, dtype=jnp.float32)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/initializers.py", line 215, in __call__
    return TruncatedNormal(stddev=stddev)(shape, dtype)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/initializers.py", line 114, in __call__
    unscaled = jax.random.truncated_normal(hk.next_rng_key(), -2., 2., shape,
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/base.py", line 448, in wrapped
    return wrapped._current(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/base.py", line 965, in next_rng_key
    return next_rng_key_internal()
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/base.py", line 1003, in next_rng_key_internal
    rng_seq = rng_seq_or_fail()
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/base.py", line 923, in rng_seq_or_fail
    raise ValueError("You must pass a non-None PRNGKey to init and/or apply "
jax._src.traceback_util.UnfilteredStackTrace: ValueError: You must pass a non-None PRNGKey to init and/or apply if you make use of random numbers.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/agus/work/eye-predict/audioEngineering/External-repos/MetaAFPackage/hoaec_eval.py", line 225, in <module>
    preds = system.infer(data)[0]
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/metaaf/meta.py", line 825, in infer
    out, aux = fit_infer(
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/metaaf/core.py", line 549, in fit_single
    cur_out, loss, batch_state = batch_step(
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/metaaf/core.py", line 462, in online_step
    opt_s = opt_update(0, filter_features, opt_s)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/example_libraries/optimizers.py", line 196, in tree_update
    new_states = map(partial(update, i), grad_flat, states)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/metaaf/optimizer_hogru.py", line 471, in update
    update, state = optimizer.apply(
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/transform.py", line 128, in apply_fn
    out, state = f.apply(params, {}, *args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/transform.py", line 357, in apply_fn
    out = f(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/metaaf/optimizer_hogru.py", line 234, in _fwd
    return optimizer(x, h, extra_inputs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/module.py", line 434, in wrapped
    out = f(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/module.py", line 273, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/metaaf/optimizer_hogru.py", line 183, in __call__
    rnn_in = self.preprocess_flatten(x, extra_inputs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/module.py", line 434, in wrapped
    out = f(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/module.py", line 273, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/metaaf/optimizer_hogru.py", line 149, in preprocess_flatten
    input_stack_flat = self.in_coupling_conv(
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/module.py", line 434, in wrapped
    out = f(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/module.py", line 273, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/basic.py", line 123, in __call__
    out = layer(out, *args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/module.py", line 434, in wrapped
    out = f(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/module.py", line 273, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/conv.py", line 200, in __call__
    w = hk.get_parameter("w", w_shape, inputs.dtype, init=w_init)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/base.py", line 448, in wrapped
    return wrapped._current(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/base.py", line 515, in get_parameter
    param = init(shape, dtype)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/metaaf/complex_utils.py", line 12, in complex_variance_scaling
    real = hk.initializers.VarianceScaling()(shape, dtype=jnp.float32)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/initializers.py", line 215, in __call__
    return TruncatedNormal(stddev=stddev)(shape, dtype)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/initializers.py", line 114, in __call__
    unscaled = jax.random.truncated_normal(hk.next_rng_key(), -2., 2., shape,
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/base.py", line 448, in wrapped
    return wrapped._current(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/base.py", line 965, in next_rng_key
    return next_rng_key_internal()
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/base.py", line 1003, in next_rng_key_internal
    rng_seq = rng_seq_or_fail()
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/base.py", line 923, in rng_seq_or_fail
    raise ValueError("You must pass a non-None PRNGKey to init and/or apply "
ValueError: You must pass a non-None PRNGKey to init and/or apply if you make use of random numbers.

Compatibility with x64 CPU

Hello, I have AMD graphics card and i can't install JAX because it is not available for windows on AMD GPUs, is this program compatible with x64 based CPUs or just x86? I have installed JAX for CPU and i just want to know if it is compatible?

Why not use nearend speech s to be the target ?

Thanks for sharing your excellent work!
Could you please explain that why not use the clean or noisy nearend speech to be the network target ? And why use the mic signal d be
the target can cover double talk scenes ? Thank you.

rir_idx index out of range

git tag : v1.0.0
model : 1.0.0 checkpoint
dataset: AEC-challenge and RIRS US mirror
content of the zoo/__config__.py file:

AEC_DATA_DIR = "/home/username/personal/dataset/AEC-Challenge/"
RIR_DATA_DIR = "/home/username/personal/rirs_noises/RIRS_NOISES/"
.
.

Cannot run aec_eval.py using pretrained model on the datasets

(metaenv) username@hostname:~/personal/MetaAF1.0.0/MetaAF-1.0.0/zoo/aec$ python aec_eval.py --name meta_aec_16_combo_rl_4_1024_512  --date 2022_08_29_01_10_31  --epoch 110 --ckpt_dir ~/personal/MetaAF1.0.0/MetaAF-1.0.0/1.0.0-model/v1.0.0_models/aec
{'ckpt_dir': '/home/username/personal/MetaAF1.0.0/MetaAF-1.0.0/1.0.0-model/v1.0.0_models/aec',
 'date': '2022_08_29_01_10_31',
 'epoch': 110,
 'name': 'meta_aec_16_combo_rl_4_1024_512',
 'out_dir': './meta_outputs',
 'save_metrics': False,
 'save_outputs': False,
 'system_len': None,
 'true_rir_len': None,
 'universal': False}
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
 --- Testing Dataset ---
 --- DT_False_SC_False_NL_False_TR_DEFAULT_SL_DEFAULT ---
  0%|                                                                       | 0/500 [00:00<?, ?it/s]0, 0
  0%|                                                                       | 0/500 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "aec_eval.py", line 305, in <module>
    system, fit_infer, test_dataset, out_dir, test_dataset_names[i], eval_kwargs
  File "aec_eval.py", line 151, in run_eval
    data = test_dataset[i]
  File "/home/username/personal/MetaAF/zoo/aec/aec.py", line 285, in __getitem__
    data_dict = self.load_from_idx(idx)
  File "/home/username/personal/MetaAF/zoo/aec/aec.py", line 233, in load_from_idx
    w, _ = sf.read(self.rirs[rir_idx])
IndexError: list index out of range

AEC Data Setup Issues

I have downloaded the ''AEC-Challenge-main'' dataset and "RIRS_NOISES" dataset and I have set the paths in the config.
I would like to try the AEC performance of Meta-AF using your pre-trained models. When I run this code which you tell in the tutorial:
(!python /content/MetaAF/zoo/aec/aec.py --n_frames 1 --window_size 2048 --hop_size 1024 --n_in_chan 1 --n_out_chan 1 --is_real --n_devices 1 --batch_size 64 --total_epochs 1000 --val_period 10 --reduce_lr_patience 1 --early_stop_patience 4 --name meta_aec_demo --unroll 16 --extra_signals ude --random_roll --outer_loss log_self_mse --double_talk --dataset nonlinear)
I will see 3 choices :
wandb: (1) Create a W&B account
wandb: (2) Use an existing W&B account
wandb: (3) Don't visualize my results
If I select 1 and after that register for a W&B account when I enter the wandb backend code. I received from the (wandb.errors.CommError: Permission denied, ask the project owner to grant you access)
when I select 3 I will receive (RuntimeError: Error opening '/content/AEC-Challenge main/datasets/synthetic/farend_speech/farend_speech_fileid_0.wav': File contains data in an unknown format.)
Could you please help me to solve this error?

Originally posted by @Alirezanezamdoost in #4 (comment)

Inference question

Thank you for sharing your work.

Is there an easier way to run inference on specific files than hacking aec_eval.py? E.g. in case you have a very long mic, reference pair that you'd like to process with one of your pre-trained AEC models.

I've got a few ideas.

I have an idea that change your specialised focus on linear FD adaptive filters to non-linear adaptive filters and use for continuous signals, I don't know if it's possible, I hope you can answer my confusion.If so, what should I do to improve.thanks

when running system.infer in AEC task, shows ValueError: 'TimeChanCoupledGRU ...'

According to the fig.4 in the Meta-AF: Meta-Learning for Adaptive Filters paper, I created a function that takes 2 lists, u and d, as input and I want it to be able to get the prediction from the pretrained model.
But when I run it, it shows the error ValueError: 'TimeChanCoupledGRU/~/linear/w' with retrieved shape (20, 32) does not match shape=[16, 32] dtype=dtype('complex64'). Do you know why?
Also, the prediction means y in the paper, right?

I copied the aec_eval.py to aec_get_output.py and created couple functions. But mostly the same.

 .
 .
 . 

def get_output(system, fit_infer, data_dict, out_dir, eval_kwargs, fs=16000):
    """
    given lists of d and u in data_dict,
    return the system prediction
    """
    u = get_u(data_dict) # [0.1, 0.2, .... , 0.1]  # list of 5000 elements
    d = get_d(data_dict) # [0.1, 0.2, .... , 0.1]  # list of 5000 elements
    print(f'u_len :{len(u)}, d_len: {len(d)}')
    e = [0]
    s = [0]
    max_len = len(u)
    u = np.pad(u, (0, max(0, max_len - len(u))), "wrap")
    d = np.pad(d, (0, max(0, max_len - len(d))), "wrap")
    e = np.pad(e, (0, max(0, max_len - len(e))), "wrap")
    s = np.pad(s, (0, max(0, max_len - len(s))), "wrap")
    u_new = u[:, None]
    d_new = d[:, None]
    e_new = e[:, None]
    s_new = s[:, None]
    print(f'Shapes :u {u_new.shape}, d {d_new.shape}, e {e_new.shape}, s {s_new.shape}')

    d_input = {"u": u_new[None], "d": d_new[None], "s": s_new[None], "e": e_new[None]}
    pred = system.infer({"signals": d_input, "metadata": {}}, fit_infer=fit_infer)[0]
    print(pred)

if __name__ == "__main__":

    # get checkpoint description from user
    parser = argparse.ArgumentParser()
    parser.add_argument("--name", type=str, default="")
    parser.add_argument("--date", type=str, default="")
    parser.add_argument("--epoch", type=int, default=0)
    parser.add_argument("--ckpt_dir", type=str, default="./meta_ckpts")

    # get evaluation conditions from user
    parser.add_argument("--universal", action="store_true", default=False)
    parser.add_argument("--system_len", type=int, default=None)

    # these will only get set if universal is false
    parser.add_argument("--true_rir_len", type=int, default=None)

    # decide what to save
    parser.add_argument("--out_dir", type=str, default="./meta_outputs")
    parser.add_argument("--save_outputs", action="store_true")
    parser.add_argument("--save_metrics", action="store_true")

    eval_kwargs = vars(parser.parse_args())
    pprint.pprint(eval_kwargs)

    # # build the checkpoint path
    ckpt_loc = os.path.join(
        eval_kwargs["ckpt_dir"], eval_kwargs["name"], eval_kwargs["date"]
    )
    epoch = int(eval_kwargs["epoch"])
    print(f'checkpoint location : {ckpt_loc}')
    

    # # load the checkpoint and kwargs file
    system, kwargs, outer_learnable = get_system_ckpt(
        ckpt_loc,
        epoch,
        system_len=eval_kwargs["system_len"],
    )
    fit_infer = system.make_fit_infer(outer_learnable=outer_learnable)

    # # build the outputs path
    out_dir = os.path.join(
        eval_kwargs["out_dir"],
        eval_kwargs["name"],
        eval_kwargs["date"],
        f"epoch_{epoch}",
    )
    if eval_kwargs["save_outputs"] or eval_kwargs["save_metrics"]:
        os.makedirs(out_dir, exist_ok=True)
    print(f'output dir: {out_dir}')

    # # name the filter and rir lengths
    true_rir_len = (
        "DEFAULT"
        if eval_kwargs["true_rir_len"] is None
        else eval_kwargs["true_rir_len"]
    )

    print(f'true RIR length : {true_rir_len}')
    system_len = (
        "DEFAULT" if eval_kwargs["system_len"] is None else eval_kwargs["system_len"]
    )
    print(f'system len : {system_len}')
    data_dict = get_data_dict('/home/burro/Downloads/Dataset.csv')
    predict = get_output(system, fit_infer, data_dict, out_dir, eval_kwargs)

error log:

(metaenv) burro@hostname:~/personal/MetaAF1.0.0/MetaAF-1.0.0/zoo/aec$ python aec_get_output.py --name meta_aec_16_combo_rl_4_1024_512 --date 2022_08_29_01_10_31 --epoch 110 --ckpt_dir ~/personal/MetaAF1.0.0/MetaAF-1.0.0/1.0.0-model/v1.0.0_models/aec
{'ckpt_dir': '/home/burro/personal/MetaAF1.0.0/MetaAF-1.0.0/1.0.0-model/v1.0.0_models/aec',
 'date': '2022_08_29_01_10_31',
 'epoch': 110,
 'name': 'meta_aec_16_combo_rl_4_1024_512',
 'out_dir': './meta_outputs',
 'save_metrics': False,
 'save_outputs': False,
 'system_len': None,
 'true_rir_len': None,
 'universal': False}
checkpoint location : /home/burro/personal/MetaAF1.0.0/MetaAF-1.0.0/1.0.0-model/v1.0.0_models/aec/meta_aec_16_combo_rl_4_1024_512/2022_08_29_01_10_31
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
output dir: ./meta_outputs/meta_aec_16_combo_rl_4_1024_512/2022_08_29_01_10_31/epoch_110
true RIR length : DEFAULT
system len : DEFAULT
u_len :5000, d_len: 5000
Shapes :(5000, 1), (5000, 1), (5000, 1), (5000, 1)
Traceback (most recent call last):
  File "aec_get_output.py", line 354, in <module>
    predict = get_output(system, fit_infer, data_dict, out_dir, eval_kwargs)
  File "aec_get_output.py", line 224, in get_output
    pred = system.infer({"signals": d_input, "metadata": {}}, fit_infer=fit_infer)[0]
  File "/home/burro/personal/MetaAF/metaaf/meta.py", line 804, in infer
    filter_s, filter_p, preprocess_s, postprocess_s, batch, key
  File "/home/burro/personal/MetaAF/metaaf/core.py", line 534, in fit_single
    batch_state, batch_hop, jnp.array(subkeys)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/_src/api.py", line 1686, in vmap_f
    ).call_wrapped(*args_flat)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/linear_util.py", line 167, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/_src/api.py", line 626, in cache_miss
    top_trace.process_call(primitive, fun_, tracers, params))
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/interpreters/batching.py", line 377, in process_call
    vals_out = call_primitive.bind(f_, *segment_lens, *vals, **params)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/core.py", line 2019, in bind
    outs = top_trace.process_call(self, fun_, tracers, params)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/core.py", line 715, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/_src/dispatch.py", line 250, in _xla_call_impl
    keep_unused=keep_unused)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/_src/dispatch.py", line 237, in _xla_call_impl_lazy
    *arg_specs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/linear_util.py", line 303, in memoized_fun
    ans = call(fun, *args)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/_src/dispatch.py", line 360, in _xla_callable_uncached
    keep_unused, *arg_specs).compile().unsafe_call
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/_src/dispatch.py", line 446, in lower_xla_callable
    fun, pe.debug_info_final(fun, "jit"))
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 2077, in trace_to_jaxpr_final2
    jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 2027, in trace_to_subjaxpr_dynamic2
    ans = fun.call_wrapped(*in_tracers_)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/linear_util.py", line 167, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/burro/personal/MetaAF/metaaf/core.py", line 445, in online_step
    opt_s = opt_update(0, filter_features, opt_s)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/example_libraries/optimizers.py", line 196, in tree_update
    new_states = map(partial(update, i), grad_flat, states)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/_src/util.py", line 78, in safe_map
    return list(map(f, *args))
  File "/home/burro/personal/MetaAF/metaaf/optimizer_fgru.py", line 183, in update
    **optimizer_kwargs,
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/transform.py", line 184, in apply_fn
    out, state = f.apply(params, None, *args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/transform.py", line 451, in apply_fn
    out = f(*args, **kwargs)
  File "/home/burro/personal/MetaAF/metaaf/optimizer_fgru.py", line 124, in _timechancoupled_gru_fwd
    return optimizer(x, h, extra_inputs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/module.py", line 465, in wrapped
    out = f(*args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/contextlib.py", line 74, in inner
    return func(*args, **kwds)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/module.py", line 306, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/home/burro/personal/MetaAF/metaaf/optimizer_fgru.py", line 83, in __call__
    rnn_in = self.preprocess_flatten(x, extra_inputs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/module.py", line 465, in wrapped
    out = f(*args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/contextlib.py", line 74, in inner
    return func(*args, **kwds)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/contextlib.py", line 74, in inner
    return func(*args, **kwds)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/module.py", line 306, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/home/burro/personal/MetaAF/metaaf/optimizer_fgru.py", line 67, in preprocess_flatten
    return self.in_lin(input_stack_flat)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/module.py", line 465, in wrapped
    out = f(*args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/contextlib.py", line 74, in inner
    return func(*args, **kwds)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/module.py", line 306, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/basic.py", line 125, in __call__
    out = layer(out, *args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/module.py", line 465, in wrapped
    out = f(*args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/contextlib.py", line 74, in inner
    return func(*args, **kwds)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/module.py", line 306, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/basic.py", line 178, in __call__
    w = hk.get_parameter("w", [input_size, output_size], dtype, init=w_init)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/base.py", line 603, in wrapped
    return wrapped._current(*args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/base.py", line 709, in get_parameter
    f"{fq_name!r} with retrieved shape {param.shape!r} does not match "
jax._src.traceback_util.UnfilteredStackTrace: ValueError: 'TimeChanCoupledGRU/~/linear/w' with retrieved shape (20, 32) does not match shape=[16, 32] dtype=dtype('complex64')

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "aec_get_output.py", line 354, in <module>
    predict = get_output(system, fit_infer, data_dict, out_dir, eval_kwargs)
  File "aec_get_output.py", line 224, in get_output
    pred = system.infer({"signals": d_input, "metadata": {}}, fit_infer=fit_infer)[0]
  File "/home/burro/personal/MetaAF/metaaf/meta.py", line 804, in infer
    filter_s, filter_p, preprocess_s, postprocess_s, batch, key
  File "/home/burro/personal/MetaAF/metaaf/core.py", line 534, in fit_single
    batch_state, batch_hop, jnp.array(subkeys)
  File "/home/burro/personal/MetaAF/metaaf/core.py", line 445, in online_step
    opt_s = opt_update(0, filter_features, opt_s)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/example_libraries/optimizers.py", line 196, in tree_update
    new_states = map(partial(update, i), grad_flat, states)
  File "/home/burro/personal/MetaAF/metaaf/optimizer_fgru.py", line 183, in update
    **optimizer_kwargs,
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/transform.py", line 184, in apply_fn
    out, state = f.apply(params, None, *args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/transform.py", line 451, in apply_fn
    out = f(*args, **kwargs)
  File "/home/burro/personal/MetaAF/metaaf/optimizer_fgru.py", line 124, in _timechancoupled_gru_fwd
    return optimizer(x, h, extra_inputs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/module.py", line 465, in wrapped
    out = f(*args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/contextlib.py", line 74, in inner
    return func(*args, **kwds)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/module.py", line 306, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/home/burro/personal/MetaAF/metaaf/optimizer_fgru.py", line 83, in __call__
    rnn_in = self.preprocess_flatten(x, extra_inputs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/module.py", line 465, in wrapped
    out = f(*args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/contextlib.py", line 74, in inner
    return func(*args, **kwds)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/contextlib.py", line 74, in inner
    return func(*args, **kwds)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/module.py", line 306, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/home/burro/personal/MetaAF/metaaf/optimizer_fgru.py", line 67, in preprocess_flatten
    return self.in_lin(input_stack_flat)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/module.py", line 465, in wrapped
    out = f(*args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/contextlib.py", line 74, in inner
    return func(*args, **kwds)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/module.py", line 306, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/basic.py", line 125, in __call__
    out = layer(out, *args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/module.py", line 465, in wrapped
    out = f(*args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/contextlib.py", line 74, in inner
    return func(*args, **kwds)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/module.py", line 306, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/basic.py", line 178, in __call__
    w = hk.get_parameter("w", [input_size, output_size], dtype, init=w_init)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/base.py", line 603, in wrapped
    return wrapped._current(*args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/base.py", line 709, in get_parameter
    f"{fq_name!r} with retrieved shape {param.shape!r} does not match "
ValueError: 'TimeChanCoupledGRU/~/linear/w' with retrieved shape (20, 32) does not match shape=[16, 32] dtype=dtype('complex64')

with retrieved shape (4, 32) does not match shape=[5, 32] dtype=dtype('complex64')

Hello, when i run code from second closed issue where is used pre-trained models and aec i get this error can you help me with this ?
Exception has occurred: ValueError (note: full exception trace is shown but execution is paused at: _run_module_as_main)
'ElementWiseGRU//linear/w' with retrieved shape (4, 32) does not match shape=[5, 32] dtype=dtype('complex64')
File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\site-packages\haiku_src\base.py", line 685, in get_parameter
raise ValueError(
File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\site-packages\haiku_src\basic.py", line 179, in call
w = hk.get_parameter("w", [input_size, output_size], dtype, init=w_init)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\site-packages\haiku_src\module.py", line 299, in run_interceptors
return bound_method(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\contextlib.py", line 81, in inner
return func(*args, **kwds)
^^^^^^^^^^^^^^^^^^^
File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\site-packages\haiku_src\module.py", line 458, in wrapped
out = f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^
File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\site-packages\haiku_src\basic.py", line 126, in call
out = layer(out, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\site-packages\haiku_src\module.py", line 299, in run_interceptors
return bound_method(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\contextlib.py", line 81, in inner
return func(*args, **kwds)
^^^^^^^^^^^^^^^^^^^
File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\site-packages\haiku_src\module.py", line 458, in wrapped
out = f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^
File "C:\Users\pc\Desktop\AI-Beamformers\meta-af\MetaAF\metaaf\optimizer_gru.py", line 71, in preprocess_flatten
return self.in_lin(input_stack_flat)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\site-packages\haiku_src\module.py", line 299, in run_interceptors
return bound_method(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\contextlib.py", line 81, in inner
return func(*args, **kwds)
^^^^^^^^^^^^^^^^^^^
File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\contextlib.py", line 81, in inner
return func(*args, **kwds)
^^^^^^^^^^^^^^^^^^^
File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\site-packages\haiku_src\module.py", line 458, in wrapped
out = f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^
File "C:\Users\pc\Desktop\AI-Beamformers\meta-af\MetaAF\metaaf\optimizer_gru.py", line 80, in call
rnn_in = self.preprocess_flatten(x, extra_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\site-packages\haiku_src\module.py", line 299, in run_interceptors
return bound_method(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\contextlib.py", line 81, in inner
return func(*args, **kwds)
^^^^^^^^^^^^^^^^^^^
File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\site-packages\haiku_src\module.py", line 458, in wrapped
out = f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^
File "C:\Users\pc\Desktop\AI-Beamformers\meta-af\MetaAF\metaaf\optimizer_gru.py", line 122, in _fwd
return optimizer(x, h, extra_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\site-packages\haiku_src\transform.py", line 456, in apply_fn
out = f(*args, **kwargs)
File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\site-packages\haiku_src\transform.py", line 183, in apply_fn
out, state = f.apply(params, None, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\pc\Desktop\AI-Beamformers\meta-af\MetaAF\metaaf\optimizer_gru.py", line 212, in update
update, state = optimizer.apply(
^^^^^^^^^^^^^^^^
File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\site-packages\jax\example_libraries\optimizers.py", line 199, in tree_update
new_states = map(partial(update, i), grad_flat, states)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\pc\Desktop\AI-Beamformers\meta-af\MetaAF\metaaf\core.py", line 462, in online_step
opt_s = opt_update(0, filter_features, opt_s)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\pc\Desktop\AI-Beamformers\meta-af\MetaAF\metaaf\core.py", line 549, in fit_single
cur_out, loss, batch_state = batch_step(
^^^^^^^^^^^
File "C:\Users\pc\Desktop\AI-Beamformers\meta-af\MetaAF\metaaf\meta.py", line 825, in infer
out, aux = fit_infer(
^^^^^^^^^^
File "C:\Users\pc\Desktop\AI-Beamformers\meta-af\MetaAF\zoo\aec\start.py", line 38, in
pred = system.infer({"signals": d_input, "metadata": {}}, fit_infer=fit_infer)[0]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\runpy.py", line 88, in _run_code
exec(code, run_globals)
File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\runpy.py", line 198, in _run_module_as_main (Current frame)
return _run_code(code, main_globals, None,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: 'ElementWiseGRU//linear/w' with retrieved shape (4, 32) does not match shape=[5, 32] dtype=dtype('complex64')?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.