See also the updated article

Jax is a Python package for efficient computation and automatic differentiation. I’ve written about the automatic differentiation capabilities in the post on the Hessian and differentiating the Black-Scholes model. In essence Jax traces the Python program and then JIT-compiles it using the XLA system.

Sometimes it is useful to be able to visualise the computational graph that is constructed by the tracing. This is how this can be done in Jax.

Export the HLO Intermediate Representation (IR)

The JAX intermediate representation of a computation (i.e., not just individual functions but the composed computation) can be saved to a text file using the jax.tools.jax_to_hlo module.

As an example here is how to dump the IR of the gradient of a simple composed function:


import jax.tools.jax_to_hlo
from jax import numpy, grad
from jax.lib import xla_client

def tanh(x):  
  y = numpy.exp(-2.0 * x)
  return (1.0 - y) / (1.0 + y)

def lfn(x):
    return numpy.log(tanh(x).sum())

def dlfn(x):
    return  grad(lfn)(x)

with open("t.txt", "w") as f:
    f.write(jax.tools.jax_to_hlo.jax_to_hlo(dlfn,
	                                        [ ("x" , xla_client.Shape("f32[100]")) ])[1])

This will create a text file t.txt with a text-based HLO.

Visualise

A tool for visualising the HLO IR is developed together with the rest of the Tensorflow/XLA toolchain but I did not find it distributed as a binary, which means it needs to be built from source. Follow the instructions for building from source and then issue a following bazel build command:

bazel build  //tensorflow/compiler/xla/tools:interactive_graphviz

The build is likely to take rather longer then it takes to make a coffee!

After you had a longish nap you an start interactive_graphviz, with --text_hlo="t.txt". Command

list computations

will show all computations, and the graph of any computation can be displayed by entering its name. This is the graph I got from this example:

G Computation xla_computation_ordered_wrapper__2.39 94165223097632 broadcast.8 dimensions={} f32[100] operand = f32[] 1 94165223233104 subtract.9 f32[100] 94165223097632->94165223233104 0 94165223250288 Parameter 0 f32[100] 94165223231152 multiply.5 f32[100] 94165223250288->94165223231152 0 94165223241040 broadcast.4 dimensions={} f32[100] operand = f32[] -2 94165223241040->94165223231152 1 94165223232128 exponential.6 f32[100] 94165223231152->94165223232128 94165223232128->94165223233104 1 94165223474320 add.12 f32[100] 94165223232128->94165223474320 0 94165223477504 multiply.34 f32[100] 94165223232128->94165223477504 1 94165223475104 divide.13 f32[100] 94165223233104->94165223475104 0 94165223487072 multiply.29 f32[100] 94165223233104->94165223487072 1 94165223473520 broadcast.11 dimensions={} f32[100] operand = f32[] 1 94165223473520->94165223474320 1 94165223474320->94165223475104 1 94165223484672 multiply.14 f32[100] 94165223474320->94165223484672 0 94165223474320->94165223484672 1 94165223489056 divide.31 f32[100] 94165223474320->94165223489056 1 94165223476704 reduce.23 Subcomputation: add dimensions={0} f32[] operand 1 = f32[] 0 94165223475104->94165223476704 0 94165223479056 log.24 f32[] 94165223476704->94165223479056 94165223480928 divide.26 f32[] operand 0 = f32[] 1 94165223476704->94165223480928 1 94165223481712 broadcast.27 dimensions={} f32[100] 94165223480928->94165223481712 94165223486240 multiply.28 f32[100] 94165223481712->94165223486240 0 94165223481712->94165223489056 0 94165223483872 broadcast.16 dimensions={} f32[100] operand = f32[] 1 94165223485456 divide.17 f32[100] 94165223483872->94165223485456 0 94165223484672->94165223485456 1 94165223485456->94165223486240 1 94165223486240->94165223487072 0 94165223488080 negate.30 f32[100] 94165223487072->94165223488080 94165223491040 add.33 f32[100] 94165223488080->94165223491040 0 94165223490064 negate.32 f32[100] 94165223489056->94165223490064 94165223490064->94165223491040 1 94165223491040->94165223477504 0 94165223496752 multiply.37 f32[100] 94165223477504->94165223496752 0 94165223495952 broadcast.36 dimensions={} f32[100] operand = f32[] -2 94165223495952->94165223496752 1 94165223498048 tuple.38 (f32[100]) 94165223496752->94165223498048 cluster_94165223059120 ROOT 94165223498048->cluster_94165223059120