Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rework type parameter stripping #78

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 72 additions & 19 deletions src/FixedSizeArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ end
Base.size(a::FixedSizeArray) = a.size

function Base.similar(::T, ::Type{E}, size::NTuple{N,Int}) where {T<:FixedSizeArray,E,N}
with_replaced_parameters(DenseArray, T, Val(E), Val(N))(undef, size)
a = with_replaced_type_parameters(TypeParametersElementType(), T, Val(E))
b = with_replaced_type_parameters(TypeParametersDimensionality(), val_parameter(a), Val(N))
val_parameter(b)(undef, size)
end

Base.isassigned(a::FixedSizeArray, i::Int) = isassigned(a.mem, i)
Expand Down Expand Up @@ -106,7 +108,8 @@ end
# broadcasting

function Base.BroadcastStyle(::Type{T}) where {T<:FixedSizeArray}
Broadcast.ArrayStyle{stripped_type(DenseArray, T)}()
spec = TypeParametersElementTypeAndDimensionality()
Broadcast.ArrayStyle{val_parameter(with_stripped_type_parameters(spec, T))}()
end

function Base.similar(
Expand All @@ -118,30 +121,80 @@ end

# helper functions

normalized_type(::Type{T}) where {T} = T
val_parameter(::Val{P}) where {P} = P

function stripped_type_unchecked(::Type{DenseVector}, ::Type{<:GenericMemory{K,<:Any,AS}}) where {K,AS}
GenericMemory{K,<:Any,AS}
struct TypeParametersElementType end
struct TypeParametersDimensionality end
struct TypeParametersElementTypeAndDimensionality end

"""
with_stripped_type_parameters_unchecked(spec, t::Type)::Val{s}

An implementation detail of [`with_stripped_type_parameters`](@ref). Don't call
directly.
"""
function with_stripped_type_parameters_unchecked end

function with_stripped_type_parameters_unchecked(::TypeParametersElementType, ::Type{<:(GenericMemory{K, T, AS} where {T})}) where {K, AS}
s = GenericMemory{K, T, AS} where {T}
Val{s}()
end

# `Base.@assume_effects :consistent` is a workaround for:
# https://github.com/JuliaLang/julia/issues/56966

Base.@assume_effects :consistent function with_stripped_type_parameters_unchecked(spec::TypeParametersElementType, ::Type{<:(FixedSizeArray{T, N, Mem} where {T})}) where {N, Mem}
mem_v = with_stripped_type_parameters(spec, Mem)
mem = val_parameter(mem_v)
s = FixedSizeArray{T, N, mem{T}} where {T}
Val{s}()
end

Base.@assume_effects :consistent function stripped_type_unchecked(
::Type{DenseArray}, ::Type{<:FixedSizeArray{<:Any,<:Any,V}},
) where {V}
U = stripped_type(DenseVector, V)
FixedSizeArray{E,N,U{E}} where {E,N}
Base.@assume_effects :consistent function with_stripped_type_parameters_unchecked(::TypeParametersDimensionality, ::Type{<:(FixedSizeArray{T, N, Mem} where {N})}) where {T, Mem}
s = FixedSizeArray{T, N, Mem} where {N}
Val{s}()
end

function stripped_type(::Type{T}, ::Type{S}) where {T,S<:T}
ret = stripped_type_unchecked(T, S)::Type{<:T}::UnionAll
S::Type{<:ret}
normalized_type(ret) # ensure `UnionAll` type variable order is normalized
Base.@assume_effects :consistent function with_stripped_type_parameters_unchecked(::TypeParametersElementTypeAndDimensionality, ::Type{<:(FixedSizeArray{T, N, Mem} where {T, N})}) where {Mem}
spec_mem = TypeParametersElementType()
mem_v = with_stripped_type_parameters(spec_mem, Mem)
mem = val_parameter(mem_v)
s = FixedSizeArray{T, N, mem{T}} where {T, N}
Val{s}()
end

function with_replaced_parameters(::Type{T}, ::Type{S}, ::Val{P1}, ::Val{P2}) where {T,S<:T,P1,P2}
t = T{P1,P2}::Type{<:T}
s = stripped_type(T, S)
S::Type{<:s}
s{P1,P2}::Type{<:s}::Type{<:T}::Type{<:t}
"""
with_stripped_type_parameters(spec, t::Type)::Val{s}

The type `s` is a `UnionAll` supertype of `t`:

```julia
(s isa UnionAll) && (t <: s)
```

Furthermore, `s` has type variables in place of the type parameters specified
via `spec`.

NB: `Val{s}()` is returned instead of `s` so the method would be *consistent*
from the point of view of Julia's effect inference, enabling constant folding.

NB: this function is supposed to only have the one method. To add
functionality, add methods to [`with_stripped_type_parameters_unchecked`](@ref).
"""
function with_stripped_type_parameters(spec, t::Type)
ret = with_stripped_type_parameters_unchecked(spec, t)
s = val_parameter(ret)
s = s::UnionAll
s = s::Type{>:t}
Val{s}()
end

function with_replaced_type_parameters(spec, type::Type, parameter::Val)
tv = with_stripped_type_parameters(spec, type)
t = val_parameter(tv)
p = val_parameter(parameter)
s = t{p}
Val{s}()
end

dimension_count_of(::Base.SizeUnknown) = 1
Expand Down
Loading