From 04fc187bc44c39fcd5d8c29ce7f27f91667065ce Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Thu, 12 Dec 2024 09:27:19 +0100 Subject: [PATCH 01/14] chore: remove deprecation warnings related to `store_full_path` (#8626) * remove deprecation warnings related to store_full_path * unused imports --- haystack/components/converters/azure.py | 7 ------- haystack/components/converters/csv.py | 8 -------- haystack/components/converters/docx.py | 8 -------- haystack/components/converters/html.py | 7 ------- haystack/components/converters/json.py | 7 ------- haystack/components/converters/markdown.py | 8 -------- haystack/components/converters/pdfminer.py | 7 ------- haystack/components/converters/pptx.py | 7 ------- haystack/components/converters/pypdf.py | 8 +------- haystack/components/converters/tika.py | 7 ------- haystack/components/converters/txt.py | 7 ------- 11 files changed, 1 insertion(+), 80 deletions(-) diff --git a/haystack/components/converters/azure.py b/haystack/components/converters/azure.py index 0c1172e6dc..d55a6b8ac8 100644 --- a/haystack/components/converters/azure.py +++ b/haystack/components/converters/azure.py @@ -5,7 +5,6 @@ import copy import hashlib import os -import warnings from collections import defaultdict from pathlib import Path from typing import Any, Dict, List, Literal, Optional, Union @@ -143,12 +142,6 @@ def run(self, sources: List[Union[str, Path, ByteStream]], meta: Optional[List[D azure_output.append(result.to_dict()) merged_metadata = {**bytestream.meta, **metadata} - warnings.warn( - "The `store_full_path` parameter defaults to True, storing full file paths in metadata. " - "In the 2.9.0 release, the default value for `store_full_path` will change to False, " - "storing only file names to improve privacy.", - DeprecationWarning, - ) if not self.store_full_path and (file_path := bytestream.meta.get("file_path")): merged_metadata["file_path"] = os.path.basename(file_path) diff --git a/haystack/components/converters/csv.py b/haystack/components/converters/csv.py index 1a007dc2ea..248ce69620 100644 --- a/haystack/components/converters/csv.py +++ b/haystack/components/converters/csv.py @@ -4,7 +4,6 @@ import io import os -import warnings from pathlib import Path from typing import Any, Dict, List, Optional, Union @@ -94,13 +93,6 @@ def run( merged_metadata = {**bytestream.meta, **metadata} - warnings.warn( - "The `store_full_path` parameter defaults to True, storing full file paths in metadata. " - "In the 2.9.0 release, the default value for `store_full_path` will change to False, " - "storing only file names to improve privacy.", - DeprecationWarning, - ) - if not self.store_full_path and "file_path" in bytestream.meta: file_path = bytestream.meta.get("file_path") if file_path: # Ensure the value is not None for pylint diff --git a/haystack/components/converters/docx.py b/haystack/components/converters/docx.py index fcd7cbf33f..b9d59bd564 100644 --- a/haystack/components/converters/docx.py +++ b/haystack/components/converters/docx.py @@ -5,7 +5,6 @@ import csv import io import os -import warnings from dataclasses import dataclass from enum import Enum from io import StringIO @@ -189,13 +188,6 @@ def run( ) continue - warnings.warn( - "The `store_full_path` parameter defaults to True, storing full file paths in metadata. " - "In the 2.9.0 release, the default value for `store_full_path` will change to False, " - "storing only file names to improve privacy.", - DeprecationWarning, - ) - docx_metadata = self._get_docx_metadata(document=docx_document) merged_metadata = {**bytestream.meta, **metadata, "docx": docx_metadata} diff --git a/haystack/components/converters/html.py b/haystack/components/converters/html.py index 79baecac4f..10509e1fab 100644 --- a/haystack/components/converters/html.py +++ b/haystack/components/converters/html.py @@ -3,7 +3,6 @@ # SPDX-License-Identifier: Apache-2.0 import os -import warnings from pathlib import Path from typing import Any, Dict, List, Optional, Union @@ -123,12 +122,6 @@ def run( merged_metadata = {**bytestream.meta, **metadata} - warnings.warn( - "The `store_full_path` parameter defaults to True, storing full file paths in metadata. " - "In the 2.9.0 release, the default value for `store_full_path` will change to False, " - "storing only file names to improve privacy.", - DeprecationWarning, - ) if not self.store_full_path and "file_path" in bytestream.meta: file_path = bytestream.meta.get("file_path") if file_path: # Ensure the value is not None for pylint diff --git a/haystack/components/converters/json.py b/haystack/components/converters/json.py index 8a39237035..3a8c6f52f0 100644 --- a/haystack/components/converters/json.py +++ b/haystack/components/converters/json.py @@ -4,7 +4,6 @@ import json import os -import warnings from pathlib import Path from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union @@ -280,12 +279,6 @@ def run( data = self._get_content_and_meta(bytestream) - warnings.warn( - "The `store_full_path` parameter defaults to True, storing full file paths in metadata. " - "In the 2.9.0 release, the default value for `store_full_path` will change to False, " - "storing only file names to improve privacy.", - DeprecationWarning, - ) for text, extra_meta in data: merged_metadata = {**bytestream.meta, **metadata, **extra_meta} diff --git a/haystack/components/converters/markdown.py b/haystack/components/converters/markdown.py index cf57af557c..2ffbe4b745 100644 --- a/haystack/components/converters/markdown.py +++ b/haystack/components/converters/markdown.py @@ -3,7 +3,6 @@ # SPDX-License-Identifier: Apache-2.0 import os -import warnings from pathlib import Path from typing import Any, Dict, List, Optional, Union @@ -112,13 +111,6 @@ def run( merged_metadata = {**bytestream.meta, **metadata} - warnings.warn( - "The `store_full_path` parameter defaults to True, storing full file paths in metadata. " - "In the 2.9.0 release, the default value for `store_full_path` will change to False, " - "storing only file names to improve privacy.", - DeprecationWarning, - ) - if not self.store_full_path and (file_path := bytestream.meta.get("file_path")): merged_metadata["file_path"] = os.path.basename(file_path) diff --git a/haystack/components/converters/pdfminer.py b/haystack/components/converters/pdfminer.py index fe9a28cad7..8642447816 100644 --- a/haystack/components/converters/pdfminer.py +++ b/haystack/components/converters/pdfminer.py @@ -4,7 +4,6 @@ import io import os -import warnings from pathlib import Path from typing import Any, Dict, List, Optional, Union @@ -172,12 +171,6 @@ def run( ) merged_metadata = {**bytestream.meta, **metadata} - warnings.warn( - "The `store_full_path` parameter defaults to True, storing full file paths in metadata. " - "In the 2.9.0 release, the default value for `store_full_path` will change to False, " - "storing only file names to improve privacy.", - DeprecationWarning, - ) if not self.store_full_path and (file_path := bytestream.meta.get("file_path")): merged_metadata["file_path"] = os.path.basename(file_path) diff --git a/haystack/components/converters/pptx.py b/haystack/components/converters/pptx.py index 468d843bd3..7282cc5ddb 100644 --- a/haystack/components/converters/pptx.py +++ b/haystack/components/converters/pptx.py @@ -4,7 +4,6 @@ import io import os -import warnings from pathlib import Path from typing import Any, Dict, List, Optional, Union @@ -104,12 +103,6 @@ def run( continue merged_metadata = {**bytestream.meta, **metadata} - warnings.warn( - "The `store_full_path` parameter defaults to True, storing full file paths in metadata. " - "In the 2.9.0 release, the default value for `store_full_path` will change to False, " - "storing only file names to improve privacy.", - DeprecationWarning, - ) if not self.store_full_path and (file_path := bytestream.meta.get("file_path")): merged_metadata["file_path"] = os.path.basename(file_path) diff --git a/haystack/components/converters/pypdf.py b/haystack/components/converters/pypdf.py index df8be1ad79..19a4e2e453 100644 --- a/haystack/components/converters/pypdf.py +++ b/haystack/components/converters/pypdf.py @@ -4,7 +4,6 @@ import io import os -import warnings from enum import Enum from pathlib import Path from typing import Any, Dict, List, Optional, Union @@ -220,12 +219,7 @@ def run( ) merged_metadata = {**bytestream.meta, **metadata} - warnings.warn( - "The `store_full_path` parameter defaults to True, storing full file paths in metadata. " - "In the 2.9.0 release, the default value for `store_full_path` will change to False, " - "storing only file names to improve privacy.", - DeprecationWarning, - ) + if not self.store_full_path and (file_path := bytestream.meta.get("file_path")): merged_metadata["file_path"] = os.path.basename(file_path) document.meta = merged_metadata diff --git a/haystack/components/converters/tika.py b/haystack/components/converters/tika.py index a6a27f584d..980fb00911 100644 --- a/haystack/components/converters/tika.py +++ b/haystack/components/converters/tika.py @@ -4,7 +4,6 @@ import io import os -import warnings from html.parser import HTMLParser from pathlib import Path from typing import Any, Dict, List, Optional, Union @@ -139,12 +138,6 @@ def run( continue merged_metadata = {**bytestream.meta, **metadata} - warnings.warn( - "The `store_full_path` parameter defaults to True, storing full file paths in metadata. " - "In the 2.9.0 release, the default value for `store_full_path` will change to False, " - "storing only file names to improve privacy.", - DeprecationWarning, - ) if not self.store_full_path and (file_path := bytestream.meta.get("file_path")): merged_metadata["file_path"] = os.path.basename(file_path) diff --git a/haystack/components/converters/txt.py b/haystack/components/converters/txt.py index ea29e3f078..0ebbda8dfc 100644 --- a/haystack/components/converters/txt.py +++ b/haystack/components/converters/txt.py @@ -3,7 +3,6 @@ # SPDX-License-Identifier: Apache-2.0 import os -import warnings from pathlib import Path from typing import Any, Dict, List, Optional, Union @@ -93,12 +92,6 @@ def run( continue merged_metadata = {**bytestream.meta, **metadata} - warnings.warn( - "The `store_full_path` parameter defaults to True, storing full file paths in metadata. " - "In the 2.9.0 release, the default value for `store_full_path` will change to False, " - "storing only file names to improve privacy.", - DeprecationWarning, - ) if not self.store_full_path and (file_path := bytestream.meta.get("file_path")): merged_metadata["file_path"] = os.path.basename(file_path) From 6cceaac15f029d629f2be5ac6412b0121220f3db Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Thu, 12 Dec 2024 15:16:54 +0100 Subject: [PATCH 02/14] docs: add deprecation warning nltk document splitter (#8628) * adding deprecation warning * adding release notes * adding release notes * updating message * Update haystack/components/preprocessors/nltk_document_splitter.py Co-authored-by: Daria Fokina --------- Co-authored-by: Daria Fokina --- .../components/preprocessors/nltk_document_splitter.py | 8 ++++++++ ...deprecating-NLTKDocumentSplitter-e9a621e025e9a49f.yaml | 4 ++++ 2 files changed, 12 insertions(+) create mode 100644 releasenotes/notes/deprecating-NLTKDocumentSplitter-e9a621e025e9a49f.yaml diff --git a/haystack/components/preprocessors/nltk_document_splitter.py b/haystack/components/preprocessors/nltk_document_splitter.py index d6f947ebfc..eb242d9013 100644 --- a/haystack/components/preprocessors/nltk_document_splitter.py +++ b/haystack/components/preprocessors/nltk_document_splitter.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 +import warnings from copy import deepcopy from typing import Any, Callable, Dict, List, Literal, Optional, Tuple @@ -52,6 +53,13 @@ def __init__( # pylint: disable=too-many-positional-arguments representing the chunks after splitting. """ + warnings.warn( + "The NLTKDocumentSplitter is deprecated and will be removed in the next release. " + "See DocumentSplitter which now supports the functionalities of the NLTKDocumentSplitter, i.e.: " + "using NLTK to detect sentence boundaries.", + DeprecationWarning, + ) + super(NLTKDocumentSplitter, self).__init__( split_by=split_by, split_length=split_length, diff --git a/releasenotes/notes/deprecating-NLTKDocumentSplitter-e9a621e025e9a49f.yaml b/releasenotes/notes/deprecating-NLTKDocumentSplitter-e9a621e025e9a49f.yaml new file mode 100644 index 0000000000..e0331a00af --- /dev/null +++ b/releasenotes/notes/deprecating-NLTKDocumentSplitter-e9a621e025e9a49f.yaml @@ -0,0 +1,4 @@ +--- +deprecations: + - | + The NLTKDocumentSplitter will deprecated and will be removed in the next release. The DocumentSplitter will instead support the functionality of the NLTKDocumentSplitter. From 3f77d3ab6c665312c20907860091ad4aa33c6edb Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Thu, 12 Dec 2024 15:22:27 +0100 Subject: [PATCH 03/14] !feat: unify NLTKDocumentSplitter and DocumentSplitter (#8617) * wip: initial import * wip: refactoring * wip: refactoring tests * wip: refactoring tests * making all NLTKSplitter related tests work * refactoring * docstrings * refactoring and removing NLTKDocumentSplitter * fixing tests for custom sentence tokenizer * fixing tests for custom sentence tokenizer * cleaning up * adding release notes * reverting some changes * cleaning up tests * fixing serialisation and adding tests * cleaning up * wip * renaming and cleaning * adding NLTK files * updating docstring * adding import to init * Update haystack/components/preprocessors/document_splitter.py Co-authored-by: Stefano Fiorucci * updating tests * wip * adding sentence/period change warning * fixing LICENSE header * Update haystack/components/preprocessors/document_splitter.py Co-authored-by: Stefano Fiorucci --------- Co-authored-by: Stefano Fiorucci --- .../preprocessors/document_splitter.py | 274 +++++++++++++-- .../preprocessors/sentence_tokenizer.py | 19 +- ...-and-nltkdocsplitter-f01a983c7e7f3ed3.yaml | 4 + .../preprocessors/test_document_splitter.py | 322 +++++++++++++++++- .../preprocessors/test_sentence_tokenizer.py | 67 ++++ .../test_sentence_window_retriever.py | 2 +- 6 files changed, 635 insertions(+), 53 deletions(-) create mode 100644 releasenotes/notes/unifying-docsplitter-and-nltkdocsplitter-f01a983c7e7f3ed3.yaml create mode 100644 test/components/preprocessors/test_sentence_tokenizer.py diff --git a/haystack/components/preprocessors/document_splitter.py b/haystack/components/preprocessors/document_splitter.py index 86d95f412a..b3e99924a7 100644 --- a/haystack/components/preprocessors/document_splitter.py +++ b/haystack/components/preprocessors/document_splitter.py @@ -2,20 +2,21 @@ # # SPDX-License-Identifier: Apache-2.0 +import warnings from copy import deepcopy from typing import Any, Callable, Dict, List, Literal, Optional, Tuple from more_itertools import windowed from haystack import Document, component, logging +from haystack.components.preprocessors.sentence_tokenizer import Language, SentenceSplitter, nltk_imports from haystack.core.serialization import default_from_dict, default_to_dict from haystack.utils import deserialize_callable, serialize_callable logger = logging.getLogger(__name__) -# Maps the 'split_by' argument to the actual char used to split the Documents. -# 'function' is not in the mapping cause it doesn't split on chars. -_SPLIT_BY_MAPPING = {"page": "\f", "passage": "\n\n", "sentence": ".", "word": " ", "line": "\n"} +# mapping of split by character, 'function' and 'sentence' don't split by character +_CHARACTER_SPLIT_BY_MAPPING = {"page": "\f", "passage": "\n\n", "period": ".", "word": " ", "line": "\n"} @component @@ -23,8 +24,7 @@ class DocumentSplitter: """ Splits long documents into smaller chunks. - This is a common preprocessing step during indexing. - It helps Embedders create meaningful semantic representations + This is a common preprocessing step during indexing. It helps Embedders create meaningful semantic representations and prevents exceeding language model context limits. The DocumentSplitter is compatible with the following DocumentStores: @@ -54,18 +54,27 @@ class DocumentSplitter: def __init__( # pylint: disable=too-many-positional-arguments self, - split_by: Literal["function", "page", "passage", "sentence", "word", "line"] = "word", + split_by: Literal["function", "page", "passage", "period", "word", "line", "sentence"] = "word", split_length: int = 200, split_overlap: int = 0, split_threshold: int = 0, splitting_function: Optional[Callable[[str], List[str]]] = None, + respect_sentence_boundary: bool = False, + language: Language = "en", + use_split_rules: bool = True, + extend_abbreviations: bool = True, ): """ Initialize DocumentSplitter. - :param split_by: The unit for splitting your documents. Choose from `word` for splitting by spaces (" "), - `sentence` for splitting by periods ("."), `page` for splitting by form feed ("\\f"), - `passage` for splitting by double line breaks ("\\n\\n") or `line` for splitting each line ("\\n"). + :param split_by: The unit for splitting your documents. Choose from: + - `word` for splitting by spaces (" ") + - `period` for splitting by periods (".") + - `page` for splitting by form feed ("\\f") + - `passage` for splitting by double line breaks ("\\n\\n") + - `line` for splitting each line ("\\n") + - `sentence` for splitting by NLTK sentence tokenizer + :param split_length: The maximum number of units in each split. :param split_overlap: The number of overlapping units for each split. :param split_threshold: The minimum number of units per split. If a split has fewer units @@ -73,21 +82,87 @@ def __init__( # pylint: disable=too-many-positional-arguments :param splitting_function: Necessary when `split_by` is set to "function". This is a function which must accept a single `str` as input and return a `list` of `str` as output, representing the chunks after splitting. + :param respect_sentence_boundary: Choose whether to respect sentence boundaries when splitting by "word". + If True, uses NLTK to detect sentence boundaries, ensuring splits occur only between sentences. + :param language: Choose the language for the NLTK tokenizer. The default is English ("en"). + :param use_split_rules: Choose whether to use additional split rules when splitting by `sentence`. + :param extend_abbreviations: Choose whether to extend NLTK's PunktTokenizer abbreviations with a list + of curated abbreviations, if available. This is currently supported for English ("en") and German ("de"). """ self.split_by = split_by - if split_by not in ["function", "page", "passage", "sentence", "word", "line"]: - raise ValueError("split_by must be one of 'function', 'word', 'sentence', 'page', 'passage' or 'line'.") + self.split_length = split_length + self.split_overlap = split_overlap + self.split_threshold = split_threshold + self.splitting_function = splitting_function + self.respect_sentence_boundary = respect_sentence_boundary + self.language = language + self.use_split_rules = use_split_rules + self.extend_abbreviations = extend_abbreviations + + self._init_checks( + split_by=split_by, + split_length=split_length, + split_overlap=split_overlap, + splitting_function=splitting_function, + respect_sentence_boundary=respect_sentence_boundary, + ) + + if split_by == "sentence" or (respect_sentence_boundary and split_by == "word"): + nltk_imports.check() + self.sentence_splitter = SentenceSplitter( + language=language, + use_split_rules=use_split_rules, + extend_abbreviations=extend_abbreviations, + keep_white_spaces=True, + ) + + if split_by == "sentence": + # ToDo: remove this warning in the next major release + msg = ( + "The `split_by='sentence'` no longer splits by '.' and now relies on custom sentence tokenizer " + "based on NLTK. To achieve the previous behaviour use `split_by='period'." + ) + warnings.warn(msg) + + def _init_checks( + self, + *, + split_by: str, + split_length: int, + split_overlap: int, + splitting_function: Optional[Callable], + respect_sentence_boundary: bool, + ) -> None: + """ + Validates initialization parameters for DocumentSplitter. + + :param split_by: The unit for splitting documents + :param split_length: The maximum number of units in each split + :param split_overlap: The number of overlapping units for each split + :param splitting_function: Custom function for splitting when split_by="function" + :param respect_sentence_boundary: Whether to respect sentence boundaries when splitting + :raises ValueError: If any parameter is invalid + """ + valid_split_by = ["function", "page", "passage", "period", "word", "line", "sentence"] + if split_by not in valid_split_by: + raise ValueError(f"split_by must be one of {', '.join(valid_split_by)}.") + if split_by == "function" and splitting_function is None: raise ValueError("When 'split_by' is set to 'function', a valid 'splitting_function' must be provided.") + if split_length <= 0: raise ValueError("split_length must be greater than 0.") - self.split_length = split_length + if split_overlap < 0: raise ValueError("split_overlap must be greater than or equal to 0.") - self.split_overlap = split_overlap - self.split_threshold = split_threshold - self.splitting_function = splitting_function + + if respect_sentence_boundary and split_by != "word": + logger.warning( + "The 'respect_sentence_boundary' option is only supported for `split_by='word'`. " + "The option `respect_sentence_boundary` will be set to `False`." + ) + self.respect_sentence_boundary = False @component.output_types(documents=List[Document]) def run(self, documents: List[Document]): @@ -98,7 +173,6 @@ def run(self, documents: List[Document]): and an overlap of `split_overlap`. :param documents: The documents to split. - :returns: A dictionary with the following key: - `documents`: List of documents with the split texts. Each document includes: - A metadata field `source_id` to track the original document. @@ -121,39 +195,69 @@ def run(self, documents: List[Document]): if doc.content == "": logger.warning("Document ID {doc_id} has an empty content. Skipping this document.", doc_id=doc.id) continue - split_docs += self._split(doc) + + split_docs += self._split_document(doc) return {"documents": split_docs} - def _split(self, to_split: Document) -> List[Document]: - # We already check this before calling _split but - # we need to make linters happy - if to_split.content is None: - return [] + def _split_document(self, doc: Document) -> List[Document]: + if self.split_by == "sentence" or self.respect_sentence_boundary: + return self._split_by_nltk_sentence(doc) if self.split_by == "function" and self.splitting_function is not None: - splits = self.splitting_function(to_split.content) - docs: List[Document] = [] - for s in splits: - meta = deepcopy(to_split.meta) - meta["source_id"] = to_split.id - docs.append(Document(content=s, meta=meta)) - return docs - - split_at = _SPLIT_BY_MAPPING[self.split_by] - units = to_split.content.split(split_at) + return self._split_by_function(doc) + + return self._split_by_character(doc) + + def _split_by_nltk_sentence(self, doc: Document) -> List[Document]: + split_docs = [] + + result = self.sentence_splitter.split_sentences(doc.content) # type: ignore # None check is done in run() + units = [sentence["sentence"] for sentence in result] + + if self.respect_sentence_boundary: + text_splits, splits_pages, splits_start_idxs = self._concatenate_sentences_based_on_word_amount( + sentences=units, split_length=self.split_length, split_overlap=self.split_overlap + ) + else: + text_splits, splits_pages, splits_start_idxs = self._concatenate_units( + elements=units, + split_length=self.split_length, + split_overlap=self.split_overlap, + split_threshold=self.split_threshold, + ) + metadata = deepcopy(doc.meta) + metadata["source_id"] = doc.id + split_docs += self._create_docs_from_splits( + text_splits=text_splits, splits_pages=splits_pages, splits_start_idxs=splits_start_idxs, meta=metadata + ) + + return split_docs + + def _split_by_character(self, doc) -> List[Document]: + split_at = _CHARACTER_SPLIT_BY_MAPPING[self.split_by] + units = doc.content.split(split_at) # Add the delimiter back to all units except the last one for i in range(len(units) - 1): units[i] += split_at - text_splits, splits_pages, splits_start_idxs = self._concatenate_units( units, self.split_length, self.split_overlap, self.split_threshold ) - metadata = deepcopy(to_split.meta) - metadata["source_id"] = to_split.id + metadata = deepcopy(doc.meta) + metadata["source_id"] = doc.id return self._create_docs_from_splits( text_splits=text_splits, splits_pages=splits_pages, splits_start_idxs=splits_start_idxs, meta=metadata ) + def _split_by_function(self, doc) -> List[Document]: + # the check for None is done already in the run method + splits = self.splitting_function(doc.content) # type: ignore + docs: List[Document] = [] + for s in splits: + meta = deepcopy(doc.meta) + meta["source_id"] = doc.id + docs.append(Document(content=s, meta=meta)) + return docs + def _concatenate_units( self, elements: List[str], split_length: int, split_overlap: int, split_threshold: int ) -> Tuple[List[str], List[int], List[int]]: @@ -265,6 +369,10 @@ def to_dict(self) -> Dict[str, Any]: split_length=self.split_length, split_overlap=self.split_overlap, split_threshold=self.split_threshold, + respect_sentence_boundary=self.respect_sentence_boundary, + language=self.language, + use_split_rules=self.use_split_rules, + extend_abbreviations=self.extend_abbreviations, ) if self.splitting_function: serialized["init_parameters"]["splitting_function"] = serialize_callable(self.splitting_function) @@ -282,3 +390,99 @@ def from_dict(cls, data: Dict[str, Any]) -> "DocumentSplitter": init_params["splitting_function"] = deserialize_callable(splitting_function) return default_from_dict(cls, data) + + @staticmethod + def _concatenate_sentences_based_on_word_amount( + sentences: List[str], split_length: int, split_overlap: int + ) -> Tuple[List[str], List[int], List[int]]: + """ + Groups the sentences into chunks of `split_length` words while respecting sentence boundaries. + + This function is only used when splitting by `word` and `respect_sentence_boundary` is set to `True`, i.e.: + with NLTK sentence tokenizer. + + :param sentences: The list of sentences to split. + :param split_length: The maximum number of words in each split. + :param split_overlap: The number of overlapping words in each split. + :returns: A tuple containing the concatenated sentences, the start page numbers, and the start indices. + """ + # chunk information + chunk_word_count = 0 + chunk_starting_page_number = 1 + chunk_start_idx = 0 + current_chunk: List[str] = [] + # output lists + split_start_page_numbers = [] + list_of_splits: List[List[str]] = [] + split_start_indices = [] + + for sentence_idx, sentence in enumerate(sentences): + current_chunk.append(sentence) + chunk_word_count += len(sentence.split()) + next_sentence_word_count = ( + len(sentences[sentence_idx + 1].split()) if sentence_idx < len(sentences) - 1 else 0 + ) + + # Number of words in the current chunk plus the next sentence is larger than the split_length, + # or we reached the last sentence + if (chunk_word_count + next_sentence_word_count) > split_length or sentence_idx == len(sentences) - 1: + # Save current chunk and start a new one + list_of_splits.append(current_chunk) + split_start_page_numbers.append(chunk_starting_page_number) + split_start_indices.append(chunk_start_idx) + + # Get the number of sentences that overlap with the next chunk + num_sentences_to_keep = DocumentSplitter._number_of_sentences_to_keep( + sentences=current_chunk, split_length=split_length, split_overlap=split_overlap + ) + # Set up information for the new chunk + if num_sentences_to_keep > 0: + # Processed sentences are the ones that are not overlapping with the next chunk + processed_sentences = current_chunk[:-num_sentences_to_keep] + chunk_starting_page_number += sum(sent.count("\f") for sent in processed_sentences) + chunk_start_idx += len("".join(processed_sentences)) + # Next chunk starts with the sentences that were overlapping with the previous chunk + current_chunk = current_chunk[-num_sentences_to_keep:] + chunk_word_count = sum(len(s.split()) for s in current_chunk) + else: + # Here processed_sentences is the same as current_chunk since there is no overlap + chunk_starting_page_number += sum(sent.count("\f") for sent in current_chunk) + chunk_start_idx += len("".join(current_chunk)) + current_chunk = [] + chunk_word_count = 0 + + # Concatenate the sentences together within each split + text_splits = [] + for split in list_of_splits: + text = "".join(split) + if len(text) > 0: + text_splits.append(text) + + return text_splits, split_start_page_numbers, split_start_indices + + @staticmethod + def _number_of_sentences_to_keep(sentences: List[str], split_length: int, split_overlap: int) -> int: + """ + Returns the number of sentences to keep in the next chunk based on the `split_overlap` and `split_length`. + + :param sentences: The list of sentences to split. + :param split_length: The maximum number of words in each split. + :param split_overlap: The number of overlapping words in each split. + :returns: The number of sentences to keep in the next chunk. + """ + # If the split_overlap is 0, we don't need to keep any sentences + if split_overlap == 0: + return 0 + + num_sentences_to_keep = 0 + num_words = 0 + # Next overlapping Document should not start exactly the same as the previous one, so we skip the first sentence + for sent in reversed(sentences[1:]): + num_words += len(sent.split()) + # If the number of words is larger than the split_length then don't add any more sentences + if num_words > split_length: + break + num_sentences_to_keep += 1 + if num_words > split_overlap: + break + return num_sentences_to_keep diff --git a/haystack/components/preprocessors/sentence_tokenizer.py b/haystack/components/preprocessors/sentence_tokenizer.py index 4932513452..505126e901 100644 --- a/haystack/components/preprocessors/sentence_tokenizer.py +++ b/haystack/components/preprocessors/sentence_tokenizer.py @@ -186,11 +186,16 @@ def _needs_join( """ Checks if the spans need to be joined as parts of one sentence. + This method determines whether two adjacent sentence spans should be joined back together as a single sentence. + It's used to prevent incorrect sentence splitting in specific cases like quotations, numbered lists, + and parenthetical expressions. + :param text: The text containing the spans. - :param span: The current sentence span within text. - :param next_span: The next sentence span within text. + :param span: Tuple of (start, end) positions for the current sentence span. + :param next_span: Tuple of (start, end) positions for the next sentence span. :param quote_spans: All quoted spans within text. - :returns: True if the spans needs to be joined. + :returns: + True if the spans needs to be joined. """ start, end = span next_start, next_end = next_span @@ -216,16 +221,16 @@ def _needs_join( return re.search(r"^\s*[\(\[]", text[next_start:next_end]) is not None @staticmethod - def _read_abbreviations(language: Language) -> List[str]: + def _read_abbreviations(lang: Language) -> List[str]: """ Reads the abbreviations for a given language from the abbreviations file. - :param language: The language to read the abbreviations for. + :param lang: The language to read the abbreviations for. :returns: List of abbreviations. """ - abbreviations_file = Path(__file__).parent.parent / f"data/abbreviations/{language}.txt" + abbreviations_file = Path(__file__).parent.parent / f"data/abbreviations/{lang}.txt" if not abbreviations_file.exists(): - logger.warning("No abbreviations file found for {language}.Using default abbreviations.", language=language) + logger.warning("No abbreviations file found for {language}. Using default abbreviations.", language=lang) return [] abbreviations = abbreviations_file.read_text().split("\n") diff --git a/releasenotes/notes/unifying-docsplitter-and-nltkdocsplitter-f01a983c7e7f3ed3.yaml b/releasenotes/notes/unifying-docsplitter-and-nltkdocsplitter-f01a983c7e7f3ed3.yaml new file mode 100644 index 0000000000..2b2a9d5f1b --- /dev/null +++ b/releasenotes/notes/unifying-docsplitter-and-nltkdocsplitter-f01a983c7e7f3ed3.yaml @@ -0,0 +1,4 @@ +--- +enhancements: + - | + The `NLTKDocumentSplitter` was merged into the `DocumentSplitter` which now provides the same functionality as the `NLTKDocumentSplitter`. The `split_by="sentence"` now uses a custom sentence boundary detection based on the `nltk` library. The previous `sentence` behaviour can still be achieved by `split_by="period"` diff --git a/test/components/preprocessors/test_document_splitter.py b/test/components/preprocessors/test_document_splitter.py index 25872626c1..78767dbccd 100644 --- a/test/components/preprocessors/test_document_splitter.py +++ b/test/components/preprocessors/test_document_splitter.py @@ -1,6 +1,8 @@ # SPDX-FileCopyrightText: 2022-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 +from typing import List + import re import pytest @@ -36,7 +38,7 @@ def merge_documents(documents): return merged_text -class TestDocumentSplitter: +class TestSplittingByFunctionOrCharacterRegex: def test_non_text_document(self): with pytest.raises( ValueError, match="DocumentSplitter only works with text documents but content for document ID" @@ -56,11 +58,13 @@ def test_empty_list(self): assert res == {"documents": []} def test_unsupported_split_by(self): - with pytest.raises( - ValueError, match="split_by must be one of 'function', 'word', 'sentence', 'page', 'passage' or 'line'." - ): + with pytest.raises(ValueError, match="split_by must be one of "): DocumentSplitter(split_by="unsupported") + def test_undefined_function(self): + with pytest.raises(ValueError, match="When 'split_by' is set to 'function', a valid 'splitting_function'"): + DocumentSplitter(split_by="function", splitting_function=None) + def test_unsupported_split_length(self): with pytest.raises(ValueError, match="split_length must be greater than 0."): DocumentSplitter(split_length=0) @@ -125,8 +129,8 @@ def test_split_by_word_multiple_input_docs(self): assert docs[4].meta["split_id"] == 2 assert docs[4].meta["split_idx_start"] == text2.index(docs[4].content) - def test_split_by_sentence(self): - splitter = DocumentSplitter(split_by="sentence", split_length=1) + def test_split_by_period(self): + splitter = DocumentSplitter(split_by="period", split_length=1) text = "This is a text with some words. There is a second sentence. And there is a third sentence." result = splitter.run(documents=[Document(content=text)]) docs = result["documents"] @@ -275,8 +279,8 @@ def test_add_page_number_to_metadata_with_no_overlap_word_split(self): for doc, p in zip(result["documents"], expected_pages): assert doc.meta["page_number"] == p - def test_add_page_number_to_metadata_with_no_overlap_sentence_split(self): - splitter = DocumentSplitter(split_by="sentence", split_length=1) + def test_add_page_number_to_metadata_with_no_overlap_period_split(self): + splitter = DocumentSplitter(split_by="period", split_length=1) doc1 = Document(content="This is some text.\f This text is on another page.") doc2 = Document(content="This content has two.\f\f page brakes.") result = splitter.run(documents=[doc1, doc2]) @@ -326,8 +330,8 @@ def test_add_page_number_to_metadata_with_overlap_word_split(self): for doc, p in zip(result["documents"], expected_pages): assert doc.meta["page_number"] == p - def test_add_page_number_to_metadata_with_overlap_sentence_split(self): - splitter = DocumentSplitter(split_by="sentence", split_length=2, split_overlap=1) + def test_add_page_number_to_metadata_with_overlap_period_split(self): + splitter = DocumentSplitter(split_by="period", split_length=2, split_overlap=1) doc1 = Document(content="This is some text. And this is more text.\f This text is on another page. End.") doc2 = Document(content="This content has two.\f\f page brakes. More text.") result = splitter.run(documents=[doc1, doc2]) @@ -494,3 +498,301 @@ def test_run_document_only_whitespaces(self): doc = Document(content=" ") results = splitter.run([doc]) assert results["documents"][0].content == " " + + +class TestSplittingNLTKSentenceSplitter: + @pytest.mark.parametrize( + "sentences, expected_num_sentences", + [ + (["The sun set.", "Moonlight shimmered softly, wolves howled nearby, night enveloped everything."], 0), + (["The sun set.", "It was a dark night ..."], 0), + (["The sun set.", " The moon was full."], 1), + (["The sun.", " The moon."], 1), # Ignores the first sentence + (["Sun", "Moon"], 1), # Ignores the first sentence even if its inclusion would be < split_overlap + ], + ) + def test_number_of_sentences_to_keep(self, sentences: List[str], expected_num_sentences: int) -> None: + num_sentences = DocumentSplitter._number_of_sentences_to_keep( + sentences=sentences, split_length=5, split_overlap=2 + ) + assert num_sentences == expected_num_sentences + + def test_number_of_sentences_to_keep_split_overlap_zero(self) -> None: + sentences = [ + "Moonlight shimmered softly, wolves howled nearby, night enveloped everything.", + " It was a dark night ...", + " The moon was full.", + ] + num_sentences = DocumentSplitter._number_of_sentences_to_keep( + sentences=sentences, split_length=5, split_overlap=0 + ) + assert num_sentences == 0 + + def test_run_split_by_sentence_1(self) -> None: + document_splitter = DocumentSplitter( + split_by="sentence", + split_length=2, + split_overlap=0, + split_threshold=0, + language="en", + use_split_rules=True, + extend_abbreviations=True, + ) + + text = ( + "Moonlight shimmered softly, wolves howled nearby, night enveloped everything. It was a dark night ... " + "The moon was full." + ) + documents = document_splitter.run(documents=[Document(content=text)])["documents"] + + assert len(documents) == 2 + assert ( + documents[0].content == "Moonlight shimmered softly, wolves howled nearby, night enveloped " + "everything. It was a dark night ... " + ) + assert documents[1].content == "The moon was full." + + def test_run_split_by_sentence_2(self) -> None: + document_splitter = DocumentSplitter( + split_by="sentence", + split_length=1, + split_overlap=0, + split_threshold=0, + language="en", + use_split_rules=False, + extend_abbreviations=True, + ) + + text = ( + "This is a test sentence with many many words that exceeds the split length and should not be repeated. " + "This is another test sentence. (This is a third test sentence.) " + "This is the last test sentence." + ) + documents = document_splitter.run(documents=[Document(content=text)])["documents"] + + assert len(documents) == 4 + assert ( + documents[0].content + == "This is a test sentence with many many words that exceeds the split length and should not be repeated. " + ) + assert documents[0].meta["page_number"] == 1 + assert documents[0].meta["split_id"] == 0 + assert documents[0].meta["split_idx_start"] == text.index(documents[0].content) + assert documents[1].content == "This is another test sentence. " + assert documents[1].meta["page_number"] == 1 + assert documents[1].meta["split_id"] == 1 + assert documents[1].meta["split_idx_start"] == text.index(documents[1].content) + assert documents[2].content == "(This is a third test sentence.) " + assert documents[2].meta["page_number"] == 1 + assert documents[2].meta["split_id"] == 2 + assert documents[2].meta["split_idx_start"] == text.index(documents[2].content) + assert documents[3].content == "This is the last test sentence." + assert documents[3].meta["page_number"] == 1 + assert documents[3].meta["split_id"] == 3 + assert documents[3].meta["split_idx_start"] == text.index(documents[3].content) + + def test_run_split_by_sentence_3(self) -> None: + document_splitter = DocumentSplitter( + split_by="sentence", + split_length=1, + split_overlap=0, + split_threshold=0, + language="en", + use_split_rules=True, + extend_abbreviations=True, + ) + + text = "Sentence on page 1.\fSentence on page 2. \fSentence on page 3. \f\f Sentence on page 5." + documents = document_splitter.run(documents=[Document(content=text)])["documents"] + + assert len(documents) == 4 + assert documents[0].content == "Sentence on page 1.\f" + assert documents[0].meta["page_number"] == 1 + assert documents[0].meta["split_id"] == 0 + assert documents[0].meta["split_idx_start"] == text.index(documents[0].content) + assert documents[1].content == "Sentence on page 2. \f" + assert documents[1].meta["page_number"] == 2 + assert documents[1].meta["split_id"] == 1 + assert documents[1].meta["split_idx_start"] == text.index(documents[1].content) + assert documents[2].content == "Sentence on page 3. \f\f " + assert documents[2].meta["page_number"] == 3 + assert documents[2].meta["split_id"] == 2 + assert documents[2].meta["split_idx_start"] == text.index(documents[2].content) + assert documents[3].content == "Sentence on page 5." + assert documents[3].meta["page_number"] == 5 + assert documents[3].meta["split_id"] == 3 + assert documents[3].meta["split_idx_start"] == text.index(documents[3].content) + + def test_run_split_by_sentence_4(self) -> None: + document_splitter = DocumentSplitter( + split_by="sentence", + split_length=2, + split_overlap=1, + split_threshold=0, + language="en", + use_split_rules=True, + extend_abbreviations=True, + ) + + text = "Sentence on page 1.\fSentence on page 2. \fSentence on page 3. \f\f Sentence on page 5." + documents = document_splitter.run(documents=[Document(content=text)])["documents"] + + assert len(documents) == 3 + assert documents[0].content == "Sentence on page 1.\fSentence on page 2. \f" + assert documents[0].meta["page_number"] == 1 + assert documents[0].meta["split_id"] == 0 + assert documents[0].meta["split_idx_start"] == text.index(documents[0].content) + assert documents[1].content == "Sentence on page 2. \fSentence on page 3. \f\f " + assert documents[1].meta["page_number"] == 2 + assert documents[1].meta["split_id"] == 1 + assert documents[1].meta["split_idx_start"] == text.index(documents[1].content) + assert documents[2].content == "Sentence on page 3. \f\f Sentence on page 5." + assert documents[2].meta["page_number"] == 3 + assert documents[2].meta["split_id"] == 2 + assert documents[2].meta["split_idx_start"] == text.index(documents[2].content) + + def test_run_split_by_word_respect_sentence_boundary(self) -> None: + document_splitter = DocumentSplitter( + split_by="word", + split_length=3, + split_overlap=0, + split_threshold=0, + language="en", + respect_sentence_boundary=True, + ) + + text = ( + "Moonlight shimmered softly, wolves howled nearby, night enveloped everything. It was a dark night.\f" + "The moon was full." + ) + documents = document_splitter.run(documents=[Document(content=text)])["documents"] + + assert len(documents) == 3 + assert documents[0].content == "Moonlight shimmered softly, wolves howled nearby, night enveloped everything. " + assert documents[0].meta["page_number"] == 1 + assert documents[0].meta["split_id"] == 0 + assert documents[0].meta["split_idx_start"] == text.index(documents[0].content) + assert documents[1].content == "It was a dark night.\f" + assert documents[1].meta["page_number"] == 1 + assert documents[1].meta["split_id"] == 1 + assert documents[1].meta["split_idx_start"] == text.index(documents[1].content) + assert documents[2].content == "The moon was full." + assert documents[2].meta["page_number"] == 2 + assert documents[2].meta["split_id"] == 2 + assert documents[2].meta["split_idx_start"] == text.index(documents[2].content) + + def test_run_split_by_word_respect_sentence_boundary_no_repeats(self) -> None: + document_splitter = DocumentSplitter( + split_by="word", + split_length=13, + split_overlap=3, + split_threshold=0, + language="en", + respect_sentence_boundary=True, + use_split_rules=False, + extend_abbreviations=False, + ) + text = ( + "This is a test sentence with many many words that exceeds the split length and should not be repeated. " + "This is another test sentence. (This is a third test sentence.) " + "This is the last test sentence." + ) + documents = document_splitter.run([Document(content=text)])["documents"] + assert len(documents) == 3 + assert ( + documents[0].content + == "This is a test sentence with many many words that exceeds the split length and should not be repeated. " + ) + assert "This is a test sentence with many many words" not in documents[1].content + assert "This is a test sentence with many many words" not in documents[2].content + + def test_run_split_by_word_respect_sentence_boundary_with_split_overlap_and_page_breaks(self) -> None: + document_splitter = DocumentSplitter( + split_by="word", + split_length=8, + split_overlap=1, + split_threshold=0, + language="en", + use_split_rules=True, + extend_abbreviations=True, + respect_sentence_boundary=True, + ) + + text = ( + "Sentence on page 1. Another on page 1.\fSentence on page 2. Another on page 2.\f" + "Sentence on page 3. Another on page 3.\f\f Sentence on page 5." + ) + documents = document_splitter.run(documents=[Document(content=text)])["documents"] + + assert len(documents) == 6 + assert documents[0].content == "Sentence on page 1. Another on page 1.\f" + assert documents[0].meta["page_number"] == 1 + assert documents[0].meta["split_id"] == 0 + assert documents[0].meta["split_idx_start"] == text.index(documents[0].content) + assert documents[1].content == "Another on page 1.\fSentence on page 2. " + assert documents[1].meta["page_number"] == 1 + assert documents[1].meta["split_id"] == 1 + assert documents[1].meta["split_idx_start"] == text.index(documents[1].content) + assert documents[2].content == "Sentence on page 2. Another on page 2.\f" + assert documents[2].meta["page_number"] == 2 + assert documents[2].meta["split_id"] == 2 + assert documents[2].meta["split_idx_start"] == text.index(documents[2].content) + assert documents[3].content == "Another on page 2.\fSentence on page 3. " + assert documents[3].meta["page_number"] == 2 + assert documents[3].meta["split_id"] == 3 + assert documents[3].meta["split_idx_start"] == text.index(documents[3].content) + assert documents[4].content == "Sentence on page 3. Another on page 3.\f\f " + assert documents[4].meta["page_number"] == 3 + assert documents[4].meta["split_id"] == 4 + assert documents[4].meta["split_idx_start"] == text.index(documents[4].content) + assert documents[5].content == "Another on page 3.\f\f Sentence on page 5." + assert documents[5].meta["page_number"] == 3 + assert documents[5].meta["split_id"] == 5 + assert documents[5].meta["split_idx_start"] == text.index(documents[5].content) + + def test_respect_sentence_boundary_checks(self): + # this combination triggers the warning + splitter = DocumentSplitter(split_by="sentence", split_length=10, respect_sentence_boundary=True) + assert splitter.respect_sentence_boundary == False + + def test_sentence_serialization(self): + """Test serialization with NLTK sentence splitting configuration and using non-default values""" + splitter = DocumentSplitter( + split_by="sentence", + language="de", + use_split_rules=False, + extend_abbreviations=False, + respect_sentence_boundary=False, + ) + serialized = splitter.to_dict() + deserialized = DocumentSplitter.from_dict(serialized) + + assert deserialized.split_by == "sentence" + assert hasattr(deserialized, "sentence_splitter") + assert deserialized.language == "de" + assert deserialized.use_split_rules == False + assert deserialized.extend_abbreviations == False + assert deserialized.respect_sentence_boundary == False + + def test_nltk_serialization_roundtrip(self): + """Test complete serialization roundtrip with actual document splitting""" + splitter = DocumentSplitter( + split_by="sentence", + language="de", + use_split_rules=False, + extend_abbreviations=False, + respect_sentence_boundary=False, + ) + serialized = splitter.to_dict() + deserialized_splitter = DocumentSplitter.from_dict(serialized) + assert splitter.split_by == deserialized_splitter.split_by + + def test_respect_sentence_boundary_serialization(self): + """Test serialization with respect_sentence_boundary option""" + splitter = DocumentSplitter(split_by="word", respect_sentence_boundary=True, language="de") + serialized = splitter.to_dict() + deserialized = DocumentSplitter.from_dict(serialized) + + assert deserialized.respect_sentence_boundary == True + assert hasattr(deserialized, "sentence_splitter") + assert deserialized.language == "de" diff --git a/test/components/preprocessors/test_sentence_tokenizer.py b/test/components/preprocessors/test_sentence_tokenizer.py new file mode 100644 index 0000000000..bf9aab9a9e --- /dev/null +++ b/test/components/preprocessors/test_sentence_tokenizer.py @@ -0,0 +1,67 @@ +import pytest +from unittest.mock import patch +from pathlib import Path + +from haystack.components.preprocessors.sentence_tokenizer import SentenceSplitter + +from pytest import LogCaptureFixture + + +def test_apply_split_rules_no_join() -> None: + text = "This is a test. This is another test. And a third test." + spans = [(0, 15), (16, 36), (37, 54)] + result = SentenceSplitter._apply_split_rules(text, spans) + assert len(result) == 3 + assert result == [(0, 15), (16, 36), (37, 54)] + + +def test_apply_split_rules_join_case_1(): + text = 'He said "This is sentence one. This is sentence two." Then he left.' + result = SentenceSplitter._apply_split_rules(text, [(0, 30), (31, 53), (54, 67)]) + assert len(result) == 2 + assert result == [(0, 53), (54, 67)] + + +def test_apply_split_rules_join_case_3(): + splitter = SentenceSplitter(language="en", use_split_rules=True) + text = """ + 1. First item + 2. Second item + 3. Third item.""" + spans = [(0, 7), (8, 25), (26, 44), (45, 56)] + result = splitter._apply_split_rules(text, spans) + assert len(result) == 1 + assert result == [(0, 56)] + + +def test_apply_split_rules_join_case_4() -> None: + text = "This is a test. (With a parenthetical statement.) And another sentence." + spans = [(0, 15), (16, 50), (51, 74)] + result = SentenceSplitter._apply_split_rules(text, spans) + assert len(result) == 2 + assert result == [(0, 50), (51, 74)] + + +@pytest.fixture +def mock_file_content(): + return "Mr.\nDr.\nProf." + + +def test_read_abbreviations_existing_file(tmp_path, mock_file_content): + abbrev_dir = tmp_path / "data" / "abbreviations" + abbrev_dir.mkdir(parents=True) + abbrev_file = abbrev_dir / f"en.txt" + abbrev_file.write_text(mock_file_content) + + with patch("haystack.components.preprocessors.sentence_tokenizer.Path") as mock_path: + mock_path.return_value.parent.parent = tmp_path + result = SentenceSplitter._read_abbreviations("en") + assert result == ["Mr.", "Dr.", "Prof."] + + +def test_read_abbreviations_missing_file(caplog: LogCaptureFixture): + with patch("haystack.components.preprocessors.sentence_tokenizer.Path") as mock_path: + mock_path.return_value.parent.parent = Path("/nonexistent") + result = SentenceSplitter._read_abbreviations("pt") + assert result == [] + assert "No abbreviations file found for pt. Using default abbreviations." in caplog.text diff --git a/test/components/retrievers/test_sentence_window_retriever.py b/test/components/retrievers/test_sentence_window_retriever.py index 4979fa334d..04e03befbe 100644 --- a/test/components/retrievers/test_sentence_window_retriever.py +++ b/test/components/retrievers/test_sentence_window_retriever.py @@ -176,7 +176,7 @@ def test_context_documents_returned_are_ordered_by_split_idx_start(self): @pytest.mark.integration def test_run_with_pipeline(self): - splitter = DocumentSplitter(split_length=1, split_overlap=0, split_by="sentence") + splitter = DocumentSplitter(split_length=1, split_overlap=0, split_by="period") text = ( "This is a text with some words. There is a second sentence. And there is also a third sentence. " "It also contains a fourth sentence. And a fifth sentence. And a sixth sentence. And a seventh sentence" From 2a9a6401d25e7e67765668cdbeec00be94cd17cf Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Thu, 12 Dec 2024 16:26:38 +0100 Subject: [PATCH 04/14] chore: pin `openai>=1.56.1` (#8632) * pin openai>=1.56.1 * release note --- pyproject.toml | 2 +- releasenotes/notes/pin-openai-1-56-1-43d50ebfbb8b5a8d.yaml | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) create mode 100644 releasenotes/notes/pin-openai-1-56-1-43d50ebfbb8b5a8d.yaml diff --git a/pyproject.toml b/pyproject.toml index 6adf05bf0a..c41c429ced 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,7 @@ dependencies = [ "tqdm", "tenacity!=8.4.0", "lazy-imports", - "openai>=1.1.0", + "openai>=1.56.1", "Jinja2", "posthog", # telemetry "pyyaml", diff --git a/releasenotes/notes/pin-openai-1-56-1-43d50ebfbb8b5a8d.yaml b/releasenotes/notes/pin-openai-1-56-1-43d50ebfbb8b5a8d.yaml new file mode 100644 index 0000000000..73f8d3458f --- /dev/null +++ b/releasenotes/notes/pin-openai-1-56-1-43d50ebfbb8b5a8d.yaml @@ -0,0 +1,4 @@ +--- +fixes: + - | + Pin OpenAI client to >=1.56.1 to avoid issues related to changes in the httpx library. From f2b5f123b32a52321560c4ac476277e4d851b845 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Fri, 13 Dec 2024 09:50:23 +0100 Subject: [PATCH 05/14] del HF token in tests (#8634) --- .../classifiers/test_zero_shot_document_classifier.py | 2 ++ test/components/generators/chat/test_hugging_face_local.py | 1 + .../components/generators/test_hugging_face_local_generator.py | 1 + .../components/rankers/test_sentence_transformers_diversity.py | 2 +- test/components/rankers/test_transformers_similarity.py | 1 + test/components/readers/test_extractive.py | 3 +++ test/components/routers/test_transformers_text_router.py | 3 +++ test/components/routers/test_zero_shot_text_router.py | 2 ++ 8 files changed, 14 insertions(+), 1 deletion(-) diff --git a/test/components/classifiers/test_zero_shot_document_classifier.py b/test/components/classifiers/test_zero_shot_document_classifier.py index 7d679e3d21..be4d04a9fe 100644 --- a/test/components/classifiers/test_zero_shot_document_classifier.py +++ b/test/components/classifiers/test_zero_shot_document_classifier.py @@ -45,6 +45,7 @@ def test_to_dict(self): def test_from_dict(self, monkeypatch): monkeypatch.delenv("HF_API_TOKEN", raising=False) + monkeypatch.delenv("HF_TOKEN", raising=False) data = { "type": "haystack.components.classifiers.zero_shot_document_classifier.TransformersZeroShotDocumentClassifier", "init_parameters": { @@ -73,6 +74,7 @@ def test_from_dict(self, monkeypatch): def test_from_dict_no_default_parameters(self, monkeypatch): monkeypatch.delenv("HF_API_TOKEN", raising=False) + monkeypatch.delenv("HF_TOKEN", raising=False) data = { "type": "haystack.components.classifiers.zero_shot_document_classifier.TransformersZeroShotDocumentClassifier", "init_parameters": {"model": "cross-encoder/nli-deberta-v3-xsmall", "labels": ["positive", "negative"]}, diff --git a/test/components/generators/chat/test_hugging_face_local.py b/test/components/generators/chat/test_hugging_face_local.py index 433917ec23..8f6749c2d8 100644 --- a/test/components/generators/chat/test_hugging_face_local.py +++ b/test/components/generators/chat/test_hugging_face_local.py @@ -166,6 +166,7 @@ def test_from_dict(self, model_info_mock): @patch("haystack.components.generators.chat.hugging_face_local.pipeline") def test_warm_up(self, pipeline_mock, monkeypatch): monkeypatch.delenv("HF_API_TOKEN", raising=False) + monkeypatch.delenv("HF_TOKEN", raising=False) generator = HuggingFaceLocalChatGenerator( model="mistralai/Mistral-7B-Instruct-v0.2", task="text2text-generation", diff --git a/test/components/generators/test_hugging_face_local_generator.py b/test/components/generators/test_hugging_face_local_generator.py index 5c3b162a31..bded2e8d47 100644 --- a/test/components/generators/test_hugging_face_local_generator.py +++ b/test/components/generators/test_hugging_face_local_generator.py @@ -18,6 +18,7 @@ class TestHuggingFaceLocalGenerator: @patch("haystack.utils.hf.model_info") def test_init_default(self, model_info_mock, monkeypatch): monkeypatch.delenv("HF_API_TOKEN", raising=False) + monkeypatch.delenv("HF_TOKEN", raising=False) model_info_mock.return_value.pipeline_tag = "text2text-generation" generator = HuggingFaceLocalGenerator() diff --git a/test/components/rankers/test_sentence_transformers_diversity.py b/test/components/rankers/test_sentence_transformers_diversity.py index eabd2ac375..018b443987 100644 --- a/test/components/rankers/test_sentence_transformers_diversity.py +++ b/test/components/rankers/test_sentence_transformers_diversity.py @@ -273,7 +273,7 @@ def test_warm_up(self, similarity, monkeypatch): Test that ranker loads the SentenceTransformer model correctly during warm up. """ monkeypatch.delenv("HF_API_TOKEN", raising=False) - + monkeypatch.delenv("HF_TOKEN", raising=False) mock_model_class = MagicMock() mock_model_instance = MagicMock() mock_model_class.return_value = mock_model_instance diff --git a/test/components/rankers/test_transformers_similarity.py b/test/components/rankers/test_transformers_similarity.py index 6031d85e15..616bfa6647 100644 --- a/test/components/rankers/test_transformers_similarity.py +++ b/test/components/rankers/test_transformers_similarity.py @@ -313,6 +313,7 @@ def test_device_map_and_device_raises(self, caplog): @patch("haystack.components.rankers.transformers_similarity.AutoModelForSequenceClassification.from_pretrained") def test_device_map_dict(self, mocked_automodel, _mocked_autotokenizer, monkeypatch): monkeypatch.delenv("HF_API_TOKEN", raising=False) + monkeypatch.delenv("HF_TOKEN", raising=False) ranker = TransformersSimilarityRanker("model", model_kwargs={"device_map": {"layer_1": 1, "classifier": "cpu"}}) class MockedModel: diff --git a/test/components/readers/test_extractive.py b/test/components/readers/test_extractive.py index aedfaa13bc..a2f658b79b 100644 --- a/test/components/readers/test_extractive.py +++ b/test/components/readers/test_extractive.py @@ -519,6 +519,7 @@ def __init__(self): @patch("haystack.components.readers.extractive.AutoModelForQuestionAnswering.from_pretrained") def test_device_map_auto(mocked_automodel, _mocked_autotokenizer, monkeypatch): monkeypatch.delenv("HF_API_TOKEN", raising=False) + monkeypatch.delenv("HF_TOKEN", raising=False) reader = ExtractiveReader("deepset/roberta-base-squad2", model_kwargs={"device_map": "auto"}) auto_device = ComponentDevice.resolve_device(None) @@ -537,6 +538,7 @@ def __init__(self): @patch("haystack.components.readers.extractive.AutoModelForQuestionAnswering.from_pretrained") def test_device_map_str(mocked_automodel, _mocked_autotokenizer, monkeypatch): monkeypatch.delenv("HF_API_TOKEN", raising=False) + monkeypatch.delenv("HF_TOKEN", raising=False) reader = ExtractiveReader("deepset/roberta-base-squad2", model_kwargs={"device_map": "cpu:0"}) class MockedModel: @@ -554,6 +556,7 @@ def __init__(self): @patch("haystack.components.readers.extractive.AutoModelForQuestionAnswering.from_pretrained") def test_device_map_dict(mocked_automodel, _mocked_autotokenizer, monkeypatch): monkeypatch.delenv("HF_API_TOKEN", raising=False) + monkeypatch.delenv("HF_TOKEN", raising=False) reader = ExtractiveReader( "deepset/roberta-base-squad2", model_kwargs={"device_map": {"layer_1": 1, "classifier": "cpu"}} ) diff --git a/test/components/routers/test_transformers_text_router.py b/test/components/routers/test_transformers_text_router.py index 8a0dca8d63..67ec163524 100644 --- a/test/components/routers/test_transformers_text_router.py +++ b/test/components/routers/test_transformers_text_router.py @@ -54,6 +54,7 @@ def test_to_dict_with_cpu_device(self, mock_auto_config_from_pretrained): def test_from_dict(self, mock_auto_config_from_pretrained, monkeypatch): mock_auto_config_from_pretrained.return_value = MagicMock(label2id={"en": 0, "de": 1}) monkeypatch.delenv("HF_API_TOKEN", raising=False) + monkeypatch.delenv("HF_TOKEN", raising=False) data = { "type": "haystack.components.routers.transformers_text_router.TransformersTextRouter", "init_parameters": { @@ -84,6 +85,7 @@ def test_from_dict(self, mock_auto_config_from_pretrained, monkeypatch): def test_from_dict_no_default_parameters(self, mock_auto_config_from_pretrained, monkeypatch): mock_auto_config_from_pretrained.return_value = MagicMock(label2id={"en": 0, "de": 1}) monkeypatch.delenv("HF_API_TOKEN", raising=False) + monkeypatch.delenv("HF_TOKEN", raising=False) data = { "type": "haystack.components.routers.transformers_text_router.TransformersTextRouter", "init_parameters": {"model": "papluca/xlm-roberta-base-language-detection"}, @@ -105,6 +107,7 @@ def test_from_dict_no_default_parameters(self, mock_auto_config_from_pretrained, def test_from_dict_with_cpu_device(self, mock_auto_config_from_pretrained, monkeypatch): mock_auto_config_from_pretrained.return_value = MagicMock(label2id={"en": 0, "de": 1}) monkeypatch.delenv("HF_API_TOKEN", raising=False) + monkeypatch.delenv("HF_TOKEN", raising=False) data = { "type": "haystack.components.routers.transformers_text_router.TransformersTextRouter", "init_parameters": { diff --git a/test/components/routers/test_zero_shot_text_router.py b/test/components/routers/test_zero_shot_text_router.py index 8e9759f361..3b931c39bb 100644 --- a/test/components/routers/test_zero_shot_text_router.py +++ b/test/components/routers/test_zero_shot_text_router.py @@ -28,6 +28,7 @@ def test_to_dict(self): def test_from_dict(self, monkeypatch): monkeypatch.delenv("HF_API_TOKEN", raising=False) + monkeypatch.delenv("HF_TOKEN", raising=False) data = { "type": "haystack.components.routers.zero_shot_text_router.TransformersZeroShotTextRouter", "init_parameters": { @@ -56,6 +57,7 @@ def test_from_dict(self, monkeypatch): def test_from_dict_no_default_parameters(self, monkeypatch): monkeypatch.delenv("HF_API_TOKEN", raising=False) + monkeypatch.delenv("HF_TOKEN", raising=False) data = { "type": "haystack.components.routers.zero_shot_text_router.TransformersZeroShotTextRouter", "init_parameters": {"labels": ["query", "passage"]}, From 176db5dbf9d5be87122e3feafa19593fed418cde Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Fri, 13 Dec 2024 12:12:40 +0100 Subject: [PATCH 06/14] initial import (#8635) --- e2e/pipelines/test_dense_doc_search.py | 2 +- e2e/pipelines/test_preprocessing_pipeline.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/e2e/pipelines/test_dense_doc_search.py b/e2e/pipelines/test_dense_doc_search.py index 39a587a106..f348b6f0e5 100644 --- a/e2e/pipelines/test_dense_doc_search.py +++ b/e2e/pipelines/test_dense_doc_search.py @@ -26,7 +26,7 @@ def test_dense_doc_search_pipeline(tmp_path, samples_path): indexing_pipeline.add_component(instance=DocumentJoiner(), name="joiner") indexing_pipeline.add_component(instance=DocumentCleaner(), name="cleaner") indexing_pipeline.add_component( - instance=DocumentSplitter(split_by="sentence", split_length=250, split_overlap=30), name="splitter" + instance=DocumentSplitter(split_by="period", split_length=250, split_overlap=30), name="splitter" ) indexing_pipeline.add_component( instance=SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"), name="embedder" diff --git a/e2e/pipelines/test_preprocessing_pipeline.py b/e2e/pipelines/test_preprocessing_pipeline.py index 82375f89d8..8894113913 100644 --- a/e2e/pipelines/test_preprocessing_pipeline.py +++ b/e2e/pipelines/test_preprocessing_pipeline.py @@ -25,9 +25,7 @@ def test_preprocessing_pipeline(tmp_path): instance=MetadataRouter(rules={"en": {"field": "language", "operator": "==", "value": "en"}}), name="router" ) preprocessing_pipeline.add_component(instance=DocumentCleaner(), name="cleaner") - preprocessing_pipeline.add_component( - instance=DocumentSplitter(split_by="sentence", split_length=1), name="splitter" - ) + preprocessing_pipeline.add_component(instance=DocumentSplitter(split_by="period", split_length=1), name="splitter") preprocessing_pipeline.add_component( instance=SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"), name="embedder" ) From db89b9a2e59da5a0fe59135ef4ab1f6252e2a7db Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Fri, 13 Dec 2024 12:35:58 +0100 Subject: [PATCH 07/14] fix: removing unused import (#8636) --- e2e/pipelines/test_preprocessing_pipeline.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/e2e/pipelines/test_preprocessing_pipeline.py b/e2e/pipelines/test_preprocessing_pipeline.py index 8894113913..4667454276 100644 --- a/e2e/pipelines/test_preprocessing_pipeline.py +++ b/e2e/pipelines/test_preprocessing_pipeline.py @@ -2,8 +2,6 @@ # # SPDX-License-Identifier: Apache-2.0 -import json - from haystack import Pipeline from haystack.components.classifiers import DocumentLanguageClassifier from haystack.components.converters import TextFileToDocument From a5b57f4b1fd4ef4227d7d54170f99b142836a04c Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Mon, 16 Dec 2024 13:57:41 +0100 Subject: [PATCH 08/14] adding SentenceSplitter to init imports (#8644) --- haystack/components/preprocessors/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/haystack/components/preprocessors/__init__.py b/haystack/components/preprocessors/__init__.py index f7e132077a..467f16ceeb 100644 --- a/haystack/components/preprocessors/__init__.py +++ b/haystack/components/preprocessors/__init__.py @@ -5,6 +5,7 @@ from .document_cleaner import DocumentCleaner from .document_splitter import DocumentSplitter from .nltk_document_splitter import NLTKDocumentSplitter +from .sentence_tokenizer import SentenceSplitter from .text_cleaner import TextCleaner -__all__ = ["DocumentSplitter", "DocumentCleaner", "TextCleaner", "NLTKDocumentSplitter"] +__all__ = ["DocumentSplitter", "DocumentCleaner", "NLTKDocumentSplitter", "SentenceSplitter", "TextCleaner"] From ea3602643aa52c27f3bea7bf5bc90b97f568dcdc Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Tue, 17 Dec 2024 17:02:04 +0100 Subject: [PATCH 09/14] feat!: new `ChatMessage` (#8640) * draft * del HF token in tests * adaptations * progress * fix type * import sorting * more control on deserialization * release note * improvements * support name field * fix chatpromptbuilder test * Update chat_message.py --------- Co-authored-by: Daria Fokina --- .../builders/chat_prompt_builder.py | 6 +- .../generators/chat/hugging_face_api.py | 7 +- .../components/generators/openai_utils.py | 9 +- haystack/dataclasses/__init__.py | 5 +- haystack/dataclasses/chat_message.py | 336 +++++++++++++++--- .../new-chatmessage-7f47d5bdeb6ad6f5.yaml | 23 ++ .../builders/test_chat_prompt_builder.py | 22 +- .../generators/chat/test_hugging_face_api.py | 7 - .../generators/test_openai_utils.py | 7 - .../routers/test_conditional_router.py | 6 +- test/core/pipeline/features/test_run.py | 6 +- test/dataclasses/test_chat_message.py | 283 +++++++++++---- 12 files changed, 560 insertions(+), 157 deletions(-) create mode 100644 releasenotes/notes/new-chatmessage-7f47d5bdeb6ad6f5.yaml diff --git a/haystack/components/builders/chat_prompt_builder.py b/haystack/components/builders/chat_prompt_builder.py index fd9969f5b7..33e2feda2d 100644 --- a/haystack/components/builders/chat_prompt_builder.py +++ b/haystack/components/builders/chat_prompt_builder.py @@ -9,7 +9,7 @@ from jinja2.sandbox import SandboxedEnvironment from haystack import component, default_from_dict, default_to_dict, logging -from haystack.dataclasses.chat_message import ChatMessage, ChatRole +from haystack.dataclasses.chat_message import ChatMessage, ChatRole, TextContent logger = logging.getLogger(__name__) @@ -197,10 +197,10 @@ def run( if message.text is None: raise ValueError(f"The provided ChatMessage has no text. ChatMessage: {message}") compiled_template = self._env.from_string(message.text) - rendered_content = compiled_template.render(template_variables_combined) + rendered_text = compiled_template.render(template_variables_combined) # deep copy the message to avoid modifying the original message rendered_message: ChatMessage = deepcopy(message) - rendered_message.content = rendered_content + rendered_message._content = [TextContent(text=rendered_text)] processed_messages.append(rendered_message) else: processed_messages.append(message) diff --git a/haystack/components/generators/chat/hugging_face_api.py b/haystack/components/generators/chat/hugging_face_api.py index d4ecd53f10..8711a9175a 100644 --- a/haystack/components/generators/chat/hugging_face_api.py +++ b/haystack/components/generators/chat/hugging_face_api.py @@ -25,13 +25,8 @@ def _convert_message_to_hfapi_format(message: ChatMessage) -> Dict[str, str]: :returns: A dictionary with the following keys: - `role` - `content` - - `name` (optional) """ - formatted_msg = {"role": message.role.value, "content": message.content} - if message.name: - formatted_msg["name"] = message.name - - return formatted_msg + return {"role": message.role.value, "content": message.text or ""} @component diff --git a/haystack/components/generators/openai_utils.py b/haystack/components/generators/openai_utils.py index 5b1838c386..ab6d5e7b1d 100644 --- a/haystack/components/generators/openai_utils.py +++ b/haystack/components/generators/openai_utils.py @@ -13,16 +13,11 @@ def _convert_message_to_openai_format(message: ChatMessage) -> Dict[str, str]: See the [API reference](https://platform.openai.com/docs/api-reference/chat/create) for details. - :returns: A dictionary with the following key: + :returns: A dictionary with the following keys: - `role` - `content` - - `name` (optional) """ if message.text is None: raise ValueError(f"The provided ChatMessage has no text. ChatMessage: {message}") - openai_msg = {"role": message.role.value, "content": message.text} - if message.name: - openai_msg["name"] = message.name - - return openai_msg + return {"role": message.role.value, "content": message.text} diff --git a/haystack/dataclasses/__init__.py b/haystack/dataclasses/__init__.py index 231ce80713..91e8f0408f 100644 --- a/haystack/dataclasses/__init__.py +++ b/haystack/dataclasses/__init__.py @@ -4,7 +4,7 @@ from haystack.dataclasses.answer import Answer, ExtractedAnswer, GeneratedAnswer from haystack.dataclasses.byte_stream import ByteStream -from haystack.dataclasses.chat_message import ChatMessage, ChatRole +from haystack.dataclasses.chat_message import ChatMessage, ChatRole, TextContent, ToolCall, ToolCallResult from haystack.dataclasses.document import Document from haystack.dataclasses.sparse_embedding import SparseEmbedding from haystack.dataclasses.streaming_chunk import StreamingChunk @@ -17,6 +17,9 @@ "ByteStream", "ChatMessage", "ChatRole", + "ToolCall", + "ToolCallResult", + "TextContent", "StreamingChunk", "SparseEmbedding", ] diff --git a/haystack/dataclasses/chat_message.py b/haystack/dataclasses/chat_message.py index fb15ee6f5e..5aadb9f752 100644 --- a/haystack/dataclasses/chat_message.py +++ b/haystack/dataclasses/chat_message.py @@ -5,104 +5,318 @@ import warnings from dataclasses import asdict, dataclass, field from enum import Enum -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional, Sequence, Union + +LEGACY_INIT_PARAMETERS = {"role", "content", "meta", "name"} class ChatRole(str, Enum): - """Enumeration representing the roles within a chat.""" + """ + Enumeration representing the roles within a chat. + """ - ASSISTANT = "assistant" + #: The user role. A message from the user contains only text. USER = "user" + + #: The system role. A message from the system contains only text. SYSTEM = "system" + + #: The assistant role. A message from the assistant can contain text and Tool calls. It can also store metadata. + ASSISTANT = "assistant" + + #: The tool role. A message from a tool contains the result of a Tool invocation. + TOOL = "tool" + + #: The function role. Deprecated in favor of `TOOL`. FUNCTION = "function" + @staticmethod + def from_str(string: str) -> "ChatRole": + """ + Convert a string to a ChatRole enum. + """ + enum_map = {e.value: e for e in ChatRole} + role = enum_map.get(string) + if role is None: + msg = f"Unknown chat role '{string}'. Supported roles are: {list(enum_map.keys())}" + raise ValueError(msg) + return role + + +@dataclass +class ToolCall: + """ + Represents a Tool call prepared by the model, usually contained in an assistant message. + + :param id: The ID of the Tool call. + :param tool_name: The name of the Tool to call. + :param arguments: The arguments to call the Tool with. + """ + + tool_name: str + arguments: Dict[str, Any] + id: Optional[str] = None # noqa: A003 + + +@dataclass +class ToolCallResult: + """ + Represents the result of a Tool invocation. + + :param result: The result of the Tool invocation. + :param origin: The Tool call that produced this result. + :param error: Whether the Tool invocation resulted in an error. + """ + + result: str + origin: ToolCall + error: bool + + +@dataclass +class TextContent: + """ + The textual content of a chat message. + + :param text: The text content of the message. + """ + + text: str + + +ChatMessageContentT = Union[TextContent, ToolCall, ToolCallResult] + @dataclass class ChatMessage: """ Represents a message in a LLM chat conversation. - :param content: The text content of the message. - :param role: The role of the entity sending the message. - :param name: The name of the function being called (only applicable for role FUNCTION). - :param meta: Additional metadata associated with the message. + Use the `from_assistant`, `from_user`, `from_system`, and `from_tool` class methods to create a ChatMessage. """ - content: str - role: ChatRole - name: Optional[str] - meta: Dict[str, Any] = field(default_factory=dict, hash=False) + _role: ChatRole + _content: Sequence[ChatMessageContentT] + _name: Optional[str] = None + _meta: Dict[str, Any] = field(default_factory=dict, hash=False) - @property - def text(self) -> Optional[str]: + def __new__(cls, *args, **kwargs): """ - Returns the textual content of the message. + This method is reimplemented to make the changes to the `ChatMessage` dataclass more visible. """ - # Currently, this property mirrors the `content` attribute. This will change in 2.9.0. - # The current actual return type is str. We are using Optional[str] to be ready for 2.9.0, - # when None will be a valid value for `text`. - return object.__getattribute__(self, "content") + + general_msg = ( + "Use the `from_assistant`, `from_user`, `from_system`, and `from_tool` class methods to create a " + "ChatMessage. For more information about the new API and how to migrate, see the documentation:" + " https://docs.haystack.deepset.ai/docs/data-classes#chatmessage" + ) + + if any(param in kwargs for param in LEGACY_INIT_PARAMETERS): + raise TypeError( + "The `role`, `content`, `meta`, and `name` init parameters of `ChatMessage` have been removed. " + f"{general_msg}" + ) + + allowed_content_types = (TextContent, ToolCall, ToolCallResult) + if len(args) > 1 and not isinstance(args[1], allowed_content_types): + raise TypeError( + "The `_content` parameter of `ChatMessage` must be one of the following types: " + f"{', '.join(t.__name__ for t in allowed_content_types)}. " + f"{general_msg}" + ) + + return super(ChatMessage, cls).__new__(cls) + + def __post_init__(self): + if self._role == ChatRole.FUNCTION: + msg = "The `FUNCTION` role has been deprecated in favor of `TOOL` and will be removed in 2.10.0. " + warnings.warn(msg, DeprecationWarning) def __getattribute__(self, name): - # this method is reimplemented to warn about the deprecation of the `content` attribute + """ + This method is reimplemented to make the `content` attribute removal more visible. + """ + if name == "content": msg = ( - "The `content` attribute of `ChatMessage` will be removed in Haystack 2.9.0. " - "Use the `text` property to access the textual value." + "The `content` attribute of `ChatMessage` has been removed. " + "Use the `text` property to access the textual value. " + "For more information about the new API and how to migrate, see the documentation: " + "https://docs.haystack.deepset.ai/docs/data-classes#chatmessage" ) - warnings.warn(msg, DeprecationWarning) + raise AttributeError(msg) return object.__getattribute__(self, name) - def is_from(self, role: ChatRole) -> bool: + def __len__(self): + return len(self._content) + + @property + def role(self) -> ChatRole: + """ + Returns the role of the entity sending the message. + """ + return self._role + + @property + def meta(self) -> Dict[str, Any]: + """ + Returns the metadata associated with the message. + """ + return self._meta + + @property + def name(self) -> Optional[str]: + """ + Returns the name associated with the message. + """ + return self._name + + @property + def texts(self) -> List[str]: + """ + Returns the list of all texts contained in the message. + """ + return [content.text for content in self._content if isinstance(content, TextContent)] + + @property + def text(self) -> Optional[str]: + """ + Returns the first text contained in the message. + """ + if texts := self.texts: + return texts[0] + return None + + @property + def tool_calls(self) -> List[ToolCall]: + """ + Returns the list of all Tool calls contained in the message. + """ + return [content for content in self._content if isinstance(content, ToolCall)] + + @property + def tool_call(self) -> Optional[ToolCall]: + """ + Returns the first Tool call contained in the message. + """ + if tool_calls := self.tool_calls: + return tool_calls[0] + return None + + @property + def tool_call_results(self) -> List[ToolCallResult]: + """ + Returns the list of all Tool call results contained in the message. + """ + return [content for content in self._content if isinstance(content, ToolCallResult)] + + @property + def tool_call_result(self) -> Optional[ToolCallResult]: + """ + Returns the first Tool call result contained in the message. + """ + if tool_call_results := self.tool_call_results: + return tool_call_results[0] + return None + + def is_from(self, role: Union[ChatRole, str]) -> bool: """ Check if the message is from a specific role. :param role: The role to check against. :returns: True if the message is from the specified role, False otherwise. """ - return self.role == role + if isinstance(role, str): + role = ChatRole.from_str(role) + return self._role == role @classmethod - def from_assistant(cls, content: str, meta: Optional[Dict[str, Any]] = None) -> "ChatMessage": + def from_user(cls, text: str, meta: Optional[Dict[str, Any]] = None, name: Optional[str] = None) -> "ChatMessage": """ - Create a message from the assistant. + Create a message from the user. - :param content: The text content of the message. + :param text: The text content of the message. :param meta: Additional metadata associated with the message. + :param name: An optional name for the participant. This field is only supported by OpenAI. :returns: A new ChatMessage instance. """ - return cls(content, ChatRole.ASSISTANT, None, meta or {}) + return cls(_role=ChatRole.USER, _content=[TextContent(text=text)], _meta=meta or {}, _name=name) @classmethod - def from_user(cls, content: str) -> "ChatMessage": + def from_system(cls, text: str, meta: Optional[Dict[str, Any]] = None, name: Optional[str] = None) -> "ChatMessage": """ - Create a message from the user. + Create a message from the system. - :param content: The text content of the message. + :param text: The text content of the message. + :param meta: Additional metadata associated with the message. + :param name: An optional name for the participant. This field is only supported by OpenAI. :returns: A new ChatMessage instance. """ - return cls(content, ChatRole.USER, None) + return cls(_role=ChatRole.SYSTEM, _content=[TextContent(text=text)], _meta=meta or {}, _name=name) @classmethod - def from_system(cls, content: str) -> "ChatMessage": + def from_assistant( + cls, + text: Optional[str] = None, + meta: Optional[Dict[str, Any]] = None, + name: Optional[str] = None, + tool_calls: Optional[List[ToolCall]] = None, + ) -> "ChatMessage": """ - Create a message from the system. + Create a message from the assistant. - :param content: The text content of the message. + :param text: The text content of the message. + :param meta: Additional metadata associated with the message. + :param tool_calls: The Tool calls to include in the message. + :param name: An optional name for the participant. This field is only supported by OpenAI. :returns: A new ChatMessage instance. """ - return cls(content, ChatRole.SYSTEM, None) + content: List[ChatMessageContentT] = [] + if text is not None: + content.append(TextContent(text=text)) + if tool_calls: + content.extend(tool_calls) + + return cls(_role=ChatRole.ASSISTANT, _content=content, _meta=meta or {}, _name=name) + + @classmethod + def from_tool( + cls, tool_result: str, origin: ToolCall, error: bool = False, meta: Optional[Dict[str, Any]] = None + ) -> "ChatMessage": + """ + Create a message from a Tool. + + :param tool_result: The result of the Tool invocation. + :param origin: The Tool call that produced this result. + :param error: Whether the Tool invocation resulted in an error. + :param meta: Additional metadata associated with the message. + :returns: A new ChatMessage instance. + """ + return cls( + _role=ChatRole.TOOL, + _content=[ToolCallResult(result=tool_result, origin=origin, error=error)], + _meta=meta or {}, + ) @classmethod def from_function(cls, content: str, name: str) -> "ChatMessage": """ - Create a message from a function call. + Create a message from a function call. Deprecated in favor of `from_tool`. :param content: The text content of the message. :param name: The name of the function being called. :returns: A new ChatMessage instance. """ - return cls(content, ChatRole.FUNCTION, name) + msg = ( + "The `from_function` method is deprecated and will be removed in version 2.10.0. " + "Its behavior has changed: it now attempts to convert legacy function messages to tool messages. " + "This conversion is not guaranteed to succeed in all scenarios. " + "Please migrate to `ChatMessage.from_tool` and carefully verify the results if you " + "continue to use this method." + ) + warnings.warn(msg) + + return cls.from_tool(content, ToolCall(id=None, tool_name=name, arguments={}), error=False) def to_dict(self) -> Dict[str, Any]: """ @@ -111,10 +325,23 @@ def to_dict(self) -> Dict[str, Any]: :returns: Serialized version of the object. """ - data = asdict(self) - data["role"] = self.role.value + serialized: Dict[str, Any] = {} + serialized["_role"] = self._role.value + serialized["_meta"] = self._meta + serialized["_name"] = self._name + content: List[Dict[str, Any]] = [] + for part in self._content: + if isinstance(part, TextContent): + content.append({"text": part.text}) + elif isinstance(part, ToolCall): + content.append({"tool_call": asdict(part)}) + elif isinstance(part, ToolCallResult): + content.append({"tool_call_result": asdict(part)}) + else: + raise TypeError(f"Unsupported type in ChatMessage content: `{type(part).__name__}` for `{part}`.") - return data + serialized["_content"] = content + return serialized @classmethod def from_dict(cls, data: Dict[str, Any]) -> "ChatMessage": @@ -126,6 +353,31 @@ def from_dict(cls, data: Dict[str, Any]) -> "ChatMessage": :returns: The created object. """ - data["role"] = ChatRole(data["role"]) + if any(param in data for param in LEGACY_INIT_PARAMETERS): + raise TypeError( + "The `role`, `content`, `meta`, and `name` init parameters of `ChatMessage` have been removed. " + "For more information about the new API and how to migrate, see the documentation: " + "https://docs.haystack.deepset.ai/docs/data-classes#chatmessage" + ) + + data["_role"] = ChatRole(data["_role"]) + + content: List[ChatMessageContentT] = [] + + for part in data["_content"]: + if "text" in part: + content.append(TextContent(text=part["text"])) + elif "tool_call" in part: + content.append(ToolCall(**part["tool_call"])) + elif "tool_call_result" in part: + result = part["tool_call_result"]["result"] + origin = ToolCall(**part["tool_call_result"]["origin"]) + error = part["tool_call_result"]["error"] + tcr = ToolCallResult(result=result, origin=origin, error=error) + content.append(tcr) + else: + raise ValueError(f"Unsupported content in serialized ChatMessage: `{part}`") + + data["_content"] = content return cls(**data) diff --git a/releasenotes/notes/new-chatmessage-7f47d5bdeb6ad6f5.yaml b/releasenotes/notes/new-chatmessage-7f47d5bdeb6ad6f5.yaml new file mode 100644 index 0000000000..b9e590e590 --- /dev/null +++ b/releasenotes/notes/new-chatmessage-7f47d5bdeb6ad6f5.yaml @@ -0,0 +1,23 @@ +--- +highlights: > + We are introducing a refactored ChatMessage dataclass. It is more flexible, future-proof, and compatible with + different types of content: text, tool calls, tool calls results. + For information about the new API and how to migrate, see the documentation: + https://docs.haystack.deepset.ai/docs/data-classes#chatmessage +upgrade: + - | + The refactoring of the ChatMessage dataclass includes some breaking changes, involving ChatMessage creation and + accessing attributes. If you have a Pipeline containing a ChatPromptBuilder, serialized using Haystack<2.9.0, + deserialization may break. + For detailed information about the changes and how to migrate, see the documentation: + https://docs.haystack.deepset.ai/docs/data-classes#chatmessage +features: + - | + Changed the ChatMessage dataclass to support different types of content, including tool calls, and tool call + results. +deprecations: + - | + The function role and ChatMessage.from_function class method have been deprecated and will be removed in + Haystack 2.10.0. ChatMessage.from_function also attempts to produce a valid tool message. + For more information, see the documentation: + https://docs.haystack.deepset.ai/docs/data-classes#chatmessage diff --git a/test/components/builders/test_chat_prompt_builder.py b/test/components/builders/test_chat_prompt_builder.py index 5e1ae6132e..a8fb8bc5b8 100644 --- a/test/components/builders/test_chat_prompt_builder.py +++ b/test/components/builders/test_chat_prompt_builder.py @@ -13,8 +13,8 @@ class TestChatPromptBuilder: def test_init(self): builder = ChatPromptBuilder( template=[ - ChatMessage.from_user(content="This is a {{ variable }}"), - ChatMessage.from_system(content="This is a {{ variable2 }}"), + ChatMessage.from_user("This is a {{ variable }}"), + ChatMessage.from_system("This is a {{ variable2 }}"), ] ) assert builder.required_variables == [] @@ -531,8 +531,13 @@ def test_to_dict(self): "type": "haystack.components.builders.chat_prompt_builder.ChatPromptBuilder", "init_parameters": { "template": [ - {"content": "text and {var}", "role": "user", "name": None, "meta": {}}, - {"content": "content {required_var}", "role": "assistant", "name": None, "meta": {}}, + {"_content": [{"text": "text and {var}"}], "_role": "user", "_meta": {}, "_name": None}, + { + "_content": [{"text": "content {required_var}"}], + "_role": "assistant", + "_meta": {}, + "_name": None, + }, ], "variables": ["var", "required_var"], "required_variables": ["required_var"], @@ -545,8 +550,13 @@ def test_from_dict(self): "type": "haystack.components.builders.chat_prompt_builder.ChatPromptBuilder", "init_parameters": { "template": [ - {"content": "text and {var}", "role": "user", "name": None, "meta": {}}, - {"content": "content {required_var}", "role": "assistant", "name": None, "meta": {}}, + {"_content": [{"text": "text and {var}"}], "_role": "user", "_meta": {}, "_name": None}, + { + "_content": [{"text": "content {required_var}"}], + "_role": "assistant", + "_meta": {}, + "_name": None, + }, ], "variables": ["var", "required_var"], "required_variables": ["required_var"], diff --git a/test/components/generators/chat/test_hugging_face_api.py b/test/components/generators/chat/test_hugging_face_api.py index 3d7fd617c0..e60ec863ab 100644 --- a/test/components/generators/chat/test_hugging_face_api.py +++ b/test/components/generators/chat/test_hugging_face_api.py @@ -68,13 +68,6 @@ def test_convert_message_to_hfapi_format(): message = ChatMessage.from_user("I have a question") assert _convert_message_to_hfapi_format(message) == {"role": "user", "content": "I have a question"} - message = ChatMessage.from_function("Function call", "function_name") - assert _convert_message_to_hfapi_format(message) == { - "role": "function", - "content": "Function call", - "name": "function_name", - } - class TestHuggingFaceAPIGenerator: def test_init_invalid_api_type(self): diff --git a/test/components/generators/test_openai_utils.py b/test/components/generators/test_openai_utils.py index 226b32f811..916a3e3d70 100644 --- a/test/components/generators/test_openai_utils.py +++ b/test/components/generators/test_openai_utils.py @@ -14,10 +14,3 @@ def test_convert_message_to_openai_format(): message = ChatMessage.from_user("I have a question") assert _convert_message_to_openai_format(message) == {"role": "user", "content": "I have a question"} - - message = ChatMessage.from_function("Function call", "function_name") - assert _convert_message_to_openai_format(message) == { - "role": "function", - "content": "Function call", - "name": "function_name", - } diff --git a/test/components/routers/test_conditional_router.py b/test/components/routers/test_conditional_router.py index e0f3552319..66d941b645 100644 --- a/test/components/routers/test_conditional_router.py +++ b/test/components/routers/test_conditional_router.py @@ -349,7 +349,7 @@ def test_unsafe(self): ] router = ConditionalRouter(routes, unsafe=True) streams = [1] - message = ChatMessage.from_user(content="This is a message") + message = ChatMessage.from_user("This is a message") res = router.run(streams=streams, message=message) assert res == {"message": message} @@ -370,7 +370,7 @@ def test_validate_output_type_without_unsafe(self): ] router = ConditionalRouter(routes, validate_output_type=True) streams = [1] - message = ChatMessage.from_user(content="This is a message") + message = ChatMessage.from_user("This is a message") with pytest.raises(ValueError, match="Route 'message' type doesn't match expected type"): router.run(streams=streams, message=message) @@ -391,7 +391,7 @@ def test_validate_output_type_with_unsafe(self): ] router = ConditionalRouter(routes, unsafe=True, validate_output_type=True) streams = [1] - message = ChatMessage.from_user(content="This is a message") + message = ChatMessage.from_user("This is a message") res = router.run(streams=streams, message=message) assert isinstance(res["message"], ChatMessage) diff --git a/test/core/pipeline/features/test_run.py b/test/core/pipeline/features/test_run.py index d7001a0187..8f07dfec99 100644 --- a/test/core/pipeline/features/test_run.py +++ b/test/core/pipeline/features/test_run.py @@ -1657,7 +1657,7 @@ def run(self, query: str): class ToolExtractor: @component.output_types(output=List[str]) def run(self, messages: List[ChatMessage]): - prompt: str = messages[-1].content + prompt: str = messages[-1].text lines = prompt.strip().split("\n") for line in reversed(lines): pattern = r"Action:\s*(\w+)\[(.*?)\]" @@ -1678,14 +1678,14 @@ def __init__(self, suffix: str = ""): @component.output_types(output=List[ChatMessage]) def run(self, replies: List[ChatMessage], current_prompt: List[ChatMessage]): - content = current_prompt[-1].content + replies[-1].content + self._suffix + content = current_prompt[-1].text + replies[-1].text + self._suffix return {"output": [ChatMessage.from_user(content)]} @component class SearchOutputAdapter: @component.output_types(output=List[ChatMessage]) def run(self, replies: List[ChatMessage]): - content = f"Observation: {replies[-1].content}\n" + content = f"Observation: {replies[-1].text}\n" return {"output": [ChatMessage.from_assistant(content)]} pipeline.add_component("prompt_concatenator_after_action", PromptConcatenator()) diff --git a/test/dataclasses/test_chat_message.py b/test/dataclasses/test_chat_message.py index 30ad51630e..832617e712 100644 --- a/test/dataclasses/test_chat_message.py +++ b/test/dataclasses/test_chat_message.py @@ -4,64 +4,240 @@ import pytest from transformers import AutoTokenizer -from haystack.dataclasses import ChatMessage, ChatRole +from haystack.dataclasses.chat_message import ChatMessage, ChatRole, ToolCall, ToolCallResult, TextContent from haystack.components.generators.openai_utils import _convert_message_to_openai_format +def test_tool_call_init(): + tc = ToolCall(id="123", tool_name="mytool", arguments={"a": 1}) + assert tc.id == "123" + assert tc.tool_name == "mytool" + assert tc.arguments == {"a": 1} + + +def test_tool_call_result_init(): + tcr = ToolCallResult(result="result", origin=ToolCall(id="123", tool_name="mytool", arguments={"a": 1}), error=True) + assert tcr.result == "result" + assert tcr.origin == ToolCall(id="123", tool_name="mytool", arguments={"a": 1}) + assert tcr.error + + +def test_text_content_init(): + tc = TextContent(text="Hello") + assert tc.text == "Hello" + + def test_from_assistant_with_valid_content(): - content = "Hello, how can I assist you?" - message = ChatMessage.from_assistant(content) - assert message.content == content - assert message.text == content + text = "Hello, how can I assist you?" + message = ChatMessage.from_assistant(text) + assert message.role == ChatRole.ASSISTANT + assert message._content == [TextContent(text)] + assert message.name is None + + assert message.text == text + assert message.texts == [text] + + assert not message.tool_calls + assert not message.tool_call + assert not message.tool_call_results + assert not message.tool_call_result + + +def test_from_assistant_with_tool_calls(): + tool_calls = [ + ToolCall(id="123", tool_name="mytool", arguments={"a": 1}), + ToolCall(id="456", tool_name="mytool2", arguments={"b": 2}), + ] + + message = ChatMessage.from_assistant(tool_calls=tool_calls) + + assert message.role == ChatRole.ASSISTANT + assert message._content == tool_calls + + assert message.tool_calls == tool_calls + assert message.tool_call == tool_calls[0] + + assert not message.texts + assert not message.text + assert not message.tool_call_results + assert not message.tool_call_result def test_from_user_with_valid_content(): - content = "I have a question." - message = ChatMessage.from_user(content) - assert message.content == content - assert message.text == content + text = "I have a question." + message = ChatMessage.from_user(text=text) + assert message.role == ChatRole.USER + assert message._content == [TextContent(text)] + assert message.name is None + + assert message.text == text + assert message.texts == [text] + + assert not message.tool_calls + assert not message.tool_call + assert not message.tool_call_results + assert not message.tool_call_result + + +def test_from_user_with_name(): + text = "I have a question." + message = ChatMessage.from_user(text=text, name="John") + + assert message.name == "John" + assert message.role == ChatRole.USER + assert message._content == [TextContent(text)] def test_from_system_with_valid_content(): - content = "System message." - message = ChatMessage.from_system(content) - assert message.content == content - assert message.text == content + text = "I have a question." + message = ChatMessage.from_system(text=text) + assert message.role == ChatRole.SYSTEM + assert message._content == [TextContent(text)] + assert message.text == text + assert message.texts == [text] -def test_with_empty_content(): - message = ChatMessage.from_user("") - assert message.content == "" - assert message.text == "" - assert message.role == ChatRole.USER + assert not message.tool_calls + assert not message.tool_call + assert not message.tool_call_results + assert not message.tool_call_result + + +def test_from_tool_with_valid_content(): + tool_result = "Tool result" + origin = ToolCall(id="123", tool_name="mytool", arguments={"a": 1}) + message = ChatMessage.from_tool(tool_result, origin, error=False) + + tcr = ToolCallResult(result=tool_result, origin=origin, error=False) + + assert message._content == [tcr] + assert message.role == ChatRole.TOOL + + assert message.tool_call_result == tcr + assert message.tool_call_results == [tcr] + + assert not message.tool_calls + assert not message.tool_call + assert not message.texts + assert not message.text + + +def test_multiple_text_segments(): + texts = [TextContent(text="Hello"), TextContent(text="World")] + message = ChatMessage(_role=ChatRole.USER, _content=texts) + + assert message.texts == ["Hello", "World"] + assert len(message) == 2 + + +def test_mixed_content(): + content = [TextContent(text="Hello"), ToolCall(id="123", tool_name="mytool", arguments={"a": 1})] + + message = ChatMessage(_role=ChatRole.ASSISTANT, _content=content) + assert len(message) == 2 + assert message.texts == ["Hello"] + assert message.text == "Hello" -def test_from_function_with_empty_name(): - content = "Function call" - message = ChatMessage.from_function(content, "") - assert message.content == content - assert message.text == content - assert message.name == "" - assert message.role == ChatRole.FUNCTION + assert message.tool_calls == [content[1]] + assert message.tool_call == content[1] -def test_to_openai_format(): - message = ChatMessage.from_system("You are good assistant") - assert _convert_message_to_openai_format(message) == {"role": "system", "content": "You are good assistant"} +def test_from_function(): + # check warning is raised + with pytest.warns(): + message = ChatMessage.from_function("Result of function invocation", "my_function") - message = ChatMessage.from_user("I have a question") - assert _convert_message_to_openai_format(message) == {"role": "user", "content": "I have a question"} + assert message.role == ChatRole.TOOL + assert message.tool_call_result == ToolCallResult( + result="Result of function invocation", + origin=ToolCall(id=None, tool_name="my_function", arguments={}), + error=False, + ) + + +def test_serde(): + # the following message is created just for testing purposes and does not make sense in a real use case + + role = ChatRole.ASSISTANT + + text_content = TextContent(text="Hello") + tool_call = ToolCall(id="123", tool_name="mytool", arguments={"a": 1}) + tool_call_result = ToolCallResult(result="result", origin=tool_call, error=False) + meta = {"some": "info"} - message = ChatMessage.from_function("Function call", "function_name") - assert _convert_message_to_openai_format(message) == { - "role": "function", - "content": "Function call", - "name": "function_name", + message = ChatMessage(_role=role, _content=[text_content, tool_call, tool_call_result], _meta=meta) + + serialized_message = message.to_dict() + assert serialized_message == { + "_content": [ + {"text": "Hello"}, + {"tool_call": {"id": "123", "tool_name": "mytool", "arguments": {"a": 1}}}, + { + "tool_call_result": { + "result": "result", + "error": False, + "origin": {"id": "123", "tool_name": "mytool", "arguments": {"a": 1}}, + } + }, + ], + "_role": "assistant", + "_name": None, + "_meta": {"some": "info"}, } + deserialized_message = ChatMessage.from_dict(serialized_message) + assert deserialized_message == message + + +def test_to_dict_with_invalid_content_type(): + text_content = TextContent(text="Hello") + invalid_content = "invalid" + + message = ChatMessage(_role=ChatRole.ASSISTANT, _content=[text_content, invalid_content]) + + with pytest.raises(TypeError): + message.to_dict() + + +def test_from_dict_with_invalid_content_type(): + data = {"_role": "assistant", "_content": [{"text": "Hello"}, "invalid"]} + with pytest.raises(ValueError): + ChatMessage.from_dict(data) + + data = {"_role": "assistant", "_content": [{"text": "Hello"}, {"invalid": "invalid"}]} + with pytest.raises(ValueError): + ChatMessage.from_dict(data) + + +def test_from_dict_with_legacy_init_parameters(): + with pytest.raises(TypeError): + ChatMessage.from_dict({"role": "user", "content": "This is a message"}) + + +def test_chat_message_content_attribute_removed(): + message = ChatMessage.from_user(text="This is a message") + with pytest.raises(AttributeError): + message.content + + +def test_chat_message_init_parameters_removed(): + with pytest.raises(TypeError): + ChatMessage(role="irrelevant", content="This is a message") + + +def test_chat_message_init_content_parameter_type(): + with pytest.raises(TypeError): + ChatMessage(ChatRole.USER, "This is a message") + + +def test_chat_message_function_role_deprecated(): + with pytest.warns(DeprecationWarning): + ChatMessage(ChatRole.FUNCTION, TextContent("This is a message")) + @pytest.mark.integration def test_apply_chat_templating_on_chat_message(): @@ -93,40 +269,3 @@ def test_apply_custom_chat_templating_on_chat_message(): formatted_messages, chat_template=anthropic_template, tokenize=False ) assert tokenized_messages == "You are good assistant\nHuman: I have a question\nAssistant:" - - -def test_to_dict(): - content = "content" - role = "user" - meta = {"some": "some"} - - message = ChatMessage.from_user(content) - message.meta.update(meta) - - assert message.text == content - assert message.to_dict() == {"content": content, "role": role, "name": None, "meta": meta} - - -def test_from_dict(): - assert ChatMessage.from_dict(data={"content": "text", "role": "user", "name": None}) == ChatMessage.from_user( - "text" - ) - - -def test_from_dict_with_meta(): - data = {"content": "text", "role": "assistant", "name": None, "meta": {"something": "something"}} - assert ChatMessage.from_dict(data) == ChatMessage.from_assistant("text", meta={"something": "something"}) - - -def test_content_deprecation_warning(recwarn): - message = ChatMessage.from_user("my message") - - # accessing the content attribute triggers the deprecation warning - _ = message.content - assert len(recwarn) == 1 - wrn = recwarn.pop(DeprecationWarning) - assert "`content` attribute" in wrn.message.args[0] - - # accessing the text property does not trigger a warning - assert message.text == "my message" - assert len(recwarn) == 0 From 96b4a1d2fd82f3e072060d8c11ae3e1fc230d681 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Wed, 18 Dec 2024 12:36:44 +0100 Subject: [PATCH 10/14] feat: `Tool` dataclass - unified abstraction to represent tools (#8652) * draft * del HF token in tests * adaptations * progress * fix type * import sorting * more control on deserialization * release note * improvements * support name field * fix chatpromptbuilder test * port Tool from experimental * release note * docs upd * Update tool.py --------- Co-authored-by: Daria Fokina --- docs/pydoc/config/data_classess_api.yml | 2 +- haystack/dataclasses/__init__.py | 2 + haystack/dataclasses/tool.py | 243 ++++++++++++++ pyproject.toml | 3 +- .../tool-dataclass-12756077bbfea3a1.yaml | 8 + test/dataclasses/test_tool.py | 305 ++++++++++++++++++ 6 files changed, 561 insertions(+), 2 deletions(-) create mode 100644 haystack/dataclasses/tool.py create mode 100644 releasenotes/notes/tool-dataclass-12756077bbfea3a1.yaml create mode 100644 test/dataclasses/test_tool.py diff --git a/docs/pydoc/config/data_classess_api.yml b/docs/pydoc/config/data_classess_api.yml index a67f28db9d..71ea77513a 100644 --- a/docs/pydoc/config/data_classess_api.yml +++ b/docs/pydoc/config/data_classess_api.yml @@ -2,7 +2,7 @@ loaders: - type: haystack_pydoc_tools.loaders.CustomPythonLoader search_path: [../../../haystack/dataclasses] modules: - ["answer", "byte_stream", "chat_message", "document", "streaming_chunk", "sparse_embedding"] + ["answer", "byte_stream", "chat_message", "document", "streaming_chunk", "sparse_embedding", "tool"] ignore_when_discovered: ["__init__"] processors: - type: filter diff --git a/haystack/dataclasses/__init__.py b/haystack/dataclasses/__init__.py index 91e8f0408f..97f253e805 100644 --- a/haystack/dataclasses/__init__.py +++ b/haystack/dataclasses/__init__.py @@ -8,6 +8,7 @@ from haystack.dataclasses.document import Document from haystack.dataclasses.sparse_embedding import SparseEmbedding from haystack.dataclasses.streaming_chunk import StreamingChunk +from haystack.dataclasses.tool import Tool __all__ = [ "Document", @@ -22,4 +23,5 @@ "TextContent", "StreamingChunk", "SparseEmbedding", + "Tool", ] diff --git a/haystack/dataclasses/tool.py b/haystack/dataclasses/tool.py new file mode 100644 index 0000000000..3df3fd18f2 --- /dev/null +++ b/haystack/dataclasses/tool.py @@ -0,0 +1,243 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import inspect +from dataclasses import asdict, dataclass +from typing import Any, Callable, Dict, Optional + +from pydantic import create_model + +from haystack.lazy_imports import LazyImport +from haystack.utils import deserialize_callable, serialize_callable + +with LazyImport(message="Run 'pip install jsonschema'") as jsonschema_import: + from jsonschema import Draft202012Validator + from jsonschema.exceptions import SchemaError + + +class ToolInvocationError(Exception): + """ + Exception raised when a Tool invocation fails. + """ + + pass + + +class SchemaGenerationError(Exception): + """ + Exception raised when automatic schema generation fails. + """ + + pass + + +@dataclass +class Tool: + """ + Data class representing a Tool that Language Models can prepare a call for. + + Accurate definitions of the textual attributes such as `name` and `description` + are important for the Language Model to correctly prepare the call. + + :param name: + Name of the Tool. + :param description: + Description of the Tool. + :param parameters: + A JSON schema defining the parameters expected by the Tool. + :param function: + The function that will be invoked when the Tool is called. + """ + + name: str + description: str + parameters: Dict[str, Any] + function: Callable + + def __post_init__(self): + jsonschema_import.check() + # Check that the parameters define a valid JSON schema + try: + Draft202012Validator.check_schema(self.parameters) + except SchemaError as e: + raise ValueError("The provided parameters do not define a valid JSON schema") from e + + @property + def tool_spec(self) -> Dict[str, Any]: + """ + Return the Tool specification to be used by the Language Model. + """ + return {"name": self.name, "description": self.description, "parameters": self.parameters} + + def invoke(self, **kwargs) -> Any: + """ + Invoke the Tool with the provided keyword arguments. + """ + + try: + result = self.function(**kwargs) + except Exception as e: + raise ToolInvocationError(f"Failed to invoke Tool `{self.name}` with parameters {kwargs}") from e + return result + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the Tool to a dictionary. + + :returns: + Dictionary with serialized data. + """ + + serialized = asdict(self) + serialized["function"] = serialize_callable(self.function) + return serialized + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "Tool": + """ + Deserializes the Tool from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Deserialized Tool. + """ + data["function"] = deserialize_callable(data["function"]) + return cls(**data) + + @classmethod + def from_function(cls, function: Callable, name: Optional[str] = None, description: Optional[str] = None) -> "Tool": + """ + Create a Tool instance from a function. + + ### Usage example + + ```python + from typing import Annotated, Literal + from haystack.dataclasses import Tool + + def get_weather( + city: Annotated[str, "the city for which to get the weather"] = "Munich", + unit: Annotated[Literal["Celsius", "Fahrenheit"], "the unit for the temperature"] = "Celsius"): + '''A simple function to get the current weather for a location.''' + return f"Weather report for {city}: 20 {unit}, sunny" + + tool = Tool.from_function(get_weather) + + print(tool) + >>> Tool(name='get_weather', description='A simple function to get the current weather for a location.', + >>> parameters={ + >>> 'type': 'object', + >>> 'properties': { + >>> 'city': {'type': 'string', 'description': 'the city for which to get the weather', 'default': 'Munich'}, + >>> 'unit': { + >>> 'type': 'string', + >>> 'enum': ['Celsius', 'Fahrenheit'], + >>> 'description': 'the unit for the temperature', + >>> 'default': 'Celsius', + >>> }, + >>> } + >>> }, + >>> function=) + ``` + + :param function: + The function to be converted into a Tool. + The function must include type hints for all parameters. + If a parameter is annotated using `typing.Annotated`, its metadata will be used as parameter description. + :param name: + The name of the Tool. If not provided, the name of the function will be used. + :param description: + The description of the Tool. If not provided, the docstring of the function will be used. + To intentionally leave the description empty, pass an empty string. + + :returns: + The Tool created from the function. + + :raises ValueError: + If any parameter of the function lacks a type hint. + :raises SchemaGenerationError: + If there is an error generating the JSON schema for the Tool. + """ + + tool_description = description if description is not None else (function.__doc__ or "") + + signature = inspect.signature(function) + + # collect fields (types and defaults) and descriptions from function parameters + fields: Dict[str, Any] = {} + descriptions = {} + + for param_name, param in signature.parameters.items(): + if param.annotation is param.empty: + raise ValueError(f"Function '{function.__name__}': parameter '{param_name}' does not have a type hint.") + + # if the parameter has not a default value, Pydantic requires an Ellipsis (...) + # to explicitly indicate that the parameter is required + default = param.default if param.default is not param.empty else ... + fields[param_name] = (param.annotation, default) + + if hasattr(param.annotation, "__metadata__"): + descriptions[param_name] = param.annotation.__metadata__[0] + + # create Pydantic model and generate JSON schema + try: + model = create_model(function.__name__, **fields) + schema = model.model_json_schema() + except Exception as e: + raise SchemaGenerationError(f"Failed to create JSON schema for function '{function.__name__}'") from e + + # we don't want to include title keywords in the schema, as they contain redundant information + # there is no programmatic way to prevent Pydantic from adding them, so we remove them later + # see https://github.com/pydantic/pydantic/discussions/8504 + _remove_title_from_schema(schema) + + # add parameters descriptions to the schema + for param_name, param_description in descriptions.items(): + if param_name in schema["properties"]: + schema["properties"][param_name]["description"] = param_description + + return Tool(name=name or function.__name__, description=tool_description, parameters=schema, function=function) + + +def _remove_title_from_schema(schema: Dict[str, Any]): + """ + Remove the 'title' keyword from JSON schema and contained property schemas. + + :param schema: + The JSON schema to remove the 'title' keyword from. + """ + schema.pop("title", None) + + for property_schema in schema["properties"].values(): + for key in list(property_schema.keys()): + if key == "title": + del property_schema[key] + + +def deserialize_tools_inplace(data: Dict[str, Any], key: str = "tools"): + """ + Deserialize Tools in a dictionary inplace. + + :param data: + The dictionary with the serialized data. + :param key: + The key in the dictionary where the Tools are stored. + """ + if key in data: + serialized_tools = data[key] + + if serialized_tools is None: + return + + if not isinstance(serialized_tools, list): + raise TypeError(f"The value of '{key}' is not a list") + + deserialized_tools = [] + for tool in serialized_tools: + if not isinstance(tool, dict): + raise TypeError(f"Serialized tool '{tool}' is not a dictionary") + deserialized_tools.append(Tool.from_dict(tool)) + + data[key] = deserialized_tools diff --git a/pyproject.toml b/pyproject.toml index c41c429ced..c1fddc8704 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ dependencies = [ "tenacity!=8.4.0", "lazy-imports", "openai>=1.56.1", + "pydantic", "Jinja2", "posthog", # telemetry "pyyaml", @@ -113,7 +114,7 @@ extra-dependencies = [ "jsonref", # OpenAPIServiceConnector, OpenAPIServiceToFunctions "openapi3", - # Validation + # JsonSchemaValidator, Tool "jsonschema", # Tracing diff --git a/releasenotes/notes/tool-dataclass-12756077bbfea3a1.yaml b/releasenotes/notes/tool-dataclass-12756077bbfea3a1.yaml new file mode 100644 index 0000000000..b6255ee1a9 --- /dev/null +++ b/releasenotes/notes/tool-dataclass-12756077bbfea3a1.yaml @@ -0,0 +1,8 @@ +--- +highlights: > + We are introducing the `Tool` dataclass: a simple and unified abstraction to represent tools throughout the framework. + By building on this abstraction, we will enable support for tools in Chat Generators, + providing a consistent experience across models. +features: + - | + Added a new `Tool` dataclass to represent a tool for which Language Models can prepare calls. diff --git a/test/dataclasses/test_tool.py b/test/dataclasses/test_tool.py new file mode 100644 index 0000000000..db9719a7f3 --- /dev/null +++ b/test/dataclasses/test_tool.py @@ -0,0 +1,305 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Literal, Optional + +import pytest + +from haystack.dataclasses.tool import ( + SchemaGenerationError, + Tool, + ToolInvocationError, + _remove_title_from_schema, + deserialize_tools_inplace, +) + +try: + from typing import Annotated +except ImportError: + from typing_extensions import Annotated + + +def get_weather_report(city: str) -> str: + return f"Weather report for {city}: 20°C, sunny" + + +parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} + + +def function_with_docstring(city: str) -> str: + """Get weather report for a city.""" + return f"Weather report for {city}: 20°C, sunny" + + +class TestTool: + def test_init(self): + tool = Tool( + name="weather", description="Get weather report", parameters=parameters, function=get_weather_report + ) + + assert tool.name == "weather" + assert tool.description == "Get weather report" + assert tool.parameters == parameters + assert tool.function == get_weather_report + + def test_init_invalid_parameters(self): + parameters = {"type": "invalid", "properties": {"city": {"type": "string"}}} + + with pytest.raises(ValueError): + Tool(name="irrelevant", description="irrelevant", parameters=parameters, function=get_weather_report) + + def test_tool_spec(self): + tool = Tool( + name="weather", description="Get weather report", parameters=parameters, function=get_weather_report + ) + + assert tool.tool_spec == {"name": "weather", "description": "Get weather report", "parameters": parameters} + + def test_invoke(self): + tool = Tool( + name="weather", description="Get weather report", parameters=parameters, function=get_weather_report + ) + + assert tool.invoke(city="Berlin") == "Weather report for Berlin: 20°C, sunny" + + def test_invoke_fail(self): + tool = Tool( + name="weather", description="Get weather report", parameters=parameters, function=get_weather_report + ) + + with pytest.raises(ToolInvocationError): + tool.invoke() + + def test_to_dict(self): + tool = Tool( + name="weather", description="Get weather report", parameters=parameters, function=get_weather_report + ) + + assert tool.to_dict() == { + "name": "weather", + "description": "Get weather report", + "parameters": parameters, + "function": "test_tool.get_weather_report", + } + + def test_from_dict(self): + tool_dict = { + "name": "weather", + "description": "Get weather report", + "parameters": parameters, + "function": "test_tool.get_weather_report", + } + + tool = Tool.from_dict(tool_dict) + + assert tool.name == "weather" + assert tool.description == "Get weather report" + assert tool.parameters == parameters + assert tool.function == get_weather_report + + def test_from_function_description_from_docstring(self): + tool = Tool.from_function(function=function_with_docstring) + + assert tool.name == "function_with_docstring" + assert tool.description == "Get weather report for a city." + assert tool.parameters == {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} + assert tool.function == function_with_docstring + + def test_from_function_with_empty_description(self): + tool = Tool.from_function(function=function_with_docstring, description="") + + assert tool.name == "function_with_docstring" + assert tool.description == "" + assert tool.parameters == {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} + assert tool.function == function_with_docstring + + def test_from_function_with_custom_description(self): + tool = Tool.from_function(function=function_with_docstring, description="custom description") + + assert tool.name == "function_with_docstring" + assert tool.description == "custom description" + assert tool.parameters == {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} + assert tool.function == function_with_docstring + + def test_from_function_with_custom_name(self): + tool = Tool.from_function(function=function_with_docstring, name="custom_name") + + assert tool.name == "custom_name" + assert tool.description == "Get weather report for a city." + assert tool.parameters == {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} + assert tool.function == function_with_docstring + + def test_from_function_missing_type_hint(self): + def function_missing_type_hint(city) -> str: + return f"Weather report for {city}: 20°C, sunny" + + with pytest.raises(ValueError): + Tool.from_function(function=function_missing_type_hint) + + def test_from_function_schema_generation_error(self): + def function_with_invalid_type_hint(city: "invalid") -> str: + return f"Weather report for {city}: 20°C, sunny" + + with pytest.raises(SchemaGenerationError): + Tool.from_function(function=function_with_invalid_type_hint) + + def test_from_function_annotated(self): + def function_with_annotations( + city: Annotated[str, "the city for which to get the weather"] = "Munich", + unit: Annotated[Literal["Celsius", "Fahrenheit"], "the unit for the temperature"] = "Celsius", + nullable_param: Annotated[Optional[str], "a nullable parameter"] = None, + ) -> str: + """A simple function to get the current weather for a location.""" + return f"Weather report for {city}: 20 {unit}, sunny" + + tool = Tool.from_function(function=function_with_annotations) + + assert tool.name == "function_with_annotations" + assert tool.description == "A simple function to get the current weather for a location." + assert tool.parameters == { + "type": "object", + "properties": { + "city": {"type": "string", "description": "the city for which to get the weather", "default": "Munich"}, + "unit": { + "type": "string", + "enum": ["Celsius", "Fahrenheit"], + "description": "the unit for the temperature", + "default": "Celsius", + }, + "nullable_param": { + "anyOf": [{"type": "string"}, {"type": "null"}], + "description": "a nullable parameter", + "default": None, + }, + }, + } + + +def test_deserialize_tools_inplace(): + tool = Tool(name="weather", description="Get weather report", parameters=parameters, function=get_weather_report) + serialized_tool = tool.to_dict() + print(serialized_tool) + + data = {"tools": [serialized_tool.copy()]} + deserialize_tools_inplace(data) + assert data["tools"] == [tool] + + data = {"mytools": [serialized_tool.copy()]} + deserialize_tools_inplace(data, key="mytools") + assert data["mytools"] == [tool] + + data = {"no_tools": 123} + deserialize_tools_inplace(data) + assert data == {"no_tools": 123} + + +def test_deserialize_tools_inplace_failures(): + data = {"key": "value"} + deserialize_tools_inplace(data) + assert data == {"key": "value"} + + data = {"tools": None} + deserialize_tools_inplace(data) + assert data == {"tools": None} + + data = {"tools": "not a list"} + with pytest.raises(TypeError): + deserialize_tools_inplace(data) + + data = {"tools": ["not a dictionary"]} + with pytest.raises(TypeError): + deserialize_tools_inplace(data) + + +def test_remove_title_from_schema(): + complex_schema = { + "properties": { + "parameter1": { + "anyOf": [{"type": "string"}, {"type": "integer"}], + "default": "default_value", + "title": "Parameter1", + }, + "parameter2": { + "default": [1, 2, 3], + "items": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, + "title": "Parameter2", + "type": "array", + }, + "parameter3": { + "anyOf": [ + {"type": "string"}, + {"type": "integer"}, + {"items": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, "type": "array"}, + ], + "default": 42, + "title": "Parameter3", + }, + "parameter4": { + "anyOf": [{"type": "string"}, {"items": {"type": "integer"}, "type": "array"}, {"type": "object"}], + "default": {"key": "value"}, + "title": "Parameter4", + }, + }, + "title": "complex_function", + "type": "object", + } + + _remove_title_from_schema(complex_schema) + + assert complex_schema == { + "properties": { + "parameter1": {"anyOf": [{"type": "string"}, {"type": "integer"}], "default": "default_value"}, + "parameter2": { + "default": [1, 2, 3], + "items": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, + "type": "array", + }, + "parameter3": { + "anyOf": [ + {"type": "string"}, + {"type": "integer"}, + {"items": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, "type": "array"}, + ], + "default": 42, + }, + "parameter4": { + "anyOf": [{"type": "string"}, {"items": {"type": "integer"}, "type": "array"}, {"type": "object"}], + "default": {"key": "value"}, + }, + }, + "type": "object", + } + + +def test_remove_title_from_schema_do_not_remove_title_property(): + """Test that the utility function only removes the 'title' keywords and not the 'title' property (if present).""" + schema = { + "properties": { + "parameter1": {"type": "string", "title": "Parameter1"}, + "title": {"type": "string", "title": "Title"}, + }, + "title": "complex_function", + "type": "object", + } + + _remove_title_from_schema(schema) + + assert schema == {"properties": {"parameter1": {"type": "string"}, "title": {"type": "string"}}, "type": "object"} + + +def test_remove_title_from_schema_handle_no_title_in_top_level(): + schema = { + "properties": { + "parameter1": {"type": "string", "title": "Parameter1"}, + "parameter2": {"type": "integer", "title": "Parameter2"}, + }, + "type": "object", + } + + _remove_title_from_schema(schema) + + assert schema == { + "properties": {"parameter1": {"type": "string"}, "parameter2": {"type": "integer"}}, + "type": "object", + } From 91619a79c11db84fe643e40c72476c6fec85296f Mon Sep 17 00:00:00 2001 From: Tobias Wochinger Date: Wed, 18 Dec 2024 21:34:57 +0100 Subject: [PATCH 11/14] fix: fix deserialization issues in multi-threading environments (#8651) --- haystack/core/pipeline/base.py | 5 ++--- haystack/utils/type_serialization.py | 20 ++++++++++++++++++- ...d-safe-module-import-ed04ad216820ab85.yaml | 4 ++++ 3 files changed, 25 insertions(+), 4 deletions(-) create mode 100644 releasenotes/notes/thread-safe-module-import-ed04ad216820ab85.yaml diff --git a/haystack/core/pipeline/base.py b/haystack/core/pipeline/base.py index 31ad2ad93c..d8f2a65932 100644 --- a/haystack/core/pipeline/base.py +++ b/haystack/core/pipeline/base.py @@ -2,7 +2,6 @@ # # SPDX-License-Identifier: Apache-2.0 -import importlib import itertools from collections import defaultdict from copy import deepcopy @@ -26,7 +25,7 @@ from haystack.core.serialization import DeserializationCallbacks, component_from_dict, component_to_dict from haystack.core.type_utils import _type_name, _types_are_compatible from haystack.marshal import Marshaller, YamlMarshaller -from haystack.utils import is_in_jupyter +from haystack.utils import is_in_jupyter, type_serialization from .descriptions import find_pipeline_inputs, find_pipeline_outputs from .draw import _to_mermaid_image @@ -161,7 +160,7 @@ def from_dict( # Import the module first... module, _ = component_data["type"].rsplit(".", 1) logger.debug("Trying to import module {module_name}", module_name=module) - importlib.import_module(module) + type_serialization.thread_safe_import(module) # ...then try again if component_data["type"] not in component.registry: raise PipelineError( diff --git a/haystack/utils/type_serialization.py b/haystack/utils/type_serialization.py index b2dd319d52..5ffb505bb1 100644 --- a/haystack/utils/type_serialization.py +++ b/haystack/utils/type_serialization.py @@ -6,10 +6,14 @@ import inspect import sys import typing +from threading import Lock +from types import ModuleType from typing import Any, get_args, get_origin from haystack import DeserializationError +_import_lock = Lock() + def serialize_type(target: Any) -> str: """ @@ -132,7 +136,7 @@ def parse_generic_args(args_str): module = sys.modules.get(module_name) if not module: try: - module = importlib.import_module(module_name) + module = thread_safe_import(module_name) except ImportError as e: raise DeserializationError(f"Could not import the module: {module_name}") from e @@ -141,3 +145,17 @@ def parse_generic_args(args_str): raise DeserializationError(f"Could not locate the type: {type_name} in the module: {module_name}") return deserialized_type + + +def thread_safe_import(module_name: str) -> ModuleType: + """ + Import a module in a thread-safe manner. + + Importing modules in a multi-threaded environment can lead to race conditions. + This function ensures that the module is imported in a thread-safe manner without having impact + on the performance of the import for single-threaded environments. + + :param module_name: the module to import + """ + with _import_lock: + return importlib.import_module(module_name) diff --git a/releasenotes/notes/thread-safe-module-import-ed04ad216820ab85.yaml b/releasenotes/notes/thread-safe-module-import-ed04ad216820ab85.yaml new file mode 100644 index 0000000000..3f1a0a2e78 --- /dev/null +++ b/releasenotes/notes/thread-safe-module-import-ed04ad216820ab85.yaml @@ -0,0 +1,4 @@ +--- +fixes: + - | + Fixes issues with deserialization of components in multi-threaded environments. From c306bee66563c23a08769c7bd6a56f01d5cac138 Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Thu, 19 Dec 2024 11:08:29 +0100 Subject: [PATCH 12/14] fix: adding missing abbreviations files for SentenceSplitter (#8660) * adding missing abbreviations files for SentenceSplitter * fixing tests path --- .pre-commit-config.yaml | 1 + .../preprocessors/sentence_tokenizer.py | 2 +- haystack/data/abbreviations/de.txt | 1097 +++++++++++++++++ haystack/data/abbreviations/en.txt | 975 +++++++++++++++ .../preprocessors/test_sentence_tokenizer.py | 2 +- 5 files changed, 2075 insertions(+), 2 deletions(-) create mode 100644 haystack/data/abbreviations/de.txt create mode 100644 haystack/data/abbreviations/en.txt diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 483ed06b08..306e205e41 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -26,6 +26,7 @@ repos: rev: v2.3.0 hooks: - id: codespell + exclude: "haystack/data/abbreviations" args: ["--toml", "pyproject.toml"] additional_dependencies: - tomli diff --git a/haystack/components/preprocessors/sentence_tokenizer.py b/haystack/components/preprocessors/sentence_tokenizer.py index 505126e901..5dd6ad97ee 100644 --- a/haystack/components/preprocessors/sentence_tokenizer.py +++ b/haystack/components/preprocessors/sentence_tokenizer.py @@ -228,7 +228,7 @@ def _read_abbreviations(lang: Language) -> List[str]: :param lang: The language to read the abbreviations for. :returns: List of abbreviations. """ - abbreviations_file = Path(__file__).parent.parent / f"data/abbreviations/{lang}.txt" + abbreviations_file = Path(__file__).parent.parent.parent / f"data/abbreviations/{lang}.txt" if not abbreviations_file.exists(): logger.warning("No abbreviations file found for {language}. Using default abbreviations.", language=lang) return [] diff --git a/haystack/data/abbreviations/de.txt b/haystack/data/abbreviations/de.txt new file mode 100644 index 0000000000..848b07541d --- /dev/null +++ b/haystack/data/abbreviations/de.txt @@ -0,0 +1,1097 @@ +abb +abds +abfr +abg +abgek +abh +abk +abl +ableg +ableu +abm +abn +abr +abs +abschn +abst +abulg +abw +abzgl +accel +add +adhrsg +adj +adr +adv +adyg +aengl +afghan +afr +afrik +afrk +afrs +ags +ahd +aind +akad +akk +akkad +akt +al +alb +alban +alem +alemann +allg +allj +allm +alltagsspr +alphanum +altgr +althochdt +altis +altisländ +altröm +alttest +amer +amerik +amerikan +amhar +amtl +amtlbegr +amtlmitt +amtm +amtsbl +amtsdt +amtsspr +anat +anatom +andalus +ang +ange +angekl +angelsächs +angest +angloamerik +anglofrz +angloind +anh +ank +ankl +anl +anm +annamit +anord +anschl +anschr +antarkt +anthrop +anw +anwbl +anz +aobd +apl +aplprof +apostr +apr +arab +aram +aran +arbrgeg +archäol +arg +argent +arkt +art +aserbaidsch +aslaw +assyr +ast +astron +asächs +attr +auff +aufl +aug +ausdr +ausf +ausg +ausl +aussch +ausschl +ausspr +ausst +austral +außenpol +awar +awest +az +aztek +bab +babl +babyl +bair +bakt +balt +baltoslaw +bankw +banz +baschk +bask +bauf +baupol +bauw +bay +baybgm +baygvbl +bayjmbl +baymabl +bayr +bayvbl +bayärztebl +bbg +bd +bde +bearb +bed +begl +begr +beif +beigel +bej +bek +bekl +bem +ber +berbersprachl +bergb +berger +berl +berlärztebl +berufsbez +bes +besbed +besch +beschl +beschw +beschwger +bestr +betr +betriebswiss +betrverf +bev +bew +bez +bezw +bf +bfn +bft +bg +bgbl +bgm +bhf +bibl +bildl +bildungspol +bildungsspr +biol +bj +bl +blk +bln +bodenk +bpräs +br-dr +br-drs +br-prot +brak-mitt +bras +brem +bremgbl +bret +breton +brn +bruchz +bsd +bsp +bspw +bstbl +bt-dr +bt-drs +btl +btto +bttr +buchw +buddh +bulg +bulgar +bundespol +burjat +burmes +bw +byzant +bz +bzb +bzbl +bzgl +bzw +börsenw +ca +carp +cf +chakass +chald +chant +chem +chilen +chin +chr +christl +chron +co +cresc +dat +ders +dez +dgl +dgr +di +dial +dig +dim +dimin +dingl +dipl +diss +do +dominikan +dor +doz +dr +drchf +drcks +drdes +dres +drhc +drphil +drrernat +drs +drucks +dt +dtl +dto +dtsch +dtzd +dvbl +dz +däbl +dän +ebd +ehem +eidg +eig +eigtl +einf +eingetr +einh +einl +einschl +einstw +einw +eisenb +elektrot +elektrotechn +em +engl +entspr +erb +erf +erg +erl +erm +ersch +erschl +erw +erzb +erzg +erzgeb +eskim +est +estn +etc +etg +etrusk +etw +eur +europ +eust +ev +evang +evtl +ew +ewen +ewenk +exkl +expl +ez +fa +fachspr +fam +feb +fem +ff +fig +finanzmath +finn +finnougr +flgh +flnr +flst +flstk +flstnr +fläm +fn +fnhd +folg +forts +fortstzg +fr +fragm +franz +französ +frdl +frh +frhr +fries +friesl +frl +frnhd +frz +fränk +frühnhd +fsm +ftm +fußn +färö +förml +gabl +gall +galloroman +gart +gaskogn +gbf +geb +gebr +ged +gef +geg +gegr +geh +geisteswissenschaftl +gek +gel +geleg +gem +gemeingerm +gen +geod +geogr +geograf +geograph +geol +geolog +geophys +georg +gep +ger +germ +ges +gesch +geschr +gespr +gest +gesundheitspol +get +gew +gez +gfsch +gft +gg +ggb +ggbfs +ggez +ggf +ggfs +ggs +ggü +ghzg +ghzgt +glchz +gleichbed +gleichz +glz +got +gr +gramm +grammat +graph +grch +grchl +grdb +grdf +grdfl +grdg +grdl +grdr +grds +grdst +griech +grz +grönländ +gvbl +gvnw +gvobl +gyn +gynäk +gz +gäl +hait +halbs +hamb +handw +hbf +hd +hdb +hebr +hess +hethit +hf +hg +hindust +hinr +hins +hinw +hist +hjber +hkl +hl +hmb +hochd +hochspr +hom +hptpl +hpts +hptst +hqu +hr +hrn +hrsg +hs +hubbr +hubr +hw +hyaz +hydr +hydrol +hzm +iallg +iber +ibid +ident +idg +ie +illyr +imkerspr +inc +ind +indef +indekl +indian +indiff +indir +indiv +indog +indogerm +indogerman +indoiran +indon +indones +inf +ing +inh +inkl +inn +innenpol +insb +insbes +int +iron +isl +islam +isländ +it +ital +jabl +jahrh +jakut +jan +jap +japan +jav +jbl +jdn +jem +jg +jh +jhd +jhdt +jhs +jidd +jmbl +jmd +jmdm +jmdn +jmds +jn +journ +jr +jt +jtsd +jugendspr +jugendsprachl +jugoslaw +jul +jun +jur +jvbl +jährl +jüd +kalm +kanad +kap +karib +kastil +katal +kath +kaufm +kaukas +kbgekd +kelt +kfm +kfr +kgr +kindersprachl +kirchenlat +kirchenslaw +kirchl +kirg +kj +kl +klass +klimatol +kol +kom +komm +konf +konj +konv +kop +kopt +korean +kostrsp +kr +kreol +kret +krimgot +kriminaltechn +krit +kroat +krs +ks +ktn +kto +kuban +kurd +kurzw +kw +l +lab +labg +ladin +landespol +landsch +landw +langfr +langj +langob +langobard +lapp +lat +latein +latinis +lautl +lautm +lbd +lbdg +ldkr +led +leg +lett +lfd +lfg +lfm +lfrg +lg +lgbl +lgfr +lgft +lgj +lig +ling +lit +lrh +ls +lst +lt +ltd +luth +luxemb +ma +mabl +mag +malai +marinespr +marx +mask +math +max +mazedon +mbl +mbll +md +mdal +mdj +mdl +mdls +mdt +mech +meckl +med +melanes +mengl +merc +meteorol +meton +mexik +mfr +mfranz +mfrk +mfrz +mfränk +mgl +mglw +mhd +mhdt +mi +mia +mihd +milit +mill +min +mind +mio +mitgl +mitteld +mitteldt +mittelhochdt +mittw +mitw +mlat +mnd +mndd +mniederd +mnl +mo +mod +mong +mr +mrd +mrs +mrz +ms +mschr +msgr +msp +mtl +mundartl +mwst +mz +mär +möbl +mündl +nachf +nachm +nachw +nat +nationalsoz +natsoz +nbfl +nchf +nd +ndd +ndrl +nds +ndssog +ndsvbl +neapolit +neub +neunorweg +neutest +neutr +nhd +niederd +niederdt +niederl +niederld +niem +nl +nlat +nom +nordamerik +nordd +norddt +nordgerm +nordostd +nordostdt +nordwestd +nordwestdt +norw +norweg +nov +nr +ntw +nutzfl +nw +nwvbl +näml +nördl +obb +obd +oberlaus +obers +obersächs +obj +od +offiz +offz +okt +op +org +orig +orth +osk +osman +ostd +ostdt +oz +palästin +pat +pers +peruan +pet +pf +pfd +pfg +philos +phonolog +phryg +phys +phöniz +pkt +pl +plur +polit +poln +polynes +portug +pos +pp +ppa +pr +preuß +prof +prot +prov +provenz +proz +präd +prähist +präs +psych +päd +qmstr +qt +qu +quadr +quar +quart +quat +quäst +rak +rd +rderl +rdnr +reg +regbl +regt +rel +relig +rep +resp +rgbl +rglm +rgstr +rgt +rh +rh-pf +rheinhess +rhet +rhfrk +rhj +rhld +rhs +ri +richtl +rip +rk +rmbl +rn +rotw +rr +rrh +rs +rspr +rumän +russ +rvj +rzp +rätorom +röm +sa +saarl +sachs +sanskr +sbd +sc +scherzh +schles +schr +schriftl +schwed +schwäb +sdp +sek +sem +semit +sen +sep +sept +serb +serbokroat +sg +sibir +singhal +sizilian +skand +slaw +slg +slowak +slowen +sod +sof +sog +sogen +sogl +soldatenspr +solv +somal +sorb +sout +soz +sozialgesch +soziol +spez +sportspr +spr +sprachwiss +spätahd +spätgriech +spätlat +spätmhd +sr +ssp +st +staatl +std +stdl +stellv +stf +str +stud +stuzbl +subsp +subst +sumer +svw +syn +syr +sächs +sächsvbl +südafrik +südd +süddt +südl +südostdt +südwestd +süßw +tabl +taf +tamil +tatar +techn +teilw +tel +telef +terr +tfx +tgl +tgt +thrak +thür +thüring +thürvbl +ti +tib +tirol +tochar +trans +tschech +tschechoslowak +tsd +tungus +turkotat +typogr +tz +tägl +türk +uabs +udgl +ugr +ugs +ukrain +umbr +umstr +unang +unbefl +ungar +ungebr +ungel +univ +unzerbr +urgerm +urk +urkdl +urspr +ursprüngl +urt +usf +ust-idnr +usw +va +var +vbl +vchr +verf +verg +vergl +verh +verkehrspol +vern +vers +verwarch +vfg +vgl +vh +viell +vkbl +vl +vlat +vllt +vlt +vobl +volkst +vollj +vorbem +vors +vs +vsl +vt +vulg +vulgärlat +vwz +vzk +wa +weibl +weißruss +westd +westdt +westf +westfäl +westgerm +westl +wfl +wg +wh +whg +winzerspr +wirtschaftl +wiss +wj +wld +wtb +wwe +wz +xerogr +xerok +xyl +yd +yds +zb +zbsp +zeithist +zf +zi +ziff +zool +zpr +zssg +zssgn +zt +zus +zw +zz +zzgl +zzt +ägypt +änd +öbgbl +ökol +ökon +ökum +österr +östl +übertr +überw +übk +übl +üblw diff --git a/haystack/data/abbreviations/en.txt b/haystack/data/abbreviations/en.txt new file mode 100644 index 0000000000..f23cb43122 --- /dev/null +++ b/haystack/data/abbreviations/en.txt @@ -0,0 +1,975 @@ +abbrev +abd +aberd +aberdeensh +abl +abol +aborig +abp +abr +abridg +abridgem +absol +abst +abstr +acad +acc +accomm +accompl +accs +acct +accts +accus +achievem +add +addit +addr +adj +adjs +adm +admir +admon +admonit +adv +advancem +advb +advert +advoc +advs +advt +advts +aerodynam +aeronaut +aff +afr +agric +agst +al +alch +alg +alleg +allit +alm +alph +alt +amer +analyt +anat +anc +anecd +ang +angl +anglo-ind +anim +ann +anniv +annot +anon +answ +anthrop +anthropol +antiq +aphet +apoc +apol +appl +appl'n +applic +appos +apr +arb +archaeol +archit +argt +arith +arithm +arrangem +arrv +artic +artific +artill +ashm +assemb +assoc +assyriol +astr +astrol +astron +att +attrib +aug +austral +auth +autobiog +autobiogr +ave +ayrsh +bacteriol +bedfordsh +bef +belg +berks +berksh +berw +berwicksh +betw +bibliogr +biochem +biog +biogr +biol +bk +bks +blvd +bord +bp +braz +bros +bur +cal +calc +calend +calif +calligr +camb +cambr +campanol +canad +canterb +capt +cartogr +catal +catech +cath +ceram +cert +certif +cf +ch +chamb +char +charac +chas +chem +chesh +chr +chron +chronol +chrons +cinematogr +circ +cl +classif +climatol +clin +co +col +coll +colloq +com +comm +commandm +commend +commerc +commiss +commonw +communic +comp +compan +compar +compend +compl +compos +conc +conch +concl +concr +conf +confid +confl +confut +congr +congreg +conj +conn +cons +consc +consecr +consid +consol +const +constit +constr +contemp +contempl +contempt +contend +contin +contr +contrib +controv +conv +conversat +convoc +cor +cornw +coron +corp +corr +corresp +counc +courtsh +cpd +craniol +craniom +crim +crit +crt +crts +cryptogr +crystallogr +ct +cumb +cumberld +cumbld +cycl +cytol +dat +dau +deb +dec +declar +ded +def +deliv +dem +demonstr +dep +depred +depredat +dept +derbysh +deriv +derog +descr +deut +devel +devonsh +dict +diffic +dim +dis +discipl +discov +discrim +diss +dist +distemp +distill +distrib +div +divers +dk +doc +doct +domest +dr +drs +durh +dyslog +eccl +eccles +ecclus +ecol +econ +ed +edin +edinb +educ +edw +egypt +egyptol +electr +electro-magn +electro-physiol +elem +eliz +elizab +ellipt +emb +embryol +emph +encl +encycl +eng +engin +englishw +enq +ent +enthus +entom +entomol +enzymol +ep +eph +ephes +epil +episc +epist +epit +equip +erron +esd +esp +ess +essent +establ +esth +etc +ethnol +etym +etymol +euphem +eval +evang +evid +evol +exalt +exc +exch +exec +exerc +exhib +exod +exped +exper +explan +explic +explor +expos +ext +ezek +fab +fam +famil +farew +feb +fem +ff +fifesh +fig +fl +footpr +forfarsh +fortif +fortn +found +fr +fragm +fratern +freq +fri +friendsh +ft +furnit +fut +gal +gard +gastron +gaz +gd +gen +geo +geog +geogr +geol +geom +geomorphol +ger +glac +glasg +glos +gloss +glouc +gloucestersh +gosp +gov +govt +gr +gram +gt +gynaecol +hab +haematol +hag +hampsh +handbk +hants +heb +hebr +hen +herb +heref +herefordsh +hertfordsh +hierogl +hist +histol +hom +horol +hort +hos +hosp +househ +housek +husb +hydraul +hydrol +ichth +icthyol +ideol +idol +illustr +imag +imit +immunol +imp +imperf +impers +impf +impr +improp +inaug +inc +inclos +ind +indef +indic +indir +industr +infin +infl +innoc +inorg +inq +inst +instr +int +intell +interc +interj +interl +internat +interpr +interrog +intr +intrans +intro +introd +inv +invertebr +investig +investm +invoc +ir +irel +iron +irreg +isa +ital +jahrb +jam +jan +jap +jas +jer +joc +josh +jr +jrnl +jrnls +jud +judg +jul +jun +jurisd +jurisdict +jurispr +justif +justific +kgs +kingd +knowl +kpr +lam +lanc +lancash +lancs +lang +langs +lat +lb +ld +lds +lect +leechd +leicest +leicestersh +leics +let +lett +lev +lex +libr +limnol +lincolnsh +lincs +ling +linn +lit +lithogr +lithol +liturg +ll +lond +lt +ltd +macc +mach +mag +magn +mal +managem +manch +manip +manuf +mar +masc +matt +meas +measurem +mech +med +medit +mem +merc +merch +metall +metallif +metallogr +metamorph +metaph +metaphor +meteorol +metrop +mex +mic +mich +microbiol +microsc +midl +mil +milit +min +mineral +misc +miscell +mispr +mon +monum +morphol +mr +mrs +ms +msc +mss +mt +mtg +mts +munic +munif +munim +mus +myst +mythol +nah +narr +narrat +nat +naut +nav +navig +neh +neighb +nerv +neurol +neurosurg +newc +newspr +nom +non-conf +nonce-wd +nonconf +norf +northamptonsh +northants +northumb +northumbld +northumbr +norw +norweg +notts +nov +ns +nucl +num +numism +obad +obed +obj +obl +obs +observ +obstet +obstetr +occas +occup +occurr +oceanogr +oct +offic +okla +ont +ophthalm +ophthalmol +opp +oppress +opt +orac +ord +org +orig +orkn +ornith +ornithol +orthogr +outl +oxf +oxfordsh +oxon +oz +pa +palaeobot +palaeogr +palaeont +palaeontol +paraphr +parasitol +parl +parnass +pathol +peculat +penins +perf +perh +periodontol +pers +persec +personif +perthsh +petrogr +pf +pharm +pharmaceut +pharmacol +phd +phil +philad +philem +philipp +philol +philos +phoen +phonet +phonol +photog +photogr +phr +phrenol +phys +physiogr +physiol +pict +pl +plur +pol +polit +polytechn +porc +poss +posth +postm +ppl +pple +pples +pract +prec +pred +predic +predict +pref +preh +prehist +prep +prerog +pres +presb +preserv +prim +princ +priv +prob +probab +probl +proc +prod +prof +prol +pron +pronunc +prop +propr +pros +prov +provid +provinc +provis +ps +pseudo-arch +pseudo-dial +pseudo-sc +psych +psychoanal +psychoanalyt +psychol +psychopathol +pt +publ +purg +qld +quot +quots +radiol +reas +reb +rec +reclam +recoll +redempt +redupl +ref +refash +refl +refus +refut +reg +regic +regist +regr +rel +relig +reminisc +remonstr +renfrewsh +rep +repr +reprod +reps +rept +repub +res +resid +ret +retrosp +rev +revol +rhet +rich +ross-sh +roxb +roy +rudim +russ +sam +sask +sat +sc +scand +sch +sci +scot +scotl +script +sculpt +seismol +sel +sen +sep +sept +ser +serm +sess +settlem +sev +shakes +shaks +sheph +shetl +shropsh +soc +sociol +som +sonn +sp +spec +specif +specim +spectrosc +spp +sq +sr +ss +st +staffordsh +staffs +stat +statist +ste +str +stratigr +struct +sts +stud +subj +subjunct +subord +subscr +subscript +subseq +subst +suff +superl +suppl +supplic +suppress +surg +surv +sus +syll +symmetr +symp +syst +taxon +techn +technol +tel +telecomm +telegr +teleph +teratol +terminol +terrestr +textbk +theat +theatr +theol +theoret +thermonucl +thes +thess +thur +topogr +tr +trad +trag +trans +transf +transl +transubstant +trav +treas +treatm +trib +trig +trigonom +trop +troub +troubl +tue +typog +typogr +ult +univ +unkn +unnat +unoffic +unstr +usu +utilit +va +vac +valedict +var +varr +vars +vb +vbl +vbs +veg +venet +vertebr +vet +vic +vict +vind +vindic +virg +virol +viz +voc +vocab +vol +vols +voy +vs +vulg +warwicksh +wd +wed +westm +westmld +westmorld +westmrld +wilts +wiltsh +wis +wisd +wk +wkly +wks +wonderf +worc +worcestersh +worcs +writ +yearbk +yng +yorks +yorksh +yr +yrs +zech +zeitschr +zeph +zoogeogr +zool diff --git a/test/components/preprocessors/test_sentence_tokenizer.py b/test/components/preprocessors/test_sentence_tokenizer.py index bf9aab9a9e..f8aaf68fa7 100644 --- a/test/components/preprocessors/test_sentence_tokenizer.py +++ b/test/components/preprocessors/test_sentence_tokenizer.py @@ -54,7 +54,7 @@ def test_read_abbreviations_existing_file(tmp_path, mock_file_content): abbrev_file.write_text(mock_file_content) with patch("haystack.components.preprocessors.sentence_tokenizer.Path") as mock_path: - mock_path.return_value.parent.parent = tmp_path + mock_path.return_value.parent.parent.parent = tmp_path result = SentenceSplitter._read_abbreviations("en") assert result == ["Mr.", "Dr.", "Prof."] From 2bc58d298749a28372bb98a7b3c902786380ea69 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Thu, 19 Dec 2024 15:04:37 +0100 Subject: [PATCH 13/14] feat: support for tools in `HuggingFaceAPIChatGenerator` (#8661) * message conversion function * hfapi w tools * right test file + hf_hub version * release note * feedback --- .../hugging_face_api_document_embedder.py | 2 +- .../hugging_face_api_text_embedder.py | 2 +- .../generators/chat/hugging_face_api.py | 152 ++++++--- .../components/generators/hugging_face_api.py | 2 +- haystack/dataclasses/tool.py | 15 +- haystack/utils/hf.py | 42 ++- pyproject.toml | 2 +- .../notes/hfapi-tools-a7224150bce52564.yaml | 4 + .../generators/chat/test_hugging_face_api.py | 297 ++++++++++++++++-- test/dataclasses/test_tool.py | 16 + test/utils/test_hf.py | 59 +++- 11 files changed, 509 insertions(+), 84 deletions(-) create mode 100644 releasenotes/notes/hfapi-tools-a7224150bce52564.yaml diff --git a/haystack/components/embedders/hugging_face_api_document_embedder.py b/haystack/components/embedders/hugging_face_api_document_embedder.py index 43f719e27d..459e386976 100644 --- a/haystack/components/embedders/hugging_face_api_document_embedder.py +++ b/haystack/components/embedders/hugging_face_api_document_embedder.py @@ -14,7 +14,7 @@ from haystack.utils.hf import HFEmbeddingAPIType, HFModelType, check_valid_model from haystack.utils.url_validation import is_valid_http_url -with LazyImport(message="Run 'pip install \"huggingface_hub>=0.23.0\"'") as huggingface_hub_import: +with LazyImport(message="Run 'pip install \"huggingface_hub>=0.27.0\"'") as huggingface_hub_import: from huggingface_hub import InferenceClient logger = logging.getLogger(__name__) diff --git a/haystack/components/embedders/hugging_face_api_text_embedder.py b/haystack/components/embedders/hugging_face_api_text_embedder.py index f60a9e5fd7..2cd68d34da 100644 --- a/haystack/components/embedders/hugging_face_api_text_embedder.py +++ b/haystack/components/embedders/hugging_face_api_text_embedder.py @@ -11,7 +11,7 @@ from haystack.utils.hf import HFEmbeddingAPIType, HFModelType, check_valid_model from haystack.utils.url_validation import is_valid_http_url -with LazyImport(message="Run 'pip install \"huggingface_hub>=0.23.0\"'") as huggingface_hub_import: +with LazyImport(message="Run 'pip install \"huggingface_hub>=0.27.0\"'") as huggingface_hub_import: from huggingface_hub import InferenceClient logger = logging.getLogger(__name__) diff --git a/haystack/components/generators/chat/hugging_face_api.py b/haystack/components/generators/chat/hugging_face_api.py index 8711a9175a..dab61e4d93 100644 --- a/haystack/components/generators/chat/hugging_face_api.py +++ b/haystack/components/generators/chat/hugging_face_api.py @@ -5,30 +5,25 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Union from haystack import component, default_from_dict, default_to_dict, logging -from haystack.dataclasses import ChatMessage, StreamingChunk +from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall +from haystack.dataclasses.tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace from haystack.lazy_imports import LazyImport from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable -from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model +from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model, convert_message_to_hf_format from haystack.utils.url_validation import is_valid_http_url -with LazyImport(message="Run 'pip install \"huggingface_hub[inference]>=0.23.0\"'") as huggingface_hub_import: - from huggingface_hub import ChatCompletionOutput, ChatCompletionStreamOutput, InferenceClient +with LazyImport(message="Run 'pip install \"huggingface_hub[inference]>=0.27.0\"'") as huggingface_hub_import: + from huggingface_hub import ( + ChatCompletionInputTool, + ChatCompletionOutput, + ChatCompletionStreamOutput, + InferenceClient, + ) logger = logging.getLogger(__name__) -def _convert_message_to_hfapi_format(message: ChatMessage) -> Dict[str, str]: - """ - Convert a message to the format expected by Hugging Face APIs. - - :returns: A dictionary with the following keys: - - `role` - - `content` - """ - return {"role": message.role.value, "content": message.text or ""} - - @component class HuggingFaceAPIChatGenerator: """ @@ -107,6 +102,7 @@ def __init__( # pylint: disable=too-many-positional-arguments generation_kwargs: Optional[Dict[str, Any]] = None, stop_words: Optional[List[str]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + tools: Optional[List[Tool]] = None, ): """ Initialize the HuggingFaceAPIChatGenerator instance. @@ -121,14 +117,22 @@ def __init__( # pylint: disable=too-many-positional-arguments - `model`: Hugging Face model ID. Required when `api_type` is `SERVERLESS_INFERENCE_API`. - `url`: URL of the inference endpoint. Required when `api_type` is `INFERENCE_ENDPOINTS` or `TEXT_GENERATION_INFERENCE`. - :param token: The Hugging Face token to use as HTTP bearer authorization. + :param token: + The Hugging Face token to use as HTTP bearer authorization. Check your HF token in your [account settings](https://huggingface.co/settings/tokens). :param generation_kwargs: A dictionary with keyword arguments to customize text generation. Some examples: `max_tokens`, `temperature`, `top_p`. For details, see [Hugging Face chat_completion documentation](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion). - :param stop_words: An optional list of strings representing the stop words. - :param streaming_callback: An optional callable for handling streaming responses. + :param stop_words: + An optional list of strings representing the stop words. + :param streaming_callback: + An optional callable for handling streaming responses. + :param tools: + A list of tools for which the model can prepare calls. + The chosen model should support tool/function calling, according to the model card. + Support for tools in the Hugging Face API and TGI is not yet fully refined and you may experience + unexpected behavior. """ huggingface_hub_import.check() @@ -159,6 +163,11 @@ def __init__( # pylint: disable=too-many-positional-arguments msg = f"Unknown api_type {api_type}" raise ValueError(msg) + if tools: + if streaming_callback is not None: + raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.") + _check_duplicate_tool_names(tools) + # handle generation kwargs setup generation_kwargs = generation_kwargs.copy() if generation_kwargs else {} generation_kwargs["stop"] = generation_kwargs.get("stop", []) @@ -171,6 +180,7 @@ def __init__( # pylint: disable=too-many-positional-arguments self.generation_kwargs = generation_kwargs self.streaming_callback = streaming_callback self._client = InferenceClient(model_or_url, token=token.resolve_value() if token else None) + self.tools = tools def to_dict(self) -> Dict[str, Any]: """ @@ -180,6 +190,7 @@ def to_dict(self) -> Dict[str, Any]: A dictionary containing the serialized component. """ callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None + serialized_tools = [tool.to_dict() for tool in self.tools] if self.tools else None return default_to_dict( self, api_type=str(self.api_type), @@ -187,6 +198,7 @@ def to_dict(self) -> Dict[str, Any]: token=self.token.to_dict() if self.token else None, generation_kwargs=self.generation_kwargs, streaming_callback=callback_name, + tools=serialized_tools, ) @classmethod @@ -195,6 +207,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceAPIChatGenerator": Deserialize this component from a dictionary. """ deserialize_secrets_inplace(data["init_parameters"], keys=["token"]) + deserialize_tools_inplace(data["init_parameters"], key="tools") init_params = data.get("init_parameters", {}) serialized_callback_handler = init_params.get("streaming_callback") if serialized_callback_handler: @@ -202,12 +215,22 @@ def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceAPIChatGenerator": return default_from_dict(cls, data) @component.output_types(replies=List[ChatMessage]) - def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None): + def run( + self, + messages: List[ChatMessage], + generation_kwargs: Optional[Dict[str, Any]] = None, + tools: Optional[List[Tool]] = None, + ): """ Invoke the text generation inference based on the provided messages and generation parameters. - :param messages: A list of ChatMessage objects representing the input messages. - :param generation_kwargs: Additional keyword arguments for text generation. + :param messages: + A list of ChatMessage objects representing the input messages. + :param generation_kwargs: + Additional keyword arguments for text generation. + :param tools: + A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set + during component initialization. :returns: A dictionary with the following keys: - `replies`: A list containing the generated responses as ChatMessage objects. """ @@ -215,12 +238,22 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, # update generation kwargs by merging with the default ones generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} - formatted_messages = [_convert_message_to_hfapi_format(message) for message in messages] + formatted_messages = [convert_message_to_hf_format(message) for message in messages] + + tools = tools or self.tools + if tools: + if self.streaming_callback: + raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.") + _check_duplicate_tool_names(tools) if self.streaming_callback: return self._run_streaming(formatted_messages, generation_kwargs) - return self._run_non_streaming(formatted_messages, generation_kwargs) + hf_tools = None + if tools: + hf_tools = [{"type": "function", "function": {**t.tool_spec}} for t in tools] + + return self._run_non_streaming(formatted_messages, generation_kwargs, hf_tools) def _run_streaming(self, messages: List[Dict[str, str]], generation_kwargs: Dict[str, Any]): api_output: Iterable[ChatCompletionStreamOutput] = self._client.chat_completion( @@ -229,11 +262,17 @@ def _run_streaming(self, messages: List[Dict[str, str]], generation_kwargs: Dict generated_text = "" - for chunk in api_output: # pylint: disable=not-an-iterable - text = chunk.choices[0].delta.content + for chunk in api_output: + # n is unused, so the API always returns only one choice + # the argument is probably allowed for compatibility with OpenAI + # see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n + choice = chunk.choices[0] + + text = choice.delta.content if text: generated_text += text - finish_reason = chunk.choices[0].finish_reason + + finish_reason = choice.finish_reason meta = {} if finish_reason: @@ -242,8 +281,7 @@ def _run_streaming(self, messages: List[Dict[str, str]], generation_kwargs: Dict stream_chunk = StreamingChunk(text, meta) self.streaming_callback(stream_chunk) # type: ignore # streaming_callback is not None (verified in the run method) - message = ChatMessage.from_assistant(generated_text) - message.meta.update( + meta.update( { "model": self._client.model, "finish_reason": finish_reason, @@ -251,24 +289,48 @@ def _run_streaming(self, messages: List[Dict[str, str]], generation_kwargs: Dict "usage": {"prompt_tokens": 0, "completion_tokens": 0}, # not available in streaming } ) + + message = ChatMessage.from_assistant(text=generated_text, meta=meta) + return {"replies": [message]} def _run_non_streaming( - self, messages: List[Dict[str, str]], generation_kwargs: Dict[str, Any] + self, + messages: List[Dict[str, str]], + generation_kwargs: Dict[str, Any], + tools: Optional[List["ChatCompletionInputTool"]] = None, ) -> Dict[str, List[ChatMessage]]: - chat_messages: List[ChatMessage] = [] - - api_chat_output: ChatCompletionOutput = self._client.chat_completion(messages, **generation_kwargs) - for choice in api_chat_output.choices: - message = ChatMessage.from_assistant(choice.message.content) - message.meta.update( - { - "model": self._client.model, - "finish_reason": choice.finish_reason, - "index": choice.index, - "usage": api_chat_output.usage or {"prompt_tokens": 0, "completion_tokens": 0}, - } - ) - chat_messages.append(message) - - return {"replies": chat_messages} + api_chat_output: ChatCompletionOutput = self._client.chat_completion( + messages=messages, tools=tools, **generation_kwargs + ) + + if len(api_chat_output.choices) == 0: + return {"replies": []} + + # n is unused, so the API always returns only one choice + # the argument is probably allowed for compatibility with OpenAI + # see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n + choice = api_chat_output.choices[0] + + text = choice.message.content + tool_calls = [] + + if hfapi_tool_calls := choice.message.tool_calls: + for hfapi_tc in hfapi_tool_calls: + tool_call = ToolCall( + tool_name=hfapi_tc.function.name, arguments=hfapi_tc.function.arguments, id=hfapi_tc.id + ) + tool_calls.append(tool_call) + + meta = {"model": self._client.model, "finish_reason": choice.finish_reason, "index": choice.index} + + usage = {"prompt_tokens": 0, "completion_tokens": 0} + if api_chat_output.usage: + usage = { + "prompt_tokens": api_chat_output.usage.prompt_tokens, + "completion_tokens": api_chat_output.usage.completion_tokens, + } + meta["usage"] = usage + + message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls, meta=meta) + return {"replies": [message]} diff --git a/haystack/components/generators/hugging_face_api.py b/haystack/components/generators/hugging_face_api.py index a164c8c56c..a44ad94575 100644 --- a/haystack/components/generators/hugging_face_api.py +++ b/haystack/components/generators/hugging_face_api.py @@ -12,7 +12,7 @@ from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model from haystack.utils.url_validation import is_valid_http_url -with LazyImport(message="Run 'pip install \"huggingface_hub>=0.23.0\"'") as huggingface_hub_import: +with LazyImport(message="Run 'pip install \"huggingface_hub>=0.27.0\"'") as huggingface_hub_import: from huggingface_hub import ( InferenceClient, TextGenerationOutput, diff --git a/haystack/dataclasses/tool.py b/haystack/dataclasses/tool.py index 3df3fd18f2..c6606d51e8 100644 --- a/haystack/dataclasses/tool.py +++ b/haystack/dataclasses/tool.py @@ -4,7 +4,7 @@ import inspect from dataclasses import asdict, dataclass -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, List, Optional from pydantic import create_model @@ -216,6 +216,19 @@ def _remove_title_from_schema(schema: Dict[str, Any]): del property_schema[key] +def _check_duplicate_tool_names(tools: List[Tool]) -> None: + """ + Check for duplicate tool names and raises a ValueError if they are found. + + :param tools: The list of tools to check. + :raises ValueError: If duplicate tool names are found. + """ + tool_names = [tool.name for tool in tools] + duplicate_tool_names = {name for name in tool_names if tool_names.count(name) > 1} + if duplicate_tool_names: + raise ValueError(f"Duplicate tool names found: {duplicate_tool_names}") + + def deserialize_tools_inplace(data: Dict[str, Any], key: str = "tools"): """ Deserialize Tools in a dictionary inplace. diff --git a/haystack/utils/hf.py b/haystack/utils/hf.py index 537b05e232..6a83594ada 100644 --- a/haystack/utils/hf.py +++ b/haystack/utils/hf.py @@ -8,7 +8,7 @@ from typing import Any, Callable, Dict, List, Optional, Union from haystack import logging -from haystack.dataclasses import StreamingChunk +from haystack.dataclasses import ChatMessage, StreamingChunk from haystack.lazy_imports import LazyImport from haystack.utils.auth import Secret from haystack.utils.device import ComponentDevice @@ -16,7 +16,7 @@ with LazyImport(message="Run 'pip install \"transformers[torch]\"'") as torch_import: import torch -with LazyImport(message="Run 'pip install \"huggingface_hub>=0.23.0\"'") as huggingface_hub_import: +with LazyImport(message="Run 'pip install \"huggingface_hub>=0.27.0\"'") as huggingface_hub_import: from huggingface_hub import HfApi, InferenceClient, model_info from huggingface_hub.utils import RepositoryNotFoundError @@ -270,6 +270,44 @@ def check_generation_params(kwargs: Optional[Dict[str, Any]], additional_accepte ) +def convert_message_to_hf_format(message: ChatMessage) -> Dict[str, Any]: + """ + Convert a message to the format expected by Hugging Face. + """ + text_contents = message.texts + tool_calls = message.tool_calls + tool_call_results = message.tool_call_results + + if not text_contents and not tool_calls and not tool_call_results: + raise ValueError("A `ChatMessage` must contain at least one `TextContent`, `ToolCall`, or `ToolCallResult`.") + if len(text_contents) + len(tool_call_results) > 1: + raise ValueError("A `ChatMessage` can only contain one `TextContent` or one `ToolCallResult`.") + + # HF always expects a content field, even if it is empty + hf_msg: Dict[str, Any] = {"role": message._role.value, "content": ""} + + if tool_call_results: + result = tool_call_results[0] + hf_msg["content"] = result.result + if tc_id := result.origin.id: + hf_msg["tool_call_id"] = tc_id + # HF does not provide a way to communicate errors in tool invocations, so we ignore the error field + return hf_msg + + if text_contents: + hf_msg["content"] = text_contents[0] + if tool_calls: + hf_tool_calls = [] + for tc in tool_calls: + hf_tool_call = {"type": "function", "function": {"name": tc.tool_name, "arguments": tc.arguments}} + if tc.id is not None: + hf_tool_call["id"] = tc.id + hf_tool_calls.append(hf_tool_call) + hf_msg["tool_calls"] = hf_tool_calls + + return hf_msg + + with LazyImport(message="Run 'pip install \"transformers[torch]\"'") as transformers_import: from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, StoppingCriteria, TextStreamer diff --git a/pyproject.toml b/pyproject.toml index c1fddc8704..6a76a2e9c0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -85,7 +85,7 @@ extra-dependencies = [ "numpy>=2", # Haystack is compatible both with numpy 1.x and 2.x, but we test with 2.x "transformers[torch,sentencepiece]==4.44.2", # ExtractiveReader, TransformersSimilarityRanker, LocalWhisperTranscriber, HFGenerators... - "huggingface_hub>=0.23.0", # Hugging Face API Generators and Embedders + "huggingface_hub>=0.27.0", # Hugging Face API Generators and Embedders "sentence-transformers>=3.0.0", # SentenceTransformersTextEmbedder and SentenceTransformersDocumentEmbedder "langdetect", # TextLanguageRouter and DocumentLanguageClassifier "openai-whisper>=20231106", # LocalWhisperTranscriber diff --git a/releasenotes/notes/hfapi-tools-a7224150bce52564.yaml b/releasenotes/notes/hfapi-tools-a7224150bce52564.yaml new file mode 100644 index 0000000000..085ed35931 --- /dev/null +++ b/releasenotes/notes/hfapi-tools-a7224150bce52564.yaml @@ -0,0 +1,4 @@ +--- +features: + - | + Add support for Tools in the Hugging Face API Chat Generator. diff --git a/test/components/generators/chat/test_hugging_face_api.py b/test/components/generators/chat/test_hugging_face_api.py index e60ec863ab..0d0857e22a 100644 --- a/test/components/generators/chat/test_hugging_face_api.py +++ b/test/components/generators/chat/test_hugging_face_api.py @@ -5,23 +5,46 @@ from unittest.mock import MagicMock, Mock, patch import pytest +from haystack import Pipeline +from haystack.dataclasses import StreamingChunk +from haystack.utils.auth import Secret +from haystack.utils.hf import HFGenerationAPIType from huggingface_hub import ( ChatCompletionOutput, - ChatCompletionStreamOutput, ChatCompletionOutputComplete, - ChatCompletionStreamOutputChoice, + ChatCompletionOutputFunctionDefinition, ChatCompletionOutputMessage, + ChatCompletionOutputToolCall, + ChatCompletionOutputUsage, + ChatCompletionStreamOutput, + ChatCompletionStreamOutputChoice, ChatCompletionStreamOutputDelta, ) from huggingface_hub.utils import RepositoryNotFoundError -from haystack.components.generators.chat.hugging_face_api import ( - HuggingFaceAPIChatGenerator, - _convert_message_to_hfapi_format, -) -from haystack.dataclasses import ChatMessage, StreamingChunk -from haystack.utils.auth import Secret -from haystack.utils.hf import HFGenerationAPIType +from haystack.components.generators.chat.hugging_face_api import HuggingFaceAPIChatGenerator +from haystack.dataclasses import ChatMessage, Tool, ToolCall + + +@pytest.fixture +def chat_messages(): + return [ + ChatMessage.from_system("You are a helpful assistant speaking A2 level of English"), + ChatMessage.from_user("Tell me about Berlin"), + ] + + +@pytest.fixture +def tools(): + tool_parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} + tool = Tool( + name="weather", + description="useful to determine the weather in a given location", + parameters=tool_parameters, + function=lambda x: x, + ) + + return [tool] @pytest.fixture @@ -48,7 +71,7 @@ def mock_chat_completion(): id="some_id", model="some_model", system_fingerprint="some_fingerprint", - usage={"completion_tokens": 10, "prompt_tokens": 5, "total_tokens": 15}, + usage=ChatCompletionOutputUsage(completion_tokens=8, prompt_tokens=17, total_tokens=25), created=1710498360, ) @@ -61,15 +84,7 @@ def streaming_callback_handler(x): return x -def test_convert_message_to_hfapi_format(): - message = ChatMessage.from_system("You are good assistant") - assert _convert_message_to_hfapi_format(message) == {"role": "system", "content": "You are good assistant"} - - message = ChatMessage.from_user("I have a question") - assert _convert_message_to_hfapi_format(message) == {"role": "user", "content": "I have a question"} - - -class TestHuggingFaceAPIGenerator: +class TestHuggingFaceAPIChatGenerator: def test_init_invalid_api_type(self): with pytest.raises(ValueError): HuggingFaceAPIChatGenerator(api_type="invalid_api_type", api_params={}) @@ -93,6 +108,29 @@ def test_init_serverless(self, mock_check_valid_model): assert generator.api_params == {"model": model} assert generator.generation_kwargs == {**generation_kwargs, **{"stop": ["stop"]}, **{"max_tokens": 512}} assert generator.streaming_callback == streaming_callback + assert generator.tools is None + + def test_init_serverless_with_tools(self, mock_check_valid_model, tools): + model = "HuggingFaceH4/zephyr-7b-alpha" + generation_kwargs = {"temperature": 0.6} + stop_words = ["stop"] + streaming_callback = None + + generator = HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": model}, + token=None, + generation_kwargs=generation_kwargs, + stop_words=stop_words, + streaming_callback=streaming_callback, + tools=tools, + ) + + assert generator.api_type == HFGenerationAPIType.SERVERLESS_INFERENCE_API + assert generator.api_params == {"model": model} + assert generator.generation_kwargs == {**generation_kwargs, **{"stop": ["stop"]}, **{"max_tokens": 512}} + assert generator.streaming_callback == streaming_callback + assert generator.tools == tools def test_init_serverless_invalid_model(self, mock_check_valid_model): mock_check_valid_model.side_effect = RepositoryNotFoundError("Invalid model id") @@ -126,6 +164,7 @@ def test_init_tgi(self): assert generator.api_params == {"url": url} assert generator.generation_kwargs == {**generation_kwargs, **{"stop": ["stop"]}, **{"max_tokens": 512}} assert generator.streaming_callback == streaming_callback + assert generator.tools is None def test_init_tgi_invalid_url(self): with pytest.raises(ValueError): @@ -139,12 +178,33 @@ def test_init_tgi_no_url(self): api_type=HFGenerationAPIType.TEXT_GENERATION_INFERENCE, api_params={"param": "irrelevant"} ) + def test_init_fail_with_duplicate_tool_names(self, mock_check_valid_model, tools): + duplicate_tools = [tools[0], tools[0]] + with pytest.raises(ValueError): + HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "irrelevant"}, + tools=duplicate_tools, + ) + + def test_init_fail_with_tools_and_streaming(self, mock_check_valid_model, tools): + with pytest.raises(ValueError): + HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "irrelevant"}, + tools=tools, + streaming_callback=streaming_callback_handler, + ) + def test_to_dict(self, mock_check_valid_model): + tool = Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print) + generator = HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, generation_kwargs={"temperature": 0.6}, stop_words=["stop", "words"], + tools=[tool], ) result = generator.to_dict() @@ -154,15 +214,26 @@ def test_to_dict(self, mock_check_valid_model): assert init_params["api_params"] == {"model": "HuggingFaceH4/zephyr-7b-beta"} assert init_params["token"] == {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"} assert init_params["generation_kwargs"] == {"temperature": 0.6, "stop": ["stop", "words"], "max_tokens": 512} + assert init_params["streaming_callback"] is None + assert init_params["tools"] == [ + { + "description": "description", + "function": "builtins.print", + "name": "name", + "parameters": {"x": {"type": "string"}}, + } + ] def test_from_dict(self, mock_check_valid_model): + tool = Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print) + generator = HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, token=Secret.from_env_var("ENV_VAR", strict=False), generation_kwargs={"temperature": 0.6}, stop_words=["stop", "words"], - streaming_callback=streaming_callback_handler, + tools=[tool], ) result = generator.to_dict() @@ -172,11 +243,57 @@ def test_from_dict(self, mock_check_valid_model): assert generator_2.api_params == {"model": "HuggingFaceH4/zephyr-7b-beta"} assert generator_2.token == Secret.from_env_var("ENV_VAR", strict=False) assert generator_2.generation_kwargs == {"temperature": 0.6, "stop": ["stop", "words"], "max_tokens": 512} - assert generator_2.streaming_callback is streaming_callback_handler + assert generator_2.streaming_callback is None + assert generator_2.tools == [tool] + + def test_serde_in_pipeline(self, mock_check_valid_model): + tool = Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print) - def test_generate_text_response_with_valid_prompt_and_generation_parameters( - self, mock_check_valid_model, mock_chat_completion, chat_messages - ): + generator = HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, + token=Secret.from_env_var("ENV_VAR", strict=False), + generation_kwargs={"temperature": 0.6}, + stop_words=["stop", "words"], + tools=[tool], + ) + + pipeline = Pipeline() + pipeline.add_component("generator", generator) + + pipeline_dict = pipeline.to_dict() + assert pipeline_dict == { + "metadata": {}, + "max_runs_per_component": 100, + "components": { + "generator": { + "type": "haystack.components.generators.chat.hugging_face_api.HuggingFaceAPIChatGenerator", + "init_parameters": { + "api_type": "serverless_inference_api", + "api_params": {"model": "HuggingFaceH4/zephyr-7b-beta"}, + "token": {"type": "env_var", "env_vars": ["ENV_VAR"], "strict": False}, + "generation_kwargs": {"temperature": 0.6, "stop": ["stop", "words"], "max_tokens": 512}, + "streaming_callback": None, + "tools": [ + { + "name": "name", + "description": "description", + "parameters": {"x": {"type": "string"}}, + "function": "builtins.print", + } + ], + }, + } + }, + "connections": [], + } + + pipeline_yaml = pipeline.dumps() + + new_pipeline = Pipeline.loads(pipeline_yaml) + assert new_pipeline == pipeline + + def test_run(self, mock_check_valid_model, mock_chat_completion, chat_messages): generator = HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "meta-llama/Llama-2-13b-chat-hf"}, @@ -187,9 +304,19 @@ def test_generate_text_response_with_valid_prompt_and_generation_parameters( response = generator.run(messages=chat_messages) - # check kwargs passed to text_generation + # check kwargs passed to chat_completion _, kwargs = mock_chat_completion.call_args - assert kwargs == {"temperature": 0.6, "stop": ["stop", "words"], "max_tokens": 512} + hf_messages = [ + {"role": "system", "content": "You are a helpful assistant speaking A2 level of English"}, + {"role": "user", "content": "Tell me about Berlin"}, + ] + assert kwargs == { + "temperature": 0.6, + "stop": ["stop", "words"], + "max_tokens": 512, + "tools": None, + "messages": hf_messages, + } assert isinstance(response, dict) assert "replies" in response @@ -197,7 +324,7 @@ def test_generate_text_response_with_valid_prompt_and_generation_parameters( assert len(response["replies"]) == 1 assert [isinstance(reply, ChatMessage) for reply in response["replies"]] - def test_generate_text_with_streaming_callback(self, mock_check_valid_model, mock_chat_completion, chat_messages): + def test_run_with_streaming_callback(self, mock_check_valid_model, mock_chat_completion, chat_messages): streaming_call_count = 0 # Define the streaming callback function @@ -260,13 +387,78 @@ def mock_iter(self): assert len(response["replies"]) > 0 assert [isinstance(reply, ChatMessage) for reply in response["replies"]] - @pytest.mark.flaky(reruns=5, reruns_delay=5) + def test_run_fail_with_tools_and_streaming(self, tools, mock_check_valid_model): + component = HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "meta-llama/Llama-2-13b-chat-hf"}, + streaming_callback=streaming_callback_handler, + ) + + with pytest.raises(ValueError): + message = ChatMessage.from_user("irrelevant") + component.run([message], tools=tools) + + def test_run_with_tools(self, mock_check_valid_model, tools): + generator = HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "meta-llama/Llama-3.1-70B-Instruct"}, + tools=tools, + ) + + with patch("huggingface_hub.InferenceClient.chat_completion", autospec=True) as mock_chat_completion: + completion = ChatCompletionOutput( + choices=[ + ChatCompletionOutputComplete( + finish_reason="stop", + index=0, + message=ChatCompletionOutputMessage( + role="assistant", + content=None, + tool_calls=[ + ChatCompletionOutputToolCall( + function=ChatCompletionOutputFunctionDefinition( + arguments={"city": "Paris"}, name="weather", description=None + ), + id="0", + type="function", + ) + ], + ), + logprobs=None, + ) + ], + created=1729074760, + id="", + model="meta-llama/Llama-3.1-70B-Instruct", + system_fingerprint="2.3.2-dev0-sha-28bb7ae", + usage=ChatCompletionOutputUsage(completion_tokens=30, prompt_tokens=426, total_tokens=456), + ) + mock_chat_completion.return_value = completion + + messages = [ChatMessage.from_user("What is the weather in Paris?")] + response = generator.run(messages=messages) + + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, ChatMessage) for reply in response["replies"]] + assert response["replies"][0].tool_calls[0].tool_name == "weather" + assert response["replies"][0].tool_calls[0].arguments == {"city": "Paris"} + assert response["replies"][0].tool_calls[0].id == "0" + assert response["replies"][0].meta == { + "finish_reason": "stop", + "index": 0, + "model": "meta-llama/Llama-3.1-70B-Instruct", + "usage": {"completion_tokens": 30, "prompt_tokens": 426}, + } + @pytest.mark.integration @pytest.mark.skipif( not os.environ.get("HF_API_TOKEN", None), reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.", ) - def test_run_serverless(self): + def test_live_run_serverless(self): generator = HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, @@ -284,13 +476,12 @@ def test_run_serverless(self): assert "prompt_tokens" in response["replies"][0].meta["usage"] assert "completion_tokens" in response["replies"][0].meta["usage"] - @pytest.mark.flaky(reruns=5, reruns_delay=5) @pytest.mark.integration @pytest.mark.skipif( not os.environ.get("HF_API_TOKEN", None), reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.", ) - def test_run_serverless_streaming(self): + def test_live_run_serverless_streaming(self): generator = HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, @@ -308,3 +499,47 @@ def test_run_serverless_streaming(self): assert "usage" in response["replies"][0].meta assert "prompt_tokens" in response["replies"][0].meta["usage"] assert "completion_tokens" in response["replies"][0].meta["usage"] + + @pytest.mark.integration + @pytest.mark.skipif( + not os.environ.get("HF_API_TOKEN", None), + reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.", + ) + @pytest.mark.integration + def test_live_run_with_tools(self, tools): + """ + We test the round trip: generate tool call, pass tool message, generate response. + + The model used here (zephyr-7b-beta) is always available and not gated. + Even if it does not officially support tools, TGI+HF API make it work. + """ + + chat_messages = [ChatMessage.from_user("What's the weather like in Paris and Munich?")] + generator = HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, + generation_kwargs={"temperature": 0.5}, + ) + + results = generator.run(chat_messages, tools=tools) + assert len(results["replies"]) == 1 + message = results["replies"][0] + + assert message.tool_calls + tool_call = message.tool_call + assert isinstance(tool_call, ToolCall) + assert tool_call.tool_name == "weather" + assert "city" in tool_call.arguments + assert "Paris" in tool_call.arguments["city"] + assert message.meta["finish_reason"] == "stop" + + new_messages = chat_messages + [message, ChatMessage.from_tool(tool_result="22° C", origin=tool_call)] + + # the model tends to make tool calls if provided with tools, so we don't pass them here + results = generator.run(new_messages, generation_kwargs={"max_tokens": 50}) + + assert len(results["replies"]) == 1 + final_message = results["replies"][0] + assert not final_message.tool_calls + assert len(final_message.text) > 0 + assert "paris" in final_message.text.lower() diff --git a/test/dataclasses/test_tool.py b/test/dataclasses/test_tool.py index db9719a7f3..9e112853f3 100644 --- a/test/dataclasses/test_tool.py +++ b/test/dataclasses/test_tool.py @@ -12,6 +12,7 @@ ToolInvocationError, _remove_title_from_schema, deserialize_tools_inplace, + _check_duplicate_tool_names, ) try: @@ -303,3 +304,18 @@ def test_remove_title_from_schema_handle_no_title_in_top_level(): "properties": {"parameter1": {"type": "string"}, "parameter2": {"type": "integer"}}, "type": "object", } + + +def test_check_duplicate_tool_names(): + tools = [ + Tool(name="weather", description="Get weather report", parameters=parameters, function=get_weather_report), + Tool(name="weather", description="A different description", parameters=parameters, function=get_weather_report), + ] + with pytest.raises(ValueError): + _check_duplicate_tool_names(tools) + + tools = [ + Tool(name="weather", description="Get weather report", parameters=parameters, function=get_weather_report), + Tool(name="weather2", description="Get weather report", parameters=parameters, function=get_weather_report), + ] + _check_duplicate_tool_names(tools) diff --git a/test/utils/test_hf.py b/test/utils/test_hf.py index 4350fb9fbb..d75e0b7501 100644 --- a/test/utils/test_hf.py +++ b/test/utils/test_hf.py @@ -2,8 +2,12 @@ # # SPDX-License-Identifier: Apache-2.0 import logging -from haystack.utils.hf import resolve_hf_device_map + +import pytest + +from haystack.utils.hf import resolve_hf_device_map, convert_message_to_hf_format from haystack.utils.device import ComponentDevice +from haystack.dataclasses import ChatMessage, ToolCall, ChatRole, TextContent def test_resolve_hf_device_map_only_device(): @@ -23,3 +27,56 @@ def test_resolve_hf_device_map_device_and_device_map(caplog): ) assert "The parameters `device` and `device_map` from `model_kwargs` are both provided." in caplog.text assert model_kwargs["device_map"] == "cuda:0" + + +def test_convert_message_to_hf_format(): + message = ChatMessage.from_system("You are good assistant") + assert convert_message_to_hf_format(message) == {"role": "system", "content": "You are good assistant"} + + message = ChatMessage.from_user("I have a question") + assert convert_message_to_hf_format(message) == {"role": "user", "content": "I have a question"} + + message = ChatMessage.from_assistant(text="I have an answer", meta={"finish_reason": "stop"}) + assert convert_message_to_hf_format(message) == {"role": "assistant", "content": "I have an answer"} + + message = ChatMessage.from_assistant( + tool_calls=[ToolCall(id="123", tool_name="weather", arguments={"city": "Paris"})] + ) + assert convert_message_to_hf_format(message) == { + "role": "assistant", + "content": "", + "tool_calls": [ + {"id": "123", "type": "function", "function": {"name": "weather", "arguments": {"city": "Paris"}}} + ], + } + + message = ChatMessage.from_assistant(tool_calls=[ToolCall(tool_name="weather", arguments={"city": "Paris"})]) + assert convert_message_to_hf_format(message) == { + "role": "assistant", + "content": "", + "tool_calls": [{"type": "function", "function": {"name": "weather", "arguments": {"city": "Paris"}}}], + } + + tool_result = {"weather": "sunny", "temperature": "25"} + message = ChatMessage.from_tool( + tool_result=tool_result, origin=ToolCall(id="123", tool_name="weather", arguments={"city": "Paris"}) + ) + assert convert_message_to_hf_format(message) == {"role": "tool", "content": tool_result, "tool_call_id": "123"} + + message = ChatMessage.from_tool( + tool_result=tool_result, origin=ToolCall(tool_name="weather", arguments={"city": "Paris"}) + ) + assert convert_message_to_hf_format(message) == {"role": "tool", "content": tool_result} + + +def test_convert_message_to_hf_invalid(): + message = ChatMessage(_role=ChatRole.ASSISTANT, _content=[]) + with pytest.raises(ValueError): + convert_message_to_hf_format(message) + + message = ChatMessage( + _role=ChatRole.ASSISTANT, + _content=[TextContent(text="I have an answer"), TextContent(text="I have another answer")], + ) + with pytest.raises(ValueError): + convert_message_to_hf_format(message) From f4d9c2bb917be0ffe132dffcc2ad4f1b0fcc5967 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Thu, 19 Dec 2024 15:12:12 +0100 Subject: [PATCH 14/14] fix: Make the `HuggingFaceLocalChatGenerator` compatible with the new `ChatMessage`; serialize `chat_template` (#8663) * message conversion function * hfapi w tools * right test file + hf_hub version * release note * fix for new chatmessage; serialize chat_template * feedback --- .../generators/chat/hugging_face_local.py | 6 ++- .../hflocalchat-fixes-ddf71e8c4c73e566.yaml | 7 ++++ .../chat/test_hugging_face_local.py | 37 +++++++++++++++++++ 3 files changed, 49 insertions(+), 1 deletion(-) create mode 100644 releasenotes/notes/hflocalchat-fixes-ddf71e8c4c73e566.yaml diff --git a/haystack/components/generators/chat/hugging_face_local.py b/haystack/components/generators/chat/hugging_face_local.py index 988bffc8b4..1ad152f1e3 100644 --- a/haystack/components/generators/chat/hugging_face_local.py +++ b/haystack/components/generators/chat/hugging_face_local.py @@ -25,6 +25,7 @@ from haystack.utils.hf import ( # pylint: disable=ungrouped-imports HFTokenStreamingHandler, StopWordsCriteria, + convert_message_to_hf_format, deserialize_hf_model_kwargs, serialize_hf_model_kwargs, ) @@ -201,6 +202,7 @@ def to_dict(self) -> Dict[str, Any]: generation_kwargs=self.generation_kwargs, streaming_callback=callback_name, token=self.token.to_dict() if self.token else None, + chat_template=self.chat_template, ) huggingface_pipeline_kwargs = serialization_dict["init_parameters"]["huggingface_pipeline_kwargs"] @@ -270,9 +272,11 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, # streamer parameter hooks into HF streaming, HFTokenStreamingHandler is an adapter to our streaming generation_kwargs["streamer"] = HFTokenStreamingHandler(tokenizer, self.streaming_callback, stop_words) + hf_messages = [convert_message_to_hf_format(message) for message in messages] + # Prepare the prompt for the model prepared_prompt = tokenizer.apply_chat_template( - messages, tokenize=False, chat_template=self.chat_template, add_generation_prompt=True + hf_messages, tokenize=False, chat_template=self.chat_template, add_generation_prompt=True ) # Avoid some unnecessary warnings in the generation pipeline call diff --git a/releasenotes/notes/hflocalchat-fixes-ddf71e8c4c73e566.yaml b/releasenotes/notes/hflocalchat-fixes-ddf71e8c4c73e566.yaml new file mode 100644 index 0000000000..fd8c96a6bb --- /dev/null +++ b/releasenotes/notes/hflocalchat-fixes-ddf71e8c4c73e566.yaml @@ -0,0 +1,7 @@ +--- +fixes: + - | + Make the HuggingFaceLocalChatGenerator compatible with the new ChatMessage format, by converting the messages to + the format expected by Hugging Face. + + Serialize the chat_template parameter. diff --git a/test/components/generators/chat/test_hugging_face_local.py b/test/components/generators/chat/test_hugging_face_local.py index 8f6749c2d8..fe5308b7b3 100644 --- a/test/components/generators/chat/test_hugging_face_local.py +++ b/test/components/generators/chat/test_hugging_face_local.py @@ -135,6 +135,7 @@ def test_to_dict(self, model_info_mock): generation_kwargs={"n": 5}, stop_words=["stop", "words"], streaming_callback=lambda x: x, + chat_template="irrelevant", ) # Call the to_dict method @@ -146,6 +147,7 @@ def test_to_dict(self, model_info_mock): assert init_params["huggingface_pipeline_kwargs"]["model"] == "NousResearch/Llama-2-7b-chat-hf" assert "token" not in init_params["huggingface_pipeline_kwargs"] assert init_params["generation_kwargs"] == {"max_new_tokens": 512, "n": 5, "stop_sequences": ["stop", "words"]} + assert init_params["chat_template"] == "irrelevant" def test_from_dict(self, model_info_mock): generator = HuggingFaceLocalChatGenerator( @@ -153,6 +155,7 @@ def test_from_dict(self, model_info_mock): generation_kwargs={"n": 5}, stop_words=["stop", "words"], streaming_callback=streaming_callback_handler, + chat_template="irrelevant", ) # Call the to_dict method result = generator.to_dict() @@ -162,6 +165,7 @@ def test_from_dict(self, model_info_mock): assert generator_2.token == Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False) assert generator_2.generation_kwargs == {"max_new_tokens": 512, "n": 5, "stop_sequences": ["stop", "words"]} assert generator_2.streaming_callback is streaming_callback_handler + assert generator_2.chat_template == "irrelevant" @patch("haystack.components.generators.chat.hugging_face_local.pipeline") def test_warm_up(self, pipeline_mock, monkeypatch): @@ -218,3 +222,36 @@ def test_run_with_custom_generation_parameters(self, model_info_mock, mock_pipel chat_message = results["replies"][0] assert chat_message.is_from(ChatRole.ASSISTANT) assert chat_message.text == "Berlin is cool" + + @patch("haystack.components.generators.chat.hugging_face_local.convert_message_to_hf_format") + def test_messages_conversion_is_called(self, mock_convert, model_info_mock): + generator = HuggingFaceLocalChatGenerator(model="fake-model") + + messages = [ChatMessage.from_user("Hello"), ChatMessage.from_assistant("Hi there")] + + with patch.object(generator, "pipeline") as mock_pipeline: + mock_pipeline.tokenizer.apply_chat_template.return_value = "test prompt" + mock_pipeline.return_value = [{"generated_text": "test response"}] + + generator.warm_up() + generator.run(messages) + + assert mock_convert.call_count == 2 + mock_convert.assert_any_call(messages[0]) + mock_convert.assert_any_call(messages[1]) + + @pytest.mark.integration + @pytest.mark.flaky(reruns=3, reruns_delay=10) + def test_live_run(self): + messages = [ChatMessage.from_user("Please create a summary about the following topic: Climate change")] + + llm = HuggingFaceLocalChatGenerator( + model="Qwen/Qwen2.5-0.5B-Instruct", generation_kwargs={"max_new_tokens": 50} + ) + llm.warm_up() + + result = llm.run(messages) + + assert "replies" in result + assert isinstance(result["replies"][0], ChatMessage) + assert "climate change" in result["replies"][0].text.lower()