Skip to content

Commit

Permalink
text-splitters: Inconsistent results with NLTKTextSplitter's `add_s…
Browse files Browse the repository at this point in the history
…tart_index=True` (#27782)

This PR closes #27781

# Problem
The current implementation of `NLTKTextSplitter` is using
`sent_tokenize`. However, this `sent_tokenize` doesn't handle chars
between 2 tokenized sentences... hence, this behavior throws errors when
we are using `add_start_index=True`, as described in issue #27781. In
particular:
```python
from nltk.tokenize import sent_tokenize

output1 = sent_tokenize("Innovation drives our success. Collaboration fosters creative solutions. Efficiency enhances data management.", language="english")
print(output1)
output2 = sent_tokenize("Innovation drives our success.        Collaboration fosters creative solutions. Efficiency enhances data management.", language="english")
print(output2)
>>> ['Innovation drives our success.', 'Collaboration fosters creative solutions.', 'Efficiency enhances data management.']
>>> ['Innovation drives our success.', 'Collaboration fosters creative solutions.', 'Efficiency enhances data management.']
```

# Solution
With this new `use_span_tokenize` parameter, we can use NLTK to create
sentences (with `span_tokenize`), but also add extra chars to be sure
that we still can map the chunks to the original text.

---------

Co-authored-by: Erick Friis <[email protected]>
Co-authored-by: Erick Friis <[email protected]>
  • Loading branch information
3 people authored Dec 16, 2024
1 parent d262d41 commit b2102b8
Show file tree
Hide file tree
Showing 7 changed files with 1,848 additions and 27 deletions.
3 changes: 3 additions & 0 deletions libs/text-splitters/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ TEST_FILE ?= tests/unit_tests/
test tests:
poetry run pytest --disable-socket --allow-unix-socket $(TEST_FILE)

integration_test integration_tests:
poetry run pytest tests/integration_tests/

test_watch:
poetry run ptw --snapshot-update --now . -- -vv -x tests/unit_tests

Expand Down
36 changes: 30 additions & 6 deletions libs/text-splitters/langchain_text_splitters/nltk.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,47 @@ class NLTKTextSplitter(TextSplitter):
"""Splitting text using NLTK package."""

def __init__(
self, separator: str = "\n\n", language: str = "english", **kwargs: Any
self,
separator: str = "\n\n",
language: str = "english",
*,
use_span_tokenize: bool = False,
**kwargs: Any,
) -> None:
"""Initialize the NLTK splitter."""
super().__init__(**kwargs)
self._separator = separator
self._language = language
self._use_span_tokenize = use_span_tokenize
if self._use_span_tokenize and self._separator != "":
raise ValueError("When use_span_tokenize is True, separator should be ''")
try:
from nltk.tokenize import sent_tokenize
if self._use_span_tokenize:
from nltk.tokenize import _get_punkt_tokenizer

self._tokenizer = _get_punkt_tokenizer(self._language)
else:
from nltk.tokenize import sent_tokenize

self._tokenizer = sent_tokenize
self._tokenizer = sent_tokenize
except ImportError:
raise ImportError(
"NLTK is not installed, please install it with `pip install nltk`."
)
self._separator = separator
self._language = language

def split_text(self, text: str) -> List[str]:
"""Split incoming text and return chunks."""
# First we naively split the large input into a bunch of smaller ones.
splits = self._tokenizer(text, language=self._language)
if self._use_span_tokenize:
spans = list(self._tokenizer.span_tokenize(text))
splits = []
for i, (start, end) in enumerate(spans):
if i > 0:
prev_end = spans[i - 1][1]
sentence = text[prev_end:start] + text[start:end]
else:
sentence = text[start:end]
splits.append(sentence)
else:
splits = self._tokenizer(text, language=self._language)
return self._merge_splits(splits, self._separator)
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(
from sentence_transformers import SentenceTransformer
except ImportError:
raise ImportError(
"Could not import sentence_transformer python package. "
"Could not import sentence_transformers python package. "
"This is needed in order to for SentenceTransformersTokenTextSplitter. "
"Please install it with `pip install sentence-transformers`."
)
Expand Down
1,720 changes: 1,718 additions & 2 deletions libs/text-splitters/poetry.lock

Large diffs are not rendered by default.

41 changes: 29 additions & 12 deletions libs/text-splitters/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[build-system]
requires = [ "poetry-core>=1.0.0",]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"

[tool.poetry]
Expand All @@ -14,7 +14,20 @@ repository = "https://github.com/langchain-ai/langchain"
[tool.mypy]
disallow_untyped_defs = "True"
[[tool.mypy.overrides]]
module = [ "transformers", "sentence_transformers", "nltk.tokenize", "konlpy.tag", "bs4", "pytest", "spacy", "spacy.lang.en", "numpy",]
module = [
"transformers",
"sentence_transformers",
"nltk.tokenize",
"konlpy.tag",
"bs4",
"pytest",
"spacy",
"spacy.lang.en",
"numpy",
"nltk",
"spacy.cli",
"torch",
]
ignore_missing_imports = "True"

[tool.poetry.urls]
Expand All @@ -26,15 +39,18 @@ python = ">=3.9,<4.0"
langchain-core = "^0.3.25"

[tool.ruff.lint]
select = [ "E", "F", "I", "T201", "D",]
ignore = [ "D100",]
select = ["E", "F", "I", "T201", "D"]
ignore = ["D100"]

[tool.coverage.run]
omit = [ "tests/*",]
omit = ["tests/*"]

[tool.pytest.ini_options]
addopts = "--strict-markers --strict-config --durations=5"
markers = [ "requires: mark tests as requiring a specific library", "compile: mark placeholder test used to compile integration tests without running them",]
markers = [
"requires: mark tests as requiring a specific library",
"compile: mark placeholder test used to compile integration tests without running them",
]
asyncio_mode = "auto"

[tool.poetry.group.lint]
Expand All @@ -53,19 +69,17 @@ optional = true
convention = "google"

[tool.ruff.lint.per-file-ignores]
"tests/**" = [ "D",]
"tests/**" = ["D"]

[tool.poetry.group.lint.dependencies]
ruff = "^0.5"


[tool.poetry.group.typing.dependencies]
mypy = "^1.10"
lxml-stubs = "^0.5.1"
types-requests = "^2.31.0.20240218"
tiktoken = "^0.8.0"


[tool.poetry.group.dev.dependencies]
jupyter = "^1.0.0"

Expand All @@ -78,20 +92,23 @@ pytest-watcher = "^0.3.4"
pytest-asyncio = "^0.21.1"
pytest-socket = "^0.7.0"

[tool.poetry.group.test_integration]
optional = true

[tool.poetry.group.test_integration.dependencies]

spacy = { version = "*", python = "<3.13" }
nltk = "^3.9.1"
transformers = "^4.47.0"
sentence-transformers = { version = ">=2.6.0", python = "<3.13" }

[tool.poetry.group.lint.dependencies.langchain-core]
path = "../core"
develop = true


[tool.poetry.group.dev.dependencies.langchain-core]
path = "../core"
develop = true


[tool.poetry.group.test.dependencies.langchain-core]
path = "../core"
develop = true
Original file line number Diff line number Diff line change
@@ -1,18 +1,36 @@
"""Test text splitting functionality using NLTK and Spacy based sentence splitters."""

from typing import Any

import nltk
import pytest
from langchain_core.documents import Document

from langchain_text_splitters.nltk import NLTKTextSplitter
from langchain_text_splitters.spacy import SpacyTextSplitter


def setup_module() -> None:
nltk.download("punkt_tab")


@pytest.fixture()
def spacy() -> Any:
try:
import spacy
except ImportError:
pytest.skip("Spacy not installed.")
spacy.cli.download("en_core_web_sm") # type: ignore
return spacy


def test_nltk_text_splitting_args() -> None:
"""Test invalid arguments."""
with pytest.raises(ValueError):
NLTKTextSplitter(chunk_size=2, chunk_overlap=4)


def test_spacy_text_splitting_args() -> None:
def test_spacy_text_splitting_args(spacy: Any) -> None:
"""Test invalid arguments."""
with pytest.raises(ValueError):
SpacyTextSplitter(chunk_size=2, chunk_overlap=4)
Expand All @@ -29,7 +47,7 @@ def test_nltk_text_splitter() -> None:


@pytest.mark.parametrize("pipeline", ["sentencizer", "en_core_web_sm"])
def test_spacy_text_splitter(pipeline: str) -> None:
def test_spacy_text_splitter(pipeline: str, spacy: Any) -> None:
"""Test splitting by sentence using Spacy."""
text = "This is sentence one. And this is sentence two."
separator = "|||"
Expand All @@ -40,7 +58,7 @@ def test_spacy_text_splitter(pipeline: str) -> None:


@pytest.mark.parametrize("pipeline", ["sentencizer", "en_core_web_sm"])
def test_spacy_text_splitter_strip_whitespace(pipeline: str) -> None:
def test_spacy_text_splitter_strip_whitespace(pipeline: str, spacy: Any) -> None:
"""Test splitting by sentence using Spacy."""
text = "This is sentence one. And this is sentence two."
separator = "|||"
Expand All @@ -50,3 +68,35 @@ def test_spacy_text_splitter_strip_whitespace(pipeline: str) -> None:
output = splitter.split_text(text)
expected_output = [f"This is sentence one. {separator}And this is sentence two."]
assert output == expected_output


def test_nltk_text_splitter_args() -> None:
"""Test invalid arguments for NLTKTextSplitter."""
with pytest.raises(ValueError):
NLTKTextSplitter(
chunk_size=80,
chunk_overlap=0,
separator="\n\n",
use_span_tokenize=True,
)


def test_nltk_text_splitter_with_add_start_index() -> None:
splitter = NLTKTextSplitter(
chunk_size=80,
chunk_overlap=0,
separator="",
use_span_tokenize=True,
add_start_index=True,
)
txt = (
"Innovation drives our success. "
"Collaboration fosters creative solutions. "
"Efficiency enhances data management."
)
docs = [Document(txt)]
chunks = splitter.split_documents(docs)
assert len(chunks) == 2
for chunk in chunks:
s_i = chunk.metadata["start_index"]
assert chunk.page_content == txt[s_i : s_i + len(chunk.page_content)]
17 changes: 14 additions & 3 deletions libs/text-splitters/tests/integration_tests/test_text_splitter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Test text splitters that require an integration."""

from typing import Any

import pytest

from langchain_text_splitters import (
Expand All @@ -11,6 +13,15 @@
)


@pytest.fixture()
def sentence_transformers() -> Any:
try:
import sentence_transformers
except ImportError:
pytest.skip("SentenceTransformers not installed.")
return sentence_transformers


def test_huggingface_type_check() -> None:
"""Test that type checks are done properly on input."""
with pytest.raises(ValueError):
Expand Down Expand Up @@ -52,7 +63,7 @@ def test_token_text_splitter_from_tiktoken() -> None:
assert expected_tokenizer == actual_tokenizer


def test_sentence_transformers_count_tokens() -> None:
def test_sentence_transformers_count_tokens(sentence_transformers: Any) -> None:
splitter = SentenceTransformersTokenTextSplitter(
model_name="sentence-transformers/paraphrase-albert-small-v2"
)
Expand All @@ -67,7 +78,7 @@ def test_sentence_transformers_count_tokens() -> None:
assert expected_token_count == token_count


def test_sentence_transformers_split_text() -> None:
def test_sentence_transformers_split_text(sentence_transformers: Any) -> None:
splitter = SentenceTransformersTokenTextSplitter(
model_name="sentence-transformers/paraphrase-albert-small-v2"
)
Expand All @@ -77,7 +88,7 @@ def test_sentence_transformers_split_text() -> None:
assert expected_text_chunks == text_chunks


def test_sentence_transformers_multiple_tokens() -> None:
def test_sentence_transformers_multiple_tokens(sentence_transformers: Any) -> None:
splitter = SentenceTransformersTokenTextSplitter(chunk_overlap=0)
text = "Lorem "

Expand Down

0 comments on commit b2102b8

Please sign in to comment.