Skip to content

Commit

Permalink
Fix Broadcast.broadcast_shape inference
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Sep 21, 2023
1 parent ab51dbe commit 1f254d3
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
7 changes: 4 additions & 3 deletions src/blockbroadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,14 @@ BroadcastStyle(::PseudoBlockStyle{M}, ::BlockStyle{N}) where {M,N} = BlockStyle(

# sortedunion can assume inputs are already sorted so this could be improved
sortedunion(a,b) = sort!(union(a,b))
sortedunion(a::Tuple, b::Tuple) = (a..., b...)
sortedunion(a::Base.OneTo, b::Base.OneTo) = Base.OneTo(max(last(a),last(b)))
sortedunion(a::AbstractUnitRange, b::AbstractUnitRange) = min(first(a),first(b)):max(last(a),last(b))
combine_blockaxes(a, b) = _BlockedUnitRange(sortedunion(blocklasts(a), blocklasts(b)))

Base.Broadcast.axistype(a::BlockedUnitRange, b::BlockedUnitRange) = length(b) == 1 ? a : combine_blockaxes(a, b)
Base.Broadcast.axistype(a::BlockedUnitRange, b) = length(b) == 1 ? a : combine_blockaxes(a, b)
Base.Broadcast.axistype(a, b::BlockedUnitRange) = length(b) == 1 ? a : combine_blockaxes(a, b)
Base.Broadcast.axistype(a::BlockedUnitRange, b::BlockedUnitRange) = combine_blockaxes(a, b)
Base.Broadcast.axistype(a::BlockedUnitRange, b) = combine_blockaxes(a, b)
Base.Broadcast.axistype(a, b::BlockedUnitRange) = combine_blockaxes(a, b)


similar(bc::Broadcasted{<:AbstractBlockStyle{N}}, ::Type{T}) where {T,N} =
Expand Down
13 changes: 13 additions & 0 deletions test/test_blockbroadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,19 @@ import BlockArrays: SubBlockIterator, BlockIndexRange, Diagonal
u = BlockArray(randn(5), [2,3]);
@inferred(copyto!(similar(u), Base.broadcasted(exp, u)))
@test exp.(u) == exp.(Vector(u))

function test_allocation!(shape1, shape2)
x = Base.Broadcast.broadcast_shape(shape1, shape2)
return nothing
end
shape1 = (BlockArrays._BlockedUnitRange((2,)),);
shape2 = (BlockArrays._BlockedUnitRange((2,)),);
@inferred Base.Broadcast.axistype(shape1[1], shape2[1])
@inferred BlockArrays.combine_blockaxes(shape1[1], shape2[1])
@inferred Base.Broadcast.broadcast_shape(shape1, shape2)
test_allocation!(shape1, shape2) # compile first
p = @allocated test_allocation!(shape1, shape2)
@test p == 0
end

@testset "adjtrans" begin
Expand Down

0 comments on commit 1f254d3

Please sign in to comment.