# Jax: Visualising the computational graph of a jax program

*See also the updated article*

Jax is a Python package for efficient computation and automatic differentiation. I’ve written about the automatic differentiation capabilities in the post on the Hessian and differentiating the Black-Scholes model. In essence Jax traces the Python program and then JIT-compiles it using the XLA system.

Sometimes it is useful to be able to visualise the computational graph that is constructed by the tracing. This is how this can be done in Jax.

## Export the HLO Intermediate Representation (IR)

The JAX intermediate representation of a computation (i.e., not just
individual functions but the composed computation) can be saved to a
text file using the `jax.tools.jax_to_hlo`

module.

As an example here is how to dump the IR of the gradient of a simple composed function:

```
import jax.tools.jax_to_hlo
from jax import numpy, grad
from jax.lib import xla_client
def tanh(x):
y = numpy.exp(-2.0 * x)
return (1.0 - y) / (1.0 + y)
def lfn(x):
return numpy.log(tanh(x).sum())
def dlfn(x):
return grad(lfn)(x)
with open("t.txt", "w") as f:
f.write(jax.tools.jax_to_hlo.jax_to_hlo(dlfn,
[ ("x" , xla_client.Shape("f32[100]")) ])[1])
```

This will create a text file `t.txt`

with a text-based HLO.

## Visualise

A tool for visualising the HLO IR is developed together with the rest
of the Tensorflow/XLA toolchain but I did not find it distributed as a
binary, which means it needs to be built from source. Follow the
instructions for building from
source and then issue a
following `bazel`

build command:

```
bazel build //tensorflow/compiler/xla/tools:interactive_graphviz
```

The build is likely to take rather longer then it takes to make a coffee!

After you had a longish nap you an start interactive_graphviz, with
`--text_hlo="t.txt"`

. Command

```
list computations
```

will show all computations, and the graph of any computation can be displayed by entering its name. This is the graph I got from this example: