Skip to content

Commit

Permalink
refactor: Vision module
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Aug 23, 2024
1 parent 4be73b0 commit 0cd8112
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 39 deletions.
32 changes: 17 additions & 15 deletions ext/BoltzMetalheadExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,41 +3,43 @@ module BoltzMetalheadExt
using ArgCheck: @argcheck
using Metalhead: Metalhead

using Boltz: Boltz, Utils, __maybe_initialize_model, Vision
using Lux: Lux, FromFluxAdaptor

using Boltz: Boltz, Utils, Vision
using Boltz.InitializeModels: maybe_initialize_model

Utils.is_extension_loaded(::Val{:Metalhead}) = true

function Vision.__AlexNet(; pretrained=false, kwargs...)
function Vision.AlexNetMetalhead(; pretrained=false, kwargs...)
model = FromFluxAdaptor()(Metalhead.AlexNet().layers)
pretrained && (model = Lux.Chain(model[1], model[2])) # Compatibility with pretrained weights
return __maybe_initialize_model(:alexnet, model; pretrained, kwargs...)
return maybe_initialize_model(:alexnet, model; pretrained, kwargs...)
end

function Vision.__ResNet(depth::Int; kwargs...)
function Vision.ResNetMetalhead(depth::Int; kwargs...)
@argcheck depth in (18, 34, 50, 101, 152)
model = FromFluxAdaptor()(Metalhead.ResNet(depth).layers)
return __maybe_initialize_model(Symbol(:resnet, depth), model; kwargs...)
return maybe_initialize_model(Symbol(:resnet, depth), model; kwargs...)
end

function Vision.__ResNeXt(depth::Int; kwargs...)
function Vision.ResNeXtMetalhead(depth::Int; kwargs...)
@argcheck depth in (50, 101, 152)
model = FromFluxAdaptor()(Metalhead.ResNeXt(depth).layers)
return __maybe_initialize_model(Symbol(:resnext, depth), model; kwargs...)
return maybe_initialize_model(Symbol(:resnext, depth), model; kwargs...)
end

function Vision.__GoogLeNet(; kwargs...)
function Vision.GoogLeNetMetalhead(; kwargs...)
model = FromFluxAdaptor()(Metalhead.GoogLeNet().layers)
return __maybe_initialize_model(:googlenet, model; kwargs...)
return maybe_initialize_model(:googlenet, model; kwargs...)
end

function Vision.__DenseNet(depth::Int; kwargs...)
function Vision.DenseNetMetalhead(depth::Int; kwargs...)
@argcheck depth in (121, 161, 169, 201)
model = FromFluxAdaptor()(Metalhead.DenseNet(depth).layers)
return __maybe_initialize_model(Symbol(:densenet, depth), model; kwargs...)
return maybe_initialize_model(Symbol(:densenet, depth), model; kwargs...)
end

function Vision.__MobileNet(name::Symbol; kwargs...)
function Vision.MobileNetMetalhead(name::Symbol; kwargs...)
@argcheck name in (:v1, :v2, :v3_small, :v3_large)
model = if name == :v1
FromFluxAdaptor()(Metalhead.MobileNetv1().layers)
Expand All @@ -48,13 +50,13 @@ function Vision.__MobileNet(name::Symbol; kwargs...)
elseif name == :v3_large
FromFluxAdaptor()(Metalhead.MobileNetv3(:large).layers)
end
return __maybe_initialize_model(Symbol(:mobilenet, "_", name), model; kwargs...)
return maybe_initialize_model(Symbol(:mobilenet, "_", name), model; kwargs...)
end

function Vision.__ConvMixer(name::Symbol; kwargs...)
function Vision.ConvMixerMetalhead(name::Symbol; kwargs...)
@argcheck name in (:base, :large, :small)
model = FromFluxAdaptor()(Metalhead.ConvMixer(name).layers)
return __maybe_initialize_model(Symbol(:convmixer, "_", name), model; kwargs...)
return maybe_initialize_model(Symbol(:convmixer, "_", name), model; kwargs...)
end

end
41 changes: 27 additions & 14 deletions src/initialize.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,16 @@
__get_pretrained_weights_path(name::Symbol) = __get_pretrained_weights_path(string(name))
function __get_pretrained_weights_path(name::String)
module InitializeModels

using ArgCheck: @argcheck
using Artifacts: Artifacts, @artifact_str
using JLD2: JLD2
using Random: Random

using LuxCore: LuxCore

using ..Utils: unwrap_val

