Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/observable tags #954

Merged
merged 4 commits into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/griptape-framework/drivers/observability-drivers.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ with Observability(observability_driver=observability_driver):
agent.run("Name an animal")
```

Ouput (only relevant because of use of `ConsoleSpanExporter`):
Output (only relevant because of use of `ConsoleSpanExporter`):
```
[06/18/24 06:57:22] INFO PromptTask 2d8ef95bf817480188ae2f74e754308a
Input: Name an animal
Expand Down
2 changes: 1 addition & 1 deletion griptape/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,5 @@
"PromptStack",
"Reference",
"observable",
"Observable"
"Observable",
]
4 changes: 4 additions & 0 deletions griptape/common/observable.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ def __call__(self):
args = (self.instance, *self.args) if self.instance and not hasattr(self.func, "__self__") else self.args
return self.func(*args, **self.kwargs)

@property
def tags(self) -> Optional[list[str]]:
return self.decorator_kwargs.get("tags")

def __init__(self, *args, **kwargs):
self._instance = None
if len(args) == 1 and len(kwargs) == 0 and isfunction(args[0]):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

@define
class BaseObservabilityDriver(ABC):
def __enter__(self) -> None:
def __enter__(self) -> None: # noqa: B027
pass

def __exit__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,14 @@ def __exit__(
def observe(self, call: Observable.Call) -> Any:
func = call.func
instance = call.instance
tags = call.tags

class_name = f"{instance.__class__.__name__}." if instance else ""
span_name = f"{class_name}{func.__name__}()"
with self._tracer.start_as_current_span(span_name) as span: # pyright: ignore[reportCallIssue]
if tags is not None:
span.set_attribute("tags", tags)

try:
result = call()
span.set_status(Status(StatusCode.OK))
Expand Down
3 changes: 3 additions & 0 deletions griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
BaseMessageContent,
TextMessageContent,
ImageMessageContent,
observable,
)
from griptape.drivers import BasePromptDriver
from griptape.tokenizers import AmazonBedrockTokenizer, BaseTokenizer
Expand All @@ -35,6 +36,7 @@ class AmazonBedrockPromptDriver(BasePromptDriver):
default=Factory(lambda self: AmazonBedrockTokenizer(model=self.model), takes_self=True), kw_only=True
)

@observable
def try_run(self, prompt_stack: PromptStack) -> Message:
response = self.bedrock_client.converse(**self._base_params(prompt_stack))

Expand All @@ -47,6 +49,7 @@ def try_run(self, prompt_stack: PromptStack) -> Message:
usage=Message.Usage(input_tokens=usage["inputTokens"], output_tokens=usage["outputTokens"]),
)

@observable
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
response = self.bedrock_client.converse_stream(**self._base_params(prompt_stack))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from attrs import Factory, define, field

from griptape.artifacts import TextArtifact
from griptape.common import PromptStack, Message, TextMessageContent, DeltaMessage
from griptape.common import PromptStack, Message, TextMessageContent, DeltaMessage, observable
from griptape.drivers import BasePromptDriver
from griptape.tokenizers import HuggingFaceTokenizer
from griptape.utils import import_optional_dependency
Expand Down Expand Up @@ -41,6 +41,7 @@ def validate_stream(self, _, stream):
if stream:
raise ValueError("streaming is not supported")

@observable
def try_run(self, prompt_stack: PromptStack) -> Message:
payload = {
"inputs": self.prompt_stack_to_string(prompt_stack),
Expand Down Expand Up @@ -78,6 +79,7 @@ def try_run(self, prompt_stack: PromptStack) -> Message:
usage=Message.Usage(input_tokens=input_tokens, output_tokens=output_tokens),
)

@observable
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
raise NotImplementedError("streaming is not supported")

Expand Down
3 changes: 3 additions & 0 deletions griptape/drivers/prompt/anthropic_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
PromptStack,
Message,
TextMessageContent,
observable,
)
from griptape.drivers import BasePromptDriver
from griptape.tokenizers import AnthropicTokenizer, BaseTokenizer
Expand Down Expand Up @@ -48,6 +49,7 @@ class AnthropicPromptDriver(BasePromptDriver):
top_k: int = field(default=250, kw_only=True, metadata={"serializable": True})
max_tokens: int = field(default=1000, kw_only=True, metadata={"serializable": True})

@observable
def try_run(self, prompt_stack: PromptStack) -> Message:
response = self.client.messages.create(**self._base_params(prompt_stack))

Expand All @@ -57,6 +59,7 @@ def try_run(self, prompt_stack: PromptStack) -> Message:
usage=Message.Usage(input_tokens=response.usage.input_tokens, output_tokens=response.usage.output_tokens),
)

@observable
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
events = self.client.messages.create(**self._base_params(prompt_stack), stream=True)

Expand Down
2 changes: 2 additions & 0 deletions griptape/drivers/prompt/base_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
PromptStack,
Message,
TextMessageContent,
observable,
)
from griptape.events import CompletionChunkEvent, FinishPromptEvent, StartPromptEvent
from griptape.mixins import ExponentialBackoffMixin, SerializableMixin
Expand Down Expand Up @@ -60,6 +61,7 @@ def after_run(self, result: Message) -> None:
)
)

@observable(tags=["PromptDriver.run()"])
def run(self, prompt_stack: PromptStack) -> Message:
for attempt in self.retrying():
with attempt:
Expand Down
3 changes: 3 additions & 0 deletions griptape/drivers/prompt/cohere_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
TextMessageContent,
BaseMessageContent,
TextDeltaMessageContent,
observable,
)
from griptape.utils import import_optional_dependency
from griptape.tokenizers import BaseTokenizer
Expand All @@ -38,6 +39,7 @@ class CoherePromptDriver(BasePromptDriver):
default=Factory(lambda self: CohereTokenizer(model=self.model, client=self.client), takes_self=True)
)

@observable
def try_run(self, prompt_stack: PromptStack) -> Message:
result = self.client.chat(**self._base_params(prompt_stack))
usage = result.meta.tokens
Expand All @@ -48,6 +50,7 @@ def try_run(self, prompt_stack: PromptStack) -> Message:
usage=Message.Usage(input_tokens=usage.input_tokens, output_tokens=usage.output_tokens),
)

@observable
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
result = self.client.chat_stream(**self._base_params(prompt_stack))

Expand Down
4 changes: 3 additions & 1 deletion griptape/drivers/prompt/dummy_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from attrs import Factory, define, field

from griptape.common import PromptStack, Message, DeltaMessage
from griptape.common import PromptStack, Message, DeltaMessage, observable
from griptape.drivers import BasePromptDriver
from griptape.exceptions import DummyException
from griptape.tokenizers import DummyTokenizer
Expand All @@ -14,8 +14,10 @@ class DummyPromptDriver(BasePromptDriver):
model: None = field(init=False, default=None, kw_only=True)
tokenizer: DummyTokenizer = field(default=Factory(lambda: DummyTokenizer()), kw_only=True)

@observable
def try_run(self, prompt_stack: PromptStack) -> Message:
raise DummyException(__class__.__name__, "try_run")

@observable
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
raise DummyException(__class__.__name__, "try_stream")
3 changes: 3 additions & 0 deletions griptape/drivers/prompt/google_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
PromptStack,
Message,
TextMessageContent,
observable,
)
from griptape.drivers import BasePromptDriver
from griptape.tokenizers import BaseTokenizer, GoogleTokenizer
Expand Down Expand Up @@ -45,6 +46,7 @@ class GooglePromptDriver(BasePromptDriver):
top_p: Optional[float] = field(default=None, kw_only=True, metadata={"serializable": True})
top_k: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True})

@observable
def try_run(self, prompt_stack: PromptStack) -> Message:
GenerationConfig = import_optional_dependency("google.generativeai.types").GenerationConfig

Expand All @@ -70,6 +72,7 @@ def try_run(self, prompt_stack: PromptStack) -> Message:
),
)

@observable
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
GenerationConfig = import_optional_dependency("google.generativeai.types").GenerationConfig

Expand Down
4 changes: 3 additions & 1 deletion griptape/drivers/prompt/huggingface_hub_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from griptape.drivers import BasePromptDriver
from griptape.tokenizers import HuggingFaceTokenizer
from griptape.common import PromptStack, Message, DeltaMessage, TextDeltaMessageContent
from griptape.common import PromptStack, Message, DeltaMessage, TextDeltaMessageContent, observable
from griptape.utils import import_optional_dependency

if TYPE_CHECKING:
Expand Down Expand Up @@ -47,6 +47,7 @@ class HuggingFaceHubPromptDriver(BasePromptDriver):
kw_only=True,
)

@observable
def try_run(self, prompt_stack: PromptStack) -> Message:
prompt = self.prompt_stack_to_string(prompt_stack)

Expand All @@ -62,6 +63,7 @@ def try_run(self, prompt_stack: PromptStack) -> Message:
usage=Message.Usage(input_tokens=input_tokens, output_tokens=output_tokens),
)

@observable
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
prompt = self.prompt_stack_to_string(prompt_stack)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from attrs import Factory, define, field

from griptape.artifacts import TextArtifact
from griptape.common import DeltaMessage, PromptStack, Message, TextMessageContent
from griptape.common import DeltaMessage, PromptStack, Message, TextMessageContent, observable
from griptape.drivers import BasePromptDriver
from griptape.tokenizers import HuggingFaceTokenizer
from griptape.utils import import_optional_dependency
Expand Down Expand Up @@ -42,6 +42,7 @@ class HuggingFacePipelinePromptDriver(BasePromptDriver):
)
)

@observable
def try_run(self, prompt_stack: PromptStack) -> Message:
messages = self._prompt_stack_to_messages(prompt_stack)

Expand All @@ -66,6 +67,7 @@ def try_run(self, prompt_stack: PromptStack) -> Message:
else:
raise Exception("invalid output format")

@observable
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
raise NotImplementedError("streaming is not supported")

Expand Down
4 changes: 3 additions & 1 deletion griptape/drivers/prompt/ollama_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from griptape.artifacts import TextArtifact
from griptape.drivers import BasePromptDriver
from griptape.tokenizers.base_tokenizer import BaseTokenizer
from griptape.common import PromptStack, TextMessageContent
from griptape.common import PromptStack, TextMessageContent, observable
from griptape.utils import import_optional_dependency
from griptape.tokenizers import SimpleTokenizer
from griptape.common import Message, DeltaMessage, TextDeltaMessageContent
Expand Down Expand Up @@ -49,6 +49,7 @@ class OllamaPromptDriver(BasePromptDriver):
kw_only=True,
)

@observable
def try_run(self, prompt_stack: PromptStack) -> Message:
response = self.client.chat(**self._base_params(prompt_stack))

Expand All @@ -60,6 +61,7 @@ def try_run(self, prompt_stack: PromptStack) -> Message:
else:
raise Exception("invalid model response")

@observable
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
stream = self.client.chat(**self._base_params(prompt_stack), stream=True)

Expand Down
3 changes: 3 additions & 0 deletions griptape/drivers/prompt/openai_chat_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
PromptStack,
Message,
TextMessageContent,
observable,
)
from griptape.drivers import BasePromptDriver
from griptape.tokenizers import BaseTokenizer, OpenAiTokenizer
Expand Down Expand Up @@ -73,6 +74,7 @@ class OpenAiChatPromptDriver(BasePromptDriver):
kw_only=True,
)

@observable
def try_run(self, prompt_stack: PromptStack) -> Message:
result = self.client.chat.completions.create(**self._base_params(prompt_stack))

Expand All @@ -89,6 +91,7 @@ def try_run(self, prompt_stack: PromptStack) -> Message:
else:
raise Exception("Completion with more than one choice is not supported yet.")

@observable
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
result = self.client.chat.completions.create(
**self._base_params(prompt_stack), stream=True, stream_options={"include_usage": True}
Expand Down
2 changes: 2 additions & 0 deletions griptape/tools/base_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import yaml
from attrs import define, field, Factory
from griptape.artifacts import BaseArtifact, InfoArtifact, TextArtifact
from griptape.common import observable
from griptape.mixins import ActivityMixin

if TYPE_CHECKING:
Expand Down Expand Up @@ -113,6 +114,7 @@ def execute(self, activity: Callable, subtask: ActionsSubtask, action: ActionsSu
def before_run(self, activity: Callable, subtask: ActionsSubtask, action: ActionsSubtask.Action) -> Optional[dict]:
return action.input

@observable(tags=["Tool.run()"])
def run(
self, activity: Callable, subtask: ActionsSubtask, action: ActionsSubtask.Action, value: Optional[dict]
) -> BaseArtifact:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,25 @@ def test_observability_agent(self, driver, mock_span_exporter):
assert mock_span_exporter.export.call_count == 1
mock_span_exporter.export.assert_called_with(expected_spans)
mock_span_exporter.export.reset_mock()

def test_context_manager_observe_adds_tags_attribute(self, driver, mock_span_exporter):
expected_spans = ExpectedSpans(
spans=[
ExpectedSpan(name="main", parent=None, status_code=StatusCode.OK),
ExpectedSpan(
name="func()", parent="main", status_code=StatusCode.OK, attributes={"tags": ("Foo.bar()",)}
),
]
)

def func(word: str):
return word + " you"

with driver:
driver.observe(
Observable.Call(func=func, instance=None, args=["Hi"], decorator_kwargs={"tags": ["Foo.bar()"]})
) == "Hi you"

assert mock_span_exporter.export.call_count == 1
mock_span_exporter.export.assert_called_with(expected_spans)
mock_span_exporter.export.reset_mock()
9 changes: 9 additions & 0 deletions tests/utils/expected_spans.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class ExpectedSpan:
parent: str = field(kw_only=True)
status_code: StatusCode = field(kw_only=True)
exception: Optional[Exception] = field(default=None, kw_only=True)
attributes: Optional[dict] = field(default=None, kw_only=True)


@define
Expand Down Expand Up @@ -76,4 +77,12 @@ def print_other_spans():
if actual_parent != expected_parent:
raise Exception(f"Span {child} has wrong parent")

expected_attributes = {span.name: span.attributes for span in self.spans}
for span_name, expected_attributes in expected_attributes.items():
other_span = other_span_by_name[span_name]
if expected_attributes is not None and other_span.attributes != expected_attributes:
raise Exception(
f"Span {span_name} has attributes {other_span.attributes} instead of {expected_attributes}"
)

return True
Loading