#!/usr/bin/env python
# coding: utf-8

# # Conjugate Gradient Mechanics
# 
# Copyright (C) 2026 Andreas Kloeckner
# 
# <details>
# <summary>MIT License</summary>
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# 
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
# 
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
# </details>

# In[1]:


import numpy as np
import numpy.linalg as la
import matplotlib.pyplot as plt

np.set_printoptions(linewidth=200)


# In[2]:


np.random.seed(22)

n = 10
L = np.random.randn(n, n)
A = L@L.T
b = np.random.randn(n)

la.cholesky(A)

L2 = np.random.randn(2, 2)
A2 = L2@L2.T
b2 = np.random.randn(2)

la.cholesky(A2), b2


# In[3]:


def plot_bowl(A, b, res=100, ext=12):
    vgrid = np.mgrid[-ext:ext:res*1j,-ext:ext:res*1j]

    phi = 1/2*np.einsum("ixy,ij,jxy->xy", vgrid, A, vgrid) - np.einsum("ixy,i->xy", vgrid, b)

    plt.contour(vgrid[0], vgrid[1], phi, 50)

plot_bowl(A2, b2)


# ## Line search

# Implement
# ```
# def alpha(A, b, x, s):
#     ...
# ```

# In[4]:


def alpha(A, b, x, s):
    r = b-A@x
    return s@r/(s@A@s)


# In[5]:


x2 = np.random.randn(2) * 4
s2 = np.random.randn(2) * 4

plot_bowl(A2, b2)

alpha2 = alpha(A2, b2, x2, s2) 

plt.quiver(x2[0], x2[1], alpha2*s2[0], alpha2*s2[1],
          color='blue', angles='xy', scale_units='xy', scale=1, label='Vector A')


# In[6]:


alphas = np.linspace(-0.5*alpha2, 1.5*alpha2, 100)

x2s = x2.reshape(-1, 1) + alphas * s2.reshape(-1, 1)
phis = 1/2*np.einsum("ia,ij,ja->a", x2s, A2, x2s) - np.einsum("ia,i->a", x2s, b2)
plt.plot(alphas, phis)
plt.vlines(alpha2, -100, 100)


# ### $A$-orthogonality
# 
# h/t [J. Shewchuk](https://www.cs.cmu.edu/~quake-papers/painless-conjugate-gradient.pdf) for the plot idea.
# 
# His is nicer because the vectors are maps of each other.

# In[7]:


import matplotlib.pyplot as plt
import numpy as np

def draw_A_orthogonal_pairs(rng, A, b, n_pairs=10, ext=12, scale=1):
    length = ext/10*2*scale

    origins = rng.uniform(-ext*0.7, ext*0.7, (n_pairs, 2))

    # Generate random angles for the first vector of each pair
    vec = rng.normal(size=(2, n_pairs))
    vec2 = np.empty((2, n_pairs))

    vec2[0] = vec[1]
    vec2[1] = -vec[0]

    vec = vec/la.norm(vec, axis=0)
    vec2 = vec2/la.norm(vec2, axis=0)

    vec = vec*length
    vec2 = vec2*length

    Linv = la.inv(la.cholesky(A))
    vec = np.einsum("ij,jn->in", Linv.T, vec)
    vec2 = np.einsum("ij,jn->in", Linv.T, vec2)

    plt.quiver(origins[:, 0], origins[:, 1], vec[0], vec[1], 
              color='blue', angles='xy', scale_units='xy', scale=1, label='Vector A')
    plt.quiver(origins[:, 0], origins[:, 1], vec2[0], vec2[1], 
              color='red', angles='xy', scale_units='xy', scale=1, label='Vector B (Orthogonal)')

plt.figure(figsize=(16,8))
plt.subplot(121)
draw_A_orthogonal_pairs(np.random.default_rng(seed=17), np.eye(2), np.zeros(2))
plot_bowl(np.eye(2), np.zeros(2))
plt.gca().set_aspect("equal")

plt.subplot(122)
draw_A_orthogonal_pairs(np.random.default_rng(seed=17), A2, b2)
plot_bowl(A2, np.zeros(2))
plt.gca().set_aspect("equal")


# ### Search Directions
# 
# - Generate using (modified) Gram-Schmidt with the $A$ inner product.
# - Observe what parts of the orthogonalization were actually unnecessary.
# - Note that (at least for illustrative purposes) we can generate these ahead of time!

# In[16]:


x0 = np.random.randn(n)

# We *could* choose this differently.
# But residual orthogonality would fail if we did.
r0 = A@x0 - b
search_dirs = [r0/(r0@A@r0)**0.5]

for i in range(n-1):
    znext = A@search_dirs[-1]
    coeffs = []
    for s in search_dirs:
        coeff = znext@A@s
        coeffs.append(coeff)
        znext = znext - coeff * s

    znext = znext/(znext@A@znext)**0.5
    search_dirs.append(znext)

    print(f"vector {i+1}: {np.array(coeffs).round(3)}")

search_dirs = np.array(search_dirs).T


# In[17]:


(search_dirs.T @ A @ search_dirs).round(8)


# ### Compare step sizes with error decomposition
# 
# Assuming $\boldsymbol x_0=\boldsymbol 0$, we have $\boldsymbol e_0= \boldsymbol x_0 - \boldsymbol x^\ast=- \boldsymbol x^\ast$.

# In[18]:


xtrue = la.solve(A, b)
error = 0 - xtrue
deltas = la.solve(search_dirs, error)


# In[19]:


x = x0
xs = [x]

for i in range(n):
    s = search_dirs[:, i]
    myalpha = alpha(A, b, x,  s)
    x = x + myalpha *s
    print(-myalpha, deltas[i])

    xs.append(x)

xs = np.array(xs).T


# In[20]:


x - la.solve(A, b)


# ### Errors
# 
# Note: Residual norms or $A$-norms are not strictly decreasing!

# In[21]:


errors = xs - xtrue.reshape(-1, 1)

error_norms = np.array([
    err@err
    for err in errors.T])
plt.plot(error_norms)


# ### Residuals

# In[22]:


residuals = A@xs - b.reshape(-1, 1)
(residuals.T @ search_dirs).round(5)


# In[23]:


(residuals.T @ residuals).round(6)


# In[ ]:




