Conway's Game of Life implemented in JAX
Google’s JAX is a powerful base to build many numerical computations – not just neural networks. It works by tracing a Python execution pass and then compiling this trace using XLA into efficient executable code that can target CPUs, GPUs or TPUs (Tensor Processing Units).
Algorithms which are iterative with modifying state are a little bit more difficult to express in JAX as plain tracing can not capture long iteration. After working this through on another algorithm, I thought what better way to illustrate how this works then with Conway’s Game of Life?
So here is how to implement the Game of Life in JAX – and it will run on GPUs and CPUs! The complete program is available in this Google Colab Notebook.
Calculating the number of neighbours
My first try was with lax.conv
, but this only works with
floating point types which we don’t need for the Game of Life. Instead
I use the nice lax.reduce
operation:
nghbrs=L.reduce_window(a, 0, L.add, (3,3), (1,1), "SAME")-a
here a
is the array from which we want to compute number of
neighbours at each point and L=jax.lax
.
Modifying the state
Cells with births and deaths are easily computed with
logical_and
and logical_or
. But this gives an array of
boolean values which can not be used as indices in JAX. The solution
is the select
operator:
na=L.select(birth,
Iρ(a.shape, N.int32),
a)
na=L.select(death,
Zρ(a.shape, N.int32),
na)
Here Zρ=jax.numpy.zeros
and Iρ=jax.numpy.ones
.
Combining the the neightbour calculation and the select
here is
the whole function to evolve a grid one generation:
Zρ=N.zeros
Iρ=N.ones
@jax.jit
def rgen(a):
"""
Evolve a grid one generation according to Game of Life
"""
# This reduction over-counts the neighbours of live cells since it includes the
# central cell itself. Subtract out the array to correct for this.
nghbrs=L.reduce_window(a, 0, L.add, (3,3), (1,1), "SAME")-a
birth=N.logical_and(a==0, nghbrs==3)
underpop=N.logical_and(a==1, nghbrs<2)
overpop=N.logical_and(a==1, nghbrs>3)
death=N.logical_or(underpop, overpop)
na=L.select(birth,
Iρ(a.shape, N.int32),
a)
na=L.select(death,
Zρ(a.shape, N.int32),
na)
return na
Iterating over many generations
Iteration needs to be done using the loops
object from
jax.experimental
. Note that the algorithm state must be
carried in the attributes of the loop Scope
variable!
def ngen(a, n):
"Evolve n generations"
with loops.Scope() as s:
s.a=a
for i in s.range(n):
s.a=rgen(s.a)
return s.a
This is basically it – the rest of the notebook is simply generating the animations.
Speed of execution
Looking into speed of execution was not a primary goal, but in case people are interested initial: study suggests this code will do about 10 billion cells / second on the Google Colab GPU runtime.
Example output
Here is an example output on a 100x100 grid initialised with a random seed. Have a look at the Google Colab Notebook to generate as many of your own as you’d like!