From 176dd5b6c57f24f63d4ef23a57cd6483f1fef395 Mon Sep 17 00:00:00 2001 From: Trenton w Fleming Date: Thu, 18 Jan 2024 00:24:42 -0500 Subject: [PATCH] #360 Bwt (#411) * basic bitvector * jacobsons start and refactor to uint for accurate machine words * confident that jacobson rank is working * reusing the incoming bitvector instead of copying everyithing for jacobson rank * access and bounds checking * just do uint64 for simplicity. bound checking and access * bit vector fixes, rsa good enough, wavelet start * Simple wavelet tree with access * wavelet fix access, add select, fix rsa bitvector select * got count working, but had to throw out jacobsons * rsa fixes and refactors * bwt locate * extract * doc BWT, refactor, and return a possible error during construction * add TODO about sorting and the nullChar * bwt examples * wavelet tree doc * wavelet tree explanation * doc and note for waveletTree * add bwt high level. move wavelet tree's some rsa bv docs * simplify bitvector, docs for bitvector and rsaBitvector * Cite Ben Langmead. --------- Co-authored-by: Willow Carretero Chavez Co-authored-by: Timothy Stiles --- CHANGELOG.md | 3 + bwt/bitvector.go | 77 ++++++ bwt/bitvector_test.go | 119 ++++++++++ bwt/bwt.go | 482 ++++++++++++++++++++++++++++++++++++++ bwt/bwt_test.go | 419 +++++++++++++++++++++++++++++++++ bwt/example_test.go | 92 ++++++++ bwt/rsa_bitvector.go | 192 +++++++++++++++ bwt/rsa_bitvector_test.go | 353 ++++++++++++++++++++++++++++ bwt/wavelet.go | 433 ++++++++++++++++++++++++++++++++++ bwt/wavelet_test.go | 285 ++++++++++++++++++++++ 10 files changed, 2455 insertions(+) create mode 100644 bwt/bitvector.go create mode 100644 bwt/bitvector_test.go create mode 100644 bwt/bwt.go create mode 100644 bwt/bwt_test.go create mode 100644 bwt/example_test.go create mode 100644 bwt/rsa_bitvector.go create mode 100644 bwt/rsa_bitvector_test.go create mode 100644 bwt/wavelet.go create mode 100644 bwt/wavelet_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index e1f8cc00..2db7607d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added +- Basic BWT for sub-sequence count and offset for sequence alignment. Only supports exact matches for now. + ## [0.30.0] - 2023-12-18 Oops, we weren't keeping a changelog before this tag! diff --git a/bwt/bitvector.go b/bwt/bitvector.go new file mode 100644 index 00000000..d6cbf814 --- /dev/null +++ b/bwt/bitvector.go @@ -0,0 +1,77 @@ +package bwt + +import ( + "fmt" + "math" +) + +const wordSize = 64 + +// bitvector a sequence of 1's and 0's. You can also think +// of this as an array of bits. This allows us to encode +// data in a memory efficient manner. +type bitvector struct { + bits []uint64 + numberOfBits int +} + +// newBitVector will return an initialized bitvector with +// the specified number of zeroed bits. +func newBitVector(initialNumberOfBits int) bitvector { + capacity := getNumOfBitSetsNeededForNumOfBits(initialNumberOfBits) + bits := make([]uint64, capacity) + return bitvector{ + bits: bits, + numberOfBits: initialNumberOfBits, + } +} + +// getBitSet gets the while word as some offset from the +// bitvector. Useful if you'd prefer to work with the +// word rather than with individual bits. +func (b bitvector) getBitSet(bitSetPos int) uint64 { + return b.bits[bitSetPos] +} + +// getBit returns the value of the bit at a given offset +// True represents 1 +// False represents 0 +func (b bitvector) getBit(i int) bool { + b.checkBounds(i) + + chunkStart := i / wordSize + offset := i % wordSize + + return (b.bits[chunkStart] & (uint64(1) << (63 - offset))) != 0 +} + +// setBit sets the value of the bit at a given offset +// True represents 1 +// False represents 0 +func (b bitvector) setBit(i int, val bool) { + b.checkBounds(i) + + chunkStart := i / wordSize + offset := i % wordSize + + if val { + b.bits[chunkStart] |= uint64(1) << (63 - offset) + } else { + b.bits[chunkStart] &= ^(uint64(1) << (63 - offset)) + } +} + +func (b bitvector) checkBounds(i int) { + if i >= b.len() || i < 0 { + msg := fmt.Sprintf("access of %d is out of bounds for bitvector with length %d", i, b.len()) + panic(msg) + } +} + +func (b bitvector) len() int { + return b.numberOfBits +} + +func getNumOfBitSetsNeededForNumOfBits(n int) int { + return int(math.Ceil(float64(n) / wordSize)) +} diff --git a/bwt/bitvector_test.go b/bwt/bitvector_test.go new file mode 100644 index 00000000..44a1ac11 --- /dev/null +++ b/bwt/bitvector_test.go @@ -0,0 +1,119 @@ +package bwt + +import ( + "testing" +) + +type GetBitTestCase struct { + position int + expected bool +} + +func TestBitVector(t *testing.T) { + initialNumberOfBits := wordSize*10 + 1 + + bv := newBitVector(initialNumberOfBits) + + if bv.len() != initialNumberOfBits { + t.Fatalf("expected len to be %d but got %d", initialNumberOfBits, bv.len()) + } + + for i := 0; i < initialNumberOfBits; i++ { + bv.setBit(i, true) + } + + bv.setBit(3, false) + bv.setBit(11, false) + bv.setBit(13, false) + bv.setBit(23, false) + bv.setBit(24, false) + bv.setBit(25, false) + bv.setBit(42, false) + bv.setBit(63, false) + bv.setBit(64, false) + bv.setBit(255, false) + bv.setBit(256, false) + + getBitTestCases := []GetBitTestCase{ + {0, true}, + {1, true}, + {2, true}, + {3, false}, + {4, true}, + {7, true}, + {8, true}, + {9, true}, + {10, true}, + {11, false}, + {12, true}, + {13, false}, + {23, false}, + {24, false}, + {25, false}, + {42, false}, + {15, true}, + {16, true}, + {62, true}, + {63, false}, + {64, false}, + // Test past the first word + {65, true}, + {72, true}, + {79, true}, + {80, true}, + {255, false}, + {256, false}, + {511, true}, + {512, true}, + } + + for _, v := range getBitTestCases { + actual := bv.getBit(v.position) + if actual != v.expected { + t.Fatalf("expected %dth bit to be %t but got %t", v.position, v.expected, actual) + } + } +} + +func TestBitVectorBoundPanic_GetBit_Lower(t *testing.T) { + defer func() { _ = recover() }() + + initialNumberOfBits := wordSize*10 + 1 + bv := newBitVector(initialNumberOfBits) + bv.getBit(-1) + + t.Fatalf("expected get bit lower bound panic") +} + +func TestBitVectorBoundPanic_GetBit_Upper(t *testing.T) { + defer func() { _ = recover() }() + initialNumberOfBits := wordSize*10 + 1 + bv := newBitVector(initialNumberOfBits) + bv.getBit(initialNumberOfBits) + + t.Fatalf("expected get bit upper bound panic") +} + +func TestBitVectorBoundPanic_SetBit_Lower(t *testing.T) { + defer func() { + if r := recover(); r != nil { + return + } + t.Fatalf("expected set bit lower bound panic") + }() + initialNumberOfBits := wordSize*10 + 1 + bv := newBitVector(initialNumberOfBits) + bv.setBit(-1, true) +} + +func TestBitVectorBoundPanic_SetBit_Upper(t *testing.T) { + defer func() { + if r := recover(); r != nil { + return + } + t.Fatalf("expected set bit upper bound panic") + }() + initialNumberOfBits := wordSize*10 + 1 + bv := newBitVector(initialNumberOfBits) + bv.setBit(initialNumberOfBits, true) +} diff --git a/bwt/bwt.go b/bwt/bwt.go new file mode 100644 index 00000000..66020166 --- /dev/null +++ b/bwt/bwt.go @@ -0,0 +1,482 @@ +package bwt + +import ( + "errors" + "fmt" + "math" + "strings" + + "golang.org/x/exp/slices" +) + +/* + +For the BWT usage, please read the BWT methods +below. To understand what it is and how +it works for either curiosity or maintenance, then read below. + +# BWT + +BWT Stands for (B)urrows-(W)heeler (T)ransform. The BWT aids in +text compression and acts as a search index for any arbitrary +sequence of characters. With the BWT and some auxiliary data +structures, we can analyze a sequence in a memory and run time +efficient manner. + +## BWT Transform + +The first step to build the BWT is to get the BWT itself. + +This is done by: +1. Appending a null terminating character to the end of a sequence +2. Rotate the sequence so that the last character is now the first +3. Repeat 2. N times where N is the length of the sequence +4. Lexicographically sort the NxN matrix of rotated sequences where + the null termination character is always the least-valued +5. Store the first and last column of the matrix. The last column + is the output of the BWT. The first column is needed to run queries + on the BWT of the original sequence. + +Lets use banana as an example. + +banana$ $banana +$banana a$banan +a$banan ana$ban +na$bana => anana$b +ana$ban banana$ +nana$ba na$bana +anana$b nana$ba + +Output: + +Last Column (BWT): annb$aa +First Column: $aaabnn + +## LF Mapping Properties + +From now on we will refer to the Last Column as L and the First as F + +There are a few special properties here to be aware of. First, notice +how the characters of the same rank show up in the same order for each +column: + +L: a0 n0 n1 b0 $0 a1 a2 + +F: $0 a0 a1 a2 b0 n0 n1 + +That is to say the characters' rank for each column appear in ascending +order. For example: a0 < a1 < a2. This is true for all BWTs + +The other important property to observe is that since the BWT is the +result of rotating each string, each character in the L column precedes +the corresponding character in the F column. + +To best way to show this is to rebuild the original sequence +using the F and L columns. We do this by rebuilding the original string +in reverse order starting with the nullChar. + +Original string: ______$0 +F($0) -> L(a0) -> _____a0$0 +F(a0) -> L(n0) -> ____n0a0$0 +F(n0) -> L(a1) -> ___a1n0a0$0 +F(a1) -> L(n1) -> __n1a1n0a0$0 +F(n1) -> L(a2) -> _a2n1a1n0a0$0 +F(a2) -> L(b0) -> b0a2n1a1n0a0$0 +F(b0) -> L($0) -> Complete + +If we take the rank subscripts away from: b0a2n1a1n0a0$0 +We get... "banana$" ! + +## LF Mapping Usage + +From these properties, the most important concept emerges- the LF Mapping. +The LF mapping is what enables us to query and analyze the BWT to gain +insight about the original sequence. + +For example, let's say we wanted to count the number of occurrences of the +pattern "ana" in "banana". We can do this by: + +1. Lookup the last char of the sequence, a, in the F column +2. Find that range of a's, [1, 4) +3. Take the next previous character in the pattern, n +4. Find the rank of n before the range from 2. [0, 1) = 0 +5. Find the rank of n in the range from 2. [1, 4) = 2 +6. Look up the start range of the n's in the F column, 5 +7. Add the result from 4 and 5 respectively to form the next + L search range: [5+0, 5+2) = [5, 7) +8. Take next previous character in the pattern, a +9. Take the rank of "a" before position 5, which is 1 +10. Take the rank of "a" before position 7, which is 3 +11. Lookup the a's in the F column again, but add the results + from 9 and 10 to the start range to get the next search + range = [1+1, 1+3) = [2, 4) +12. That is beginning of our pattern, we sub subtract the end and start + of the search range to get our count, 4-2=2 + +Another way to look at this is that we are constantly refining our search +range for each character of the pattern we are searching for. Once we +reach the end of the pattern, our final range represents the a's which +start our pattern. If the range < 0, then at some point our search ranged +has collapsed and we can conclude that there is no matching pattern. + +## Suffix Array + +For other operations such as Locate and Extract, we need another auxiliary +data structure, the suffix array. Since rows of the BWT can map to any +position within the original sequence, we need some kind of reference as to +which BWT rows map to which positions in the original sequence. We can do this by storing +the positions of each character from the original sequence to each of the corresponding +rows in the BWT column. With our banana example: + +F: $0 a0 a1 a2 b0 n0 n1 +SA: [6 5 3 1 0 4 2] + +If we take our count example for the pattern "ana" above, you'll remember +that our final search range was [2, 4). You'll also remember that we counted +2 occurrences of "ana" by subtracting the end of the range from the start, 4-2=2. +If iterate from 2 to 4, we can lookup the corresponding SA entry for the BWT rows 2 and 3. +If we look up 2 in the SA, we'll find that our first offset is at position 3 in the original sequence ban"ana" +If we look up 3 in the SA, we'll find that our second offset is at position 1 in the original sequence b"ana"na + +## Notes on Performance + +The explanation above leads to a very naive implementation. For example, +having the full SA would take way more memory than the BWT itself. Assuming +int64, that would 8 times the amount of memory of the BWT in its plain text +representation! In the implementation below, we may instead sample the SA +and do additional look ups as needed to find the offsets we need. + +Similarly, storing both the F and L column as plain text would take double the +amount of memory to store the original sequence... BWT is used for text +compression, not expansion! That's why in the below implementation, you +will see other data structures that lower the amount of memory +needed. You will also notice that we can make huge improvements by +compressing sequences by runs of characters like with the F column. + +Instead of: + +F: $0 a0 a1 a2 b0 n0 n1 + +Since F is lexicographically sorted, we can have: + +F: {$: [0, 1)}, {a: [1, 4)}, {b: [4, 5)} {n: [5, 7)} + +Although these performance enhancements may lead to a different implementation to what is +described above, any implementation will just be an LF mapping- just with a few more steps. + + +NOTE: The above is just to explain what is happening at a high level. Please +reference the implementation below to see how the BWT is actually currently +working + +Many of the Ideas come from Ben Langmead. +He has a whole YouTube playlist about BWT Indexing: https://www.youtube.com/watch?v=5G2Db41pSHE&list=PL2mpR0RYFQsADmYpW2YWBrXJZ_6EL_3nu +*/ + +const nullChar = "$" + +// BWT Burrows-Wheeler Transform +// Compresses and Indexes a given sequence so that it can be +// be used for search, alignment, and text extraction. This is +// useful for sequences so large that it would be beneficial +// to reduce its memory footprint while also maintaining a way +// to analyze and work with the sequence. +type BWT struct { + // firstColumnSkipList is the first column of the BWT. It is + // represented as a list of skipEntries because the first column of + // the BWT is always lexicographically ordered. This saves time and memory. + firstColumnSkipList []skipEntry + // Column last column of the BWT- the actual textual representation + // of the BWT. + lastColumn waveletTree + // suffixArray an array that allows us to map a position in the first + // column to a position in the original sequence. This is needed to be + // able to extract text from the BWT. + suffixArray []int +} + +// Count represents the number of times the provided pattern +// shows up in the original sequence. +func (bwt BWT) Count(pattern string) (count int, err error) { + defer bwtRecovery("Count", &err) + err = isValidPattern(pattern) + if err != nil { + return 0, err + } + + searchRange := bwt.lfSearch(pattern) + return searchRange.end - searchRange.start, nil +} + +// Locate returns a list of offsets at which the beginning +// of the provided pattern occurs in the original +// sequence. +func (bwt BWT) Locate(pattern string) (offsets []int, err error) { + defer bwtRecovery("Locate", &err) + err = isValidPattern(pattern) + if err != nil { + return nil, err + } + + searchRange := bwt.lfSearch(pattern) + if searchRange.start >= searchRange.end { + return nil, nil + } + + numOfOffsets := searchRange.end - searchRange.start + offsets = make([]int, numOfOffsets) + for i := 0; i < numOfOffsets; i++ { + offsets[i] = bwt.suffixArray[searchRange.start+i] + } + + return offsets, nil +} + +// Extract this allows us to extract parts of the original +// sequence from the BWT. +// start is the beginning of the range of text to extract inclusive. +// end is the end of the range of text to extract exclusive. +// If either start or end are out of bounds, Extract will panic. +func (bwt BWT) Extract(start, end int) (extracted string, err error) { + defer bwtRecovery("Extract", &err) + err = validateRange(start, end) + if err != nil { + return "", err + } + + if end > bwt.getLenOfOriginalStringWithNullChar()-1 { + return "", fmt.Errorf("end [%d] exceeds the max range of the BWT [%d]", end, bwt.getLenOfOriginalStringWithNullChar()-1) + } + + if start < 0 { + return "", fmt.Errorf("start [%d] exceeds the min range of the BWT [0]", start) + } + + strB := strings.Builder{} + for i := start; i < end; i++ { + fPos := bwt.getFCharPosFromOriginalSequenceCharPos(i) + skip := bwt.lookupSkipByOffset(fPos) + strB.WriteByte(skip.char) + } + + return strB.String(), nil +} + +// Len return the length of the sequence used to build the BWT +func (bwt BWT) Len() int { + return bwt.getLenOfOriginalStringWithNullChar() - 1 +} + +// GetTransform returns the last column of the BWT transform of the original sequence. +func (bwt BWT) GetTransform() string { + return bwt.lastColumn.reconstruct() +} + +// getFCharPosFromOriginalSequenceCharPos looks up mapping from the original position +// of the sequence to its corresponding position in the First Column of the BWT +// NOTE: This clearly isn't ideal. Instead of improving this implementation, this will be replaced with +// something like r-index in the near future. +func (bwt BWT) getFCharPosFromOriginalSequenceCharPos(originalPos int) int { + for i := range bwt.suffixArray { + if bwt.suffixArray[i] == originalPos { + return i + } + } + panic("Unable to find the corresponding original position for a character in the original sequence in the suffix array. This should not be possible and indicates a malformed BWT.") +} + +// lfSearch LF Search- Last First Search. +// Finds the valid range within the BWT index where the provided pattern is possible. +// If the final range is <= 0, then the pattern does not exist in the original sequence. +func (bwt BWT) lfSearch(pattern string) interval { + searchRange := interval{start: 0, end: bwt.getLenOfOriginalStringWithNullChar()} + for i := 0; i < len(pattern); i++ { + if searchRange.end-searchRange.start <= 0 { + return interval{} + } + + c := pattern[len(pattern)-1-i] + skip, ok := bwt.lookupSkipByChar(c) + if !ok { + return interval{} + } + searchRange.start = skip.openEndedInterval.start + bwt.lastColumn.Rank(c, searchRange.start) + searchRange.end = skip.openEndedInterval.start + bwt.lastColumn.Rank(c, searchRange.end) + } + return searchRange +} + +// lookupSkipByChar looks up a skipEntry by its character in the First Column +func (bwt BWT) lookupSkipByChar(c byte) (entry skipEntry, ok bool) { + for i := range bwt.firstColumnSkipList { + if bwt.firstColumnSkipList[i].char == c { + return bwt.firstColumnSkipList[i], true + } + } + return skipEntry{}, false +} + +// lookupSkipByOffset looks up a skipEntry based off of an +// offset of the Fist Column of the BWT. +func (bwt BWT) lookupSkipByOffset(offset int) skipEntry { + if offset > bwt.getLenOfOriginalStringWithNullChar()-1 { + msg := fmt.Sprintf("offset [%d] exceeds the max bound of the BWT [%d]", offset, bwt.getLenOfOriginalStringWithNullChar()-1) + panic(msg) + } + if offset < 0 { + msg := fmt.Sprintf("offset [%d] exceeds the min bound of the BWT [0]", offset) + panic(msg) + } + + for skipIndex := range bwt.firstColumnSkipList { + if bwt.firstColumnSkipList[skipIndex].openEndedInterval.start <= offset && offset < bwt.firstColumnSkipList[skipIndex].openEndedInterval.end { + return bwt.firstColumnSkipList[skipIndex] + } + } + msg := fmt.Sprintf("could not find the skip entry that falls within the range of the skip column at a given offset. range: [0, %d) offset: %d", bwt.getLenOfOriginalStringWithNullChar(), offset) + panic(msg) +} + +func (bwt BWT) getLenOfOriginalStringWithNullChar() int { + return bwt.firstColumnSkipList[len(bwt.firstColumnSkipList)-1].openEndedInterval.end +} + +type interval struct { + start int + end int +} + +type skipEntry struct { + char byte + // openEndedInterval start is inclusive and end is exclusive + openEndedInterval interval +} + +// New returns a BWT of the provided sequence +// The provided sequence must not contain the nullChar +// defined in this package. If it does, New will return +// an error. +func New(sequence string) (BWT, error) { + err := validateSequenceBeforeTransforming(&sequence) + if err != nil { + return BWT{}, err + } + + sequence += nullChar + + prefixArray := make([]string, len(sequence)) + for i := 0; i < len(sequence); i++ { + prefixArray[i] = sequence[len(sequence)-i-1:] + } + + sortPrefixArray(prefixArray) + + suffixArray := make([]int, len(sequence)) + lastColBuilder := strings.Builder{} + for i := 0; i < len(prefixArray); i++ { + currChar := sequence[getBWTIndex(len(sequence), len(prefixArray[i]))] + lastColBuilder.WriteByte(currChar) + + suffixArray[i] = len(sequence) - len(prefixArray[i]) + } + fb := strings.Builder{} + for i := 0; i < len(prefixArray); i++ { + fb.WriteByte(prefixArray[i][0]) + } + + wt, err := newWaveletTreeFromString(lastColBuilder.String()) + if err != nil { + return BWT{}, err + } + + return BWT{ + firstColumnSkipList: buildSkipList(prefixArray), + lastColumn: wt, + suffixArray: suffixArray, + }, nil +} + +// buildSkipList compressed the First Column of the BWT into a skip list +func buildSkipList(prefixArray []string) []skipEntry { + prevChar := prefixArray[0][0] + skipList := []skipEntry{{char: prevChar, openEndedInterval: interval{start: 0}}} + for i := 1; i < len(prefixArray); i++ { + currChar := prefixArray[i][0] + if currChar != prevChar { + skipList[len(skipList)-1].openEndedInterval.end = i + skipList = append(skipList, skipEntry{ + char: currChar, + openEndedInterval: interval{start: i}, + }) + prevChar = currChar + } + } + skipList[len(skipList)-1].openEndedInterval.end = len(prefixArray) + return skipList +} + +// getBWTIndex helps us calculate the corresponding character that would +// be in the L column without having to rotate the full string. +// For example: +// Original string: banana$ +// Rotation: ana$___ +// Position: 7-4-1= 2 +// Original[3]: n +func getBWTIndex(lenOfSequenceBeingBuilt, lenOfSuffixArrayVisited int) int { + bwtCharIndex := lenOfSequenceBeingBuilt - lenOfSuffixArrayVisited - 1 + if bwtCharIndex == -1 { + bwtCharIndex = lenOfSequenceBeingBuilt - 1 + } + return bwtCharIndex +} + +func sortPrefixArray(prefixArray []string) { + slices.SortFunc(prefixArray, func(a, b string) bool { + minLen := int(math.Min(float64(len(a)), float64(len(b)))) + for i := 0; i < minLen; i++ { + if a[i] == b[i] { + continue + } + if a[i] == nullChar[0] { + return true + } + if b[i] == nullChar[0] { + return false + } + return a[i] < b[i] + } + + return len(a) < len(b) + }) +} + +func bwtRecovery(operation string, err *error) { + if r := recover(); r != nil { + rErr := fmt.Errorf("BWT %s InternalError=%s", operation, r) + *err = rErr + } +} + +func isValidPattern(s string) (err error) { + if len(s) == 0 { + return errors.New("Pattern can not be empty") + } + return nil +} + +func validateRange(start, end int) (err error) { + if start >= end { + return errors.New("Start must be strictly less than end") + } + return nil +} + +func validateSequenceBeforeTransforming(sequence *string) (err error) { + if len(*sequence) == 0 { + return fmt.Errorf("Provided sequence must not by empty. BWT cannot be constructed") + } + if strings.Contains(*sequence, nullChar) { + return fmt.Errorf("Provided sequence contains the nullChar %s. BWT cannot be constructed", nullChar) + } + return nil +} diff --git a/bwt/bwt_test.go b/bwt/bwt_test.go new file mode 100644 index 00000000..7b5512ff --- /dev/null +++ b/bwt/bwt_test.go @@ -0,0 +1,419 @@ +package bwt + +import ( + "strings" + "testing" + + "golang.org/x/exp/slices" +) + +type BWTCountTestCase struct { + seq string + expected int +} + +func TestBWT_Count(t *testing.T) { + baseTestStr := "thequickbrownfoxjumpsoverthelazydogwithanovertfrownafterfumblingitsparallelogramshapedbananagramallarounddowntown" + testStr := strings.Join([]string{baseTestStr, baseTestStr, baseTestStr}, "") + + bwt, err := New(testStr) + if err != nil { + t.Fatal(err) + } + + testTable := []BWTCountTestCase{ + {"uick", 3}, + {"the", 6}, + {"over", 6}, + {"own", 12}, + {"ana", 6}, + {"an", 9}, + {"na", 9}, + {"rown", 6}, + {"townthe", 2}, + + // patterns that should not exist + {"zzz", 0}, + {"clown", 0}, + {"crown", 0}, + {"spark", 0}, + {"brawn", 0}, + } + + for _, v := range testTable { + count, err := bwt.Count(v.seq) + if err != nil { + t.Fatalf("seq=%s unexpectedError=%s", v.seq, err) + } + if count != v.expected { + t.Fatalf("seq=%s expectedCount=%v actualCount=%v", v.seq, v.expected, count) + } + } +} + +func TestBWT_Count_EmptyPattern(t *testing.T) { + testStr := "banana" + bwt, err := New(testStr) + if err != nil { + t.Fatal(err) + } + _, err = bwt.Count("") + if err == nil { + t.Fatal("Expected error for empty pattern but got nil") + } +} + +type BWTLocateTestCase struct { + seq string + expected []int +} + +func TestBWT_Locate(t *testing.T) { + baseTestStr := "thequickbrownfoxjumpsoverthelazydogwithanovertfrownafterfumblingitsparallelogramshapedbananagramallarounddowntown" // len == 112 + testStr := strings.Join([]string{baseTestStr, baseTestStr, baseTestStr}, "") + + bwt, err := New(testStr) + if err != nil { + t.Fatal(err) + } + + testTable := []BWTLocateTestCase{ + {"uick", []int{4, 117, 230}}, + {"the", []int{0, 25, 113, 138, 226, 251}}, + {"over", []int{21, 41, 134, 154, 247, 267}}, + {"own", []int{10, 48, 106, 110, 123, 161, 219, 223, 236, 274, 332, 336}}, + {"ana", []int{87, 89, 200, 202, 313, 315}}, + {"an", []int{39, 87, 89, 152, 200, 202, 265, 313, 315}}, + {"na", []int{50, 88, 90, 163, 201, 203, 276, 314, 316}}, + {"rown", []int{9, 47, 122, 160, 235, 273}}, + {"townthe", []int{109, 222}}, + + // patterns that should not exist + {"zzz", nil}, + {"clown", nil}, + {"crown", nil}, + {"spark", nil}, + {"brawn", nil}, + } + + for _, v := range testTable { + offsets, err := bwt.Locate(v.seq) + if err != nil { + t.Fatalf("seq=%s unexpectedError=%s", v.seq, err) + } + slices.Sort(offsets) + if len(offsets) != len(v.expected) { + t.Fatalf("seq=%s expectedOffsets=%v actualOffsets=%v", v.seq, v.expected, offsets) + } + for i := range offsets { + if offsets[i] != v.expected[i] { + t.Fatalf("seq=%s expectedOffsets=%v actualOffsets=%v", v.seq, v.expected, offsets) + } + } + } +} + +func TestBWT_Locate_EmptyPattern(t *testing.T) { + testStr := "banana" + bwt, err := New(testStr) + if err != nil { + t.Fatal(err) + } + _, err = bwt.Locate("") + if err == nil { + t.Fatal("Expected error for empty pattern but got nil") + } +} + +type BWTExtractTestCase struct { + start int + end int + expected string +} + +func TestBWT_Extract(t *testing.T) { + baseTestStr := "thequickbrownfoxjumpsoverthelazydogwithanovertfrownafterfumblingitsparallelogramshapedbananagramallarounddowntown" // len == 112 + testStr := strings.Join([]string{baseTestStr, baseTestStr, baseTestStr}, "") + + bwt, err := New(testStr) + if err != nil { + t.Fatal(err) + } + + testTable := []BWTExtractTestCase{ + {4, 8, "uick"}, + {117, 121, "uick"}, + {230, 234, "uick"}, + {0, 3, "the"}, + {25, 28, "the"}, + {113, 116, "the"}, + {138, 141, "the"}, + {226, 229, "the"}, + {251, 254, "the"}, + {21, 25, "over"}, + {41, 45, "over"}, + {134, 138, "over"}, + {154, 158, "over"}, + {247, 251, "over"}, + {267, 271, "over"}, + {10, 13, "own"}, + {48, 51, "own"}, + {106, 109, "own"}, + {123, 126, "own"}, + {161, 164, "own"}, + {219, 222, "own"}, + {223, 226, "own"}, + {236, 239, "own"}, + {274, 277, "own"}, + {332, 335, "own"}, + {336, 339, "own"}, + {87, 90, "ana"}, + {89, 92, "ana"}, + {200, 203, "ana"}, + {202, 205, "ana"}, + {313, 316, "ana"}, + {315, 318, "ana"}, + {39, 41, "an"}, + {87, 89, "an"}, + {152, 154, "an"}, + {200, 202, "an"}, + {202, 204, "an"}, + {265, 267, "an"}, + {313, 315, "an"}, + {50, 52, "na"}, + {88, 90, "na"}, + {163, 165, "na"}, + {201, 203, "na"}, + {203, 205, "na"}, + {276, 278, "na"}, + {314, 316, "na"}, + {316, 318, "na"}, + {9, 13, "rown"}, + {47, 51, "rown"}, + {122, 126, "rown"}, + {160, 164, "rown"}, + {235, 239, "rown"}, + {273, 277, "rown"}, + {109, 116, "townthe"}, + {222, 229, "townthe"}, + } + + for _, v := range testTable { + str, err := bwt.Extract(v.start, v.end) + if err != nil { + t.Fatalf("extractRange=(%d, %d) unexpectedError=%s", v.start, v.end, err) + } + if str != v.expected { + t.Fatalf("extractRange=(%d, %d) expected=%s actual=%s", v.start, v.end, v.expected, str) + } + } +} + +func TestBWT_Extract_InvalidRanges(t *testing.T) { + testStr := "banana" + bwt, err := New(testStr) + if err != nil { + t.Fatal(err) + } + _, err = bwt.Extract(5, 4) + if err == nil { + t.Fatal("Expected error for invalid range but got nil") + } + _, err = bwt.Extract(4, 4) + if err == nil { + t.Fatal("Expected error for invalid range but got nil") + } +} + +func TestBWT_Extract_DoNotAllowExtractionOfLastNullChar(t *testing.T) { + testStr := "banana" + + bwt, err := New(testStr) + if err != nil { + t.Fatal(err) + } + + str, err := bwt.Extract(0, 6) + if err != nil { + t.Fatalf("extractRange=(%d, %d) unexpectedError=%s", 0, 6, err) + } + if str != testStr { + t.Fatalf("extractRange=(%d, %d) expected=%s actual=%s", 0, 6, testStr, str) + } + + _, err = bwt.Extract(0, 7) + + if err == nil { + t.Fatalf("extractRange=(%d, %d) expected err but was nil", 0, 7) + } + + if !strings.Contains(err.Error(), "exceeds the max range") { + t.Fatalf("expected error to contain \"exceeds the max range\" but received \"%s\"", err) + } +} + +func TestBWT_Len(t *testing.T) { + testStr := "banana" + + bwt, err := New(testStr) + if err != nil { + t.Fatal(err) + } + + if bwt.Len() != len(testStr) { + t.Fatalf("expected Len to be %d but got %d", len(testStr), bwt.Len()) + } +} + +func TestNewBWTWithSequenceContainingNullChar(t *testing.T) { + nc := nullChar + testStr := "banana" + nc + + _, err := New(testStr) + if err == nil { + t.Fatal("expected error but got nil") + } +} + +func TestNewBWTEmptySequence(t *testing.T) { + testStr := "" + + _, err := New(testStr) + if err == nil { + t.Fatal("expected error but got nil") + } +} + +// TestBWTReconstruction this helps us ensure that the LF mapping is correct and that the suffix array lookup +// must be well formed. Otherwise, we would not be able to recreate the original sequence. +func TestBWTReconstruction(t *testing.T) { + baseTestStr := "thequickbrownfoxjumpsoverthelazydogwithanovertfrownafterfumblingitsparallelogramshapedbananagramallarounddowntown" + testStr := strings.Join([]string{baseTestStr, baseTestStr, baseTestStr}, "") + + bwt, err := New(testStr) + if err != nil { + t.Fatal(err) + } + + extracted, err := bwt.Extract(0, bwt.Len()) + if err != nil { + t.Fatal(err) + } + if extracted != testStr { + t.Log("Reconstruction failed") + t.Log("Expected:\t", testStr) + t.Log("Actual:\t", extracted) + t.Fail() + } + + // This will either result in an even or all alphabet. The alphabet matters. + testStrWithOneMoreAlpha := testStr + "!" + bwt, err = New(testStrWithOneMoreAlpha) + if err != nil { + t.Fatal(err) + } + extracted, err = bwt.Extract(0, bwt.Len()) + if err != nil { + t.Fatal(err) + } + if extracted != testStrWithOneMoreAlpha { + t.Log("Reconstruction failed with extra alpha character") + t.Log("Expected:\t", testStr) + t.Log("Actual:\t", extracted) + t.Fail() + } +} + +func TestBWTStartError(t *testing.T) { + testStr := "banana" + + bwt, err := New(testStr) + if err != nil { + t.Fatal(err) + } + + _, err = bwt.Extract(-1, 6) + if err == nil { + t.Fatal("expected error but got nil") + } +} +func TestBWT_GetFCharPosFromOriginalSequenceCharPos_Panic(t *testing.T) { + testStr := "banana" + bwt, err := New(testStr) + if err != nil { + t.Fatal(err) + } + + // Call the function with an invalid original position + originalPos := -1 + defer func() { + if r := recover(); r == nil { + t.Errorf("Expected panic, but it did not occur") + } + }() + bwt.getFCharPosFromOriginalSequenceCharPos(originalPos) +} +func TestBWT_LFSearch_InvalidChar(t *testing.T) { + testStr := "banana" + bwt, err := New(testStr) + if err != nil { + t.Fatal(err) + } + + pattern := "x" // Invalid character + + result := bwt.lfSearch(pattern) + + if result.start != 0 || result.end != 0 { + t.Fatalf("Expected search range to be (0, 0), but got (%d, %d)", result.start, result.end) + } +} +func TestBWT_LookupSkipByOffset_PanicOffsetExceedsMaxBound(t *testing.T) { + testStr := "banana" + bwt, err := New(testStr) + if err != nil { + t.Fatal(err) + } + + offset := bwt.getLenOfOriginalStringWithNullChar() + defer func() { + if r := recover(); r == nil { + t.Errorf("Expected panic, but it did not occur") + } + }() + bwt.lookupSkipByOffset(offset) +} + +func TestBWT_LookupSkipByOffset_PanicOffsetExceedsMinBound(t *testing.T) { + testStr := "banana" + bwt, err := New(testStr) + if err != nil { + t.Fatal(err) + } + + offset := -1 + defer func() { + if r := recover(); r == nil { + t.Errorf("Expected panic, but it did not occur") + } + }() + bwt.lookupSkipByOffset(offset) +} + +func TestBWTRecovery(t *testing.T) { + // Test panic recovery for bwtRecovery function + var err error + operation := "test operation" + + defer func() { + if err == nil { + t.Fatal("expected bwtRecovery to recover from the panic and set an error message, but got nil") + } + }() + defer bwtRecovery(operation, &err) + doPanic() +} + +func doPanic() { + panic("test panic") +} diff --git a/bwt/example_test.go b/bwt/example_test.go new file mode 100644 index 00000000..2d1e1b71 --- /dev/null +++ b/bwt/example_test.go @@ -0,0 +1,92 @@ +package bwt_test + +import ( + "fmt" + "log" + + "github.com/bebop/poly/bwt" + "golang.org/x/exp/slices" +) + +// This example shows how BWT can be used for exact pattern +// matching by returning the offsets at which the pattern exists. +// This can be useful for alignment when you need need to reduce +// the memory footprint of a reference sequence without loosing +// any data since BWT is a lossless compression. +func ExampleBWT_basic() { + inputSequence := "AACCTGCCGTCGGGGCTGCCCGTCGCGGGACGTCGAAACGTGGGGCGAAACGTG" + + bwt, err := bwt.New(inputSequence) + if err != nil { + log.Fatal(err) + } + + offsets, err := bwt.Locate("GCC") + if err != nil { + log.Fatal(err) + } + slices.Sort(offsets) + fmt.Println(offsets) + // Output: [5 17] +} + +func ExampleBWT_Count() { + inputSequence := "AACCTGCCGTCGGGGCTGCCCGTCGCGGGACGTCGAAACGTGGGGCGAAACGTG" + + bwt, err := bwt.New(inputSequence) + if err != nil { + log.Fatal(err) + } + + count, err := bwt.Count("CG") + if err != nil { + log.Fatal(err) + } + fmt.Println(count) + // Output: 10 +} + +func ExampleBWT_Locate() { + inputSequence := "AACCTGCCGTCGGGGCTGCCCGTCGCGGGACGTCGAAACGTGGGGCGAAACGTG" + + bwt, err := bwt.New(inputSequence) + if err != nil { + log.Fatal(err) + } + + offsets, err := bwt.Locate("CG") + if err != nil { + log.Fatal(err) + } + slices.Sort(offsets) + fmt.Println(offsets) + // Output: [7 10 20 23 25 30 33 38 45 50] +} + +func ExampleBWT_Extract() { + inputSequence := "AACCTGCCGTCGGGGCTGCCCGTCGCGGGACGTCGAAACGTGGGGCGAAACGTG" + + bwt, err := bwt.New(inputSequence) + if err != nil { + log.Fatal(err) + } + + extracted, err := bwt.Extract(48, 54) + if err != nil { + log.Fatal(err) + } + fmt.Println(extracted) + // Output: AACGTG +} + +func ExampleBWT_GetTransform() { + inputSequence := "banana" + + bwt, err := bwt.New(inputSequence) + if err != nil { + log.Fatal(err) + } + + fmt.Println(bwt.GetTransform()) + // Output: annb$aa +} diff --git a/bwt/rsa_bitvector.go b/bwt/rsa_bitvector.go new file mode 100644 index 00000000..fc2e9ceb --- /dev/null +++ b/bwt/rsa_bitvector.go @@ -0,0 +1,192 @@ +package bwt + +import "math/bits" + +// rsaBitVector allows us to perform RSA: (R)ank, (S)elect, and (A)ccess +// queries in a memory performant and memory compact way. +// To learn about how Rank, Select, and Access work, take a look at the +// examples in each respective method. +type rsaBitVector struct { + bv bitvector + totalOnesRank int + jrc []chunk + jrSubChunksPerChunk int + jrBitsPerChunk int + jrBitsPerSubChunk int + oneSelectMap map[int]int + zeroSelectMap map[int]int +} + +// newRSABitVectorFromBitVector allows us to build the auxiliary components +// needed to perform RSA queries on top of the provided bitvector. +// WARNING: Do not modify the underlying bitvector. The rsaBitvector will +// get out of sync with the original bitvector. +func newRSABitVectorFromBitVector(bv bitvector) rsaBitVector { + jacobsonRankChunks, jrSubChunksPerChunk, jrBitsPerSubChunk, totalOnesRank := buildJacobsonRank(bv) + ones, zeros := buildSelectMaps(bv) + + return rsaBitVector{ + bv: bv, + totalOnesRank: totalOnesRank, + jrc: jacobsonRankChunks, + jrSubChunksPerChunk: jrSubChunksPerChunk, + jrBitsPerChunk: jrSubChunksPerChunk * jrBitsPerSubChunk, + jrBitsPerSubChunk: jrBitsPerSubChunk, + oneSelectMap: ones, + zeroSelectMap: zeros, + } +} + +// Rank returns the rank of the given value up to, but not including +// the ith bit. +// For Example: +// Given the bitvector 001000100001 +// Rank(true, 1) = 0 +// Rank(true, 2) = 0 +// Rank(true, 3) = 1 +// Rank(true, 8) = 2 +// Rank(false, 8) = 6 +func (rsa rsaBitVector) Rank(val bool, i int) int { + if i == rsa.bv.len() { + if val { + return rsa.totalOnesRank + } + return rsa.bv.len() - rsa.totalOnesRank + } + + chunkPos := (i / rsa.jrBitsPerChunk) + chunk := rsa.jrc[chunkPos] + + subChunkPos := (i % rsa.jrBitsPerChunk) / rsa.jrBitsPerSubChunk + subChunk := chunk.subChunks[subChunkPos] + + bitOffset := i % rsa.jrBitsPerSubChunk + + bitSet := rsa.bv.getBitSet(chunkPos*rsa.jrSubChunksPerChunk + subChunkPos) + + shiftRightAmount := uint64(rsa.jrBitsPerSubChunk - bitOffset) + if val { + remaining := bitSet >> shiftRightAmount + return chunk.onesCumulativeRank + subChunk.onesCumulativeRank + bits.OnesCount64(remaining) + } + remaining := ^bitSet >> shiftRightAmount + + // cumulative ranks for 0 should just be the sum of the compliment of cumulative ranks for 1 + return (chunkPos*rsa.jrBitsPerChunk - chunk.onesCumulativeRank) + (subChunkPos*rsa.jrBitsPerSubChunk - subChunk.onesCumulativeRank) + bits.OnesCount64(remaining) +} + +// Select returns the position of the given value with the provided Rank +// For Example: +// Given the bitvector 001000100001 +// Select(true, 1) = 2 +// Rank(false, 5) = 5 +// Rank(false, 1) = 1 +// Rank(false, 0) = 0 +func (rsa rsaBitVector) Select(val bool, rank int) (i int, ok bool) { + if val { + i, ok := rsa.oneSelectMap[rank] + return i, ok + } else { + i, ok := rsa.zeroSelectMap[rank] + return i, ok + } +} + +// Access returns the value of a bit at a given offset +func (rsa rsaBitVector) Access(i int) bool { + return rsa.bv.getBit(i) +} + +type chunk struct { + subChunks []subChunk + onesCumulativeRank int +} + +type subChunk struct { + onesCumulativeRank int +} + +/* +buildJacobsonRank Jacobson rank is a succinct data structure. This allows us to represent something +normally would require O(N) worth of memory with less that N memory. Jacobson Rank allows for +sub linear growth. Jacobson rank also allows us to lookup rank for some value of a bitvector in O(1) +time. Theoretically, Jacobson Rank Requires: +1. Creating log(N) "Chunks" +2. Creating 2log(N) "Sub Chunks" +3. Having "Sub Chunks" be 0.5log(N) in length +4. For each "Chunk", store the cumulative rank of set bits relative to the overall bitvector +5. For each "Sub Chunk", store the cumulative rank of set bits relative to the parent "Chunk" +6. We can One's count the N bit word if possible. We will only consider this possibility :) + +For simplicity and all around decent results, we just have "Sub Chunks" of size 64 bits. + +It is O(1) because given some offset i, all we have to do is calculate rank is: +rank = CumulativeRank(ChunkOfi(i))) + CumulativeRank(SubChunkOfi(i))) + OnesCount(SubChunkOfi(i)) + +To understand why it is sub linear in space, you can refer to Ben Langmead and other literature that +describes the space complexity. +https://www.youtube.com/watch?v=M1sUZxXVjG8&list=PL2mpR0RYFQsADmYpW2YWBrXJZ_6EL_3nu&index=7 +*/ +func buildJacobsonRank(inBv bitvector) (jacobsonRankChunks []chunk, numOfSubChunksPerChunk, numOfBitsPerSubChunk, totalRank int) { + numOfSubChunksPerChunk = 4 + + totalRank = 0 + chunkCumulativeRank := 0 + subChunkCumulativeRank := 0 + + var currSubChunks []subChunk + for i := range inBv.bits { + if len(currSubChunks) == numOfSubChunksPerChunk { + jacobsonRankChunks = append(jacobsonRankChunks, chunk{ + subChunks: currSubChunks, + onesCumulativeRank: chunkCumulativeRank, + }) + + chunkCumulativeRank += subChunkCumulativeRank + + currSubChunks = nil + subChunkCumulativeRank = 0 + } + currSubChunks = append(currSubChunks, subChunk{ + onesCumulativeRank: subChunkCumulativeRank, + }) + + onesCount := bits.OnesCount64(inBv.getBitSet(i)) + subChunkCumulativeRank += onesCount + totalRank += onesCount + } + + if currSubChunks != nil { + jacobsonRankChunks = append(jacobsonRankChunks, chunk{ + subChunks: currSubChunks, + onesCumulativeRank: chunkCumulativeRank, + }) + } + + return jacobsonRankChunks, numOfSubChunksPerChunk, wordSize, totalRank +} + +// This is not good. We should find a better means of select- like Clark's Select +func buildSelectMaps(inBv bitvector) (oneSelectMap, zeroSelectMap map[int]int) { + oneSelectMap = make(map[int]int) + zeroSelectMap = make(map[int]int) + oneCount := 0 + zeroCount := 0 + for i := 0; i < inBv.len(); i++ { + bit := inBv.getBit(i) + if bit { + oneSelectMap[oneCount] = i + oneCount++ + } else { + zeroSelectMap[zeroCount] = i + zeroCount++ + } + } + + // Account for the case where we need to find the + // position for the max rank for both 0's and 1's + oneSelectMap[oneCount] = inBv.len() + zeroSelectMap[zeroCount] = inBv.len() + + return oneSelectMap, zeroSelectMap +} diff --git a/bwt/rsa_bitvector_test.go b/bwt/rsa_bitvector_test.go new file mode 100644 index 00000000..d09a9eb2 --- /dev/null +++ b/bwt/rsa_bitvector_test.go @@ -0,0 +1,353 @@ +package bwt + +import ( + "testing" +) + +type rsaRankTestCase struct { + val bool + bitPosition int + expectedRank int +} + +func TestRSARank_singlePartialChunk(t *testing.T) { + if wordSize != 64 { + t.Skip() + } + + bitsToTruncate := 22 + initialNumberOfBits := wordSize*2 - bitsToTruncate + + rsa := newTestRSAFromWords(initialNumberOfBits, + 0xffffffff00000000, + 0x00000000ffc00000, + ) + + testCases := []rsaRankTestCase{ + {true, 0, 0}, {false, 0, 0}, + + {true, 64, 32}, {false, 64, 32}, + + {true, 96, 32}, {false, 96, 64}, + + {true, 105, 41}, {false, 105, 64}, + } + + for _, tc := range testCases { + rank := rsa.Rank(tc.val, tc.bitPosition) + if rank != tc.expectedRank { + t.Fatalf("expected rank(%t, %d) to be %d but got %d", tc.val, tc.bitPosition, tc.expectedRank, rank) + } + } +} + +func TestRSARank_singleCompleteChunk_PastBounds_Ones(t *testing.T) { + rsa := newTestRSAFromWords(64*4, + 0x0000000000000000, + 0xffffffffffffffff, + 0x0000000000000000, + 0xffffffffffffffff, + ) + + testCases := []rsaRankTestCase{ + {true, 0, 0}, {false, 0, 0}, + {true, 255, 127}, {false, 255, 128}, + {true, 256, 128}, {false, 256, 128}, + } + + for _, tc := range testCases { + rank := rsa.Rank(tc.val, tc.bitPosition) + if rank != tc.expectedRank { + t.Fatalf("expected rank(%t, %d) to be %d but got %d", tc.val, tc.bitPosition, tc.expectedRank, rank) + } + } +} + +func TestRSARank_singleCompleteChunk_PastBounds_Zeros(t *testing.T) { + rsa := newTestRSAFromWords(64*4, + 0xffffffffffffffff, + 0x0000000000000000, + 0xffffffffffffffff, + 0x0000000000000000, + ) + + testCases := []rsaRankTestCase{ + {true, 0, 0}, {false, 0, 0}, + {true, 255, 128}, {false, 255, 127}, + {true, 256, 128}, {false, 256, 128}, + } + + for _, tc := range testCases { + rank := rsa.Rank(tc.val, tc.bitPosition) + if rank != tc.expectedRank { + t.Fatalf("expected rank(%t, %d) to be %d but got %d", tc.val, tc.bitPosition, tc.expectedRank, rank) + } + } +} + +func TestRSARank_singleCompleteChunk(t *testing.T) { + initialNumberOfBits := wordSize * 4 + + rsa := newTestRSAFromWords(initialNumberOfBits, + 0x8000000000000001, + 0xff0f30fffacea80d, + 0x90e0a0e0b0e0cf0c, + 0x3d0f064f7206f717, + ) + + testCases := []rsaRankTestCase{ + {true, 0, 0}, {false, 0, 0}, + {true, 1, 1}, {false, 1, 0}, + {true, 2, 1}, {false, 2, 1}, + {true, 3, 1}, {false, 3, 2}, + {true, 62, 1}, {false, 62, 61}, + {true, 63, 1}, {false, 63, 62}, + + {true, 64, 2}, {false, 64, 62}, + {true, 65, 3}, {false, 65, 62}, + {true, 72, 10}, {false, 72, 62}, + {true, 127, 40}, {false, 127, 87}, + + {true, 128, 41}, {false, 128, 87}, + {true, 129, 42}, {false, 129, 87}, + {true, 130, 42}, {false, 130, 88}, + {true, 131, 42}, {false, 131, 89}, + {true, 132, 43}, {false, 132, 89}, + {true, 133, 43}, {false, 133, 90}, + {true, 159, 51}, {false, 159, 108}, + {true, 160, 51}, {false, 160, 109}, + {true, 161, 52}, {false, 161, 109}, + {true, 162, 52}, {false, 162, 110}, + {true, 163, 53}, {false, 163, 110}, + {true, 164, 54}, {false, 164, 110}, + {true, 165, 54}, {false, 165, 111}, + {true, 176, 57}, {false, 176, 119}, + {true, 177, 58}, {false, 177, 119}, + {true, 178, 59}, {false, 178, 119}, + {true, 179, 59}, {false, 179, 120}, + {true, 180, 59}, {false, 180, 121}, + {true, 183, 62}, {false, 183, 121}, + {true, 184, 63}, {false, 184, 121}, + {true, 185, 63}, {false, 185, 122}, + {true, 186, 63}, {false, 186, 123}, + {true, 187, 63}, {false, 187, 124}, + {true, 188, 63}, {false, 188, 125}, + {true, 189, 64}, {false, 189, 125}, + {true, 190, 65}, {false, 190, 125}, + {true, 191, 65}, {false, 191, 126}, + + {true, 192, 65}, {false, 192, 127}, + {true, 193, 65}, {false, 193, 128}, + {true, 194, 65}, {false, 194, 129}, + {true, 195, 66}, {false, 195, 129}, + {true, 196, 67}, {false, 196, 129}, + {true, 248, 94}, {false, 248, 154}, + {true, 249, 94}, {false, 249, 155}, + {true, 250, 94}, {false, 250, 156}, + {true, 251, 94}, {false, 251, 157}, + {true, 252, 95}, {false, 252, 157}, + {true, 253, 95}, {false, 253, 158}, + {true, 254, 96}, {false, 254, 158}, + {true, 255, 97}, {false, 255, 158}, + } + + for _, tc := range testCases { + rank := rsa.Rank(tc.val, tc.bitPosition) + if rank != tc.expectedRank { + t.Fatalf("expected rank(%t, %d) to be %d but got %d", tc.val, tc.bitPosition, tc.expectedRank, rank) + } + } +} + +func TestRSARank_multipleChunks(t *testing.T) { + rsa := newTestRSAFromWords((8*4+3)*64, + 0x0000000000000000, + 0xffffffffffffffff, + 0x0000000000000000, + 0xffffffffffffffff, + + 0xffffffffffffffff, + 0x0000000000000000, + 0xffffffffffffffff, + 0x0000000000000000, + + 0x0000000000000000, + 0xffffffffffffffff, + 0x0000000000000000, + 0xffffffffffffffff, + + 0xffffffffffffffff, + 0x0000000000000000, + 0xffffffffffffffff, + 0x0000000000000000, + + // If Jacobson rank is still there, this should go past the first + // chunk + 0xffffffffffffffff, + 0x0000000000000000, + 0xffffffffffffffff, + 0x0000000000000000, + + 0xffffffffffffffff, + 0x0000000000000000, + 0xffffffffffffffff, + 0x0000000000000000, + + 0xffffffffffffffff, + 0x0000000000000000, + 0xffffffffffffffff, + 0x0000000000000000, + + 0xffffffffffffffff, + 0x0000000000000000, + 0xffffffffffffffff, + 0x0000000000000000, + + // If Jacobson rank is still there, this should go past the second + // chunk + 0xffffffffffffffff, + 0x0000000000000000, + 0xffffffffffffffff, + ) + + testCases := []rsaRankTestCase{ + {true, 0, 0}, {false, 0, 0}, + + {true, 64, 0}, {false, 64, 64}, + {true, 128, 64}, {false, 128, 64}, + {true, 192, 64}, {false, 192, 128}, + {true, 256, 128}, {false, 256, 128}, + + {true, 320, 192}, {false, 256, 128}, + {true, 384, 192}, {false, 384, 192}, + {true, 448, 256}, {false, 448, 192}, + {true, 512, 256}, {false, 512, 256}, + + {true, 576, 256}, {false, 576, 320}, + {true, 640, 320}, {false, 640, 320}, + {true, 704, 320}, {false, 704, 384}, + {true, 768, 384}, {false, 768, 384}, + + {true, 832, 448}, {false, 832, 384}, + {true, 896, 448}, {false, 896, 448}, + + {true, 1024, 512}, {false, 1024, 512}, + + {true, 2048, 1024}, {false, 2048, 1024}, + } + + for _, tc := range testCases { + rank := rsa.Rank(tc.val, tc.bitPosition) + if rank != tc.expectedRank { + t.Fatalf("expected rank(%t, %d) to be %d but got %d", tc.val, tc.bitPosition, tc.expectedRank, rank) + } + } +} + +type rsaSelectTestCase struct { + val bool + rank int + expectedPosition int +} + +func TestRSASelect(t *testing.T) { + bitsToTruncate := 17 + initialNumberOfBits := wordSize*4 - bitsToTruncate + rsa := newTestRSAFromWords(initialNumberOfBits, + 0x8010000000010000, // 1Count = 3 + 0xfff1ffffffffffff, // 1Count = 63 + 0x0000010000000000, // 1Count = 1 + 0xffffffffffffffff, // Possible 1Count = 47 + ) + + testCases := []rsaSelectTestCase{ + {true, 0, 0}, + {true, 1, 11}, + {true, 2, 47}, + {false, 0, 1}, + {false, 1, 2}, + {false, 3, 4}, + {false, 8, 9}, + {false, 9, 10}, + {false, 10, 12}, + {false, 11, 13}, + {false, 60, 63}, + + {true, 3, 64}, + {true, 9, 70}, + {true, 13, 74}, + {true, 14, 75}, + {true, 15, 79}, + {true, 16, 80}, + {true, 63, 127}, + {false, 61, 76}, + {false, 62, 77}, + {false, 63, 78}, + + {true, 64, 151}, + {true, 65, 192}, + {true, 111, 238}, + {false, 64, 128}, + + {false, 126, 191}, + + // Select of penultimate ranks should be the positions at which they appear. + {true, 111, rsa.bv.len() - 1}, + {false, 126, 191}, + + // Max bitvector positions for the max rank should be at the ends of the bitvector + {true, 112, rsa.bv.len()}, + {false, 127, rsa.bv.len()}, + } + + for _, tc := range testCases { + position, ok := rsa.Select(tc.val, tc.rank) + + if !ok { + t.Fatalf("expected select(%t, %d) to be %d but went out of range", tc.val, tc.rank, tc.expectedPosition) + } + + if position != tc.expectedPosition { + t.Fatalf("expected select(%t, %d) to be %d but got %d", tc.val, tc.rank, tc.expectedPosition, position) + } + } +} + +func TestRSASelect_notOk(t *testing.T) { + bitsToTruncate := 17 + initialNumberOfBits := wordSize*4 - bitsToTruncate + rsa := newTestRSAFromWords(initialNumberOfBits, + 0x8010000000010000, + 0xfff1ffffffffffff, + 0x0000010000000000, + 0xffffffffffffffff, + ) + + if _, ok := rsa.Select(true, -1); ok { + t.Fatalf("expected select(true, -1) to be not ok but somehow returned a value") + } + + pos, ok := rsa.Select(true, 111) + if !ok { + t.Fatalf("expected select(true, 111) to be ok but somehow got not ok") + } + + if pos != 238 { + t.Fatalf("expected select(true, 111) to be 238 but got %d", pos) + } + + if _, ok := rsa.Select(true, 239); ok { + t.Fatalf("expected select(true, 239) to be not ok but somehow returned a value") + } +} + +func newTestRSAFromWords(sizeInBits int, wordsToCopy ...uint64) rsaBitVector { + bv := newBitVector(sizeInBits) + for i := 0; i < sizeInBits; i++ { + w := wordsToCopy[i/64] + mask := uint64(1) << uint64(63-i%64) + bit := w&mask != 0 + bv.setBit(i, bit) + } + return newRSABitVectorFromBitVector(bv) +} diff --git a/bwt/wavelet.go b/bwt/wavelet.go new file mode 100644 index 00000000..af20d5c3 --- /dev/null +++ b/bwt/wavelet.go @@ -0,0 +1,433 @@ +package bwt + +import ( + "errors" + "fmt" + "math" + + "golang.org/x/exp/slices" +) + +/* + +For the waveletTree's usage, please read its +method documentation. To understand what it is and how +it works for either curiosity or maintenance, then read below. + +# WaveletTree + +The Wavelet Tree allows us to conduct RSA queries on strings in +a memory and run time efficient manner. +RSA stands for (R)ank, (S)elect, (A)ccess. + +See this blog post by Alex Bowe for an additional explanation: +https://www.alexbowe.com/wavelet-trees/ + +## The Character's Path Encoding + +Each character from a sequence's alphabet will be assigned a path. +This path encoding represents a path from the Wavelet Tree's root to some +leaf node that represents a character. +For example, given the alphabet A B C D E F G H, a possible encoding is: + +A: 000 +B: 001 +C: 010 +D: 011 +E: 100 +F: 101 +G: 110 +H: 111 + +If we wanted to get to the leaf that represents the character D, we'd have +to use D's path encoding to traverse the tree. +Consider 0 as the left and 1 as the right. +If we follow D's encoding, 011, then we'd take a path that looks like: + + root + / +left + \ + right + \ + right + +## The Data Represented at each node + +Let us consider the sequence "bananas" +It has the alphabet b, a, n, s +Let's say it has the encoding: +a: 00 +n: 01 +b: 10 +s: 11 +and that 0 is left and 1 is right +We can represent this tree with bitvectors: + + 0010101 + bananas + / \ + 1000 001 + baaa nns + / \ / \ +a n b s + +If we translate each bit vector to its corresponding string, then it becomes: + + bananas + / \ + baaa nns + / \ / \ +a b n s + +Each node of the tree consists of a bitvector whose values indicate whether +the character at a particular index is in the left (0) or right (1) child of the +tree. + +## RSA + +At this point, we can talk about RSA. RSA stands for (R)ank, (S)elect, (A)ccess. + +### Rank Example + +WaveletTree.Rank(c, n) returns the rank of character c at index n in a sequence, i.e. how many +times c has occurred in a sequence before index n. + +To get WaveletTree.Rank(a, 4) of bananas where a's encoding is 00 +1. root.Rank(0, 4) of 0010101 is 3 +2. Visit Left Child +3. child.Rank(0, 3) of 1000 is 2 +4. Visit Left Child +5. We are at a leaf node, so return our last recorded rank: 2 + +### Select Example + +To get WaveletTree.Select(n, 1) of bananas where n's encoding is 01 +1. Go down to n's leaf using the path encoding is 01 +2. Go back to n's leaf's parent +3. parent.Select(0, 1) of 001 is 0 +4. Go to the next parent +5. parent.Select(1, 0) of 0010101 is 2 +6. return 2 since we are at the root. + +### Access Example + +Take the tree we constructed earlier to represent the sequence "bananas". + + 0010101 + / \ + 1000 001 + / \ / \ +a n b s + +To access the 4th character of the sequence, we would call WaveletTree.Access(3), +which performs the following operations: + +1. root[3] is 0 and root.Rank(0, 3) is 2 +2. Since root[3] is 0, visit left child +3. child[2] is 0 and child.Rank(0, 2) is 1 +4. Since child[2] is 0, visit left child +5. Left child is a leaf, so we've found our value (a)! + +NOTE: The waveletTree does not literally have to be a tree. There are other forms that it may +exist in like the concatenation of order level representation of all its node's bitvectors... +as one example. Please reference the implementation if you'd like to understand how this +specific waveletTree works. + +*/ + +// waveletTree is a data structure that allows us to index a sequence +// in a memory efficient way that allows us to conduct RSA, (R)ank (S)elect (A)ccess +// queries on strings. This is very useful in situations where you'd like to understand +// certain aspects of a sequence like: +// * the number of times a character appears +// * counting how the frequency of a character up to certain offset +// * locating characters of certain rank within the sequence +// * accessing the character at a given position +type waveletTree struct { + root *node + alpha []charInfo + length int +} + +// Access will return the ith character of the original +// string used to build the waveletTree +func (wt waveletTree) Access(i int) byte { + if wt.root.isLeaf() { + return *wt.root.char + } + + curr := wt.root + for !curr.isLeaf() { + bit := curr.data.Access(i) + i = curr.data.Rank(bit, i) + if bit { + curr = curr.right + } else { + curr = curr.left + } + } + return *curr.char +} + +// Rank allows us to get the rank of a specified character in +// the original string +func (wt waveletTree) Rank(char byte, i int) int { + if wt.root.isLeaf() { + return wt.root.data.Rank(true, i) + } + + curr := wt.root + ci := wt.lookupCharInfo(char) + level := 0 + var rank int + for !curr.isLeaf() { + pathBit := ci.path.getBit(ci.path.len() - 1 - level) + rank = curr.data.Rank(pathBit, i) + if pathBit { + curr = curr.right + } else { + curr = curr.left + } + level++ + i = rank + } + return rank +} + +// Select allows us to get the corresponding position of a character +// in the original string given its rank. +func (wt waveletTree) Select(char byte, rank int) int { + if wt.root.isLeaf() { + s, ok := wt.root.data.Select(true, rank) + if !ok { + msg := fmt.Sprintf("could not find a corresponding bit for node.Select(true, %d) root as leaf node", rank) + panic(msg) + } + return s + } + + curr := wt.root + ci := wt.lookupCharInfo(char) + level := 0 + + for !curr.isLeaf() { + pathBit := ci.path.getBit(ci.path.len() - 1 - level) + if pathBit { + curr = curr.right + } else { + curr = curr.left + } + level++ + } + + for curr.parent != nil { + curr = curr.parent + level-- + pathBit := ci.path.getBit(ci.path.len() - 1 - level) + nextRank, ok := curr.data.Select(pathBit, rank) + if !ok { + msg := fmt.Sprintf("could not find a corresponding bit for node.Select(%t, %d) for characterInfo %+v", pathBit, rank, ci) + panic(msg) + } + rank = nextRank + } + + return rank +} + +func (wt waveletTree) lookupCharInfo(char byte) charInfo { + for i := range wt.alpha { + if wt.alpha[i].char == char { + return wt.alpha[i] + } + } + msg := fmt.Sprintf("could not find character %s in alphabet %+v. this should not be possible and indicates that the WaveletTree is malformed", string(char), wt.alpha) + panic(msg) +} + +func (wt waveletTree) reconstruct() string { + str := "" + for i := 0; i < wt.length; i++ { + str += string(wt.Access(i)) + } + return str +} + +type node struct { + data rsaBitVector + char *byte + parent *node + left *node + right *node +} + +func (n node) isLeaf() bool { + return n.char != nil +} + +type charInfo struct { + char byte + maxRank int + path bitvector +} + +func newWaveletTreeFromString(str string) (waveletTree, error) { + err := validateWaveletTreeBuildInput(&str) + if err != nil { + return waveletTree{}, err + } + + bytes := []byte(str) + + alpha := getCharInfoDescByRank(bytes) + root := buildWaveletTree(0, alpha, bytes) + + // Handle the case where the provided sequence only has an alphabet + // of size 1 + if root.isLeaf() { + bv := newBitVector(len(bytes)) + for i := 0; i < bv.len(); i++ { + bv.setBit(i, true) + } + root.data = newRSABitVectorFromBitVector(bv) + } + + return waveletTree{ + root: root, + alpha: alpha, + length: len(str), + }, nil +} + +func buildWaveletTree(currentLevel int, alpha []charInfo, bytes []byte) *node { + if len(alpha) == 0 { + return nil + } + + if len(alpha) == 1 { + return &node{char: &alpha[0].char} + } + + leftAlpha, rightAlpha := partitionAlpha(currentLevel, alpha) + + var leftBytes []byte + var rightBytes []byte + + bv := newBitVector(len(bytes)) + for i := range bytes { + if isInAlpha(rightAlpha, bytes[i]) { + bv.setBit(i, true) + rightBytes = append(rightBytes, bytes[i]) + } else { + leftBytes = append(leftBytes, bytes[i]) + } + } + + root := &node{ + data: newRSABitVectorFromBitVector(bv), + } + + leftTree := buildWaveletTree(currentLevel+1, leftAlpha, leftBytes) + rightTree := buildWaveletTree(currentLevel+1, rightAlpha, rightBytes) + + root.left = leftTree + root.right = rightTree + + if leftTree != nil { + leftTree.parent = root + } + if rightTree != nil { + rightTree.parent = root + } + + return root +} + +func isInAlpha(alpha []charInfo, b byte) bool { + for _, a := range alpha { + if a.char == b { + return true + } + } + return false +} + +// partitionAlpha partitions the alphabet in half based on whether its corresponding path bit +// is a 0 or 1. 0 will comprise the left tree while 1 will comprise the right. The alphabet +// should be sorted in such a way that we remove the most amount of characters nearest to the +// root of the tree to reduce the memory footprint as much as possible. +func partitionAlpha(currentLevel int, alpha []charInfo) (left []charInfo, right []charInfo) { + for _, a := range alpha { + if a.path.getBit(a.path.len() - 1 - currentLevel) { + right = append(right, a) + } else { + left = append(left, a) + } + } + + return left, right +} + +// getCharInfoDescByRank takes in the bytes of the original +// string and return a sorted list of character metadata descending +// by rank. The character metadata is important for building the rest +// of the tree along with querying it later on. The sorting is important +// because this allows us to build the tree in the most memory efficient +// way since the characters with the greatest counts will be removed first +// before build the subsequent nodes in the lower levels. +// NOTE: alphabets are expected to be small for real usecases +func getCharInfoDescByRank(b []byte) []charInfo { + ranks := make(map[byte]int) + for i := 0; i < len(b); i++ { + if _, ok := ranks[b[i]]; ok { + ranks[b[i]] += 1 + } else { + ranks[b[i]] = 0 + } + } + + var sortedInfo []charInfo + for k := range ranks { + sortedInfo = append(sortedInfo, charInfo{char: k, maxRank: ranks[k]}) + } + + slices.SortFunc(sortedInfo, func(a, b charInfo) bool { + if a.maxRank == b.maxRank { + return a.char < b.char + } + return a.maxRank > b.maxRank + }) + + numOfBits := getTreeHeight(sortedInfo) + for i := range sortedInfo { + bv := newBitVector(numOfBits) + encodeCharPathIntoBitVector(bv, uint64(i)) + sortedInfo[i].path = bv + } + + return sortedInfo +} + +func encodeCharPathIntoBitVector(bv bitvector, n uint64) { + shift := 0 + for n>>shift > 0 { + if n>>shift%2 == 1 { + bv.setBit(bv.len()-1-shift, true) + } else { + bv.setBit(bv.len()-1-shift, false) + } + shift++ + } +} + +func getTreeHeight(alpha []charInfo) int { + return int(math.Log2(float64(len(alpha)))) + 1 +} + +func validateWaveletTreeBuildInput(sequence *string) error { + if len(*sequence) == 0 { + return errors.New("Sequence can not be empty") + } + return nil +} diff --git a/bwt/wavelet_test.go b/bwt/wavelet_test.go new file mode 100644 index 00000000..432f4a85 --- /dev/null +++ b/bwt/wavelet_test.go @@ -0,0 +1,285 @@ +package bwt + +import ( + "strings" + "testing" +) + +type WaveletTreeAccessTestCase struct { + pos int + expected string +} + +func TestWaveletTree_Access(t *testing.T) { + testStr := "AAAACCCCTTTTGGGG" + "ACTG" + "TGCA" + "TTAA" + "CCGG" + "GGGGTTTTCCCCAAAA" + wt, err := newWaveletTreeFromString(testStr) + if err != nil { + t.Fatal(err) + } + + testCases := []WaveletTreeAccessTestCase{ + {0, "A"}, + {3, "A"}, + {4, "C"}, + {7, "C"}, + {8, "T"}, + {9, "T"}, + {11, "T"}, + {12, "G"}, + {13, "G"}, + {15, "G"}, + + {16, "A"}, + {17, "C"}, + {18, "T"}, + {19, "G"}, + + {20, "T"}, + {21, "G"}, + {22, "C"}, + {23, "A"}, + + {24, "T"}, + {25, "T"}, + {26, "A"}, + {27, "A"}, + + {28, "C"}, + {29, "C"}, + {30, "G"}, + {31, "G"}, + + {32, "G"}, + {35, "G"}, + {36, "T"}, + {39, "T"}, + {40, "C"}, + {41, "C"}, + {43, "C"}, + {44, "A"}, + {46, "A"}, + {47, "A"}, + } + + for _, tc := range testCases { + actual := string(wt.Access(tc.pos)) + if actual != tc.expected { + t.Fatalf("expected access(%d) to be %s but got %s", tc.pos, tc.expected, actual) + } + } +} + +type WaveletTreeRankTestCase struct { + char string + pos int + expected int +} + +func TestWaveletTree_Rank_Genomic(t *testing.T) { + testStr := "AAAACCCCTTTTGGGG" + "ACTG" + "TGCA" + "TTAA" + "CCGG" + "GGGGTTTTCCCCAAAA" + wt, err := newWaveletTreeFromString(testStr) + if err != nil { + t.Fatal(err) + } + + testCases := []WaveletTreeRankTestCase{ + {"A", 0, 0}, + {"A", 2, 2}, + {"A", 3, 3}, + {"A", 8, 4}, + {"C", 4, 0}, + {"C", 6, 2}, + {"C", 12, 4}, + {"T", 2, 0}, + {"T", 8, 0}, + {"T", 12, 4}, + {"T", 15, 4}, + {"G", 15, 3}, + + {"A", 16, 4}, + {"A", 17, 5}, + {"G", 16, 4}, + + {"T", 20, 5}, + {"A", 23, 5}, + + {"T", 24, 6}, + {"T", 27, 8}, + + {"C", 28, 6}, + {"G", 31, 7}, + + {"G", 32, 8}, + {"G", 33, 9}, + {"T", 36, 8}, + {"T", 38, 10}, + {"C", 40, 8}, + {"C", 43, 11}, + {"A", 44, 8}, + {"A", 47, 11}, + } + + for _, tc := range testCases { + actual := wt.Rank(tc.char[0], tc.pos) + if actual != tc.expected { + t.Fatalf("expected rank(%s, %d) to be %d but got %d", tc.char, tc.pos, tc.expected, actual) + } + } +} + +type WaveletTreeSelectTestCase struct { + char string + rank int + expected int +} + +func TestWaveletTree_Select(t *testing.T) { + testStr := "AAAACCCCTTTTGGGG" + "ACTG" + "TGCA" + "TTAA" + "CCGG" + "GGGGTTTTCCCCAAAA" + wt, err := newWaveletTreeFromString(testStr) + if err != nil { + t.Fatal(err) + } + + testCases := []WaveletTreeSelectTestCase{ + {"A", 0, 0}, + {"A", 1, 1}, + {"A", 2, 2}, + {"A", 3, 3}, + {"C", 0, 4}, + {"C", 3, 7}, + + {"A", 4, 16}, + {"C", 4, 17}, + {"T", 4, 18}, + {"G", 4, 19}, + + {"T", 5, 20}, + {"G", 5, 21}, + {"C", 5, 22}, + {"A", 5, 23}, + + {"T", 6, 24}, + {"T", 7, 25}, + {"A", 6, 26}, + + {"C", 6, 28}, + {"G", 6, 30}, + {"G", 7, 31}, + + {"G", 8, 32}, + {"A", 11, 47}, + } + + for _, tc := range testCases { + actual := wt.Select(tc.char[0], tc.rank) + if actual != tc.expected { + t.Fatalf("expected select(%s, %d) to be %d but got %d", tc.char, tc.rank, tc.expected, actual) + } + } +} + +// TestWaveletTree_Access_Reconstruction these tests are to ensure that the wavelet tree is formed correctly. If we can reconstruct the string, we can be +// fairly confident that the WaveletTree is well formed. +func TestWaveletTree_Access_Reconstruction(t *testing.T) { + // Build with a fair sized alphabet + enhancedQuickBrownFox := "the quick brown fox jumps over the lazy dog with an overt frown after fumbling its parallelogram shaped bananagram all around downtown" + enhancedQuickBrownFoxRepeated := strings.Join([]string{enhancedQuickBrownFox, enhancedQuickBrownFox, enhancedQuickBrownFox, enhancedQuickBrownFox, enhancedQuickBrownFox}, " ") + // Make it very large to account for any succinct data structures being used under the hood. For example, this helped uncover and errors + // diagnose issues with the Jacobson's Rank used under the hood. + enhancedQuickBrownFoxSuperLarge := "" + for i := 0; i < 100; i++ { + enhancedQuickBrownFoxSuperLarge += enhancedQuickBrownFoxRepeated + } + + testCases := []string{ + "the quick brown fox jumped over the lazy dog", + "the quick brown fox jumped over the lazy dog!", // odd numbered alphabet + enhancedQuickBrownFox, + enhancedQuickBrownFoxRepeated, + enhancedQuickBrownFoxSuperLarge, + } + + for _, str := range testCases { + wt, err := newWaveletTreeFromString(str) + if err != nil { + t.Fatal(err) + } + actual := wt.reconstruct() + if actual != str { + t.Fatalf("expected to rebuild:\n%s\nbut instead got:\n%s", str, actual) + } + } +} + +func TestWaveletTreeEmptyStr(t *testing.T) { + str := "" + _, err := newWaveletTreeFromString(str) + if err == nil { + t.Fatal("expected error but got nil") + } +} + +func TestWaveletTreeSingleChar(t *testing.T) { + char := "l" + wt, err := newWaveletTreeFromString(char) + if err != nil { + t.Fatal(err) + } + r := wt.Rank(char[0], 1) + s := wt.Select(char[0], 0) + a := wt.Access(0) + + if r != 1 { + t.Fatalf("expected Rank(%s, %d) to be %d but got %d", char, 1, 1, r) + } + if s != 0 { + t.Fatalf("expected Select(%s, %d) to be %d but got %d", char, 0, 0, s) + } + if a != char[0] { + t.Fatalf("expected Access(%d) to be %d but got %d", 1, 1, s) + } +} + +func TestWaveletTreeSingleAlpha(t *testing.T) { + str := "lll" + wt, err := newWaveletTreeFromString(str) + if err != nil { + t.Fatal(err) + } + r := wt.Rank(str[0], 1) + s := wt.Select(str[0], 1) + a := wt.Access(0) + + if r != 1 { + t.Fatalf("expected Rank(%s, %d) to be %d but got %d", str, 1, 1, r) + } + if s != 1 { + t.Fatalf("expected Select(%s, %d) to be %d but got %d", str, 1, 1, s) + } + if a != str[0] { + t.Fatalf("expected Access(%d) to be %d but got %d", 1, 1, s) + } +} +func TestBuildWaveletTree_ZeroAlpha(t *testing.T) { + bytes := []byte("AAAACCCCTTTTGGGG") + alpha := []charInfo{} + + root := buildWaveletTree(0, alpha, bytes) + + if root != nil { + t.Fatalf("expected root to be nil but got %v", root) + } +} +func TestWaveletTree_LookupCharInfo_Panic(t *testing.T) { + wt := waveletTree{ + alpha: []charInfo{}, + } + + defer func() { + if r := recover(); r == nil { + t.Errorf("expected panic but got nil") + } + }() + + wt.lookupCharInfo('B') +}