From 9a3f893f4103b48a39d7127ff31cb803465725a5 Mon Sep 17 00:00:00 2001 From: Jakob Nybo Nissen Date: Sat, 26 Oct 2024 20:23:09 +0200 Subject: [PATCH] Unroll == and some search functions (#327) This allows autovectorization, approximately doubling speed for long sequences. Performance for short sequences is almost unchanged, so this is a clear win. --- src/counting.jl | 5 -- src/longsequences/chunk_iterator.jl | 13 +++-- src/longsequences/operators.jl | 77 ++++++++++++++++++++++++----- test/longsequences/basics.jl | 11 +++++ test/longsequences/find.jl | 10 ++++ test/longsequences/seqview.jl | 14 ++++++ 6 files changed, 108 insertions(+), 22 deletions(-) diff --git a/src/counting.jl b/src/counting.jl index 753628b4..3333346a 100644 --- a/src/counting.jl +++ b/src/counting.jl @@ -1,8 +1,3 @@ -# TODO: SubSeqChunks may be unnecessary for counting, since we don't care about -# the order of the chunks. -# Instead, we could emit the head and tail chunk seperately, then an iterator of -# all the full chunks loaded as single elements from the underlying vector. - trunc_seq(x::LongSequence, len::Int) = typeof(x)(x.data, len % UInt) trunc_seq(x::LongSubSeq, len::Int) = typeof(x)(x.data, first(x.part):first(x.part)+len-1) diff --git a/src/longsequences/chunk_iterator.jl b/src/longsequences/chunk_iterator.jl index 773b5f63..4dd0962e 100644 --- a/src/longsequences/chunk_iterator.jl +++ b/src/longsequences/chunk_iterator.jl @@ -120,17 +120,22 @@ end Base.length(it::PairedChunkIterator) = length(it.a) Base.eltype(::Type{<:PairedChunkIterator}) = NTuple{2, UInt64} +first_state(x::PairedChunkIterator) = (first_state(x.a), first_state(x.b)) -@inline function Base.iterate( - it::PairedChunkIterator, - state=(first_state(it.a), first_state(it.b)) -) +@inline function Base.iterate(it::PairedChunkIterator, state=first_state(it)) a = iterate(it.a, first(state)) isnothing(a) && return nothing b = iter_inbounds(it.b, last(state)) ((first(a), first(b)), (last(a), last(b))) end +@inline function iter_inbounds(it::PairedChunkIterator, state) + (sa, sb) = state + a = iter_inbounds(it.a, sa) + b = iter_inbounds(it.b, sb) + ((first(a), first(b)), (last(a), last(b))) +end + # This returns (head, body, tail), where: # - head and tail are Tuple{UInt64, UInt8}, with a coding element and the number # of coding bits in that element. Head is the partial coding element before any diff --git a/src/longsequences/operators.jl b/src/longsequences/operators.jl index 4d0e0afe..ece5af91 100644 --- a/src/longsequences/operators.jl +++ b/src/longsequences/operators.jl @@ -207,37 +207,65 @@ function Base.:(==)(seq1::SeqOrView{A}, seq2::SeqOrView{A}) where {A <: Alphabet end # Check all filled UInts - (it, (ch1, ch2, rem)) = iter_chunks(seq1, seq2) - for (i, j) in it + (it, (ch1, ch2, rm)) = iter_chunks(seq1, seq2) + chunks = length(it) + state = first_state(it) + unroll = 8 + while chunks ≥ unroll + same = true + for _ in 1:unroll + ((i, j), state) = iter_inbounds(it, state) + same &= i == j + end + same || return false + chunks -= unroll + end + itval = iterate(it, state) + while itval !== nothing + ((i, j), state) = itval i == j || return false + itval = iterate(it, state) end # Check last coding UInt (or compare two zeros, if none) - mask = UInt64(1) << (rem & 63) - 1 + mask = UInt64(1) << (rm & 63) - 1 return (ch1 & mask) == (ch2 & mask) end function Base.:(==)(seq1::LongSequence{A}, seq2::LongSequence{A}) where {A <: Alphabet} length(seq1) == length(seq2) || return false isempty(seq1) && return true + (data1, data2) = (seq1.data, seq2.data) # Check all filled UInts nextind = nextposition(lastbitindex(seq1)) - @inbounds for i in 1:index(nextind) - 1 - seq1.data[i] == seq2.data[i] || return false + last_chunk_index = index(nextind) % Int - 1 + i = 1 + unroll = 8 + while i + unroll - 2 < last_chunk_index + same = true + @inbounds for j in 0:unroll-1 + same &= data1[i + j] == data2[i + j] + end + same || return false + i += unroll + end + @inbounds for i in i:last_chunk_index + data1[i] == data2[i] || return false end - # Check last coding UInt, if any @inbounds if !iszero(offset(nextind)) - mask = bitmask(offset(nextind)) + mask = UInt64(1) << (offset(nextind) & 63) - 1 i = index(nextind) - (seq1.data[i] & mask) == (seq2.data[i] & mask) || return false + (data1[i] & mask) == (data2[i] & mask) || return false end - return true end ## Search +function Base.findnext(::typeof(isgap), seq::SeqOrView{<:KNOWN_ALPHABETS}, i::Integer) + findnext(==(gap(eltype(seq))), seq, i) +end # We only dispatch on known alphabets, because new alphabets may implement == in surprising ways function Base.findnext( @@ -267,9 +295,18 @@ function _findfirst(seq::SeqOrView{<:KNOWN_ALPHABETS}, enc::UInt64) tu < symbols_in_head && return tu + 1 end i = symbols_in_head + 1 + unroll = 8 + while body_i + unroll - 2 < body_stop + u = zero(UInt64) + for j in 0:unroll - 1 + u |= set_zero_encoding(BitsPerSymbol(seq), @inbounds(data[body_i + j]) ⊻ enc) + end + iszero(u) || break + i += symbols_per_data_element(seq) * unroll + body_i += unroll + end while body_i ≤ body_stop - chunk = @inbounds data[body_i] ⊻ enc - ze = set_zero_encoding(BitsPerSymbol(seq), chunk) + ze = set_zero_encoding(BitsPerSymbol(seq), @inbounds(data[body_i]) ⊻ enc) if !iszero(ze) return i + div(trailing_zeros(ze) % UInt, bits_per_symbol(Alphabet(seq))) % Int end @@ -283,6 +320,10 @@ function _findfirst(seq::SeqOrView{<:KNOWN_ALPHABETS}, enc::UInt64) nothing end +function Base.findprev(::typeof(isgap), seq::SeqOrView{<:KNOWN_ALPHABETS}, i::Integer) + findprev(==(gap(eltype(seq))), seq, i) +end + function Base.findprev( cmp::Base.Fix2{<:Union{typeof(==), typeof(isequal)}}, seq::SeqOrView{<:KNOWN_ALPHABETS}, @@ -316,9 +357,19 @@ function _findlast(seq::SeqOrView{<:KNOWN_ALPHABETS}, enc::UInt64) lu < symbols_in_tail && return (i - lu) % Int i -= lu end + unroll = 8 + (body_i, body_stop) = (body_i % Int, body_stop % Int) + while body_i - unroll + 2 > body_stop + u = zero(UInt64) + for j in 0:unroll - 1 + u |= set_zero_encoding(BitsPerSymbol(seq), @inbounds(data[body_i - j]) ⊻ enc) + end + iszero(u) || break + i -= symbols_per_data_element(seq) * unroll + body_i -= unroll + end while body_i ≥ body_stop - chunk = @inbounds data[body_i] ⊻ enc - ze = set_zero_encoding(BitsPerSymbol(seq), chunk) + ze = set_zero_encoding(BitsPerSymbol(seq), @inbounds(data[body_i]) ⊻ enc) if !iszero(ze) return i - div(leading_zeros(ze) % UInt, bits_per_symbol(Alphabet(seq))) % Int end diff --git a/test/longsequences/basics.jl b/test/longsequences/basics.jl index d0c71705..ca004f98 100644 --- a/test/longsequences/basics.jl +++ b/test/longsequences/basics.jl @@ -341,6 +341,13 @@ end ACGTN """ @test a == b + + a = randdnaseq(512) + a[111] = DNA_G + b = copy(a) + @test a == b + a[111] = DNA_C + @test a != b end @testset "RNA" begin @@ -373,6 +380,10 @@ end X """ @test a == b + + a = randaaseq(131) + b = copy(a) + @test a == b end end diff --git a/test/longsequences/find.jl b/test/longsequences/find.jl index e8a50fdc..6eac41c5 100644 --- a/test/longsequences/find.jl +++ b/test/longsequences/find.jl @@ -64,6 +64,16 @@ @test findnext(==(DNA_M), seq, 8) == 8 @test findnext(==(DNA_M), seq, 9) == 21 + # Very large sequence so it hits the inner SIMD loop + seq = randdnaseq(541) + seq[100] = DNA_Gap + @test findfirst(==(DNA_Gap), seq) == findfirst(isgap, seq) == 100 + seq[100] = DNA_A + seq[500] = DNA_Gap + @test findlast(==(DNA_Gap), seq) == findlast(isgap, seq) == 500 + seq[500] = DNA_A + @test findfirst(==(DNA_Gap), seq) == findlast(isgap, seq) == nothing + # View with only tail # 1234 seq = view(aa"KWYPAV-L", 3:6) diff --git a/test/longsequences/seqview.jl b/test/longsequences/seqview.jl index 777a8c35..f9306e1c 100644 --- a/test/longsequences/seqview.jl +++ b/test/longsequences/seqview.jl @@ -44,6 +44,20 @@ end @test String(seq) == "ANKYH" end +@testset "Equality" begin + for size in [41, 504, 7] + for offset in [0, 3, 32] + seq = randrnaseq(size) + seq2 = view(randrnaseq(offset) * seq * randrnaseq(15), offset+1:offset+size) + @test seq == seq2 + seq[4] = RNA_Gap + @test seq != seq2 + seq3 = view(seq, 1:size) + @test seq == seq3 + end + end +end + # Added after issue 260 @testset "Random construction" begin for i in 1:100