# Conway's Game of Life implemented in JAX

Google’s JAX is a powerful base to build many numerical computations – not just neural networks. It works by tracing a Python execution pass and then compiling this trace using XLA into efficient executable code that can target CPUs, GPUs or TPUs (Tensor Processing Units).

Algorithms which are iterative with modifying state are a little bit more difficult to express in JAX as plain tracing can not capture long iteration. After working this through on another algorithm, I thought what better way to illustrate how this works then with Conway’s Game of Life?

So here is how to implement the Game of Life in JAX – and it will run on GPUs and CPUs! The complete program is available in this Google Colab Notebook.

## Calculating the number of neighbours

My first try was with `lax.conv`

, but this only works with
floating point types which we don’t need for the Game of Life. Instead
I use the nice `lax.reduce`

operation:

```
nghbrs=L.reduce_window(a, 0, L.add, (3,3), (1,1), "SAME")-a
```

here `a`

is the array from which we want to compute number of
neighbours at each point and `L=jax.lax`

.

## Modifying the state

Cells with births and deaths are easily computed with
`logical_and`

and `logical_or`

. But this gives an array of
boolean values which can not be used as indices in JAX. The solution
is the `select`

operator:

```
na=L.select(birth,
Iρ(a.shape, N.int32),
a)
na=L.select(death,
Zρ(a.shape, N.int32),
na)
```

Here `Zρ=jax.numpy.zeros`

and `Iρ=jax.numpy.ones`

.

Combining the the neightbour calculation and the `select`

here is
the whole function to evolve a grid one generation:

```
Zρ=N.zeros
Iρ=N.ones
@jax.jit
def rgen(a):
"""
Evolve a grid one generation according to Game of Life
"""
# This reduction over-counts the neighbours of live cells since it includes the
# central cell itself. Subtract out the array to correct for this.
nghbrs=L.reduce_window(a, 0, L.add, (3,3), (1,1), "SAME")-a
birth=N.logical_and(a==0, nghbrs==3)
underpop=N.logical_and(a==1, nghbrs<2)
overpop=N.logical_and(a==1, nghbrs>3)
death=N.logical_or(underpop, overpop)
na=L.select(birth,
Iρ(a.shape, N.int32),
a)
na=L.select(death,
Zρ(a.shape, N.int32),
na)
return na
```

## Iterating over many generations

Iteration needs to be done using the `loops`

object from
`jax.experimental`

. Note that the algorithm state **must** be
carried in the attributes of the loop `Scope`

variable!

```
def ngen(a, n):
"Evolve n generations"
with loops.Scope() as s:
s.a=a
for i in s.range(n):
s.a=rgen(s.a)
return s.a
```

This is basically it – the rest of the notebook is simply generating the animations.

## Speed of execution

Looking into speed of execution was not a primary goal, but in case people are interested initial: study suggests this code will do about 10 billion cells / second on the Google Colab GPU runtime.

# Example output

Here is an example output on a 100x100 grid initialised with a random seed. Have a look at the Google Colab Notebook to generate as many of your own as you’d like!