Efficient Hessian calculation with JAX and automatic forward and reversemode differentiation
This is already fairly well documented in the https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html but here it is with a moderately complex example that is not drawn from machine learning and some timings. The advantages of JAX for this particular example are:

Composable reversemode and forwardmode automatic differentiation which enables efficient Hessian computation

Minimal adjustment of NumPy/Python programs needed

Compilation via XLA to efficient GPU or CPU (or TPU) code
As you will see below the accelleration versus plain NumPy code is about a factor of 500!
Introduction
The Hessian matrix is the matrix of second derivatives of a function of multiple variables:
where is a function of the variables . Since the matrix is symetric there is elements of Hessian to compute. Where a finitedifference method is used to compute the Hessian and is large, the computation can become prohibitetively expensive.
Uses of the Hessian matrix
Most frequently cited use of the Hessian is in optimisation in Netwontype methods (see https://en.wikipedia.org/wiki/Hessian_matrix#Applications). The expense of computing it however means QuasiNetwon methods are often used, where the Hessian is estimated rather than computed. Much more efficient Hessian matrix computation makes true Netwon methods tractable for a wide range of problems^{1}.
Another application is calculation of the https://en.wikipedia.org/wiki/Fisher_information and the Cramér–Rao bound by calculating the Hessian of negative loglikelihod. This in turn allows estimation of variance of parameters when using maximum likelihood.
Example
As an example I am using a maximum likelihood calculation of fitting a simple model to twodimensional raster observations (i.e., a monochromatic image).
The model^{2} is a two dimensional Gaussian:
where , and .
The Python code is straightforward:
def gauss2d(x0, y0, amp, sigma, rho, diff, a):
"""
Sample model: Gaussian on a 2d plane
"""
dx=a[...,0]x0
dy= a[...,1]y0
r=np.hypot(dx, dy)
return amp*np.exp(1.0/ (2*sigma**2) * (r**2 +
rho*(dx*dy)+
diff*(dx**2dy**2)))
Likelihood
Very often the nest step would be to assume that observational uncertainty is normally distributed. This has a deep computational advantage^{3} but is however often not really the case as the prelevance of various attempts to clean the observed data before it is put into maximum likelihood analysis shows.
For this example I’ll instead assume that the uncertainty is distributed with the Cauchy distribution. The wider tails of this distribution means a maximum likelihood solution will be far less affected by a few outlier points.
The Python is simple:
def cauchy(x, g, x0):
return 1.0/(numpy.pi * g) * g**2/((xx0)**2+g**2)
def cauchyll(o, m, g):
return 1 * np.log(cauchy(m, g, o)).sum()
Observed data
I will use a simulated data set instead of an observation: a image with a Gaussian source in the middle and normally distributed noise:
Putting it together and applying JAX
Applying JAX to this nontrivial NumPy is extremely simple. It consists of :
 Using
jax.numpy
module instead of the standardnumpy
 Decorating functions with
@jit
 Calculating the Hessian by first doing reversemode and then forwardmode differentiation
Overal the number of changes is very small and it is quite practical to maintain a codebase which can be switched between numpy and JAX without modifications.
The main part of the program looks like this:
import jax.numpy as np
from jax import grad, jit, vmap, jacfwd, jacrev
@jit
def gauss2d(x0, y0, amp, sigma, rho, diff, a):
"""
Sample model: Gaussian on a 2d plane
"""
dx=a[...,0]x0
dy= a[...,1]y0
r=np.hypot(dx, dy)
return amp*np.exp(1.0/ (2*sigma**2) * (r**2 +
rho*(dx*dy)+
diff*(dx**2dy**2)))
def mkobs(p, n, s):
"A simulated observation"
aa=numpy.moveaxis(numpy.mgrid[2:2:n*1j, 2:2:n*1j], 0, 1)
aa=np.array(aa, dtype="float64")
m=gauss2d(*p, aa)
return aa, m + numpy.random.normal(size=m.shape, scale=s)
@jit
def cauchy(x, g, x0):
return 1.0/(numpy.pi * g) * g**2/((xx0)**2+g**2)
@jit
def cauchyll(o, m, g):
return 1 * np.log(cauchy(m, g, o)).sum()
def makell(o, a, g):
def ll(p):
m=gauss2d(*p, a)
return cauchyll(o, m, g)
return jit(ll)
def hessian(f):
return jit(jacfwd(jacrev(f)))
The measurement of timings was done as follows:
# Calculate the hessian around this point, which by design is the
# mostlikely point
P=np.array([0.,0., 1.0, 0.5, 0., 0.], dtype="float64" )
a, o = mkobs( P, 3000, 0.5)
ff=makell(o, a, 0.5)
hf=hessian(ff)
ndhf=nd.Hessian(ff)
# JIT warmup call. Smallish effect if number is large in timeit
hf(P).block_until_ready()
print("Time with JAX:", timeit.timeit("hf(P).block_until_ready()", number=10, globals=globals()))
print("Time with finite diff:", timeit.timeit("ndhf(P)", number=10, globals=globals()))
How efficient?
Running this Intel CPU (no GPU) I get following timings for 10 runs
using timeit
:
 Time with JAX and automatic differentiation: 16s
 Time with JAX function valuation and finitedifference differentiation (with Numdifftools) : 891s
 Time with plain numpy and numerical differentiation (with Numdifftools): 9900s
So impressively there is a reduction in runtime! Note that this is a fairly large problem where the JIT costs are well amortized – the performance would not be this good for scattered small oneoff problems. The improvement in performance can broken down as a gain due to automaticdifferentiation and gain due to fusion and compilation using XLA.
Footnotes

QuasiNetwon methods have some intrinsic advantages so even with efficient Hessian matrix they could be competitive. So the choice of algorithm probably needs to be considered on a casebycase basis. ↩

This not an often used parametrisation but it is useful because the azimuthally independent dimensional width is a separate parameter from the squashedness of the Gaussian in the + and X directions which are given by dimensionless parameters ↩

Because the loglikehood will correspond to sum of squares of deviations of the model from observations – hence the traditional leastsquares methods ↩