Skip to content

Commit

Permalink
Unroll == and some search functions
Browse files Browse the repository at this point in the history
This allows autovectorization, approximately doubling speed for long sequences.
Performance for short sequences is almost unchanged, so this is a clear win.
  • Loading branch information
jakobnissen committed Oct 26, 2024
1 parent 1c63ecf commit d562b87
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 13 deletions.
9 changes: 9 additions & 0 deletions src/longsequences/chunk_iterator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ end

Base.length(it::PairedChunkIterator) = length(it.a)
Base.eltype(::Type{<:PairedChunkIterator}) = NTuple{2, UInt64}
first_state(x::PairedChunkIterator) = (first_state(x.a), first_state(x.b))


@inline function Base.iterate(
it::PairedChunkIterator,
Expand All @@ -131,6 +133,13 @@ Base.eltype(::Type{<:PairedChunkIterator}) = NTuple{2, UInt64}
((first(a), first(b)), (last(a), last(b)))
end

@inline function iter_inbounds(it::PairedChunkIterator, state)
(sa, sb) = state
a = iter_inbounds(it.a, sa)
b = iter_inbounds(it.b, sb)
((first(a), first(b)), (last(a), last(b)))
end

# This returns (head, body, tail), where:
# - head and tail are Tuple{UInt64, UInt8}, with a coding element and the number
# of coding bits in that element. Head is the partial coding element before any
Expand Down
77 changes: 64 additions & 13 deletions src/longsequences/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -207,37 +207,65 @@ function Base.:(==)(seq1::SeqOrView{A}, seq2::SeqOrView{A}) where {A <: Alphabet
end

# Check all filled UInts
(it, (ch1, ch2, rem)) = iter_chunks(seq1, seq2)
for (i, j) in it
(it, (ch1, ch2, rm)) = iter_chunks(seq1, seq2)
chunks = length(it)
state = first_state(it)
unroll = 8
while chunks unroll
same = true
for _ in 1:unroll
((i, j), state) = iter_inbounds(it, state)
same &= i == j
end
same || return false
chunks -= unroll
end
itval = iterate(it, state)
while itval !== nothing
((i, j), state) = itval
i == j || return false
itval = iterate(it, state)
end

# Check last coding UInt (or compare two zeros, if none)
mask = UInt64(1) << (rem & 63) - 1
mask = UInt64(1) << (rm & 63) - 1
return (ch1 & mask) == (ch2 & mask)
end

function Base.:(==)(seq1::LongSequence{A}, seq2::LongSequence{A}) where {A <: Alphabet}
length(seq1) == length(seq2) || return false
isempty(seq1) && return true
(data1, data2) = (seq1.data, seq2.data)

# Check all filled UInts
nextind = nextposition(lastbitindex(seq1))
@inbounds for i in 1:index(nextind) - 1
seq1.data[i] == seq2.data[i] || return false
last_chunk_index = index(nextind) % Int - 1
i = 1
unroll = 8
while i + unroll - 2 < last_chunk_index
same = true
@inbounds for j in 0:unroll-1
same &= data1[i + j] == data2[i + j]
end
same || return false
i += unroll
end
@inbounds for i in i:last_chunk_index
data1[i] == data2[i] || return false
end

# Check last coding UInt, if any
@inbounds if !iszero(offset(nextind))
mask = bitmask(offset(nextind))
mask = UInt64(1) << (offset(nextind) & 63) - 1
i = index(nextind)
(seq1.data[i] & mask) == (seq2.data[i] & mask) || return false
(data1[i] & mask) == (data2[i] & mask) || return false
end

return true
end

## Search
function Base.findnext(::typeof(isgap), seq::SeqOrView{<:KNOWN_ALPHABETS}, i::Integer)
findnext(==(gap(eltype(seq))), seq, i)
end

# We only dispatch on known alphabets, because new alphabets may implement == in surprising ways
function Base.findnext(
Expand Down Expand Up @@ -267,9 +295,18 @@ function _findfirst(seq::SeqOrView{<:KNOWN_ALPHABETS}, enc::UInt64)
tu < symbols_in_head && return tu + 1
end
i = symbols_in_head + 1
unroll = 8
while body_i + unroll - 2 < body_stop
u = zero(UInt64)
for j in 0:unroll - 1
u |= set_zero_encoding(BitsPerSymbol(seq), @inbounds(data[body_i + j]) enc)
end
iszero(u) || break
i += symbols_per_data_element(seq) * unroll
body_i += unroll
end
while body_i body_stop
chunk = @inbounds data[body_i] enc
ze = set_zero_encoding(BitsPerSymbol(seq), chunk)
ze = set_zero_encoding(BitsPerSymbol(seq), @inbounds(data[body_i]) enc)
if !iszero(ze)
return i + div(trailing_zeros(ze) % UInt, bits_per_symbol(Alphabet(seq))) % Int
end
Expand All @@ -283,6 +320,10 @@ function _findfirst(seq::SeqOrView{<:KNOWN_ALPHABETS}, enc::UInt64)
nothing
end

function Base.findprev(::typeof(isgap), seq::SeqOrView{<:KNOWN_ALPHABETS}, i::Integer)
findprev(==(gap(eltype(seq))), seq, i)
end

function Base.findprev(
cmp::Base.Fix2{<:Union{typeof(==), typeof(isequal)}},
seq::SeqOrView{<:KNOWN_ALPHABETS},
Expand Down Expand Up @@ -316,9 +357,19 @@ function _findlast(seq::SeqOrView{<:KNOWN_ALPHABETS}, enc::UInt64)
lu < symbols_in_tail && return (i - lu) % Int
i -= lu
end
unroll = 8
(body_i, body_stop) = (body_i % Int, body_stop % Int)
while body_i - unroll + 2 > body_stop
u = zero(UInt64)
for j in 0:unroll - 1
u |= set_zero_encoding(BitsPerSymbol(seq), @inbounds(data[body_i + j]) enc)
end
iszero(u) || break
i -= symbols_per_data_element(seq) * unroll
body_i -= unroll
end
while body_i body_stop
chunk = @inbounds data[body_i] enc
ze = set_zero_encoding(BitsPerSymbol(seq), chunk)
ze = set_zero_encoding(BitsPerSymbol(seq), @inbounds(data[body_i]) enc)
if !iszero(ze)
return i - div(leading_zeros(ze) % UInt, bits_per_symbol(Alphabet(seq))) % Int
end
Expand Down

0 comments on commit d562b87

Please sign in to comment.