Skip to content

Commit

Permalink
Unroll == and some search functions
Browse files Browse the repository at this point in the history
This allows autovectorization, approximately doubling speed for long sequences.
Performance for short sequences is almost unchanged, so this is a clear win.
  • Loading branch information
jakobnissen committed Oct 26, 2024
1 parent 1c63ecf commit d562b87
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 13 deletions.
9 changes: 9 additions & 0 deletions src/longsequences/chunk_iterator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ end

Base.length(it::PairedChunkIterator) = length(it.a)
Base.eltype(::Type{<:PairedChunkIterator}) = NTuple{2, UInt64}
first_state(x::PairedChunkIterator) = (first_state(x.a), first_state(x.b))


@inline function Base.iterate(
it::PairedChunkIterator,
Expand All @@ -131,6 +133,13 @@ Base.eltype(::Type{<:PairedChunkIterator}) = NTuple{2, UInt64}
((first(a), first(b)), (last(a), last(b)))
end

@inline function iter_inbounds(it::PairedChunkIterator, state)
(sa, sb) = state
a = iter_inbounds(it.a, sa)
b = iter_inbounds(it.b, sb)
((first(a), first(b)), (last(a), last(b)))

Check warning on line 140 in src/longsequences/chunk_iterator.jl

View check run for this annotation

Codecov / codecov/patch

src/longsequences/chunk_iterator.jl#L136-L140

Added lines #L136 - L140 were not covered by tests
end

# This returns (head, body, tail), where:
# - head and tail are Tuple{UInt64, UInt8}, with a coding element and the number
# of coding bits in that element. Head is the partial coding element before any
Expand Down
77 changes: 64 additions & 13 deletions src/longsequences/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -207,37 +207,65 @@ function Base.:(==)(seq1::SeqOrView{A}, seq2::SeqOrView{A}) where {A <: Alphabet
end

# Check all filled UInts
(it, (ch1, ch2, rem)) = iter_chunks(seq1, seq2)
for (i, j) in it
(it, (ch1, ch2, rm)) = iter_chunks(seq1, seq2)
chunks = length(it)
state = first_state(it)
unroll = 8
while chunks unroll
same = true
for _ in 1:unroll
((i, j), state) = iter_inbounds(it, state)
same &= i == j
end
same || return false
chunks -= unroll
end

Check warning on line 222 in src/longsequences/operators.jl

View check run for this annotation

Codecov / codecov/patch

src/longsequences/operators.jl#L215-L222

Added lines #L215 - L222 were not covered by tests
itval = iterate(it, state)
while itval !== nothing
((i, j), state) = itval
i == j || return false
itval = iterate(it, state)
end

# Check last coding UInt (or compare two zeros, if none)
mask = UInt64(1) << (rem & 63) - 1
mask = UInt64(1) << (rm & 63) - 1
return (ch1 & mask) == (ch2 & mask)
end

function Base.:(==)(seq1::LongSequence{A}, seq2::LongSequence{A}) where {A <: Alphabet}
length(seq1) == length(seq2) || return false
isempty(seq1) && return true
(data1, data2) = (seq1.data, seq2.data)

# Check all filled UInts
nextind = nextposition(lastbitindex(seq1))
@inbounds for i in 1:index(nextind) - 1
seq1.data[i] == seq2.data[i] || return false
last_chunk_index = index(nextind) % Int - 1
i = 1
unroll = 8
while i + unroll - 2 < last_chunk_index
same = true
@inbounds for j in 0:unroll-1
same &= data1[i + j] == data2[i + j]
end
same || return false
i += unroll
end
@inbounds for i in i:last_chunk_index
data1[i] == data2[i] || return false
end

# Check last coding UInt, if any
@inbounds if !iszero(offset(nextind))
mask = bitmask(offset(nextind))
mask = UInt64(1) << (offset(nextind) & 63) - 1
i = index(nextind)
(seq1.data[i] & mask) == (seq2.data[i] & mask) || return false
(data1[i] & mask) == (data2[i] & mask) || return false
end

return true
end

## Search
function Base.findnext(::typeof(isgap), seq::SeqOrView{<:KNOWN_ALPHABETS}, i::Integer)
findnext(==(gap(eltype(seq))), seq, i)

Check warning on line 267 in src/longsequences/operators.jl

View check run for this annotation

Codecov / codecov/patch

src/longsequences/operators.jl#L266-L267

Added lines #L266 - L267 were not covered by tests
end

# We only dispatch on known alphabets, because new alphabets may implement == in surprising ways
function Base.findnext(
Expand Down Expand Up @@ -267,9 +295,18 @@ function _findfirst(seq::SeqOrView{<:KNOWN_ALPHABETS}, enc::UInt64)
tu < symbols_in_head && return tu + 1
end
i = symbols_in_head + 1
unroll = 8
while body_i + unroll - 2 < body_stop
u = zero(UInt64)
for j in 0:unroll - 1
u |= set_zero_encoding(BitsPerSymbol(seq), @inbounds(data[body_i + j]) enc)
end
iszero(u) || break
i += symbols_per_data_element(seq) * unroll
body_i += unroll
end

Check warning on line 307 in src/longsequences/operators.jl

View check run for this annotation

Codecov / codecov/patch

src/longsequences/operators.jl#L300-L307

Added lines #L300 - L307 were not covered by tests
while body_i body_stop
chunk = @inbounds data[body_i] enc
ze = set_zero_encoding(BitsPerSymbol(seq), chunk)
ze = set_zero_encoding(BitsPerSymbol(seq), @inbounds(data[body_i]) enc)
if !iszero(ze)
return i + div(trailing_zeros(ze) % UInt, bits_per_symbol(Alphabet(seq))) % Int
end
Expand All @@ -283,6 +320,10 @@ function _findfirst(seq::SeqOrView{<:KNOWN_ALPHABETS}, enc::UInt64)
nothing
end

function Base.findprev(::typeof(isgap), seq::SeqOrView{<:KNOWN_ALPHABETS}, i::Integer)
findprev(==(gap(eltype(seq))), seq, i)

Check warning on line 324 in src/longsequences/operators.jl

View check run for this annotation

Codecov / codecov/patch

src/longsequences/operators.jl#L323-L324

Added lines #L323 - L324 were not covered by tests
end

function Base.findprev(
cmp::Base.Fix2{<:Union{typeof(==), typeof(isequal)}},
seq::SeqOrView{<:KNOWN_ALPHABETS},
Expand Down Expand Up @@ -316,9 +357,19 @@ function _findlast(seq::SeqOrView{<:KNOWN_ALPHABETS}, enc::UInt64)
lu < symbols_in_tail && return (i - lu) % Int
i -= lu
end
unroll = 8
(body_i, body_stop) = (body_i % Int, body_stop % Int)
while body_i - unroll + 2 > body_stop
u = zero(UInt64)
for j in 0:unroll - 1
u |= set_zero_encoding(BitsPerSymbol(seq), @inbounds(data[body_i + j]) enc)
end
iszero(u) || break
i -= symbols_per_data_element(seq) * unroll
body_i -= unroll
end

Check warning on line 370 in src/longsequences/operators.jl

View check run for this annotation

Codecov / codecov/patch

src/longsequences/operators.jl#L363-L370

Added lines #L363 - L370 were not covered by tests
while body_i body_stop
chunk = @inbounds data[body_i] enc
ze = set_zero_encoding(BitsPerSymbol(seq), chunk)
ze = set_zero_encoding(BitsPerSymbol(seq), @inbounds(data[body_i]) enc)
if !iszero(ze)
return i - div(leading_zeros(ze) % UInt, bits_per_symbol(Alphabet(seq))) % Int
end
Expand Down

0 comments on commit d562b87

Please sign in to comment.