diff --git a/.github/container/Dockerfile.jax b/.github/container/Dockerfile.jax index 9423d6c18..f4420e234 100644 --- a/.github/container/Dockerfile.jax +++ b/.github/container/Dockerfile.jax @@ -99,12 +99,21 @@ ADD build-jax.sh local_cuda_arch test-jax.sh /usr/local/bin/ RUN mkdir -p /opt/pip-tools.d ## Editable installations of jax and jaxlib +## For 25.01 release we also pin several packages obtained +## from https://github.com/jax-ml/jax-ai-stack RUN <<"EOF" bash -ex for component in $(ls ${BUILD_PATH_JAXLIB}); do echo "-e file://${BUILD_PATH_JAXLIB}/${component}" >> /opt/pip-tools.d/requirements-jax.in; done echo "-e file://${SRC_PATH_JAX}" >> /opt/pip-tools.d/requirements-jax.in -echo "numpy<2.0.0" >> /opt/pip-tools.d/requirements-jax.in +for pkg in \ + "ml_dtypes==0.4.0" \ + "optax==0.2.4" \ + "orbax-checkpoint==0.10.2" \ + "orbax-export==0.0.6" \ +; do + echo "$pkg" >> /opt/pip-tools.d/requirements-jax.in +done EOF ## Flax diff --git a/.github/container/Dockerfile.maxtext b/.github/container/Dockerfile.maxtext index 87b73efcd..b5013f3c9 100644 --- a/.github/container/Dockerfile.maxtext +++ b/.github/container/Dockerfile.maxtext @@ -5,6 +5,8 @@ ARG URLREF_MAXTEXT=https://github.com/google/maxtext.git#main ARG URLREF_TFTEXT=https://github.com/tensorflow/text.git#master ARG SRC_PATH_MAXTEXT=/opt/maxtext ARG SRC_PATH_TFTEXT=/opt/tensorflow-text +ARG URLREF_JETSTREAM=https://github.com/google/jetstream.git#main +ARG SRC_PATH_JETSTREAM=/opt/jetstream ############################################################################### ## build tensorflow-text and lingvo, which do not have working arm64 pip wheels @@ -56,6 +58,7 @@ RUN echo "tensorflow-text @ file://$(ls /opt/tensorflow_text*.whl)" >> /opt/pip- RUN <<"EOF" bash -ex git-clone.sh ${URLREF_MAXTEXT} ${SRC_PATH_MAXTEXT} +sed -i '/google-jetstream/d' ${SRC_PATH_MAXTEXT}/requirements.txt echo "-r ${SRC_PATH_MAXTEXT}/requirements.txt" >> /opt/pip-tools.d/requirements-maxtext.in # specify some restrictions to speed up the build and @@ -64,6 +67,7 @@ for pattern in \ "s|absl-py|absl-py>=2.1.0|g" \ "s|protobuf==3.20.3|protobuf>=3.19.0|g" \ "s|tensorflow-datasets|tensorflow-datasets>=4.8.0|g" \ + "s|grain-nightly|grain|g" \ ; do sed -i "${pattern}" ${SRC_PATH_MAXTEXT}/requirements.txt; done @@ -76,6 +80,18 @@ EOF ADD test-maxtext.sh /usr/local/bin +############################################################################### +## Add JetStream +############################################################################### + +ARG URLREF_JETSTREAM +ARG SRC_PATH_JETSTREAM + +RUN <<"EOF" bash -ex +git-clone.sh ${URLREF_JETSTREAM} ${SRC_PATH_JETSTREAM} +echo "-e file://${SRC_PATH_JETSTREAM}" >> /opt/pip-tools.d/requirements-jetstream.in +EOF + ############################################################################### ## Install accumulated packages from the base image and the previous stage ############################################################################### diff --git a/.github/container/Dockerfile.t5x b/.github/container/Dockerfile.t5x index ea4bbf2ec..84004ba09 100644 --- a/.github/container/Dockerfile.t5x +++ b/.github/container/Dockerfile.t5x @@ -73,6 +73,7 @@ echo "seqio-nightly>=0.0.18.dev20240714" >> /opt/pip-tools.d/requirements-t5x.in # 2. Remove head-of-tree specs from select dependencies pushd ${SRC_PATH_T5X} sed -i "s| @ git+https://github.com/google/flax#egg=flax||g" setup.py +sed -i "s| @ git+https://github.com/deepmind/optax#egg=optax||g" setup.py # for ARM64 build if [[ "$(dpkg --print-architecture)" == "arm64" ]]; then diff --git a/.github/container/build-jax.sh b/.github/container/build-jax.sh index b179e9456..2bcf9ce86 100755 --- a/.github/container/build-jax.sh +++ b/.github/container/build-jax.sh @@ -318,13 +318,17 @@ else fi # install jax and jaxlib -pip --disable-pip-version-check install -e ${BUILD_PATH_JAXLIB}/jaxlib -e ${BUILD_PATH_JAXLIB}/jax-cuda-pjrt -e ${BUILD_PATH_JAXLIB}/jax-cuda-plugin -e "${SRC_PATH_JAX}" +pip --disable-pip-version-check install -e ${BUILD_PATH_JAXLIB}/jaxlib -e ${BUILD_PATH_JAXLIB}/jax-cuda-pjrt -e ${BUILD_PATH_JAXLIB}/jax-cuda-plugin +jaxlib_version=$(pip show jaxlib | grep Version | tr ':' '\n' | tail -1) +sed -i "s|^_current_jaxlib_version.*|_current_jaxlib_version = '${jaxlib_version}'|" /opt/jax/setup.py +sed -i "s| f'jaxlib >={_minimum_jaxlib_version}, <={_jax_version}',| f'jaxlib>=0.4.30',|" /opt/jax/setup.py +pip --disable-pip-version-check install -e "${SRC_PATH_JAX}" ## after installation (example) -# jax 0.4.36.dev20241125+f828f2d7d /opt/jax -# jax-cuda12-pjrt 0.4.36.dev20241125 /opt/jaxlibs/jax-cuda-pjrt -# jax-cuda12-plugin 0.4.36.dev20241125 /opt/jaxlibs/jax-cuda-plugin -# jaxlib 0.4.36.dev20241125 /opt/jaxlibs/jaxlib +# jax 0.4.36.dev20241220+f828f2d7d /opt/jax +# jax-cuda12-pjrt 0.4.36.dev20241220 /opt/jaxlibs/jax-cuda-pjrt +# jax-cuda12-plugin 0.4.36.dev20241220 /opt/jaxlibs/jax-cuda-plugin +# jaxlib 0.4.36.dev20241220 /opt/jaxlibs/jaxlib pip list | grep jax # Ensure directories are readable by all for non-root users diff --git a/.github/container/manifest.yaml b/.github/container/manifest.yaml index b9c06e2e6..95d0c03b5 100644 --- a/.github/container/manifest.yaml +++ b/.github/container/manifest.yaml @@ -12,7 +12,7 @@ flax: url: https://github.com/google/flax.git mirror_url: https://github.com/nvjax-svc-0/flax.git tracking_ref: main - latest_verified_commit: 718aa8ccb12c3fdefcf3d196874e4fc667b3ade5 + latest_verified_commit: d89c955d1faac9dd2162a0c674f7897f2c53f54d mode: git-clone patches: pull/3340/head: file://patches/flax/PR-3340.patch # Add Sharding Annotations to Flax Modules @@ -177,8 +177,8 @@ panopticapi: mode: git-clone orbax-checkpoint: url: https://github.com/google/orbax.git - tracking_ref: main - latest_verified_commit: 16c2d409e365576284dbaf190ac002b24c1f927f + tracking_ref: v0.10.2 + latest_verified_commit: d6101bad9ec5ddee8ee8b8c10e1d27d6c57f0963 mode: pip-vcs pathwaysutils: url: https://github.com/google/pathways-utils.git