diff --git a/pyserini/prebuilt_index_info.py b/pyserini/prebuilt_index_info.py index 9f2ea2f38..6f1e34993 100644 --- a/pyserini/prebuilt_index_info.py +++ b/pyserini/prebuilt_index_info.py @@ -1771,3 +1771,85 @@ "texts": "wikipedia-dpr" } } + +JASS_INDEX_INFO = { + "jass-msmarco-passage-bm25": { + "description": "BP reordered JASS impact index of the MS MARCO passage corpus with BM25 scoring", + "filename": "jass-index.msmarco-passage.bm25.20220217.5cbb40.tar.gz", + "urls": [ + "https://rgw.cs.uwaterloo.ca/JIMMYLIN-bucket0/pyserini-indexes/jass-index.msmarco-passage.bm25.20220217.5cbb40.tar.gz" + ], + "md5": "9add4b1f754c5f33d31501c65e5e92d3", + "size compressed (bytes)": 629101230, + "total_terms": 0, + "documents": 0, + "unique_terms": 0, + "downloaded": False + }, + "jass-msmarco-passage-d2q-t5": { + "description": "BP reordered JASS impact index of the MS MARCO passage corpus with BM25 scoring over a DocT5Query expanded collection", + "filename": "jass-index.msmarco-passage.d2q-t5.20220217.5cbb40.tar.gz", + "urls": [ + "https://rgw.cs.uwaterloo.ca/JIMMYLIN-bucket0/pyserini-indexes/jass-index.msmarco-passage.d2q-t5.20220217.5cbb40.tar.gz" + ], + "md5": "9be8d8890d60410243a8c7323849ecc9", + "size compressed (bytes)": 832303111, + "total_terms": 0, + "documents": 0, + "unique_terms": 0, + "downloaded": False + }, + "jass-msmarco-passage-deepimpact": { + "description": "BP reordered JASS impact index of the MS MARCO passage corpus with DeepImpact scoring", + "filename": "jass-index.msmarco-passage.deepimpact.20220217.5cbb40.tar.gz", + "urls": [ + "https://rgw.cs.uwaterloo.ca/JIMMYLIN-bucket0/pyserini-indexes/jass-index.msmarco-passage.deepimpact.20220217.5cbb40.tar.gz" + ], + "md5": "d9ed05d97e1f07373d7a98a1dd9f6fac", + "size compressed (bytes)": 1217477634, + "total_terms": 0, + "documents": 0, + "unique_terms": 0, + "downloaded": False + }, + "jass-msmarco-passage-unicoil-d2q": { + "description": "BP reordered JASS impact index of the MS MARCO passage corpus with uniCOIL scoring over a DocT5Query expanded collection", + "filename" : "jass-index.msmarco-passage.unicoil-d2q.20220217.5cbb40.tar.gz", + "urls": [ + "https://rgw.cs.uwaterloo.ca/JIMMYLIN-bucket0/pyserini-indexes/jass-index.msmarco-passage.unicoil-d2q.20220217.5cbb40.tar.gz" + ], + "md5": "24bab2ef23914ab124d4f0eba8dc866c", + "size compressed (bytes)": 1084195359, + "total_terms": 0, + "documents": 0, + "unique_terms": 0, + "downloaded": False + }, + "jass-msmarco-unicoil-tilde": { + "description": "BP reordered JASS impact index of the MS MARCO passage corpus with uniCOIL scoring over a TILDE expanded collection", + "filename": "jass-index.msmarco-passage.unicoil-tilde.20220217.5cbb40.tar.gz", + "urls": [ + "https://rgw.cs.uwaterloo.ca/JIMMYLIN-bucket0/pyserini-indexes/jass-index.msmarco-passage.unicoil-tilde.20220217.5cbb40.tar.gz" + ], + "md5": "705c3e72cff189265de9b5c509be00a6", + "size compressed (bytes)": 1724440877, + "total_terms": 0, + "documents": 0, + "unique_terms": 0, + "downloaded": False + }, + "jass-msmarco-passage-distill-splade-max": { + "description": "BP reordered JASS impact index of the MS MARCO passage corpus with distill-splade-max scoring", + "filename": "jass-index.msmarco-passage.distill-splade-max.20220217.5cbb40.tar.gz", + "urls": [ + "https://rgw.cs.uwaterloo.ca/JIMMYLIN-bucket0/pyserini-indexes/jass-index.msmarco-passage.distill-splade-max.20220217.5cbb40.tar.gz" + ], + "md5": "f6bf3cdf983d4e1aaee8677acbcdb47f", + "size compressed (bytes)": 3530600632, + "total_terms": 0, + "documents": 0, + "unique_terms": 0, + "downloaded": False + } +} + diff --git a/pyserini/search/__init__.py b/pyserini/search/__init__.py index beb12d643..ae8e27a75 100644 --- a/pyserini/search/__init__.py +++ b/pyserini/search/__init__.py @@ -19,13 +19,12 @@ from .lucene import JLuceneSearcherResult, LuceneSimilarities, LuceneFusionSearcher, LuceneSearcher from .lucene import JImpactSearcherResult, LuceneImpactSearcher from ._deprecated import SimpleSearcher, ImpactSearcher, SimpleFusionSearcher - from .faiss import DenseSearchResult, PRFDenseSearchResult, FaissSearcher, BinaryDenseSearcher, QueryEncoder, \ DprQueryEncoder, BprQueryEncoder, DkrrDprQueryEncoder, TctColBertQueryEncoder, AnceQueryEncoder, AutoQueryEncoder +from .jass import JASSv2Searcher from .faiss import AnceEncoder from .faiss import DenseVectorAveragePrf, DenseVectorRocchioPrf, DenseVectorAncePrf - __all__ = ['JQuery', 'LuceneSimilarities', 'LuceneFusionSearcher', @@ -51,10 +50,10 @@ 'BprQueryEncoder', 'DkrrDprQueryEncoder', 'TctColBertQueryEncoder', + 'JASSv2Searcher', 'AnceEncoder', 'AnceQueryEncoder', 'AutoQueryEncoder', 'DenseVectorAveragePrf', 'DenseVectorRocchioPrf', - 'DenseVectorAncePrf'] - + 'DenseVectorAncePrf'] \ No newline at end of file diff --git a/pyserini/search/jass/__init__.py b/pyserini/search/jass/__init__.py new file mode 100644 index 000000000..6f6ce7ff1 --- /dev/null +++ b/pyserini/search/jass/__init__.py @@ -0,0 +1,20 @@ +# +# Pyserini: Reproducible IR research with sparse and dense representations +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from ._searcher import JASSv2Searcher , JASSv2SearcherResult + + +__all__ = ['JASSv2Searcher', 'JASSv2SearcherResult'] \ No newline at end of file diff --git a/pyserini/search/jass/__main__.py b/pyserini/search/jass/__main__.py new file mode 100644 index 000000000..a2dbd16a1 --- /dev/null +++ b/pyserini/search/jass/__main__.py @@ -0,0 +1,111 @@ +# +# Pyserini: Reproducible IR research with sparse and dense representations +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyserini.search import JASSv2Searcher +import argparse +import os +from tqdm import tqdm + +from pyserini.output_writer import OutputFormat, get_output_writer +from pyserini.query_iterator import get_query_iterator, TopicsFormat + + + +def define_search_args(parser): + parser.add_argument('--index', type=str, metavar='path to index or index name', required=True, + help="Path to pyJass index") + parser.add_argument('--rho', type=int, default=1000000000, help='rho: how many postings to process') + parser.add_argument('--basic-parser', default=False, action='store_true', help="Use the basic query parser; Default is to use the ASCII parser") + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Search a pyJass index.') + define_search_args(parser) + parser.add_argument('--topics', type=str, metavar='topic_name', required=True, + help="Name of topics. Available: robust04, robust05, core17, core18.") + parser.add_argument('--hits', type=int, metavar='num', + required=False, default=1000, help="Number of hits.") + parser.add_argument('--topics-format', type=str, metavar='format', default=TopicsFormat.DEFAULT.value, + help=f"Format of topics. Available: {[x.value for x in list(TopicsFormat)]}") + parser.add_argument('--output-format', type=str, metavar='format', default=OutputFormat.TREC.value, + help=f"Format of output. Available: {[x.value for x in list(OutputFormat)]}") + parser.add_argument('--output', type=str, metavar='path', + help="Path to output file.") + parser.add_argument('--batch-size', type=int, metavar='num', required=False, + default=1, help="Specify batch size to search the collection concurrently.") + parser.add_argument('--threads', type=int, metavar='num', required=False, + default=1, help="Maximum number of threads to use.") + parser.add_argument('--impact', action='store_true', help="Use Impact.") + + args = parser.parse_args() + + query_iterator = get_query_iterator(args.topics, TopicsFormat(args.topics_format)) + topics = query_iterator.topics + + if os.path.exists(args.index): + searcher = JASSv2Searcher(args.index, 2) + else: + searcher = JASSv2Searcher.from_prebuilt_index(args.index) + + if not searcher: + exit() + + # JASS does not (yet) support field-based retrieval + fields = None + + if not args.impact: + print("Enforcing --impact; JASS requires impact-based retrieval.") + + # JASS Parser Option + if args.basic_parser: + searcher.set_basic_parser() + + # build output path + output_path = args.output + if output_path is None: + tokens = ['run', args.topics, '_'.join(['rho',str(args.rho)]), 'txt'] # we use the rho output + output_path = '.'.join(tokens) + + print(f'Running {args.topics} topics, saving to {output_path}...') + tag = output_path[:-4] if args.output is None else 'JaSS' + + output_writer = get_output_writer(output_path, OutputFormat(args.output_format), 'w', + max_hits=args.hits, tag=tag, topics=topics) + + with output_writer: + batch_topics = list() + batch_topic_ids = list() + for index, (topic_id, text) in enumerate(tqdm(query_iterator, total=len(topics.keys()))): + if args.batch_size <= 1 and args.threads <= 1: + hits = searcher.search(text, args.hits, args.rho) + results = [(topic_id, hits)] + else: + batch_topic_ids.append(str(topic_id)) + batch_topics.append(text) + if (index + 1) % args.batch_size == 0 or \ + index == len(topics.keys()) - 1: + results = searcher.batch_search( + batch_topics, batch_topic_ids, args.hits, args.rho, args.threads) + results = [(id_, results[id_]) for id_ in batch_topic_ids] + batch_topic_ids.clear() + batch_topics.clear() + else: + continue + + for topic, hits in results: + # write results + output_writer.write(topic, hits) + + results.clear() diff --git a/pyserini/search/jass/_searcher.py b/pyserini/search/jass/_searcher.py new file mode 100644 index 000000000..896a78ab8 --- /dev/null +++ b/pyserini/search/jass/_searcher.py @@ -0,0 +1,219 @@ +# +# Pyserini: Reproducible IR research with sparse and dense representations +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +This module provides Pyserini's Python search interface to JASSv2. The main entry point is the ``JASSv2Searcher`` +class, which wraps the C++ ``JASS_anytime_api``. +""" + +from dataclasses import dataclass +import logging +import pyjass +from typing import Dict, List, Optional, Union +from pyserini.trectools import TrecRun +from pyserini.util import download_prebuilt_index +logger = logging.getLogger(__name__) + +# Wrappers around JASS classes + +@dataclass +class JASSv2SearcherResult: + docid: str # doc id + score: float # score in float + #TODO Implement the following attributes specially for JASSv2 + # query: str #query + # postings_processed: int # no of posting processed + + +class JASSv2Searcher: + + # Constants + EXPECTED_ENTRIES = 6 + DOCID_POS = 2 + SCORE_POS = 4 + ONE_BILLION = 1000000000 + + """Wrapper class for the ``JASS_anytime_api`` in JASSv2. + + Parameters + ---------- + index_dir : str + Path to JASS index directory. + """ + + def __init__(self, index_dir: str, version: int = 2): + self.index_dir = index_dir + self.object = pyjass.anytime() + self.set_ascii_parser() + index = self.object.load_index(version,index_dir) + self.num_docs = self.object.get_document_count() + if index != 0: + raise Exception('Unable to load index - error code' + str(index)) + + + @classmethod + def from_prebuilt_index(cls, prebuilt_index_name: str): + """Build a searcher from a pre-built index; download the index if necessary. + + Parameters + ---------- + prebuilt_index_name : str + Prebuilt index name. + + Returns + ------- + SimpleSearcher + Searcher built from the prebuilt index. + """ + print(f'Attempting to initialize pre-built index {prebuilt_index_name}.') + try: + index_dir = download_prebuilt_index(prebuilt_index_name) + except ValueError as e: + print(str(e)) + return None + + print(f'Initializing {prebuilt_index_name}...') + return cls(index_dir) + + + def convert_to_search_result(self, result_list:str) -> List[JASSv2SearcherResult]: + """Process a pyJass query and return the results in a list of DenseSearchResult. + + Parameters + ---------- + query : str + Query string fromy pyjass. Multiple queries are stored as with new line token. + + Returns + ------- + List[JASSv2SearcherResult] + List of JASSv2SearcherResult which contains the DocID and also the score pair. + """ + docid_score_pair = list() + results = result_list.split('\n') + for res in results: + # Split by space. We expect the `trec` format, bail out if we don't get it + result_data = res.split(' ') + if len(result_data) == self.EXPECTED_ENTRIES: + # All is well, append the [docid, score] tuple. + docid_score_pair.append(JASSv2SearcherResult(result_data[self.DOCID_POS], float(result_data[self.SCORE_POS]))) + return docid_score_pair + + + def search(self, q: str, k: int = 10, rho: int = ONE_BILLION) -> List[JASSv2SearcherResult]: + """Search the collection for a single query. + + Parameters + ---------- + q : str + Query string. + k : int + Number of results to return. + rho : int + Value of rho to use. + + Returns + ------- + List[JASSv2SearcherResult] + List of search results. + + """ + + self.object.set_top_k(k) + self.object.set_postings_to_process(rho) + # JASS expects queries to be an identifier followed by terms, delimited by either ':', '\t', or ' ' + # We do not want to split on spaces as it may result in discarded terms. + split_query = q.split(":\t") + # Assume the first field is the identifier... + if len(split_query) == 2: + results = self.object.search(q) + else: + results = self.object.search("0:"+q) # appending `0:` so JASS consumes it as the identifier + return (self.convert_to_search_result(results.results_list)) + + + def __list_to_strvector(self, qids: List[str] ,queries: List[str]) -> pyjass.JASS_string_vector: + """Convert a list of queries to a c++ string_vector. + + Parameters + ---------- + qids : List[str] + List of query ids. + queries : List[str] + List of queries. + + Returns + ------- + pyjass.string_vector + c++ string_vector to be consumed by Jass. + + """ + return(pyjass.JASS_string_vector([':'.join(map(str, i)) for i in zip(qids, queries)])) + + + + def batch_search(self, queries: List[str], qids: List[str], k: int = 10, rho: int = ONE_BILLION, threads: int = 1) -> Dict[str, List[JASSv2SearcherResult]]: + + """Search the collection concurrently for multiple queries, using multiple threads. + + Parameters + ---------- + queries : List[str] + List of queries. + qids : List[str] + List of query ids. + k : int + Number of results to return. + rho : int + Value of rho to use. + threads : int + Number of threads to use. + + Returns + ------- + Dict[str, List[JASSv2SearcherResult]] + Dictionary holding the search results, with the query ids as keys and the corresponding lists of search + results as the values. + """ + + self.object.set_top_k(k) + output = dict() + self.object.set_postings_to_process(rho) + results = self.object.threaded_search(self.__list_to_strvector(qids, queries), threads) + for i in range(len(results)): + if len(results[i].results) > 0: + for key in results[i].results.asdict().keys(): + output[key] = self.convert_to_search_result(results[i].results[key].results_list) + + return output + + def set_ascii_parser(self) -> int: + """Set Jass to use ascii parser.""" + return(self.object.use_ascii_parser()) + + def set_basic_parser(self) -> int: + """Set Jass to use query parser.""" + return(self.object.use_query_parser()) + + + def __get_time_taken(self) -> float: + """Get the time taken to perform the search.' + Returns + ------- + float + Time taken to perform the search. + """ + raise NotImplementedError("This method is not implemented in JASSv2Searcher.") diff --git a/pyserini/search/jass/test.py b/pyserini/search/jass/test.py new file mode 100644 index 000000000..a0760fd23 --- /dev/null +++ b/pyserini/search/jass/test.py @@ -0,0 +1,10 @@ +from pyserini.search import JASSv2Searcher +import jass + + + +searcher = JASSv2Searcher('msmarco-passage') +hits = searcher.search('what is a lobster roll?') + +for i in range(0, 10): + print(f'{i+1:2} {hits[i].docid:7} {hits[i].score:.5f}') \ No newline at end of file diff --git a/pyserini/util.py b/pyserini/util.py index 6c831f6b0..d938ddb7a 100644 --- a/pyserini/util.py +++ b/pyserini/util.py @@ -28,7 +28,7 @@ from pyserini.encoded_query_info import QUERY_INFO from pyserini.evaluate_script_info import EVALUATION_INFO -from pyserini.prebuilt_index_info import TF_INDEX_INFO, FAISS_INDEX_INFO, IMPACT_INDEX_INFO +from pyserini.prebuilt_index_info import TF_INDEX_INFO, FAISS_INDEX_INFO, IMPACT_INDEX_INFO, JASS_INDEX_INFO logger = logging.getLogger(__name__) @@ -173,8 +173,10 @@ def check_downloaded(index_name): target_index = TF_INDEX_INFO[index_name] elif index_name in IMPACT_INDEX_INFO: target_index = IMPACT_INDEX_INFO[index_name] - else: + elif index_name in FAISS_INDEX_INFO: target_index = FAISS_INDEX_INFO[index_name] + else: + target_index = JASS_INDEX_INFO[index_name] index_url = target_index['urls'][0] index_md5 = target_index['md5'] index_name = index_url.split('/')[-1] @@ -216,14 +218,16 @@ def get_dense_indexes_info(): def download_prebuilt_index(index_name, force=False, verbose=True, mirror=None): - if index_name not in TF_INDEX_INFO and index_name not in FAISS_INDEX_INFO and index_name not in IMPACT_INDEX_INFO: + if index_name not in TF_INDEX_INFO and index_name not in FAISS_INDEX_INFO and index_name not in IMPACT_INDEX_INFO and index_name not in JASS_INDEX_INFO: raise ValueError(f'Unrecognized index name {index_name}') if index_name in TF_INDEX_INFO: target_index = TF_INDEX_INFO[index_name] elif index_name in IMPACT_INDEX_INFO: target_index = IMPACT_INDEX_INFO[index_name] - else: + elif index_name in FAISS_INDEX_INFO: target_index = FAISS_INDEX_INFO[index_name] + else: + target_index = JASS_INDEX_INFO[index_name] index_md5 = target_index['md5'] for url in target_index['urls']: local_filename = target_index['filename'] if 'filename' in target_index else None diff --git a/requirements.txt b/requirements.txt index d3927eb67..4c37a72c6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,3 +11,4 @@ nmslib>=2.1.1 onnxruntime>=1.8.1 lightgbm>=3.3.2 spacy>=3.2.1 +pyjass>=0.2a7 diff --git a/tests/test_search_pyjass.py b/tests/test_search_pyjass.py new file mode 100644 index 000000000..c3638b240 --- /dev/null +++ b/tests/test_search_pyjass.py @@ -0,0 +1,166 @@ +# +# Pyserini: Reproducible IR research with sparse and dense representations +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import shutil +import tarfile +import unittest +from random import randint +from typing import List, Dict +from urllib.request import urlretrieve + +from pyserini.search.jass import JASSv2Searcher, JASSv2SearcherResult +import pyjass +from pyserini.index import Document + + +class TestSearchPyJass(unittest.TestCase): + def setUp(self): + # Download pre-built CACM index; append a random value to avoid filename clashes. + #TODO To-be filled in by the test runner. + r = randint(0, 10000000) + self.collection_url = 'https://github.com/prasys/anserini-data/raw/master/CACM/jass-index.cacm.tar.gz' # to be replaced + self.tarball_name = 'jass-index.cacm-{}.tar.gz'.format(r) + self.index_dir = 'jass{}/'.format(r) + + filename, headers = urlretrieve(self.collection_url, self.tarball_name) + + tarball = tarfile.open(self.tarball_name) + tarball.extractall(self.index_dir) + tarball.close() + + self.searcher = JASSv2Searcher(f'{self.index_dir}jass-index.cacm') + + def test_basic(self): + hits = self.searcher.search('information retrieval') + + self.assertEqual(3204, self.searcher.num_docs) + self.assertTrue(isinstance(hits, List)) + + self.assertTrue(isinstance(hits[0], JASSv2SearcherResult)) + self.assertEqual(hits[0].docid, 'CACM-3134') + self.assertEqual(hits[0].score, 664.0) + + + + self.assertTrue(isinstance(hits[9], JASSv2SearcherResult)) + self.assertEqual(hits[9].docid, 'CACM-2631') + self.assertEqual(hits[9].score, 589.0) + + hits = self.searcher.search('search') + + self.assertTrue(isinstance(hits[0], JASSv2SearcherResult)) + self.assertEqual(hits[0].docid, 'CACM-3041') + self.assertEqual(hits[0].score, 413.0) + + self.assertTrue(isinstance(hits[9], JASSv2SearcherResult)) + self.assertEqual(hits[9].docid, 'CACM-1815') + self.assertEqual(hits[9].score, 392.0) + + def test_batch(self): + results = self.searcher.batch_search(['information retrieval', 'search'], ['q1', 'q2'], threads=2) + + self.assertEqual(3204, self.searcher.num_docs) + self.assertTrue(isinstance(results, Dict)) + + self.assertTrue(isinstance(results['q1'], List)) + self.assertTrue(isinstance(results['q1'][0], JASSv2SearcherResult)) + self.assertEqual(results['q1'][0].docid, 'CACM-3134') + self.assertEqual(results['q1'][0].score, 664.0) + + self.assertTrue(isinstance(results['q1'][9], JASSv2SearcherResult)) + self.assertEqual(results['q1'][9].docid, 'CACM-2631') + self.assertEqual(results['q1'][9].score, 589.0) + + self.assertTrue(isinstance(results['q2'], List)) + self.assertTrue(isinstance(results['q2'][0], JASSv2SearcherResult)) + self.assertEqual(results['q2'][0].docid, 'CACM-3041') + self.assertEqual(results['q2'][0].score, 413.0) + + self.assertTrue(isinstance(results['q2'][9], JASSv2SearcherResult)) + self.assertEqual(results['q2'][9].docid, 'CACM-1815') + self.assertEqual(results['q2'][9].score, 392.0) + + def test_basic_k(self): + hits = self.searcher.search('information retrieval', k=88) + + self.assertEqual(3204, self.searcher.num_docs) + self.assertTrue(isinstance(hits, List)) + self.assertTrue(isinstance(hits[0], JASSv2SearcherResult)) + self.assertEqual(len(hits), 88) + + def test_batch_k(self): + results = self.searcher.batch_search(['information retrieval', 'search'], ['q1', 'q2'], k=88, threads=2) + + self.assertEqual(3204, self.searcher.num_docs) + self.assertTrue(isinstance(results, Dict)) + self.assertTrue(isinstance(results['q1'], List)) + self.assertTrue(isinstance(results['q1'][0], JASSv2SearcherResult)) + self.assertEqual(len(results['q1']), 88) + self.assertTrue(isinstance(results['q2'], List)) + self.assertTrue(isinstance(results['q2'][0], JASSv2SearcherResult)) + self.assertEqual(len(results['q2']), 88) + + def test_basic_rho(self): + hits = self.searcher.search('information retrieval', k=42, rho=50) + + self.assertEqual(3204, self.searcher.num_docs) + self.assertTrue(isinstance(hits, List)) + self.assertTrue(isinstance(hits[0], JASSv2SearcherResult)) + self.assertEqual(hits[9].docid, 'CACM-1725') + self.assertEqual(hits[9].score, 362.0) + self.assertEqual(len(hits), 42) + + def test_batch_rho(self): + # This test just provides a sanity check, it's not that interesting as it only searches one field. + results = self.searcher.batch_search(['information retrieval', 'search'], ['q1', 'q2'], k=42, + threads=2, rho=50) + + self.assertEqual(3204, self.searcher.num_docs) + self.assertTrue(isinstance(results, Dict)) + self.assertTrue(isinstance(results['q1'], List)) + self.assertTrue(isinstance(results['q1'][0], JASSv2SearcherResult)) + self.assertEqual(len(results['q1']), 42) + self.assertEqual(results['q1'][9].docid, 'CACM-1725') + self.assertEqual(results['q1'][9].score, 362.0) + + self.assertTrue(isinstance(results['q2'], List)) + self.assertTrue(isinstance(results['q2'][0], JASSv2SearcherResult)) + self.assertEqual(len(results['q2']), 42) + self.assertEqual(results['q2'][9].docid, 'CACM-1815') + self.assertEqual(results['q2'][9].score, 392.0) + + # def test_different_similarity(self): + + def test_ascii(self): + output = self.searcher.set_ascii_parser() + self.assertEqual(0, output) + + + + def test_basic_parser(self): + output = self.searcher.set_basic_parser() + self.assertEqual(0, output) + + + + def tearDown(self): + os.remove(self.tarball_name) + shutil.rmtree(self.index_dir) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file