Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Differentiable filtering using a cascade of second order IIR filters #3808

Open
SuperKogito opened this issue Jul 10, 2024 · 0 comments
Open

Comments

@SuperKogito
Copy link

SuperKogito commented Jul 10, 2024

🚀 The feature

A pytorch differentiable sosfilt() implementation like in https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.sosfilt.html, will allow for filtering data along one dimension using cascaded second-order sections. This should allow for a better support of high order stable filtering.

Motivation, pitch

The current alternative is to convert the cascade of biquads (2nd order IIR filters) to a high order filter and then use https://pytorch.org/audio/main/generated/torchaudio.functional.lfilter.html to apply the filter. Unfortunately this only works to a certain order (order<6). The following code illustrates the stability issues faced using lfilter with a high order filter. Hence, an option for a cascaded filtering to maintain stability would be of great advantage.

import torch
import scipy.signal as signal
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torchaudio.functional import lfilter

hardware = "cpu"
device = torch.device(hardware)
eps = 1e-8
    
def coeff_product(polynomials):
    n = polynomials.shape[0]
    if n == 1:
        return polynomials

    c1 = coeff_product(polynomials[n // 2 :])
    c2 = coeff_product(polynomials[: n // 2])
    if c1.shape[1] > c2.shape[1]:
        c1, c2 = c2, c1
    weight = c1.unsqueeze(1).flip(2)
    prod = F.conv1d(
        c2.unsqueeze(0),
        weight,
        padding=weight.shape[2] - 1,
        groups=c2.shape[0],
    ).squeeze(0)
    return prod

if __name__ == "__main__":
    for order in range(2, 12, 2):
        # Print the poles, zeros, and gain
        b, a = signal.ellip(order, 0.009, 80, 0.05, output='ba')
        sos = signal.ellip(order, 0.009, 80, 0.05, output='sos')
        zeros, poles, gain = signal.sos2zpk(sos)
        
        print("-" * 52)
        print("Zeroes : ", zeros)        
        print("Poles  : ", poles)        
        print("-" * 52)
        
        print("sos : ", sos)
        print("-" * 52)        
        
        print("b : ", b)
        print("a : ", a)
        print("-" * 52)
        # init var 
        fs = 500
        eps = 1e-8
        dirac = torch.tensor(signal.unit_impulse(fs), dtype=torch.float32)
        # PYTORCH IMPLEMENTATION
        # prepare coeffs 
        torch_sos = torch.tensor(sos, dtype=torch.float32)
        torch_a = torch_sos[:, 3:]
        torch_b = torch_sos[:, :3]
        high_order_a = coeff_product(torch_a)
        high_order_b = coeff_product(torch_b)
        
        print("sos : ", torch_sos)
        print("-" * 52)        
        
        print("torch_b : ", torch_b)
        print("torch_a : ", torch_a)
        print("-" * 52)
        
        print("high_order_b : ", high_order_b)
        print("high_order_a : ", high_order_a)
        print("-" * 52)
        
        # compute filter response
        y_torch_ba = lfilter(dirac.unsqueeze(0), high_order_a, high_order_b)
        
        ## SCIPY IMPLEMENTATION
        freq, freq_response = signal.sosfreqz(sos)
        x     = signal.unit_impulse(fs)
        y_tf  = signal.lfilter(high_order_b.squeeze(0).detach().numpy(), high_order_a.squeeze(0).detach().numpy(), x)
        y_sos = signal.sosfilt(sos, x)
        
        # plotting
        plt.figure(figsize=(15, 30))
        plt.subplot(3, 1, 1)
        plt.plot(y_sos, 'g', label='SOS')
        plt.legend(loc='best')
    
        plt.subplot(3, 1, 2)
        plt.plot(y_tf, 'k', label='TF')
        plt.legend(loc='best')
    
        plt.subplot(3, 1, 3)
        plt.plot(y_torch_ba.squeeze(0).detach().numpy(), "r", label="torch")
        plt.legend(loc='best')
        plt.show()

This feature would allow users to apply high order filtering (order>6) within loss functions and training loops.

Alternatives

The current alternative since no filtering based on a cascade of biquads is available are:

Additional context

https://dsp.stackexchange.com/questions/31457/multiple-biquads-vs-higher-order-filtering

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant