An update on visualising the computational graph of a jax program
I’ve written while ago about visualising the computational graph of a JAX program here. Jax has evolved since then, so here is an update for the current (as of time of writing) version of Jax (0.3.1).
Export the HLO Intermediate Representation (IR)
The main change is that the graph is produced from an
jax.xla_computation
object. The other change is that instead of
specifying the input shape I am supplying an example array
(numpy.ones(100)
).
Here is the updated program:
import jax
from jax import numpy, grad
def tanh(x):
y = numpy.exp(-2.0 * x)
return (1.0 - y) / (1.0 + y)
def lfn(x):
return numpy.log(tanh(x).sum())
def dlfn(x):
return grad(lfn)(x)
z=jax.xla_computation(dlfn)(numpy.ones(100))
with open("t.txt", "w") as f:
f.write(z.as_hlo_text())
This will create a text file t.txt
with a text-based HLO.
Visualise
The simplest way of visualising is to dump as a dot graph and run dot:
with open("t.dot", "w") as f:
f.write(z.as_hlo_dot_graph())
dot t.dot -Tpng > t.png
This produces the following image:
Visualise a neural network
The same basic approach can be used to visualise a flax
neural network, e.g.,:
import functools
import flax.linen as nn
class CNN(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
x = nn.log_softmax(x)
return x
model = CNN()
batch = numpy.ones((32, 64, 64, 10)) # (N, H, W, C) format
variables = model.init(jax.random.PRNGKey(0), batch)
f=functools.partial(model.apply, variables)
z=jax.xla_computation(f)(batch)
with open("t2.dot", "w") as f:
f.write(z.as_hlo_dot_graph())
This produces the following image: