Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

python311Packages.{jax,jaxlib,jaxlib-bin}: 0.4.24 -> 0.4.28 #291705

Merged
merged 12 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions pkgs/development/python-modules/blackjax/default.nix
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

buildPythonPackage rec {
pname = "blackjax";
version = "1.2.0";
version = "1.2.1";
pyproject = true;

disabled = pythonOlder "3.9";
Expand All @@ -25,7 +25,7 @@ buildPythonPackage rec {
owner = "blackjax-devs";
repo = "blackjax";
rev = "refs/tags/${version}";
hash = "sha256-vXyxK3xALKG61YGK7fmoqQNGfOiagHFrvnU02WKZThw=";
hash = "sha256-VoWBCjFMyE5LVJyf7du/pKlnvDHj22lguiP6ZUzH9ak=";
};

build-system = [
Expand Down Expand Up @@ -56,6 +56,10 @@ buildPythonPackage rec {
disabledTests = [
# too slow
"test_adaptive_tempered_smc"
] ++ lib.optionals (stdenv.isLinux && stdenv.isAarch64) [
# Numerical test (AssertionError)
# https://github.com/blackjax-devs/blackjax/issues/668
"test_chees_adaptation"
];

pythonImportsCheck = [
Expand Down
17 changes: 15 additions & 2 deletions pkgs/development/python-modules/equinox/default.nix
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,21 @@ buildPythonPackage rec {
pythonImportsCheck = [ "equinox" ];

disabledTests = [
# Failed: DID NOT WARN. No warnings of type (<class 'UserWarning'>,) were emitted.
"test_tracetime"
# For simplicity, JAX has removed its internal frames from the traceback of the following exception.
# https://github.com/patrick-kidger/equinox/issues/716
"test_abstract"
"test_complicated"
"test_grad"
"test_jvp"
"test_mlp"
"test_num_traces"
"test_pytree_in"
"test_simple"
"test_vmap"

# AssertionError: assert 'foo:\n pri...pe=float32)\n' == 'foo:\n pri...pe=float32)\n'
# Also reported in patrick-kidger/equinox#716
"test_backward_nan"
];

meta = with lib; {
Expand Down
8 changes: 4 additions & 4 deletions pkgs/development/python-modules/flax/default.nix
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

buildPythonPackage rec {
pname = "flax";
version = "0.8.2";
version = "0.8.3";
pyproject = true;

disabled = pythonOlder "3.9";
Expand All @@ -34,16 +34,16 @@ buildPythonPackage rec {
owner = "google";
repo = "flax";
rev = "refs/tags/v${version}";
hash = "sha256-UABgJGe1grUSkwOJpjeIoFqhXsqG//HlC1YyYPxXV+g=";
hash = "sha256-uDGTyksUZTTL6FiTJP+qteFLOjr75dcTj9yRJ6Jm8xU=";
};

nativeBuildInputs = [
build-system = [
jaxlib
pythonRelaxDepsHook
setuptools-scm
];

propagatedBuildInputs = [
dependencies = [
jax
msgpack
numpy
Expand Down
12 changes: 10 additions & 2 deletions pkgs/development/python-modules/jax/default.nix
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ let
in
buildPythonPackage rec {
pname = "jax";
version = "0.4.25";
version = "0.4.28";
pyproject = true;

disabled = pythonOlder "3.9";
Expand All @@ -39,7 +39,7 @@ buildPythonPackage rec {
repo = "jax";
# google/jax contains tags for jax and jaxlib. Only use jax tags!
rev = "refs/tags/jax-v${version}";
hash = "sha256-poQQo2ZgEhPYzK3aCs+BjaHTNZbezJAECd+HOdY1Yok=";
hash = "sha256-qSHPwi3is6Ts7pz5s4KzQHBMbcjGp+vAOsejW3o36Ek=";
};

nativeBuildInputs = [
Expand Down Expand Up @@ -81,6 +81,14 @@ buildPythonPackage rec {
"tests/"
];

# Prevents `tests/export_back_compat_test.py::CompatTest::test_*` tests from failing on darwin with
# PermissionError: [Errno 13] Permission denied: '/tmp/back_compat_testdata/test_*.py'
# See https://github.com/google/jax/blob/jaxlib-v0.4.27/jax/_src/internal_test_util/export_back_compat_test_util.py#L240-L241
# NOTE: this doesn't seem to be an issue on linux
preCheck = lib.optionalString stdenv.isDarwin ''
export TEST_UNDECLARED_OUTPUTS_DIR=$(mktemp -d)
'';

disabledTests = [
# Exceeds tolerance when the machine is busy
"test_custom_linear_solve_aux"
Expand Down
60 changes: 22 additions & 38 deletions pkgs/development/python-modules/jaxlib/bin.nix
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,17 @@
, stdenv
# Options:
, cudaSupport ? config.cudaSupport
, cudaPackagesGoogle
, cudaPackages
GaetanLepage marked this conversation as resolved.
Show resolved Hide resolved
}:

let
inherit (cudaPackagesGoogle) cudaVersion;
inherit (cudaPackages) cudaVersion;

version = "0.4.24";
version = "0.4.28";

inherit (python) pythonVersion;

cudaLibPath = lib.makeLibraryPath (with cudaPackagesGoogle; [
cudaLibPath = lib.makeLibraryPath (with cudaPackages; [
cuda_cudart.lib # libcudart.so
cuda_cupti.lib # libcupti.so
cudnn.lib # libcudnn.so
Expand All @@ -56,65 +56,65 @@ let
"3.9-x86_64-linux" = getSrcFromPypi {
platform = "manylinux2014_x86_64";
dist = "cp39";
hash = "sha256-6P5ArMoLZiUkHUoQ/mJccbNj5/7el/op+Qo6cGQ33xE=";
hash = "sha256-Slbr8FtKTBeRaZ2HTgcvP4CPCYa0AQsU+1SaackMqdw=";
};
"3.9-aarch64-darwin" = getSrcFromPypi {
platform = "macosx_11_0_arm64";
dist = "cp39";
hash = "sha256-23JQZRwMLtt7sK/JlCBqqRyfTVIAVJFN2sL+nAkQgvU=";
hash = "sha256-sBVi7IrXVxm30DiXUkiel+trTctMjBE75JFjTVKCrTw=";
};
"3.9-x86_64-darwin" = getSrcFromPypi {
platform = "macosx_10_14_x86_64";
dist = "cp39";
hash = "sha256-OgMedn9GHGs5THZf3pkP3Aw/jJ0vL5qK1b+Lzf634Ik=";
hash = "sha256-T5jMg3srbG3P4Kt/+esQkxSSCUYRmqOvn6oTlxj/J4c=";
};

"3.10-x86_64-linux" = getSrcFromPypi {
platform = "manylinux2014_x86_64";
dist = "cp310";
hash = "sha256-/VwUIIa7mTs/wLz0ArsEfNrz2pGriVVT5GX9XRFRxfY=";
hash = "sha256-47zcb45g+FVPQVwU2TATTmAuPKM8OOVGJ0/VRfh1dps=";
};
"3.10-aarch64-darwin" = getSrcFromPypi {
platform = "macosx_11_0_arm64";
dist = "cp310";
hash = "sha256-LgICOyDGts840SQQJh+yOMobMASb62llvJjpGvhzrSw=";
hash = "sha256-8Djmi9ENGjVUcisLvjbmpEg4RDenWqnSg/aW8O2fjAk=";
};
"3.10-x86_64-darwin" = getSrcFromPypi {
platform = "macosx_10_14_x86_64";
dist = "cp310";
hash = "sha256-vhyULw+zBpz1UEi2tqgBMQEzY9a6YBgEIg6A4PPh3bQ=";
hash = "sha256-pCHSN/jCXShQFm0zRgPGc925tsJvUrxJZwS4eCKXvWY=";
};

"3.11-x86_64-linux" = getSrcFromPypi {
platform = "manylinux2014_x86_64";
dist = "cp311";
hash = "sha256-VJO/VVwBFkOEtq4y/sLVgAV8Cung01JULiuT6W96E/8=";
hash = "sha256-Rc4PPIQM/4I2z/JsN/Jsn/B4aV+T4MFiwyDCgfUEEnU=";
};
"3.11-aarch64-darwin" = getSrcFromPypi {
platform = "macosx_11_0_arm64";
dist = "cp311";
hash = "sha256-VtuwXxurpSp1KI8ty1bizs5cdy8GEBN2MgS227sOCmE=";
hash = "sha256-eThX+vN/Nxyv51L+pfyBH0NeQ7j7S1AgWERKf17M+Ck=";
};
"3.11-x86_64-darwin" = getSrcFromPypi {
platform = "macosx_10_14_x86_64";
dist = "cp311";
hash = "sha256-4Dj5dEGKb9hpg3HlVogNO1Gc9UibJhy1eym2mjivxAQ=";
hash = "sha256-L/gpDtx7ksfq5SUX9lSSYz4mey6QZ7rT5MMj0hPnfPU=";
};

"3.12-x86_64-linux" = getSrcFromPypi {
platform = "manylinux2014_x86_64";
dist = "cp312";
hash = "sha256-TlrGVtb3NTLmhnILWPLJR+jISCZ5SUV4wxNFpSfkCBo=";
hash = "sha256-RqGqhX9P7uikP8upXA4Kti1AwmzJcwtsaWVZCLo1n40=";
};
"3.12-aarch64-darwin" = getSrcFromPypi {
platform = "macosx_11_0_arm64";
dist = "cp312";
hash = "sha256-FIwK5CGykQjteuWzLZnbtAggIxLQeGV96bXlZGEytN0=";
hash = "sha256-jdi//jhTcC9jzZJNoO4lc0pNGc1ckmvgM9dyun0cF10=";
};
"3.12-x86_64-darwin" = getSrcFromPypi {
platform = "macosx_10_14_x86_64";
dist = "cp312";
hash = "sha256-9/jw/wr6oUD9pOadVAaMRL086iVMUXwVgnUMcG1UNvE=";
hash = "sha256-1sCaVFMpciRhrwVuc1FG0sjHTCKsdCaoRetp8ya096A=";
};
};

Expand All @@ -130,35 +130,19 @@ let
gpuSrcs = {
"cuda12.2-3.9" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp39-cp39-manylinux2014_x86_64.whl";
hash = "sha256-xdJKLPtx+CIza2CrWKM3M0cZJzyNFVTTTsvlgh38bfM=";
hash = "sha256-d8LIl22gIvmWfoyKfXKElZJXicPQIZxdS4HumhwQGCw=";
};
"cuda12.2-3.10" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp310-cp310-manylinux2014_x86_64.whl";
hash = "sha256-QCjrOczD2mp+CDwVXBc0/4rJnAizeV62AK0Dpx9X6TE=";
hash = "sha256-PXtWv+UEcMWF8LhWe6Z1UGkf14PG3dkJ0Iop0LiimnQ=";
};
"cuda12.2-3.11" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp311-cp311-manylinux2014_x86_64.whl";
hash = "sha256-Ipy3vk1yUplpNzECAFt63aOIhgEWgXG7hkoeTIk9bQQ=";
hash = "sha256-QO2WSOzmJ48VaCha596mELiOfPsAGLpGctmdzcCHE/o=";
};
"cuda12.2-3.12" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp312-cp312-manylinux2014_x86_64.whl";
hash = "sha256-LSnZHaUga/8Z65iKXWBnZDk4yUpNykFTu3vukCchO6Q=";
};
"cuda11.8-3.9" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp39-cp39-manylinux2014_x86_64.whl";
hash = "sha256-UmyugL0VjlXkiD7fuDPWgW8XUpr/QaP5ggp6swoZTzU=";
};
"cuda11.8-3.10" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp310-cp310-manylinux2014_x86_64.whl";
hash = "sha256-luKULEiV1t/sO6eckDxddJTiOFa0dtJeDlrvp+WYmHk=";
};
"cuda11.8-3.11" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp311-cp311-manylinux2014_x86_64.whl";
hash = "sha256-4+uJ8Ij6mFGEmjFEgi3fLnSLZs+v18BRoOt7mZuqydw=";
};
"cuda11.8-3.12" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp312-cp312-manylinux2014_x86_64.whl";
hash = "sha256-bUDFb94Ar/65SzzR9RLIs/SL/HdjaPT1Su5whmjkS00=";
hash = "sha256-ixWMaIChy4Ammsn23/3cCoala0lFibuUxyUr3tjfFKU=";
};
};

Expand Down Expand Up @@ -213,7 +197,7 @@ buildPythonPackage {
# for more info.
postInstall = lib.optional cudaSupport ''
mkdir -p $out/${python.sitePackages}/jaxlib/cuda/bin
ln -s ${lib.getExe' cudaPackagesGoogle.cuda_nvcc "ptxas"} $out/${python.sitePackages}/jaxlib/cuda/bin/ptxas
ln -s ${lib.getExe' cudaPackages.cuda_nvcc "ptxas"} $out/${python.sitePackages}/jaxlib/cuda/bin/ptxas
'';

inherit (jaxlib-build) pythonImportsCheck;
Expand All @@ -227,7 +211,7 @@ buildPythonPackage {
platforms = [ "aarch64-darwin" "x86_64-linux" "x86_64-darwin" ];
broken =
!(cudaSupport -> lib.versionAtLeast cudaVersion "11.1")
|| !(cudaSupport -> lib.versionAtLeast cudaPackagesGoogle.cudnn.version "8.2")
|| !(cudaSupport -> lib.versionAtLeast cudaPackages.cudnn.version "8.2")
|| !(cudaSupport -> stdenv.isLinux)
|| !(cudaSupport -> (gpuSrcs ? "cuda${cudaVersion}-${pythonVersion}"))
# Fails at pythonImportsCheckPhase:
Expand Down
40 changes: 15 additions & 25 deletions pkgs/development/python-modules/jaxlib/default.nix
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
, curl
, cython
, fetchFromGitHub
, fetchpatch
, git
, IOKit
, jsoncpp
Expand Down Expand Up @@ -45,22 +44,22 @@
, config
# CUDA flags:
, cudaSupport ? config.cudaSupport
, cudaPackagesGoogle
, cudaPackages

# MKL:
, mklSupport ? true
}@inputs:

let
inherit (cudaPackagesGoogle) cudaFlags cudaVersion cudnn nccl;
inherit (cudaPackages) cudaFlags cudaVersion cudnn nccl;

pname = "jaxlib";
version = "0.4.24";
version = "0.4.28";

# It's necessary to consistently use backendStdenv when building with CUDA
# support, otherwise we get libstdc++ errors downstream
stdenv = throw "Use effectiveStdenv instead";
effectiveStdenv = if cudaSupport then cudaPackagesGoogle.backendStdenv else inputs.stdenv;
effectiveStdenv = if cudaSupport then cudaPackages.backendStdenv else inputs.stdenv;

meta = with lib; {
description = "JAX is Autograd and XLA, brought together for high-performance machine learning research.";
Expand All @@ -78,7 +77,7 @@ let
# These are necessary at build time and run time.
cuda_libs_joined = symlinkJoin {
name = "cuda-joined";
paths = with cudaPackagesGoogle; [
paths = with cudaPackages; [
cuda_cudart.lib # libcudart.so
cuda_cudart.static # libcudart_static.a
cuda_cupti.lib # libcupti.so
Expand All @@ -92,11 +91,11 @@ let
# These are only necessary at build time.
cuda_build_deps_joined = symlinkJoin {
name = "cuda-build-deps-joined";
paths = with cudaPackagesGoogle; [
paths = with cudaPackages; [
cuda_libs_joined

# Binaries
cudaPackagesGoogle.cuda_nvcc.bin # nvcc
cudaPackages.cuda_nvcc.bin # nvcc

# Headers
cuda_cccl.dev # block_load.cuh
Expand Down Expand Up @@ -181,19 +180,10 @@ let
owner = "openxla";
repo = "xla";
# Update this according to https://github.com/google/jax/blob/jaxlib-v${version}/third_party/xla/workspace.bzl.
rev = "12eee889e1f2ad41e27d7b0e970cb92d282d3ec5";
hash = "sha256-68kjjgwYjRlcT0TVJo9BN6s+WTkdu5UMJqQcfHpBT90=";
rev = "e8247c3ea1d4d7f31cf27def4c7ac6f2ce64ecd4";
hash = "sha256-ZhgMIVs3Z4dTrkRWDqaPC/i7yJz2dsYXrZbjzqvPX3E=";
};

patches = [
# Resolves "could not convert ‘result’ from ‘SmallVector<[...],6>’ to
# ‘SmallVector<[...],4>’" compilation error. See https://github.com/google/jax/issues/19814#issuecomment-1945141259.
(fetchpatch {
url = "https://github.com/openxla/xla/commit/7a614cd346594fc7ea2fe75570c9c53a4a444f60.patch";
hash = "sha256-RtuQTH8wzNiJcOtISLhf+gMlH1gg8hekvxEB+4wX6BM=";
})
];

GaetanLepage marked this conversation as resolved.
Show resolved Hide resolved
dontBuild = true;

# This is necessary for patchShebangs to know the right path to use.
Expand All @@ -220,7 +210,7 @@ let
repo = "jax";
# google/jax contains tags for jax and jaxlib. Only use jaxlib tags!
rev = "refs/tags/${pname}-v${version}";
hash = "sha256-hmx7eo3pephc6BQfoJ3U0QwWBWmhkAc+7S4QmW32qQs=";
hash = "sha256-qSHPwi3is6Ts7pz5s4KzQHBMbcjGp+vAOsejW3o36Ek=";
};

nativeBuildInputs = [
Expand Down Expand Up @@ -364,10 +354,10 @@ let
];

sha256 = (if cudaSupport then {
x86_64-linux = "sha256-8JilAoTbqOjOOJa/Zc/n/quaEDcpdcLXCNb34mfB+OM=";
x86_64-linux = "sha256-VGNMf5/DgXbgsu1w5J1Pmrukw+7UO31BNU+crKVsX5k=";
} else {
x86_64-linux = "sha256-iqS+I1FQLNWXNMsA20cJp7YkyGUeshee5b2QfRBNZtk=";
aarch64-linux = "sha256-qmJ0Fm/VGMTmko4PhKs1P8/GLEJmVxb8xg+ss/HsakY==";
x86_64-linux = "sha256-uOoAyMBLHPX6jzdN43b5wZV5eW0yI8sCDD7BSX2h4oQ=";
aarch64-linux = "sha256-+SnGKY9LIT1Qhu/x6Uh7sHRaAEjlc//qyKj1m4t16PA=";
}).${effectiveStdenv.system} or (throw "jaxlib: unsupported system: ${effectiveStdenv.system}");
};

Expand Down Expand Up @@ -414,7 +404,7 @@ buildPythonPackage {
# for more info.
postInstall = lib.optionalString cudaSupport ''
mkdir -p $out/bin
ln -s ${cudaPackagesGoogle.cuda_nvcc.bin}/bin/ptxas $out/bin/ptxas
ln -s ${cudaPackages.cuda_nvcc.bin}/bin/ptxas $out/bin/ptxas

find $out -type f \( -name '*.so' -or -name '*.so.*' \) | while read lib; do
patchelf --add-rpath "${lib.makeLibraryPath [cuda_libs_joined cudnn nccl]}" "$lib"
Expand All @@ -423,7 +413,7 @@ buildPythonPackage {

nativeBuildInputs = lib.optionals cudaSupport [ autoAddDriverRunpath ];

propagatedBuildInputs = [
dependencies = [
absl-py
curl
double-conversion
Expand Down
Loading