Skip to content

Commit

Permalink
improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
strickvl committed May 6, 2024
1 parent 9cd5f87 commit 799daab
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 51 deletions.
1 change: 1 addition & 0 deletions llm-complete-guide/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.
#


# Vector Store constants
CHUNK_SIZE = 2000
CHUNK_OVERLAP = 50
Expand Down
27 changes: 11 additions & 16 deletions llm-complete-guide/pipelines/llm_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,32 +14,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from steps.eval_e2e import e2e_evaluation, e2e_evaluation_llm_judged
from steps.eval_e2e import e2e_evaluation
from steps.eval_retrieval import (
retrieval_evaluation_full,
retrieval_evaluation_small,
)
from steps.eval_visualisation import visualize_evaluation_results
from zenml import pipeline


@pipeline
def llm_eval() -> None:
"""Executes the pipeline to evaluate a RAG pipeline."""
# Retrieval evals
failure_rate_retrieval = retrieval_evaluation_small()
(
failure_rate_bad_answers,
failure_rate_bad_immediate_responses,
failure_rate_good_responses,
) = e2e_evaluation()
full_retrieval_answers = retrieval_evaluation_full()

full_failure_rate_retrieval = retrieval_evaluation_full()
# E2E evals
e2e_eval_tuple = e2e_evaluation()
# e2e_llm_judged_tuple = e2e_evaluation_llm_judged()

visualize_evaluation_results(
failure_rate_retrieval,
failure_rate_bad_answers,
failure_rate_bad_immediate_responses,
failure_rate_good_responses,
)

e2e_evaluation_llm_judged()
# visualize_evaluation_results(
# failure_rate_retrieval,
# e2e_answers,
# full_retrieval_answers,
# )
1 change: 1 addition & 0 deletions llm-complete-guide/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ umap-learn
matplotlib
pyarrow
rerankers[all]
datasets

# optional requirements for S3 artifact store
# s3fs>2022.3.0
Expand Down
4 changes: 3 additions & 1 deletion llm-complete-guide/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,5 +139,7 @@ def main(
if __name__ == "__main__":
# use custom materializer for documents
# register early
materializer_registry.register_materializer_type(Document, DocumentMaterializer)
materializer_registry.register_materializer_type(
Document, DocumentMaterializer
)
main()
52 changes: 24 additions & 28 deletions llm-complete-guide/steps/eval_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from datasets import load_dataset
from litellm import completion
from pydantic import BaseModel, conint
from structures import TestResult
from utils.llm_utils import process_input_with_retrieval
from zenml import step

Expand Down Expand Up @@ -58,13 +59,6 @@
]


class TestResult(BaseModel):
success: bool
question: str
keyword: str = ""
response: str


def test_content_for_bad_words(
item: dict, n_items_retrieved: int = 5
) -> TestResult:
Expand Down Expand Up @@ -290,24 +284,6 @@ def run_llm_judged_tests(
)


@step(enable_cache=False)
def e2e_evaluation_llm_judged() -> (
Tuple[
Annotated[float, "average_toxicity_score"],
Annotated[float, "average_faithfulness_score"],
Annotated[float, "average_helpfulness_score"],
Annotated[float, "average_relevance_score"],
]
):
"""Executes the end-to-end evaluation step.
Returns:
Tuple: The average toxicity, faithfulness, helpfulness, and relevance scores.
"""
logging.info("Starting end-to-end evaluation...")
return run_llm_judged_tests(llm_judged_test_e2e)


def run_simple_tests(test_data: list, test_function: Callable) -> float:
"""
Run tests for bad answers.
Expand Down Expand Up @@ -337,9 +313,11 @@ def run_simple_tests(test_data: list, test_function: Callable) -> float:

@step
def e2e_evaluation() -> (
Annotated[float, "failure_rate_bad_answers"],
Annotated[float, "failure_rate_bad_immediate_responses"],
Annotated[float, "failure_rate_good_responses"],
Tuple[
Annotated[float, "failure_rate_bad_answers"],
Annotated[float, "failure_rate_bad_immediate_responses"],
Annotated[float, "failure_rate_good_responses"],
]
):
"""Executes the end-to-end evaluation step."""
logging.info("Testing bad answers...")
Expand Down Expand Up @@ -368,3 +346,21 @@ def e2e_evaluation() -> (
failure_rate_bad_immediate_responses,
failure_rate_good_responses,
)


@step(enable_cache=False)
def e2e_evaluation_llm_judged() -> (
Tuple[
Annotated[float, "average_toxicity_score"],
Annotated[float, "average_faithfulness_score"],
Annotated[float, "average_helpfulness_score"],
Annotated[float, "average_relevance_score"],
]
):
"""Executes the end-to-end evaluation step.
Returns:
Tuple: The average toxicity, faithfulness, helpfulness, and relevance scores.
"""
logging.info("Starting end-to-end evaluation...")
return run_llm_judged_tests(llm_judged_test_e2e)
8 changes: 8 additions & 0 deletions llm-complete-guide/structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import List, Optional

import numpy as np
from pydantic import BaseModel


@dataclass
Expand All @@ -39,3 +40,10 @@ class Document:
embedding: Optional[np.ndarray] = None
token_count: Optional[int] = None
generated_questions: Optional[List[str]] = None


class TestResult(BaseModel):
success: bool
question: str
keyword: str = ""
response: str
23 changes: 17 additions & 6 deletions llm-complete-guide/utils/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,12 @@
from sentence_transformers import SentenceTransformer
from structures import Document


logger = logging.getLogger(__name__)


def split_text_with_regex(text: str, separator: str, keep_separator: bool) -> List[str]:
def split_text_with_regex(
text: str, separator: str, keep_separator: bool
) -> List[str]:
"""Splits a given text using a specified separator.
This function splits the input text using the provided separator. The separator can be included or excluded
Expand All @@ -66,7 +67,9 @@ def split_text_with_regex(text: str, separator: str, keep_separator: bool) -> Li
if separator:
if keep_separator:
_splits = re.split(f"({separator})", text)
splits = [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)]
splits = [
_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)
]
if len(_splits) % 2 == 0:
splits += _splits[-1:]
splits = [_splits[0]] + splits
Expand Down Expand Up @@ -125,7 +128,9 @@ def split_text(
current_chunk += split + _separator
else:
if current_chunk:
token_count = len(encoding.encode(current_chunk.rstrip(_separator)))
token_count = len(
encoding.encode(current_chunk.rstrip(_separator))
)
chunks.append(
Document(
page_content=current_chunk.rstrip(_separator),
Expand Down Expand Up @@ -155,7 +160,9 @@ def split_text(
final_chunks.append(chunks[i])
else:
overlap = chunks[i - 1].page_content[-chunk_overlap:]
token_count = len(encoding.encode(overlap + chunks[i].page_content))
token_count = len(
encoding.encode(overlap + chunks[i].page_content)
)
final_chunks.append(
Document(
page_content=overlap + chunks[i].page_content,
Expand Down Expand Up @@ -241,7 +248,11 @@ def get_db_password() -> str:
if not password:
from zenml.client import Client

password = Client().get_secret("supabase_postgres_db").secret_values["password"]
password = (
Client()
.get_secret("supabase_postgres_db")
.secret_values["password"]
)
return password


Expand Down

0 comments on commit 799daab

Please sign in to comment.