Skip to content

Commit

Permalink
Merge branch 'main' into james/v0.0.6
Browse files Browse the repository at this point in the history
merge
  • Loading branch information
jamescalam committed May 28, 2024
2 parents 2c26456 + 8bacc9b commit b1c3d59
Show file tree
Hide file tree
Showing 9 changed files with 61 additions and 20 deletions.
13 changes: 9 additions & 4 deletions semantic_chunkers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
from semantic_chunkers.chunkers import BaseChunker
from semantic_chunkers.chunkers import ConsecutiveChunker
from semantic_chunkers.chunkers import CumulativeChunker
from semantic_chunkers.chunkers import StatisticalChunker
from semantic_chunkers.chunkers import (
BaseChunker,
ConsecutiveChunker,
CumulativeChunker,
StatisticalChunker,
)
from semantic_chunkers.splitters import BaseSplitter, RegexSplitter

__all__ = [
"BaseChunker",
"ConsecutiveChunker",
"CumulativeChunker",
"StatisticalChunker",
"BaseSplitter",
"RegexSplitter",
]

__version__ = "0.0.6"
5 changes: 3 additions & 2 deletions semantic_chunkers/chunkers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@

from semantic_router.encoders.base import BaseEncoder
from semantic_chunkers.schema import Chunk
from semantic_chunkers.splitters.sentence import regex_splitter
from semantic_chunkers.splitters.base import BaseSplitter


class BaseChunker(BaseModel):
name: str
encoder: BaseEncoder
splitter: BaseSplitter

class Config:
extra = Extra.allow
Expand All @@ -19,7 +20,7 @@ def __call__(self, docs: List[str]) -> List[List[Chunk]]:
raise NotImplementedError("Subclasses must implement this method")

def _split(self, doc: str) -> List[str]:
return regex_splitter(doc)
return self.splitter(doc)

def _chunk(self, splits: List[Any]) -> List[Chunk]:
raise NotImplementedError("Subclasses must implement this method")
Expand Down
5 changes: 4 additions & 1 deletion semantic_chunkers/chunkers/consecutive.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from semantic_router.encoders.base import BaseEncoder
from semantic_chunkers.schema import Chunk
from semantic_chunkers.chunkers.base import BaseChunker
from semantic_chunkers.splitters.base import BaseSplitter
from semantic_chunkers.splitters.sentence import RegexSplitter


class ConsecutiveChunker(BaseChunker):
Expand All @@ -16,10 +18,11 @@ class ConsecutiveChunker(BaseChunker):
def __init__(
self,
encoder: BaseEncoder,
splitter: BaseSplitter = RegexSplitter(),
name: str = "consecutive_chunker",
score_threshold: float = 0.45,
):
super().__init__(name=name, encoder=encoder)
super().__init__(name=name, encoder=encoder, splitter=splitter)
encoder.score_threshold = score_threshold
self.score_threshold = score_threshold

Expand Down
5 changes: 4 additions & 1 deletion semantic_chunkers/chunkers/cumulative.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from semantic_router.encoders import BaseEncoder
from semantic_chunkers.schema import Chunk
from semantic_chunkers.chunkers.base import BaseChunker
from semantic_chunkers.splitters.base import BaseSplitter
from semantic_chunkers.splitters.sentence import RegexSplitter


class CumulativeChunker(BaseChunker):
Expand All @@ -17,10 +19,11 @@ class CumulativeChunker(BaseChunker):
def __init__(
self,
encoder: BaseEncoder,
splitter: BaseSplitter = RegexSplitter(),
name: str = "cumulative_chunker",
score_threshold: float = 0.45,
):
super().__init__(name=name, encoder=encoder)
super().__init__(name=name, encoder=encoder, splitter=splitter)
encoder.score_threshold = score_threshold
self.score_threshold = score_threshold

Expand Down
5 changes: 4 additions & 1 deletion semantic_chunkers/chunkers/statistical.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from semantic_router.encoders.base import BaseEncoder
from semantic_chunkers.schema import Chunk
from semantic_chunkers.chunkers.base import BaseChunker
from semantic_chunkers.splitters.base import BaseSplitter
from semantic_chunkers.splitters.sentence import RegexSplitter
from semantic_chunkers.utils.text import tiktoken_length
from semantic_chunkers.utils.logger import logger

Expand Down Expand Up @@ -41,6 +43,7 @@ class StatisticalChunker(BaseChunker):
def __init__(
self,
encoder: BaseEncoder,
splitter: BaseSplitter = RegexSplitter(),
name="statistical_chunker",
threshold_adjustment=0.01,
dynamic_threshold: bool = True,
Expand All @@ -51,7 +54,7 @@ def __init__(
plot_chunks=False,
enable_statistics=False,
):
super().__init__(name=name, encoder=encoder)
super().__init__(name=name, encoder=encoder, splitter=splitter)
self.calculated_threshold: float
self.encoder = encoder
self.threshold_adjustment = threshold_adjustment
Expand Down
8 changes: 8 additions & 0 deletions semantic_chunkers/splitters/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from semantic_chunkers.splitters.base import BaseSplitter
from semantic_chunkers.splitters.sentence import RegexSplitter


__all__ = [
"BaseSplitter",
"RegexSplitter",
]
11 changes: 11 additions & 0 deletions semantic_chunkers/splitters/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from typing import List

from pydantic.v1 import BaseModel, Extra


class BaseSplitter(BaseModel):
class Config:
extra = Extra.allow

def __call__(self, doc: str) -> List[str]:
raise NotImplementedError("Subclasses must implement this method")
20 changes: 10 additions & 10 deletions semantic_chunkers/splitters/sentence.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import regex
from typing import List

from semantic_chunkers.splitters.base import BaseSplitter

def regex_splitter(text: str) -> list[str]:

class RegexSplitter(BaseSplitter):
"""
Enhanced regex pattern to split a given text into sentences more accurately.
Expand All @@ -11,13 +14,8 @@ def regex_splitter(text: str) -> list[str]:
- Decimal numbers and dates.
- Ellipses and other punctuation marks used in informal text.
- Removing control characters and format characters.
Args:
text (str): The text to split into sentences.
Returns:
list: A list of sentences extracted from the text.
"""

regex_pattern = r"""
# Negative lookbehind for word boundary, word char, dot, word char
(?<!\b\w\.\w.)
Expand Down Expand Up @@ -51,6 +49,8 @@ def regex_splitter(text: str) -> list[str]:
# Matches and removes control characters and format characters
[\p{Cc}\p{Cf}]+
"""
sentences = regex.split(regex_pattern, text, flags=regex.VERBOSE)
sentences = [sentence.strip() for sentence in sentences if sentence.strip()]
return sentences

def __call__(self, doc: str) -> List[str]:
sentences = regex.split(self.regex_pattern, doc, flags=regex.VERBOSE)
sentences = [sentence.strip() for sentence in sentences if sentence.strip()]
return sentences
9 changes: 8 additions & 1 deletion tests/unit/test_splitters.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from semantic_router.encoders.base import BaseEncoder
from semantic_router.encoders.cohere import CohereEncoder
from semantic_chunkers import BaseChunker
from semantic_chunkers import BaseSplitter
from semantic_chunkers import ConsecutiveChunker
from semantic_chunkers import CumulativeChunker

Expand Down Expand Up @@ -106,7 +107,13 @@ def base_splitter_instance():
mock_encoder = Mock(spec=BaseEncoder)
mock_encoder.name = "mock_encoder"
mock_encoder.score_threshold = 0.5
return BaseChunker(name="test_splitter", encoder=mock_encoder, score_threshold=0.5)
mock_splitter = Mock(spec=BaseSplitter)
return BaseChunker(
name="test_splitter",
encoder=mock_encoder,
splitter=mock_splitter,
score_threshold=0.5,
)


def test_base_splitter_call_not_implemented(base_splitter_instance):
Expand Down

0 comments on commit b1c3d59

Please sign in to comment.