Skip to content

Commit

Permalink
Unroll == and some search functions (BioJulia#327)
Browse files Browse the repository at this point in the history
This allows autovectorization, approximately doubling speed for long sequences.
Performance for short sequences is almost unchanged, so this is a clear win.
  • Loading branch information
jakobnissen authored Oct 26, 2024
1 parent 1c63ecf commit 9a3f893
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 22 deletions.
5 changes: 0 additions & 5 deletions src/counting.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
# TODO: SubSeqChunks may be unnecessary for counting, since we don't care about
# the order of the chunks.
# Instead, we could emit the head and tail chunk seperately, then an iterator of
# all the full chunks loaded as single elements from the underlying vector.

trunc_seq(x::LongSequence, len::Int) = typeof(x)(x.data, len % UInt)
trunc_seq(x::LongSubSeq, len::Int) = typeof(x)(x.data, first(x.part):first(x.part)+len-1)

Expand Down
13 changes: 9 additions & 4 deletions src/longsequences/chunk_iterator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,17 +120,22 @@ end

Base.length(it::PairedChunkIterator) = length(it.a)
Base.eltype(::Type{<:PairedChunkIterator}) = NTuple{2, UInt64}
first_state(x::PairedChunkIterator) = (first_state(x.a), first_state(x.b))

@inline function Base.iterate(
it::PairedChunkIterator,
state=(first_state(it.a), first_state(it.b))
)
@inline function Base.iterate(it::PairedChunkIterator, state=first_state(it))
a = iterate(it.a, first(state))
isnothing(a) && return nothing
b = iter_inbounds(it.b, last(state))
((first(a), first(b)), (last(a), last(b)))
end

@inline function iter_inbounds(it::PairedChunkIterator, state)
(sa, sb) = state
a = iter_inbounds(it.a, sa)
b = iter_inbounds(it.b, sb)
((first(a), first(b)), (last(a), last(b)))
end

# This returns (head, body, tail), where:
# - head and tail are Tuple{UInt64, UInt8}, with a coding element and the number
# of coding bits in that element. Head is the partial coding element before any
Expand Down
77 changes: 64 additions & 13 deletions src/longsequences/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -207,37 +207,65 @@ function Base.:(==)(seq1::SeqOrView{A}, seq2::SeqOrView{A}) where {A <: Alphabet
end

# Check all filled UInts
(it, (ch1, ch2, rem)) = iter_chunks(seq1, seq2)
for (i, j) in it
(it, (ch1, ch2, rm)) = iter_chunks(seq1, seq2)
chunks = length(it)
state = first_state(it)
unroll = 8
while chunks unroll
same = true
for _ in 1:unroll
((i, j), state) = iter_inbounds(it, state)
same &= i == j
end
same || return false
chunks -= unroll
end
itval = iterate(it, state)
while itval !== nothing
((i, j), state) = itval
i == j || return false
itval = iterate(it, state)
end

# Check last coding UInt (or compare two zeros, if none)
mask = UInt64(1) << (rem & 63) - 1
mask = UInt64(1) << (rm & 63) - 1
return (ch1 & mask) == (ch2 & mask)
end

function Base.:(==)(seq1::LongSequence{A}, seq2::LongSequence{A}) where {A <: Alphabet}
length(seq1) == length(seq2) || return false
isempty(seq1) && return true
(data1, data2) = (seq1.data, seq2.data)

# Check all filled UInts
nextind = nextposition(lastbitindex(seq1))
@inbounds for i in 1:index(nextind) - 1
seq1.data[i] == seq2.data[i] || return false
last_chunk_index = index(nextind) % Int - 1
i = 1
unroll = 8
while i + unroll - 2 < last_chunk_index
same = true
@inbounds for j in 0:unroll-1
same &= data1[i + j] == data2[i + j]
end
same || return false
i += unroll
end
@inbounds for i in i:last_chunk_index
data1[i] == data2[i] || return false
end

# Check last coding UInt, if any
@inbounds if !iszero(offset(nextind))
mask = bitmask(offset(nextind))
mask = UInt64(1) << (offset(nextind) & 63) - 1
i = index(nextind)
(seq1.data[i] & mask) == (seq2.data[i] & mask) || return false
(data1[i] & mask) == (data2[i] & mask) || return false
end

return true
end

## Search
function Base.findnext(::typeof(isgap), seq::SeqOrView{<:KNOWN_ALPHABETS}, i::Integer)
findnext(==(gap(eltype(seq))), seq, i)
end

# We only dispatch on known alphabets, because new alphabets may implement == in surprising ways
function Base.findnext(
Expand Down Expand Up @@ -267,9 +295,18 @@ function _findfirst(seq::SeqOrView{<:KNOWN_ALPHABETS}, enc::UInt64)
tu < symbols_in_head && return tu + 1
end
i = symbols_in_head + 1
unroll = 8
while body_i + unroll - 2 < body_stop
u = zero(UInt64)
for j in 0:unroll - 1
u |= set_zero_encoding(BitsPerSymbol(seq), @inbounds(data[body_i + j]) enc)
end
iszero(u) || break
i += symbols_per_data_element(seq) * unroll
body_i += unroll
end
while body_i body_stop
chunk = @inbounds data[body_i] enc
ze = set_zero_encoding(BitsPerSymbol(seq), chunk)
ze = set_zero_encoding(BitsPerSymbol(seq), @inbounds(data[body_i]) enc)
if !iszero(ze)
return i + div(trailing_zeros(ze) % UInt, bits_per_symbol(Alphabet(seq))) % Int
end
Expand All @@ -283,6 +320,10 @@ function _findfirst(seq::SeqOrView{<:KNOWN_ALPHABETS}, enc::UInt64)
nothing
end

function Base.findprev(::typeof(isgap), seq::SeqOrView{<:KNOWN_ALPHABETS}, i::Integer)
findprev(==(gap(eltype(seq))), seq, i)
end

function Base.findprev(
cmp::Base.Fix2{<:Union{typeof(==), typeof(isequal)}},
seq::SeqOrView{<:KNOWN_ALPHABETS},
Expand Down Expand Up @@ -316,9 +357,19 @@ function _findlast(seq::SeqOrView{<:KNOWN_ALPHABETS}, enc::UInt64)
lu < symbols_in_tail && return (i - lu) % Int
i -= lu
end
unroll = 8
(body_i, body_stop) = (body_i % Int, body_stop % Int)
while body_i - unroll + 2 > body_stop
u = zero(UInt64)
for j in 0:unroll - 1
u |= set_zero_encoding(BitsPerSymbol(seq), @inbounds(data[body_i - j]) enc)
end
iszero(u) || break
i -= symbols_per_data_element(seq) * unroll
body_i -= unroll
end
while body_i body_stop
chunk = @inbounds data[body_i] enc
ze = set_zero_encoding(BitsPerSymbol(seq), chunk)
ze = set_zero_encoding(BitsPerSymbol(seq), @inbounds(data[body_i]) enc)
if !iszero(ze)
return i - div(leading_zeros(ze) % UInt, bits_per_symbol(Alphabet(seq))) % Int
end
Expand Down
11 changes: 11 additions & 0 deletions test/longsequences/basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,13 @@ end
ACGTN
"""
@test a == b

a = randdnaseq(512)
a[111] = DNA_G
b = copy(a)
@test a == b
a[111] = DNA_C
@test a != b
end

@testset "RNA" begin
Expand Down Expand Up @@ -373,6 +380,10 @@ end
X
"""
@test a == b

a = randaaseq(131)
b = copy(a)
@test a == b
end
end

Expand Down
10 changes: 10 additions & 0 deletions test/longsequences/find.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,16 @@
@test findnext(==(DNA_M), seq, 8) == 8
@test findnext(==(DNA_M), seq, 9) == 21

# Very large sequence so it hits the inner SIMD loop
seq = randdnaseq(541)
seq[100] = DNA_Gap
@test findfirst(==(DNA_Gap), seq) == findfirst(isgap, seq) == 100
seq[100] = DNA_A
seq[500] = DNA_Gap
@test findlast(==(DNA_Gap), seq) == findlast(isgap, seq) == 500
seq[500] = DNA_A
@test findfirst(==(DNA_Gap), seq) == findlast(isgap, seq) == nothing

# View with only tail
# 1234
seq = view(aa"KWYPAV-L", 3:6)
Expand Down
14 changes: 14 additions & 0 deletions test/longsequences/seqview.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,20 @@ end
@test String(seq) == "ANKYH"
end

@testset "Equality" begin
for size in [41, 504, 7]
for offset in [0, 3, 32]
seq = randrnaseq(size)
seq2 = view(randrnaseq(offset) * seq * randrnaseq(15), offset+1:offset+size)
@test seq == seq2
seq[4] = RNA_Gap
@test seq != seq2
seq3 = view(seq, 1:size)
@test seq == seq3
end
end
end

# Added after issue 260
@testset "Random construction" begin
for i in 1:100
Expand Down

0 comments on commit 9a3f893

Please sign in to comment.