最新消息:雨落星辰是一个专注网站SEO优化、网站SEO诊断、搜索引擎研究、网络营销推广、网站策划运营及站长类的自媒体原创博客

How to build jax with label [cuda12] locally? - Stack Overflow

programmeradmin2浏览0评论

Question

I successfully built jaxlib. However, how can I build jax[cuda12]? I did not find any instructions in the jax documentation for building a specific wheel with the [cuda12] tag. Thank you for your response!

Background

I have some applications running on jax installed via pip install jax[cuda12]==0.4.34 jaxlib==0.4.34. Recently, I encountered this issue when enabling both cuDNN and persistent caching simultaneously: String field 'xla.gpu.CompilationResultProto.DnnCompiledGraphsEntry.value' contains invalid UTF-8 data. I noticed that this bug was fixed in the commit to the xla repository on January 9, 2025, and after that, the latest version of jax was 0.5.0. So, I reinstalled the latest jax using pip install jax[cuda12]==0.5.0 jaxlib==0.5.0 after removing the old jax. The aforementioned issue was resolved! However, my application is not compatible with jax 0.5.0, as it runs slower and encounters some NaN errors. I decided to build jaxlib and jax[cuda12] using a local xla repository.

What did I do

I pulled the jax and xla repositories and switched to specific commits using the following commands:

git clone --recurse-submodules .git
git clone --recurse-submodules .git
cd jax
# This corresponds to the [jax v0.4.34 release](.4.34) version.
git checkout affba367c5533df8900e32cbc3d31ca92dd1c1ea 
git submodule update --init --recursive
cd ..
cd xla
# This is the XLA version defined in the [/jax/third_party/xla/workspace.bzl]() file from the `jax v0.4.34 release`.
git checkout cd6e808c59f53b40a99df1f1b860db9a3e598bff 
git submodule update --init --recursive

After modifying the XLA source code to fix the bug, I read developer.md and used the following command to build jaxlib:

python3 build/build.py \
    --python_version=3.11 \
    --enable_cuda \
    --cuda_version=12.6.1 \
    --cudnn_version=9.4.0 \
    --bazel_options=--override_repository=xla=/home/thomas/xla \
    --verbose

After a long wait, I received the following output:

Target //jaxlib/tools:build_wheel up-to-date:
  bazel-bin/jaxlib/tools/build_wheel
INFO: Elapsed time: 1928.511s, Critical Path: 287.80s
INFO: 3390 processes: 24 internal, 3366 local.
INFO: Build completed successfully, 3390 total actions
INFO: Running command line: bazel-bin/jaxlib/tools/build_wheel '--output_path=/home/thomas/jax/dist' '--jaxlib_git_hash=affba367c5533df8900e32cbc3d31ca92dd1c1ea' '--cpu=x86_64'
Output wheel: /home/thomas/jax/dist/jaxlib-0.4.34.dev20250218-cp311-cp311-manylinux2014_x86_64.whl
To install the newly-built jaxlib wheel on system Python, run:
  pip install /home/thomas/jax/dist/jaxlib-0.4.34.dev20250218-cp311-cp311-manylinux2014_x86_64.whl --force-reinstall
To install the newly-built jaxlib wheel on hermetic Python, run:
  echo -e "\n/home/thomas/jax/dist/jaxlib-0.4.34.dev20250218-cp311-cp311-manylinux2014_x86_64.whl" >> build/requirements.in
  bazel run //build:requirements.update --repo_env=HERMETIC_PYTHON_VERSION=3.11

Question

I successfully built jaxlib. However, how can I build jax[cuda12]? I did not find any instructions in the jax documentation for building a specific wheel with the [cuda12] tag. Thank you for your response!

Background

