Jax is a Python package for efficient computation and automatic differentiation. The question here is : will it always use all CPU cores efficiently?
Some Jax operations have internal multi-threading parallelisation and will use multiple threads. This is true of many operations backed by for example the Eigen library such as matrix multiplications etc. However not all operations have internal multi threading, for example the FFT operations *do not have** multi-threading enabled.
In situations where internal operation multithreading does not use the
CPU resources well enough it is possible to parallise at JAX level
through the use of the
pmap function. However it is first
necessary to split the CPU into multiple apparent devices using the
--xla_force_host_platform_device_count flag (see pmap-cpu
). For example:
XLA_FLAGS="--xla_force_host_platform_device_count=8" python myscript.py
will split the CPU into 8 independent devices over which pmap will parallelism. See the pmap-cookbook for examples of how to use pmap.