An update on visualising the computational graph of a jax program
I’ve written while ago about visualising the computational graph of a JAX program here. Jax has evolved since then, so here is an update for the current (as of time of writing) version of Jax (0.3.1).
Export the HLO Intermediate Representation (IR)
The main change is that the graph is produced from an
jax.xla_computation
object. The other change is that instead of
specifying the input shape I am supplying an example array
(numpy.ones(100)
).
Here is the updated program:
import jax
from jax import numpy, grad
def tanh(x):
y = numpy.exp(-2.0 * x)
return (1.0 - y) / (1.0 + y)
def lfn(x):
return numpy.log(tanh(x).sum())
def dlfn(x):
return grad(lfn)(x)
z=jax.xla_computation(dlfn)(numpy.ones(100))
with open("t.txt", "w") as f:
f.write(z.as_hlo_text())
This will create a text file t.txt
with a text-based HLO.
Visualise
The simplest way of visualising is to dump as a dot graph and run dot:
with open("t.dot", "w") as f:
f.write(z.as_hlo_dot_graph())
dot t.dot -Tpng > t.png
This produces the following image:
Visualise the optimised graph
(added on 2023-08-28 )
It is possible to also visualise the optimised graph produced by the XLA compiler.
The mechanism to get the HLO representation of a compiled graph is
shown in test_compiler_ir
test in file api_test.py
of the JaX
repository. Basically it consists of calling the jax.jit
function, lower()
ing using to a specific example data structure
as argument, compile()
ing and then using the as_text()
method. So for example:
def ff(x):
x = x*3
x = x+2
return x
jax.jit(ff).lower(numpy.ones(100)).compile().as_text()
To visualise it is necessary to re-parse the HLO into a XLA computation and then use the XLA functionality to generate a Dot graph. This can be easily achieved using the raw XLA Python bindings with following function:
def todotgraph(x):
return xla_client._xla.hlo_module_to_dot_graph(xla_client._xla.hlo_module_from_text(x))
As an example for the above function the compiled graph shows XLA loop fusion. The un-optimised graph needs two loops over arrays:
But the optimised graph using the jit compiled function only needs one:
Visualise a neural network
The same basic approach can be used to visualise a flax
neural network, e.g.,:
import functools
import flax.linen as nn
class CNN(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
x = nn.log_softmax(x)
return x
model = CNN()
batch = numpy.ones((32, 64, 64, 10)) # (N, H, W, C) format
variables = model.init(jax.random.PRNGKey(0), batch)
f=functools.partial(model.apply, variables)
z=jax.xla_computation(f)(batch)
with open("t2.dot", "w") as f:
f.write(z.as_hlo_dot_graph())
This produces the following image: