Building Google's Jax library from source using Nix
Easy, reproducible, maintainable builds of Jax using the NIX system is straightforward:
- The package is already defined in
pkgs/development/python-modules/jaxlibandpkgs/development/python-modules/jax. See https://nixos.wiki/wiki/JAX -
Bringing up to most update version requires updating the versions and hashes in the
default.nixfiles and also potentially some dependencies:a. The dependency I’ve needed to update is the Google
snappylibrary (a compression library) - The
jaxlibcan be easily modified by adding thepatchesoption in the bazel-build element of the default.nix file - Build using the usual
nix-build -A python3Packages.jaxlib <mynixpkgs>option. Add-Kto aid debugging.
Be prepared to wait though – the compilation stage takes a substantial amount of processing resources