Skip to content

Commit

Permalink
cleanup, add zygote
Browse files Browse the repository at this point in the history
  • Loading branch information
colinxs committed Mar 28, 2020
1 parent 7aa53c1 commit d4a88e2
Show file tree
Hide file tree
Showing 15 changed files with 511 additions and 458 deletions.
470 changes: 228 additions & 242 deletions Manifest.toml

Large diffs are not rendered by default.

5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@ version = "0.1.0"
[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
LyceumBase = "db31fed1-ca1e-4084-8a49-12fae1996a55"
LyceumCore = "e5bd5517-2193-49f0-ba9c-d5a8508cb639"
LyceumDevTools = "fd23256c-5a67-41c4-8f5a-c8cf5526e505"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Shapes = "175de200-b73b-11e9-28b7-9b5b306cec37"
StaticNumbers = "c5e4b96a-f99f-5557-8ed2-dc63ef9b5131"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
UnsafeArrays = "c4a57d5a-5b31-53a6-b365-19f8c011fbd6"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
julia = "1.4"
Expand Down
14 changes: 9 additions & 5 deletions src/SpecialArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,30 @@ module SpecialArrays
using Adapt
using Base: @propagate_inbounds, @pure, @_inline_meta, require_one_based_indexing
using DocStringExtensions
using LyceumBase.LyceumCore
using LyceumCore
using MacroTools: @forward
using Requires: @require
using StaticNumbers
using UnsafeArrays


const Idx = Union{Colon, Real, AbstractArray}
const Idx = Union{Colon,Real,AbstractArray}


include("viewtype.jl")
include("cartesianindexer.jl")

export innereltype, innerndims, inneraxes, innersize, innerlength
export flatten
include("functions.jl")

export SlicedArray, slice
include("slicedarray.jl")

export FlattenedArray, flatview
export FlattenedArray, flatten
include("flattenedarray.jl")

end # module
function __init__()
@require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" include("zygote.jl")
end

end # module
7 changes: 5 additions & 2 deletions src/cartesianindexer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@ end
getindex(A.parent, I...)
end

@propagate_inbounds function Base.setindex!(A::CartesianIndexer{<:Any,N}, v, I::Vararg{Int,N}) where {N}
@propagate_inbounds function Base.setindex!(
A::CartesianIndexer{<:Any,N},
v,
I::Vararg{Int,N},
) where {N}
setindex!(A.parent, v, I...)
end

Expand All @@ -17,4 +21,3 @@ Base.IndexStyle(::Type{<:CartesianIndexer}) = IndexCartesian()
function Base.similar(A::CartesianIndexer, T::Type, dims::Dims)
similar(A.parent, T, dims)
end

98 changes: 45 additions & 53 deletions src/flattenedarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,61 +2,23 @@ struct FlattenedArray{V,L,M,P,InAx} <: AbstractArray{V,L}
parent::P
inneraxes::InAx
@inline function FlattenedArray{V,L,M,P,InAx}(parent, inneraxes) where {V,L,M,P,InAx}
_check_flattenedarray_parameters(V, Val(L), Val(M), P, InAx)
new{V,L,M,P,InAx}(parent, inneraxes)
end
end

@inline function FlattenedArray(parent::AbsSimilarNestedArr{V,M,N}, inneraxes::NTuple{M,Any}) where {V,M,N}
FlattenedArray{V,M+N,M,typeof(parent),typeof(inneraxes)}(parent, inneraxes)
@inline function FlattenedArray(
parent::AbsSimilarNestedArr{V,M,N},
inneraxes::NTuple{M,Any},
) where {V,M,N}
FlattenedArray{V,M + N,M,typeof(parent),typeof(inneraxes)}(parent, inneraxes)
end

@inline function FlattenedArray(parent::AbsSimilarNestedArr)
FlattenedArray(parent, inneraxes(parent))
@inline function FlattenedArray(parent::AbsArr{<:Any,N}, inneraxes::NTuple{M,Any}) where {N,M}
V = innereltype(parent)
FlattenedArray{V,M + N,M,typeof(parent),typeof(inneraxes)}(parent, inneraxes)
end

function _check_flattenedarray_parameters(::Type{V}, ::Val{L}, ::Val{M}, ::Type{P}, ::Type{InAx}) where {V,L,M,P,InAx}
if !(L isa Int && M isa Int)
throw(ArgumentError("FlattenedArray type parameters L and M must be of type Int"))
end
if !(P <: AbsSimilarNestedArr{V,M,L-M})
throw(ArgumentError("FlattenedArray parameter P should be <: AbstractArray{<:AbstractArray{V,M},N}"))
end
if !(InAx <: NTuple{M,Any})
throw(ArgumentError("FlattenedArray parameter InAx should be <: NTuple{M,Any}"))
end
return nothing
end


"""
$(SIGNATURES)
Like [`flatten`](@ref), but provides a view into the parent array `A` instead of
creating a new array.
```jldoctest
julia> A = [reshape(Vector(1:6), (2, 3)), reshape(Vector(7:12), (2, 3))]
2-element Array{Array{Int64,2},1}:
[1 3 5; 2 4 6]
[7 9 11; 8 10 12]
julia> B = flatview(A)
2×3×2 flatview(::Array{Array{Int64,2},1}) with eltype Int64 and inner size (2, 3):
[:, :, 1] =
1 3 5
2 4 6
[:, :, 2] =
7 9 11
8 10 12
julia> B == reshape(hcat(B...), (2, 3, 2))
true
```
"""
flatview(A::AbsNestedArr) = FlattenedArray(A)
flatview(A::AbsArr) = A
@inline FlattenedArray(parent::AbsArr) = FlattenedArray(parent, inneraxes(parent))


####
Expand All @@ -73,7 +35,11 @@ flatview(A::AbsArr) = A
F.parent[_outer_indices(F, I)...][_inner_indices(F, I)...]
end

@propagate_inbounds function Base.setindex!(F::FlattenedArray{<:Any,L}, v, I::Vararg{Int,L}) where {L}
@propagate_inbounds function Base.setindex!(
F::FlattenedArray{<:Any,L},
v,
I::Vararg{Int,L},
) where {L}
F.parent[_outer_indices(F, I)...][_inner_indices(F, I)...] = v
return F
end
Expand Down Expand Up @@ -172,12 +138,10 @@ end


function Base.showarg(io::IO, A::FlattenedArray, toplevel)
print(io, "flatview(")
print(io, "flatten(")
Base.showarg(io, parent(A), false)
print(io, ')')
if toplevel
print(io, " with eltype ", eltype(A), " and inner size ", innersize(A))
end
toplevel && print(io, " with eltype ", eltype(A))
return nothing
end

Expand All @@ -186,4 +150,32 @@ end
#### Extra
####

@inline inneraxes(F::FlattenedArray) = F.inneraxes
@inline inneraxes(F::FlattenedArray) = F.inneraxes

"""
flatten(A::AbstractArray{<:AbstractArray{V,M},N})
Return a `M+N`-dimensional flattened view of `A`. Throws an error if the elements of `A` do not
have equal size. If `A` is not a nested array, the return value is `A` itself.
```jldoctest
julia> A = [reshape(Vector(1:6), (2, 3)), reshape(Vector(7:12), (2, 3))]
2-element Array{Array{Int64,2},1}:
[1 3 5; 2 4 6]
[7 9 11; 8 10 12]
julia> B = flatten(A)
2×3×2 flatten(::Array{Array{Int64,2},1}) with eltype Int64:
[:, :, 1] =
1 3 5
2 4 6
[:, :, 2] =
7 9 11
8 10 12
julia> B == reshape(hcat(B...), (2, 3, 2))
true
```
"""
flatten(A::AbsArr) = FlattenedArray(A)
90 changes: 37 additions & 53 deletions src/functions.jl
Original file line number Diff line number Diff line change
@@ -1,89 +1,73 @@
"""
innereltype(A::AbstractArray{<:AbstractArray})
innereltype(A::Type{<:AbstractArray{<:AbstractArray}})
innereltype(A::Type{<:AbstractArray})
innereltype(A::AbstractArray)
Returns the common element type of the element arrays of `A`.
Equivalent to eltype(eltype(A)).
Returns the common `eltype` of the elements of `A`.
"""
function innereltype end

innereltype(::Type{A}) where {A<:AbsNestedArr} = eltype(eltype(A))
innereltype(A::AbsNestedArr) = eltype(eltype(A))
innereltype(::Type{A}) where {A<:AbsArr} = eltype(eltype(A))
innereltype(A::AbsArr) = eltype(eltype(A))


"""
innerndims(A::AbstractArray{<:AbstractArray})
innerndims(A::Type{<:AbstractArray{<:AbstractArray}})
innerndims(A::Type{<:AbstractArray})
innerndims(A::AbstractArray)
Returns the dimensionality of the element arrays of `A`.
Equivalent to ndims(eltype(A)).
Returns the common dimensionality of the elements of `A`.
Throws an error if the elements of `A` do not have equal dimensionality.
"""
function innerndims end

innerndims(::Type{A}) where {A<:AbsNestedArr} = ndims(eltype(A))
innerndims(A::AbsNestedArr) = ndims(eltype(A))
innerndims(::Type{<:AbsArr{<:AbsArr{<:Any,N}}}) where {N} = N
innerndims(A::AbsArr{<:AbsArr{<:Any,N}}) where {N} = N
# Unlike innereltype which defaults to Any if eltype(A) isa UnionAll there is no good default for N
innerndims(A::AbsArr) = _scan_inner(ndims, A)


"""
inneraxes(A::AbstractArray{<:AbstractArray}[, d])
inneraxes(A::AbstractArray[, d])
Returns the common length of the element arrays of `A`.
Throws an error if the element arrays of `A` do not have equal axes.
Returns the common axes of the elements of `A`.
Throws an error if the elements of `A` do not have equal axes.
"""
function inneraxes end

function inneraxes(A::AbsNestedArr)
M = innerndims(A)
function inneraxes(A::AbsArr)
if isempty(A)
# TODO this would be wrong for offset arrays?
ax = ntuple(_ -> Base.OneTo(0), Val(M))
return ntuple(_ -> Base.OneTo(0), innerndims(A))
else
ax = axes(first(A))
length(ax) == M || throw(DimensionMismatch("length(inneraxes(A)) != innerndims(A)"))
for a in A
if axes(a) != ax
throw(DimensionMismatch("The elements of A do not have equal axes"))
end
end
return _scan_inner(axes, A)
end
return ax
end

@inline function inneraxes(A::AbsNestedArr, d::Integer)
@inline function inneraxes(A::AbsArr, d::Integer)
d <= innerndims(A) ? inneraxes(A)[d] : Base.OneTo(1)
end


"""
innersize(A::AbstractArray{<:AbstractArray}[, d])
innersize(A::AbstractArray[, d])
Returns the size of the element arrays of `A`.
Returns the common size of the elements of `A`.
Throws an error if the elements of `A` do not have equal axes.
"""
function innersize end

@inline innersize(A::AbsArr) = map(Base.unsafe_length, inneraxes(A))
@inline innersize(A::AbsArr, d::Integer) = Base.unsafe_length(inneraxes(A, d))


"""
innerlength(A::AbstractArray{<:AbstractArray})
Returns the common length of the element arrays of `A`.
Throws an error if the element arrays of `A` do not have equal size.
"""
function innerlength end

@inline innerlength(A::AbsNestedArr) = prod(innersize(A))

innerlength(A::AbstractArray)
Returns the common length of the elements of `A`.
Throws an error if the elements of `A` do not have equal length.
"""
flatten(A::AbstractArray{<:AbstractArray{V,M},N}
@inline innerlength(A::AbsArr) = _scan_inner(length, A)

Flatten `A` into an AbstractArray{V,M+N}. Fails if the elements of `A` do not all
have the same size. If the `A` is not a nested array, the return value is `A` itself.
"""
function flatten end

flatten(A::AbsArr) = A
flatten(A::AbsNestedArr) = Array(flatview(A))
function _scan_inner(f::F, A::AbsArr) where {F}
if isempty(A)
throw(ArgumentError("Cannot apply $f to an empty array"))
else
x = f(first(A))
for a in A
f(a) == x || error("The element arrays of A do not have matching $f")
end
return x
end
end
Loading

0 comments on commit d4a88e2

Please sign in to comment.