Skip to content

Commit

Permalink
Set logger level in Chat (#1064)
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter authored Aug 15, 2024
1 parent 60037e9 commit 9111570
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 3 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `QueryTool` for having the LLM query text.
- Support for bitshift composition in `BaseTask` for adding parent/child tasks.
- `JsonArtifact` for handling de/seralization of values.
- `Chat.logger_level` for setting what the `Chat` utility sets the logger level to.

### Changed
- **BREAKING**: Removed all uses of `EventPublisherMixin` in favor of `event_bus`.
Expand All @@ -46,11 +47,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **BREAKING**: Changed `JsonExtractionEngine.template_schema` from a `run` argument to a class attribute.
- **BREAKING**: Changed `CsvExtractionEngine.column_names` from a `run` argument to a class attribute.
- **BREAKING**: Removed `JsonExtractionTask`, and `CsvExtractionTask` use `ExtractionTask` instead.
- **BREAKING**: Removed `TaskMemoryClient`, use `RagClient`, `ExtractionTool`, or `PromptSummaryTool` instead.
- **BREAKING**: Removed `TaskMemoryClient`, use `QueryClient`, `ExtractionTool`, or `PromptSummaryTool` instead.
- **BREAKING**: `BaseTask.add_parent/child` now take a `BaseTask` instead of `str | BaseTask`.
- Engines that previously required Drivers now pull from `griptape.config.config.drivers` by default.
- `BaseTask.add_parent/child` will now call `self.structure.add_task` if possible.
- `BaseTask.add_parent/child` now returns `self`, allowing for chaining.
- `Chat` now sets the `griptape` logger level to `logging.ERROR`, suppressing all logs except for errors.

### Fixed
- `JsonExtractionEngine` failing to parse json when the LLM outputs more than just the json.
Expand Down
9 changes: 9 additions & 0 deletions griptape/utils/chat.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Callable, Optional

from attrs import Factory, define, field
Expand All @@ -23,6 +24,7 @@ class Chat:
default=Factory(lambda self: self.default_output_fn, takes_self=True),
kw_only=True,
)
logger_level: int = field(default=logging.ERROR, kw_only=True)

def default_output_fn(self, text: str) -> None:
from griptape.tasks.prompt_task import PromptTask
Expand All @@ -38,6 +40,10 @@ def default_output_fn(self, text: str) -> None:
def start(self) -> None:
from griptape.config import config

# Hide Griptape's logging output except for errors
old_logger_level = logging.getLogger(config.logging.logger_name).getEffectiveLevel()
logging.getLogger(config.logging.logger_name).setLevel(self.logger_level)

if self.intro_text:
self.output_fn(self.intro_text)
while True:
Expand All @@ -57,3 +63,6 @@ def start(self) -> None:
else:
self.output_fn(self.processing_text)
self.output_fn(f"{self.response_prefix}{self.structure.run(question).output_task.output.to_text()}")

# Restore the original logger level
logging.getLogger(config.logging.logger_name).setLevel(old_logger_level)
24 changes: 22 additions & 2 deletions tests/unit/utils/test_chat.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import logging
from unittest.mock import patch

from griptape.config import config
from griptape.memory.structure import ConversationMemory
from griptape.structures import Agent
from griptape.utils import Chat


class TestConversation:
def test_init(self):
import logging

agent = Agent(conversation_memory=ConversationMemory())

chat = Chat(
Expand All @@ -18,6 +20,7 @@ def test_init(self):
prompt_prefix="Question: ",
response_prefix="Answer: ",
output_fn=logging.info,
logger_level=logging.INFO,
)
assert chat.structure == agent
assert chat.exiting_text == "foo..."
Expand All @@ -26,3 +29,20 @@ def test_init(self):
assert chat.prompt_prefix == "Question: "
assert chat.response_prefix == "Answer: "
assert callable(chat.output_fn)
assert chat.logger_level == logging.INFO

@patch("builtins.input", side_effect=["exit"])
def test_chat_logger_level(self, mock_input):
agent = Agent(conversation_memory=ConversationMemory())

chat = Chat(agent)

logger = logging.getLogger(config.logging.logger_name)
logger.setLevel(logging.DEBUG)

assert logger.getEffectiveLevel() == logging.DEBUG

chat.start()

assert logger.getEffectiveLevel() == logging.DEBUG
assert mock_input.call_count == 1

0 comments on commit 9111570

Please sign in to comment.