Skip to content

Commit

Permalink
Merge pull request #291705 from GaetanLepage/jax
Browse files Browse the repository at this point in the history
python311Packages.{jax,jaxlib,jaxlib-bin}: 0.4.24 -> 0.4.28
  • Loading branch information
samuela authored May 13, 2024
2 parents 8fe1aa6 + 32d1bb1 commit 3a993d3
Show file tree
Hide file tree
Showing 15 changed files with 128 additions and 100 deletions.
8 changes: 6 additions & 2 deletions pkgs/development/python-modules/blackjax/default.nix
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

buildPythonPackage rec {
pname = "blackjax";
version = "1.2.0";
version = "1.2.1";
pyproject = true;

disabled = pythonOlder "3.9";
Expand All @@ -25,7 +25,7 @@ buildPythonPackage rec {
owner = "blackjax-devs";
repo = "blackjax";
rev = "refs/tags/${version}";
hash = "sha256-vXyxK3xALKG61YGK7fmoqQNGfOiagHFrvnU02WKZThw=";
hash = "sha256-VoWBCjFMyE5LVJyf7du/pKlnvDHj22lguiP6ZUzH9ak=";
};

build-system = [
Expand Down Expand Up @@ -56,6 +56,10 @@ buildPythonPackage rec {
disabledTests = [
# too slow
"test_adaptive_tempered_smc"
] ++ lib.optionals (stdenv.isLinux && stdenv.isAarch64) [
# Numerical test (AssertionError)
# https://github.com/blackjax-devs/blackjax/issues/668
"test_chees_adaptation"
];

pythonImportsCheck = [
Expand Down
17 changes: 15 additions & 2 deletions pkgs/development/python-modules/equinox/default.nix
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,21 @@ buildPythonPackage rec {
pythonImportsCheck = [ "equinox" ];

disabledTests = [
# Failed: DID NOT WARN. No warnings of type (<class 'UserWarning'>,) were emitted.
"test_tracetime"
# For simplicity, JAX has removed its internal frames from the traceback of the following exception.
# https://github.com/patrick-kidger/equinox/issues/716
"test_abstract"
"test_complicated"
"test_grad"
"test_jvp"
"test_mlp"
"test_num_traces"
"test_pytree_in"
"test_simple"
"test_vmap"

# AssertionError: assert 'foo:\n pri...pe=float32)\n' == 'foo:\n pri...pe=float32)\n'
# Also reported in patrick-kidger/equinox#716
"test_backward_nan"
];

meta = with lib; {
Expand Down
8 changes: 4 additions & 4 deletions pkgs/development/python-modules/flax/default.nix
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

buildPythonPackage rec {
pname = "flax";
version = "0.8.2";
version = "0.8.3";
pyproject = true;

disabled = pythonOlder "3.9";
Expand All @@ -34,16 +34,16 @@ buildPythonPackage rec {
owner = "google";
repo = "flax";
rev = "refs/tags/v${version}";
hash = "sha256-UABgJGe1grUSkwOJpjeIoFqhXsqG//HlC1YyYPxXV+g=";
hash = "sha256-uDGTyksUZTTL6FiTJP+qteFLOjr75dcTj9yRJ6Jm8xU=";
};

nativeBuildInputs = [
build-system = [
jaxlib
pythonRelaxDepsHook
setuptools-scm
];

propagatedBuildInputs = [
dependencies = [
jax
msgpack
numpy
Expand Down
12 changes: 10 additions & 2 deletions pkgs/development/python-modules/jax/default.nix
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ let
in
buildPythonPackage rec {
pname = "jax";
version = "0.4.25";
version = "0.4.28";
pyproject = true;

disabled = pythonOlder "3.9";
Expand All @@ -39,7 +39,7 @@ buildPythonPackage rec {
repo = "jax";
# google/jax contains tags for jax and jaxlib. Only use jax tags!
rev = "refs/tags/jax-v${version}";
hash = "sha256-poQQo2ZgEhPYzK3aCs+BjaHTNZbezJAECd+HOdY1Yok=";
hash = "sha256-qSHPwi3is6Ts7pz5s4KzQHBMbcjGp+vAOsejW3o36Ek=";
};

nativeBuildInputs = [
Expand Down Expand Up @@ -81,6 +81,14 @@ buildPythonPackage rec {
"tests/"
];

# Prevents `tests/export_back_compat_test.py::CompatTest::test_*` tests from failing on darwin with
# PermissionError: [Errno 13] Permission denied: '/tmp/back_compat_testdata/test_*.py'
# See https://github.com/google/jax/blob/jaxlib-v0.4.27/jax/_src/internal_test_util/export_back_compat_test_util.py#L240-L241
# NOTE: this doesn't seem to be an issue on linux
preCheck = lib.optionalString stdenv.isDarwin ''
export TEST_UNDECLARED_OUTPUTS_DIR=$(mktemp -d)
'';

disabledTests = [
# Exceeds tolerance when the machine is busy
"test_custom_linear_solve_aux"
Expand Down
60 changes: 22 additions & 38 deletions pkgs/development/python-modules/jaxlib/bin.nix
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,17 @@
, stdenv
# Options:
, cudaSupport ? config.cudaSupport
, cudaPackagesGoogle
, cudaPackages
}:

let
inherit (cudaPackagesGoogle) cudaVersion;
inherit (cudaPackages) cudaVersion;

version = "0.4.24";
version = "0.4.28";

inherit (python) pythonVersion;

cudaLibPath = lib.makeLibraryPath (with cudaPackagesGoogle; [
cudaLibPath = lib.makeLibraryPath (with cudaPackages; [
cuda_cudart.lib # libcudart.so
cuda_cupti.lib # libcupti.so
cudnn.lib # libcudnn.so
Expand All @@ -56,65 +56,65 @@ let
"3.9-x86_64-linux" = getSrcFromPypi {
platform = "manylinux2014_x86_64";
dist = "cp39";
hash = "sha256-6P5ArMoLZiUkHUoQ/mJccbNj5/7el/op+Qo6cGQ33xE=";
hash = "sha256-Slbr8FtKTBeRaZ2HTgcvP4CPCYa0AQsU+1SaackMqdw=";
};
"3.9-aarch64-darwin" = getSrcFromPypi {
platform = "macosx_11_0_arm64";
dist = "cp39";
hash = "sha256-23JQZRwMLtt7sK/JlCBqqRyfTVIAVJFN2sL+nAkQgvU=";
hash = "sha256-sBVi7IrXVxm30DiXUkiel+trTctMjBE75JFjTVKCrTw=";
};
"3.9-x86_64-darwin" = getSrcFromPypi {
platform = "macosx_10_14_x86_64";
dist = "cp39";
hash = "sha256-OgMedn9GHGs5THZf3pkP3Aw/jJ0vL5qK1b+Lzf634Ik=";
hash = "sha256-T5jMg3srbG3P4Kt/+esQkxSSCUYRmqOvn6oTlxj/J4c=";
};

"3.10-x86_64-linux" = getSrcFromPypi {
platform = "manylinux2014_x86_64";
dist = "cp310";
hash = "sha256-/VwUIIa7mTs/wLz0ArsEfNrz2pGriVVT5GX9XRFRxfY=";
hash = "sha256-47zcb45g+FVPQVwU2TATTmAuPKM8OOVGJ0/VRfh1dps=";
};
"3.10-aarch64-darwin" = getSrcFromPypi {
platform = "macosx_11_0_arm64";
dist = "cp310";
hash = "sha256-LgICOyDGts840SQQJh+yOMobMASb62llvJjpGvhzrSw=";
hash = "sha256-8Djmi9ENGjVUcisLvjbmpEg4RDenWqnSg/aW8O2fjAk=";
};
"3.10-x86_64-darwin" = getSrcFromPypi {
platform = "macosx_10_14_x86_64";
dist = "cp310";
hash = "sha256-vhyULw+zBpz1UEi2tqgBMQEzY9a6YBgEIg6A4PPh3bQ=";
hash = "sha256-pCHSN/jCXShQFm0zRgPGc925tsJvUrxJZwS4eCKXvWY=";
};

"3.11-x86_64-linux" = getSrcFromPypi {
platform = "manylinux2014_x86_64";
dist = "cp311";
hash = "sha256-VJO/VVwBFkOEtq4y/sLVgAV8Cung01JULiuT6W96E/8=";
hash = "sha256-Rc4PPIQM/4I2z/JsN/Jsn/B4aV+T4MFiwyDCgfUEEnU=";
};
"3.11-aarch64-darwin" = getSrcFromPypi {
platform = "macosx_11_0_arm64";
dist = "cp311";
hash = "sha256-VtuwXxurpSp1KI8ty1bizs5cdy8GEBN2MgS227sOCmE=";
hash = "sha256-eThX+vN/Nxyv51L+pfyBH0NeQ7j7S1AgWERKf17M+Ck=";
};
"3.11-x86_64-darwin" = getSrcFromPypi {
platform = "macosx_10_14_x86_64";
dist = "cp311";
hash = "sha256-4Dj5dEGKb9hpg3HlVogNO1Gc9UibJhy1eym2mjivxAQ=";
hash = "sha256-L/gpDtx7ksfq5SUX9lSSYz4mey6QZ7rT5MMj0hPnfPU=";
};

"3.12-x86_64-linux" = getSrcFromPypi {
platform = "manylinux2014_x86_64";
dist = "cp312";
hash = "sha256-TlrGVtb3NTLmhnILWPLJR+jISCZ5SUV4wxNFpSfkCBo=";
hash = "sha256-RqGqhX9P7uikP8upXA4Kti1AwmzJcwtsaWVZCLo1n40=";
};
"3.12-aarch64-darwin" = getSrcFromPypi {
platform = "macosx_11_0_arm64";
dist = "cp312";
hash = "sha256-FIwK5CGykQjteuWzLZnbtAggIxLQeGV96bXlZGEytN0=";
hash = "sha256-jdi//jhTcC9jzZJNoO4lc0pNGc1ckmvgM9dyun0cF10=";
};
"3.12-x86_64-darwin" = getSrcFromPypi {
platform = "macosx_10_14_x86_64";
dist = "cp312";
hash = "sha256-9/jw/wr6oUD9pOadVAaMRL086iVMUXwVgnUMcG1UNvE=";
hash = "sha256-1sCaVFMpciRhrwVuc1FG0sjHTCKsdCaoRetp8ya096A=";
};
};

Expand All @@ -130,35 +130,19 @@ let
gpuSrcs = {
"cuda12.2-3.9" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp39-cp39-manylinux2014_x86_64.whl";
hash = "sha256-xdJKLPtx+CIza2CrWKM3M0cZJzyNFVTTTsvlgh38bfM=";
hash = "sha256-d8LIl22gIvmWfoyKfXKElZJXicPQIZxdS4HumhwQGCw=";
};
"cuda12.2-3.10" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp310-cp310-manylinux2014_x86_64.whl";
hash = "sha256-QCjrOczD2mp+CDwVXBc0/4rJnAizeV62AK0Dpx9X6TE=";
hash = "sha256-PXtWv+UEcMWF8LhWe6Z1UGkf14PG3dkJ0Iop0LiimnQ=";
};
"cuda12.2-3.11" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp311-cp311-manylinux2014_x86_64.whl";
hash = "sha256-Ipy3vk1yUplpNzECAFt63aOIhgEWgXG7hkoeTIk9bQQ=";
hash = "sha256-QO2WSOzmJ48VaCha596mELiOfPsAGLpGctmdzcCHE/o=";
};
"cuda12.2-3.12" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp312-cp312-manylinux2014_x86_64.whl";
hash = "sha256-LSnZHaUga/8Z65iKXWBnZDk4yUpNykFTu3vukCchO6Q=";
};
"cuda11.8-3.9" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp39-cp39-manylinux2014_x86_64.whl";
hash = "sha256-UmyugL0VjlXkiD7fuDPWgW8XUpr/QaP5ggp6swoZTzU=";
};
"cuda11.8-3.10" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp310-cp310-manylinux2014_x86_64.whl";
hash = "sha256-luKULEiV1t/sO6eckDxddJTiOFa0dtJeDlrvp+WYmHk=";
};
"cuda11.8-3.11" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp311-cp311-manylinux2014_x86_64.whl";
hash = "sha256-4+uJ8Ij6mFGEmjFEgi3fLnSLZs+v18BRoOt7mZuqydw=";
};
"cuda11.8-3.12" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp312-cp312-manylinux2014_x86_64.whl";
hash = "sha256-bUDFb94Ar/65SzzR9RLIs/SL/HdjaPT1Su5whmjkS00=";
hash = "sha256-ixWMaIChy4Ammsn23/3cCoala0lFibuUxyUr3tjfFKU=";
};
};

Expand Down Expand Up @@ -213,7 +197,7 @@ buildPythonPackage {
# for more info.
postInstall = lib.optional cudaSupport ''
mkdir -p $out/${python.sitePackages}/jaxlib/cuda/bin
ln -s ${lib.getExe' cudaPackagesGoogle.cuda_nvcc "ptxas"} $out/${python.sitePackages}/jaxlib/cuda/bin/ptxas
ln -s ${lib.getExe' cudaPackages.cuda_nvcc "ptxas"} $out/${python.sitePackages}/jaxlib/cuda/bin/ptxas
'';

inherit (jaxlib-build) pythonImportsCheck;
Expand All @@ -227,7 +211,7 @@ buildPythonPackage {
platforms = [ "aarch64-darwin" "x86_64-linux" "x86_64-darwin" ];
broken =
!(cudaSupport -> lib.versionAtLeast cudaVersion "11.1")
|| !(cudaSupport -> lib.versionAtLeast cudaPackagesGoogle.cudnn.version "8.2")
|| !(cudaSupport -> lib.versionAtLeast cudaPackages.cudnn.version "8.2")
|| !(cudaSupport -> stdenv.isLinux)
|| !(cudaSupport -> (gpuSrcs ? "cuda${cudaVersion}-${pythonVersion}"))
# Fails at pythonImportsCheckPhase:
Expand Down
40 changes: 15 additions & 25 deletions pkgs/development/python-modules/jaxlib/default.nix
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
, curl
, cython
, fetchFromGitHub
, fetchpatch
, git
, IOKit
, jsoncpp
Expand Down Expand Up @@ -45,22 +44,22 @@
, config
# CUDA flags:
, cudaSupport ? config.cudaSupport
, cudaPackagesGoogle
, cudaPackages

# MKL:
, mklSupport ? true
}@inputs:

let
inherit (cudaPackagesGoogle) cudaFlags cudaVersion cudnn nccl;
inherit (cudaPackages) cudaFlags cudaVersion cudnn nccl;

pname = "jaxlib";
version = "0.4.24";
version = "0.4.28";

# It's necessary to consistently use backendStdenv when building with CUDA
# support, otherwise we get libstdc++ errors downstream
stdenv = throw "Use effectiveStdenv instead";
effectiveStdenv = if cudaSupport then cudaPackagesGoogle.backendStdenv else inputs.stdenv;
effectiveStdenv = if cudaSupport then cudaPackages.backendStdenv else inputs.stdenv;

meta = with lib; {
description = "JAX is Autograd and XLA, brought together for high-performance machine learning research.";
Expand All @@ -78,7 +77,7 @@ let
# These are necessary at build time and run time.
cuda_libs_joined = symlinkJoin {
name = "cuda-joined";
paths = with cudaPackagesGoogle; [
paths = with cudaPackages; [
cuda_cudart.lib # libcudart.so
cuda_cudart.static # libcudart_static.a
cuda_cupti.lib # libcupti.so
Expand All @@ -92,11 +91,11 @@ let
# These are only necessary at build time.
cuda_build_deps_joined = symlinkJoin {
name = "cuda-build-deps-joined";
paths = with cudaPackagesGoogle; [
paths = with cudaPackages; [
cuda_libs_joined

# Binaries
cudaPackagesGoogle.cuda_nvcc.bin # nvcc
cudaPackages.cuda_nvcc.bin # nvcc

# Headers
cuda_cccl.dev # block_load.cuh
Expand Down Expand Up @@ -181,19 +180,10 @@ let
owner = "openxla";
repo = "xla";
# Update this according to https://github.com/google/jax/blob/jaxlib-v${version}/third_party/xla/workspace.bzl.
rev = "12eee889e1f2ad41e27d7b0e970cb92d282d3ec5";
hash = "sha256-68kjjgwYjRlcT0TVJo9BN6s+WTkdu5UMJqQcfHpBT90=";
rev = "e8247c3ea1d4d7f31cf27def4c7ac6f2ce64ecd4";
hash = "sha256-ZhgMIVs3Z4dTrkRWDqaPC/i7yJz2dsYXrZbjzqvPX3E=";
};

patches = [
# Resolves "could not convert ‘result’ from ‘SmallVector<[...],6>’ to
# ‘SmallVector<[...],4>’" compilation error. See https://github.com/google/jax/issues/19814#issuecomment-1945141259.
(fetchpatch {
url = "https://github.com/openxla/xla/commit/7a614cd346594fc7ea2fe75570c9c53a4a444f60.patch";
hash = "sha256-RtuQTH8wzNiJcOtISLhf+gMlH1gg8hekvxEB+4wX6BM=";
})
];

dontBuild = true;

# This is necessary for patchShebangs to know the right path to use.
Expand All @@ -220,7 +210,7 @@ let
repo = "jax";
# google/jax contains tags for jax and jaxlib. Only use jaxlib tags!
rev = "refs/tags/${pname}-v${version}";
hash = "sha256-hmx7eo3pephc6BQfoJ3U0QwWBWmhkAc+7S4QmW32qQs=";
hash = "sha256-qSHPwi3is6Ts7pz5s4KzQHBMbcjGp+vAOsejW3o36Ek=";
};

nativeBuildInputs = [
Expand Down Expand Up @@ -364,10 +354,10 @@ let
];

sha256 = (if cudaSupport then {
x86_64-linux = "sha256-8JilAoTbqOjOOJa/Zc/n/quaEDcpdcLXCNb34mfB+OM=";
x86_64-linux = "sha256-VGNMf5/DgXbgsu1w5J1Pmrukw+7UO31BNU+crKVsX5k=";
} else {
x86_64-linux = "sha256-iqS+I1FQLNWXNMsA20cJp7YkyGUeshee5b2QfRBNZtk=";
aarch64-linux = "sha256-qmJ0Fm/VGMTmko4PhKs1P8/GLEJmVxb8xg+ss/HsakY==";
x86_64-linux = "sha256-uOoAyMBLHPX6jzdN43b5wZV5eW0yI8sCDD7BSX2h4oQ=";
aarch64-linux = "sha256-+SnGKY9LIT1Qhu/x6Uh7sHRaAEjlc//qyKj1m4t16PA=";
}).${effectiveStdenv.system} or (throw "jaxlib: unsupported system: ${effectiveStdenv.system}");
};

Expand Down Expand Up @@ -414,7 +404,7 @@ buildPythonPackage {
# for more info.
postInstall = lib.optionalString cudaSupport ''
mkdir -p $out/bin
ln -s ${cudaPackagesGoogle.cuda_nvcc.bin}/bin/ptxas $out/bin/ptxas
ln -s ${cudaPackages.cuda_nvcc.bin}/bin/ptxas $out/bin/ptxas
find $out -type f \( -name '*.so' -or -name '*.so.*' \) | while read lib; do
patchelf --add-rpath "${lib.makeLibraryPath [cuda_libs_joined cudnn nccl]}" "$lib"
Expand All @@ -423,7 +413,7 @@ buildPythonPackage {

nativeBuildInputs = lib.optionals cudaSupport [ autoAddDriverRunpath ];

propagatedBuildInputs = [
dependencies = [
absl-py
curl
double-conversion
Expand Down
Loading

0 comments on commit 3a993d3

Please sign in to comment.