Butterfly Factorization¶
Following the paper "An algorithm for the rapid evaluation of special function transforms", by Michael O’Neil, Franco Woolfe, Vladimir Rokhlin.
In [2]:
import numpy as np
import scipy.linalg as la
import scipy.linalg.interpolative as sli
import matplotlib.pyplot as plt
import scipy.special as sps
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
Set some parameters: (ignore nlevels
for now)
In [95]:
#nlevels = 4
#nlevels = 7
nlevels = 9
n = 2**(nlevels + 2)
In [96]:
def make_dft(n, power):
omega = np.exp(2*np.pi*1j/n)
ns = np.arange(n)
exponents = ns.reshape(-1, 1) * ns
return omega**(power*exponents)
dft = make_dft(n, power=1)
idft = make_dft(n, power=-1)
In [75]:
la.norm(np.abs(idft @ dft) - n*np.eye(n))
Out[75]:
1.0665710446355353e-07
Verify the FFT property:
In [76]:
quotient = dft[::2, :n//2] / make_dft(n//2, power=1)
plt.imshow(quotient.real)
plt.colorbar()
print(np.max(np.abs(quotient - 1)))
7.629497034145061e-11
In [77]:
plt.imshow(dft.real)
Out[77]:
<matplotlib.image.AxesImage at 0x7f661aba7170>
Consider the the claim that the numerical rank of the kernel $e^{ixt}$ for $x\in[0,X]$ and $t\in[0,T]$ depends only on the product $XT$:
In [78]:
T = 15
X = 15
resolution = 200
x, t = np.mgrid[0:X:resolution * 1j, 0:T:resolution * 1j]
mat = np.exp(1j*x*t)
plt.imshow(mat[:, ::-1].T.real, extent=(0, X, 0, T))
scale = 6
for exp in np.linspace(-1.25, 1.25, 5):
subX = 2**exp * scale
subT = 2**-exp * scale
plt.gca().add_patch(Rectangle((0, 0), subX, subT, fill=False))
# Observe: These are all the same matrix!
In [79]:
Xfacs = np.linspace(1/2, 2, 30)
Tfacs = 1/Xfacs
# Change me
scale = np.pi
resolution = 30
for Xfac, Tfac in zip(Xfacs, Tfacs):
x, t = np.mgrid[0:Xfac*scale:resolution * 1j, 0:Tfac*scale:resolution * 1j]
mat = np.exp(1j*x*t)
_, sigma, _ = la.svd(mat)
print(f"{Xfac:.2f} {Tfac:.2f}\t", np.sum(sigma > 1e-7))
0.50 2.00 11 0.55 1.81 11 0.60 1.66 11 0.66 1.53 11 0.71 1.41 11 0.76 1.32 11 0.81 1.23 11 0.86 1.16 11 0.91 1.09 11 0.97 1.04 11 1.02 0.98 11 1.07 0.94 11 1.12 0.89 11 1.17 0.85 11 1.22 0.82 11 1.28 0.78 11 1.33 0.75 11 1.38 0.72 11 1.43 0.70 11 1.48 0.67 11 1.53 0.65 11 1.59 0.63 11 1.64 0.61 11 1.69 0.59 11 1.74 0.57 11 1.79 0.56 11 1.84 0.54 11 1.90 0.53 11 1.95 0.51 11 2.00 0.50 11
The Legendre Vandermonde / Transform¶
In [82]:
lege_nodes = sps.legendre(n).weights[:, 0]
lege_vdm = np.array([sps.eval_legendre(i, lege_nodes) for i in range(n)]).T
plt.imshow(lege_vdm)
Out[82]:
<matplotlib.image.AxesImage at 0x7f66135a4c50>