diff --git a/src/convnets/efficientnets/efficientnet.jl b/src/convnets/efficientnets/efficientnet.jl index 3948048f..1e748723 100644 --- a/src/convnets/efficientnets/efficientnet.jl +++ b/src/convnets/efficientnets/efficientnet.jl @@ -60,7 +60,7 @@ function efficientnet(config::Symbol; norm_layer = BatchNorm, stochastic_depth_p end """ - EfficientNet(config::Symbol; pretrain::Bool = false, inchannels::Integer = 3, + EfficientNet(config::Symbol; pretrain::Union{Bool,String} = false, inchannels::Integer = 3, nclasses::Integer = 1000) Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)). @@ -68,7 +68,8 @@ Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)). # Arguments - `config`: size of the model. Can be one of `[:b0, :b1, :b2, :b3, :b4, :b5, :b6, :b7, :b8]`. - - `pretrain`: set to `true` to load the pre-trained weights for ImageNet + - `pretrain`: set to `true` to load the pre-trained weights for ImageNet, or provide a local path string to load a + custom weights file. - `inchannels`: number of input channels. - `nclasses`: number of output classes. @@ -83,12 +84,16 @@ struct EfficientNet end @functor EfficientNet -function EfficientNet(config::Symbol; pretrain::Bool = false, inchannels::Integer = 3, +function EfficientNet(config::Symbol; pretrain::Union{Bool,String} = false, inchannels::Integer = 3, nclasses::Integer = 1000) layers = efficientnet(config; inchannels, nclasses) model = EfficientNet(layers) - if pretrain + if pretrain === true loadpretrain!(model, string("efficientnet_", config)) + elseif pretrain isa String + isfile(pretrain) || error("Weights file does not exist at `pretrain`") + m = load_weights_file(pretrain) + Flux.loadmodel!(model, m) end return model end diff --git a/src/pretrain.jl b/src/pretrain.jl index 4403d634..67f23dcd 100644 --- a/src/pretrain.jl +++ b/src/pretrain.jl @@ -5,7 +5,7 @@ Load the pre-trained weights for `model` using the stored artifacts. """ function loadweights(artifact_name) artifact_dir = try - @artifact_str(artifact_name) + @artifact_str(artifact_name) catch e throw(ArgumentError("No pre-trained weights available for $artifact_name.")) end @@ -23,15 +23,19 @@ function loadweights(artifact_name) end file_path = joinpath(artifact_dir, file_name) - - if endswith(file_name, ".bson") + + return load_weights_file(file_path) +end + +function load_weights_file(file_path::String) + if endswith(file_path, ".bson") artifact = BSON.load(file_path, @__MODULE__) if haskey(artifact, :model_state) return artifact[:model_state] elseif haskey(artifact, :model) return artifact[:model] else - throw(ErrorException("Found weight artifact for $artifact_name but the weights are not saved under the key :model_state or :model.")) + throw(ErrorException("Weights in the file `$file_path` are not saved under the key :model_state or :model.")) end elseif endswith(file_path, ".jld2") artifact = JLD2.load(file_path) @@ -40,10 +44,10 @@ function loadweights(artifact_name) elseif haskey(artifact, "model") return artifact["model"] else - throw(ErrorException("Found weight artifact for $artifact_name but the weights are not saved under the key \"model_state\" or \"model\".")) + throw(ErrorException("Weights in the file `$file_path` are not saved under the key \"model_state\" or \"model\".")) end else - throw(ErrorException("Found weight artifact for $artifact_name but only jld2 and bson serialization format are supported.")) + throw(ErrorException("Only jld2 and bson serialization format are supported for weights files.")) end end