Jax is a Python package for efficient
computation and automatic differentiation. I’ve written about the
automatic differentiation capabilities in the post on the Hessian and differentiating the
Black-Scholes model. In essence
Jax traces the Python program and then JIT-compiles it using the XLA
system.
Sometimes it is useful to be able to visualise the computational graph
that is constructed by the tracing. This is how this can be done in
Jax.
Export the HLO Intermediate Representation (IR)
The JAX intermediate representation of a computation (i.e., not just
individual functions but the composed computation) can be saved to a
text file using the jax.tools.jax_to_hlo module.
As an example here is how to dump the IR of the gradient of a simple
composed function:
This will create a text file t.txt with a text-based HLO.
Visualise
A tool for visualising the HLO IR is developed together with the rest
of the Tensorflow/XLA toolchain but I did not find it distributed as a
binary, which means it needs to be built from source. Follow the
instructions for building from
source and then issue a
following bazel build command: