Skip to content

Commit

Permalink
fix: don't reexport NNlib.dropout
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jul 27, 2024
1 parent e5ca05f commit a65f2aa
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion docs/src/api/Building_Blocks/LuxLib.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ fused_conv_bias_activation

```@docs
alpha_dropout
LuxLib.dropout
dropout
```

## Normalization
Expand Down
5 changes: 3 additions & 2 deletions src/Lux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,16 @@ using NNlib: NNlib
using Optimisers: Optimisers
using Preferences: load_preference, @load_preference
using Random: Random, AbstractRNG
using Reexport: @reexport
using Reexport: Reexport, @reexport
using Statistics: mean
using UnrolledUtilities: unrolled_map, unrolled_mapreduce

import LuxCore: AbstractExplicitLayer, AbstractExplicitContainerLayer, initialparameters,
initialstates, parameterlength, statelength, inputsize, outputsize,
update_state, trainmode, testmode, setup, apply, display_name, replicate

@reexport using LuxCore, LuxLib, MLDataDevices, NNlib, WeightInitializers
@reexport using LuxCore, LuxLib, MLDataDevices, WeightInitializers
@eval Expr(:export, filter(x -> x !== :dropout, Reexport.exported_names(NNlib))...)

const CRC = ChainRulesCore

Expand Down
5 changes: 2 additions & 3 deletions src/layers/dropout.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ function Dropout(p; dims=:)
end

function (d::Dropout)(x, ps, st::NamedTuple)
y, _, rng = LuxLib.dropout(st.rng, x, d.p, st.training, d.q, d.dims)
y, _, rng = dropout(st.rng, x, d.p, st.training, d.q, d.dims)
return y, merge(st, (; rng))
end

Expand Down Expand Up @@ -176,8 +176,7 @@ end

function (d::VariationalHiddenDropout)(x, ps, st::NamedTuple)
_mask = st.mask === nothing ? x : st.mask
y, mask, rng = LuxLib.dropout(
st.rng, x, _mask, d.p, st.training, st.update_mask, d.q, d.dims)
y, mask, rng = dropout(st.rng, x, _mask, d.p, st.training, st.update_mask, d.q, d.dims)
return y, merge(st, (; mask, rng, update_mask=Val(false)))
end

Expand Down

0 comments on commit a65f2aa

Please sign in to comment.