I have an issue when I install JAX and jaxlib version 0.4.23 and then install Optax. I've tried different versions, such as 0.1.5, but the GPU support no longer works, and I get this message when running:
print(xla_bridge.get_backend().platform)
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
I don’t have this issue if I only install JAX and jaxlib. How can I resolve this problem? Is it because the libraries have incompatible versions? If so, which version of Optax should I install? Otherwise, how can I fix the issue?
I have an issue when I install JAX and jaxlib version 0.4.23 and then install Optax. I've tried different versions, such as 0.1.5, but the GPU support no longer works, and I get this message when running:
print(xla_bridge.get_backend().platform)
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
I don’t have this issue if I only install JAX and jaxlib. How can I resolve this problem? Is it because the libraries have incompatible versions? If so, which version of Optax should I install? Otherwise, how can I fix the issue?
Share Improve this question asked Nov 28, 2024 at 10:11 Alessandro CastelliAlessandro Castelli 174 bronze badges1 Answer
Reset to default 0Because JAX v0.4.23 is relatively old (released Dec 2023), I suspect you probably installed an optax version with a more recent JAX requirement, and so pip install optax
resulted in overwriting your GPU-compatible jax & jaxlib installation with a newer CPU-only jaxlib.
You can fix this by installing a newer jaxlib version (following the GPU installation instructions) along with optax
; for example, on Linux GPU with CUDA 12 something like this should work:
pip install "jax[cuda12]" optax
If for some reason you need an older JAX version, you should install jax and optax in the same command to avoid a later optax install overwriting your original jax installation; for example:
pip install "jax[cuda11_pip]==0.4.23" optax -f https://storage.googleapis/jax-releases/jax_cuda_releases.html
(this follows the 0.4.23 GPU installation instructions, which you can read here: https://github/jax-ml/jax/blob/jaxlib-v0.4.23/docs/installation.md)