This package provides some mapslices
-like functions, with gradients defined for
Tracker and Zygote:
mapcols(f, M) ≈ mapreduce(f, hcat, eachcol(M))
MapCols{d}(f, M) # where d=size(M,1), for SVector slices
ThreadMapCols{d}(f, M) # using Threads.@threads
maprows(f, M) ≈ mapslices(f, M, dims=2)
slicemap(f, A; dims) ≈ mapslices(f, A, dims=dims) # only Zygote
The capitalised functions differ both in using StaticArrays slices, and using ForwardDiff for the gradient of each slice, instead of the same reverse-mode Tracker/Zygote. For small slices, this will often be much faster, with or without gradients.
The package also defines Zygote gradients for the Slice/Align functions in
JuliennedArrays,
which is a good way to roll-your-own mapslices
-like thing (and is exactly
how slicemap(f, A; dims)
works). Similar gradients are also available in
TensorCast,
and in LazyStack.
There are more details & examples at docs/intro.md.