I have some applications running on jax installed via pip install jax[cuda12]==0.4.34 jaxlib==0.4.34. Recently, I encountered this issue when enabling both cuDNN and persistent caching simultaneously: String field 'xla.gpu.CompilationResultProto.DnnCompiledGraphsEntry.value' contains invalid UTF-8 data. I noticed that this bug was fixed in the commit to the xla repository on January 9, 2025, and after that, the latest version of jax was 0.5.0. So, I reinstalled the latest jax using pip install jax[cuda12]==0.5.0 jaxlib==0.5.0 after removing the old jax. The aforementioned issue was resolved! However, my application is not compatible with jax 0.5.0, as it runs slower and encounters some NaN errors. I decided to build jaxlib and jax[cuda12] using a local xla repository.

What did I do

I pulled the jax and xla repositories and switched to specific commits using the following commands:

git clone --recurse-submodules https://github/jax-ml/jax.git
git clone --recurse-submodules https://github/openxla/xla.git
cd jax
# This corresponds to the [jax v0.4.34 release](https://github/jax-ml/jax/tree/jax-v0.4.34) version.
git checkout affba367c5533df8900e32cbc3d31ca92dd1c1ea 
git submodule update --init --recursive
cd ..
cd xla
# This is the XLA version defined in the [/jax/third_party/xla/workspace.bzl](https://github/jax-ml/jax/commit/aa9ee7abfab2344ce56483af29266a31ca7b7708) file from the `jax v0.4.34 release`.
git checkout cd6e808c59f53b40a99df1f1b860db9a3e598bff 
git submodule update --init --recursive

After modifying the XLA source code to fix the bug, I read developer.md and used the following command to build jaxlib:

python3 build/build.py \
    --python_version=3.11 \
    --enable_cuda \
    --cuda_version=12.6.1 \
    --cudnn_version=9.4.0 \
    --bazel_options=--override_repository=xla=/home/thomas/xla \
    --verbose

After a long wait, I received the following output:

Target //jaxlib/tools:build_wheel up-to-date:
  bazel-bin/jaxlib/tools/build_wheel
INFO: Elapsed time: 1928.511s, Critical Path: 287.80s
INFO: 3390 processes: 24 internal, 3366 local.
INFO: Build completed successfully, 3390 total actions
INFO: Running command line: bazel-bin/jaxlib/tools/build_wheel '--output_path=/home/thomas/jax/dist' '--jaxlib_git_hash=affba367c5533df8900e32cbc3d31ca92dd1c1ea' '--cpu=x86_64'
Output wheel: /home/thomas/jax/dist/jaxlib-0.4.34.dev20250218-cp311-cp311-manylinux2014_x86_64.whl
To install the newly-built jaxlib wheel on system Python, run:
  pip install /home/thomas/jax/dist/jaxlib-0.4.34.dev20250218-cp311-cp311-manylinux2014_x86_64.whl --force-reinstall
To install the newly-built jaxlib wheel on hermetic Python, run:
  echo -e "\n/home/thomas/jax/dist/jaxlib-0.4.34.dev20250218-cp311-cp311-manylinux2014_x86_64.whl" >> build/requirements.in
  bazel run //build:requirements.update --repo_env=HERMETIC_PYTHON_VERSION=3.11
Share Improve this question edited Mar 5 at 23:19 talonmies 72.4k35 gold badges203 silver badges289 bronze badges asked Mar 5 at 19:53 ThomasThomas 111 bronze badge
Add a comment  | 

1 Answer 1

Reset to default 0

Quoting from https://docs.jax.dev/en/latest/developer.html:

If you would like to build jaxlib and the CUDA plugins: Run

python build/build.py build --wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt

to generate three wheels (jaxlib without cuda, jax-cuda-plugin, and jax-cuda-pjrt). By default all CUDA compilation steps performed by NVCC and clang, but it can be restricted to clang via the --build_cuda_with_clang flag.

The resulting jaxlib, jax-cuda-plugin, and jax-cuda-pjrt wheels are local builds of the packages that jax[cuda12] installs, as you can see in JAX's setup.py definition: https://github/jax-ml/jax/blob/jax-v0.5.2/setup.py#L88-L91.

To install the equivalent of jax[cuda12] with your local builds, you would install these three wheels manually.

发布评论

评论列表(0)

  1. 暂无评论