Jax is a library for high performance numerical computing with GPUs. It has a wide range of capabilities, from autodifferentiation to a GPU-enabled version of the Numpy library.

 
Sections

 

Install for Yourself

It’s recommended that you install Jax in your own virtual environment or conda environment. The sections below explain how to install CPU and GPU versions of Jax into your environment. These instructions have been taken from the Jax installation page.

Python Virtual Environment

  1. Create and activate your virtualenv as explained on the Python Installs page.
  2. Install the CPU-Only or CPU+GPU versions of Jax with pip:

    CPU-Only

    (my_newenv) [rcs@scc1 ~] pip install jax[cpu]

    OR

    CPU + GPU

    (my_newenv) [rcs@scc1 ~] pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

 

Conda Environment

  1. Create and activate your conda env as explained on the Miniconda Installs page.
  2. Install the CPU-Only or CPU+GPU versions of Jax with conda install:

    CPU-Only

    (my_conda_env) [rcs@scc1 ~] conda install jax -c conda-forge

    OR

    CPU + GPU

    (my_conda_env) [rcs@scc1 ~] conda install jaxlib=*=*cuda* jax cuda-nvcc -c conda-forge -c nvidia

Back to top