get_pretrained_weights_path(name::Symbol) = get_pretrained_weights_path(string(name))
function get_pretrained_weights_path(name::String)
try
return @artifact_str(name)
catch err
Expand All @@ -9,18 +20,10 @@ function __get_pretrained_weights_path(name::String)
end
end

const INITIALIZE_KWARGS = """
* `pretrained::Bool=false`: If `true`, returns a pretrained model.
* `rng::Union{Nothing, AbstractRNG}=nothing`: Random number generator.
* `seed::Int=0`: Random seed.
* `initialized::Val{Bool}=Val(true)`: If `Val(true)`, returns
`(model, parameters, states)`, otherwise just `model`.
"""

function __initialize_model(
function initialize_model(
name::Symbol, model; pretrained::Bool=false, rng=nothing, seed=0, kwargs...)
if pretrained
path = __get_pretrained_weights_path(name)
path = get_pretrained_weights_path(name)
ps = load(joinpath(path, "$name.jld2"), "parameters")
st = load(joinpath(path, "$name.jld2"), "states")
return ps, st
Expand All @@ -32,10 +35,20 @@ function __initialize_model(
return LuxCore.setup(rng, model)
end

function __maybe_initialize_model(name::Symbol, model; pretrained=false,
function maybe_initialize_model(name::Symbol, model; pretrained=false,
initialized::Union{Val, Bool}=Val(true), kwargs...)
@argcheck !pretrained || unwrap_val(initialized)
unwrap_val(initialized) || return model
ps, st = __initialize_model(name, model; pretrained, kwargs...)
ps, st = initialize_model(name, model; pretrained, kwargs...)
return model, ps, st
end

const INITIALIZE_KWARGS = """
* `pretrained::Bool=false`: If `true`, returns a pretrained model.
* `rng::Union{Nothing, AbstractRNG}=nothing`: Random number generator.
* `seed::Int=0`: Random seed.
* `initialized::Val{Bool}=Val(true)`: If `Val(true)`, returns
`(model, parameters, states)`, otherwise just `model`.
"""

end
7 changes: 5 additions & 2 deletions src/vision/Vision.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
module Vision

using ArgCheck: @argcheck
using Compat: @compat
using Random: Xoshiro

using Lux: Lux
using LuxCore: LuxCore, AbstractExplicitLayer
using NNlib: relu
using Random: Xoshiro

using ..Boltz, __maybe_initialize_model, Layers, INITIALIZE_KWARGS
using ..InitializeModels: maybe_initialize_model, INITIALIZE_KWARGS
using ..Layers: Layers
using ..Utils: flatten_spatial, second_dim_mean, is_extension_loaded

include("extensions.jl")
Expand Down
12 changes: 6 additions & 6 deletions src/vision/extensions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,15 +143,15 @@ $(INITIALIZE_KWARGS)
function ConvMixer end

for f in [:AlexNet, :ResNet, :ResNeXt, :GoogLeNet, :DenseNet, :MobileNet, :ConvMixer]
f_inner = Symbol("__", f)
f_metalhead = Symbol(f, :Metalhead)
@eval begin
function $(f_inner) end
function $f(args...; kwargs...)
function $(f_metalhead) end
function $(f)(args...; kwargs...)
if !is_extension_loaded(Val(:Metalhead))
error("Metalhead.jl is not loaded. Please load Metalhead.jl to use this \
function.")
error("`Metalhead.jl` is not loaded. Please load `Metalhead.jl` to use \
this function.")
end
$(f_inner)(args...; kwargs...)
$(f_metalhead)(args...; kwargs...)
end
end
end
2 changes: 1 addition & 1 deletion src/vision/vgg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,5 +82,5 @@ function VGG(depth::Int; batchnorm::Bool=false, kwargs...)
name = Symbol(:vgg, depth, ifelse(batchnorm, "_bn", ""))
config, inchannels, nclasses, fcsize = VGG_CONFIG[depth], 3, 1000, 4096
model = VGG((224, 224); config, inchannels, batchnorm, nclasses, fcsize, dropout=0.5f0)
return __maybe_initialize_model(name, model; kwargs...)
return maybe_initialize_model(name, model; kwargs...)
end
2 changes: 1 addition & 1 deletion src/vision/vit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ $(INITIALIZE_KWARGS)
function VisionTransformer(name::Symbol; kwargs...)
@argcheck name in keys(VIT_CONFIGS)
model = VisionTransformer(; VIT_CONFIGS[name]..., kwargs...)
return __maybe_initialize_model(name, model; kwargs...)
return maybe_initialize_model(name, model; kwargs...)
end

const ViT = VisionTransformer

0 comments on commit 0cd8112

Please sign in to comment.