Skip to content

Commit

Permalink
allow loading custom weights files for EfficientNet
Browse files Browse the repository at this point in the history
  • Loading branch information
IanButterworth committed Dec 4, 2023
1 parent 43e0e9d commit 7a93dfa
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 10 deletions.
13 changes: 9 additions & 4 deletions src/convnets/efficientnets/efficientnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,16 @@ function efficientnet(config::Symbol; norm_layer = BatchNorm, stochastic_depth_p
end

"""
EfficientNet(config::Symbol; pretrain::Bool = false, inchannels::Integer = 3,
EfficientNet(config::Symbol; pretrain::Union{Bool,String} = false, inchannels::Integer = 3,
nclasses::Integer = 1000)
Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)).
# Arguments
- `config`: size of the model. Can be one of `[:b0, :b1, :b2, :b3, :b4, :b5, :b6, :b7, :b8]`.
- `pretrain`: set to `true` to load the pre-trained weights for ImageNet
- `pretrain`: set to `true` to load the pre-trained weights for ImageNet, or provide a local path string to load a
custom weights file.
- `inchannels`: number of input channels.
- `nclasses`: number of output classes.
Expand All @@ -83,12 +84,16 @@ struct EfficientNet
end
@functor EfficientNet

function EfficientNet(config::Symbol; pretrain::Bool = false, inchannels::Integer = 3,
function EfficientNet(config::Symbol; pretrain::Union{Bool,String} = false, inchannels::Integer = 3,
nclasses::Integer = 1000)
layers = efficientnet(config; inchannels, nclasses)
model = EfficientNet(layers)
if pretrain
if pretrain === true
loadpretrain!(model, string("efficientnet_", config))
elseif pretrain isa String
isfile(pretrain) || error("Weights file does not exist at `pretrain`")
m = load_weights_file(pretrain)
Flux.loadmodel!(model, m)
end
return model
end
Expand Down
16 changes: 10 additions & 6 deletions src/pretrain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ Load the pre-trained weights for `model` using the stored artifacts.
"""
function loadweights(artifact_name)
artifact_dir = try
@artifact_str(artifact_name)
@artifact_str(artifact_name)
catch e
throw(ArgumentError("No pre-trained weights available for $artifact_name."))
end
Expand All @@ -23,15 +23,19 @@ function loadweights(artifact_name)
end

file_path = joinpath(artifact_dir, file_name)

if endswith(file_name, ".bson")

return load_weights_file(file_path)
end

function load_weights_file(file_path::String)
if endswith(file_path, ".bson")
artifact = BSON.load(file_path, @__MODULE__)
if haskey(artifact, :model_state)
return artifact[:model_state]
elseif haskey(artifact, :model)
return artifact[:model]
else
throw(ErrorException("Found weight artifact for $artifact_name but the weights are not saved under the key :model_state or :model."))
throw(ErrorException("Weights in the file `$file_path` are not saved under the key :model_state or :model."))
end
elseif endswith(file_path, ".jld2")
artifact = JLD2.load(file_path)
Expand All @@ -40,10 +44,10 @@ function loadweights(artifact_name)
elseif haskey(artifact, "model")
return artifact["model"]
else
throw(ErrorException("Found weight artifact for $artifact_name but the weights are not saved under the key \"model_state\" or \"model\"."))
throw(ErrorException("Weights in the file `$file_path` are not saved under the key \"model_state\" or \"model\"."))
end
else
throw(ErrorException("Found weight artifact for $artifact_name but only jld2 and bson serialization format are supported."))
throw(ErrorException("Only jld2 and bson serialization format are supported for weights files."))
end
end

Expand Down

0 comments on commit 7a93dfa

Please sign in to comment.