Giter Site home page Giter Site logo

v0lta / pytorch-wavelet-toolbox Goto Github PK

View Code? Open in Web Editor NEW
246.0 7.0 32.0 25.29 MB

Differentiable fast wavelet transforms in PyTorch with GPU support.

License: European Union Public License 1.2

Python 100.00%
wavelet-transform pytorch wavelet-packets wavelet wavelet-analysis fast-wavelet-transform matrix-fwt

pytorch-wavelet-toolbox's People

Contributors

cthoyt avatar felixblanke avatar felixdivo avatar kgasenzer avatar loki-veera avatar v0lta avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

pytorch-wavelet-toolbox's Issues

Wavelet packet transform -urgent help required

Hello, i was working on a research project and i wanted to ask you what is the exact output of the wavelet packet transform . Do we get outputs such that they can directly be concatenatd with convolution layer? I need to fuse the frequency domain features with spatial domain.I need more information about how the wavelet packet transform outputs look like.
Thank you in advance for your reply

WaveletPacket2D yields incoherent results for certain data shape/wavelet level combinations

Calling ptwt.WaveletPacket2D (and processing its outputs) can lead to not self-explanatory exceptions or unexpected results depending on the image size and the wavelet level. Evaluating the following code snippet demonstrates this behavior based on six simple tests with a test image.

import numpy as np
from PIL import Image
import torch
import pywt
import ptwt
from itertools import product

