#!/usr/bin/env python
# coding: utf-8

# # Multipole and Local Expansions

# In[1]:


import numpy as np
import numpy.linalg as la
import matplotlib.pyplot as plt


# Let's consider a potential. This one could look slightly familiar from a homework assignment.

# In[2]:


nsources = 15
nxtgts = 400
nytgts = 400

angles = np.linspace(0, 2*np.pi, nsources, endpoint=False)
r = 1 + 0.3 * np.sin(3*angles)
sources = np.array([
    r*np.cos(angles),
    r*np.sin(angles),
    ])

np.random.seed(15)
charges = np.random.randn(nsources)

left, right, bottom, top = extent = (-2, 4, -4, 2)
targets = np.mgrid[left:right:nxtgts*1j, bottom:top:nytgts*1j]


# In[3]:


plt.plot(sources[0], sources[1], "x")

dist_vecs = sources.reshape(2, -1, 1, 1) - targets.reshape(2, 1, targets.shape[-1], -1)
dists = np.sqrt(np.sum(dist_vecs**2, axis=0))

potentials = np.sum(charges.reshape(-1, 1, 1) * np.log(dists), axis=0)
plt.imshow(potentials.T[::-1], extent=extent)


# Now let's create a stash of derivatives, all about a center of 0, to make things easier:

# In[4]:


def f(arg):
    return np.log(np.sqrt(np.sum(arg**2, axis=0)))

def fdx(arg):
    x, y = arg
    r2 = np.sum(arg**2, axis=0)
    return x/r2

def fdy(arg):
    x, y = arg
    r2 = np.sum(arg**2, axis=0)
    return y/r2

def fdxx(arg):
    x, y = arg
    r2 = np.sum(arg**2, axis=0)
    return 1/r2 - 2*x**2/r2**2

def fdyy(arg):
    x, y = arg
    r2 = np.sum(arg**2, axis=0)
    return 1/r2 - 2*y**2/r2**2

def fdxy(arg):
    x, y = arg
    r2 = np.sum(arg**2, axis=0)
    return - 2*x*y/r2**2


# ## Local expansions

# In[5]:


center = np.array([1.5, -1])
#center = np.array([2, -2])
#center = np.array([3, -3])
#center = np.array([0, 0])


# Local expansion:
# $$\psi (\mathbf{x} - \mathbf{y}) \approx \sum _{| p | \leqslant k
#      } \underbrace{\frac{D^p_{\mathbf{x}} \psi
#    (\mathbf{ x - \mathbf{y}) |_{\mathbf{x = \mathbf{c}}}
#     }  }{p!}}_{\text{depends on src/ctr}}
#    \underbrace{(\mathbf{x} - \mathbf{c})^p}_{\text{dep. on ctr/tgt}} $$
# 
# $\mathbf{x}$: targets, $\mathbf{y}$: sources.

# In[20]:


expn = 0

for isrc in range(nsources):
    a = center - sources[:, isrc]
    hx, hy = targets - center.reshape(-1, 1, 1)
    expn += charges[isrc]*(
        f(a)
        + fdx(a)*hx
        + fdy(a)*hy
        + fdxx(a)*hx**2/2
        + fdxy(a)*hx*hy
        + fdyy(a)*hy**2/2
        )


# In[7]:


err = expn - potentials
plt.plot(center[0], center[1], "o")
plt.plot(sources[0], sources[1], "x")
plt.imshow(np.log10(1e-2 + np.abs(err.T[::-1])), extent=extent)
plt.colorbar()


# Test accuracy at a point
test_y_idx = np.argmin(np.abs(center[1] - targets[1, 0, :]))
test_idx = (7*nxtgts//8, test_y_idx)
plt.plot(targets[0][test_idx], targets[1][test_idx], "ro")
print("Relative error at (red) test point:", abs(err[test_idx])/abs(potentials[test_idx]))


# * Move the center around, see how the errors change
# * Reduce to linears, see how the errors change

# In[8]:


plt.grid()
plt.xlabel("Distance from center")
plt.ylabel("Error")
plt.loglog(targets[0, :, test_y_idx]-center[0], np.abs(err[:, test_y_idx]))


# What is the slope of the error graph? What should it be?
# 
# (Disregard the close-to-center region: Center and Target points are not at *exactly* the same vertical position.)

# ## Multipole expansions

# In[9]:


center = np.array([0, 0])
# center = np.array([1, 0])


# Now sum a multipole expansion about the center at the targets. Make sure to watch for negative signs from the chain rule.
# 
# Multipole expansion:
# $$\psi (\mathbf{x} - \mathbf{y}) \approx \sum _{| p | \leqslant k
#      } \underbrace{\frac{D^p_{\mathbf{y}} \psi
#    (\mathbf{ x - \mathbf{y}) |_{\mathbf{y = \mathbf{c}}}
#     }  }{p!}}_{\text{depends on ctr/tgt}}
#    \underbrace{(\mathbf{y} - \mathbf{c})^p}_{\text{dep. on src/ctr}} . $$
# 
# $\mathbf{x}$: targets, $\mathbf{y}$: sources.

# In[22]:


expn = 0

for isrc in range(nsources):
    a = targets - center.reshape(-1, 1, 1)
    hx, hy = sources[:, isrc] - center

    expn += charges[isrc]*(
        f(a)
        
        # Negative sign from the chain rule
        - fdx(a)*hx
        - fdy(a)*hy
        
        + fdxx(a)*hx**2/2
        + fdxy(a)*hx*hy
        + fdyy(a)*hy**2/2
        )


# In[18]:


err = expn - potentials
imgdata = err
plt.plot(center[0], center[1], "o")
plt.plot(sources[0], sources[1], "x")
plt.imshow(np.log10(1e-2 + np.abs(imgdata.T[::-1])), extent=extent)
plt.colorbar()

# Test accuracy at a point
test_y_idx = 5*nytgts//8
test_idx = (7*nxtgts//8, test_y_idx)
plt.plot(targets[0][test_idx], targets[1][test_idx], "ro")
print("Relative error at (red) test point:", abs(err[test_idx])/abs(potentials[test_idx]))


# * Move the center around, observe convergence behavior
# * Reduce to linears, observe convergence behavior

# In[19]:


plt.grid()
plt.xlabel("Distance from center")
plt.ylabel("Error")
plt.loglog(targets[0, :, test_y_idx]-center[0], np.abs(err[:, test_y_idx]))


# What is the slope in the far region? What should it be?

# Look at individual basis functions:

# In[14]:


plt.imshow(
    fdx(targets).T[::-1],
    extent=extent, vmin=-1, vmax=1)


# Why is this thing called a 'multipole' expansion?
