Skip to content

Commit 232edb0

Browse files
jakeharmon8jpuigcerver
authored andcommitted
Update references to JAX's GitHub repo
JAX has moved from https://github.com/google/jax to https://github.com/jax-ml/jax PiperOrigin-RevId: 702886874
1 parent 1e207e8 commit 232edb0

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

vmoe/scripts/install_gce.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ source env/bin/activate;
1919

2020
if ( which nvidia-smi &> /dev/null ); then
2121
# This assumes CUDA 11 and cuDNN 8.2.
22-
# Check https://github.com/google/jax#pip-installation-gpu-cuda for alternatives.
22+
# Check https://github.com/jax-ml/jax#pip-installation-gpu-cuda for alternatives.
2323
pip install -q 'jax[cuda]' -f https://storage.googleapis.com/jax-releases/jax_releases.html;
2424
else
2525
# Since pjit does not work on CPUs, if nvidia-smi is not found, we assume that
@@ -31,7 +31,7 @@ fi;
3131
python3 -m pip install -q --upgrade pip;
3232
# Upgrade the following packages from GIT, since we use some features not part
3333
# of any release yet.
34-
pip install -q --upgrade git+https://github.com/google/jax.git;
34+
pip install -q --upgrade git+https://github.com/jax-ml/jax.git;
3535
pip install -q --upgrade git+https://github.com/google/flax.git;
3636
pip install -q --upgrade git+https://github.com/google/CommonLoopUtils.git;
3737
# Install the rest of necessary packages from PyPi.

0 commit comments

Comments
 (0)