Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move StaticArrays support to extension #265

Merged
merged 9 commits into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,33 @@ uuid = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
version = "0.6.16"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
N5N3 marked this conversation as resolved.
Show resolved Hide resolved
DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
N5N3 marked this conversation as resolved.
Show resolved Hide resolved
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
N5N3 marked this conversation as resolved.
Show resolved Hide resolved

[weakdeps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[extensions]
StructArraysAdaptExt = "Adapt"
StructArraysGPUArraysCoreExt = "GPUArraysCore"
StructArraysStaticArraysExt = "StaticArrays"

[compat]
Adapt = "1, 2, 3"
Adapt = "2, 3"
ConstructionBase = "1"
DataAPI = "1"
GPUArraysCore = "0.1.2"
StaticArrays = "1.5.6"
StaticArraysCore = "1.3"
Tables = "1"
julia = "1.6"

[extras]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
PooledArrays = "2dfb63ee-cc39-5dd5-95bd-886bf059d720"
Expand All @@ -32,4 +40,4 @@ TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9"
WeakRefStrings = "ea10d353-3f73-51f8-a26c-33c1cb351aa5"

[targets]
test = ["Test", "JLArrays", "StaticArrays", "OffsetArrays", "PooledArrays", "TypedTables", "WeakRefStrings", "Documenter", "SparseArrays"]
test = ["Test", "JLArrays", "StaticArrays", "OffsetArrays", "PooledArrays", "TypedTables", "WeakRefStrings", "Documenter", "SparseArrays", "GPUArraysCore", "Adapt"]
97 changes: 32 additions & 65 deletions docs/src/advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,89 +6,56 @@ StructArrays support structures with custom data layout. The user is required to

Here is an example of a type `MyType` that has as custom fields either its field `data` or fields of its field `rest` (which is a named tuple):

```jldoctest advanced1
julia> using StructArrays
```@repl advanced1
using StructArrays

julia> struct MyType{T, NT<:NamedTuple}
data::T
rest::NT
end
struct MyType{T, NT<:NamedTuple}
data::T
rest::NT
end

julia> MyType(x; kwargs...) = MyType(x, values(kwargs))
MyType
MyType(x; kwargs...) = MyType(x, values(kwargs))
```

Let's create a small array of these objects:

```jldoctest advanced1
julia> s = [MyType(i/5, a=6-i, b=2) for i in 1:5]
5-element Vector{MyType{Float64, NamedTuple{(:a, :b), Tuple{Int64, Int64}}}}:
MyType{Float64, NamedTuple{(:a, :b), Tuple{Int64, Int64}}}(0.2, (a = 5, b = 2))
MyType{Float64, NamedTuple{(:a, :b), Tuple{Int64, Int64}}}(0.4, (a = 4, b = 2))
MyType{Float64, NamedTuple{(:a, :b), Tuple{Int64, Int64}}}(0.6, (a = 3, b = 2))
MyType{Float64, NamedTuple{(:a, :b), Tuple{Int64, Int64}}}(0.8, (a = 2, b = 2))
MyType{Float64, NamedTuple{(:a, :b), Tuple{Int64, Int64}}}(1.0, (a = 1, b = 2))
```@repl advanced1
s = [MyType(i/5, a=6-i, b=2) for i in 1:5]
```

The default `StructArray` does not unpack the `NamedTuple`:

```jldoctest advanced1
julia> sa = StructArray(s);

julia> sa.rest
5-element Vector{NamedTuple{(:a, :b), Tuple{Int64, Int64}}}:
(a = 5, b = 2)
(a = 4, b = 2)
(a = 3, b = 2)
(a = 2, b = 2)
(a = 1, b = 2)

julia> sa.a
ERROR: type NamedTuple has no field a
Stacktrace:
[1] component
[...]
```@repl advanced1
sa = StructArray(s);
sa.rest
sa.a
```

Suppose we wish to give the keywords their own fields. We can define custom `staticschema`, `component`, and `createinstance` methods for `MyType`:

```jldoctest advanced1
julia> function StructArrays.staticschema(::Type{MyType{T, NamedTuple{names, types}}}) where {T, names, types}
# Define the desired names and eltypes of the "fields"
return NamedTuple{(:data, names...), Base.tuple_type_cons(T, types)}
end;

julia> function StructArrays.component(m::MyType, key::Symbol)
# Define a component-extractor
return key === :data ? getfield(m, 1) : getfield(getfield(m, 2), key)
end;

julia> function StructArrays.createinstance(::Type{MyType{T, NT}}, x, args...) where {T, NT}
# Generate an instance of MyType from components
return MyType(x, NT(args))
end;
```@repl advanced1
function StructArrays.staticschema(::Type{MyType{T, NamedTuple{names, types}}}) where {T, names, types}
# Define the desired names and eltypes of the "fields"
return NamedTuple{(:data, names...), Base.tuple_type_cons(T, types)}
end;

function StructArrays.component(m::MyType, key::Symbol)
# Define a component-extractor
return key === :data ? getfield(m, 1) : getfield(getfield(m, 2), key)
end;

function StructArrays.createinstance(::Type{MyType{T, NT}}, x, args...) where {T, NT}
# Generate an instance of MyType from components
return MyType(x, NT(args))
end;
```

and now:

```jldoctest advanced1
julia> sa = StructArray(s);

julia> sa.a
5-element Vector{Int64}:
5
4
3
2
1

julia> sa.b
5-element Vector{Int64}:
2
2
2
2
2
```@repl advanced1
sa = StructArray(s);
sa.a
sa.b
```

The above strategy has been tested and implemented in [GeometryBasics.jl](https://github.com/JuliaGeometry/GeometryBasics.jl).
Expand Down
70 changes: 18 additions & 52 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,73 +9,39 @@ The package was largely inspired by the `Columns` type in [IndexedTables](https:
## Collection and initialization

One can create a `StructArray` by providing the struct type and a tuple or NamedTuple of field arrays:
```jldoctest intro
julia> using StructArrays

julia> struct Foo{T}
a::T
b::T
end

julia> adata = [1 2; 3 4]; bdata = [10 20; 30 40];

julia> x = StructArray{Foo}((adata, bdata))
2×2 StructArray(::Matrix{Int64}, ::Matrix{Int64}) with eltype Foo:
Foo{Int64}(1, 10) Foo{Int64}(2, 20)
Foo{Int64}(3, 30) Foo{Int64}(4, 40)
```@repl intro
using StructArrays
struct Foo{T}
a::T
b::T
end
adata = [1 2; 3 4]; bdata = [10 20; 30 40];
x = StructArray{Foo}((adata, bdata))
```

You can also initialze a StructArray by passing in a NamedTuple, in which case the name (rather than the order) specifies how the input arrays are assigned to fields:

```jldoctest intro
julia> x = StructArray{Foo}((b = adata, a = bdata)) # initialize a with bdata and vice versa
2×2 StructArray(::Matrix{Int64}, ::Matrix{Int64}) with eltype Foo:
Foo{Int64}(10, 1) Foo{Int64}(20, 2)
Foo{Int64}(30, 3) Foo{Int64}(40, 4)
```@repl intro
x = StructArray{Foo}((b = adata, a = bdata)) # initialize a with bdata and vice versa
```

If a struct is not specified, a StructArray with Tuple or NamedTuple elements will be created:
```jldoctest intro
julia> x = StructArray((adata, bdata))
2×2 StructArray(::Matrix{Int64}, ::Matrix{Int64}) with eltype Tuple{Int64, Int64}:
(1, 10) (2, 20)
(3, 30) (4, 40)

julia> x = StructArray((a = adata, b = bdata))
2×2 StructArray(::Matrix{Int64}, ::Matrix{Int64}) with eltype NamedTuple{(:a, :b), Tuple{Int64, Int64}}:
(a = 1, b = 10) (a = 2, b = 20)
(a = 3, b = 30) (a = 4, b = 40)
```@repl intro
x = StructArray((adata, bdata))
x = StructArray((a = adata, b = bdata))
```

It's also possible to create a `StructArray` by choosing a particular dimension to interpret as the components of a struct:

```jldoctest intro
julia> x = StructArray{Complex{Int}}(adata; dims=1) # along dimension 1, the first item `re` and the second is `im`
2-element StructArray(view(::Matrix{Int64}, 1, :), view(::Matrix{Int64}, 2, :)) with eltype Complex{Int64}:
1 + 3im
2 + 4im

julia> x = StructArray{Complex{Int}}(adata; dims=2) # along dimension 2, the first item `re` and the second is `im`
2-element StructArray(view(::Matrix{Int64}, :, 1), view(::Matrix{Int64}, :, 2)) with eltype Complex{Int64}:
1 + 2im
3 + 4im
```@repl intro
x = StructArray{Complex{Int}}(adata; dims=1) # along dimension 1, the first item `re` and the second is `im`
x = StructArray{Complex{Int}}(adata; dims=2) # along dimension 2, the first item `re` and the second is `im`
```

One can also create a `StructArray` from an iterable of structs without creating an intermediate `Array`:

```jldoctest intro
julia> StructArray(log(j+2.0*im) for j in 1:10)
10-element StructArray(::Vector{Float64}, ::Vector{Float64}) with eltype ComplexF64:
0.8047189562170501 + 1.1071487177940904im
1.0397207708399179 + 0.7853981633974483im
1.2824746787307684 + 0.5880026035475675im
1.4978661367769954 + 0.4636476090008061im
1.683647914993237 + 0.3805063771123649im
1.8444397270569681 + 0.3217505543966422im
1.985145956776061 + 0.27829965900511133im
2.1097538525880535 + 0.24497866312686414im
2.2213256282451583 + 0.21866894587394195im
2.3221954495706862 + 0.19739555984988078im
```@repl intro
StructArray(log(j+2.0*im) for j in 1:10)
```

Another option is to create an uninitialized `StructArray` and then fill it with data. Just like in normal arrays, this is done with the `undef` syntax:
Expand Down
16 changes: 16 additions & 0 deletions ext/StructArraysAdaptExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
module StructArraysAdaptExt
# Use Adapt allows for automatic conversion of CPU to GPU StructArrays
using Adapt, StructArrays
@static if !applicable(Adapt.adapt, Int)
# Adapt.jl has curried support, implement it ourself
N5N3 marked this conversation as resolved.
Show resolved Hide resolved
adpat(to) = Base.Fix1(Adapt.adapt, to)

Check warning on line 6 in ext/StructArraysAdaptExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/StructArraysAdaptExt.jl#L6

Added line #L6 was not covered by tests
N5N3 marked this conversation as resolved.
Show resolved Hide resolved
if VERSION < v"1.9.0-DEV.857"
@eval function adapt(to::Type{T}) where {T}
(@isdefined T) || return Base.Fix1(Adapt.adapt, to)
AT = Base.Fix1{typeof(Adapt.adapt),Type{T}}
return $(Expr(:new, :AT, :(Adapt.adapt), :to))

Check warning on line 11 in ext/StructArraysAdaptExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/StructArraysAdaptExt.jl#L8-L11

Added lines #L8 - L11 were not covered by tests
end
end
end
Adapt.adapt_structure(to, s::StructArray) = replace_storage(adapt(to), s)
N5N3 marked this conversation as resolved.
Show resolved Hide resolved
end
21 changes: 21 additions & 0 deletions ext/StructArraysGPUArraysCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
module StructArraysGPUArraysCoreExt

using StructArrays
using StructArrays: map_params, array_types

using Base: tail

import GPUArraysCore

# for GPU broadcast
import GPUArraysCore
function GPUArraysCore.backend(::Type{T}) where {T<:StructArray}
backends = map_params(GPUArraysCore.backend, array_types(T))
backend, others = backends[1], tail(backends)
isconsistent = mapfoldl(isequal(backend), &, others; init=true)
isconsistent || throw(ArgumentError("all component arrays must have the same GPU backend"))
return backend
end
StructArrays.always_struct_broadcast(::GPUArraysCore.AbstractGPUArrayStyle) = true

end # module
82 changes: 82 additions & 0 deletions ext/StructArraysStaticArraysExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
module StructArraysStaticArraysExt

using StructArrays
using StaticArrays: StaticArray, FieldArray, tuple_prod

"""
StructArrays.staticschema(::Type{<:StaticArray{S, T}}) where {S, T}

The `staticschema` of a `StaticArray` element type is the `staticschema` of the underlying `Tuple`.
```julia
julia> StructArrays.staticschema(SVector{2, Float64})
Tuple{Float64, Float64}
```
The one exception to this rule is `<:StaticArrays.FieldArray`, since `FieldArray` is based on a
struct. In this case, `staticschema(<:FieldArray)` returns the `staticschema` for the struct
which subtypes `FieldArray`.
"""
@generated function StructArrays.staticschema(::Type{<:StaticArray{S, T}}) where {S, T}
return quote
Base.@_inline_meta
return NTuple{$(tuple_prod(S)), T}
end
end
StructArrays.createinstance(::Type{T}, args...) where {T<:StaticArray} = T(args)
StructArrays.component(s::StaticArray, i) = getindex(s, i)

# invoke general fallbacks for a `FieldArray` type.
@inline function StructArrays.staticschema(T::Type{<:FieldArray})
invoke(StructArrays.staticschema, Tuple{Type{<:Any}}, T)
end
StructArrays.component(s::FieldArray, i) = invoke(StructArrays.component, Tuple{Any, Any}, s, i)
StructArrays.createinstance(T::Type{<:FieldArray}, args...) = invoke(StructArrays.createinstance, Tuple{Type{<:Any}, Vararg}, T, args...)

# Broadcast overload
using StaticArrays: StaticArrayStyle, similar_type, Size, SOneTo
using StaticArrays: broadcast_flatten, broadcast_sizes, first_statictype, __broadcast
using StructArrays: isnonemptystructtype
using Base.Broadcast: Broadcasted

# StaticArrayStyle has no similar defined.
# Overload `try_struct_copy` instead.
@inline function StructArrays.try_struct_copy(bc::Broadcasted{StaticArrayStyle{M}}) where {M}
flat = broadcast_flatten(bc); as = flat.args; f = flat.f
argsizes = broadcast_sizes(as...)
ax = axes(bc)
ax isa Tuple{Vararg{SOneTo}} || error("Dimension is not static. Please file a bug at `StaticArrays.jl`.")
return _broadcast(f, Size(map(length, ax)), argsizes, as...)
end

@inline function _broadcast(f, sz::Size{newsize}, s::Tuple{Vararg{Size}}, a...) where {newsize}
N5N3 marked this conversation as resolved.
Show resolved Hide resolved
first_staticarray = first_statictype(a...)
elements, ET = if prod(newsize) == 0
N5N3 marked this conversation as resolved.
Show resolved Hide resolved
# Use inference to get eltype in empty case (see also comments in _map)
N5N3 marked this conversation as resolved.
Show resolved Hide resolved
eltys = Tuple{map(eltype, a)...}
(), Core.Compiler.return_type(f, eltys)

Check warning on line 55 in ext/StructArraysStaticArraysExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/StructArraysStaticArraysExt.jl#L54-L55

Added lines #L54 - L55 were not covered by tests
else
temp = __broadcast(f, sz, s, a...)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part worries me a little bit, we are using something explicitly marked as internal in StaticArrays. Is there no way to achieve this using only public methods? Or maybe we could check over at StaticArrays if they can offer some solution (maybe add a public method that does what we need).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well __broadcast was splitted from _broadcast in JuliaArrays/StaticArrays.jl#1001.
So perhaps the best solution is defining it ourselves.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see, thanks for pointing out that discussion. In that case, maybe one could just add a small docstring in StaticArrays.__broadcast to mention that it is the method to be used by outside packages to implement broadcasting of wrapped static arrays (in that way it doesn't accidentally get removed in some StaticArrays refactor that would accidentally break our code).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense too.
Anyway, there's not much code added here. Should we just merge this PR as is, and revert that specific commit once StaticArrays get that mention.

temp, eltype(temp)
end
if isnonemptystructtype(ET)
@static if VERSION >= v"1.7"

Check warning on line 61 in ext/StructArraysStaticArraysExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/StructArraysStaticArraysExt.jl#L61

Added line #L61 was not covered by tests
arrs = ntuple(Val(fieldcount(ET))) do i
@inbounds similar_type(first_staticarray, fieldtype(ET, i), sz)(_getfields(elements, i))
end
else
similarET(::Type{SA}, ::Type{T}) where {SA, T} = i -> @inbounds similar_type(SA, fieldtype(T, i), sz)(_getfields(elements, i))
arrs = ntuple(similarET(first_staticarray, ET), Val(fieldcount(ET)))
end
N5N3 marked this conversation as resolved.
Show resolved Hide resolved
return StructArray{ET}(arrs)
end
@inbounds return similar_type(first_staticarray, ET, sz)(elements)
N5N3 marked this conversation as resolved.
Show resolved Hide resolved
end

@inline function _getfields(x::Tuple, i::Int)
N5N3 marked this conversation as resolved.
Show resolved Hide resolved
if @generated
return Expr(:tuple, (:(getfield(x[$j], i)) for j in 1:fieldcount(x))...)
else
return map(Base.Fix2(getfield, i), x)
end
end

end
Loading
Loading