diff --git a/src/FixedSizeArrays.jl b/src/FixedSizeArrays.jl index a9a74ac..11d3c12 100644 --- a/src/FixedSizeArrays.jl +++ b/src/FixedSizeArrays.jl @@ -160,4 +160,14 @@ Base.unsafe_convert(::Type{Ptr{T}}, a::FixedSizeArray{T}) where {T} = Base.unsaf Base.elsize(::Type{A}) where {A<:FixedSizeArray} = Base.elsize(parent_type(A)) +# `reshape`: specializing it to ensure it returns a `FixedSizeArray` + +function Base.reshape(a::FixedSizeArray{T}, size::NTuple{N,Int}) where {T,N} + len = checked_dims(size) + if length(a) != len + throw(DimensionMismatch("new shape not consistent with existing array length")) + end + FixedSizeArray{T,N}(Internal(), a.mem, size) +end + end # module FixedSizeArrays diff --git a/test/runtests.jl b/test/runtests.jl index 10b21b7..6142e06 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -301,4 +301,36 @@ end end end end + + @testset "`reshape`" begin + length_to_shapes = Dict( + (0 => ((0,), (0, 0), (0, 1), (1, 0), (1, 0, 0), (0, 0, 1))), + (1 => ((), (1,), (1, 1), (1, 1, 1))), + (2 => ((2,), (1, 2), (2, 1), (1, 2, 1))), + (3 => ((3,), (1, 3), (3, 1), (1, 3, 1))), + (4 => ((4,), (1, 4), (4, 1), (2, 2), (1, 2, 2), (2, 1, 2))), + (6 => ((6,), (1, 6), (6, 1), (2, 3), (3, 2), (1, 3, 2), (2, 1, 3))), + ) + for elem_type ∈ (Int, Number, Union{Nothing,Int}) + for len ∈ keys(length_to_shapes) + shapes = length_to_shapes[len] + for shape1 ∈ shapes + a = FixedSizeArray{elem_type,length(shape1)}(undef, shape1) + @test_throws DimensionMismatch reshape(a, length(a)+1) + @test_throws DimensionMismatch reshape(a, length(a)+1, 1) + @test_throws DimensionMismatch reshape(a, 1, length(a)+1) + for shape2 ∈ shapes + @test prod(shape1) === prod(shape2) === len # meta + T = FixedSizeArray{elem_type,length(shape2)} + test_inferred_noalloc(reshape, T, (a, shape2)) + test_inferred_noalloc(reshape, T, (a, shape2...)) + b = reshape(a, shape2) + @test size(b) === shape2 + @test a.mem === b.mem + @test a === reshape(b, shape1) + end + end + end + end + end end