-
Hi, I have two custom using Lux
struct ImportanceScaling <: Lux.AbstractLuxLayer
dim::Int
end
function Lux.initialparameters(rng::AbstractRNG, layer::ImportanceScaling)
importance_weights = rand(rng, Float32, layer.dim)
return (; importance_weights)
end
function (layer::ImportanceScaling)(x::AbstractVecOrMat, params::NamedTuple, state::NamedTuple)
y = x .* params.importance_weights
return (y, state)
end
struct Decoder <: Lux.AbstractLuxLayer
dim_encoding::Int
dim_decoding::Int
end
function Lux.initialparameters(rng::AbstractRNG, decoder::Decoder)
word_embedding = kaiming_uniform(rng, Float32, decoder.dim_decoding, decoder.dim_encoding)
bias = Lux.init_linear_bias(rng, nothing, decoder.dim_encoding, decoder.dim_decoding)
return (; word_embedding, bias)
end
function Lux.outputsize(decoder::Decoder, _, ::AbstractRNG)
return (decoder.dim_decoding,)
end
function (decoder::Decoder)(x, params::NamedTuple, state::NamedTuple)
x1 = importance_weights .* x # want to use importance_weights from ImportanceScaling layer above
x2 = params.word_embedding * x1 .+ params.bias
return (x2, state)
end
model = Chain(ImportanceScaling(10), Decoder(10, 10)) |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 4 replies
-
If you Chain the models then it is assumed that layer[I] cannot interact with layer[I - 1]/layer[I + 1] in any form other than the outputs it generates. My recommendation would be to use a Lux.AbstractLuxContainerLayer and write out the forward pass manually. If you want to hack a solution with Chain, the only way would be to make ImportanceScaling return |
Beta Was this translation helpful? Give feedback.
Yes this would work if you are okay with having importance_weights in the decoder params. For the
share_parameters
to work correctly, you might want to initialize the decoder with importance_weights and later link thme as you did.One pointer to make your debugging easier (if needed), construct the Chain as
Chain(; importance_scaling=ImportanceScaling(10), decoder=Decoder(10, 10))
then the sharing becomesLux.Experimental.share_parameters(ps, (("importance_scaling.importance_weights", "decoder.importance_weights"),))