
Jax is a library for high performance numerical computing with GPUs. It has a wide range of capabilities, from autodifferentiation to a GPU-enabled version of the Numpy library.
Sections
Install for Yourself
It’s recommended that you install Jax in your own virtual environment or conda environment. The sections below explain how to install CPU and GPU versions of Jax into your environment. These instructions have been taken from the Jax installation page.
Python Virtual Environment
- Create and activate your virtualenv as explained on the Python Installs page.
- Install the CPU-Only or CPU+GPU versions of Jax with
pip
:
CPU-Only
(my_newenv) [rcs@scc1 ~] pip install jax[cpu]
OR
CPU + GPU
(my_newenv) [rcs@scc1 ~] pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Conda Environment
- Create and activate your conda env as explained on the Miniconda Installs page.
- Install the CPU-Only or CPU+GPU versions of Jax with
conda install
:
CPU-Only
(my_conda_env) [rcs@scc1 ~] conda install jax -c conda-forge
OR
CPU + GPU
(my_conda_env) [rcs@scc1 ~] conda install jaxlib=*=*cuda* jax cuda-nvcc -c conda-forge -c nvidia