Is Jax multi-threaded when run CPUs?
Jax is a Python package for efficient computation and automatic differentiation. The question here is : will it always use all CPU cores efficiently?
Intra-op multi-threading
Some Jax operations have internal multi-threading parallelisation and will use multiple threads. This is true of many operations backed by for example the Eigen library such as matrix multiplications etc. However not all operations have internal multi threading, for example the FFT operations *do not have** multi-threading enabled.
PMap-based multi-threading
In situations where internal operation multithreading does not use the
CPU resources well enough it is possible to parallise at JAX level
through the use of the pmap
function. However it is first
necessary to split the CPU into multiple apparent devices using the
--xla_force_host_platform_device_count
flag (see pmap-cpu
multithreading issue
). For example:
XLA_FLAGS="--xla_force_host_platform_device_count=8" python myscript.py
will split the CPU into 8 independent devices over which pmap will parallelism. See the pmap-cookbook for examples of how to use pmap.