关注

python小波变换3-代码实现(pywt库,cwt-2D/3D时频图绘制,dwt-信号分解及重建)

感谢前辈大佬,引用自:
[1]https://ataspinar.com/2018/12/21/a-guide-for-using-the-wavelet-transform-in-machine-learning/
[2]https://blog.csdn.net/Fvine_/article/details/83381250
[3]https://blog.csdn.net/weixin_46713695/article/details/127106554
码龄一年半,目前用python做一些数据分析。加上了一些自己的整理和总结,欢迎大家提出建议,侵删!

〇、更多一点的原理

1.小波变换如何工作?

傅里叶变换使用一系列不同频率的正弦波来分析信号。即,信号通过正弦波的线性组合来表示。
小波变换使用一系列称为小波的函数,每个函数具有不同的尺度。小波这个词的意思是小波,这正是小波的意思。
在这里插入图片描述

我们可以看到正弦波和小波之间的区别。主要区别在于正弦波不在时间上局部化(它从-无穷大延伸到+无穷大),而小波在时间上局部化。这允许小波变换除了频率信息之外还获得时间信息。

由于小波在时间上是局部的,我们可以将我们的信号与时间上不同位置的小波相乘。我们从信号的开头开始,慢慢地将小波移动到信号的结尾。此过程也称为卷积。在对原始(母)小波完成此操作后,我们可以对其进行缩放,使其变大并重复该过程。这个过程如下图所示。

wavelet动态图

小波变换的这个二维输出是以比例图形式表示信号的时间尺度。那么这个维度为什么叫尺度呢?由于术语频率是为傅里叶变换保留的,因此小波变换通常用尺度来表示。这就是为什么比例图的两个维度是时间和比例。对于那些发现频率比比例更直观的人,可以使用等式将比例转换为伪频率
在这里插入图片描述
其中fa是伪频率,fc是母小波的中心频率,a一个是比例因子。

我们可以看到更高的比例因子(更长的小波)对应于更小的频率,因此通过在时域中缩放小波,我们将在频域中分析更小的频率(实现更高分辨率)。反之亦然,通过使用较小的比例,我们在时域中有更多的细节。所以尺度基本上是频率的倒数。

PS:PyWavelets 包含函数scale2frequency 将比例域转换为频域。
小波工作原理:https://www.youtube.com/watch?v=QX1-xGVFqmw

2.母小波的选择

每种类型的小波都有不同的形状、平滑度和紧凑度,可用于不同的目的。由于小波只需满足两个数学条件,因此很容易生成新型小波。

这两个数学条件就是所谓的归一化和正交化约束:
小波必须具有 1) 有限能量和 2) 零均值。
有限的能量意味着它局限在时间和频率上;它是可积的,小波和信号之间的内积总是存在的。
可接受性条件意味着小波在时域中具有零均值,在时域中零频率处为零。这是必要的,以确保它是可积的,并且小波变换的逆也可以计算。

此外:
小波可以是正交的或非正交的。
小波可以是双正交的,也可以不是。
小波可以是对称的,也可以不是。
小波可以是复数或实数。如果是复数,通常分为表示幅度的实部和表示相位的虚部。
小波被归一化为具有单位能量。

在每个小波家族中,可以有许多属于该家族的不同小波子类别。您可以通过系数的数量(即消失动量Vanishing Moments维基百科的数量)和分解级别来区分小波的不同子类别。

import pywt
import matplotlib.pyplot as plt
 
db_wavelets = pywt.wavelist('db')[:5]
print(db_wavelets)
*** ['db1', 'db2', 'db3', 'db4', 'db5']
 
fig, axarr = plt.subplots(ncols=5, nrows=5, figsize=(20,16))
fig.suptitle('Daubechies family of wavelets', fontsize=16)
for col_no, waveletname in enumerate(db_wavelets):
    wavelet = pywt.Wavelet(waveletname)
    no_moments = wavelet.vanishing_moments_psi
    family_name = wavelet.family_name
    for row_no, level in enumerate(range(1,6)):
        wavelet_function, scaling_function, x_values = wavelet.wavefun(level = level)
        axarr[row_no, col_no].set_title("{} - level {}\n{} vanishing moments\n{} samples".format(
            waveletname, level, no_moments, len(x_values)), loc='left')
        axarr[row_no, col_no].plot(x_values, wavelet_function, 'bD--')
        axarr[row_no, col_no].set_yticks([])
        axarr[row_no, col_no].set_yticklabels([])
plt.tight_layout()
plt.subplots_adjust(top=0.9)
plt.show()

在这里插入图片描述在图中,我们可以看到“Daubechies”家族 (db) 的小波。在第一列中,我们可以看到一阶 Daubechies 小波 (db1),在第二列中可以看到二阶 (db2),在第五列中可以看到五阶。PyWavelets 包含高达 20 阶 (db20) 的 Daubechies 小波。
阶数表示消失动量的数量。所以 db3 有 3 个消失动量,db5 有 5 个消失动量。消失动量的数量与小波的逼近阶数和平滑度有关。如果一个小波有 p 个消失动量,它可以逼近 p – 1 次多项式。

选择小波时,我们还可以指明分解的级别。默认情况下,PyWavelets 选择输入信号可能的最大分解级别。最大分解级别(参见pywt.dwt_max_level())取决于输入信号长度和小波的长度。

正如我们所看到的,随着消失动量的数量增加,小波的多项式次数增加并且变得更平滑。并且随着分解层次的增加,该小波表示的样本数增加。

一、环境配置

python 3.9
pywavelets 1.3.0 (即pywt)

二、pywavelets库API

官网:https://pywavelets.readthedocs.io/en/latest/index.html

1.一些小波对象

1.1 小波家族、小波名称

PyWavelets库包含 14 个母小波 (小波家族)

import pywt
pywt.families()

> ['haar',  'db',  'sym',  'coif',  'bior',  'rbio',  'dmey',  'gaus', 
> 'mexh',  'morl',  'cgau',  'shan',  'fbsp',  'cmor']
pywt.wavelist(family=None, kind='all') 

> ['bior1.1',  'bior1.3',  'bior1.5',  'bior2.2',  'bior2.4', 
> 'bior2.6',  'bior2.8',  'bior3.1',  'bior3.3',  'bior3.5',  'bior3.7',
> 'bior3.9',  'bior4.4',  'bior5.5',  'bior6.8',  'cgau1',  'cgau2', 
> 'cgau3',  'cgau4',  'cgau5',  'cgau6',  'cgau7',  'cgau8',  'cmor', 
> 'coif1',  'coif2',  'coif3',  'coif4',  'coif5',  'coif6',  'coif7', 
> 'coif8',  'coif9',  'coif10',  'coif11',  'coif12',  'coif13', 
> 'coif14',  'coif15',  'coif16',  'coif17',  'db1',  'db2',  'db3', 
> 'db4',  'db5',  'db6',  'db7',  'db8',  'db9',  'db10',  'db11', 
> 'db12',  'db13',  'db14',  'db15',  'db16',  'db17',  'db18',  'db19',
> 'db20',  'db21',  'db22',  'db23',  'db24',  'db25',  'db26',  'db27',
> 'db28',  'db29',  'db30',  'db31',  'db32',  'db33',  'db34',  'db35',
> 'db36',  'db37',  'db38',  'dmey',  'fbsp',  'gaus1',  'gaus2', 
> 'gaus3',  'gaus4',  'gaus5',  'gaus6',  'gaus7',  'gaus8',  'haar', 
> 'mexh',  'morl',  'rbio1.1',  'rbio1.3',  'rbio1.5',  'rbio2.2', 
> 'rbio2.4',  'rbio2.6',  'rbio2.8',  'rbio3.1',  'rbio3.3',  'rbio3.5',
> 'rbio3.7',  'rbio3.9',  'rbio4.4',  'rbio5.5',  'rbio6.8',  'shan', 
> 'sym2',  'sym3',  'sym4',  'sym5',  'sym6',  'sym7',  'sym8',  'sym9',
> 'sym10',  'sym11',  'sym12',  'sym13',  'sym14',  'sym15',  'sym16', 
> 'sym17',  'sym18',  'sym19',  'sym20']

pywt.wavelist(kind='continuous')

