最新消息:雨落星辰是一个专注网站SEO优化、网站SEO诊断、搜索引擎研究、网络营销推广、网站策划运营及站长类的自媒体原创博客

I have a problem when i install jax, jaxlib and optax - Stack Overflow

programmeradmin1浏览0评论

I have an issue when I install JAX and jaxlib version 0.4.23 and then install Optax. I've tried different versions, such as 0.1.5, but the GPU support no longer works, and I get this message when running:

print(xla_bridge.get_backend().platform)

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.

I don’t have this issue if I only install JAX and jaxlib. How can I resolve this problem? Is it because the libraries have incompatible versions? If so, which version of Optax should I install? Otherwise, how can I fix the issue?

I have an issue when I install JAX and jaxlib version 0.4.23 and then install Optax. I've tried different versions, such as 0.1.5, but the GPU support no longer works, and I get this message when running:

print(xla_bridge.get_backend().platform)

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.

I don’t have this issue if I only install JAX and jaxlib. How can I resolve this problem? Is it because the libraries have incompatible versions? If so, which version of Optax should I install? Otherwise, how can I fix the issue?

Share Improve this question asked Nov 28, 2024 at 10:11 Alessandro CastelliAlessandro Castelli 174 bronze badges
Add a comment  | 

1 Answer 1

Reset to default 0

Because JAX v0.4.23 is relatively old (released Dec 2023), I suspect you probably installed an optax version with a more recent JAX requirement, and so pip install optax resulted in overwriting your GPU-compatible jax & jaxlib installation with a newer CPU-only jaxlib.

You can fix this by installing a newer jaxlib version (following the GPU installation instructions) along with optax; for example, on Linux GPU with CUDA 12 something like this should work:

pip install "jax[cuda12]" optax

If for some reason you need an older JAX version, you should install jax and optax in the same command to avoid a later optax install overwriting your original jax installation; for example:

pip install "jax[cuda11_pip]==0.4.23" optax -f https://storage.googleapis/jax-releases/jax_cuda_releases.html

(this follows the 0.4.23 GPU installation instructions, which you can read here: https://github/jax-ml/jax/blob/jaxlib-v0.4.23/docs/installation.md)

发布评论

评论列表(0)

  1. 暂无评论