Google’s JAX is a powerful base to build many numerical computations – not just neural networks. It works by tracing a Python execution pass and then compiling this trace using XLA into efficient executable code that can target CPUs, GPUs or TPUs (Tensor Processing Units).
Algorithms which are iterative with modifying state are a little bit more difficult to express in JAX as plain tracing can not capture long iteration. After working this through on another algorithm, I thought what better way to illustrate how this works then with Conway’s Game of Life?
So here is how to implement the Game of Life in JAX – and it will run on GPUs and CPUs! The complete program is available in this Google Colab Notebook.
Calculating the number of neighbours
My first try was with
lax.conv, but this only works with
floating point types which we don’t need for the Game of Life. Instead
I use the nice
nghbrs=L.reduce_window(a, 0, L.add, (3,3), (1,1), "SAME")-a
a is the array from which we want to compute number of
neighbours at each point and
Modifying the state
Cells with births and deaths are easily computed with
logical_or. But this gives an array of
boolean values which can not be used as indices in JAX. The solution
na=L.select(birth, Iρ(a.shape, N.int32), a) na=L.select(death, Zρ(a.shape, N.int32), na)
Combining the the neightbour calculation and the
select here is
the whole function to evolve a grid one generation:
Zρ=N.zeros Iρ=N.ones @jax.jit def rgen(a): """ Evolve a grid one generation according to Game of Life """ # This reduction over-counts the neighbours of live cells since it includes the # central cell itself. Subtract out the array to correct for this. nghbrs=L.reduce_window(a, 0, L.add, (3,3), (1,1), "SAME")-a birth=N.logical_and(a==0, nghbrs==3) underpop=N.logical_and(a==1, nghbrs<2) overpop=N.logical_and(a==1, nghbrs>3) death=N.logical_or(underpop, overpop) na=L.select(birth, Iρ(a.shape, N.int32), a) na=L.select(death, Zρ(a.shape, N.int32), na) return na
Iterating over many generations
Iteration needs to be done using the
loops object from
jax.experimental . Note that the algorithm state must be
carried in the attributes of the loop
def ngen(a, n): "Evolve n generations" with loops.Scope() as s: s.a=a for i in s.range(n): s.a=rgen(s.a) return s.a
This is basically it – the rest of the notebook is simply generating the animations.
Speed of execution
Looking into speed of execution was not a primary goal, but in case people are interested initial: study suggests this code will do about 10 billion cells / second on the Google Colab GPU runtime.
Here is an example output on a 100x100 grid initialised with a random seed. Have a look at the Google Colab Notebook to generate as many of your own as you’d like!