Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
Merge pull request #15 from LuxDL/ap/fastpaths
Browse files Browse the repository at this point in the history
Add fast and type stable paths for certain datastructures
  • Loading branch information
avik-pal authored Sep 7, 2023
2 parents 0816490 + 8b741dc commit c5586c2
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 22 deletions.
6 changes: 1 addition & 5 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LuxDeviceUtils"
uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.1.7"
version = "0.1.8"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand All @@ -14,15 +14,13 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[weakdeps]
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
LuxDeviceUtilsComponentArraysExt = "ComponentArrays"
LuxDeviceUtilsFillArraysExt = "FillArrays"
LuxDeviceUtilsLuxAMDGPUExt = "LuxAMDGPU"
LuxDeviceUtilsLuxCUDAExt = "LuxCUDA"
Expand All @@ -32,7 +30,6 @@ LuxDeviceUtilsZygoteExt = "Zygote"
[compat]
Adapt = "3"
ChainRulesCore = "1"
ComponentArrays = "0.13, 0.14"
FillArrays = "0.13, 1"
Functors = "0.2, 0.3, 0.4"
LuxAMDGPU = "0.1"
Expand All @@ -45,7 +42,6 @@ Zygote = "0.6"
julia = "1.6"

[extras]
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
Expand Down
10 changes: 0 additions & 10 deletions ext/LuxDeviceUtilsComponentArraysExt.jl

This file was deleted.

25 changes: 18 additions & 7 deletions src/LuxDeviceUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -209,14 +209,25 @@ Return a `LuxCPUDevice` object which can be used to transfer data to CPU.
"""
@inline cpu_device() = LuxCPUDevice()

(::LuxCPUDevice)(x) = fmap(Base.Fix1(adapt, LuxCPUAdaptor()), x; exclude=_isleaf)
(::LuxCUDADevice)(x) = fmap(Base.Fix1(adapt, LuxCUDAAdaptor()), x; exclude=_isleaf)
(::LuxAMDGPUDevice)(x) = fmap(Base.Fix1(adapt, LuxAMDGPUAdaptor()), x; exclude=_isleaf)
(::LuxMetalDevice)(x) = fmap(Base.Fix1(adapt, LuxMetalAdaptor()), x; exclude=_isleaf)

for dev in (LuxCPUDevice, LuxCUDADevice, LuxAMDGPUDevice, LuxMetalDevice)
# Dispatches for Different Data Structures
# Abstract Array / Tuples / NamedTuples have special fast paths to facilitate type stability
# For all other types we rely on fmap which means we lose type stability.
# For Lux, typically models only has these 3 datastructures so we should be mostly fine.
for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal)
ldev = Symbol("Lux$(dev)Device")
ladaptor = Symbol("Lux$(dev)Adaptor")
@eval begin
function (::$dev)(::LuxCore.AbstractExplicitLayer)
function (::$(ldev))(x::AbstractArray)
fn = Base.Fix1(adapt, $(ladaptor)())
return _isbitsarray(x) ? fn(x) : map(fn, x)
end
(::$(ldev))(x::Tuple) = map(Base.Fix1(adapt, $(ladaptor)()), x)
(dev::$(ldev))(x::NamedTuple{F}) where {F} = NamedTuple{F}(dev(values(x)))
function (::$(ldev))(x)
_isleaf(x) && return adapt($(ladaptor)(), x)
return fmap(Base.Fix1(adapt, $(ladaptor)()), x; exclude=_isleaf)
end
function (::$(ldev))(::LuxCore.AbstractExplicitLayer)
throw(ArgumentError("Lux layers are stateless and hence don't participate in device transfers. Apply this function on the parameters and states generated using `Lux.setup`."))
end
end
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ComponentArrays = "0.14.1"
julia = "1.6"

2 comments on commit c5586c2

@avik-pal
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/91010

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.1.8 -m "<description of version>" c5586c21dc2e0381c880ecf2eb464f635d082ece
git push origin v0.1.8

Please sign in to comment.