Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor build-jax.sh and pin orbax-checkpoint #1211

Merged
merged 6 commits into from
Dec 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion .github/container/Dockerfile.jax
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,21 @@ ADD build-jax.sh local_cuda_arch test-jax.sh /usr/local/bin/
RUN mkdir -p /opt/pip-tools.d

## Editable installations of jax and jaxlib
## For 25.01 release we also pin several packages obtained
## from https://github.com/jax-ml/jax-ai-stack
RUN <<"EOF" bash -ex
for component in $(ls ${BUILD_PATH_JAXLIB}); do
echo "-e file://${BUILD_PATH_JAXLIB}/${component}" >> /opt/pip-tools.d/requirements-jax.in;
done
echo "-e file://${SRC_PATH_JAX}" >> /opt/pip-tools.d/requirements-jax.in
echo "numpy<2.0.0" >> /opt/pip-tools.d/requirements-jax.in
for pkg in \
"ml_dtypes==0.4.0" \
"optax==0.2.4" \
"orbax-checkpoint==0.10.2" \
"orbax-export==0.0.6" \
; do
echo "$pkg" >> /opt/pip-tools.d/requirements-jax.in
done
EOF

## Flax
Expand Down
16 changes: 16 additions & 0 deletions .github/container/Dockerfile.maxtext
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ ARG URLREF_MAXTEXT=https://github.com/google/maxtext.git#main
ARG URLREF_TFTEXT=https://github.com/tensorflow/text.git#master
ARG SRC_PATH_MAXTEXT=/opt/maxtext
ARG SRC_PATH_TFTEXT=/opt/tensorflow-text
ARG URLREF_JETSTREAM=https://github.com/google/jetstream.git#main
ARG SRC_PATH_JETSTREAM=/opt/jetstream

###############################################################################
## build tensorflow-text and lingvo, which do not have working arm64 pip wheels
Expand Down Expand Up @@ -56,6 +58,7 @@ RUN echo "tensorflow-text @ file://$(ls /opt/tensorflow_text*.whl)" >> /opt/pip-

RUN <<"EOF" bash -ex
git-clone.sh ${URLREF_MAXTEXT} ${SRC_PATH_MAXTEXT}
sed -i '/google-jetstream/d' ${SRC_PATH_MAXTEXT}/requirements.txt
echo "-r ${SRC_PATH_MAXTEXT}/requirements.txt" >> /opt/pip-tools.d/requirements-maxtext.in

# specify some restrictions to speed up the build and
Expand All @@ -64,6 +67,7 @@ for pattern in \
"s|absl-py|absl-py>=2.1.0|g" \
"s|protobuf==3.20.3|protobuf>=3.19.0|g" \
"s|tensorflow-datasets|tensorflow-datasets>=4.8.0|g" \
"s|grain-nightly|grain|g" \
; do
sed -i "${pattern}" ${SRC_PATH_MAXTEXT}/requirements.txt;
done
Expand All @@ -76,6 +80,18 @@ EOF

ADD test-maxtext.sh /usr/local/bin

###############################################################################
## Add JetStream
###############################################################################

ARG URLREF_JETSTREAM
ARG SRC_PATH_JETSTREAM

RUN <<"EOF" bash -ex
git-clone.sh ${URLREF_JETSTREAM} ${SRC_PATH_JETSTREAM}
echo "-e file://${SRC_PATH_JETSTREAM}" >> /opt/pip-tools.d/requirements-jetstream.in
EOF

###############################################################################
## Install accumulated packages from the base image and the previous stage
###############################################################################
Expand Down
1 change: 1 addition & 0 deletions .github/container/Dockerfile.t5x
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ echo "seqio-nightly>=0.0.18.dev20240714" >> /opt/pip-tools.d/requirements-t5x.in
# 2. Remove head-of-tree specs from select dependencies
pushd ${SRC_PATH_T5X}
sed -i "s| @ git+https://github.com/google/flax#egg=flax||g" setup.py
sed -i "s| @ git+https://github.com/deepmind/optax#egg=optax||g" setup.py

# for ARM64 build
if [[ "$(dpkg --print-architecture)" == "arm64" ]]; then
Expand Down
14 changes: 9 additions & 5 deletions .github/container/build-jax.sh
Original file line number Diff line number Diff line change
Expand Up @@ -318,13 +318,17 @@ else
fi

# install jax and jaxlib
pip --disable-pip-version-check install -e ${BUILD_PATH_JAXLIB}/jaxlib -e ${BUILD_PATH_JAXLIB}/jax-cuda-pjrt -e ${BUILD_PATH_JAXLIB}/jax-cuda-plugin -e "${SRC_PATH_JAX}"
pip --disable-pip-version-check install -e ${BUILD_PATH_JAXLIB}/jaxlib -e ${BUILD_PATH_JAXLIB}/jax-cuda-pjrt -e ${BUILD_PATH_JAXLIB}/jax-cuda-plugin
jaxlib_version=$(pip show jaxlib | grep Version | tr ':' '\n' | tail -1)
sed -i "s|^_current_jaxlib_version.*|_current_jaxlib_version = '${jaxlib_version}'|" /opt/jax/setup.py
sed -i "s| f'jaxlib >={_minimum_jaxlib_version}, <={_jax_version}',| f'jaxlib>=0.4.30',|" /opt/jax/setup.py
pip --disable-pip-version-check install -e "${SRC_PATH_JAX}"

## after installation (example)
# jax 0.4.36.dev20241125+f828f2d7d /opt/jax
# jax-cuda12-pjrt 0.4.36.dev20241125 /opt/jaxlibs/jax-cuda-pjrt
# jax-cuda12-plugin 0.4.36.dev20241125 /opt/jaxlibs/jax-cuda-plugin
# jaxlib 0.4.36.dev20241125 /opt/jaxlibs/jaxlib
# jax 0.4.36.dev20241220+f828f2d7d /opt/jax
# jax-cuda12-pjrt 0.4.36.dev20241220 /opt/jaxlibs/jax-cuda-pjrt
# jax-cuda12-plugin 0.4.36.dev20241220 /opt/jaxlibs/jax-cuda-plugin
# jaxlib 0.4.36.dev20241220 /opt/jaxlibs/jaxlib
pip list | grep jax

# Ensure directories are readable by all for non-root users
Expand Down
6 changes: 3 additions & 3 deletions .github/container/manifest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ flax:
url: https://github.com/google/flax.git
mirror_url: https://github.com/nvjax-svc-0/flax.git
tracking_ref: main
latest_verified_commit: 718aa8ccb12c3fdefcf3d196874e4fc667b3ade5
latest_verified_commit: d89c955d1faac9dd2162a0c674f7897f2c53f54d
mode: git-clone
patches:
pull/3340/head: file://patches/flax/PR-3340.patch # Add Sharding Annotations to Flax Modules
Expand Down Expand Up @@ -177,8 +177,8 @@ panopticapi:
mode: git-clone
orbax-checkpoint:
url: https://github.com/google/orbax.git
tracking_ref: main
latest_verified_commit: 16c2d409e365576284dbaf190ac002b24c1f927f
tracking_ref: v0.10.2
latest_verified_commit: d6101bad9ec5ddee8ee8b8c10e1d27d6c57f0963
mode: pip-vcs
pathwaysutils:
url: https://github.com/google/pathways-utils.git
Expand Down
Loading