diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c1dd2db..420bdcf 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -9,7 +9,7 @@ on: jobs: test: - name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} + name: Julia ${{ matrix.julia-version }} - ${{ matrix.os }} - ${{ matrix.julia-arch }} runs-on: ${{ matrix.os }} strategy: fail-fast: false diff --git a/.github/workflows/compat.yml b/.github/workflows/compat.yml index 4ec0fbf..37f914b 100644 --- a/.github/workflows/compat.yml +++ b/.github/workflows/compat.yml @@ -7,7 +7,7 @@ on: jobs: Compat: - name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} + name: Julia ${{ matrix.julia-version }} - ${{ matrix.os }} - ${{ matrix.julia-arch }} runs-on: ${{ matrix.os }} strategy: matrix: diff --git a/Manifest.toml b/Manifest.toml index bf941a3..b2614a1 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -29,6 +29,11 @@ git-tree-sha1 = "45bb6705d93be619b81451bb2006b7ee5d4e4453" uuid = "15f4f7f2-30c1-5605-9d31-71845cf9641f" version = "0.2.0" +[[BSON]] +git-tree-sha1 = "e794bd8f3f319218e8c8b46657631bdbea2807ca" +uuid = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" +version = "0.2.5" + [[Base64]] uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" @@ -50,6 +55,12 @@ git-tree-sha1 = "7d10b92c4d9951ccf3009d960d9b66883c174474" uuid = "00ebfdb7-1f24-5e51-bd34-a7502290713f" version = "2.2.0" +[[CodecZlib]] +deps = ["BinaryProvider", "Libdl", "TranscodingStreams"] +git-tree-sha1 = "05916673a2627dd91b4969ff8ba6941bc85a960e" +uuid = "944b1d66-785c-5afd-91f1-9de20f533193" +version = "0.6.0" + [[CommonSubexpressions]] deps = ["Test"] git-tree-sha1 = "efdaf19ab11c7889334ca247ff4c9f7c322817b0" @@ -58,9 +69,14 @@ version = "0.2.0" [[CompilerSupportLibraries_jll]] deps = ["Libdl", "Pkg"] -git-tree-sha1 = "067567a322fe466c5ec8d01413eee7127bd11699" +git-tree-sha1 = "ff8101d6736414bc93c0f8df77b1e4095ca988c3" uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" -version = "0.3.1+0" +version = "0.3.2+0" + +[[Crayons]] +git-tree-sha1 = "cb7a62895da739fe5bb43f1a26d4292baf4b3dc0" +uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" +version = "4.0.1" [[DataAPI]] git-tree-sha1 = "674b67f344687a88310213ddfa8a2b3c76cc4252" @@ -104,6 +120,27 @@ git-tree-sha1 = "88bb0edb352b16608036faadcc071adda068582a" uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" version = "0.8.1" +[[ElasticArrays]] +git-tree-sha1 = "5b5b7cb8cba44bcf337b8af0a1f3e57c89468660" +uuid = "fdbdab4c-e67f-52f5-8c3f-e7b388dad3d4" +version = "1.0.0" + +[[EllipsisNotation]] +git-tree-sha1 = "65dad386e877850e6fce4fc77f60fe75a468ce9d" +uuid = "da5c29d0-fa7d-589e-88eb-ea29b0a81949" +version = "0.4.0" + +[[ExprTools]] +git-tree-sha1 = "08c1f74d9ad03acf0ee84c12c9e665ab1a9a6e33" +uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04" +version = "0.1.0" + +[[EzXML]] +deps = ["Printf", "XML2_jll"] +git-tree-sha1 = "0fa3b52a04a4e210aeb1626def9c90df3ae65268" +uuid = "8f5d6c58-4d21-5cfd-889c-e3ad7ee6a615" +version = "1.1.0" + [[FFTW]] deps = ["AbstractFFTs", "FFTW_jll", "IntelOpenMP_jll", "Libdl", "LinearAlgebra", "MKL_jll", "Reexport"] git-tree-sha1 = "109d82fa4b00429f9afcce873e9f746f11f018d3" @@ -112,24 +149,34 @@ version = "1.2.0" [[FFTW_jll]] deps = ["Libdl", "Pkg"] -git-tree-sha1 = "ddb57f4cf125243b4aa4908c94d73a805f3cbf2c" +git-tree-sha1 = "6c975cd606128d45d1df432fb812d6eb10fee00b" uuid = "f5851436-0d7a-5f13-b9de-f02708fd171a" -version = "3.3.9+4" +version = "3.3.9+5" + +[[FilePathsBase]] +deps = ["Dates", "LinearAlgebra", "Printf", "Test", "UUIDs"] +git-tree-sha1 = "923fd3b942a11712435682eaa95cc8518c428b2c" +uuid = "48062228-2e41-5def-b9a4-89aafe57970f" +version = "0.8.0" [[FileWatching]] uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" [[FillArrays]] deps = ["LinearAlgebra", "Random", "SparseArrays"] -git-tree-sha1 = "85c6b57e2680fa28d5c8adc798967377646fbf66" +git-tree-sha1 = "3eb5253af6186eada40de3df524a1c10f0c6bfa2" uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "0.8.5" +version = "0.8.6" [[ForwardDiff]] deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "NaNMath", "Random", "SpecialFunctions", "StaticArrays"] -git-tree-sha1 = "88b082d492be6b63f967b6c96b352e25ced1a34c" +git-tree-sha1 = "869540e4367122fbffaace383a5bdc34d6e5e5ac" uuid = "f6369f11-7733-5829-9624-2563aa707210" -version = "0.10.9" +version = "0.10.10" + +[[Future]] +deps = ["Random"] +uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" [[GitForge]] deps = ["Dates", "HTTP", "JSON2"] @@ -182,6 +229,12 @@ git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" uuid = "82899510-4779-5014-852e-03e436cf321d" version = "1.0.0" +[[JLSO]] +deps = ["BSON", "CodecZlib", "FilePathsBase", "Memento", "Pkg", "Serialization"] +git-tree-sha1 = "ef6164da5b2cad11c0bd8282cc0f029622bc057d" +uuid = "9da8a3cd-07a3-59c0-a743-3fdc52c30d11" +version = "2.2.0" + [[JSON]] deps = ["Dates", "Mmap", "Parsers", "Unicode"] git-tree-sha1 = "b34d7cef7b337321e97d22242c3c2b91f476748e" @@ -213,6 +266,12 @@ uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" [[Libdl]] uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" +[[Libiconv_jll]] +deps = ["Libdl", "Pkg"] +git-tree-sha1 = "802f5b23c846cb4ed568cae0bfb0ce0d2ba1926d" +uuid = "94ce4f54-9a6c-5748-9c1c-f9c7231a4531" +version = "1.16.0+1" + [[LinearAlgebra]] deps = ["Libdl"] uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -220,21 +279,27 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [[Logging]] uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" +[[LyceumBase]] +deps = ["Adapt", "Dates", "ElasticArrays", "EllipsisNotation", "Future", "InteractiveUtils", "JLSO", "LibGit2", "LinearAlgebra", "Logging", "MacroTools", "Parameters", "Pkg", "Random", "Shapes", "StaticArrays", "Statistics", "UnicodePlots", "UniversalLogger", "UnsafeArrays"] +git-tree-sha1 = "453499625deeeeac95f07a695dcede8c53f3d810" +uuid = "db31fed1-ca1e-4084-8a49-12fae1996a55" +version = "0.1.0" + [[LyceumCore]] deps = ["StaticNumbers"] -git-tree-sha1 = "503641247a835f656eedf1c4c9141b5d60217bf1" +git-tree-sha1 = "2ce45d30ca36dda08f0a189c9eeb90d8a50ecc68" repo-rev = "master" repo-url = "https://github.com/Lyceum/LyceumCore.jl.git" uuid = "e5bd5517-2193-49f0-ba9c-d5a8508cb639" version = "0.1.0" [[LyceumDevTools]] -deps = ["Base64", "BenchmarkTools", "Dates", "GitHub", "HTTP", "JuliaFormatter", "LibGit2", "MacroTools", "Markdown", "Parameters", "Pkg", "PkgTemplates", "Reexport", "Registrator", "RegistryTools", "Test"] -git-tree-sha1 = "3db08ca2163bb4ef3829bd3a4f8b435ddbc64bbc" +deps = ["Base64", "BenchmarkTools", "Dates", "Distributed", "GitHub", "HTTP", "JuliaFormatter", "LibGit2", "LyceumBase", "MacroTools", "Markdown", "Parameters", "Pkg", "PkgTemplates", "Random", "Reexport", "Registrator", "RegistryTools", "Requires", "Shapes", "Test"] +git-tree-sha1 = "42b5fac7967884767870514561878dcc01d2ed87" repo-rev = "master" repo-url = "https://github.com/Lyceum/LyceumDevTools.jl.git" uuid = "fd23256c-5a67-41c4-8f5a-c8cf5526e505" -version = "0.3.0" +version = "0.3.1" [[MKL_jll]] deps = ["IntelOpenMP_jll", "Libdl", "Pkg"] @@ -258,9 +323,27 @@ git-tree-sha1 = "85f5947b53c8cfd53ccfa3f4abae31faa22c2181" uuid = "739be429-bea8-5141-9913-cc70e7f3736d" version = "0.7.0" +[[Memento]] +deps = ["Dates", "Distributed", "JSON", "Serialization", "Sockets", "Syslogs", "Test", "TimeZones", "UUIDs"] +git-tree-sha1 = "090463b13da88689e5eae6468a6f531a21392175" +uuid = "f28f55f0-a522-5efc-85c2-fe41dfb9b2d9" +version = "0.12.1" + +[[Missings]] +deps = ["DataAPI"] +git-tree-sha1 = "de0a5ce9e5289f27df672ffabef4d1e5861247d5" +uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" +version = "0.4.3" + [[Mmap]] uuid = "a63ad114-7e13-5084-954f-fe012c677804" +[[Mocking]] +deps = ["ExprTools"] +git-tree-sha1 = "916b850daad0d46b8c71f65f719c49957e9513ed" +uuid = "78c3b35d-d492-501b-9361-3d52fe80e533" +version = "0.7.1" + [[Mustache]] deps = ["Printf", "Tables"] git-tree-sha1 = "f39de3a12232eb47bd0629b3a661054287780276" @@ -328,6 +411,11 @@ version = "0.6.4" deps = ["Unicode"] uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" +[[PushVectors]] +git-tree-sha1 = "f157c6758aba95f179d28fcb6b3928d9e5e8c4d9" +uuid = "36b54c61-190e-5a5f-82d5-6f0a962d7362" +version = "0.2.0" + [[REPL]] deps = ["InteractiveUtils", "Markdown", "Sockets"] uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" @@ -350,9 +438,9 @@ version = "1.1.0" [[RegistryTools]] deps = ["AutoHashEquals", "LibGit2", "Pkg", "UUIDs"] -git-tree-sha1 = "3dfa671318ac8af835cc621b08fcdcc28a4aee67" +git-tree-sha1 = "14873d5a5c36b53897b47e64d123e363176f6cde" uuid = "d1eb7eb1-105f-429d-abf5-b0f65cb9e2c4" -version = "1.3.2" +version = "1.3.3" [[Requires]] deps = ["UUIDs"] @@ -375,6 +463,12 @@ version = "0.2.0" [[Sockets]] uuid = "6462fe0b-24de-5631-8697-dd941f90decc" +[[SortingAlgorithms]] +deps = ["DataStructures", "Random", "Test"] +git-tree-sha1 = "03f5898c9959f8115e30bc7226ada7d0df554ddd" +uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" +version = "0.3.1" + [[SparseArrays]] deps = ["LinearAlgebra", "Random"] uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" @@ -401,6 +495,18 @@ version = "0.3.2" deps = ["LinearAlgebra", "SparseArrays"] uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +[[StatsBase]] +deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics"] +git-tree-sha1 = "19bfcb46245f69ff4013b3df3b977a289852c3a1" +uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +version = "0.32.2" + +[[Syslogs]] +deps = ["Printf", "Sockets"] +git-tree-sha1 = "46badfcc7c6e74535cc7d833a91f4ac4f805f86d" +uuid = "cea106d9-e007-5e6c-ad93-58fe2094e9c4" +version = "0.3.0" + [[TableTraits]] deps = ["IteratorInterfaceExtensions"] git-tree-sha1 = "b1ad568ba658d8cbb3b892ed5380a6f3e781a81e" @@ -423,11 +529,23 @@ git-tree-sha1 = "43defcaf72b89b047f11b778cd83b71ac3e418b0" uuid = "37f0c46e-897f-50ef-b453-b26c3eed3d6c" version = "0.2.0" +[[TimeZones]] +deps = ["Dates", "EzXML", "Mocking", "Printf", "Serialization", "Unicode"] +git-tree-sha1 = "f60a33649ef8380bafe6be7d1af1eeb8a3a3ea92" +uuid = "f269a46b-ccf7-5d73-abea-4c690281aa53" +version = "1.0.1" + [[Tokenize]] git-tree-sha1 = "73c00ad506d88a7e8e4f90f48a70943101728227" uuid = "0796e94c-ce3b-5d07-9a54-7f471281c624" version = "0.5.8" +[[TranscodingStreams]] +deps = ["Random", "Test"] +git-tree-sha1 = "7c53c35547de1c5b9d46a4797cf6d8253807108c" +uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" +version = "0.9.5" + [[URIParser]] deps = ["Test", "Unicode"] git-tree-sha1 = "6ddf8244220dfda2f17539fa8c9de20d6c575b69" @@ -441,6 +559,18 @@ uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" [[Unicode]] uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" +[[UnicodePlots]] +deps = ["Crayons", "Dates", "SparseArrays", "StatsBase"] +git-tree-sha1 = "af0c29913f108f649999e74098814c7ef0f644de" +uuid = "b8865327-cd53-5732-bb35-84acbb429228" +version = "1.2.0" + +[[UniversalLogger]] +deps = ["Logging", "PushVectors"] +git-tree-sha1 = "ae1a73e7681e27bc37d8d9db21af73d7cf4518fd" +uuid = "5c5e3362-9445-4819-9f95-51c44c51adeb" +version = "0.2.0" + [[UnsafeArrays]] git-tree-sha1 = "1de6ef280110c7ad3c5d2f7a31a360b57a1bde21" uuid = "c4a57d5a-5b31-53a6-b365-19f8c011fbd6" @@ -452,6 +582,12 @@ git-tree-sha1 = "13f763d38c7a05688938808b49cb29b18b60c8c8" uuid = "104b5d7c-a370-577a-8038-80a2059c5097" version = "1.5.2" +[[XML2_jll]] +deps = ["Libdl", "Libiconv_jll", "Pkg", "Zlib_jll"] +git-tree-sha1 = "ed5603a695aefe3e9e404fc7b052e02cc72cfab6" +uuid = "02c8fc9c-b97f-50b9-bbe4-9be30ff0a78a" +version = "2.9.9+1" + [[ZMQ]] deps = ["FileWatching", "Sockets", "ZeroMQ_jll"] git-tree-sha1 = "adb2d52aa12c8284da12714f35d2b21fc3d5b2bb" @@ -464,6 +600,12 @@ git-tree-sha1 = "d24fc0004686b534cc7518412b626deeea0b0208" uuid = "8f1865be-045e-5c20-9c9f-bfbfb0764568" version = "4.3.2+1" +[[Zlib_jll]] +deps = ["Libdl", "Pkg"] +git-tree-sha1 = "fd36a6739e256527287c5444960d0266712cd49e" +uuid = "83775a58-1f1d-513f-b197-d71354ab007a" +version = "1.2.11+8" + [[Zygote]] deps = ["ArrayLayouts", "DiffRules", "FFTW", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"] git-tree-sha1 = "9688fce24bd8a9468fed12f3d5206099a39054dc" diff --git a/Project.toml b/Project.toml index 5360d82..040915b 100644 --- a/Project.toml +++ b/Project.toml @@ -12,19 +12,8 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Requires = "ae029012-a4dd-5104-9daa-d747884805df" Shapes = "175de200-b73b-11e9-28b7-9b5b306cec37" StaticNumbers = "c5e4b96a-f99f-5557-8ed2-dc63ef9b5131" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" UnsafeArrays = "c4a57d5a-5b31-53a6-b365-19f8c011fbd6" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] julia = "1.4" - -[extras] -AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9" -BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" -Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[targets] -test = ["Test", "AxisArrays", "BenchmarkTools", "Random", "Parameters"] diff --git a/src/SpecialArrays.jl b/src/SpecialArrays.jl index fd9f0c2..ec7e56d 100644 --- a/src/SpecialArrays.jl +++ b/src/SpecialArrays.jl @@ -1,11 +1,15 @@ module SpecialArrays using Adapt + using Base: @propagate_inbounds, @pure, @_inline_meta, require_one_based_indexing +using Base.MultiplicativeInverses: SignedMultiplicativeInverse + using DocStringExtensions using LyceumCore using MacroTools: @forward using Requires: @require +using Shapes using StaticNumbers using UnsafeArrays @@ -15,6 +19,7 @@ const Idx = Union{Colon,Real,AbstractArray} include("viewtype.jl") include("cartesianindexer.jl") +include("typedbool.jl") export innereltype, innerndims, inneraxes, innersize, innerlength include("functions.jl") @@ -25,6 +30,12 @@ include("slicedarray.jl") export FlattenedArray, flatten include("flattenedarray.jl") +export ElasticArray +include("elasticarray.jl") + +export BatchedVector, batch, batchlike +include("batchedvector.jl") + function __init__() @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" include("zygote.jl") end diff --git a/src/batchedvector.jl b/src/batchedvector.jl new file mode 100644 index 0000000..bdbb274 --- /dev/null +++ b/src/batchedvector.jl @@ -0,0 +1,118 @@ +struct BatchedVector{T,P<:AbsVec} <: AbsVec{T} + parent::P + offsets::Vector{Int} + @inline function BatchedVector{T,P}(parent, offsets) where {T,P<:AbsVec} + check_offsets(parent, offsets) + new(parent, offsets) + end +end + +@inline function BatchedVector(parent::AbsVec, offsets::AbsVec{<:Integer}) + T = viewtype(parent, firstindex(parent):firstindex(parent)) + BatchedVector{T,typeof(parent)}(parent, offsets) +end + + +#### +#### Core Array Interface +#### + +@inline Base.size(A::BatchedVector) = (length(A), ) + +@inline Base.length(A::BatchedVector) = length(A.offsets) - 1 + + +@propagate_inbounds function Base.getindex(A::BatchedVector, i::Int) + return view(A.parent, batchrange(A, i)) +end + +@propagate_inbounds function Base.setindex!(A::BatchedVector, x, i::Int) + A.parent[batchrange(A, i)] = x + return A +end + +Base.IndexStyle(::Type{<:BatchedVector}) = IndexLinear() + + +#### +#### Misc +#### + +@inline function Base.:(==)(A::BatchedVector, B::BatchedVector) + A.offsets == B.offsets && A.parent == B.parent +end + +Base.copy(A::BatchedVector) = BatchedVector(copy(A.parent), copy(A.offsets)) + +Base.dataids(A::BatchedVector) = (Base.dataids(A.parent)..., Base.dataids(A.offsets)...) + +Base.parent(A::BatchedVector) = A.parent + + +#### +#### Extra +#### + +""" + $(SIGNATURES) + +View `A` as a vector of batches where `batch(parent, batch_lengths)[i]` has +length `batch_lengths[i]`. +""" +@inline function batch(parent::AbsVec, batch_lengths) + offsets = Vector{Int}(undef, length(batch_lengths) + 1) + offsets[1] = cumsum = 0 + for (i, l) in enumerate(batch_lengths) + cumsum += l + offsets[i + 1] = cumsum + end + BatchedVector(parent, offsets) +end + +""" + $(SIGNATURES) +View `A` as a vector of batches using the same batch lengths as `B`. +""" +@inline function batchlike(A::AbsVec, B::BatchedVector) + length(A) == length(B.parent) || throw(ArgumentError("length(A) != length(parent(B))")) + BatchedVector(A, copy(B.offsets)) +end + + +#### +#### 3rd Party +#### + +@inline function UnsafeArrays.unsafe_uview(A::BatchedVector) + BatchedVector(uview(A.parent), A.offsets) +end + + +#### +#### Util +#### + +@propagate_inbounds function batchrange(A::BatchedVector, i::Int) + j = firstindex(A.parent) + from = A.offsets[i] + j + to = A.offsets[i + 1] - 1 + j + return from:to +end + +check_offsets(A::BatchedVector) = check_offsets(A.parent, A.offsets) + +function check_offsets(parent::AbsVec, offsets::AbsVec{<:Integer}) + length(offsets) >= 1 || throw(ArgumentError("offsets cannot be empty")) + first(offsets) == 0 || throw(ArgumentError("First offset is non-zero")) + len = 0 + for i in LinearIndices(offsets)[2:end] + o1 = offsets[i-1] + o2 = offsets[i] + o2 > o1 || throw(ArgumentError("Overlapping indices found in offsets")) + len += o2 - o1 + end + if len != length(parent) + throw(ArgumentError("Length computed from offsets is not equal to length(parent)")) + end + return nothing +end \ No newline at end of file diff --git a/src/elasticarray.jl b/src/elasticarray.jl new file mode 100644 index 0000000..f16a8e3 --- /dev/null +++ b/src/elasticarray.jl @@ -0,0 +1,277 @@ +# Modified from https://github.com/JuliaArrays/ElasticArrays.jl +# Copyright (c) 2017: Oliver Schulz oschulz@mpp.mpg.de +# Copyright (c) 2020: Colin Summers colinxs@cs.washington.edu + +const IDims{N} = NTuple{N,Integer} +const IVararg{N} = Vararg{Integer,N} + +""" + ElasticArray{T,N,M} <: DenseArray{T,N} + +An `ElasticArray` can grow/shrink in its last dimension. `N` is the total number of +dimensions, `M == N - 1` the number of non-resizable dimensions (all but the last dimension). + +Constructors: + + ElasticArray(kernel_size::Dims{M}, data::Vector{T}, len::Int) + ElasticArray{T}(::UndefInitializer, dims::NTuple{N,Integer}) + ElasticArray{T}(::UndefInitializer, dims::Integer...) + convert(ElasticArray, A::AbsArr) +""" +struct ElasticArray{T,N,M} <: DenseArray{T,N} + kernel_size::Dims{M} + kernel_length::SignedMultiplicativeInverse{Int} + data::Vector{T} + + function ElasticArray{T,N,M}(kernel_size::Dims{M}, data::Vector{T}) where {T,N,M} + check_elasticarray_parameters(T, Val(N), Val(M)) + kernel_length = SignedMultiplicativeInverse{Int}(prod(kernel_size)) + if rem(length(eachindex(data)), kernel_length) != 0 + throw(ArgumentError("length(data) must be integer multiple of prod(kernel_size)")) + end + new{T,N,M}(kernel_size, kernel_length, data) + end +end + +@inline function ElasticArray{T,N,M}(kernel_size::IDims{M}, data::AbsVec) where {T,N,M} + check_elasticarray_parameters(T, Val(N), Val(M)) + ElasticArray{T,N,M}(convert(Dims{M}, kernel_size), convert(Vector{T}, data)) +end +@inline function ElasticArray{T,N}(kernel_size::IDims{M}, data::AbsVec) where {T,N,M} + ElasticArray{T,N,M}(kernel_size, data) +end +@inline function ElasticArray{T}(kernel_size::IDims{M}, data::AbsVec) where {T,M} + ElasticArray{T,M + 1}(kernel_size, data) +end +@inline function ElasticArray(kernel_size::IDims, data::AbsVec{T}) where {T} + ElasticArray{T}(kernel_size, data) +end + + +@inline function ElasticArray{T,N,M}(::UndefInitializer, dims::IDims{N}) where {T,N,M} + check_elasticarray_parameters(T, Val(N), Val(M)) + ElasticArray{T,N,M}(front(dims), Vector{T}(undef, prod(dims))) +end +@inline function ElasticArray{T,N}(::UndefInitializer, dims::IDims{N}) where {T,N} + ElasticArray{T,N,N - 1}(undef, dims) +end +@inline function ElasticArray{T}(::UndefInitializer, dims::IDims{N}) where {T,N} + ElasticArray{T,N}(undef, dims) +end + +@inline function ElasticArray{T,N,M}(::UndefInitializer, dims::IVararg{N}) where {T,N,M} + ElasticArray{T,N,M}(undef, dims) +end +@inline function ElasticArray{T,N}(::UndefInitializer, dims::IVararg{N}) where {T,N} + ElasticArray{T,N}(undef, dims) +end +@inline function ElasticArray{T}(::UndefInitializer, dims::IVararg{N}) where {T,N} + ElasticArray{T,N}(undef, dims) +end + + +@propagate_inbounds function ElasticArray{T,N,M}(A::AbsArr{<:Any,N}) where {T,N,M} + check_elasticarray_parameters(T, Val(N), Val(M)) + ElasticArray{T,N,M}(front(size(A)), copyto!(Vector{T}(undef, length(A)), A)) +end +@propagate_inbounds function ElasticArray{T,N}(A::AbsArr{<:Any,N}) where {T,N} + ElasticArray{T,N,N - 1}(A) +end +@propagate_inbounds function ElasticArray{T}(A::AbsArr) where {T} + ElasticArray{T,ndims(A)}(A) +end + +@propagate_inbounds ElasticArray(A::AbsArr) = ElasticArray{eltype(A)}(A) + + +@inline function ElasticArray(undef, shape::AbstractShape{<:Any,T,M}, dims::NTuple{N,Integer}) where {T,M,N} + ElasticArray{T,M+N}(undef, size(shape)..., dims...) +end + +@inline function ElasticArray(undef, shape::AbstractShape, dims::Integer...) where {T,M,N} + ElasticArray(undef, shape, dims) +end + + +Base.convert(::Type{T}, A::AbsArr) where {T<:ElasticArray} = A isa T ? A : T(A) + + +#### +#### Core Array Interface +#### + +@inline function Base.size(A::ElasticArray) + (A.kernel_size..., div(length(eachindex(A.data)), A.kernel_length)) +end + +@inline function Base.size(A::ElasticArray, d) + if d > ndims(A) + return 1 + elseif d == ndims(A) + return div(length(eachindex(A.data)), A.kernel_length) + else + A.kernel_size[d] + end +end + +@propagate_inbounds Base.getindex(A::ElasticArray, i::Int) = getindex(A.data, i) +@propagate_inbounds Base.setindex!(A::ElasticArray, x, i::Int) = (setindex!(A.data, x, i); return A) + +Base.IndexStyle(::Type{<:ElasticArray}) = IndexLinear() + +@inline Base.length(A::ElasticArray) = length(A.data) + +@inline function Base.similar(A::ElasticArray, T::Type, dims::Dims{N}) where {N} + ElasticArray{T,N}(front(dims), similar(A.data, T, prod(dims))) +end + + +#### +#### Misc +#### + +@inline function Base.:(==)(A::ElasticArray{<:Any,N,M}, B::ElasticArray{<:Any,N,M}) where {N,M} + A.kernel_size == B.kernel_size && A.data == B.data +end +@inline Base.:(==)(A::ElasticArray, B::ElasticArray) = false + + +@inline function Base.unsafe_convert(::Type{Ptr{T}}, A::ElasticArray{T}) where {T} + Base.unsafe_convert(Ptr{T}, A.data) +end + +@inline Base.pointer(A::ElasticArray, i::Integer) = pointer(A.data, i) + +@inline Base.dataids(A::ElasticArray) = Base.dataids(A.data) + + +@inline function Base.copyto!( + dest::ElasticArray, + doffs::Integer, + src::AbsArr, + soffs::Integer, + N::Integer, +) + copyto!(dest.data, doffs, src, soffs, N) + return dest +end +@inline function Base.copyto!( + dest::AbsArr, + doffs::Integer, + src::ElasticArray, + soffs::Integer, + N::Integer, +) + copyto!(dest, doffs, src.data, soffs, N) +end + +@inline Base.copyto!(dest::ElasticArray, src::AbsArr) = (copyto!(dest.data, src); dest) +@inline Base.copyto!(dest::AbsArr, src::ElasticArray) = copyto!(dest, src.data) + +@inline function Base.copyto!( + dest::ElasticArray, + doffs::Integer, + src::ElasticArray, + soffs::Integer, + N::Integer, +) + copyto!(dest.data, doffs, src.data, soffs, N) + return dest +end +@inline function Base.copyto!(dest::ElasticArray, src::ElasticArray) + copyto!(dest.data, src.data) + return dest +end + + +@inline function Base.resize!(A::ElasticArray{<:Any,N}, dims::IDims{N}) where {N} + kernel_size, size_lastdim = _split_resize_dims(A, dims) + resize!(A.data, A.kernel_length.divisor * size_lastdim) + return A +end +@inline function Base.resize!(A::ElasticArray{<:Any,N}, dims::IVararg{N}) where {N} + resize!(A, dims) +end + +@inline function growlastdim!(A::ElasticArray, n::Integer) + n < 0 && throw(DomainError(n, "n must be positive")) + return resizelastdim!(A, size(A, ndims(A)) + n) +end + +@inline function shrinklastdim!(A::ElasticArray, n::Integer) + n < 0 && throw(DomainError(n, "n must be positive")) + return resizelastdim!(A, size(A, ndims(A)) - n) +end + +@inline resizelastdim!(A::ElasticArray, n::Integer) = resize!(A, (A.kernel_size..., n)) + + +@inline function Base.sizehint!(A::ElasticArray{<:Any,N}, dims::IDims{N}) where {N} + kernel_size, size_lastdim = _split_resize_dims(A, dims) + sizehint!(A.data, A.kernel_length.divisor * size_lastdim) + return A +end +@inline function Base.sizehint!(A::ElasticArray{<:Any,N}, dims::IVararg{N}) where {N} + sizehint!(A, dims) +end + + +function Base.append!(dest::ElasticArray, src::AbsArr) + if rem(length(eachindex(src)), dest.kernel_length) != 0 + throw(DimensionMismatch("Can't append, length of source array is incompatible")) + end + append!(dest.data, src) + return dest +end + +function Base.append!(dest::ElasticArray, iter) + for el in iter + append!(dest, el) + end + return dest +end + +function Base.prepend!(dest::ElasticArray, src::AbsArr) + if rem(length(eachindex(src)), dest.kernel_length) != 0 + throw(DimensionMismatch("Can't prepend, length of source array is incompatible")) + end + prepend!(dest.data, src) + return dest +end + + +#### +#### Broadcasting +#### + +Broadcast.BroadcastStyle(::Type{<:ElasticArray}) = Broadcast.ArrayStyle{ElasticArray}() + +function Base.similar( + bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{ElasticArray}}, + ::Type{ElType}, +) where {ElType} + similar(ElasticArray{ElType}, axes(bc)) +end + + +#### +#### Util +#### + +@inline function _split_resize_dims(A::ElasticArray, dims::IDims{N}) where {N} + kernel_size, size_lastdim = front(dims), last(dims) + if kernel_size != A.kernel_size + throw(ArgumentError("Can only resize last dimension of an ElasticArray")) + end + return kernel_size, size_lastdim +end + +@generated function check_elasticarray_parameters(::Type{T}, ::Val{N}, ::Val{M}) where {T,N,M} + !isa(N, Int) && return :(throw(ArgumentError("ElasticArray parameter N must be of type Int"))) + !isa(M, Int) && return :(throw(ArgumentError("ElasticArray parameter M must be of type Int"))) + M < 0 && return :(throw(DomainError($M, "ElasticArray parameter M cannot be negative"))) + N < 0 && return :(throw(DomainError($N, "ElasticArray parameter N cannot be negative"))) + if M != N - 1 + return :(throw(ArgumentError("ElasticArray{$T,$N,$M} does not satisfy requirement M == N - 1"))) + end +end diff --git a/src/flattenedarray.jl b/src/flattenedarray.jl index f4cda68..ffc44de 100644 --- a/src/flattenedarray.jl +++ b/src/flattenedarray.jl @@ -178,4 +178,7 @@ julia> B == reshape(hcat(B...), (2, 3, 2)) true ``` """ -flatten(A::AbsArr) = FlattenedArray(A) +flatten(A::AbsArr) = _flatten(A, inneraxes(A), eltype(A)) +_flatten(A::AbsArr, inax::Tuple{}, ::Type{<:AbsArr{<:Any,0}}) = FlattenedArray(A, inax) +_flatten(A::AbsArr, inax::Tuple, ::Type) = FlattenedArray(A, inax) +_flatten(A::AbsArr, ::Tuple{}, ::Type) = A diff --git a/src/slicedarray.jl b/src/slicedarray.jl index 3f85955..3d9eeca 100644 --- a/src/slicedarray.jl +++ b/src/slicedarray.jl @@ -1,31 +1,35 @@ -struct SlicedArray{T,N,M,P<:AbsArr,A} <: AbstractArray{T,N} +struct SlicedArray{T,N,M,P,A} <: AbstractArray{T,N} parent::P alongs::A - @inline function SlicedArray{T,N,M,P,A}(parent, alongs) where {T,N,M,P<:AbsArr,A} - check_slices_parameters(T, Val(N), Val(M), P, A) + @inline function SlicedArray{T,N,M,P,A}(parent, alongs) where {T,N,M,P,A} + check_slices_parameters(T::Type, Val(N::Int), Val(M::Int), P::Type, A::Type) new(parent, alongs) end end @inline function SlicedArray( parent::AbsArr{<:Any,L}, - alongs::NTuple{L,SBool}, + alongs::NTuple{L,TypedBool}, inaxes::NTuple{M,Any}, outaxes::NTuple{N,Any}, ) where {L,M,N} - I = ntuple(i -> first(outaxes[i]), Val(N)) - J = static_merge(alongs, ntuple(i -> Base.Slice(inaxes[i]), Val(M)), I) - T = viewtype(parent, J...) + I = ntuple(i -> (@_inline_meta; first(outaxes[i])), Val(N)) + J = parentindices(axes(parent), alongs, I) + T = viewtype(parent, J) SlicedArray{T,N,M,typeof(parent),typeof(alongs)}(parent, alongs) end -@inline function SlicedArray(parent::AbsArr{<:Any,L}, alongs::NTuple{L,SBool}) where {L} +@inline function SlicedArray(parent::AbsArr{<:Any,L}, alongs::NTuple{L,TypedBool}) where {L} paxes = axes(parent) - inaxes = static_filter(STrue(), alongs, paxes) - outaxes = static_filter(SFalse(), alongs, paxes) + inaxes = paxes[alongs] + outaxes = paxes[tuple_map(!, alongs)] SlicedArray(parent, alongs, inaxes, outaxes) end +@inline function SlicedArray(S::SlicedArray{<:Any,L}, alongs::NTuple{L,TypedBool}) where {L} + SlicedArray(S.parent, Base.setindex(S.alongs, alongs, tuple_map(!, S.alongs))) +end + @generated function check_slices_parameters( ::Type{T}, ::Val{N}, @@ -33,16 +37,15 @@ end ::Type{P}, ::Type{A}, ) where {T,N,M,P,A} - if !(N isa Int && M isa Int) - return :(throw(ArgumentError("SlicedArray parameters N and M must be of type Int"))) - elseif !(A <: NTuple{M + N,SBool}) - return :(throw(ArgumentError("SlicedArray parameter A should be of type NTuple{M+N,$SBool}"))) - elseif N < 0 || M < 0 || ndims(P) != N + M && (M > 0 && sum(unwrap, A.parameters) != M) - return :(throw(ArgumentError("Dimension mismatch in SlicedArray parameters"))) # got N=$N, M=$M, ndims(P)=$(ndims(P)), and sum(A)=$(sum(A))"))) - else - return nothing + A <: NTuple{M + N,TypedBool} || return :(throw(ArgumentError("SlicedArray parameter A should <: NTuple{M+N,$TypedBool}"))) + ntrue = 0 + for AA in A.parameters + AA === True && (ntrue += 1) + end + if !(P <: AbsArr{eltype(T),M+N} && N >= 0 && M >= 0 && M == ntrue) + return :(throw(ArgumentError("SlicedArray parameter P should be <: AbstractArray{eltype(T),M+N}"))) end - error("Internal error. Please file a bug report") + return nothing end @@ -84,7 +87,7 @@ julia> innersize(B) ``` """ @inline function slice(A::AbsArr, I::Vararg{Union{Colon,typeof(*)},L}) where {L} - alongs = ntuple(i -> (Base.@_inline_meta; I[i] isa Colon ? STrue() : SFalse()), Val(L)) + alongs = ntuple(i -> (Base.@_inline_meta; I[i] === Colon() ? True() : False()), Val(L)) SlicedArray(reshape(A, Val(L)), alongs) end @@ -115,8 +118,10 @@ julia> B = slice(A, 1, 3) julia> view(A, :, 1, :) === B[1] true ``` + `alongs` can also be specified using `StaticInteger` from StaticNumbers.jl for better type inference: + ```jldoctest julia> using StaticNumbers @@ -127,12 +132,12 @@ true ``` """ @inline function slice(A::AbsArr{<:Any,L}, alongs::TupleN{StaticOrInt}) where {L} - SlicedArray(A, ntuple(dim -> static_in(dim, alongs), Val(L))) + SlicedArray(A, ntuple(dim -> (@_inline_meta; static_in(dim, alongs)), Val(L))) end @inline slice(A::AbsArr, alongs::Vararg{StaticOrInt}) = slice(A, alongs) @inline function slice(A::AbsArr{<:Any,L}, ::Tuple{}) where {L} - SlicedArray(A, ntuple(_ -> static(false), Val(L))) + SlicedArray(A, ntuple(_ -> (@_inline_meta; False()), Val(L))) end @inline slice(A::AbsArr) = slice(A, ()) @@ -141,9 +146,9 @@ end #### Core Array Interface #### -@inline Base.axes(S::SlicedArray) = static_filter(SFalse(), S.alongs, axes(S.parent)) +@inline Base.axes(S::SlicedArray) = axes(S.parent)[tuple_map(!, S.alongs)] -@inline Base.size(S::SlicedArray) = static_filter(SFalse(), S.alongs, size(S.parent)) +@inline Base.size(S::SlicedArray) = size(S.parent)[tuple_map(!, S.alongs)] # standard Cartesian indexing @@ -167,7 +172,8 @@ end # If `I isa Vararg{Idx,N} && length(Base.index_ndims(I...)) == N` We can just forward the # indices to the parent array and drop the corresponding entries in `S.alongs` # (that is, we can drop `S.alongs[i]` iff -# `S.alongs[i] === static(false) && and Base.index_shape(I[i]) === ())`. +# `S.alongs[i] === False() && and Base.index_shape(I[i]) === ())`) rather than allocating a new +# output array. @inline function Base.getindex(S::SlicedArray{<:Any,N}, I::Vararg{Idx,N}) where {N} J = Base.to_indices(S, I) @@ -189,28 +195,30 @@ end @inbounds CartesianIndexer(A)[J...] end -@inline _maybe_wrap(A::AbstractArray, alongs::TupleN{SBool}) = SlicedArray(A, alongs) +@inline _maybe_wrap(A::AbstractArray, alongs::TupleN{TypedBool}) = SlicedArray(A, alongs) # A single element, so no need to wrap with a SlicedArray -@inline _maybe_wrap(A::AbstractArray{<:Any,M}, ::NTuple{M,STrue}) where {M} = A +@inline _maybe_wrap(A::AbstractArray{<:Any,M}, ::NTuple{M,True}) where {M} = A -# add/drop non-sliced dimensions (i.e. alongs[dim] == SFalse()) to match J -@inline function reslice(alongs::NTuple{L,SBool}, K::NTuple{L,Any}) where {L} +# add/drop non-sliced dimensions (i.e. alongs[dim] == False()) to match J +@inline function reslice(alongs::NTuple{L,TypedBool}, K::NTuple{L,Any}) where {L} (_reslice1(first(alongs), first(K))..., reslice(tail(alongs), tail(K))...) end reslice(::Tuple{}, ::Tuple{}) = () -@inline _reslice1(::STrue, k) = (static(true),) # keep inner dimension -@inline _reslice1(::SFalse, k) = _reslicefalse(k) +@inline _reslice1(::True, k) = (True(), ) # keep inner dimension +@inline _reslice1(::False, k) = _reslicefalse(k) @inline _reslicefalse(::Real) = () # drop this dimension -@inline _reslicefalse(::Colon) = (static(false),) # keep this dimension +@inline _reslicefalse(::Colon) = (False(), ) # keep this dimension @inline function _reslicefalse(::AbstractArray{<:Any,N}) where {N} - ntuple(_ -> static(false), Val(N)) + ntuple(_ -> False(), Val(N)) end @inline function parentindices(S::SlicedArray{<:Any,N,M}, I::NTuple{N,Any}) where {N,M} - inaxes = inneraxes(S) - slices = ntuple(i -> (Base.@_inline_meta; Base.Slice(inaxes[i])), Val(M)) - static_merge(S.alongs, slices, I) + parentindices(axes(S.parent), S.alongs, I) +end + +@inline function parentindices(paxes::NTuple{N,Any}, alongs::NTuple{N,TypedBool}, I::Tuple) where {N} + Base.setindex(tuple_map(ax->(@_inline_meta; Base.Slice(ax)), paxes), I, tuple_map(!, alongs)) end @@ -218,9 +226,7 @@ end ##### Misc ##### -function Base.:(==)(A::SlicedArray, B::SlicedArray) - A.alongs == B.alongs && A.parent == B.parent -end +Base.:(==)(A::SlicedArray, B::SlicedArray) = A.alongs == B.alongs && A.parent == B.parent Base.parent(S::SlicedArray) = S.parent @@ -228,6 +234,21 @@ Base.dataids(S::SlicedArray) = Base.dataids(S.parent) Base.copy(S::SlicedArray) = SlicedArray(copy(S.parent), S.alongs) + +Base.append!(S::SlicedArray, iter) = (append!(S.parent, iter); S) + +Base.prepend!(S::SlicedArray, iter) = (prepend!(S.parent, iter); S) + +function Base.resize!(S::SlicedArray{<:Any,N}, dims::NTuple{N,Integer}) where {N} + indims = innersize(S) + parentdims = static_merge(S.alongs, indims, dims) + resize!(S.parent, parentdims) + return S +end + +Base.resize!(S::SlicedArray, dims::Integer...) = resize!(S, dims) + + function Base.showarg(io::IO, A::SlicedArray, toplevel) print(io, "slice(") Base.showarg(io, parent(A), false) @@ -240,23 +261,23 @@ function Base.showarg(io::IO, A::SlicedArray, toplevel) end return nothing end -along2string(::STrue) = ':' -along2string(::SFalse) = '*' +along2string(::True) = ':' +along2string(::False) = '*' ##### ##### Extra ##### -@inline innersize(S::SlicedArray) = static_filter(STrue(), S.alongs, size(S.parent)) -@inline inneraxes(S::SlicedArray) = static_filter(STrue(), S.alongs, axes(S.parent)) +@inline innersize(S::SlicedArray) = size(S.parent)[S.alongs] +@inline inneraxes(S::SlicedArray) = axes(S.parent)[S.alongs] flatten(S::SlicedArray) = S.parent function mapslices(f, A::AbstractArray; dims::TupleN{StaticOrInt}) S = slice(A, dims) - B = _alloc_keepdims(S, f(first(S))) + B = _alloc_mapslices(S, f(first(S))) @assert axes(S) == axes(B) for I in eachindex(S, B) # TODO start from second _unsafe_copy_inner!(B, f(S[I]), I) @@ -282,7 +303,7 @@ end return B end -function _alloc_keepdims(S::SlicedArray{<:Any,N,M}, b1) where {T,N,M} +function _alloc_mapslices(S::SlicedArray{<:Any,N,M}, b1) where {T,N,M} innerax = _reshape_axes(axes(b1), Val(M)) innersz = ntuple(i -> (Base.@_inline_meta; length(innerax[i])), Val(M)) parentsz = static_merge(S.alongs, innersz, size(S)) diff --git a/src/typedbool.jl b/src/typedbool.jl new file mode 100644 index 0000000..b0055e2 --- /dev/null +++ b/src/typedbool.jl @@ -0,0 +1,66 @@ +struct True end + +struct False end + +const TypedBool = Union{True, False} + +@pure Base.:(!)(::False) = True() +@pure Base.:(!)(::True) = False() + + +@generated function static_merge(::Bys, x::X, y::Y) where {Bys<:TupleN{TypedBool},X<:Tuple,Y<:Tuple} + i = j = 0 + xy = [] + for By in Bys.parameters + if By === True + push!(xy, :(x[$(i += 1)])) + i > tuple_length(X) && return :(throw(BoundsError(x, $i))) + else + push!(xy, :(y[$(j += 1)])) + j > tuple_length(Y) && return :(throw(BoundsError(y, $j))) + end + end + return :(@_inline_meta; $(Expr(:tuple, xy...))) +end + +# See: https://github.com/JuliaLang/julia/issues/33126 +static_in(x::StaticOrInt, itr::TupleN{StaticOrInt}) = _static_in(x, itr) +@pure function _static_in(x::StaticOrInt, itr::TupleN{StaticOrInt}) + for y in itr + unstatic(y) === unstatic(x) && return True() + end + return False() +end + + +@inline function Base.getindex(xs::NTuple{N,Any}, I::NTuple{N,TypedBool}) where {N} + _getindex(xs, I) +end + +@inline function _getindex(xs::NTuple{N,Any}, I::NTuple{N,TypedBool}) where {N} + rest = _getindex(tail(xs), tail(I)) + return first(I) === True() ? (first(xs), rest...) : rest +end +_getindex(::Tuple{}, ::Tuple{}) = () + + +@inline function Base.setindex(t::NTuple{N,Any}, v::Tuple, I::NTuple{N,TypedBool}) where {N} + _setindex(t, v, I) +end + +@inline function _setindex(t::NTuple{N,Any}, v::Tuple, I::NTuple{N,TypedBool}) where {N} + if first(I) === True() + (first(v), _setindex(tail(t), tail(v), tail(I))...) + else + (first(t), _setindex(tail(t), v, tail(I))...) + end +end +_setindex(::Tuple{}, ::Tuple{}, ::Tuple{}) = () +function _setindex(::Tuple{}, ::Tuple, ::Tuple{}) + throw(DimensionMismatch("Cannot assign more values than indices")) +end + + +@inline tuple_map(f::F, t::NTuple{N,Any}) where {F,N} = ntuple(i -> (@_inline_meta; f(t[i])), Val(N)) + +@inline Base.findall(t::NTuple{N,TypedBool}) where {N} = ntuple(identity, Val(N))[t] \ No newline at end of file diff --git a/test/Manifest.toml b/test/Manifest.toml new file mode 100644 index 0000000..4902c0f --- /dev/null +++ b/test/Manifest.toml @@ -0,0 +1,461 @@ +# This file is machine-generated - editing it directly is not advised + +[[AbstractFFTs]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "051c95d6836228d120f5f4b984dd5aba1624f716" +uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" +version = "0.5.0" + +[[Adapt]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "c88cfc7f9c1f9f8633cddf0b56e86302b70f64c5" +uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +version = "1.0.1" + +[[ArrayLayouts]] +deps = ["FillArrays", "LinearAlgebra"] +git-tree-sha1 = "bc779df8d73be70e4e05a63727d3a4dfb4c52b1f" +uuid = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" +version = "0.1.5" + +[[AssetRegistry]] +deps = ["Distributed", "JSON", "Pidfile", "SHA", "Test"] +git-tree-sha1 = "b25e88db7944f98789130d7b503276bc34bc098e" +uuid = "bf4720bc-e11a-5d0c-854e-bdca1663c893" +version = "0.1.0" + +[[AutoHashEquals]] +git-tree-sha1 = "45bb6705d93be619b81451bb2006b7ee5d4e4453" +uuid = "15f4f7f2-30c1-5605-9d31-71845cf9641f" +version = "0.2.0" + +[[Base64]] +uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" + +[[BenchmarkTools]] +deps = ["JSON", "Logging", "Printf", "Statistics", "UUIDs"] +git-tree-sha1 = "9e62e66db34540a0c919d72172cc2f642ac71260" +uuid = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +version = "0.5.0" + +[[BinaryProvider]] +deps = ["Libdl", "SHA"] +git-tree-sha1 = "5b08ed6036d9d3f0ee6369410b830f8873d4024c" +uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232" +version = "0.5.8" + +[[CSTParser]] +deps = ["Tokenize"] +git-tree-sha1 = "7d10b92c4d9951ccf3009d960d9b66883c174474" +uuid = "00ebfdb7-1f24-5e51-bd34-a7502290713f" +version = "2.2.0" + +[[CommonSubexpressions]] +deps = ["Test"] +git-tree-sha1 = "efdaf19ab11c7889334ca247ff4c9f7c322817b0" +uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" +version = "0.2.0" + +[[CompilerSupportLibraries_jll]] +deps = ["Libdl", "Pkg"] +git-tree-sha1 = "ff8101d6736414bc93c0f8df77b1e4095ca988c3" +uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" +version = "0.3.2+0" + +[[DataAPI]] +git-tree-sha1 = "674b67f344687a88310213ddfa8a2b3c76cc4252" +uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" +version = "1.1.0" + +[[DataStructures]] +deps = ["InteractiveUtils", "OrderedCollections"] +git-tree-sha1 = "5a431d46abf2ef2a4d5d00bd0ae61f651cf854c8" +uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" +version = "0.17.10" + +[[DataValueInterfaces]] +git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" +uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464" +version = "1.0.0" + +[[Dates]] +deps = ["Printf"] +uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" + +[[DiffResults]] +deps = ["StaticArrays"] +git-tree-sha1 = "da24935df8e0c6cf28de340b958f6aac88eaa0cc" +uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" +version = "1.0.2" + +[[DiffRules]] +deps = ["NaNMath", "Random", "SpecialFunctions"] +git-tree-sha1 = "eb0c34204c8410888844ada5359ac8b96292cfd1" +uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" +version = "1.0.1" + +[[Distributed]] +deps = ["Random", "Serialization", "Sockets"] +uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" + +[[FFTW]] +deps = ["AbstractFFTs", "FFTW_jll", "IntelOpenMP_jll", "Libdl", "LinearAlgebra", "MKL_jll", "Reexport"] +git-tree-sha1 = "109d82fa4b00429f9afcce873e9f746f11f018d3" +uuid = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" +version = "1.2.0" + +[[FFTW_jll]] +deps = ["Libdl", "Pkg"] +git-tree-sha1 = "6c975cd606128d45d1df432fb812d6eb10fee00b" +uuid = "f5851436-0d7a-5f13-b9de-f02708fd171a" +version = "3.3.9+5" + +[[FileWatching]] +uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" + +[[FillArrays]] +deps = ["LinearAlgebra", "Random", "SparseArrays"] +git-tree-sha1 = "3eb5253af6186eada40de3df524a1c10f0c6bfa2" +uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" +version = "0.8.6" + +[[ForwardDiff]] +deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "NaNMath", "Random", "SpecialFunctions", "StaticArrays"] +git-tree-sha1 = "869540e4367122fbffaace383a5bdc34d6e5e5ac" +uuid = "f6369f11-7733-5829-9624-2563aa707210" +version = "0.10.10" + +[[GitForge]] +deps = ["Dates", "HTTP", "JSON2"] +git-tree-sha1 = "4469573ba6e4c262ba3c3018de2166c063ec5c2d" +uuid = "8f6bce27-0656-5410-875b-07a5572985df" +version = "0.1.5" + +[[GitHub]] +deps = ["Base64", "Dates", "HTTP", "JSON", "MbedTLS", "Sockets"] +git-tree-sha1 = "f8f9c05004861b6680c1bd363e7e2fcff602a283" +uuid = "bc5e4493-9b4d-5f90-b8aa-2b2bcaad7a26" +version = "5.1.4" + +[[HTTP]] +deps = ["Base64", "Dates", "IniFile", "MbedTLS", "Sockets"] +git-tree-sha1 = "cd60d9a575d3b70c026d7e714212fd4ecf86b4bb" +uuid = "cd3eb016-35fb-5094-929b-558a96fad6f3" +version = "0.8.13" + +[[Hiccup]] +deps = ["MacroTools", "Test"] +git-tree-sha1 = "6187bb2d5fcbb2007c39e7ac53308b0d371124bd" +uuid = "9fb69e20-1954-56bb-a84f-559cc56a8ff7" +version = "0.2.2" + +[[IRTools]] +deps = ["InteractiveUtils", "MacroTools", "Test"] +git-tree-sha1 = "1a4355e4b5b50be2311ebb644f34f3306dbd0410" +uuid = "7869d1d1-7146-5819-86e3-90919afe41df" +version = "0.3.1" + +[[IniFile]] +deps = ["Test"] +git-tree-sha1 = "098e4d2c533924c921f9f9847274f2ad89e018b8" +uuid = "83e8ac13-25f8-5344-8a64-a9f2b223428f" +version = "0.5.0" + +[[IntelOpenMP_jll]] +deps = ["Libdl", "Pkg"] +git-tree-sha1 = "fb8e1c7a5594ba56f9011310790e03b5384998d6" +uuid = "1d5cc7b8-4909-519e-a0f8-d0f5ad9712d0" +version = "2018.0.3+0" + +[[InteractiveUtils]] +deps = ["Markdown"] +uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" + +[[IteratorInterfaceExtensions]] +git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" +uuid = "82899510-4779-5014-852e-03e436cf321d" +version = "1.0.0" + +[[JSON]] +deps = ["Dates", "Mmap", "Parsers", "Unicode"] +git-tree-sha1 = "b34d7cef7b337321e97d22242c3c2b91f476748e" +uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" +version = "0.21.0" + +[[JSON2]] +deps = ["Dates", "Parsers", "Test"] +git-tree-sha1 = "6cbbbab27d9411946725f5d5c91e8b8fb5f7d5db" +uuid = "2535ab7d-5cd8-5a07-80ac-9b1792aadce3" +version = "0.3.1" + +[[JuliaFormatter]] +deps = ["CSTParser", "DataStructures", "Test", "Tokenize"] +git-tree-sha1 = "a6c9c29d1dfab0f62b617064cca73e1506e599ae" +uuid = "98e50ef6-434e-11e9-1051-2b60c6c9e899" +version = "0.3.9" + +[[Lazy]] +deps = ["MacroTools"] +git-tree-sha1 = "0bd934e15f5df97414aa81abf74ba8a2d5042964" +uuid = "50d2b5c4-7a5e-59d5-8109-a42b560f39c0" +version = "0.15.0" + +[[LibGit2]] +deps = ["Printf"] +uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" + +[[Libdl]] +uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" + +[[LinearAlgebra]] +deps = ["Libdl"] +uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + +[[Logging]] +uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" + +[[LyceumCore]] +deps = ["StaticNumbers"] +git-tree-sha1 = "503641247a835f656eedf1c4c9141b5d60217bf1" +uuid = "e5bd5517-2193-49f0-ba9c-d5a8508cb639" +version = "0.1.0" + +[[LyceumDevTools]] +deps = ["Base64", "BenchmarkTools", "Dates", "GitHub", "HTTP", "JuliaFormatter", "LibGit2", "MacroTools", "Markdown", "Parameters", "Pkg", "PkgTemplates", "Reexport", "Registrator", "RegistryTools", "Test"] +git-tree-sha1 = "657b0d09a5e8ae33920835e40591fe6a8a2d4b4e" +uuid = "fd23256c-5a67-41c4-8f5a-c8cf5526e505" +version = "0.3.1" + +[[MKL_jll]] +deps = ["IntelOpenMP_jll", "Libdl", "Pkg"] +git-tree-sha1 = "720629cc8cbd12c146ca01b661fd1a6cf66e2ff4" +uuid = "856f044c-d86e-5d09-b602-aeab76dc8ba7" +version = "2019.0.117+2" + +[[MacroTools]] +deps = ["Markdown", "Random"] +git-tree-sha1 = "f7d2e3f654af75f01ec49be82c231c382214223a" +uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +version = "0.5.5" + +[[Markdown]] +deps = ["Base64"] +uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" + +[[MbedTLS]] +deps = ["BinaryProvider", "Dates", "Libdl", "Random", "Sockets"] +git-tree-sha1 = "85f5947b53c8cfd53ccfa3f4abae31faa22c2181" +uuid = "739be429-bea8-5141-9913-cc70e7f3736d" +version = "0.7.0" + +[[Mmap]] +uuid = "a63ad114-7e13-5084-954f-fe012c677804" + +[[Mustache]] +deps = ["Printf", "Tables"] +git-tree-sha1 = "f39de3a12232eb47bd0629b3a661054287780276" +uuid = "ffc61752-8dc7-55ee-8c37-f3e9cdd09e70" +version = "0.5.13" + +[[Mux]] +deps = ["AssetRegistry", "Base64", "HTTP", "Hiccup", "Lazy", "Pkg", "Sockets", "Test", "WebSockets"] +git-tree-sha1 = "5b41f03d63400c290bab4e1a49fb9ac36de1084a" +uuid = "a975b10e-0019-58db-a62f-e48ff68538c9" +version = "0.7.0" + +[[NNlib]] +deps = ["BinaryProvider", "Libdl", "LinearAlgebra", "Requires", "Statistics"] +git-tree-sha1 = "d9f196d911f55aeaff11b11f681b135980783824" +uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +version = "0.6.6" + +[[NaNMath]] +git-tree-sha1 = "928b8ca9b2791081dc71a51c55347c27c618760f" +uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" +version = "0.3.3" + +[[OpenSpecFun_jll]] +deps = ["CompilerSupportLibraries_jll", "Libdl", "Pkg"] +git-tree-sha1 = "d51c416559217d974a1113522d5919235ae67a87" +uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" +version = "0.5.3+3" + +[[OrderedCollections]] +deps = ["Random", "Serialization", "Test"] +git-tree-sha1 = "c4c13474d23c60d20a67b217f1d7f22a40edf8f1" +uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +version = "1.1.0" + +[[Parameters]] +deps = ["OrderedCollections"] +git-tree-sha1 = "b62b2558efb1eef1fa44e4be5ff58a515c287e38" +uuid = "d96e819e-fc66-5662-9728-84c9c7592b0a" +version = "0.12.0" + +[[Parsers]] +deps = ["Dates", "Test"] +git-tree-sha1 = "d6d82d5bdbb75048e574cd2d2c89dfbf2c74250c" +uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" +version = "1.0.0" + +[[Pidfile]] +deps = ["FileWatching", "Test"] +git-tree-sha1 = "1ffd82728498b5071cde851bbb7abd780d4445f3" +uuid = "fa939f87-e72e-5be4-a000-7fc836dbe307" +version = "1.1.0" + +[[Pkg]] +deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"] +uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" + +[[PkgTemplates]] +deps = ["Dates", "InteractiveUtils", "LibGit2", "Mustache", "Pkg", "REPL", "URIParser"] +git-tree-sha1 = "fb5de58bbf1823c4373f813a5611d37e57edb109" +uuid = "14b8a8f1-9102-5b29-a752-f990bacb7fe1" +version = "0.6.4" + +[[Printf]] +deps = ["Unicode"] +uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" + +[[REPL]] +deps = ["InteractiveUtils", "Markdown", "Sockets"] +uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" + +[[Random]] +deps = ["Serialization"] +uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" + +[[Reexport]] +deps = ["Pkg"] +git-tree-sha1 = "7b1d07f411bc8ddb7977ec7f377b97b158514fe0" +uuid = "189a3867-3050-52da-a836-e630ba90ab69" +version = "0.2.0" + +[[Registrator]] +deps = ["AutoHashEquals", "Base64", "Dates", "Distributed", "FileWatching", "GitForge", "GitHub", "HTTP", "JSON", "LibGit2", "Logging", "MbedTLS", "Mustache", "Mux", "Pkg", "RegistryTools", "Serialization", "Sockets", "TimeToLive", "UUIDs", "ZMQ"] +git-tree-sha1 = "722d249d1f29be9281dafd14bdf1dd5cb25cf225" +uuid = "4418983a-e44d-11e8-3aec-9789530b3b3e" +version = "1.1.0" + +[[RegistryTools]] +deps = ["AutoHashEquals", "LibGit2", "Pkg", "UUIDs"] +git-tree-sha1 = "14873d5a5c36b53897b47e64d123e363176f6cde" +uuid = "d1eb7eb1-105f-429d-abf5-b0f65cb9e2c4" +version = "1.3.3" + +[[Requires]] +deps = ["UUIDs"] +git-tree-sha1 = "d37400976e98018ee840e0ca4f9d20baa231dc6b" +uuid = "ae029012-a4dd-5104-9daa-d747884805df" +version = "1.0.1" + +[[SHA]] +uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" + +[[Serialization]] +uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" + +[[Sockets]] +uuid = "6462fe0b-24de-5631-8697-dd941f90decc" + +[[SparseArrays]] +deps = ["LinearAlgebra", "Random"] +uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + +[[SpecialFunctions]] +deps = ["OpenSpecFun_jll"] +git-tree-sha1 = "e19b98acb182567bcb7b75bb5d9eedf3a3b5ec6c" +uuid = "276daf66-3868-5448-9aa4-cd146d93841b" +version = "0.10.0" + +[[StaticArrays]] +deps = ["LinearAlgebra", "Random", "Statistics"] +git-tree-sha1 = "5a3bcb6233adabde68ebc97be66e95dcb787424c" +uuid = "90137ffa-7385-5640-81b9-e52037218182" +version = "0.12.1" + +[[StaticNumbers]] +deps = ["Requires"] +git-tree-sha1 = "8fedba0c674686bfa52dba5f68acf916b824f53a" +uuid = "c5e4b96a-f99f-5557-8ed2-dc63ef9b5131" +version = "0.3.2" + +[[Statistics]] +deps = ["LinearAlgebra", "SparseArrays"] +uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" + +[[TableTraits]] +deps = ["IteratorInterfaceExtensions"] +git-tree-sha1 = "b1ad568ba658d8cbb3b892ed5380a6f3e781a81e" +uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" +version = "1.0.0" + +[[Tables]] +deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "TableTraits", "Test"] +git-tree-sha1 = "242b7fde70b8bc6a30d6476adf17ca3cf1ced6ee" +uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" +version = "1.0.3" + +[[Test]] +deps = ["Distributed", "InteractiveUtils", "Logging", "Random"] +uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[[TimeToLive]] +deps = ["Dates", "Test"] +git-tree-sha1 = "43defcaf72b89b047f11b778cd83b71ac3e418b0" +uuid = "37f0c46e-897f-50ef-b453-b26c3eed3d6c" +version = "0.2.0" + +[[Tokenize]] +git-tree-sha1 = "73c00ad506d88a7e8e4f90f48a70943101728227" +uuid = "0796e94c-ce3b-5d07-9a54-7f471281c624" +version = "0.5.8" + +[[URIParser]] +deps = ["Test", "Unicode"] +git-tree-sha1 = "6ddf8244220dfda2f17539fa8c9de20d6c575b69" +uuid = "30578b45-9adc-5946-b283-645ec420af67" +version = "0.4.0" + +[[UUIDs]] +deps = ["Random", "SHA"] +uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" + +[[Unicode]] +uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" + +[[UnsafeArrays]] +git-tree-sha1 = "1de6ef280110c7ad3c5d2f7a31a360b57a1bde21" +uuid = "c4a57d5a-5b31-53a6-b365-19f8c011fbd6" +version = "1.0.0" + +[[WebSockets]] +deps = ["Base64", "Dates", "Distributed", "HTTP", "Logging", "Random", "Sockets", "Test"] +git-tree-sha1 = "13f763d38c7a05688938808b49cb29b18b60c8c8" +uuid = "104b5d7c-a370-577a-8038-80a2059c5097" +version = "1.5.2" + +[[ZMQ]] +deps = ["FileWatching", "Sockets", "ZeroMQ_jll"] +git-tree-sha1 = "adb2d52aa12c8284da12714f35d2b21fc3d5b2bb" +uuid = "c2297ded-f4af-51ae-bb23-16f91089e4e1" +version = "1.2.0" + +[[ZeroMQ_jll]] +deps = ["Libdl", "Pkg"] +git-tree-sha1 = "d24fc0004686b534cc7518412b626deeea0b0208" +uuid = "8f1865be-045e-5c20-9c9f-bfbfb0764568" +version = "4.3.2+1" + +[[Zygote]] +deps = ["ArrayLayouts", "DiffRules", "FFTW", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"] +git-tree-sha1 = "9688fce24bd8a9468fed12f3d5206099a39054dc" +uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" +version = "0.4.12" + +[[ZygoteRules]] +deps = ["MacroTools"] +git-tree-sha1 = "b3b4882cc9accf6731a08cc39543fbc6b669dca8" +uuid = "700de1a5-db45-46bc-99cf-38207098b444" +version = "0.2.0" diff --git a/test/Project.toml b/test/Project.toml new file mode 100644 index 0000000..375a66f --- /dev/null +++ b/test/Project.toml @@ -0,0 +1,10 @@ +[deps] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +LyceumCore = "e5bd5517-2193-49f0-ba9c-d5a8508cb639" +LyceumDevTools = "fd23256c-5a67-41c4-8f5a-c8cf5526e505" +Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +StaticNumbers = "c5e4b96a-f99f-5557-8ed2-dc63ef9b5131" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +UnsafeArrays = "c4a57d5a-5b31-53a6-b365-19f8c011fbd6" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/test/preamble.jl b/test/preamble.jl index cd50bee..8d56ade 100644 --- a/test/preamble.jl +++ b/test/preamble.jl @@ -1,23 +1,19 @@ -using Adapt: Adapt, adapt -using AxisArrays: AxisArrays - using Base: index_shape, index_dimsum, index_ndims, to_indices -using Base: mightalias, unalias, dataids +using Base: mightalias, unalias, dataids, unsafe_convert -using BenchmarkTools -using Parameters -using Random +using Adapt: Adapt, adapt using LyceumCore using LyceumDevTools +using Parameters +using Random using StaticNumbers using Test using UnsafeArrays -using Zygote using SpecialArrays -using SpecialArrays: CartesianIndexer include("util.jl") + testdims(L::Integer) = ntuple(i -> 3 + i, Val(unstatic(L))) diff --git a/test/runtests.jl b/test/runtests.jl index d939d69..c40ca79 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,15 +2,6 @@ module TestSpecialArrays include("preamble.jl") -if isempty(ARGS) - TEST_FILES = - sort([file for file in readdir(@__DIR__) if match(r"^test_.*\.jl$", file) !== nothing]) -else - TEST_FILES = ARGS -end - -@testset "$file" for file in TEST_FILES - include(file) -end +@includetests ProgressTestSet "SpecialArrays" end # module diff --git a/test/test_batchedvector.jl b/test/test_batchedvector.jl new file mode 100644 index 0000000..a425d7e --- /dev/null +++ b/test/test_batchedvector.jl @@ -0,0 +1,63 @@ +module TestBatchedVector + +include("preamble.jl") + +function makedata(V::Type, nbatches::Integer) + batch_lengths = testdims(nbatches) + offsets = [0, cumsum([batch_lengths...])...] + nested = Vector{V}[rand(V, bl) for bl in batch_lengths] + flat = nbatches == 0 ? V[] : reduce(vcat, nested) + return (batch_lengths = batch_lengths, offsets = offsets, nested=nested, flat=flat) +end + + +@testset "nbatches = $nbatches, V = $V" for nbatches in (1, 10), V in (Float64, ) + @testset "constructors" begin + @unpack batch_lengths, offsets, nested, flat = makedata(V, nbatches) + Expected = BatchedVector{<:AbsArr{V,innerndims(nested)},typeof(flat)} + @test BatchedVector(flat, offsets) isa Expected + @test_inferred BatchedVector(flat, offsets) + end + + @testset "Extra" begin + @unpack batch_lengths, offsets, nested, flat = makedata(V, nbatches) + Expected = BatchedVector{<:AbsArr{V,innerndims(nested)},typeof(flat)} + @test batch(flat, batch_lengths) isa Expected + @test_inferred batch(flat, batch_lengths) + let B1 = batch(flat, batch_lengths) + B2 = batchlike(copy(flat), B1) + @test B1 == B2 + @test B1.parent !== B2.parent + end + end + + let nbatches = nbatches, V = V + test_array_AB() do + @unpack batch_lengths, nested, flat = makedata(V, nbatches) + A = batch(flat, batch_lengths) + B = nested + return A, B + end + end +end + +@testset "Misc" begin + d = makedata(Float64, 5) + B = BatchedVector(d.flat, d.offsets) + @test Base.dataids(B) === (Base.dataids(d.flat)..., Base.dataids(d.offsets)...) + @test parent(B) === d.flat +end + +@testset "UnsafeArrays" begin + d = makedata(Float64, 5) + B = BatchedVector(d.flat, d.offsets) + U = uview(B) + @test U == B + @test U.parent isa UnsafeArray{eltype(d.flat),1} + @test U.offsets isa Vector{Int} + @test_inferred uview(B) +end + + +end # module + diff --git a/test/test_elasticarray.jl b/test/test_elasticarray.jl new file mode 100644 index 0000000..779a215 --- /dev/null +++ b/test/test_elasticarray.jl @@ -0,0 +1,249 @@ +module TestElasticArray + +using Base.MultiplicativeInverses: SignedMultiplicativeInverse +using SpecialArrays: growlastdim!, shrinklastdim!, resizelastdim! + +include("preamble.jl") + +function makedata(T::Type, N::Integer) + dims = testdims(N) + kernel_size = front(dims) + data = rand(prod(dims))::Vector + return ( + M=N - 1, + dims=dims, + kernel_size=kernel_size, + kernel_length = SignedMultiplicativeInverse(prod(kernel_size)), + data = data, + A = reshape(copy(data), dims), + ) +end + + +@testset "N = $N, T = $T" for N in 1:2, T in (Float64, ) + @testset "constructors" begin + @unpack M, dims, kernel_size, kernel_length, data = makedata(T, N) + + @test ElasticArray{T,N,M}(kernel_size, data) isa ElasticArray{T,N,M} + @test ElasticArray{T,N,M}(kernel_size, data).kernel_size === kernel_size + @test ElasticArray{T,N,M}(kernel_size, data).kernel_length === kernel_length + @test ElasticArray{T,N,M}(kernel_size, data).data === data + + @test ElasticArray{T,N}(kernel_size, data) isa ElasticArray{T,N,M} + @test ElasticArray{T,N}(kernel_size, data).kernel_size === kernel_size + @test ElasticArray{T,N}(kernel_size, data).kernel_length === kernel_length + @test ElasticArray{T,N}(kernel_size, data).data === data + + @test ElasticArray{T}(kernel_size, data) isa ElasticArray{T,N,M} + @test ElasticArray{T}(kernel_size, data).kernel_size === kernel_size + @test ElasticArray{T}(kernel_size, data).kernel_length === kernel_length + @test ElasticArray{T}(kernel_size, data).data === data + end + + @testset "undef" begin + @unpack M, dims, kernel_size, kernel_length, data = makedata(T, N) + + @test ElasticArray{T,N,M}(undef, dims) isa ElasticArray{T,N,M} + @test ElasticArray{T,N,M}(undef, dims).kernel_size === kernel_size + @test ElasticArray{T,N,M}(undef, dims).kernel_length === kernel_length + + @test ElasticArray{T,N}(undef, dims) isa ElasticArray{T,N,M} + @test ElasticArray{T,N}(undef, dims).kernel_size === kernel_size + @test ElasticArray{T,N}(undef, dims).kernel_length === kernel_length + + @test ElasticArray{T}(undef, dims) isa ElasticArray{T,N,M} + @test ElasticArray{T}(undef, dims).kernel_size === kernel_size + @test ElasticArray{T}(undef, dims).kernel_length === kernel_length + + + @test ElasticArray{T,N,M}(undef, dims...) isa ElasticArray{T,N,M} + @test ElasticArray{T,N,M}(undef, dims...).kernel_size === kernel_size + @test ElasticArray{T,N,M}(undef, dims...).kernel_length === kernel_length + + @test ElasticArray{T,N}(undef, dims...) isa ElasticArray{T,N,M} + @test ElasticArray{T,N}(undef, dims...).kernel_size === kernel_size + @test ElasticArray{T,N}(undef, dims...).kernel_length === kernel_length + + @test ElasticArray{T}(undef, dims...) isa ElasticArray{T,N,M} + @test ElasticArray{T}(undef, dims...).kernel_size === kernel_size + @test ElasticArray{T}(undef, dims...).kernel_length === kernel_length + end + + @testset "from array" begin + @unpack M, dims, kernel_size, kernel_length, data, A = makedata(T, N) + + @test ElasticArray{T,N,M}(A) isa ElasticArray{T,N,M} + @test ElasticArray{T,N,M}(A).kernel_size === kernel_size + @test ElasticArray{T,N,M}(A).kernel_length === kernel_length + @test ElasticArray{T,N,M}(A).data == vec(A) + @test ElasticArray{T,N,M}(A).data !== A + + @test ElasticArray{T,N}(A) isa ElasticArray{T,N,M} + @test ElasticArray{T,N}(A).kernel_size === kernel_size + @test ElasticArray{T,N}(A).kernel_length === kernel_length + @test ElasticArray{T,N}(A).data == vec(A) + @test ElasticArray{T,N}(A).data !== A + + @test ElasticArray{T}(A) isa ElasticArray{T,N,M} + @test ElasticArray{T}(A).kernel_size === kernel_size + @test ElasticArray{T}(A).kernel_length === kernel_length + @test ElasticArray{T}(A).data == vec(A) + @test ElasticArray{T}(A).data !== A + end + + @testset "convert" begin + @unpack M, dims, kernel_size, kernel_length, data, A = makedata(T, N) + + @test convert(ElasticArray{T,N,M}, A) isa ElasticArray{T,N,M} + @test convert(ElasticArray{T,N,M}, A).kernel_size === kernel_size + @test convert(ElasticArray{T,N,M}, A).kernel_length === kernel_length + @test convert(ElasticArray{T,N,M}, A).data == vec(A) + @test convert(ElasticArray{T,N,M}, A).data !== A + + @test convert(ElasticArray{T,N}, A) isa ElasticArray{T,N,M} + @test convert(ElasticArray{T,N}, A).kernel_size === kernel_size + @test convert(ElasticArray{T,N}, A).kernel_length === kernel_length + @test convert(ElasticArray{T,N}, A).data == vec(A) + @test convert(ElasticArray{T,N}, A).data !== A + + @test convert(ElasticArray{T}, A) isa ElasticArray{T,N,M} + @test convert(ElasticArray{T}, A).kernel_size === kernel_size + @test convert(ElasticArray{T}, A).kernel_length === kernel_length + @test convert(ElasticArray{T}, A).data == vec(A) + @test convert(ElasticArray{T}, A).data !== A + end + + @testset "misc" begin + data = makedata(T, N) + E = ElasticArray(data.kernel_size, data.data) + @test dataids(E) === dataids(data.data) + end + + @testset "resize!" begin + function resize_test(delta::Integer) + data = makedata(T, N) + E = ElasticArray(data.kernel_size, data.data) + + new_size = (front(size(E))..., last(size(E)) + delta) + cmp_idxs = (front(axes(E))..., 1:(last(size(E))+min(0, delta))) + @test sizehint!(E, new_size...) === E + @test resize!(E, new_size...) === E + @test size(E) === new_size + @test E[cmp_idxs...] == data.A[cmp_idxs...] + end + + resize_test(0) + resize_test(2) + resize_test(-2) + end + + @testset "append!/prepend!" begin + let data = makedata(T, N) + E = ElasticArray(data.kernel_size, data.data) + + dims = front(size(E)) + len_lastdim = last(size(E)) + V = Array{T,length(dims)}[] + + for i = 1:4 + push!(V, rand!(zeros(dims...))) + append!(E, last(V)) + end + + @test size(E) == (dims..., len_lastdim + length(V)) + @test all(1:length(V)) do i + selectdim(E, ndims(E), i + len_lastdim) == V[i] + end + end + + let data = makedata(T, N) + E = ElasticArray(data.kernel_size, data.data) + dims = front(size(E)) + len_lastdim = last(size(E)) + V = Array{T,length(dims)}[] + + for i = 1:4 + pushfirst!(V, rand!(zeros(dims...))) + prepend!(E, first(V)) + end + + @test size(E) == (dims..., len_lastdim + length(V)) + @test all(1:length(V)) do i + selectdim(E, ndims(E), i) == V[i] + end + end + end + + let N = N, T = T + test_array_AB() do + data = makedata(T, N) + A = ElasticArray(data.kernel_size, data.data) + B = data.A + return A, B + end + end + + @testset "growlastdim!, shrinklastdim!, resizelastdim!" begin + let data = makedata(T, N) + E = ElasticArray(data.kernel_size, data.data) + dims, d = front(size(E)), last(size(E)) + growlastdim!(E, 2) + @test size(E) == (dims..., d + 2) + end + let data = makedata(T, N) + E = ElasticArray(data.kernel_size, data.data) + dims, d = front(size(E)), last(size(E)) + shrinklastdim!(E, 2) + @test size(E) == (dims..., d - 2) + end + let data = makedata(T, N) + E = ElasticArray(data.kernel_size, data.data) + dims, d = front(size(E)), last(size(E)) + resizelastdim!(E, 2) + @test size(E) == (dims..., 2) + end + end +end + +@testset "pointer/unsafe_convert" begin + data = makedata(Float64, 2) + E = ElasticArray(data.kernel_size, data.data) + + @test pointer(E) === pointer(parent(E)) + @test pointer(E, length(E)) === pointer(parent(E), length(E)) + @test unsafe_convert(Ptr{eltype(E)}, E) === unsafe_convert(Ptr{eltype(E)}, parent(E)) +end + +@testset "basic math" begin + T = Float64 + + E1 = rand!(ElasticArray{T}(undef, 9, 9)) + E2 = rand!(ElasticArray{T}(undef, 9, 9)) + E3 = rand!(ElasticArray{T}(undef, 9, 7)) + + A1 = Array(E1) + A2 = Array(E2) + A3 = Array(E3) + + @test_inferred 2 * E1 + @test 2 * E1 isa ElasticArray{T,2,1} + @test 2 * E1 == 2 * A1 + + @test_inferred E1 .+ 2 + @test E1 .+ 2 isa ElasticArray{T,2,1} + @test E1 .+ 2 == A1 .+ 2 + + @test_inferred E1 + E2 + @test E1 + E2 isa ElasticArray{T,2,1} + @test E1 + E2 == A1 + A2 + + @test_inferred E1 * E2 + @test E1 * E2 isa ElasticArray{T,2,1} + @test E1 * E2 == A1 * A2 + @test E1 * E3 == A1 * A3 + + @test E1^3 == A1^3 + @test inv(E1) == inv(A1) +end + +end # module \ No newline at end of file diff --git a/test/test_flattenedarray.jl b/test/test_flattenedarray.jl index bb67bb3..8bf3843 100644 --- a/test/test_flattenedarray.jl +++ b/test/test_flattenedarray.jl @@ -1,7 +1,11 @@ -module FlattenedArrayTest +module TestFlattenedArray + +using Zygote + include("preamble.jl") + function makedata(V::Type, M::Integer, N::Integer) dims = testdims(M + N) innersize, outersize = tuple_split(dims, M) @@ -22,7 +26,7 @@ function makedata(V::Type, M::Integer, N::Integer) ) end -@testset "M = $M, N = $N, V = $V" for M = 1:2, N = 1:2, V in (Float64,) +@testset "M = $M, N = $N, V = $V" for M in (0, 2), N = 1:2, V in (Float64,) @testset "constructors" begin data = makedata(V, M, N) Expected = FlattenedArray{V,M + N,M,typeof(data.nested),typeof(data.inneraxes)} @@ -53,9 +57,15 @@ end F = flatten(data.nested) @test F.inneraxes === inneraxes(F) === inneraxes(data.nested) @test map(length, F.inneraxes) === innersize(F) === innersize(data.nested) + @test eltype(F) === innereltype(data.nested) end end +@testset "flatten(flat)" begin + A = rand(10) + @test flatten(A) === A +end + @testset "aliasing" begin x1 = rand(2, 3) x2 = rand(2, 3) diff --git a/test/test_slicedarray.jl b/test/test_slicedarray.jl index f3f6cef..ed41f2a 100644 --- a/test/test_slicedarray.jl +++ b/test/test_slicedarray.jl @@ -1,25 +1,31 @@ -module TestSlices +module TestSlicedArray + +using Zygote using SpecialArrays: along2string +using SpecialArrays: CartesianIndexer +using SpecialArrays: tuple_map, True, False, TypedBool + include("preamble.jl") + const TEST_ALONGS = [ - (static(true),), - (static(false),), - (static(true), static(true)), - (static(false), static(true)), - (static(false), static(false)), -] + (True(), ), + (False(), ), -slicedims(al::TupleN{SBool}) = Tuple(i for i = 1:length(al) if unstatic(al[i])) + (True(), True()), + (True(), False()), + (False(), True()), + (False(), False()), +] -function makedata(V::Type, al::TupleN{SBool}) +function makedata(V::Type, al::TupleN{TypedBool}) L = length(al) pdims = testdims(L) - sdims = slicedims(al) - innersize = Tuple(pdims[i] for i in 1:L if unstatic(al[i])) - outersize = Tuple(pdims[i] for i in 1:L if !unstatic(al[i])) + sdims = findall(al) + innersize = pdims[al] + outersize = pdims[tuple_map(!, al)] M, N = length(innersize), length(outersize) flat = rand!(Array{V,L}(undef, pdims...)) @@ -66,7 +72,7 @@ showalongs(al) = "($(join(map(SpecialArrays.along2string, al), ", ")))" @test typeof(slice(flat, sdims)) <: Expected @test typeof(slice(flat, sdims...)) <: Expected - I = map(a -> a isa STrue ? Colon() : *, al) + I = map(a -> a isa True ? Colon() : *, al) @test typeof(slice(flat, I...)) <: Expected @test_inferred slice(flat, I...) end @@ -146,7 +152,7 @@ Adapt.adapt_storage(::Type{<:CartesianIndexer}, A) = CartesianIndexer(A) end @testset "Zygote" begin - data = makedata(Float64, (STrue(), SFalse(), STrue())) + data = makedata(Float64, (True(), False(), True())) x = rand!(zeros(last(data.innersize))) g1 = Zygote.gradient(x -> sum(sum(a -> a * x, data.nested)), x) g2 = Zygote.gradient(x -> sum(sum(a -> a * x, slice(data.flat, data.sdims))), x)