Google’s JAX is a powerful base to build many numerical computations – not just neural networks. It works by tracing a Python execution pass and then compiling this trace using XLA into efficient executable code that can target CPUs, GPUs or TPUs (Tensor Processing Units).

Algorithms which are iterative with modifying state are a little bit more difficult to express in JAX as plain tracing can not capture long iteration. After working this through on another algorithm, I thought what better way to illustrate how this works then with Conway’s Game of Life?

So here is how to implement the Game of Life in JAX – and it will run on GPUs and CPUs! The complete program is available in this Google Colab Notebook.

Calculating the number of neighbours

My first try was with lax.conv, but this only works with floating point types which we don’t need for the Game of Life. Instead I use the nice lax.reduce operation:

nghbrs=L.reduce_window(a, 0, L.add, (3,3), (1,1), "SAME")-a

here a is the array from which we want to compute number of neighbours at each point and L=jax.lax.

Modifying the state

Cells with births and deaths are easily computed with logical_and and logical_or. But this gives an array of boolean values which can not be used as indices in JAX. The solution is the select operator:

na=L.select(birth,
            (a.shape, N.int32),
            a)

na=L.select(death,
            (a.shape, N.int32),
            na)

Here Zρ=jax.numpy.zeros and Iρ=jax.numpy.ones.

Combining the the neightbour calculation and the select here is the whole function to evolve a grid one generation:

=N.zeros
=N.ones

@jax.jit
def rgen(a):
  """
    Evolve a grid one generation according to Game of Life
  """
    # This reduction over-counts the neighbours of live cells since it includes the
    # central cell itself. Subtract out the array to correct for this.
    nghbrs=L.reduce_window(a, 0, L.add, (3,3), (1,1), "SAME")-a
    birth=N.logical_and(a==0, nghbrs==3)
    underpop=N.logical_and(a==1, nghbrs<2)
    overpop=N.logical_and(a==1, nghbrs>3)
    death=N.logical_or(underpop, overpop)

    na=L.select(birth,
                (a.shape, N.int32),
                a)

    na=L.select(death,
                (a.shape, N.int32),
                na)

    return na

Iterating over many generations

Iteration needs to be done using the loops object from jax.experimental . Note that the algorithm state must be carried in the attributes of the loop Scope variable!

def ngen(a, n):
  "Evolve n generations"
  with loops.Scope() as s:
    s.a=a
    for i in s.range(n):
        s.a=rgen(s.a)
  return s.a

This is basically it – the rest of the notebook is simply generating the animations.

Speed of execution

Looking into speed of execution was not a primary goal, but in case people are interested initial: study suggests this code will do about 10 billion cells / second on the Google Colab GPU runtime.

Example output

Here is an example output on a 100x100 grid initialised with a random seed. Have a look at the Google Colab Notebook to generate as many of your own as you’d like!