diff --git a/src/longsequences/chunk_iterator.jl b/src/longsequences/chunk_iterator.jl index 773b5f63..a51962a6 100644 --- a/src/longsequences/chunk_iterator.jl +++ b/src/longsequences/chunk_iterator.jl @@ -120,6 +120,8 @@ end Base.length(it::PairedChunkIterator) = length(it.a) Base.eltype(::Type{<:PairedChunkIterator}) = NTuple{2, UInt64} +first_state(x::PairedChunkIterator) = (first_state(x.a), first_state(x.b)) + @inline function Base.iterate( it::PairedChunkIterator, @@ -131,6 +133,13 @@ Base.eltype(::Type{<:PairedChunkIterator}) = NTuple{2, UInt64} ((first(a), first(b)), (last(a), last(b))) end +@inline function iter_inbounds(it::PairedChunkIterator, state) + (sa, sb) = state + a = iter_inbounds(it.a, sa) + b = iter_inbounds(it.b, sb) + ((first(a), first(b)), (last(a), last(b))) +end + # This returns (head, body, tail), where: # - head and tail are Tuple{UInt64, UInt8}, with a coding element and the number # of coding bits in that element. Head is the partial coding element before any diff --git a/src/longsequences/operators.jl b/src/longsequences/operators.jl index 4d0e0afe..24782bb2 100644 --- a/src/longsequences/operators.jl +++ b/src/longsequences/operators.jl @@ -207,37 +207,65 @@ function Base.:(==)(seq1::SeqOrView{A}, seq2::SeqOrView{A}) where {A <: Alphabet end # Check all filled UInts - (it, (ch1, ch2, rem)) = iter_chunks(seq1, seq2) - for (i, j) in it + (it, (ch1, ch2, rm)) = iter_chunks(seq1, seq2) + chunks = length(it) + state = first_state(it) + unroll = 8 + while chunks ≥ unroll + same = true + for _ in 1:unroll + ((i, j), state) = iter_inbounds(it, state) + same &= i == j + end + same || return false + chunks -= unroll + end + itval = iterate(it, state) + while itval !== nothing + ((i, j), state) = itval i == j || return false + itval = iterate(it, state) end # Check last coding UInt (or compare two zeros, if none) - mask = UInt64(1) << (rem & 63) - 1 + mask = UInt64(1) << (rm & 63) - 1 return (ch1 & mask) == (ch2 & mask) end function Base.:(==)(seq1::LongSequence{A}, seq2::LongSequence{A}) where {A <: Alphabet} length(seq1) == length(seq2) || return false isempty(seq1) && return true + (data1, data2) = (seq1.data, seq2.data) # Check all filled UInts nextind = nextposition(lastbitindex(seq1)) - @inbounds for i in 1:index(nextind) - 1 - seq1.data[i] == seq2.data[i] || return false + last_chunk_index = index(nextind) % Int - 1 + i = 1 + unroll = 8 + while i + unroll - 2 < last_chunk_index + same = true + @inbounds for j in 0:unroll-1 + same &= data1[i + j] == data2[i + j] + end + same || return false + i += unroll + end + @inbounds for i in i:last_chunk_index + data1[i] == data2[i] || return false end - # Check last coding UInt, if any @inbounds if !iszero(offset(nextind)) - mask = bitmask(offset(nextind)) + mask = UInt64(1) << (offset(nextind) & 63) - 1 i = index(nextind) - (seq1.data[i] & mask) == (seq2.data[i] & mask) || return false + (data1[i] & mask) == (data2[i] & mask) || return false end - return true end ## Search +function Base.findnext(::typeof(isgap), seq::SeqOrView{<:KNOWN_ALPHABETS}, i::Integer) + findnext(==(gap(eltype(seq))), seq, i) +end # We only dispatch on known alphabets, because new alphabets may implement == in surprising ways function Base.findnext( @@ -267,9 +295,18 @@ function _findfirst(seq::SeqOrView{<:KNOWN_ALPHABETS}, enc::UInt64) tu < symbols_in_head && return tu + 1 end i = symbols_in_head + 1 + unroll = 8 + while body_i + unroll - 2 < body_stop + u = zero(UInt64) + for j in 0:unroll - 1 + u |= set_zero_encoding(BitsPerSymbol(seq), @inbounds(data[body_i + j]) ⊻ enc) + end + iszero(u) || break + i += symbols_per_data_element(seq) * unroll + body_i += unroll + end while body_i ≤ body_stop - chunk = @inbounds data[body_i] ⊻ enc - ze = set_zero_encoding(BitsPerSymbol(seq), chunk) + ze = set_zero_encoding(BitsPerSymbol(seq), @inbounds(data[body_i]) ⊻ enc) if !iszero(ze) return i + div(trailing_zeros(ze) % UInt, bits_per_symbol(Alphabet(seq))) % Int end @@ -283,6 +320,10 @@ function _findfirst(seq::SeqOrView{<:KNOWN_ALPHABETS}, enc::UInt64) nothing end +function Base.findprev(::typeof(isgap), seq::SeqOrView{<:KNOWN_ALPHABETS}, i::Integer) + findprev(==(gap(eltype(seq))), seq, i) +end + function Base.findprev( cmp::Base.Fix2{<:Union{typeof(==), typeof(isequal)}}, seq::SeqOrView{<:KNOWN_ALPHABETS}, @@ -316,9 +357,19 @@ function _findlast(seq::SeqOrView{<:KNOWN_ALPHABETS}, enc::UInt64) lu < symbols_in_tail && return (i - lu) % Int i -= lu end + unroll = 8 + (body_i, body_stop) = (body_i % Int, body_stop % Int) + while body_i - unroll + 2 > body_stop + u = zero(UInt64) + for j in 0:unroll - 1 + u |= set_zero_encoding(BitsPerSymbol(seq), @inbounds(data[body_i + j]) ⊻ enc) + end + iszero(u) || break + i -= symbols_per_data_element(seq) * unroll + body_i -= unroll + end while body_i ≥ body_stop - chunk = @inbounds data[body_i] ⊻ enc - ze = set_zero_encoding(BitsPerSymbol(seq), chunk) + ze = set_zero_encoding(BitsPerSymbol(seq), @inbounds(data[body_i]) ⊻ enc) if !iszero(ze) return i - div(leading_zeros(ze) % UInt, bits_per_symbol(Alphabet(seq))) % Int end