Skip to content

Commit

Permalink
fix: update simplechains layer API
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 7, 2024
1 parent 5010f10 commit 0bd7099
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 26 deletions.
41 changes: 16 additions & 25 deletions src/layers/extension.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ Base.show(io::IO, ::MIME"text/plain", l::FluxLayer) = print(io, "FluxLayer($(l.l
## SimpleChains.jl

"""
SimpleChainsLayer{ToArray}(layer, lux_layer=nothing)
SimpleChainsLayer(layer, ToArray::Union{Bool, Val}=Val(false))
SimpleChainsLayer(layer, to_array::Union{Bool, Val}=Val(false))
SimpleChainsLayer(layer, lux_layer, to_array)
Wraps a `SimpleChains` layer into a `Lux` layer. All operations are performed using
`SimpleChains` but the layer satisfies the `AbstractLuxLayer` interface.
Expand All @@ -62,39 +62,30 @@ regular `Array` or not. Default is `false`.
- `layer`: SimpleChains layer
- `lux_layer`: Potentially equivalent Lux layer that is used for printing
"""
struct SimpleChainsLayer{ToArray, SL, LL <: Union{Nothing, AbstractLuxLayer}} <:
AbstractLuxLayer
to_array::ToArray
layer::SL
lux_layer::LL

function SimpleChainsLayer{ToArray}(layer, lux_layer=nothing) where {ToArray}
to_array = static(ToArray)
return new{typeof(to_array), typeof(layer), typeof(lux_layer)}(
to_array, layer, lux_layer)
end
function SimpleChainsLayer(layer, ToArray::BoolType=False())
to_array = static(ToArray)
return new{typeof(to_array), typeof(layer), Nothing}(to_array, layer, nothing)
end
@concrete struct SimpleChainsLayer <: AbstractLuxLayer
layer
lux_layer <: Union{Nothing, AbstractLuxLayer}
to_array <: StaticBool
end

function SimpleChainsLayer(layer, to_array::BoolType=False())
return SimpleChainsLayer(layer, nothing, static(to_array))
end

function Base.show(
io::IO, ::MIME"text/plain", s::SimpleChainsLayer{ToArray}) where {ToArray}
PrettyPrinting.print_wrapper_model(
io, "SimpleChainsLayer{to_array=$ToArray}", s.lux_layer)
function Base.show(io::IO, ::MIME"text/plain", s::SimpleChainsLayer)
PrettyPrinting.print_wrapper_model(io, "SimpleChainsLayer", s.lux_layer)
end

function (sc::SimpleChainsLayer)(x, ps, st)
y = match_eltype(sc, ps, st, x)
return (
simple_chain_output(
sc, apply_simple_chain(sc.layer, y, ps.params, MLDataDevices.get_device(x))),
to_array(sc.to_array,
apply_simple_chain(sc.layer, y, ps.params, MLDataDevices.get_device(x))),
st)
end

simple_chain_output(::SimpleChainsLayer{False}, y) = y
simple_chain_output(::SimpleChainsLayer{True}, y) = convert(Array, y)
to_array(::False, y) = y
to_array(::True, y) = convert(Array, y)

apply_simple_chain(layer, x, ps, ::CPUDevice) = layer(x, ps)

Expand Down
2 changes: 1 addition & 1 deletion src/transform/simplechains.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ function Adapt.adapt(to::ToSimpleChainsAdaptor, L::AbstractLuxLayer)
error("`ToSimpleChainsAdaptor` requires `SimpleChains.jl` to be loaded.")
end
sc_layer = fix_simplechain_input_dims(make_simplechain_network(L), to.input_dims)
return SimpleChainsLayer{to.convert_to_array}(sc_layer, L)
return SimpleChainsLayer(sc_layer, L, static(to.convert_to_array))
end

function make_simplechain_network end
Expand Down

0 comments on commit 0bd7099

Please sign in to comment.