diff --git a/Cargo.toml b/Cargo.toml index 68075af718..9fbdb0a1f0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,11 +21,11 @@ byteorder = "1.4.3" crc32fast = "1.3.2" once_cell = "1.10.0" regex = { version = "1.5.5", default-features = false, features = [ - "std", - "unicode", + "std", + "unicode", ] } aho-corasick = "1.0" -tantivy-fst = "0.5" +tantivy-fst = { git = "https://github.com/paradedb/fst.git" } memmap2 = { version = "0.9.0", optional = true } lz4_flex = { version = "0.11", default-features = false, optional = true } zstd = { version = "0.13", optional = true, default-features = false } @@ -40,7 +40,7 @@ crossbeam-channel = "0.5.4" rust-stemmers = "1.2.0" downcast-rs = "1.2.1" bitpacking = { version = "0.9.2", default-features = false, features = [ - "bitpacker4x", + "bitpacker4x", ] } census = "0.4.2" rustc-hash = "2.0.0" @@ -129,14 +129,14 @@ compare_hash_only = ["stacker/compare_hash_only"] [workspace] members = [ - "query-grammar", - "bitpacker", - "common", - "ownedbytes", - "stacker", - "sstable", - "tokenizer-api", - "columnar", + "query-grammar", + "bitpacker", + "common", + "ownedbytes", + "stacker", + "sstable", + "tokenizer-api", + "columnar", ] # Following the "fail" crate best practises, we isolate diff --git a/src/aggregation/mod.rs b/src/aggregation/mod.rs index 7b1f21b296..edcf8c01e8 100644 --- a/src/aggregation/mod.rs +++ b/src/aggregation/mod.rs @@ -160,6 +160,11 @@ use itertools::Itertools; use serde::de::{self, Visitor}; use serde::{Deserialize, Deserializer, Serialize}; +#[allow(unused)] +pub(crate) fn invalid_agg_request(message: String) -> crate::TantivyError { + crate::TantivyError::AggregationError(AggregationError::InvalidRequest(message)) +} + fn parse_str_into_f64(value: &str) -> Result { let parsed = value .parse::() diff --git a/src/block_join_collector.rs b/src/block_join_collector.rs new file mode 100644 index 0000000000..ec15a6aec0 --- /dev/null +++ b/src/block_join_collector.rs @@ -0,0 +1,113 @@ +use crate::collector::Collector; +use crate::query::Scorer; +use crate::DocId; +use crate::Result; +use crate::Score; +use crate::SegmentReader; +use common::BitSet; + +/// A conceptual `BlockJoinCollector` that aims to mimic Lucene's BlockJoinCollector. +/// It collects parent documents and, for each one, stores which child docs matched. +/// After search, you can retrieve these "groups". +/// +/// NOTE: This is a conceptual implementation. Adjust as per Tantivy's Collector API. +/// In Tantivy, you'd typically implement `Collector` and `SegmentCollector`. +pub struct BlockJoinCollector { + // For simplicity, store doc groups in memory: + groups: Vec<(DocId, Vec, Vec)>, + current_reader_base: DocId, +} + +impl BlockJoinCollector { + pub fn new() -> BlockJoinCollector { + BlockJoinCollector { + groups: Vec::new(), + current_reader_base: 0, + } + } + + /// Retrieve the collected groups: + pub fn get_groups(&self) -> &[(DocId, Vec, Vec)] { + &self.groups + } +} + +impl Collector for BlockJoinCollector { + type Fruit = (); + + fn set_segment( + &mut self, + _segment_id: u32, + reader: &SegmentReader, + ) -> Result>> { + let base = self.current_reader_base; + self.current_reader_base += reader.max_doc(); + let mut parent_bitset = BitSet::with_max_value(reader.max_doc()); + // In a real scenario, you'd identify the parent docs here using a filter. + // For this conceptual example, we assume parents are known externally. + // You might need to pass that information in or have a filter pre-applied. + + Ok(Box::new(BlockJoinSegmentCollector { + parent_bitset, + parent_groups: &mut self.groups, + base, + })) + } + + fn requires_scoring(&self) -> bool { + true + } + + fn collect(&mut self, _doc: DocId, _score: Score) -> Result<()> { + // This method won't be called directly if we rely on segment collectors. + Ok(()) + } + + fn harvest(self) -> Result { + Ok(()) + } +} + +struct BlockJoinSegmentCollector<'a> { + parent_bitset: BitSet, + parent_groups: &'a mut Vec<(DocId, Vec, Vec)>, + base: DocId, +} + +impl<'a> crate::collector::SegmentCollector for BlockJoinSegmentCollector<'a> { + type Fruit = (); + + fn collect(&mut self, doc: DocId, score: Score) { + // In a more complete implementation, you'd need + // logic to detect transitions from child docs to parent doc. + // + // This is a simplified conceptual collector. In practice: + // 1. Identify if `doc` is a parent or child. + // 2. If child, associate with last-seen parent. + // 3. If parent, start a new group. + + // Without full integration it's hard to do. For now, + // assume that the scoring and doc iteration are done by + // BlockJoinScorer and that we only collect parents when + // we hit them: + if self.parent_bitset.contains(doc) { + // It's a parent doc + self.parent_groups + .push((self.base + doc, Vec::new(), Vec::new())); + } else { + // It's a child doc - associate it with last parent + if let Some(last) = self.parent_groups.last_mut() { + last.1.push(self.base + doc); + last.2.push(score); + } + } + } + + fn set_scorer(&mut self, _scorer: Box) { + // Not implemented - you'd store the scorer if needed. + } + + fn harvest(self) -> Result { + Ok(()) + } +} diff --git a/src/directory/directory_lock.rs b/src/directory/directory_lock.rs index a6321b50b0..0ae1bc9209 100644 --- a/src/directory/directory_lock.rs +++ b/src/directory/directory_lock.rs @@ -58,3 +58,8 @@ pub static META_LOCK: Lazy = Lazy::new(|| Lock { filepath: PathBuf::from(".tantivy-meta.lock"), is_blocking: true, }); + +pub static MANAGED_LOCK: Lazy = Lazy::new(|| Lock { + filepath: PathBuf::from(".tantivy-managed.lock"), + is_blocking: true, +}); diff --git a/src/directory/managed_directory.rs b/src/directory/managed_directory.rs index c24ec5534f..2edea828cd 100644 --- a/src/directory/managed_directory.rs +++ b/src/directory/managed_directory.rs @@ -11,7 +11,7 @@ use crate::directory::error::{DeleteError, LockError, OpenReadError, OpenWriteEr use crate::directory::footer::{Footer, FooterProxy}; use crate::directory::{ DirectoryLock, FileHandle, FileSlice, GarbageCollectionResult, Lock, WatchCallback, - WatchHandle, WritePtr, META_LOCK, + WatchHandle, WritePtr, MANAGED_LOCK, META_LOCK, }; use crate::error::DataCorruption; use crate::Directory; @@ -39,7 +39,6 @@ fn is_managed(path: &Path) -> bool { #[derive(Debug)] pub struct ManagedDirectory { directory: Box, - meta_informations: Arc>, } #[derive(Debug, Default)] @@ -51,9 +50,9 @@ struct MetaInformation { /// that were created by tantivy. fn save_managed_paths( directory: &dyn Directory, - wlock: &RwLockWriteGuard<'_, MetaInformation>, + managed_paths: &HashSet, ) -> io::Result<()> { - let mut w = serde_json::to_vec(&wlock.managed_paths)?; + let mut w = serde_json::to_vec(managed_paths)?; writeln!(&mut w)?; directory.atomic_write(&MANAGED_FILEPATH, &w[..])?; Ok(()) @@ -62,7 +61,11 @@ fn save_managed_paths( impl ManagedDirectory { /// Wraps a directory as managed directory. pub fn wrap(directory: Box) -> crate::Result { - match directory.atomic_read(&MANAGED_FILEPATH) { + Ok(ManagedDirectory { directory }) + } + + pub fn get_managed_paths(&self) -> crate::Result> { + match self.directory.atomic_read(&MANAGED_FILEPATH) { Ok(data) => { let managed_files_json = String::from_utf8_lossy(&data); let managed_files: HashSet = serde_json::from_str(&managed_files_json) @@ -72,17 +75,9 @@ impl ManagedDirectory { format!("Managed file cannot be deserialized: {e:?}. "), ) })?; - Ok(ManagedDirectory { - directory, - meta_informations: Arc::new(RwLock::new(MetaInformation { - managed_paths: managed_files, - })), - }) + Ok(managed_files) } - Err(OpenReadError::FileDoesNotExist(_)) => Ok(ManagedDirectory { - directory, - meta_informations: Arc::default(), - }), + Err(OpenReadError::FileDoesNotExist(_)) => Ok(HashSet::new()), io_err @ Err(OpenReadError::IoError { .. }) => Err(io_err.err().unwrap().into()), Err(OpenReadError::IncompatibleIndex(incompatibility)) => { // For the moment, this should never happen `meta.json` @@ -110,9 +105,11 @@ impl ManagedDirectory { &mut self, get_living_files: L, ) -> crate::Result { - info!("Garbage collect"); let mut files_to_delete = vec![]; + // We're about to do an atomic write to managed.json, lock it down + let _lock = self.acquire_lock(&MANAGED_LOCK)?; + let managed_paths = self.get_managed_paths()?; // It is crucial to get the living files after acquiring the // read lock of meta information. That way, we // avoid the following scenario. @@ -124,11 +121,6 @@ impl ManagedDirectory { // // releasing the lock as .delete() will use it too. { - let meta_informations_rlock = self - .meta_informations - .read() - .expect("Managed directory rlock poisoned in garbage collect."); - // The point of this second "file" lock is to enforce the following scenario // 1) process B tries to load a new set of searcher. // The list of segments is loaded @@ -138,7 +130,7 @@ impl ManagedDirectory { match self.acquire_lock(&META_LOCK) { Ok(_meta_lock) => { let living_files = get_living_files(); - for managed_path in &meta_informations_rlock.managed_paths { + for managed_path in &managed_paths { if !living_files.contains(managed_path) { files_to_delete.push(managed_path.clone()); } @@ -181,16 +173,12 @@ impl ManagedDirectory { if !deleted_files.is_empty() { // update the list of managed files by removing // the file that were removed. - let mut meta_informations_wlock = self - .meta_informations - .write() - .expect("Managed directory wlock poisoned (2)."); - let managed_paths_write = &mut meta_informations_wlock.managed_paths; + let mut managed_paths_write = managed_paths; for delete_file in &deleted_files { managed_paths_write.remove(delete_file); } self.directory.sync_directory()?; - save_managed_paths(self.directory.as_mut(), &meta_informations_wlock)?; + save_managed_paths(self.directory.as_mut(), &managed_paths_write)?; } Ok(GarbageCollectionResult { @@ -215,15 +203,20 @@ impl ManagedDirectory { if !is_managed(filepath) { return Ok(()); } - let mut meta_wlock = self - .meta_informations - .write() - .expect("Managed file lock poisoned"); - let has_changed = meta_wlock.managed_paths.insert(filepath.to_owned()); + + // We're about to do an atomic write to managed.json, lock it down + let _lock = self + .acquire_lock(&MANAGED_LOCK) + .expect("must be able to acquire lock for managed.json"); + + let mut managed_paths = self + .get_managed_paths() + .expect("reading managed files should not fail"); + let has_changed = managed_paths.insert(filepath.to_owned()); if !has_changed { return Ok(()); } - save_managed_paths(self.directory.as_ref(), &meta_wlock)?; + save_managed_paths(self.directory.as_ref(), &managed_paths)?; // This is not the first file we add. // Therefore, we are sure that `.managed.json` has been already // properly created and we do not need to sync its parent directory. @@ -231,11 +224,12 @@ impl ManagedDirectory { // (It might seem like a nicer solution to create the managed_json on the // creation of the ManagedDirectory instance but it would actually // prevent the use of read-only directories..) - let managed_file_definitely_already_exists = meta_wlock.managed_paths.len() > 1; + let managed_file_definitely_already_exists = managed_paths.len() > 1; if managed_file_definitely_already_exists { return Ok(()); } self.directory.sync_directory()?; + Ok(()) } @@ -258,13 +252,11 @@ impl ManagedDirectory { /// List all managed files pub fn list_managed_files(&self) -> HashSet { - let managed_paths = self - .meta_informations - .read() - .expect("Managed directory rlock poisoned in list damaged.") - .managed_paths - .clone(); - managed_paths + let _lock = self + .acquire_lock(&MANAGED_LOCK) + .expect("must be able to acquire lock for managed.json"); + self.get_managed_paths() + .expect("reading managed files should not fail") } } @@ -329,7 +321,6 @@ impl Clone for ManagedDirectory { fn clone(&self) -> ManagedDirectory { ManagedDirectory { directory: self.directory.box_clone(), - meta_informations: Arc::clone(&self.meta_informations), } } } diff --git a/src/directory/mod.rs b/src/directory/mod.rs index 93c9225679..7eccbeb2a8 100644 --- a/src/directory/mod.rs +++ b/src/directory/mod.rs @@ -24,7 +24,7 @@ pub use common::{AntiCallToken, OwnedBytes, TerminatingWrite}; pub(crate) use self::composite_file::{CompositeFile, CompositeWrite}; pub use self::directory::{Directory, DirectoryClone, DirectoryLock}; -pub use self::directory_lock::{Lock, INDEX_WRITER_LOCK, META_LOCK}; +pub use self::directory_lock::{Lock, INDEX_WRITER_LOCK, MANAGED_LOCK, META_LOCK}; pub use self::ram_directory::RamDirectory; pub use self::watch_event_router::{WatchCallback, WatchCallbackList, WatchHandle}; diff --git a/src/index/index.rs b/src/index/index.rs index 052bc4f920..9c51df3704 100644 --- a/src/index/index.rs +++ b/src/index/index.rs @@ -12,7 +12,9 @@ use crate::core::{Executor, META_FILEPATH}; use crate::directory::error::OpenReadError; #[cfg(feature = "mmap")] use crate::directory::MmapDirectory; -use crate::directory::{Directory, ManagedDirectory, RamDirectory, INDEX_WRITER_LOCK}; +use crate::directory::{ + Directory, DirectoryLock, ManagedDirectory, RamDirectory, INDEX_WRITER_LOCK, +}; use crate::error::{DataCorruption, TantivyError}; use crate::index::{IndexMeta, SegmentId, SegmentMeta, SegmentMetaInventory}; use crate::indexer::index_writer::{MAX_NUM_THREAD, MEMORY_BUDGET_NUM_BYTES_MIN}; diff --git a/src/index/segment.rs b/src/index/segment.rs index 4c9382cb03..fcd32a1fff 100644 --- a/src/index/segment.rs +++ b/src/index/segment.rs @@ -46,7 +46,7 @@ impl Segment { /// /// This method is only used when updating `max_doc` from 0 /// as we finalize a fresh new segment. - pub(crate) fn with_max_doc(self, max_doc: u32) -> Segment { + pub fn with_max_doc(self, max_doc: u32) -> Segment { Segment { index: self.index, meta: self.meta.with_max_doc(max_doc), diff --git a/src/indexer/index_writer.rs b/src/indexer/index_writer.rs index 6719afe9f9..b2a5a1d09a 100644 --- a/src/indexer/index_writer.rs +++ b/src/indexer/index_writer.rs @@ -714,6 +714,28 @@ impl IndexWriter { Ok(opstamp) } + /// Adds multiple documents as a block. + /// + /// This method allows adding multiple documents together as a single block. + /// This is important for nested documents, where child documents need to be + /// added before their parent document, and they need to be stored together + /// in the same block. + /// + /// The opstamp returned is the opstamp of the last document added. + pub fn add_documents(&self, documents: Vec) -> crate::Result { + let count = documents.len() as u64; + if count == 0 { + return Ok(self.stamper.stamp()); + } + let (batch_opstamp, stamps) = self.get_batch_opstamps(count); + let mut adds = AddBatch::default(); + for (document, opstamp) in documents.into_iter().zip(stamps) { + adds.push(AddOperation { opstamp, document }); + } + self.send_add_documents_batch(adds)?; + Ok(batch_opstamp) + } + /// Gets a range of stamps from the stamper and "pops" the last stamp /// from the range returning a tuple of the last optstamp and the popped /// range. @@ -820,6 +842,7 @@ mod tests { STRING, TEXT, }; use crate::store::DOCSTORE_CACHE_CAPACITY; + use crate::Result; use crate::{ DateTime, DocAddress, Index, IndexSettings, IndexWriter, ReloadPolicy, TantivyDocument, Term, @@ -1234,6 +1257,188 @@ mod tests { Ok(()) } + #[test] + fn test_add_documents() -> Result<()> { + // Create a simple schema with one text field + let mut schema_builder = Schema::builder(); + let text_field = schema_builder.add_text_field("text", TEXT); + let schema = schema_builder.build(); + + // Create an index in RAM + let index = Index::create_in_ram(schema); + let mut index_writer = index.writer_for_tests()?; + + // Create multiple documents + let docs = vec![ + doc!(text_field => "hello"), + doc!(text_field => "world"), + doc!(text_field => "tantivy"), + ]; + + // Add documents using add_documents + let opstamp = index_writer.add_documents(docs)?; + assert_eq!(opstamp, 3u64); // Since we have three documents, opstamp should be 3 + + // Commit the changes + index_writer.commit()?; + + // Create a reader and searcher + let reader = index.reader()?; + reader.reload()?; + let searcher = reader.searcher(); + + // Verify that the documents are indexed correctly + let term = Term::from_field_text(text_field, "hello"); + let query = TermQuery::new(term, IndexRecordOption::Basic); + let top_docs = searcher.search(&query, &TopDocs::with_limit(10))?; + assert_eq!(top_docs.len(), 1); + + let term = Term::from_field_text(text_field, "world"); + let query = TermQuery::new(term, IndexRecordOption::Basic); + let top_docs = searcher.search(&query, &TopDocs::with_limit(10))?; + assert_eq!(top_docs.len(), 1); + + let term = Term::from_field_text(text_field, "tantivy"); + let query = TermQuery::new(term, IndexRecordOption::Basic); + let top_docs = searcher.search(&query, &TopDocs::with_limit(10))?; + assert_eq!(top_docs.len(), 1); + + Ok(()) + } + + #[test] + fn test_add_documents_empty() -> Result<()> { + // Test adding an empty list of documents + let mut schema_builder = Schema::builder(); + let text_field = schema_builder.add_text_field("text", TEXT); + let schema = schema_builder.build(); + + let index = Index::create_in_ram(schema); + let mut index_writer = index.writer_for_tests()?; + + let docs: Vec = Vec::new(); + let opstamp = index_writer.add_documents(docs)?; + assert_eq!(opstamp, 0u64); + + // Since no documents were added, committing should not change anything + index_writer.commit()?; + + let reader = index.reader()?; + reader.reload()?; + let searcher = reader.searcher(); + + // Search for any documents, expecting none + let term = Term::from_field_text(text_field, "any"); + let query = TermQuery::new(term, IndexRecordOption::Basic); + let top_docs = searcher.search(&query, &TopDocs::with_limit(10))?; + assert_eq!(top_docs.len(), 0); + + Ok(()) + } + + #[test] + fn test_add_documents_order() -> Result<()> { + // Test that documents are indexed in the order they are added + let mut schema_builder = Schema::builder(); + let text_field = schema_builder.add_text_field("text", TEXT | STORED); + let schema = schema_builder.build(); + + let index = Index::create_in_ram(schema); + let mut index_writer = index.writer_for_tests()?; + + // Create multiple documents + let docs = vec![ + doc!(text_field => "doc1"), + doc!(text_field => "doc2"), + doc!(text_field => "doc3"), + ]; + + // Add documents using add_documents + index_writer.add_documents(docs)?; + index_writer.commit()?; + + // Create a reader and searcher + let reader = index.reader()?; + reader.reload()?; + let searcher = reader.searcher(); + + // Collect documents and verify their order + let all_docs = searcher + .segment_readers() + .iter() + .flat_map(|segment_reader| { + let store_reader = segment_reader.get_store_reader(1000).unwrap(); + segment_reader + .doc_ids_alive() + .map(move |doc_id| store_reader.get::(doc_id).unwrap()) + }) + .collect::>(); + + assert_eq!(all_docs.len(), 3); + assert_eq!( + all_docs[0].get_first(text_field).unwrap().as_str(), + Some("doc1") + ); + assert_eq!( + all_docs[1].get_first(text_field).unwrap().as_str(), + Some("doc2") + ); + assert_eq!( + all_docs[2].get_first(text_field).unwrap().as_str(), + Some("doc3") + ); + + Ok(()) + } + + #[test] + fn test_add_documents_concurrency() -> Result<()> { + // Test adding documents concurrently + use std::sync::mpsc; + use std::thread; + + let mut schema_builder = Schema::builder(); + let text_field = schema_builder.add_text_field("text", TEXT); + let schema = schema_builder.build(); + + let index = Index::create_in_ram(schema); + let mut index_writer = index.writer_for_tests()?; + + // Create a channel to send documents to the indexer + let (doc_sender, doc_receiver) = mpsc::channel(); + + // Spawn a thread to add documents + let sender_clone = doc_sender.clone(); + let handle = thread::spawn(move || { + let docs = vec![doc!(text_field => "threaded")]; + for doc in docs { + sender_clone.send(doc).unwrap(); + } + }); + + // Drop the extra sender to close the channel when done + drop(doc_sender); + + // Indexer thread + for doc in doc_receiver { + index_writer.add_document(doc)?; + } + + index_writer.commit()?; + handle.join().unwrap(); + + let reader = index.reader()?; + reader.reload()?; + let searcher = reader.searcher(); + + let term = Term::from_field_text(text_field, "threaded"); + let query = TermQuery::new(term, IndexRecordOption::Basic); + let top_docs = searcher.search(&query, &TopDocs::with_limit(10))?; + assert_eq!(top_docs.len(), 1); + + Ok(()) + } + #[test] fn test_add_then_delete_all_documents() { let mut schema_builder = schema::Schema::builder(); diff --git a/src/indexer/mod.rs b/src/indexer/mod.rs index c583a4b5c5..76f805754c 100644 --- a/src/indexer/mod.rs +++ b/src/indexer/mod.rs @@ -35,7 +35,7 @@ pub use self::index_writer::IndexWriter; pub use self::log_merge_policy::LogMergePolicy; pub use self::merge_operation::MergeOperation; pub use self::merge_policy::{MergeCandidate, MergePolicy, NoMergePolicy}; -use self::operation::AddOperation; +pub use self::operation::AddOperation; pub use self::operation::UserOperation; pub use self::prepared_commit::PreparedCommit; pub use self::segment_entry::SegmentEntry; diff --git a/src/lib.rs b/src/lib.rs index cf4308d557..8887089016 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ +#![allow(warnings)] #![doc(html_logo_url = "http://fulmicoton.com/tantivy-logo/tantivy-logo.png")] #![cfg_attr(all(feature = "unstable", test), feature(test))] #![doc(test(attr(allow(unused_variables), deny(warnings))))] diff --git a/src/query/automaton_weight.rs b/src/query/automaton_weight.rs index 5f1053fb67..8be9f897e9 100644 --- a/src/query/automaton_weight.rs +++ b/src/query/automaton_weight.rs @@ -1,15 +1,18 @@ +use crate::postings::TermInfo; +use crate::query::fuzzy_query::DfaWrapper; +use crate::query::score_combiner::SumCombiner; + +use std::any::{Any, TypeId}; use std::io; use std::sync::Arc; -use common::BitSet; use tantivy_fst::Automaton; use super::phrase_prefix_query::prefix_end; use crate::index::SegmentReader; -use crate::postings::TermInfo; -use crate::query::{BitSetDocSet, ConstScorer, Explanation, Scorer, Weight}; +use crate::query::{BufferedUnionScorer, ConstScorer, Explanation, Scorer, Weight}; use crate::schema::{Field, IndexRecordOption}; -use crate::termdict::{TermDictionary, TermStreamer}; +use crate::termdict::{TermDictionary, TermWithStateStreamer}; use crate::{DocId, Score, TantivyError}; /// A weight struct for Fuzzy Term and Regex Queries @@ -52,9 +55,9 @@ where fn automaton_stream<'a>( &'a self, term_dict: &'a TermDictionary, - ) -> io::Result> { + ) -> io::Result> { let automaton: &A = &self.automaton; - let mut term_stream_builder = term_dict.search(automaton); + let mut term_stream_builder = term_dict.search_with_state(automaton); if let Some(json_path_bytes) = &self.json_path_bytes { term_stream_builder = term_stream_builder.ge(json_path_bytes); @@ -85,35 +88,27 @@ where A::State: Clone, { fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result> { - let max_doc = reader.max_doc(); - let mut doc_bitset = BitSet::with_max_value(max_doc); let inverted_index = reader.inverted_index(self.field)?; let term_dict = inverted_index.terms(); let mut term_stream = self.automaton_stream(term_dict)?; - while term_stream.advance() { - let term_info = term_stream.value(); - let mut block_segment_postings = inverted_index - .read_block_postings_from_terminfo(term_info, IndexRecordOption::Basic)?; - loop { - let docs = block_segment_postings.docs(); - if docs.is_empty() { - break; - } - for &doc in docs { - doc_bitset.insert(doc); - } - block_segment_postings.advance(); - } + + let mut scorers = vec![]; + while let Some((_term, term_info, state)) = term_stream.next() { + let score = automaton_score(self.automaton.as_ref(), state); + let segment_postings = + inverted_index.read_postings_from_terminfo(term_info, IndexRecordOption::Basic)?; + let scorer = ConstScorer::new(segment_postings, boost * score); + scorers.push(scorer); } - let doc_bitset = BitSetDocSet::from(doc_bitset); - let const_scorer = ConstScorer::new(doc_bitset, boost); - Ok(Box::new(const_scorer)) + + let scorer = BufferedUnionScorer::build(scorers, SumCombiner::default); + Ok(Box::new(scorer)) } fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result { let mut scorer = self.scorer(reader, 1.0)?; if scorer.seek(doc) == doc { - Ok(Explanation::new("AutomatonScorer", 1.0)) + Ok(Explanation::new("AutomatonScorer", scorer.score())) } else { Err(TantivyError::InvalidArgument( "Document does not exist".to_string(), @@ -122,6 +117,25 @@ where } } +fn automaton_score(automaton: &A, state: A::State) -> f32 +where + A: Automaton + Send + Sync + 'static, + A::State: Clone, +{ + if TypeId::of::() == automaton.type_id() && TypeId::of::() == state.type_id() { + let dfa = automaton as *const A as *const DfaWrapper; + let dfa = unsafe { &*dfa }; + + let id = &state as *const A::State as *const u32; + let id = unsafe { *id }; + + let dist = dfa.0.distance(id).to_u8() as f32; + 1.0 / (1.0 + dist) + } else { + 1.0 + } +} + #[cfg(test)] mod tests { use tantivy_fst::Automaton; diff --git a/src/query/block_join_query.rs b/src/query/block_join_query.rs new file mode 100644 index 0000000000..3d66e230a9 --- /dev/null +++ b/src/query/block_join_query.rs @@ -0,0 +1,1474 @@ +// src/query/block_join_query.rs + +use crate::core::searcher::Searcher; +use crate::query::{EnableScoring, Explanation, Query, QueryClone, Scorer, Weight}; +use crate::schema::Term; +use crate::{DocAddress, DocId, DocSet, Result, Score, SegmentReader, TERMINATED}; +use common::BitSet; +use std::fmt; + +/// How scores should be aggregated from child documents. +#[derive(Clone, Copy, Debug, PartialEq)] +pub enum ScoreMode { + /// Use the average of all child scores as the parent score. + Avg, + /// Use the maximum child score as the parent score. + Max, + /// Sum all child scores for the parent score. + Total, + /// Do not score parent docs from child docs. Just rely on parent scoring. + None, +} + +impl Default for ScoreMode { + fn default() -> Self { + ScoreMode::Avg + } +} + +/// `BlockJoinQuery` performs a join from child documents to parent documents, +/// based on a block structure: child documents are indexed before their parent. +/// The `parents_filter` identifies the parent documents in each segment. +pub struct BlockJoinQuery { + child_query: Box, + parents_filter: Box, + score_mode: ScoreMode, +} + +impl Clone for BlockJoinQuery { + fn clone(&self) -> Self { + BlockJoinQuery { + child_query: self.child_query.box_clone(), + parents_filter: self.parents_filter.box_clone(), + score_mode: self.score_mode, + } + } +} + +impl BlockJoinQuery { + /// Creates a new `BlockJoinQuery`. + /// + /// # Arguments + /// + /// * `child_query` - The query to match child documents. + /// * `parents_filter` - The query to identify parent documents. + /// * `score_mode` - The mode to aggregate scores from child documents. + pub fn new( + child_query: Box, + parents_filter: Box, + score_mode: ScoreMode, + ) -> BlockJoinQuery { + BlockJoinQuery { + child_query, + parents_filter, + score_mode, + } + } +} + +impl fmt::Debug for BlockJoinQuery { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "BlockJoinQuery(child_query: {:?}, parents_filter: {:?}, score_mode: {:?})", + self.child_query, self.parents_filter, self.score_mode + ) + } +} + +impl Query for BlockJoinQuery { + fn weight(&self, enable_scoring: EnableScoring<'_>) -> Result> { + println!("BlockJoinQuery::weight() - Creating weights"); + let child_weight = self.child_query.weight(enable_scoring.clone())?; + println!("BlockJoinQuery::weight() - Created child weight"); + let parents_weight = self.parents_filter.weight(enable_scoring)?; + println!("BlockJoinQuery::weight() - Created parent weight"); + + Ok(Box::new(BlockJoinWeight { + child_weight, + parents_weight, + score_mode: self.score_mode, + })) + } + + fn explain(&self, searcher: &Searcher, doc_address: DocAddress) -> Result { + let reader = searcher.segment_reader(doc_address.segment_ord); + let mut scorer = self + .weight(EnableScoring::enabled_from_searcher(searcher))? + .scorer(reader, 1.0)?; + + // Perform an initial advance to move the scorer to the first matching document + let mut current_doc = scorer.advance(); + + // Continue advancing until the target doc_id is reached or surpassed + while current_doc != TERMINATED && current_doc < doc_address.doc_id { + current_doc = scorer.advance(); + } + + let score = if current_doc == doc_address.doc_id { + scorer.score() + } else { + 0.0 + }; + + let mut explanation = Explanation::new("BlockJoinQuery", score); + explanation.add_detail(Explanation::new("score", score)); + Ok(explanation) + } + + fn count(&self, searcher: &Searcher) -> Result { + let weight = self.weight(EnableScoring::disabled_from_searcher(searcher))?; + let mut total_count = 0; + for reader in searcher.segment_readers() { + total_count += weight.count(reader)? as usize; + } + Ok(total_count) + } + + fn query_terms<'a>(&'a self, visitor: &mut dyn FnMut(&'a Term, bool)) { + self.child_query.query_terms(visitor); + self.parents_filter.query_terms(visitor); + } +} + +struct BlockJoinWeight { + child_weight: Box, + parents_weight: Box, + score_mode: ScoreMode, +} + +impl Weight for BlockJoinWeight { + fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result> { + println!( + "BlockJoinWeight::scorer() - Creating scorer with boost {}", + boost + ); + + let max_doc = reader.max_doc(); + println!("BlockJoinWeight::scorer() - Max doc value: {}", max_doc); + let mut parents_bitset = BitSet::with_max_value(max_doc); + + // Create a scorer for parent documents + println!("BlockJoinWeight::scorer() - Creating parent scorer"); + let mut parents_scorer = self.parents_weight.scorer(reader, boost.clone())?; + println!("BlockJoinWeight::scorer() - Parent scorer created"); + + // Iterate through all parent documents and filter based on child matches + let mut found_parent = false; + let mut parent_count = 0; + let mut previous_parent = TERMINATED; + + while parents_scorer.doc() != TERMINATED { + let parent_doc = parents_scorer.doc(); + println!( + "BlockJoinWeight::scorer() - Found parent doc: {}", + parent_doc + ); + + // Define the range of child documents for this parent + let start_doc = if previous_parent == TERMINATED { + 0 + } else { + previous_parent + 1 + }; + let end_doc = parent_doc; + + // Create a new child scorer for each parent to check for matching children + let mut child_scorer = self.child_weight.scorer(reader, boost.clone())?; + // Advance the child scorer to the start of the current parent's children + while child_scorer.doc() != TERMINATED && child_scorer.doc() < start_doc { + child_scorer.advance(); + } + + // Check if any child within the block matches the child query + let mut has_matching_child = false; + while child_scorer.doc() != TERMINATED && child_scorer.doc() < end_doc { + let score = child_scorer.score(); + if score > 0.0 { + has_matching_child = true; + break; + } + child_scorer.advance(); + } + + if has_matching_child { + parents_bitset.insert(parent_doc); + found_parent = true; + parent_count += 1; + } + + previous_parent = parent_doc; + parents_scorer.advance(); + } + + println!( + "BlockJoinWeight::scorer() - Found {} parent documents with matching children", + parent_count + ); + + if !found_parent { + println!("BlockJoinWeight::scorer() - No parents found with matching children, returning empty scorer"); + return Ok(Box::new(EmptyScorer)); + } + + // Initialize with the first matching parent + let mut first_parent = TERMINATED; + for i in 0..=max_doc { + if parents_bitset.contains(i) { + first_parent = i; + break; + } + } + + println!( + "BlockJoinWeight::scorer() - Creating BlockJoinScorer (first_parent: {})", + first_parent + ); + let scorer = BlockJoinScorer { + child_scorer: self.child_weight.scorer(reader, boost)?, + parent_docs: parents_bitset, + score_mode: self.score_mode, + current_parent: first_parent, + previous_parent: None, + current_score: 1.0, + initialized: false, + has_more: first_parent != TERMINATED, + }; + Ok(Box::new(scorer)) + } + + fn explain(&self, _reader: &SegmentReader, _doc: DocId) -> crate::Result { + unimplemented!("Explain is not implemented for BlockJoinWeight"); + } + + fn count(&self, reader: &SegmentReader) -> crate::Result { + let mut count = 0; + let mut scorer = self.scorer(reader, 1.0)?; + while scorer.doc() != TERMINATED { + count += 1; + scorer.advance(); + } + Ok(count) + } + + /// Correctly implemented `for_each_pruning` method + fn for_each_pruning( + &self, + threshold: Score, + reader: &SegmentReader, + callback: &mut dyn FnMut(DocId, Score) -> Score, + ) -> crate::Result<()> { + println!( + "BlockJoinWeight::for_each_pruning() - Starting with threshold {}", + threshold + ); + + // Create a scorer for parent documents + let mut parents_scorer = self.parents_weight.scorer(reader, 1.0)?; + println!("BlockJoinWeight::for_each_pruning() - Parent scorer created"); + + let mut previous_parent = TERMINATED; + + // Iterate through all parent documents + while parents_scorer.doc() != TERMINATED { + let parent_doc = parents_scorer.doc(); + println!( + "BlockJoinWeight::for_each_pruning() - Found parent doc: {}", + parent_doc + ); + + // Define the range of child documents for this parent + let start_doc = if previous_parent == TERMINATED { + 0 + } else { + previous_parent + 1 + }; + let end_doc = parent_doc; + + // Create a new child scorer for each parent to check for matching children + let mut child_scorer = self.child_weight.scorer(reader, 1.0)?; + // Advance the child scorer to the start of the current parent's children + while child_scorer.doc() != TERMINATED && child_scorer.doc() < start_doc { + child_scorer.advance(); + } + + // Check if any child within the block matches the child query + let mut has_matching_child = false; + while child_scorer.doc() != TERMINATED && child_scorer.doc() < end_doc { + let score = child_scorer.score(); + if score > 0.0 { + has_matching_child = true; + break; + } + child_scorer.advance(); + } + + if has_matching_child { + // Assign a score based on ScoreMode + let score = match self.score_mode { + ScoreMode::Avg | ScoreMode::Max | ScoreMode::Total => { + // Simplified: assign a fixed score. + // Implement actual score calculations based on ScoreMode if needed. + 1.0 + } + ScoreMode::None => 1.0, + }; + + if score >= threshold { + println!( + "BlockJoinWeight::for_each_pruning() - Processing parent doc: {}, score: {}", + parent_doc, score + ); + let new_threshold = callback(parent_doc, score); + println!( + "BlockJoinWeight::for_each_pruning() - New threshold after callback: {}", + new_threshold + ); + + // Update the threshold + if new_threshold > score { + // If the new threshold is higher than the current score, we can stop early + println!( + "BlockJoinWeight::for_each_pruning() - Early termination as new threshold {} > score {}", + new_threshold, score + ); + break; + } + } + } + + previous_parent = parent_doc; + + // Advance to the next parent document + parents_scorer.advance(); + } + + println!("BlockJoinWeight::for_each_pruning() - Completed"); + Ok(()) + } +} + +struct EmptyScorer; + +impl DocSet for EmptyScorer { + fn advance(&mut self) -> DocId { + TERMINATED + } + + fn doc(&self) -> DocId { + TERMINATED + } + + fn size_hint(&self) -> u32 { + 0 + } +} + +impl Scorer for EmptyScorer { + fn score(&mut self) -> Score { + 0.0 + } +} + +struct BlockJoinScorer { + child_scorer: Box, + parent_docs: BitSet, + score_mode: ScoreMode, + current_parent: DocId, + previous_parent: Option, + current_score: Score, + initialized: bool, + has_more: bool, +} + +impl DocSet for BlockJoinScorer { + fn advance(&mut self) -> DocId { + if !self.has_more { + return TERMINATED; + } + + if !self.initialized { + self.initialized = true; + self.previous_parent = None; + self.collect_matches(); + return self.current_parent; + } + + // Find next parent after current one + let next_parent = self.find_next_parent(self.current_parent + 1); + if next_parent == TERMINATED { + self.has_more = false; + self.current_parent = TERMINATED; + return TERMINATED; + } + + self.previous_parent = Some(self.current_parent); + self.current_parent = next_parent; + self.collect_matches(); + self.current_parent + } + + fn doc(&self) -> DocId { + if !self.initialized { + TERMINATED + } else if self.has_more { + self.current_parent + } else { + TERMINATED + } + } + + fn size_hint(&self) -> u32 { + self.parent_docs.len() as u32 + } +} + +impl BlockJoinScorer { + fn initialize(&mut self) { + println!("Initializing BlockJoinScorer..."); + if !self.initialized { + // Initialize the child scorer + let _child_doc = self.child_scorer.advance(); + + // Find the first parent + let first_parent = self.find_next_parent(0); + if first_parent != TERMINATED { + self.current_parent = first_parent; + self.has_more = true; + self.collect_matches(); + } else { + self.has_more = false; + self.current_parent = TERMINATED; + } + + self.initialized = true; + println!( + "Initialization complete: current_parent={}, has_more={}", + self.current_parent, self.has_more + ); + } + } + + fn find_next_parent(&self, from: DocId) -> DocId { + println!("Finding next parent from {}", from); + let mut current = from; + let max_val = self.parent_docs.max_value(); + + while current <= max_val { + if self.parent_docs.contains(current) { + println!("Found parent at {}", current); + return current; + } + current += 1; + } + println!("No more parents found"); + TERMINATED + } + + fn collect_matches(&mut self) { + let mut child_scores = Vec::new(); + + // Determine the starting document ID for collecting child documents + let start_doc = match self.previous_parent { + Some(prev_parent_doc) => prev_parent_doc + 1, + None => 0, + }; + + // Advance the child_scorer to the start_doc if necessary + let mut current_child = self.child_scorer.doc(); + while current_child != TERMINATED && current_child < start_doc { + current_child = self.child_scorer.advance(); + } + + let end_doc = self.current_parent; + + // Collect all child documents between start_doc and end_doc + while current_child != TERMINATED && current_child < end_doc { + child_scores.push(self.child_scorer.score()); + current_child = self.child_scorer.advance(); + } + + // Aggregate the scores according to the score_mode + self.current_score = match self.score_mode { + ScoreMode::Avg => { + if child_scores.is_empty() { + 1.0 + } else { + child_scores.iter().sum::() / child_scores.len() as Score + } + } + ScoreMode::Max => { + if child_scores.is_empty() { + 1.0 + } else { + child_scores.iter().cloned().fold(f32::MIN, f32::max) + } + } + ScoreMode::Total => { + if child_scores.is_empty() { + 1.0 + } else { + child_scores.iter().sum() + } + } + ScoreMode::None => 1.0, + }; + } +} + +impl Scorer for BlockJoinScorer { + fn score(&mut self) -> Score { + self.current_score + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::collector::TopDocs; + use crate::query::TermQuery; + use crate::schema::{Field, IndexRecordOption, Schema, Value, STORED, STRING}; + use crate::{DocAddress, Index, IndexWriter, TantivyDocument, Term}; + + /// Helper function to create a test index with parent and child documents. + fn create_test_index() -> crate::Result<(Index, Field, Field, Field, Field)> { + let mut schema_builder = Schema::builder(); + let name_field = schema_builder.add_text_field("name", STRING | STORED); + let country_field = schema_builder.add_text_field("country", STRING | STORED); + let skill_field = schema_builder.add_text_field("skill", STRING | STORED); + let doc_type_field = schema_builder.add_text_field("docType", STRING); + let schema = schema_builder.build(); + + let index = Index::create_in_ram(schema); + { + let mut index_writer: IndexWriter = index.writer_for_tests()?; + + // First block: + // children docs first, parent doc last + index_writer.add_documents(vec![ + doc!( + skill_field => "java", + doc_type_field => "job" + ), + doc!( + skill_field => "python", + doc_type_field => "job" + ), + doc!( + skill_field => "java", + doc_type_field => "job" + ), + // parent last in this block + doc!( + name_field => "Lisa", + country_field => "United Kingdom", + doc_type_field => "resume" // Consistent identifier for parent + ), + ])?; + + // Second block: + index_writer.add_documents(vec![ + doc!( + skill_field => "ruby", + doc_type_field => "job" + ), + doc!( + skill_field => "java", + doc_type_field => "job" + ), + // parent last in this block + doc!( + name_field => "Frank", + country_field => "United States", + doc_type_field => "resume" // Consistent identifier for parent + ), + ])?; + + index_writer.commit()?; + } + Ok(( + index, + name_field, + country_field, + skill_field, + doc_type_field, + )) + } + + #[test] + pub fn test_simple_block_join() -> crate::Result<()> { + let (index, name_field, _country_field, skill_field, doc_type_field) = create_test_index()?; + let reader = index.reader()?; + let searcher = reader.searcher(); + + let parent_query = TermQuery::new( + Term::from_field_text(doc_type_field, "resume"), // Updated from "parent" to "resume" + IndexRecordOption::Basic, + ); + + let child_query = TermQuery::new( + Term::from_field_text(skill_field, "java"), + IndexRecordOption::Basic, + ); + + let block_join_query = BlockJoinQuery::new( + Box::new(child_query), + Box::new(parent_query), + ScoreMode::Avg, + ); + + let top_docs = searcher.search(&block_join_query, &TopDocs::with_limit(1))?; + assert_eq!(top_docs.len(), 1, "Should find 1 top document"); + + let doc: TantivyDocument = searcher.doc(top_docs[0].1)?; + assert_eq!( + doc.get_first(name_field).unwrap().as_str().unwrap(), + "Lisa", + "Expected top document to be 'Lisa'" + ); + + Ok(()) + } + + #[test] + pub fn test_block_join_no_matches() -> crate::Result<()> { + let (index, name_field, country_field, skill_field, doc_type_field) = create_test_index()?; + let reader = index.reader()?; + let searcher = reader.searcher(); + + // Use "ruby" to match only "Frank"'s child + let parent_query = TermQuery::new( + Term::from_field_text(doc_type_field, "resume"), + IndexRecordOption::Basic, + ); + + let child_query = TermQuery::new( + Term::from_field_text(skill_field, "ruby"), + IndexRecordOption::Basic, + ); + + let block_join_query = BlockJoinQuery::new( + Box::new(child_query), + Box::new(parent_query), + ScoreMode::Avg, + ); + + let top_docs = searcher.search(&block_join_query, &TopDocs::with_limit(1))?; + assert_eq!(top_docs.len(), 1, "Should find 1 top document"); + + let doc: TantivyDocument = searcher.doc(top_docs[0].1)?; + assert_eq!( + doc.get_first(name_field).unwrap().as_str().unwrap(), + "Frank", + "Expected top document to be 'Frank'" + ); + + Ok(()) + } + + #[test] + pub fn test_block_join_scoring() -> crate::Result<()> { + let (index, _name_field, _country_field, skill_field, doc_type_field) = + create_test_index()?; + let reader = index.reader()?; + let searcher = reader.searcher(); + + let parent_query = TermQuery::new( + Term::from_field_text(doc_type_field, "resume"), // Updated from "parent" to "resume" + IndexRecordOption::WithFreqs, + ); + + let child_query = TermQuery::new( + Term::from_field_text(skill_field, "java"), + IndexRecordOption::WithFreqs, + ); + + let block_join_query = BlockJoinQuery::new( + Box::new(child_query), + Box::new(parent_query), + ScoreMode::Avg, + ); + + let top_docs = searcher.search(&block_join_query, &TopDocs::with_limit(1))?; + assert_eq!(top_docs.len(), 1, "Should find 1 top document"); + + // Score should be influenced by children, ensure it's not zero + assert!( + top_docs[0].0 > 0.0, + "Top document score should be greater than 0.0" + ); + + Ok(()) + } + + #[test] + pub fn test_explain_block_join() -> crate::Result<()> { + let (index, _name_field, country_field, skill_field, doc_type_field) = create_test_index()?; + let reader = index.reader()?; + let searcher = reader.searcher(); + + let parent_query = TermQuery::new( + Term::from_field_text(doc_type_field, "resume"), + IndexRecordOption::Basic, + ); + + let child_query = TermQuery::new( + Term::from_field_text(skill_field, "ruby"), // Changed to "ruby" to match "Frank"'s child + IndexRecordOption::Basic, + ); + + let block_join_query = BlockJoinQuery::new( + Box::new(child_query), + Box::new(parent_query), + ScoreMode::None, // Ensures a fixed score + ); + + // The parent doc for "Frank" is doc6 in the first segment + let explanation = block_join_query.explain(&searcher, DocAddress::new(0, 6))?; + assert!( + explanation.value() > 0.0, + "Explanation score should be greater than 0.0" + ); + + Ok(()) + } +} + +// src/query/block_join_query.rs + +#[cfg(test)] +mod atomic_tests { + use super::*; + use crate::collector::TopDocs; + use crate::query::TermQuery; + use crate::schema::{Field, IndexRecordOption, Schema, Value, STORED, STRING}; + use crate::{DocAddress, Index, IndexWriter, TantivyDocument, Term}; + + /// Helper function to create a very simple test index with just one parent and one child + fn create_minimal_index() -> crate::Result<(Index, Field, Field)> { + let mut schema_builder = Schema::builder(); + let content_field = schema_builder.add_text_field("content", STRING | STORED); + let doc_type_field = schema_builder.add_text_field("docType", STRING); + let schema = schema_builder.build(); + + let index = Index::create_in_ram(schema); + { + let mut index_writer: IndexWriter = index.writer_for_tests()?; + + // Add one child and one parent + index_writer.add_documents(vec![ + doc!( + content_field => "child content", + doc_type_field => "child" + ), + doc!( + content_field => "first resume", // Changed from "parent" to "resume" + doc_type_field => "resume" // Changed from "parent" to "resume" + ), + ])?; + + index_writer.commit()?; + } + Ok((index, content_field, doc_type_field)) + } + + #[test] + fn test_parent_filter_only() -> crate::Result<()> { + let (index, _content_field, doc_type_field) = create_minimal_index()?; + let reader = index.reader()?; + let searcher = reader.searcher(); + + let parent_query = TermQuery::new( + Term::from_field_text(doc_type_field, "resume"), // Changed from "parent" to "resume" + IndexRecordOption::Basic, + ); + + // Just search for parents directly + let top_docs = searcher.search(&parent_query, &TopDocs::with_limit(1))?; + assert_eq!(top_docs.len(), 1, "Should find exactly one parent document"); + + Ok(()) + } + + #[test] + fn test_child_query_only() -> crate::Result<()> { + let (index, _content_field, doc_type_field) = create_minimal_index()?; + let reader = index.reader()?; + let searcher = reader.searcher(); + + let child_query = TermQuery::new( + Term::from_field_text(doc_type_field, "child"), + IndexRecordOption::Basic, + ); + + // Just search for children directly + let top_docs = searcher.search(&child_query, &TopDocs::with_limit(1))?; + assert_eq!(top_docs.len(), 1, "Should find exactly one child document"); + + Ok(()) + } + + #[test] + fn test_parent_bitset_creation() -> crate::Result<()> { + let (index, _content_field, doc_type_field) = create_minimal_index()?; + let reader = index.reader()?; + let searcher = reader.searcher(); + let segment_reader = searcher.segment_reader(0); + + let parent_query = TermQuery::new( + Term::from_field_text(doc_type_field, "resume"), // Changed from "parent" to "resume" + IndexRecordOption::Basic, + ); + + let parent_weight = + parent_query.weight(EnableScoring::disabled_from_searcher(&reader.searcher()))?; + let mut parent_scorer = parent_weight.scorer(segment_reader, 1.0)?; + + let mut parent_docs = Vec::new(); + while parent_scorer.doc() != TERMINATED { + parent_docs.push(parent_scorer.doc()); + parent_scorer.advance(); + } + + assert_eq!( + parent_docs.len(), + 1, + "Should find exactly one parent document" + ); + assert_eq!(parent_docs[0], 1, "Parent document should be at position 1"); + + Ok(()) + } + + #[test] + fn test_minimal_block_join() -> crate::Result<()> { + let (index, content_field, doc_type_field) = create_minimal_index()?; + let reader = index.reader()?; + let searcher = reader.searcher(); + + let parent_query = TermQuery::new( + Term::from_field_text(doc_type_field, "resume"), // Changed from "parent" to "resume" + IndexRecordOption::Basic, + ); + + let child_query = TermQuery::new( + Term::from_field_text(doc_type_field, "child"), + IndexRecordOption::Basic, + ); + + let block_join_query = BlockJoinQuery::new( + Box::new(child_query), + Box::new(parent_query), + ScoreMode::None, // Start with simplest scoring mode + ); + + let top_docs = searcher.search(&block_join_query, &TopDocs::with_limit(1))?; + assert_eq!(top_docs.len(), 1, "Should find exactly one document"); + + let doc: TantivyDocument = searcher.doc(top_docs[0].1)?; + let content = doc.get_first(content_field).unwrap().as_str().unwrap(); + assert_eq!(content, "first resume", "Should retrieve parent document"); + + Ok(()) + } +} + +// src/query/block_join_query.rs + +#[cfg(test)] +mod atomic_scorer_tests { + use super::*; + use crate::collector::TopDocs; + use crate::query::TermQuery; + use crate::schema::{Field, IndexRecordOption, Schema, Value, STORED, STRING}; + use crate::{DocAddress, Index, IndexWriter, TantivyDocument, Term}; + + /// Creates a test index with a very specific document arrangement for testing scorer behavior + fn create_scorer_test_index() -> crate::Result<(Index, Field, Field)> { + let mut schema_builder = Schema::builder(); + let content_field = schema_builder.add_text_field("content", STRING | STORED); + let doc_type_field = schema_builder.add_text_field("docType", STRING); + let schema = schema_builder.build(); + + let index = Index::create_in_ram(schema); + { + let mut index_writer: IndexWriter = index.writer_for_tests()?; + + // Create a very specific arrangement: + // doc0: child + // doc1: resume + // doc2: child + // doc3: resume + index_writer.add_documents(vec![ + // First block + doc!( + content_field => "first child", + doc_type_field => "child" + ), + doc!( + content_field => "first resume", // Changed from "parent" to "resume" + doc_type_field => "resume" // Changed from "parent" to "resume" + ), + // Second block + doc!( + content_field => "second child", + doc_type_field => "child" + ), + doc!( + content_field => "second resume", // Changed from "parent" to "resume" + doc_type_field => "resume" // Changed from "parent" to "resume" + ), + ])?; + + index_writer.commit()?; + } + Ok((index, content_field, doc_type_field)) + } + + #[test] + pub fn test_parent_filter_only() -> crate::Result<()> { + let (index, _content_field, doc_type_field) = create_scorer_test_index()?; + let reader = index.reader()?; + let searcher = reader.searcher(); + + let parent_query = TermQuery::new( + Term::from_field_text(doc_type_field, "resume"), // Changed from "parent" to "resume" + IndexRecordOption::Basic, + ); + + let top_docs = searcher.search(&parent_query, &TopDocs::with_limit(1))?; + assert_eq!(top_docs.len(), 1, "Should find exactly one parent document"); + + Ok(()) + } + + #[test] + pub fn test_child_query_only() -> crate::Result<()> { + let (index, _content_field, doc_type_field) = create_scorer_test_index()?; + let reader = index.reader()?; + let searcher = reader.searcher(); + + let child_query = TermQuery::new( + Term::from_field_text(doc_type_field, "child"), + IndexRecordOption::Basic, + ); + + let top_docs = searcher.search(&child_query, &TopDocs::with_limit(1))?; + assert_eq!(top_docs.len(), 1, "Should find exactly one child document"); + + Ok(()) + } + + #[test] + pub fn test_parent_bitset_creation() -> crate::Result<()> { + let (index, _content_field, doc_type_field) = create_scorer_test_index()?; + let reader = index.reader()?; + let searcher = reader.searcher(); + let segment_reader = searcher.segment_reader(0); + + let parent_query = TermQuery::new( + Term::from_field_text(doc_type_field, "resume"), // Changed from "parent" to "resume" + IndexRecordOption::Basic, + ); + + let parent_weight = + parent_query.weight(EnableScoring::disabled_from_searcher(&searcher))?; + let mut parent_scorer = parent_weight.scorer(segment_reader, 1.0)?; + + let mut parent_docs = Vec::new(); + while parent_scorer.doc() != TERMINATED { + parent_docs.push(parent_scorer.doc()); + parent_scorer.advance(); + } + + assert_eq!( + parent_docs.len(), + 2, + "Should find exactly two parent documents" + ); + assert_eq!( + parent_docs, + vec![1, 3], + "Parents should be at positions 1 and 3" + ); + + Ok(()) + } + + #[test] + pub fn test_minimal_block_join() -> crate::Result<()> { + let (index, content_field, doc_type_field) = create_scorer_test_index()?; + let reader = index.reader()?; + let searcher = reader.searcher(); + + let parent_query = TermQuery::new( + Term::from_field_text(doc_type_field, "resume"), + IndexRecordOption::Basic, + ); + + let child_query = TermQuery::new( + Term::from_field_text(doc_type_field, "child"), + IndexRecordOption::Basic, + ); + + let block_join_query = BlockJoinQuery::new( + Box::new(child_query), + Box::new(parent_query), + ScoreMode::None, // Start with simplest scoring mode + ); + + let top_docs = searcher.search(&block_join_query, &TopDocs::with_limit(1))?; + assert_eq!(top_docs.len(), 1, "Should find exactly one document"); + + let doc: TantivyDocument = searcher.doc(top_docs[0].1)?; + let content = doc.get_first(content_field).unwrap().as_str().unwrap(); + assert_eq!(content, "first resume", "Should retrieve parent document"); + + Ok(()) + } +} + +// src/query/block_join_query.rs + +#[cfg(test)] +mod first_advance_tests { + use super::*; + use crate::query::TermQuery; + use crate::schema::{Field, IndexRecordOption, Schema, STRING}; + use crate::{DocAddress, Index, IndexWriter, TantivyDocument, Term}; + + /// Creates a minimal test index with exactly one child followed by one parent + fn create_single_block_index() -> crate::Result<(Index, Field)> { + let mut schema_builder = Schema::builder(); + let doc_type_field = schema_builder.add_text_field("docType", STRING); + let schema = schema_builder.build(); + + let index = Index::create_in_ram(schema); + { + let mut index_writer: IndexWriter = index.writer_for_tests()?; + + // Single block: one child, one parent + index_writer.add_documents(vec![ + doc!(doc_type_field => "child"), + doc!(doc_type_field => "resume"), // Changed from "parent" to "resume" + ])?; + + index_writer.commit()?; + } + Ok((index, doc_type_field)) + } + + #[test] + fn test_first_advance_behavior() -> crate::Result<()> { + let (index, doc_type_field) = create_single_block_index()?; + let reader = index.reader()?; + let searcher = reader.searcher(); + let segment_reader = searcher.segment_reader(0); + + let parent_query = TermQuery::new( + Term::from_field_text(doc_type_field, "resume"), // Changed from "parent" to "resume" + IndexRecordOption::Basic, + ); + let child_query = TermQuery::new( + Term::from_field_text(doc_type_field, "child"), + IndexRecordOption::Basic, + ); + + let block_join_weight = BlockJoinWeight { + child_weight: child_query.weight(EnableScoring::disabled_from_searcher(&searcher))?, + parents_weight: parent_query + .weight(EnableScoring::disabled_from_searcher(&searcher))?, + score_mode: ScoreMode::None, + }; + + let mut scorer = block_join_weight.scorer(segment_reader, 1.0)?; + + println!("Initial doc: {}", scorer.doc()); + + // First advance should find the parent + let first_doc = scorer.advance(); + println!("After first advance: {}", first_doc); + + assert_eq!( + first_doc, 1, + "First advance should find parent at position 1" + ); + + // Subsequent advance should find TERMINATED + let next_doc = scorer.advance(); + println!("After second advance: {}", next_doc); + + assert_eq!( + next_doc, TERMINATED, + "Second advance should return TERMINATED" + ); + + Ok(()) + } + + #[test] + fn test_block_join_scoring() -> crate::Result<()> { + let (index, doc_type_field) = create_single_block_index()?; + let reader = index.reader()?; + let searcher = reader.searcher(); + let segment_reader = searcher.segment_reader(0); + + let parent_query = TermQuery::new( + Term::from_field_text(doc_type_field, "resume"), // Changed from "parent" to "resume" + IndexRecordOption::Basic, + ); + let child_query = TermQuery::new( + Term::from_field_text(doc_type_field, "child"), + IndexRecordOption::Basic, + ); + + let block_join_weight = BlockJoinWeight { + child_weight: child_query.weight(EnableScoring::disabled_from_searcher(&searcher))?, + parents_weight: parent_query + .weight(EnableScoring::disabled_from_searcher(&searcher))?, + score_mode: ScoreMode::None, + }; + + let mut scorer = block_join_weight.scorer(segment_reader, 1.0)?; + + // Advance to first parent + let doc = scorer.advance(); + assert_eq!(doc, 1, "Should find parent at position 1"); + + // Check the score + let score = scorer.score(); + assert_eq!(score, 1.0, "Score should be 1.0 with ScoreMode::None"); + + Ok(()) + } +} + +// src/query/block_join_query.rs + +#[cfg(test)] +mod advancement_tests { + use super::*; + use crate::query::TermQuery; + use crate::schema::{Field, IndexRecordOption, Schema, STRING}; + use crate::{DocAddress, Index, IndexWriter, TantivyDocument, Term}; + + /// Creates a test index with a specific pattern to test block membership: + /// doc0: child1 + /// doc1: resume1 + /// doc2: child2 + /// doc3: resume2 + fn create_block_test_index() -> crate::Result<(Index, Field)> { + let mut schema_builder = Schema::builder(); + let doc_type_field = schema_builder.add_text_field("docType", STRING); + let schema = schema_builder.build(); + + let index = Index::create_in_ram(schema); + { + let mut index_writer: IndexWriter = index.writer_for_tests()?; + + // First block + index_writer.add_documents(vec![ + doc!(doc_type_field => "child"), // doc0 + doc!(doc_type_field => "resume"), // doc1 + doc!(doc_type_field => "child"), // doc2 + doc!(doc_type_field => "resume"), // doc3 + ])?; + + index_writer.commit()?; + } + Ok((index, doc_type_field)) + } + + #[test] + fn test_initial_scorer_state() -> crate::Result<()> { + let (index, doc_type_field) = create_block_test_index()?; + let reader = index.reader()?; + let searcher = reader.searcher(); + let segment_reader = searcher.segment_reader(0); + + let parent_query = TermQuery::new( + Term::from_field_text(doc_type_field, "resume"), // Changed from "parent" to "resume" + IndexRecordOption::Basic, + ); + let child_query = TermQuery::new( + Term::from_field_text(doc_type_field, "child"), + IndexRecordOption::Basic, + ); + + let block_join_weight = BlockJoinWeight { + child_weight: child_query.weight(EnableScoring::disabled_from_searcher(&searcher))?, + parents_weight: parent_query + .weight(EnableScoring::disabled_from_searcher(&searcher))?, + score_mode: ScoreMode::None, + }; + + let mut scorer = block_join_weight.scorer(segment_reader, 1.0)?; + + // Initial doc should be TERMINATED + assert_eq!(scorer.doc(), TERMINATED, "Should start at TERMINATED"); + + // First advance should find the first parent + let first_doc = scorer.advance(); + assert_eq!( + first_doc, 1, + "First advance should find parent at position 1" + ); + + // Second advance should find the second parent + let second_doc = scorer.advance(); + assert_eq!( + second_doc, 3, + "Second advance should find parent at position 3" + ); + + // Third advance should return TERMINATED + let third_doc = scorer.advance(); + assert_eq!( + third_doc, TERMINATED, + "Third advance should return TERMINATED" + ); + + Ok(()) + } + + #[test] + fn test_block_join_scoring() -> crate::Result<()> { + let (index, doc_type_field) = create_block_test_index()?; + let reader = index.reader()?; + let searcher = reader.searcher(); + let segment_reader = searcher.segment_reader(0); + + let parent_query = TermQuery::new( + Term::from_field_text(doc_type_field, "resume"), // Changed from "parent" to "resume" + IndexRecordOption::Basic, + ); + let child_query = TermQuery::new( + Term::from_field_text(doc_type_field, "child"), + IndexRecordOption::Basic, + ); + + let block_join_weight = BlockJoinWeight { + child_weight: child_query.weight(EnableScoring::disabled_from_searcher(&searcher))?, + parents_weight: parent_query + .weight(EnableScoring::disabled_from_searcher(&searcher))?, + score_mode: ScoreMode::None, + }; + + let mut scorer = block_join_weight.scorer(segment_reader, 1.0)?; + + // Advance to first parent + let first_doc = scorer.advance(); + assert_eq!(first_doc, 1, "Should find parent at position 1"); + + // Check the score + let score = scorer.score(); + assert_eq!(score, 1.0, "Score should be 1.0 with ScoreMode::None"); + + // Advance to second parent + let second_doc = scorer.advance(); + assert_eq!(second_doc, 3, "Should find parent at position 3"); + + // Check the score + let score = scorer.score(); + assert_eq!(score, 1.0, "Score should be 1.0 with ScoreMode::None"); + + Ok(()) + } +} + +#[cfg(test)] +mod block_membership_tests { + use super::*; + use crate::collector::TopDocs; + use crate::query::TermQuery; + use crate::schema::{Field, IndexRecordOption, Schema, STORED, STRING}; + use crate::{Index, IndexWriter, Term}; + + /// Creates a test index with a specific pattern to test block membership: + /// doc0: child1 + /// doc1: child2 + /// doc2: resume1 + /// doc3: child3 + /// doc4: resume2 + fn create_block_test_index() -> crate::Result<(Index, Field)> { + let mut schema_builder = Schema::builder(); + let doc_type_field = schema_builder.add_text_field("docType", STRING); + let schema = schema_builder.build(); + + let index = Index::create_in_ram(schema); + { + let mut index_writer: IndexWriter = index.writer_for_tests()?; + + // First block + index_writer.add_documents(vec![ + doc!(doc_type_field => "child"), // doc0 + doc!(doc_type_field => "child"), // doc1 + doc!(doc_type_field => "resume"), // doc2 + ])?; + + // Second block + index_writer.add_documents(vec![ + doc!(doc_type_field => "child"), // doc3 + doc!(doc_type_field => "resume"), // doc4 + ])?; + + index_writer.commit()?; + } + Ok((index, doc_type_field)) + } + + #[test] + fn test_child_block_membership() -> crate::Result<()> { + let (index, doc_type_field) = create_block_test_index()?; + let reader = index.reader()?; + let searcher = reader.searcher(); + let segment_reader = searcher.segment_reader(0); + + let parent_query = TermQuery::new( + Term::from_field_text(doc_type_field, "resume"), + IndexRecordOption::Basic, + ); + let child_query = TermQuery::new( + Term::from_field_text(doc_type_field, "child"), + IndexRecordOption::Basic, + ); + + let block_join_weight = BlockJoinWeight { + child_weight: child_query.weight(EnableScoring::disabled_from_searcher(&searcher))?, + parents_weight: parent_query + .weight(EnableScoring::disabled_from_searcher(&searcher))?, + score_mode: ScoreMode::None, + }; + + let mut scorer = block_join_weight.scorer(segment_reader, 1.0)?; + + // Get first parent + let first_doc = scorer.advance(); + assert_eq!(first_doc, 2, "First parent should be at position 2"); + assert_eq!( + scorer.score(), + 1.0, + "Score should be 1.0 with ScoreMode::None" + ); + + // Get second parent + let second_doc = scorer.advance(); + assert_eq!(second_doc, 4, "Second parent should be at position 4"); + assert_eq!( + scorer.score(), + 1.0, + "Score should be 1.0 with ScoreMode::None" + ); + + Ok(()) + } + + #[test] + fn test_collect_matches_block_boundaries() -> crate::Result<()> { + let (index, doc_type_field) = create_block_test_index()?; + let reader = index.reader()?; + let searcher = reader.searcher(); + + let parent_query = TermQuery::new( + Term::from_field_text(doc_type_field, "resume"), + IndexRecordOption::Basic, + ); + let child_query = TermQuery::new( + Term::from_field_text(doc_type_field, "child"), + IndexRecordOption::Basic, + ); + + // First verify parents are correctly indexed + let parent_docs = searcher.search(&parent_query, &TopDocs::with_limit(10))?; + println!( + "\n=== Parent documents found directly ({} results) ===", + parent_docs.len() + ); + for (i, doc) in parent_docs.iter().enumerate() { + println!("Parent {}: doc_id={}, score={}", i, doc.1.doc_id, doc.0); + } + + // Test the block join query scoring + println!("\n=== Testing block join query ==="); + let block_join_query = BlockJoinQuery::new( + Box::new(child_query.clone()), + Box::new(parent_query.clone()), + ScoreMode::None, + ); + + let collector = TopDocs::with_limit(10); + println!("\n=== Searching with collector (limit: 10) ==="); + + let top_docs = searcher.search(&block_join_query, &collector)?; + println!( + "\n=== Search completed, found {} results ===", + top_docs.len() + ); + + for (i, doc) in top_docs.iter().enumerate() { + println!("Result {}: doc_id={}, score={}", i, doc.1.doc_id, doc.0); + } + + println!("\nAsserting results..."); + assert_eq!(top_docs.len(), 2, "Should find both parent documents"); + + let mut result_ids: Vec = top_docs.iter().map(|(_score, doc)| doc.doc_id).collect(); + result_ids.sort_unstable(); + println!("Sorted result IDs: {:?}", result_ids); + + assert_eq!(result_ids[0], 2, "Should find parent at position 2"); + assert_eq!(result_ids[1], 4, "Should find parent at position 4"); + + Ok(()) + } + + #[test] + fn test_scorer_behavior() -> crate::Result<()> { + let (index, doc_type_field) = create_block_test_index()?; + let reader = index.reader()?; + let searcher = reader.searcher(); + let segment_reader = searcher.segment_reader(0); + + let parent_query = TermQuery::new( + Term::from_field_text(doc_type_field, "resume"), + IndexRecordOption::Basic, + ); + let child_query = TermQuery::new( + Term::from_field_text(doc_type_field, "child"), + IndexRecordOption::Basic, + ); + + println!("\n=== Creating block join scorer ==="); + let block_join_query = BlockJoinQuery::new( + Box::new(child_query), + Box::new(parent_query), + ScoreMode::None, + ); + + let weight = block_join_query.weight(EnableScoring::disabled_from_searcher(&searcher))?; + let mut scorer = weight.scorer(segment_reader, 1.0)?; + + println!("\n=== Testing scorer directly ==="); + let mut docs = Vec::new(); + let initial_doc = scorer.doc(); + println!("Initial doc: {}", initial_doc); + + // First advance + let mut current = scorer.advance(); + println!("First advance: {}", current); + + while current != TERMINATED { + docs.push(current); + current = scorer.advance(); + println!("Advanced to: {}", current); + } + + println!("\nCollected docs: {:?}", docs); + assert!(!docs.is_empty(), "Scorer should find documents"); + assert_eq!(docs.len(), 2, "Should find both parents"); + assert_eq!(docs[0], 2, "First parent should be at position 2"); + assert_eq!(docs[1], 4, "Second parent should be at position 4"); + + Ok(()) + } +} diff --git a/src/query/fuzzy_query.rs b/src/query/fuzzy_query.rs index 143eed1c77..2f692992f2 100644 --- a/src/query/fuzzy_query.rs +++ b/src/query/fuzzy_query.rs @@ -295,7 +295,7 @@ mod test { let top_docs = searcher.search(&fuzzy_query, &TopDocs::with_limit(2))?; assert_eq!(top_docs.len(), 1, "Expected only 1 document"); let (score, _) = top_docs[0]; - assert_nearly_equals!(1.0, score); + assert_nearly_equals!(0.5, score); } // fails because non-prefix Levenshtein distance is more than 1 (add 'a' and 'n') diff --git a/src/query/mod.rs b/src/query/mod.rs index 23e64f1894..894598f64a 100644 --- a/src/query/mod.rs +++ b/src/query/mod.rs @@ -1,6 +1,7 @@ mod all_query; mod automaton_weight; mod bitset; +mod block_join_query; mod bm25; mod boolean_query; mod boost_query; @@ -14,6 +15,7 @@ mod explanation; mod fuzzy_query; mod intersection; mod more_like_this; +mod nested_document_query; mod phrase_prefix_query; mod phrase_query; mod query; diff --git a/src/query/nested_document_query.rs b/src/query/nested_document_query.rs new file mode 100644 index 0000000000..972d3fec0e --- /dev/null +++ b/src/query/nested_document_query.rs @@ -0,0 +1,479 @@ +use crate::core::searcher::Searcher; +use crate::query::{EnableScoring, Explanation, Query, QueryClone, Scorer, Weight}; +use crate::schema::Term; +use crate::{DocId, DocSet, Result, Score, SegmentReader, TERMINATED}; +use common::BitSet; +use std::fmt; + +#[derive(Clone, Copy, Debug, PartialEq)] +pub enum ScoreMode { + Avg, + Max, + Total, +} + +impl Default for ScoreMode { + fn default() -> Self { + ScoreMode::Avg + } +} + +pub struct NestedDocumentQuery { + child_query: Box, + parents_filter: Box, + score_mode: ScoreMode, +} + +impl Clone for NestedDocumentQuery { + fn clone(&self) -> Self { + NestedDocumentQuery { + child_query: self.child_query.box_clone(), + parents_filter: self.parents_filter.box_clone(), + score_mode: self.score_mode, + } + } +} + +impl NestedDocumentQuery { + pub fn new( + child_query: Box, + parents_filter: Box, + score_mode: ScoreMode, + ) -> NestedDocumentQuery { + NestedDocumentQuery { + child_query, + parents_filter, + score_mode, + } + } +} + +impl fmt::Debug for NestedDocumentQuery { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "NestedDocumentQuery(child_query: {:?}, parents_filter: {:?}, score_mode: {:?})", + self.child_query, self.parents_filter, self.score_mode + ) + } +} + +impl Query for NestedDocumentQuery { + fn weight(&self, enable_scoring: EnableScoring<'_>) -> Result> { + println!("Creating weights for NestedDocumentQuery"); + let child_weight = self.child_query.weight(enable_scoring)?; + let parents_weight = self.parents_filter.weight(enable_scoring)?; + + Ok(Box::new(NestedDocumentWeight { + child_weight, + parents_weight, + score_mode: self.score_mode, + })) + } + + fn explain( + &self, + _searcher: &Searcher, + _doc_address: crate::DocAddress, + ) -> Result { + unimplemented!("Explain is not implemented for NestedDocumentQuery"); + } + + fn count(&self, searcher: &Searcher) -> Result { + let weight = self.weight(EnableScoring::disabled_from_searcher(searcher))?; + let mut total_count = 0; + for reader in searcher.segment_readers() { + total_count += weight.count(reader)?; + } + Ok(total_count as usize) + } + + fn query_terms<'a>(&'a self, visitor: &mut dyn FnMut(&'a Term, bool)) { + self.child_query.query_terms(visitor); + self.parents_filter.query_terms(visitor); + } +} + +struct NestedDocumentWeight { + child_weight: Box, + parents_weight: Box, + score_mode: ScoreMode, +} + +impl Weight for NestedDocumentWeight { + fn scorer(&self, reader: &SegmentReader, boost: Score) -> Result> { + println!("\n=== Creating NestedDocumentScorer for segment ==="); + println!("Max doc in segment: {}", reader.max_doc()); + + // Create parents bitset + let mut parents_bitset = BitSet::with_max_value(reader.max_doc()); + + println!("Building parents bitset with boost: {}", boost); + let mut parents_scorer = self.parents_weight.scorer(reader, boost)?; + + // Collect all parent documents + let mut found_parent = false; + while parents_scorer.doc() != TERMINATED { + let parent_doc = parents_scorer.doc(); + parents_bitset.insert(parent_doc); + println!( + "Found parent doc: {} (total parents: {})", + parent_doc, + parents_bitset.len() + ); + found_parent = true; + parents_scorer.advance(); + } + + // If no parents in this segment, return empty scorer + if !found_parent { + println!("No parents found in segment, returning empty scorer"); + return Ok(Box::new(EmptyScorer)); + } + + println!("Total parent docs found: {}", parents_bitset.len()); + + // Get child scorer + let child_scorer = self.child_weight.scorer(reader, boost)?; + + Ok(Box::new(NestedDocumentScorer { + child_scorer, + parent_docs: parents_bitset, + score_mode: self.score_mode, + current_parent: 0, + current_score: 0.0, + initialized: false, + has_more: true, + })) + } + + fn explain(&self, _reader: &SegmentReader, _doc: DocId) -> Result { + unimplemented!("Explain is not implemented for NestedDocumentWeight"); + } + + fn count(&self, reader: &SegmentReader) -> Result { + let mut count = 0; + let mut scorer = self.scorer(reader, 1.0)?; + while scorer.doc() != TERMINATED { + count += 1; + scorer.advance(); + } + Ok(count) + } +} + +struct EmptyScorer; + +impl DocSet for EmptyScorer { + fn advance(&mut self) -> DocId { + TERMINATED + } + + fn doc(&self) -> DocId { + TERMINATED + } + + fn size_hint(&self) -> u32 { + 0 + } +} + +impl Scorer for EmptyScorer { + fn score(&mut self) -> Score { + 0.0 + } +} + +struct NestedDocumentScorer { + child_scorer: Box, + parent_docs: BitSet, + score_mode: ScoreMode, + current_parent: DocId, + current_score: Score, + initialized: bool, + has_more: bool, +} + +impl DocSet for NestedDocumentScorer { + fn advance(&mut self) -> DocId { + // If we are out of docs, just return TERMINATED. + if !self.has_more { + return TERMINATED; + } + + // If this is the first time we advance, initialize the child scorer. + if !self.initialized { + // Advance the child scorer once to position it properly. + self.child_scorer.advance(); + self.initialized = true; + } + + // Keep looping until we find a parent with children or run out of parents. + loop { + let start = if self.current_parent == TERMINATED { + // Start from doc 0 if we haven't found any parent yet. + 0 + } else { + self.current_parent + 1 + }; + + // Find the next parent + self.current_parent = self.find_next_parent(start); + if self.current_parent == TERMINATED { + self.has_more = false; + return TERMINATED; + } + + // Collect matches for this parent + let doc_id = self.collect_matches(); + if doc_id != TERMINATED { + // We found a parent with children + return doc_id; + } + // If no children were found, try the next parent + } + } + + fn doc(&self) -> DocId { + if self.has_more { + self.current_parent + } else { + TERMINATED + } + } + + fn size_hint(&self) -> u32 { + self.parent_docs.len() as u32 + } +} + +impl NestedDocumentScorer { + fn find_next_parent(&self, from: DocId) -> DocId { + println!(">>> Looking for next parent starting from {}", from); + let mut current = from; + while current < self.parent_docs.max_value() { + if self.parent_docs.contains(current) { + println!(">>> Found next parent: {}", current); + return current; + } + current += 1; + } + println!(">>> No more parents found after {}", from); + TERMINATED + } + + fn collect_matches(&mut self) -> DocId { + println!( + "\n>>> Collecting matches for parent: {}", + self.current_parent + ); + + let mut child_doc = self.child_scorer.doc(); + println!(">>> Using current child doc: {}", child_doc); + println!(">>> Current parent: {}", self.current_parent); + println!(">>> Has more: {}", self.has_more); + + let mut child_scores = Vec::new(); + + // Gather all valid children for this parent + while child_doc != TERMINATED && child_doc < self.current_parent { + println!( + ">>> Examining child doc {} for parent {}", + child_doc, self.current_parent + ); + + // Check if there is another parent in between child and this parent + let mut is_valid = true; + for doc_id in (child_doc + 1)..self.current_parent { + if self.parent_docs.contains(doc_id) { + println!( + ">>> Found intervening parent {} between child {} and parent {}", + doc_id, child_doc, self.current_parent + ); + is_valid = false; + break; + } + } + + if is_valid { + let score = self.child_scorer.score(); + println!( + ">>> Child {} is valid for parent {} with score {}", + child_doc, self.current_parent, score + ); + child_scores.push(score); + } else { + println!( + ">>> Child {} is not valid for parent {}", + child_doc, self.current_parent + ); + } + + child_doc = self.child_scorer.advance(); + println!(">>> Advanced child scorer to: {}", child_doc); + } + + if child_scores.is_empty() { + println!( + ">>> No valid children for parent {}, skipping this parent", + self.current_parent + ); + // Return TERMINATED to indicate no matches for this parent + // and let `advance()` try the next one. + TERMINATED + } else { + println!( + ">>> Found {} valid children for parent {}", + child_scores.len(), + self.current_parent + ); + + self.current_score = match self.score_mode { + ScoreMode::Avg => { + let avg = child_scores.iter().sum::() / child_scores.len() as Score; + println!( + ">>> Average score for parent {}: {}", + self.current_parent, avg + ); + avg + } + ScoreMode::Max => { + let max = child_scores.iter().copied().fold(f32::MIN, f32::max); + println!(">>> Max score for parent {}: {}", self.current_parent, max); + max + } + ScoreMode::Total => { + let sum: Score = child_scores.iter().sum(); + println!( + ">>> Total score for parent {}: {}", + self.current_parent, sum + ); + sum + } + }; + + println!( + ">>> Returning parent {} with score {}", + self.current_parent, self.current_score + ); + self.current_parent + } + } +} + +impl Scorer for NestedDocumentScorer { + fn score(&mut self) -> Score { + self.current_score + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::collector::TopDocs; + use crate::query::{BooleanQuery, Occur, RangeQuery, TermQuery}; + use crate::schema::{IndexRecordOption, Schema, INDEXED, STORED, STRING}; + use crate::{Index, Result, Term}; + use std::ops::Bound; + + #[test] + fn test_nested_document_query() -> Result<()> { + println!("\n=== Starting test_nested_document_query ==="); + + // Create schema + let mut schema_builder = Schema::builder(); + let name = schema_builder.add_text_field("name", STORED); + let country = schema_builder.add_text_field("country", STRING | STORED); + let doc_type = schema_builder.add_text_field("doc_type", STRING | STORED); + let skill = schema_builder.add_text_field("skill", STRING | STORED); + let year = schema_builder.add_u64_field("year", STORED | INDEXED); + let schema = schema_builder.build(); + + println!("Created schema"); + + // Create index + let index = Index::create_in_ram(schema); + let mut writer = index.writer(50_000_000)?; + + println!("Created index and writer"); + + // Add documents using add_documents in a single batch + writer.add_documents(vec![ + doc!(skill => "java", year => 2006u64), + doc!(skill => "python", year => 2010u64), + doc!(name => "Lisa", country => "United Kingdom", doc_type => "resume"), + doc!(skill => "ruby", year => 2005u64), + doc!(skill => "java", year => 2007u64), + doc!(name => "Frank", country => "United States", doc_type => "resume"), + ])?; + + println!("Added all documents"); + writer.commit()?; + + let reader = index.reader()?; + let searcher = reader.searcher(); + + println!("Created reader and searcher"); + + // Verify document count + let total_docs: u32 = searcher + .segment_readers() + .iter() + .map(|reader| reader.num_docs()) + .sum(); + println!("Total documents in index: {}", total_docs); + assert_eq!(total_docs, 6, "Should have 6 documents total"); + + // Create parent query + let parent_query = TermQuery::new( + Term::from_field_text(doc_type, "resume"), + IndexRecordOption::Basic, + ); + + // Test parent query + let parent_results = searcher.search(&parent_query, &TopDocs::with_limit(10))?; + println!("Parent query found {} results", parent_results.len()); + assert_eq!(parent_results.len(), 2, "Should find 2 parent documents"); + + // Create child query + let child_query = BooleanQuery::new(vec![ + ( + Occur::Must, + Box::new(TermQuery::new( + Term::from_field_text(skill, "java"), + IndexRecordOption::Basic, + )), + ), + ( + Occur::Must, + Box::new(RangeQuery::new( + Bound::Included(Term::from_field_u64(year, 2006u64)), + Bound::Included(Term::from_field_u64(year, 2011u64)), + )), + ), + ]); + + // Test child query + let child_results = searcher.search(&child_query, &TopDocs::with_limit(10))?; + println!("Child query found {} results", child_results.len()); + assert_eq!(child_results.len(), 2, "Should find 2 child documents"); + + // Create nested query + let nested_query = NestedDocumentQuery::new( + Box::new(child_query), + Box::new(parent_query), + ScoreMode::Avg, + ); + + // Test nested query + let nested_results = searcher.search(&nested_query, &TopDocs::with_limit(10))?; + println!("Nested query found {} results", nested_results.len()); + assert_eq!( + nested_results.len(), + 2, + "Should find 2 nested document matches" + ); + + Ok(()) + } +} diff --git a/src/query/term_query/mod.rs b/src/query/term_query/mod.rs index fd8fbc7c14..13389dd73b 100644 --- a/src/query/term_query/mod.rs +++ b/src/query/term_query/mod.rs @@ -4,6 +4,7 @@ mod term_weight; pub use self::term_query::TermQuery; pub use self::term_scorer::TermScorer; + #[cfg(test)] mod tests { diff --git a/src/termdict/fst_termdict/mod.rs b/src/termdict/fst_termdict/mod.rs index 4201df6a4e..673b569c02 100644 --- a/src/termdict/fst_termdict/mod.rs +++ b/src/termdict/fst_termdict/mod.rs @@ -24,5 +24,7 @@ mod term_info_store; mod termdict; pub use self::merger::TermMerger; -pub use self::streamer::{TermStreamer, TermStreamerBuilder}; +pub use self::streamer::{ + TermStreamer, TermStreamerBuilder, TermWithStateStreamer, TermWithStateStreamerBuilder, +}; pub use self::termdict::{TermDictionary, TermDictionaryBuilder}; diff --git a/src/termdict/fst_termdict/streamer.rs b/src/termdict/fst_termdict/streamer.rs index d2e31421f0..8bdf3c3989 100644 --- a/src/termdict/fst_termdict/streamer.rs +++ b/src/termdict/fst_termdict/streamer.rs @@ -1,7 +1,7 @@ use std::io; use tantivy_fst::automaton::AlwaysMatch; -use tantivy_fst::map::{Stream, StreamBuilder}; +use tantivy_fst::map::{Stream, StreamBuilder, StreamWithState}; use tantivy_fst::{Automaton, IntoStreamer, Streamer}; use super::TermDictionary; @@ -11,14 +11,16 @@ use crate::termdict::TermOrdinal; /// `TermStreamerBuilder` is a helper object used to define /// a range of terms that should be streamed. pub struct TermStreamerBuilder<'a, A = AlwaysMatch> -where A: Automaton +where + A: Automaton, { fst_map: &'a TermDictionary, stream_builder: StreamBuilder<'a, A>, } impl<'a, A> TermStreamerBuilder<'a, A> -where A: Automaton +where + A: Automaton, { pub(crate) fn new(fst_map: &'a TermDictionary, stream_builder: StreamBuilder<'a, A>) -> Self { TermStreamerBuilder { @@ -73,7 +75,8 @@ where A: Automaton /// `TermStreamer` acts as a cursor over a range of terms of a segment. /// Terms are guaranteed to be sorted. pub struct TermStreamer<'a, A = AlwaysMatch> -where A: Automaton +where + A: Automaton, { pub(crate) fst_map: &'a TermDictionary, pub(crate) stream: Stream<'a, A>, @@ -83,7 +86,8 @@ where A: Automaton } impl TermStreamer<'_, A> -where A: Automaton +where + A: Automaton, { /// Advance position the stream on the next item. /// Before the first call to `.advance()`, the stream @@ -145,3 +149,153 @@ where A: Automaton } } } + +/// `TermWithStateStreamerBuilder` is a helper object used to define +/// a range of terms that should be streamed. +pub struct TermWithStateStreamerBuilder<'a, A = AlwaysMatch> +where + A: Automaton, + A::State: Clone, +{ + fst_map: &'a TermDictionary, + stream_builder: StreamBuilder<'a, A>, +} + +impl<'a, A> TermWithStateStreamerBuilder<'a, A> +where + A: Automaton, + A::State: Clone, +{ + pub(crate) fn new(fst_map: &'a TermDictionary, stream_builder: StreamBuilder<'a, A>) -> Self { + TermWithStateStreamerBuilder { + fst_map, + stream_builder, + } + } + + /// Limit the range to terms greater or equal to the bound + pub fn ge>(mut self, bound: T) -> Self { + self.stream_builder = self.stream_builder.ge(bound); + self + } + + /// Limit the range to terms strictly greater than the bound + pub fn gt>(mut self, bound: T) -> Self { + self.stream_builder = self.stream_builder.gt(bound); + self + } + + /// Limit the range to terms lesser or equal to the bound + pub fn le>(mut self, bound: T) -> Self { + self.stream_builder = self.stream_builder.le(bound); + self + } + + /// Limit the range to terms lesser or equal to the bound + pub fn lt>(mut self, bound: T) -> Self { + self.stream_builder = self.stream_builder.lt(bound); + self + } + + /// Iterate over the range backwards. + pub fn backward(mut self) -> Self { + self.stream_builder = self.stream_builder.backward(); + self + } + + /// Creates the stream corresponding to the range + /// of terms defined using the `TermWithStateStreamerBuilder`. + pub fn into_stream(self) -> io::Result> { + Ok(TermWithStateStreamer { + fst_map: self.fst_map, + stream: self.stream_builder.with_state().into_stream(), + term_ord: 0u64, + current_key: Vec::with_capacity(100), + current_value: TermInfo::default(), + current_state: None, + }) + } +} + +/// `TermWithStateStreamer` acts as a cursor over a range of terms of a segment. +/// Terms are guaranteed to be sorted. +pub struct TermWithStateStreamer<'a, A = AlwaysMatch> +where + A: Automaton, + A::State: Clone, +{ + fst_map: &'a TermDictionary, + stream: StreamWithState<'a, A>, + term_ord: TermOrdinal, + current_key: Vec, + current_value: TermInfo, + current_state: Option, +} + +impl<'a, A> TermWithStateStreamer<'a, A> +where + A: Automaton, + A::State: Clone, +{ + /// Advance position the stream on the next item. + /// Before the first call to `.advance()`, the stream + /// is an unitialized state. + pub fn advance(&mut self) -> bool { + if let Some((term, term_ord, state)) = self.stream.next() { + self.current_key.clear(); + self.current_key.extend_from_slice(term); + self.term_ord = term_ord; + self.current_value = self.fst_map.term_info_from_ord(term_ord); + self.current_state = Some(state); + true + } else { + false + } + } + + /// Returns the `TermOrdinal` of the given term. + /// + /// May panic if the called as `.advance()` as never + /// been called before. + pub fn term_ord(&self) -> TermOrdinal { + self.term_ord + } + + /// Accesses the current key. + /// + /// `.key()` should return the key that was returned + /// by the `.next()` method. + /// + /// If the end of the stream as been reached, and `.next()` + /// has been called and returned `None`, `.key()` remains + /// the value of the last key encountered. + /// + /// Before any call to `.next()`, `.key()` returns an empty array. + pub fn key(&self) -> &[u8] { + &self.current_key + } + + /// Accesses the current value. + /// + /// Calling `.value()` after the end of the stream will return the + /// last `.value()` encountered. + /// + /// # Panics + /// + /// Calling `.value()` before the first call to `.advance()` returns + /// `V::default()`. + pub fn value(&self) -> &TermInfo { + &self.current_value + } + + /// Return the next `(key, value, state)` triplet. + #[cfg_attr(feature = "cargo-clippy", allow(clippy::should_implement_trait))] + pub fn next(&mut self) -> Option<(&[u8], &TermInfo, A::State)> { + if self.advance() { + let state = self.current_state.take().unwrap(); // always Some(_) after advance + Some((self.key(), self.value(), state)) + } else { + None + } + } +} diff --git a/src/termdict/fst_termdict/termdict.rs b/src/termdict/fst_termdict/termdict.rs index 23ed3606d4..5bc9196942 100644 --- a/src/termdict/fst_termdict/termdict.rs +++ b/src/termdict/fst_termdict/termdict.rs @@ -6,7 +6,7 @@ use tantivy_fst::raw::Fst; use tantivy_fst::Automaton; use super::term_info_store::{TermInfoStore, TermInfoStoreWriter}; -use super::{TermStreamer, TermStreamerBuilder}; +use super::{TermStreamer, TermStreamerBuilder, TermWithStateStreamerBuilder}; use crate::directory::{FileSlice, OwnedBytes}; use crate::postings::TermInfo; use crate::termdict::TermOrdinal; @@ -217,4 +217,15 @@ impl TermDictionary { let stream_builder = self.fst_index.search(automaton); TermStreamerBuilder::::new(self, stream_builder) } + + /// Returns a search builder, to stream all of the terms + /// within the Automaton + pub fn search_with_state<'a, A>(&'a self, automaton: A) -> TermWithStateStreamerBuilder<'a, A> + where + A: Automaton + 'a, + A::State: Clone, + { + let stream_builder = self.fst_index.search(automaton); + TermWithStateStreamerBuilder::::new(self, stream_builder) + } } diff --git a/src/termdict/mod.rs b/src/termdict/mod.rs index 01c9591ee3..dd0ccdd954 100644 --- a/src/termdict/mod.rs +++ b/src/termdict/mod.rs @@ -40,11 +40,12 @@ use common::file_slice::FileSlice; use common::BinarySerializable; use tantivy_fst::Automaton; +use self::fst_termdict::TermWithStateStreamerBuilder; use self::termdict::{ TermDictionary as InnerTermDict, TermDictionaryBuilder as InnerTermDictBuilder, TermStreamerBuilder, }; -pub use self::termdict::{TermMerger, TermStreamer}; +pub use self::termdict::{TermMerger, TermStreamer, TermWithStateStreamer}; use crate::postings::TermInfo; #[derive(Debug, Eq, PartialEq)] @@ -183,6 +184,16 @@ impl TermDictionary { ) -> FileSlice { self.0.file_slice_for_range(key_range, limit) } + + /// Returns a search builder, to stream all of the terms + /// within the Automaton + pub fn search_with_state<'a, A>(&'a self, automaton: A) -> TermWithStateStreamerBuilder<'a, A> + where + A: Automaton + 'a, + A::State: Clone, + { + self.0.search_with_state(automaton) + } } /// A TermDictionaryBuilder wrapping either an FST or a SSTable dictionary builder. diff --git a/sstable/src/dictionary.rs b/sstable/src/dictionary.rs index 7c61df09a8..0b8c1dd7d0 100644 --- a/sstable/src/dictionary.rs +++ b/sstable/src/dictionary.rs @@ -443,16 +443,26 @@ impl Dictionary { let mut current_sstable_delta_reader = self.sstable_delta_reader_block(current_block_addr.clone())?; let mut current_ordinal = 0; + let mut prev_ord = None; for ord in ord { - assert!(ord >= current_ordinal); - // check if block changed for new term_ord - let new_block_addr = self.sstable_index.get_block_with_ord(ord); - if new_block_addr != current_block_addr { - current_block_addr = new_block_addr; - current_ordinal = current_block_addr.first_ordinal; - current_sstable_delta_reader = - self.sstable_delta_reader_block(current_block_addr.clone())?; - bytes.clear(); + + // only advance forward if the new ord is different than the one we just processed + // + // this allows the input TermOrdinal iterator to contain duplicates, so long as it's + // still sorted + if Some(ord) != prev_ord { + assert!(ord >= current_ordinal); + // check if block changed for new term_ord + let new_block_addr = self.sstable_index.get_block_with_ord(ord); + if new_block_addr != current_block_addr { + current_block_addr = new_block_addr; + current_ordinal = current_block_addr.first_ordinal; + current_sstable_delta_reader = + self.sstable_delta_reader_block(current_block_addr.clone())?; + bytes.clear(); + } + + prev_ord = Some(ord); } // move to ord inside that block diff --git a/tests/fuzzy_scoring.rs b/tests/fuzzy_scoring.rs new file mode 100644 index 0000000000..33e076ae74 --- /dev/null +++ b/tests/fuzzy_scoring.rs @@ -0,0 +1,128 @@ +#[cfg(test)] +mod test { + use maplit::hashmap; + use tantivy::collector::TopDocs; + use tantivy::query::FuzzyTermQuery; + use tantivy::schema::{Schema, Value, STORED, TEXT}; + use tantivy::{doc, Index, TantivyDocument, Term}; + + #[test] + pub fn test_fuzzy_term() { + // Define a list of documents to be indexed. Each entry represents a text + // that will be associated with the field "country" in the index. + let docs = vec![ + "WENN ROT WIE RUBIN", + "WENN ROT WIE ROBIN", + "WHEN RED LIKE ROBIN", + "WENN RED AS ROBIN", + "WHEN ROYAL BLUE ROBIN", + "IF RED LIKE RUBEN", + "WHEN GREEN LIKE ROBIN", + "WENN ROSE LIKE ROBIN", + "IF PINK LIKE ROBIN", + "WENN ROT WIE RABIN", + "WENN BLU WIE ROBIN", + "WHEN YELLOW LIKE RABBIT", + "IF BLUE LIKE ROBIN", + "WHEN ORANGE LIKE RIBBON", + "WENN VIOLET WIE RUBIX", + "WHEN INDIGO LIKE ROBBIE", + "IF TEAL LIKE RUBY", + "WHEN GOLD LIKE ROB", + "WENN SILVER WIE ROBY", + "IF BRONZE LIKE ROBE", + ]; + + // Define the expected scores when queried with "robin" and a fuzziness of 2. + // This map associates each document text with its expected score. + let expected_scores = hashmap! { + "WHEN GREEN LIKE ROBIN" => 1.0, + "WENN RED AS ROBIN" => 1.0, + "WHEN RED LIKE ROBIN" => 1.0, + "WENN ROSE LIKE ROBIN" => 1.0, + "WENN ROT WIE ROBIN" => 1.0, + "WHEN ROYAL BLUE ROBIN" => 1.0, + "IF PINK LIKE ROBIN" => 1.0, + "IF BLUE LIKE ROBIN" => 1.0, + "WENN BLU WIE ROBIN" => 1.0, + "WENN ROT WIE RUBIN" => 0.5, + "WENN ROT WIE RABIN" => 0.5, + "IF RED LIKE RUBEN" => 0.33333334, + "WENN VIOLET WIE RUBIX" => 0.33333334, + "IF BRONZE LIKE ROBE" => 0.33333334, + "WENN SILVER WIE ROBY" => 0.33333334, + "WHEN GOLD LIKE ROB" => 0.33333334, + "WHEN INDIGO LIKE ROBBIE" => 0.33333334, + }; + + // Build a schema for the index. + // The schema determines how documents are indexed and searched. + let mut schema_builder = Schema::builder(); + + // Add a text field named "country" to the schema. This field will store the text and + // is indexed in a way that makes it searchable. + let country_field = schema_builder.add_text_field("country", TEXT | STORED); + // Build the schema based on the provided definitions. + let schema = schema_builder.build(); + // Create a new index in RAM based on the defined schema. + let index = Index::create_in_ram(schema); + { + // Create an index writer with one thread and a certain memory limit. + // The writer allows us to add documents to the index. + let mut index_writer = index.writer_with_num_threads(1, 15_000_000).unwrap(); + + // Index each document in the docs list. + for &doc in &docs { + index_writer + .add_document(doc!(country_field => doc)) + .unwrap(); + } + + // Commit changes to the index. This finalizes the addition of documents. + index_writer.commit().unwrap(); + } + + // Create a reader for the index to search the indexed documents. + let reader = index.reader().unwrap(); + let searcher = reader.searcher(); + + { + // Define a term based on the field "country" and the text "robin". + let term = Term::from_field_text(country_field, "robin"); + + // Create a fuzzy query for "robin", a fuzziness of 2, and a prefix length of 0. + let fuzzy_query = FuzzyTermQuery::new(term, 2, true); + + // Search the index with the fuzzy query and retrieve up to 100 top documents. + let top_docs = searcher + .search(&fuzzy_query, &TopDocs::with_limit(100)) + .unwrap(); + + // Print out the scores and documents retrieved by the search. + for (score, adr) in &top_docs { + let doc: TantivyDocument = searcher.doc(*adr).expect("document"); + println!("{score}, {:?}", doc.field_values().next().unwrap().1); + } + + // Assert that 17 documents match the fuzzy query criteria. + // We don't expect anything that has a larger fuzziness than 2 + // to be returned in the query, leaving us with 17 expected results. + assert_eq!(top_docs.len(), 17, "Expected 17 documents"); + + // Check the scores of the returned documents against the expected scores. + for (score, adr) in &top_docs { + let doc: TantivyDocument = searcher.doc(*adr).expect("document"); + let doc_text = doc.field_values().next().unwrap().1.as_str().unwrap(); + + // Ensure the retrieved score for each document is close to the expected score. + assert!( + (score - expected_scores[doc_text]).abs() < f32::EPSILON, + "Unexpected score for document {}. Expected: {}, Actual: {}", + doc_text, + expected_scores[doc_text], + score + ); + } + } + } +} diff --git a/tokenizer-api/src/lib.rs b/tokenizer-api/src/lib.rs index 2ba38d82ae..e5d3a5baff 100644 --- a/tokenizer-api/src/lib.rs +++ b/tokenizer-api/src/lib.rs @@ -157,6 +157,77 @@ pub trait TokenFilter: 'static + Send + Sync { fn transform(self, tokenizer: T) -> Self::Tokenizer; } +/// An optional [`TokenFilter`]. +impl TokenFilter for Option { + type Tokenizer = OptionalTokenizer, T>; + + #[inline] + fn transform(self, tokenizer: T) -> Self::Tokenizer { + match self { + Some(filter) => OptionalTokenizer::Enabled(filter.transform(tokenizer)), + None => OptionalTokenizer::Disabled(tokenizer), + } + } +} + +/// A [`Tokenizer`] derived from a [`TokenFilter::transform`] on an +/// [`Option`] token filter. +#[derive(Clone)] +pub enum OptionalTokenizer { + Enabled(E), + Disabled(D), +} + +impl Tokenizer for OptionalTokenizer { + type TokenStream<'a> = OptionalTokenStream, D::TokenStream<'a>>; + + #[inline] + fn token_stream<'a>(&'a mut self, text: &'a str) -> Self::TokenStream<'a> { + match self { + Self::Enabled(tokenizer) => { + let token_stream = tokenizer.token_stream(text); + OptionalTokenStream::Enabled(token_stream) + } + Self::Disabled(tokenizer) => { + let token_stream = tokenizer.token_stream(text); + OptionalTokenStream::Disabled(token_stream) + } + } + } +} + +/// A [`TokenStream`] derived from a [`Tokenizer::token_stream`] on an [`OptionalTokenizer`]. +pub enum OptionalTokenStream { + Enabled(E), + Disabled(D), +} + +impl TokenStream for OptionalTokenStream { + #[inline] + fn advance(&mut self) -> bool { + match self { + Self::Enabled(t) => t.advance(), + Self::Disabled(t) => t.advance(), + } + } + + #[inline] + fn token(&self) -> &Token { + match self { + Self::Enabled(t) => t.token(), + Self::Disabled(t) => t.token(), + } + } + + #[inline] + fn token_mut(&mut self) -> &mut Token { + match self { + Self::Enabled(t) => t.token_mut(), + Self::Disabled(t) => t.token_mut(), + } + } +} + #[cfg(test)] mod test { use super::*;