Testing Derivatives and Automatic Differentiation¶
Copyright (C) 2026 Andreas Kloeckner
MIT License
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import numpy as np
import numpy.linalg as la
def f(xvec):
x, y = xvec
return np.array([
x*y + 2*y**3 - 2,
x**2*y + 4*y**2*np.cos(x) - 4
])
def Jf(xvec):
x, y = xvec
return np.array([
[y, x + 6*y**2],
[2*x*y - 4*y**2*np.sin(x), x**2 + 8*y*np.cos(x)]
])
x = np.random.randn(2)
s = np.random.randn(2)
s /= la.norm(s, 2)
for h in [1e-1, 1e-2, 1e-3, 1e-4, 1e-5]:
print(h, (f(x + h*s) - f(x))/h - Jf(x)@s)
0.1 [-0.05075743 -0.00600501] 0.01 [-0.00559639 -0.00193633] 0.001 [-0.00056485 -0.00020722] 0.0001 [-5.65366619e-05 -2.08583528e-05] 1e-05 [-5.65420810e-06 -2.08717268e-06]
Now try centered differences.
for h in [1e-1, 1e-2, 1e-3, 1e-4, 1e-5]:
print(h, (f(x + h*s) - f(x-h*s))/(2*h) - Jf(x)@s)
0.1 [0.00578502 0.01511624] 0.01 [5.78501843e-05 1.51267266e-04] 0.001 [5.78501809e-07 1.51268329e-06] 0.0001 [5.78437953e-09 1.51234321e-08] 1e-05 [4.67468286e-11 1.35421230e-10]
Automatic differentiation (with JAX)¶
import jax.numpy as jnp
from jax import jacfwd, jacrev, make_jaxpr, jvp
def f(xvec):
x, y = xvec
return jnp.array([
x*y + 2*y**3 - 2,
x**2*y + 4*y**2*jnp.cos(x) - 4
])
x = np.random.randn(2)
s = np.random.randn(2)
s /= la.norm(s, 2)
Now subject the JAX-computed Jacobian to the same test as above:
Jf = jacfwd(f)
for h in [1e-1, 1e-2, 1e-3, 1e-4]:
print(h, (f(x + h*s) - f(x))/h - Jf(x)@s)
0.1 [0.0767622 0.3196597] 0.01 [0.00780439 0.03206491] 0.001 [0.00077128 0.0033834 ] 0.0001 [5.5789948e-05 3.1447411e-03]
Is there a computationally more efficient variant? Consider using
_, jvp_val = jvp(f, (x,), (s,))
for h in [1e-1, 1e-2, 1e-3, 1e-4]:
_, jvp_val = jvp(f, (x,), (s,))
print(h, (f(x + h*s) - f(x))/h - jvp_val)
0.1 [-0.00058496 0.26220798] 0.01 [6.3538551e-05 2.3724556e-02] 0.001 [0.00120783 0.00150394] 0.0001 [ 0.01646674 -0.00231075]
How does it work?¶
print(make_jaxpr(f)(jnp.array([1.,2])))
{ lambda ; a:f32[2]. let
b:f32[1] = slice[limit_indices=(1,) start_indices=(0,) strides=(1,)] a
c:f32[] = squeeze[dimensions=(0,)] b
d:f32[1] = slice[limit_indices=(2,) start_indices=(1,) strides=(1,)] a
e:f32[] = squeeze[dimensions=(0,)] d
f:f32[] = mul c e
g:f32[] = integer_pow[y=3] e
h:f32[] = mul 2.0:f32[] g
i:f32[] = add f h
j:f32[] = sub i 2.0:f32[]
k:f32[] = integer_pow[y=2] c
l:f32[] = mul k e
m:f32[] = integer_pow[y=2] e
n:f32[] = mul 4.0:f32[] m
o:f32[] = cos c
p:f32[] = mul n o
q:f32[] = add l p
r:f32[] = sub q 4.0:f32[]
s:f32[1] = broadcast_in_dim[
broadcast_dimensions=()
shape=(1,)
sharding=None
] j
t:f32[1] = broadcast_in_dim[
broadcast_dimensions=()
shape=(1,)
sharding=None
] r
u:f32[2] = concatenate[dimension=0] s t
in (u,) }
print(make_jaxpr(jacfwd(f))(jnp.array([1.,2])))
{ lambda ; a:f32[2]. let
b:i32[2,2] = iota[dimension=0 dtype=int32 shape=(2, 2) sharding=None]
c:i32[2,2] = iota[dimension=1 dtype=int32 shape=(2, 2) sharding=None]
d:i32[2,2] = add b 0:i32[]
e:bool[2,2] = eq d c
f:f32[2,2] = convert_element_type[new_dtype=float32 weak_type=False] e
g:f32[2,2] = split[axis=1 sizes=(np.int64(2),)] f
h:f32[1] = slice[limit_indices=(1,) start_indices=(0,) strides=(1,)] a
i:f32[2,1] = slice[limit_indices=(2, 1) start_indices=(0, 0) strides=(1, 1)] g
j:f32[] = squeeze[dimensions=(0,)] h
k:f32[2] = squeeze[dimensions=(1,)] i
l:f32[1] = slice[limit_indices=(2,) start_indices=(1,) strides=(1,)] a
m:f32[2,1] = slice[limit_indices=(2, 2) start_indices=(0, 1) strides=(1, 1)] g
n:f32[] = squeeze[dimensions=(0,)] l
o:f32[2] = squeeze[dimensions=(1,)] m
p:f32[] = mul j n
q:f32[2] = mul k n
r:f32[2] = mul j o
s:f32[2] = add_any q r
t:f32[] = integer_pow[y=3] n
u:f32[] = integer_pow[y=2] n
v:f32[] = mul 3.0:f32[] u
w:f32[2] = mul o v
x:f32[] = mul 2.0:f32[] t
y:f32[2] = mul 2.0:f32[] w
z:f32[] = add p x
ba:f32[2] = add s y
bb:f32[] = sub z 2.0:f32[]
bc:f32[] = integer_pow[y=2] j
bd:f32[] = integer_pow[y=1] j
be:f32[] = mul 2.0:f32[] bd
bf:f32[2] = mul k be
bg:f32[] = mul bc n
bh:f32[2] = mul bf n
bi:f32[2] = mul bc o
bj:f32[2] = add_any bh bi
bk:f32[] = integer_pow[y=2] n
bl:f32[] = integer_pow[y=1] n
bm:f32[] = mul 2.0:f32[] bl
bn:f32[2] = mul o bm
bo:f32[] = mul 4.0:f32[] bk
bp:f32[2] = mul 4.0:f32[] bn
bq:f32[] = cos j
br:f32[] = sin j
bs:f32[2] = mul k br
bt:f32[2] = neg bs
bu:f32[] = mul bo bq
bv:f32[2] = mul bp bq
bw:f32[2] = mul bo bt
bx:f32[2] = add_any bv bw
by:f32[] = add bg bu
bz:f32[2] = add bj bx
ca:f32[] = sub by 4.0:f32[]
cb:f32[1] = broadcast_in_dim[
broadcast_dimensions=()
shape=(1,)
sharding=None
] bb
cc:f32[2,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1)
sharding=None
] ba
cd:f32[1] = broadcast_in_dim[
broadcast_dimensions=()
shape=(1,)
sharding=None
] ca
ce:f32[2,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1)
sharding=None
] bz
_:f32[2] = concatenate[dimension=0] cb cd
cf:f32[2,2] = concatenate[dimension=1] cc ce
cg:f32[2,2] = transpose[permutation=(1, 0)] cf
ch:f32[2,2] = split[axis=1 sizes=(np.int64(2),)] cg
in (ch,) }
print(make_jaxpr(jacrev(f))(jnp.array([1.,2])))
{ lambda ; a:f32[2]. let
b:f32[1] = slice[limit_indices=(1,) start_indices=(0,) strides=(1,)] a
c:f32[] = squeeze[dimensions=(0,)] b
d:f32[1] = slice[limit_indices=(2,) start_indices=(1,) strides=(1,)] a
e:f32[] = squeeze[dimensions=(0,)] d
f:f32[] = mul c e
g:f32[] = integer_pow[y=3] e
h:f32[] = integer_pow[y=2] e
i:f32[] = mul 3.0:f32[] h
j:f32[] = mul 2.0:f32[] g
k:f32[] = add f j
l:f32[] = sub k 2.0:f32[]
m:f32[] = integer_pow[y=2] c
n:f32[] = integer_pow[y=1] c
o:f32[] = mul 2.0:f32[] n
p:f32[] = mul m e
q:f32[] = integer_pow[y=2] e
r:f32[] = integer_pow[y=1] e
s:f32[] = mul 2.0:f32[] r
t:f32[] = mul 4.0:f32[] q
u:f32[] = cos c
v:f32[] = sin c
w:f32[] = mul t u
x:f32[] = add p w
y:f32[] = sub x 4.0:f32[]
z:f32[1] = broadcast_in_dim[
broadcast_dimensions=()
shape=(1,)
sharding=None
] l
ba:f32[1] = broadcast_in_dim[
broadcast_dimensions=()
shape=(1,)
sharding=None
] y
_:f32[2] = concatenate[dimension=0] z ba
bb:i32[2,2] = iota[dimension=0 dtype=int32 shape=(2, 2) sharding=None]
bc:i32[2,2] = iota[dimension=1 dtype=int32 shape=(2, 2) sharding=None]
bd:i32[2,2] = add bb 0:i32[]
be:bool[2,2] = eq bd bc
bf:f32[2,2] = convert_element_type[new_dtype=float32 weak_type=False] be
bg:f32[2,2] = split[axis=1 sizes=(np.int64(2),)] bf
bh:f32[2,1] bi:f32[2,1] = split[axis=1 sizes=(1, 1)] bg
bj:f32[2] = reduce_sum[axes=(np.int64(1),)] bi
bk:f32[2] = reduce_sum[axes=(np.int64(1),)] bh
bl:f32[2] = mul t bj
bm:f32[2] = mul bj u
bn:f32[2] = neg bl
bo:f32[2] = mul bn v
bp:f32[2] = mul 4.0:f32[] bm
bq:f32[2] = mul bp s
br:f32[2] = mul m bj
bs:f32[2] = add_any bq br
bt:f32[2] = mul bj e
bu:f32[2] = mul bt o
bv:f32[2] = add_any bo bu
bw:f32[2] = mul 2.0:f32[] bk
bx:f32[2] = mul bw i
by:f32[2] = add_any bs bx
bz:f32[2] = mul c bk
ca:f32[2] = add_any by bz
cb:f32[2] = mul bk e
cc:f32[2] = add_any bv cb
cd:f32[2,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1)
sharding=None
] ca
ce:f32[2,2] = pad[padding_config=((0, 0, 0), (1, np.int64(0), 0))] cd 0.0:f32[]
cf:f32[2,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1)
sharding=None
] cc
cg:f32[2,2] = pad[padding_config=((0, 0, 0), (0, np.int64(1), 0))] cf 0.0:f32[]
ch:f32[2,2] = add_any ce cg
ci:f32[2,2] = split[axis=0 sizes=(np.int64(2),)] ch
in (ci,) }
- Comment on
jacfwdvsjacrev. - Comment on
jvpvsvjp. - Mention
jit. - Mention
vmap.