From 877ef96bbcea3abf76fa3627a613d3e2f44a9c2a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 25 Oct 2024 14:59:49 -0400 Subject: [PATCH] fix: task switching in AMDGPU complex batched_matmul (#178) * ci(buildkite): add downstream testing for NeuralOperators * perf: restore old batched_mul * fix: disable threading for certain devices * revert: "perf: restore old batched_mul" This reverts commit a8c0f3b4615f96a8773577e16fac61ba310d8123. --- .buildkite/testing.yml | 5 ++--- Project.toml | 2 +- src/impl/batched_mul.jl | 41 +++++++++++++++++++++++++++++++++++++---- 3 files changed, 40 insertions(+), 8 deletions(-) diff --git a/.buildkite/testing.yml b/.buildkite/testing.yml index a4cfaa6e..ad88470c 100644 --- a/.buildkite/testing.yml +++ b/.buildkite/testing.yml @@ -38,7 +38,6 @@ steps: - src - ext env: - RETESTITEMS_NWORKERS: 2 BACKEND_GROUP: "AMDGPU" agents: queue: "juliagpu" @@ -126,6 +125,7 @@ steps: repo: - "Boltz" - "Lux" + - "NeuralOperators" - group: ":telescope: Downstream AMD GPU" steps: @@ -143,8 +143,6 @@ steps: queue: "juliagpu" rocm: "*" rocmgpu: "*" - env: - RETESTITEMS_NWORKERS: 2 if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip downstream\]/ && build.message !~ /\[skip ci\]/ && build.branch != "main" timeout_in_minutes: 240 matrix: @@ -152,6 +150,7 @@ steps: repo: - "Boltz" - "Lux" + - "NeuralOperators" env: JULIA_PKG_SERVER: "" diff --git a/Project.toml b/Project.toml index 7225334c..6f6005b7 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "1.3.4" +version = "1.3.5" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/src/impl/batched_mul.jl b/src/impl/batched_mul.jl index 257b4e0f..b8900d8e 100644 --- a/src/impl/batched_mul.jl +++ b/src/impl/batched_mul.jl @@ -70,15 +70,15 @@ end function batched_matmul_loopvec_impl! end function fallback_batched_matmul( - dev, x::AbstractArray{xT, 3}, y::AbstractArray{yT, 3}) where {xT, yT} + opmode, x::AbstractArray{xT, 3}, y::AbstractArray{yT, 3}) where {xT, yT} z = similar(x, promote_type(eltype(x), eltype(y)), size(x, 1), size(y, 2), max(size(x, 3), size(y, 3))) - fallback_batched_matmul!(z, dev, x, y) + fallback_batched_matmul!(z, opmode, x, y) return z end function fallback_batched_matmul!( - z::AbstractArray{zT, 3}, dev, x::AbstractArray{xT, 3}, + z::AbstractArray{zT, 3}, opmode, x::AbstractArray{xT, 3}, y::AbstractArray{yT, 3}) where {zT, xT, yT} # XXX: bring back once the enzyme segfault is fixed # @warn "Using fallback Batched Matrix Multiply routine for $(dev) with A: size = \ @@ -90,6 +90,36 @@ function fallback_batched_matmul!( throw(DimensionMismatch(lazy"size(x) = $(size(x)), size(y) = $(size(y)) inconsistent for batched_matmul.")) end + if use_threaded_batched_matmul(get_device_type(x)) + unsafe_fallback_threaded_batched_matmul!(z, x, y) + else + unsafe_fallback_serial_batched_matmul!(z, x, y) + end + + return +end + +function unsafe_fallback_serial_batched_matmul!( + z::AbstractArray{zT, 3}, x::AbstractArray{xT, 3}, + y::AbstractArray{yT, 3}) where {zT, xT, yT} + if size(x, 3) == size(y, 3) + for L in axes(z, 3) + mul!(batchview(z, L), batchview(x, L), batchview(y, L)) + end + elseif size(x, 3) == 1 + for L in axes(z, 3) + mul!(batchview(z, L), batchview(x, 1), batchview(y, L)) + end + else # has to be size(y, 3) == 1 + for L in axes(z, 3) + mul!(batchview(z, L), batchview(x, L), batchview(y, 1)) + end + end +end + +function unsafe_fallback_threaded_batched_matmul!( + z::AbstractArray{zT, 3}, x::AbstractArray{xT, 3}, + y::AbstractArray{yT, 3}) where {zT, xT, yT} old_threads = maybe_reduce_BLAS_threads(z) if size(x, 3) == size(y, 3) @@ -107,10 +137,13 @@ function fallback_batched_matmul!( end reset_BLAS_threads(old_threads) - return end +use_threaded_batched_matmul(::Type) = false +use_threaded_batched_matmul(::Type{CUDADevice}) = true +use_threaded_batched_matmul(::Type{CPUDevice}) = true + function CRC.rrule(::typeof(batched_matmul), x::AbstractArray{xT, 3}, y::AbstractArray{yT, 3}) where {xT, yT} ∇batched_matmul = @closure Δ_ -> begin