>['cgau1', 'cgau2', 'cgau3', 'cgau4', 'cgau5', 'cgau6', 'cgau7', ...
for family in pywt.families():
    print("%s family: " % family + ', '.join(pywt.wavelist(family)))

> haar family: haar db family: db1, db2, db3, db4, db5, db6, db7, db8,
> db9, db10, db11, db12, db13, db14, db15, db16, db17, db18, db19, db20,
> db21, db22, db23, db24, db25, db26, db27, db28, db29, db30, db31,
> db32, db33, db34, db35, db36, db37, db38 sym family: sym2, sym3, sym4,
> sym5, sym6, sym7, sym8, sym9, sym10, sym11, sym12, sym13, sym14,
> sym15, sym16, sym17, sym18, sym19, sym20 coif family: coif1, coif2,
> coif3, coif4, coif5, coif6, coif7, coif8, coif9, coif10, coif11,
> coif12, coif13, coif14, coif15, coif16, coif17 bior family: bior1.1,
> bior1.3, bior1.5, bior2.2, bior2.4, bior2.6, bior2.8, bior3.1,
> bior3.3, bior3.5, bior3.7, bior3.9, bior4.4, bior5.5, bior6.8 rbio
> family: rbio1.1, rbio1.3, rbio1.5, rbio2.2, rbio2.4, rbio2.6, rbio2.8,
> rbio3.1, rbio3.3, rbio3.5, rbio3.7, rbio3.9, rbio4.4, rbio5.5, rbio6.8
> dmey family: dmey gaus family: gaus1, gaus2, gaus3, gaus4, gaus5,
> gaus6, gaus7, gaus8 mexh family: mexh morl family: morl cgau family:
> cgau1, cgau2, cgau3, cgau4, cgau5, cgau6, cgau7, cgau8 shan family:
> shan fbsp family: fbsp cmor family: cmor

1.2 连续小波对象及属性

cwt = pywt.ContinuousWavelet('morl') 
print(cwt)

> ContinuousWavelet morl   
> Family name:    Morlet wavelet   
> Short name:     morl   
> Symmetry:       symmetric   
> DWT:            False   
> CWT:            True   
> Complex CWT:    False

对象属性:

symmetry:对称性
orthogonal:正交性
biorthogonal:双正交性
family_name:小波家族 族名
short_family_name:家族 族名简称
family_number:小波家族里排行老几
name
number
除了以上属性是和离散相似的,连续小波还有自己比较特殊的
lower_bound
upper_bound
dt
complex_cwt
针对 shan, fbsp, cmor这三族小波。还可查看带宽频率和中心频率属性
bandwidth_frequency
center_frequency

针对fbsp小波,还可查看其参数顺序属性
fbsp_order

cwt = pywt.ContinuousWavelet('morl') 
cwt.symmetry

> 'symmetric'

2.一些小波函数

2.1 绘制尺度函数和小波函数图形**

import matplotlib.pyplot as plt
wavelet=pywt.ContinuousWavelet('gaus8')
[psi,xval] = wavelet.wavefun( level=10)#或者写length=1024
plt.plot(xval,psi)
plt.title("Gaussian Wavelet of order 1024")
plt.show()

在这里插入图片描述

#选择对称且连续的小波类型
import pywt
count=0
fig=plt.figure(figsize=(10, 10))
for wave in pywt.wavelist(kind='continuous'):
    if wave in ['cgau2','cgau6']:
        continue
    else:
        cwt = pywt.ContinuousWavelet(wave) 
        print(cwt)
        if cwt.symmetry=='symmetric':
            [psi,xval] = cwt.wavefun(level=10)#或者写length=1024 
            count += 1 
            # count_list.append(count)
            print(count)
            # 画子图
            ax = fig.add_subplot(6, 2, count)
            ax.plot(xval,psi)
            ax.set_title("{} of order 1024".format(wave))
            plt.tight_layout()
plt.show()

在这里插入图片描述

3.CWT尺度图(scaleogram)绘制

#二维时频图
#1.2.3为参数,y为参数
sr=128 #1.sampling rate
wavename = 'morl'#2.母小波名称
totalscal = 150 # 3.totalscal是对信号进行小波变换时所用尺度序列的长度(通常需要预先设定好)
fc = pywt.central_frequency(wavename)  # 计算小波函数的中心频率
cparam = 2 * fc * totalscal  # 常数c
scales = cparam / np.arange(totalscal, 1, -1)  # 为使转换后的频率序列是一等差序列,尺度序列必须取为这一形式(也即小波尺度)
[cwtmatr, frequencies] = pywt.cwt(y, scales, wavename, 1.0 / sr)#4.y为将要进行cwt变换的一维输入信号
t = np.arange(0, y.shape[0]/sr, 1.0/sr)
plt.contourf(t, frequencies, abs(cwtmatr))
plt.ylabel(u"freq(Hz)")
plt.xlabel(u"time(s)")
# plt.subplots_adjust(hspace=0.4)  # 调整边距和子图的间距 hspace为子图之间的空间保留的高度,平均轴高度的一部分
plt.title = ("小波时频图,totalscale = {}".format(i))
plt.show()

在这里插入图片描述

#接上,三维时频图
from mpl_toolkits import mplot3d
import matplotlib.pyplot as plt
plt.rc('font',family='Arial') 
 
fig = plt.figure(figsize=(10,10))
ax = plt.axes(projection='3d')
ax.contour3D(t, frequencies, abs(cwtmatr), 50,cmap='rainbow')
ax.tick_params(labelsize=14)
ax.set_xlabel('Translation',fontsize=22)
ax.set_ylabel('Scale',fontsize=22)
ax.set_zlabel('Amplitude',fontsize=22)

plt.show()

在这里插入图片描述

厄尔尼诺数据集是一个用于跟踪厄尔尼诺现象的时间序列数据集,包含从1871年到1997年的每季度海面温度测量值。为了了解尺度图,让我们将其与原始时间序列数据及其傅里叶变换一起可视化为厄尔尼诺数据集。

def plot_wavelet(time, signal, scales, 
                 waveletname = 'cmor', 
                 cmap = plt.cm.seismic, 
                 title = 'Wavelet Transform (Power Spectrum) of signal', 
                 ylabel = 'Period (years)', 
                 xlabel = 'Time'):
    
    dt = time[1] - time[0]
    [coefficients, frequencies] = pywt.cwt(signal, scales, waveletname, dt)
    power = (abs(coefficients)) ** 2
    period = 1. / frequencies
    levels = [0.0625, 0.125, 0.25, 0.5, 1, 2, 4, 8]
    contourlevels = np.log2(levels)
    
    fig, ax = plt.subplots(figsize=(15, 10))
    im = ax.contourf(time, np.log2(period), np.log2(power), contourlevels, extend='both',cmap=cmap)
    
    ax.set_title(title, fontsize=20)
    ax.set_ylabel(ylabel, fontsize=18)
    ax.set_xlabel(xlabel, fontsize=18)
    
    yticks = 2**np.arange(np.ceil(np.log2(period.min())), np.ceil(np.log2(period.max())))
    ax.set_yticks(np.log2(yticks))
    ax.set_yticklabels(yticks)
    ax.invert_yaxis()
    ylim = ax.get_ylim()
    ax.set_ylim(ylim[0], -1)
    
    cbar_ax = fig.add_axes([0.95, 0.5, 0.03, 0.25])
    fig.colorbar(im, cax=cbar_ax, orientation="vertical")
    plt.show()
 
def plot_signal_plus_average(time, signal, average_over = 5):
    fig, ax = plt.subplots(figsize=(15, 3))
    time_ave, signal_ave = get_ave_values(time, signal, average_over)
    ax.plot(time, signal, label='signal')
    ax.plot(time_ave, signal_ave, label = 'time average (n={})'.format(5))
    ax.set_xlim([time[0], time[-1]])
    ax.set_ylabel('Signal Amplitude', fontsize=18)
    ax.set_title('Signal + Time Average', fontsize=18)
    ax.set_xlabel('Time', fontsize=18)
    ax.legend()
    plt.show()
    
def get_fft_values(y_values, T, N, f_s):
    f_values = np.linspace(0.0, 1.0/(2.0*T), N//2)
    fft_values_ = fft(y_values)
    fft_values = 2.0/N * np.abs(fft_values_[0:N//2])
    return f_values, fft_values
 
def plot_fft_plus_power(time, signal):
    dt = time[1] - time[0]
    N = len(signal)
    fs = 1/dt
    
    fig, ax = plt.subplots(figsize=(15, 3))
    variance = np.std(signal)**2
    f_values, fft_values = get_fft_values(signal, dt, N, fs)
    fft_power = variance * abs(fft_values) ** 2     # FFT power spectrum
    ax.plot(f_values, fft_values, 'r-', label='Fourier Transform')
    ax.plot(f_values, fft_power, 'k--', linewidth=1, label='FFT Power Spectrum')
    ax.set_xlabel('Frequency [Hz / year]', fontsize=18)
    ax.set_ylabel('Amplitude', fontsize=18)
    ax.legend()
    plt.show()
 
dataset = "http://paos.colorado.edu/research/wavelets/wave_idl/sst_nino3.dat"
df_nino = pd.read_table(dataset)
N = df_nino.shape[0]
t0=1871
dt=0.25
time = np.arange(0, N) * dt + t0
signal = df_nino.values.squeeze()
 
scales = np.arange(1, 128)
plot_signal_plus_average(time, signal)
plot_fft_plus_power(time, signal)
plot_wavelet(time, signal, scales)

厄尔尼诺数据集(顶部),傅里叶变换(中间)和连续小波变换(底部)
我们可以在上图中看到厄尔尼诺数据集及其时间平均值,在中间图中看到傅里叶变换,在底部图中看到连续小波变换产生的尺度图。在尺度图中,我们可以看到大部分power集中在2-8年。如果我们将其转换为频率(T=1/f),则对应于0.125–0.5 Hz。在傅里叶变换中也可以看到对应功率的增加。小波变换和傅里叶变换的主要区别在于小波变换还提供了时间信息,而傅里叶变换没有。例如,在比例图中,我们可以看到1920年之前有很多波动,而在1960年到1990年之间没有那么多波动。我们还可以看到,随着时间的推移,周期从短变长。信号中的这种动态行为,可以用小波变换来可视化,傅里叶变换则不能。

4. DWT信号分解

DWT 通过高通和低通滤波器的级联成为一个滤波器组,将信号分成几个子频带。
举个例子,假设我们有一个频率高达 1000 Hz 的信号。在第一阶段,我们将信号分成低频部分和高频部分,即 0-500 Hz 和 500-1000 Hz。在第二阶段,我们将低频部分再次分成两部分:0-250 Hz 和 250-500 Hz。在第三阶段,我们将 0-250 Hz 部分拆分为 0-125 Hz 部分和 125-250 Hz 部分。这一直持续到我们达到所需的细化水平或直到我们用完样本。

import pywt
 
x = np.linspace(0, 1, num=2048)
chirp_signal = np.sin(250 * np.pi * x**2)
    
fig, ax = plt.subplots(figsize=(6,1))
ax.set_title("Original Chirp Signal: ")
ax.plot(chirp_signal)
plt.show()
    
data = chirp_signal
waveletname = 'sym5'
 
fig, axarr = plt.subplots(nrows=5, ncols=2, figsize=(6,6))
for ii in range(5):
    (data, coeff_d) = pywt.dwt(data, waveletname)
    axarr[ii, 0].plot(data, 'r')
    axarr[ii, 1].plot(coeff_d, 'g')
    axarr[ii, 0].set_ylabel("Level {}".format(ii + 1), fontsize=14, rotation=90)
    axarr[ii, 0].set_yticklabels([])
    if ii == 0:
        axarr[ii, 0].set_title("Approximation coefficients", fontsize=14)
        axarr[ii, 1].set_title("Detail coefficients", fontsize=14)
    axarr[ii, 1].set_yticklabels([])
plt.tight_layout()
plt.show()

在PyWavelets中,DWT的使用方式为pywt.dwt()。dwt将返回两组系数;近似系数和细节系数。近似系数表示DWT的低通滤波器(平均滤波器)的输出。细节系数表示DWT的高通滤波器(差分滤波器)的输出。通过对上一个DWT的近似系数再次应用DWT,我们得到下一级的小波变换。在每个下一级,原始信号被向下采样2倍。
所以现在我们已经看到了DWT作为滤波器组实现的含义;在每个后续级别,近似系数被划分为低通和高通部分,并且DWT被再次应用于低通部分。正如我们所看到的,我们的原始信号现在被转换为多个信号,每个信号对应于不同的频带。
PS:我们还可以使用pywt.wavedec()计算更高级别的系数。该函数将原始信号和分解级数作为输入,并返回一组近似系数(第n级)和n组细节系数(第1级至第n级)。
PS2:这种在不同尺度上分析信号的想法也称为多分辨率/多尺度分析,以这种方式分解信号也就是多分辨率分解或子带编码。

#短通滤波及信号重建,signal为一维信号序列,thresh, wavelet为可选参数
def lowpassfilter(signal, thresh = 1, wavelet="db22"):
    thresh = thresh*np.nanmax(signal)
    coeff = pywt.wavedec(signal, wavelet, mode="per" )
    coeff[1:] = (pywt.threshold(i, value=thresh, mode="soft" ) for i in coeff[1:])
    reconstructed_signal = pywt.waverec(coeff, wavelet, mode="per" )
    return reconstructed_signal
 
fig, ax = plt.subplots(figsize=(12,8))
ax.plot(signal, color="b", alpha=0.5, label='original signal')
rec = lowpassfilter(d5.iloc[0,:], 0.9)
ax.plot(signal,rec, 'k', label='DWT smoothing', linewidth=2)
ax.legend()
ax.set_title('Removing High Frequency Noise with DWT', fontsize=18)
ax.set_ylabel('Signal Amplitude', fontsize=16)
ax.set_xlabel('Sample No', fontsize=16)
plt.show()

在这里插入图片描述

转载自CSDN-专业IT技术社区

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。

原文链接:https://blog.csdn.net/m0_67587806/article/details/128099265

评论

赞0

评论列表

微信小程序
QQ小程序

关于作者

点赞数:0
关注数:0
粉丝:0
文章:0
关注标签:0
加入于:--