Skip to content

Commit

Permalink
Make sure first example in Custom Layers docs uses type parameter (#2415
Browse files Browse the repository at this point in the history
)

* Make sure first example uses type parameter

* Move performance explanation up and expand a bit.

* tweak
  • Loading branch information
BioTurboNick authored Apr 1, 2024
1 parent d4b94ee commit 348c56f
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions docs/src/models/advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ Here we will try and describe usage of some more advanced features that Flux pro
Here is a basic example of a custom model. It simply adds the input to the result from the neural network.

```julia
struct CustomModel
chain::Chain
struct CustomModel{T <: Chain} # Parameter to avoid type instability
chain::T
end

function (m::CustomModel)(x)
Expand All @@ -21,6 +21,7 @@ end
# Call @layer to allow for training. Described below in more detail.
Flux.@layer CustomModel
```
Notice that we parameterized the type of the `chain` field. This is necessary for fast Julia code, so that that struct field can be given a concrete type. `Chain`s have a type parameter fully specifying the types of the layers they contain. By using a type parameter, we are freeing Julia to determine the correct concrete type, so that we do not need to specify the full, possibly quite long, type ourselves.

You can then use the model like:

Expand Down Expand Up @@ -140,7 +141,7 @@ end
# allow Join(op, m1, m2, ...) as a constructor
Join(combine, paths...) = Join(combine, paths)
```
Notice that we parameterized the type of the `paths` field. This is necessary for fast Julia code; in general, `T` might be a `Tuple` or `Vector`, but we don't need to pay attention to what it specifically is. The same goes for the `combine` field.
Notice again that we parameterized the type of the `combine` and `paths` fields. In addition to the performance considerations of concrete types, this allows either field to be `Vector`s, `Tuple`s, or one of each - we don't need to pay attention to which.

The next step is to use [`Flux.@layer`](@ref) to make our struct behave like a Flux layer. This is important so that calling `Flux.setup` on a `Join` maps over the underlying trainable arrays on each path.
```julia
Expand Down

0 comments on commit 348c56f

Please sign in to comment.