def generate_PIL_img(size=256, channels=3):
    # generate demo image
    img = np.zeros((size, size, channels), dtype=np.uint8)
    img[::size//8] = 255
    img[:,::size//8] = 255
    return Image.fromarray(img)

def wl_transform(img, img_size, max_lev, wavelet_str = "db5", mode = "reflect"):
    # wavelet transform with pipeline: PIL -> numpy -> torch -> numpy
    img = img.resize((img_size,img_size))
    image_batch = np.array(img)[None,:]
    image_batch_tensor = torch.from_numpy(image_batch.astype(np.float32))

    wavelet = pywt.Wavelet(wavelet_str)
    wp_keys = ["".join(node) for node in product(["a", "h", "v", "d"], repeat=max_lev)]

    channels = []
    for channel in range(image_batch_tensor.shape[-1]):
        with torch.no_grad():
            pt_data = image_batch_tensor[:, :, :, channel]
            ptwt_wp_tree = ptwt.WaveletPacket2D(data=pt_data, wavelet=wavelet, mode=mode)
            packet_list = []
            for node in wp_keys:
                packet = torch.squeeze(ptwt_wp_tree["".join(node)], dim=1)
                packet_list.append(packet)
            channel_packets = torch.stack(packet_list, dim=1)
        channels.append(channel_packets)
    packets = torch.stack(channels, -1)

    return packets.numpy()

# TEST 1
# Ok 
wav = wl_transform(generate_PIL_img(), 128, 2)
assert len(wav.shape) == 5

# TEST 2
# Exception 1: AssertionError
wav = wl_transform(generate_PIL_img(), 128, 3)
assert len(wav.shape) == 5 # wav.shape = (23, 64, 23, 3)

# TEST 3
# Ok 
wav = wl_transform(generate_PIL_img(), 256, 3)
assert len(wav.shape) == 5

# TEST 4
# Exception 2: KeyError
wav = wl_transform(generate_PIL_img(), 128, 4) # 'aaaa' not found in ptwt_wp_tree
assert len(wav.shape) == 5

# TEST 5
# Exception 3: AssertionError 
wav = wl_transform(generate_PIL_img(), 256, 4)
assert len(wav.shape) == 5 # wave.shape = (24, 256, 24, 3)

# TEST 6
# Ok 
wav = wl_transform(generate_PIL_img(), 512, 4)
assert len(wav.shape) == 5

The expected behavior would be an informative exception on wrong usage or a valid output otherwise.

torch 2.0 support

This issue tracks our efforts to support the upcoming torch 2.0 release.

Insufficient padding removal in wavelet packet reconstruction.

I noticed that the removal of added padding is not done properly in our Wavelet packet reconstruction.

The following snippet calculates the reconstruction sizes for all valid combinations of

  • FWT / Wavelet Packets
  • reflect at boundary/boundary wavelets
  • separable/non-separable
import ptwt, torch
def print_reconstruction_sizes(size: int, wavelet="db4", level=3):
    data = torch.eye(size, device="cuda", dtype=torch.float32)
    reconstructions = []
    reconstructions.append(ptwt.waverec2(ptwt.wavedec2(data=data, wavelet=wavelet, level=level, mode="reflect"), wavelet=wavelet))
    reconstructions.append(ptwt.fswaverec2(ptwt.fswavedec2(data, wavelet=wavelet, mode="reflect", level=level), wavelet=wavelet))
    reconstructions.append(ptwt.WaveletPacket2D(data=data, wavelet=wavelet, maxlevel=level, mode="reflect").reconstruct()[""])
    reconstructions.append(ptwt.MatrixWaverec2(wavelet=wavelet)(ptwt.MatrixWavedec2(wavelet=wavelet, level=level)(data)))
    reconstructions.append(ptwt.MatrixWaverec2(wavelet=wavelet, separable=False)(ptwt.MatrixWavedec2(wavelet=wavelet, level=level, separable=False)(data)))
    reconstructions.append(ptwt.WaveletPacket2D(data=data, wavelet=wavelet, maxlevel=level, mode="boundary", separable=False).reconstruct()[""])
    reconstructions.append(ptwt.WaveletPacket2D(data=data, wavelet=wavelet, maxlevel=level, mode="boundary", separable=True).reconstruct()[""])
    print(f"reconstruction sizes for size {size}: {[rec.shape[-1] for rec in reconstructions]}")

Running it on the current v0.1.5 dev branch for some sample input sizes yields:

for size in range(42, 80, 2):
    print_reconstruction_sizes(size)
Output

reconstruction sizes for size 42: [42, 42, 46, 42, 42, 48, 48]
reconstruction sizes for size 44: [44, 44, 46, 44, 44, 48, 48]
reconstruction sizes for size 46: [46, 46, 46, 46, 46, 48, 48]
reconstruction sizes for size 48: [48, 48, 54, 48, 48, 48, 48]
reconstruction sizes for size 50: [50, 50, 54, 50, 50, 56, 56]
reconstruction sizes for size 52: [52, 52, 54, 52, 52, 56, 56]
reconstruction sizes for size 54: [54, 54, 54, 54, 54, 56, 56]
reconstruction sizes for size 56: [56, 56, 62, 56, 56, 56, 56]
reconstruction sizes for size 58: [58, 58, 62, 58, 58, 64, 64]
reconstruction sizes for size 60: [60, 60, 62, 60, 60, 64, 64]
reconstruction sizes for size 62: [62, 62, 62, 62, 62, 64, 64]
reconstruction sizes for size 64: [64, 64, 70, 64, 64, 64, 64]
reconstruction sizes for size 66: [66, 66, 70, 66, 66, 72, 72]
reconstruction sizes for size 68: [68, 68, 70, 68, 68, 72, 72]
reconstruction sizes for size 70: [70, 70, 70, 70, 70, 72, 72]
reconstruction sizes for size 72: [72, 72, 78, 72, 72, 72, 72]
reconstruction sizes for size 74: [74, 74, 78, 74, 74, 80, 80]
reconstruction sizes for size 76: [76, 76, 78, 76, 76, 80, 80]
reconstruction sizes for size 78: [78, 78, 78, 78, 78, 80, 80]

3D Matrix Wavelet Decomposition

I have a specific application where the sparse matrix representation of the DWT is really useful for a 3D signals. I'd like to propose it as an enhancement.

Using wavelets in FNet instead of FFT

Hi @v0lta ,

My name is Alexander, I am one of the authors of https://github.com/snakers4/silero-models.

Some time ago there was a paper by Google (FNet: Mixing Tokens with Fourier Transforms) , which essentially took a self-attention transformer module and replaced the self-attention mechanism with an FFT layer and it allegedly reduced computation a lot without affection metrics a lot.

On the other hand, in PyTorch it is as simple as just using this method instead of the actual self-attention layer. The data is essentially batch_size * sequence_length * hidden_dimension.

I tried this idea in my domain (speech to text) and it really worked to my astonishment (!). Of course I applied it to a small network (aiming to make it much faster, remove SoftMax layer and compress the network 2-3x both in terms of speed and parameters). I could not get anywhere near the boasted by Google mere 10% quality reduction, but probably and surprisingly, ceteris paribus, the 2x smaller network (and probably much faster due to the most expensive part being removed) could boast maybe 25-30% worse metrics.

I had and idea, maybe try wavelet transform instead of FFT? Since they are supposed to be localized both in time and frequency. Then I googled a bit, found pywt and then your library. I am not very well acquainted with wavelets, so I found a wavedec2 function on your README and I would like to ask a couple of questions:

  • If I run wavelist() I get a long list of filters. Since I am not very familiar with them, could you maybe point me somewhere where I can find some intuition behind choosing them?;

  • Do gradients flow through your functions, i.e. can they be used as feature extractors in the middle of the network? In your learnable example you manually construct the network, but you learn the filters themselves, not use them as layers;

Stationary Wavelet Transform

Hi,
I really appreciate your work!
I wonder if it is possible to expand this toolbox, inserting the SWT (in particular swt2 and iswt2), already implemented in the pywt library.
Thanks,
Matteo

my inverse stationary wavelets transform (ISWT) implement get wrong results, is there any problem?

Thanks to the author for his contribution, this project is great! My current work requires swt and iswt. I see that the author has provided an experimental swt code, but no code for iswt yet. I wrote a code for iswt modeled after _swt() and waverec() functions. It can generate a result with right shape, but it is different from the input of _swt(). May I ask where is the problem?

def _iswt(
    coeffs: List[torch.Tensor],
    wavelet: Union[Wavelet, str],
    level: Optional[int] = None,
) -> torch.Tensor:

    torch_device = coeffs[0].device
    torch_dtype = coeffs[0].dtype
    
    for coeff in coeffs[1:]:
        if torch_device != coeff.device:
            raise ValueError("coefficients must be on the same device")
        elif torch_dtype != coeff.dtype:
            raise ValueError("coefficients must have the same dtype")
    
    _, _, rec_lo, rec_hi = _get_filter_tensors(
        wavelet, flip=False, device=torch_device, dtype=torch_dtype
    )
    
    filt_len = rec_lo.shape[-1]
    filt = torch.stack([rec_lo, rec_hi], 0)

    res_lo = coeffs[0]
    for cpos, res_hi in enumerate(coeffs[1:]):
        dilation = 2**cpos
        res_lo = torch.stack([res_lo, res_hi], 1)
        res_lo = torch.nn.functional.conv_transpose1d(res_lo, filt, stride=1,dilation=dilation)
        # remove the padding
        padl, padr = dilation * (filt_len // 2 - 1), dilation * (filt_len // 2)

        if padl > 0:
            res_lo = res_lo[..., padl:]
        if padr > 0:
            res_lo = res_lo[..., :-padr]                                                                                                                

    return res_lo

3D doesn't work

I have tried to pass tensor of shape B,C,D,H,W but internally it adds a dimension after B. So the tensor becomes B,1,C,D,H,W.
Then it fails to conv3d. Is there anyway to resolve this?

Support for Differentiable CWT

Hello!

I was wondering if it would be possible to support a differentiable version of the ptwt.continuous_transform.cwt function. I see that internally, the function converts everything to numpy arrays, and so it's not able to handle input tensors with gradients attached to them.

This would be very useful for my case where I'm using CWT scalograms for computing a similarity score/loss between signals. I understand that several other transforms support gradients e.g. wavedec and waverec, which work fantastically in ML pipelines I've tested, so I was hoping that such functionality could be extended to the continuous transform as well.

Cheers!

wavelet_linear

does the wavelet linear layer in the network compression perform wavelet transformation of each image being passed into it?
i want to fit a unit into a cnn such that it gives a wavelet transform as the output,but i cant figure out how im supposed to do that.

Error in wavedec2

Hello,

Thank you for the great work.

I'm trying to implement wavedec2 on image tensor (e.g. size (1x3x10x10))

When i run

tmp = torch.randn(1, 3, 10, 10)
wavelet = pywt.Wavelet("haar")
coeff2d = ptwt.wavedec2(tmp, wavelet, level=1, mode="zero")

I get

RuntimeError: Given groups=1, weight of size [4, 1, 2, 2], expected input[1, 3, 10, 10] to have 1 channels, but got 3 channels instead

This error. It seems to be having problem with input tensor that has a channel dimension bigger than 1.

It would be greatly appreciated if you could help me out on this

Best,

Questions about input dimensions

It's a great job!
Questions about input dimensions:
Q1: To apply 2D DWT to an image with [B, C, H, W], I have to combine BC (eg. [BC, H, W]) to use ptwt.wavedec2 ?
Q2: To apply 3D DWT to a video sequence with [B, T, C, H, W], I have to combine BC (eg. [BC, T, H, W]) to use ptwt.wavedec3 ?

In Ptwt:
wavedec2:
data (torch.Tensor): The input data tensor with up to three dimensions.
2d inputs are interpreted as [height, width],
3d inputs are interpreted as [batch_size, height, width].
wavedec3:
data (torch.Tensor): The input data of shape
[batch_size, length, height, width]

Adaptive 1D wavelet filters

Hello, Thank you for the great repo!

I was wondering if there are plans for adding adaptive 1D wavelet filters with examples?

Thanks again for the wonderful effort!

1d boundary filters seem transposed.

Our 1d boundary filter code produces transposed coefficient matrices, which means the packet code has a problem i.e.

@pytest.mark.slow
@pytest.mark.parametrize("max_lev", [1, 2, 3, 4])
def test_boundary_matrix_packets1(max_lev):
    """Ensure the 2d - sparse matrix haar tree and pywt-tree are the same."""
    _compare_trees1("db1", max_lev, "zero", "boundary")

Has a problem currently. The 2d case is ok.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.