Skip to content

Commit

Permalink
Update token handling as proposed by CodeRabbit
Browse files Browse the repository at this point in the history
  • Loading branch information
alexjoham committed Oct 12, 2024
1 parent 4324180 commit c9e89be
Show file tree
Hide file tree
Showing 10 changed files with 47 additions and 29 deletions.
1 change: 1 addition & 0 deletions app/llm/external/PipelineEnum.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ class PipelineEnum(str, Enum):
IRIS_CITATION_PIPELINE = "IRIS_CITATION_PIPELINE"
IRIS_RERANKER_PIPELINE = "IRIS_RERANKER_PIPELINE"
IRIS_SUMMARY_PIPELINE = "IRIS_SUMMARY_PIPELINE"
IRIS_LECTURE_RETRIEVAL_PIPELINE = "IRIS_LECTURE_RETRIEVAL_PIPELINE"
NOT_SET = "NOT_SET"
8 changes: 4 additions & 4 deletions app/llm/external/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,10 @@ def chat(
options=self.options,
)
return convert_to_iris_message(
response["message"],
response["prompt_eval_count"],
response["eval_count"],
response["model"],
response.get("message"),
response.get("prompt_eval_count", 0),
response.get("eval_count", 0),
response.get("model", self.model),
)

def embed(self, text: str) -> list[float]:
Expand Down
6 changes: 3 additions & 3 deletions app/pipeline/chat/code_feedback_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def __call__(
}
)
)
num_tokens = self.llm.tokens
num_tokens.pipeline = PipelineEnum.IRIS_CODE_FEEDBACK
self.tokens = num_tokens
token_usage = self.llm.tokens
token_usage.pipeline = PipelineEnum.IRIS_CODE_FEEDBACK
self.tokens = token_usage
return response.replace("{", "{{").replace("}", "}}")
19 changes: 12 additions & 7 deletions app/pipeline/chat/exercise_chat_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,14 @@ def __call__(self, dto: ExerciseChatPipelineExecutionDTO):
suggestion_dto.last_message = self.exercise_chat_response
suggestion_dto.problem_statement = dto.exercise.problem_statement
suggestions = self.suggestion_pipeline(suggestion_dto)
if self.suggestion_pipeline.tokens is not None:
tokens = [self.suggestion_pipeline.tokens]
else:
tokens = []
self.callback.done(
final_result=None,
suggestions=suggestions,
tokens=[self.suggestion_pipeline.tokens],
tokens=tokens,
)
else:
# This should never happen but whatever
Expand Down Expand Up @@ -264,9 +268,7 @@ def _run_exercise_chat_pipeline(
.with_config({"run_name": "Response Drafting"})
.invoke({})
)
if self.llm.tokens is not None:
self.llm.tokens.pipeline = PipelineEnum.IRIS_CHAT_EXERCISE_MESSAGE
self.tokens.append(self.llm.tokens)
self._collect_llm_tokens()
self.callback.done()
self.prompt = ChatPromptTemplate.from_messages(
[
Expand All @@ -281,9 +283,7 @@ def _run_exercise_chat_pipeline(
.with_config({"run_name": "Response Refining"})
.invoke({})
)
if self.llm.tokens is not None:
self.llm.tokens.pipeline = PipelineEnum.IRIS_CHAT_EXERCISE_MESSAGE
self.tokens.append(self.llm.tokens)
self._collect_llm_tokens()

if "!ok!" in guide_response:
print("Response is ok and not rewritten!!!")
Expand Down Expand Up @@ -385,3 +385,8 @@ def should_execute_lecture_pipeline(self, course_id: int) -> bool:
)
return len(result.objects) > 0
return False

def _collect_llm_tokens(self):
if self.llm.tokens is not None:
self.llm.tokens.pipeline = PipelineEnum.IRIS_CHAT_EXERCISE_MESSAGE
self.tokens.append(self.llm.tokens)
6 changes: 3 additions & 3 deletions app/pipeline/chat/lecture_chat_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,9 @@ def __call__(self, dto: LectureChatPipelineExecutionDTO):
self.prompt = ChatPromptTemplate.from_messages(prompt_val)
try:
response = (self.prompt | self.pipeline).invoke({})
num_tokens = self.llm.tokens
num_tokens.pipeline = PipelineEnum.IRIS_CHAT_LECTURE_MESSAGE
self.tokens.append(num_tokens)
token_usage = self.llm.tokens
token_usage.pipeline = PipelineEnum.IRIS_CHAT_LECTURE_MESSAGE
self.tokens.append(token_usage)
response_with_citation = self.citation_pipeline(
retrieved_lecture_chunks, response
)
Expand Down
4 changes: 2 additions & 2 deletions app/pipeline/competency_extraction_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,13 @@ def __call__(
response = self.request_handler.chat(
[prompt], CompletionArguments(temperature=0.4)
)
num_tokens = LLMTokenCount(
token_usage = LLMTokenCount(
model_info=response.model_info,
num_input_tokens=response.num_input_tokens,
num_output_tokens=response.num_output_tokens,
pipeline=PipelineEnum.IRIS_COMPETENCY_GENERATION,
)
self.tokens.append(num_tokens)
self.tokens.append(token_usage)
response = response.contents[0].text_content

generated_competencies: list[Competency] = []
Expand Down
7 changes: 3 additions & 4 deletions app/pipeline/lecture_ingestion_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import fitz
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from sipbuild.generator.parser.tokens import tokens
from unstructured.cleaners.core import clean
from weaviate import WeaviateClient
from weaviate.classes.query import Filter
Expand Down Expand Up @@ -280,9 +279,9 @@ def merge_page_content_and_image_interpretation(
(prompt | self.pipeline).invoke({}), bullets=True, extra_whitespace=True
)
# TODO: send to artemis
num_tokens = self.llm.tokens
num_tokens.pipeline = PipelineEnum.IRIS_LECTURE_INGESTION
tokens.append(num_tokens)
token_usage = self.llm.tokens
token_usage.pipeline = PipelineEnum.IRIS_LECTURE_INGESTION
self.tokens.append(token_usage)
return clean_output

def get_course_language(self, page_content: str) -> str:
Expand Down
2 changes: 1 addition & 1 deletion app/pipeline/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import ABCMeta
from typing import List

from app.llm.external import LLMTokenCount
from app.llm.external.LLMTokenCount import LLMTokenCount


class Pipeline(metaclass=ABCMeta):
Expand Down
6 changes: 3 additions & 3 deletions app/pipeline/shared/reranker_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def __call__(
prompt = self.default_prompt

response = (prompt | self.pipeline).invoke(data)
num_tokens = self.llm.tokens
num_tokens.pipeline = PipelineEnum.IRIS_RERANKER_PIPELINE
self.tokens.append(num_tokens)
token_usage = self.llm.tokens
token_usage.pipeline = PipelineEnum.IRIS_RERANKER_PIPELINE
self.tokens.append(token_usage)
return response.selected_paragraphs
17 changes: 15 additions & 2 deletions app/retrieval/lecture_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from ..common import convert_iris_message_to_langchain_message
from ..llm.external.LLMTokenCount import LLMTokenCount
from ..llm.external.PipelineEnum import PipelineEnum
from ..llm.langchain import IrisLangchainChatModel
from ..pipeline import Pipeline

Expand Down Expand Up @@ -82,7 +83,7 @@ class LectureRetrieval(Pipeline):
Class for retrieving lecture data from the database.
"""

tokens: LLMTokenCount
tokens: [LLMTokenCount]

def __init__(self, client: WeaviateClient, **kwargs):
super().__init__(implementation_id="lecture_retrieval_pipeline")
Expand All @@ -101,6 +102,7 @@ def __init__(self, client: WeaviateClient, **kwargs):
self.pipeline = self.llm | StrOutputParser()
self.collection = init_lecture_schema(client)
self.reranker_pipeline = RerankerPipeline()
self.tokens = []

@traceable(name="Full Lecture Retrieval")
def __call__(
Expand Down Expand Up @@ -239,7 +241,9 @@ def rewrite_student_query(
prompt = ChatPromptTemplate.from_messages(prompt_val)
try:
response = (prompt | self.pipeline).invoke({})
self.tokens = self.llm.tokens
token_usage = self.llm.tokens
token_usage.pipeline = PipelineEnum.IRIS_LECTURE_RETRIEVAL_PIPELINE
self.tokens.append(self.llm.tokens)
logger.info(f"Response from exercise chat pipeline: {response}")
return response
except Exception as e:
Expand Down Expand Up @@ -277,6 +281,9 @@ def rewrite_student_query_with_exercise_context(
prompt = ChatPromptTemplate.from_messages(prompt_val)
try:
response = (prompt | self.pipeline).invoke({})
token_usage = self.llm.tokens
token_usage.pipeline = PipelineEnum.IRIS_LECTURE_RETRIEVAL_PIPELINE
self.tokens.append(self.llm.tokens)
logger.info(f"Response from exercise chat pipeline: {response}")
return response
except Exception as e:
Expand Down Expand Up @@ -312,6 +319,9 @@ def rewrite_elaborated_query(
prompt = ChatPromptTemplate.from_messages(prompt_val)
try:
response = (prompt | self.pipeline).invoke({})
token_usage = self.llm.tokens
token_usage.pipeline = PipelineEnum.IRIS_LECTURE_RETRIEVAL_PIPELINE
self.tokens.append(self.llm.tokens)
logger.info(f"Response from retirval pipeline: {response}")
return response
except Exception as e:
Expand Down Expand Up @@ -351,6 +361,9 @@ def rewrite_elaborated_query_with_exercise_context(
)
try:
response = (prompt | self.pipeline).invoke({})
token_usage = self.llm.tokens
token_usage.pipeline = PipelineEnum.IRIS_LECTURE_RETRIEVAL_PIPELINE
self.tokens.append(self.llm.tokens)
logger.info(f"Response from exercise chat pipeline: {response}")
return response
except Exception as e:
Expand Down

0 comments on commit c9e89be

Please sign in to comment.