Jax: Visualising the computational graph of a jax program
See also the updated article
Jax is a Python package for efficient computation and automatic differentiation. I’ve written about the automatic differentiation capabilities in the post on the Hessian and differentiating the Black-Scholes model. In essence Jax traces the Python program and then JIT-compiles it using the XLA system.
Sometimes it is useful to be able to visualise the computational graph that is constructed by the tracing. This is how this can be done in Jax.
Export the HLO Intermediate Representation (IR)
The JAX intermediate representation of a computation (i.e., not just
individual functions but the composed computation) can be saved to a
text file using the jax.tools.jax_to_hlo
module.
As an example here is how to dump the IR of the gradient of a simple composed function:
import jax.tools.jax_to_hlo
from jax import numpy, grad
from jax.lib import xla_client
def tanh(x):
y = numpy.exp(-2.0 * x)
return (1.0 - y) / (1.0 + y)
def lfn(x):
return numpy.log(tanh(x).sum())
def dlfn(x):
return grad(lfn)(x)
with open("t.txt", "w") as f:
f.write(jax.tools.jax_to_hlo.jax_to_hlo(dlfn,
[ ("x" , xla_client.Shape("f32[100]")) ])[1])
This will create a text file t.txt
with a text-based HLO.
Visualise
A tool for visualising the HLO IR is developed together with the rest
of the Tensorflow/XLA toolchain but I did not find it distributed as a
binary, which means it needs to be built from source. Follow the
instructions for building from
source and then issue a
following bazel
build command:
bazel build //tensorflow/compiler/xla/tools:interactive_graphviz
The build is likely to take rather longer then it takes to make a coffee!
After you had a longish nap you an start interactive_graphviz, with
--text_hlo="t.txt"
. Command
list computations
will show all computations, and the graph of any computation can be displayed by entering its name. This is the graph I got from this example: