Skip to content

Commit

Permalink
feat(llmobs): submit llmobs payloads from gemini integration (#10619)
Browse files Browse the repository at this point in the history
This PR enables submitting LLMObs spans from the Gemini integration.

`generate_content/generate_content_async()` calls are traced by the
Gemini APM integration. This PR also generates LLMObs span events from
those spans and submits them to LLM Observability.

The following data is collected by Gemini LLMObs spans:
- span kind: LLM
- provider: google
- model name
- metadata: collects any of
(`temperature/max_output_tokens/candidate_count/top_p/top_k`) if set on
the model instance or in the request itself
- input messages: includes system instruction as a system message (if
provided), and history of input messages passed to the model/request
(also includes function calls/responses in addition to text messages)
- output_messages: self explanatory (also includes function calls in
addition to text responses)
- metrics: token metrics (input/output/total)
- additional span data provided ootb by LLMObs (span
duration/error/tags/etc)

## Checklist
- [x] PR author has checked that all the criteria below are met
- The PR description includes an overview of the change
- The PR description articulates the motivation for the change
- The change includes tests OR the PR description describes a testing
strategy
- The PR description notes risks associated with the change, if any
- Newly-added code is easy to change
- The change follows the [library release note
guidelines](https://ddtrace.readthedocs.io/en/stable/releasenotes.html)
- The change includes or references documentation updates if necessary
- Backport labels are set (if
[applicable](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting))

## Reviewer Checklist
- [x] Reviewer has checked that all the criteria below are met 
- Title is accurate
- All changes are related to the pull request's stated goal
- Avoids breaking
[API](https://ddtrace.readthedocs.io/en/stable/versioning.html#interfaces)
changes
- Testing strategy adequately addresses listed risks
- Newly-added code is easy to change
- Release note makes sense to a user of the library
- If necessary, author has acknowledged and discussed the performance
implications of this PR as reported in the benchmarks PR comment
- Backport labels are set in a manner that is consistent with the
[release branch maintenance
policy](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting)
  • Loading branch information
Yun-Kim authored Sep 11, 2024
1 parent c0fd013 commit dc7e31e
Show file tree
Hide file tree
Showing 10 changed files with 819 additions and 13 deletions.
2 changes: 1 addition & 1 deletion ddtrace/contrib/google_generativeai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
Pin.override(genai, service="my-gemini-service")
""" # noqa: E501
from ...internal.utils.importlib import require_modules
from ddtrace.internal.utils.importlib import require_modules


required_modules = ["google.generativeai"]
Expand Down
8 changes: 8 additions & 0 deletions ddtrace/contrib/internal/google_generativeai/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ def __iter__(self):
else:
tag_response(self._dd_span, self.__wrapped__, self._dd_integration, self._model_instance)
finally:
if self._dd_integration.is_pc_sampled_llmobs(self._dd_span):
self._dd_integration.llmobs_set_tags(
self._dd_span, self._args, self._kwargs, self._model_instance, self.__wrapped__
)
self._dd_span.finish()


Expand All @@ -44,6 +48,10 @@ async def __aiter__(self):
else:
tag_response(self._dd_span, self.__wrapped__, self._dd_integration, self._model_instance)
finally:
if self._dd_integration.is_pc_sampled_llmobs(self._dd_span):
self._dd_integration.llmobs_set_tags(
self._dd_span, self._args, self._kwargs, self._model_instance, self.__wrapped__
)
self._dd_span.finish()


Expand Down
6 changes: 6 additions & 0 deletions ddtrace/contrib/internal/google_generativeai/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def traced_generate(genai, pin, func, instance, args, kwargs):
"%s.%s" % (instance.__class__.__name__, func.__name__),
provider="google",
model=_extract_model_name(instance),
submit_to_llmobs=True,
)
try:
tag_request(span, integration, instance, args, kwargs)
Expand All @@ -59,6 +60,8 @@ def traced_generate(genai, pin, func, instance, args, kwargs):
finally:
# streamed spans will be finished separately once the stream generator is exhausted
if span.error or not stream:
if integration.is_pc_sampled_llmobs(span):
integration.llmobs_set_tags(span, args, kwargs, instance, generations)
span.finish()
return generations

Expand All @@ -73,6 +76,7 @@ async def traced_agenerate(genai, pin, func, instance, args, kwargs):
"%s.%s" % (instance.__class__.__name__, func.__name__),
provider="google",
model=_extract_model_name(instance),
submit_to_llmobs=True,
)
try:
tag_request(span, integration, instance, args, kwargs)
Expand All @@ -86,6 +90,8 @@ async def traced_agenerate(genai, pin, func, instance, args, kwargs):
finally:
# streamed spans will be finished separately once the stream generator is exhausted
if span.error or not stream:
if integration.is_pc_sampled_llmobs(span):
integration.llmobs_set_tags(span, args, kwargs, instance, generations)
span.finish()
return generations

Expand Down
1 change: 1 addition & 0 deletions ddtrace/llmobs/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"Span started while LLMObs is disabled." " Spans will not be sent to LLM Observability."
)

GEMINI_APM_SPAN_NAME = "gemini.request"
LANGCHAIN_APM_SPAN_NAME = "langchain.request"
OPENAI_APM_SPAN_NAME = "openai.request"

Expand Down
140 changes: 140 additions & 0 deletions ddtrace/llmobs/_integrations/gemini.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,25 @@
import json
from typing import Any
from typing import Dict
from typing import Iterable
from typing import List
from typing import Optional

from ddtrace import Span
from ddtrace.internal.utils import get_argument_value
from ddtrace.llmobs._constants import INPUT_MESSAGES
from ddtrace.llmobs._constants import INPUT_TOKENS_METRIC_KEY
from ddtrace.llmobs._constants import METADATA
from ddtrace.llmobs._constants import METRICS
from ddtrace.llmobs._constants import MODEL_NAME
from ddtrace.llmobs._constants import MODEL_PROVIDER
from ddtrace.llmobs._constants import OUTPUT_MESSAGES
from ddtrace.llmobs._constants import OUTPUT_TOKENS_METRIC_KEY
from ddtrace.llmobs._constants import SPAN_KIND
from ddtrace.llmobs._constants import TOTAL_TOKENS_METRIC_KEY
from ddtrace.llmobs._integrations.base import BaseLLMIntegration
from ddtrace.llmobs._utils import _get_attr
from ddtrace.llmobs._utils import _unserializable_default_repr


class GeminiIntegration(BaseLLMIntegration):
Expand All @@ -16,3 +32,127 @@ def _set_base_span_tags(
span.set_tag_str("google_generativeai.request.provider", str(provider))
if model is not None:
span.set_tag_str("google_generativeai.request.model", str(model))

def llmobs_set_tags(
self, span: Span, args: List[Any], kwargs: Dict[str, Any], instance: Any, generations: Any = None
) -> None:
if not self.llmobs_enabled:
return

span.set_tag_str(SPAN_KIND, "llm")
span.set_tag_str(MODEL_NAME, span.get_tag("google_generativeai.request.model") or "")
span.set_tag_str(MODEL_PROVIDER, span.get_tag("google_generativeai.request.provider") or "")

metadata = self._llmobs_set_metadata(kwargs, instance)
span.set_tag_str(METADATA, json.dumps(metadata, default=_unserializable_default_repr))

system_instruction = _get_attr(instance, "_system_instruction", None)
input_contents = get_argument_value(args, kwargs, 0, "contents")
input_messages = self._extract_input_message(input_contents, system_instruction)
span.set_tag_str(INPUT_MESSAGES, json.dumps(input_messages, default=_unserializable_default_repr))

if span.error or generations is None:
span.set_tag_str(OUTPUT_MESSAGES, json.dumps([{"content": ""}]))
else:
output_messages = self._extract_output_message(generations)
span.set_tag_str(OUTPUT_MESSAGES, json.dumps(output_messages, default=_unserializable_default_repr))

usage = self._get_llmobs_metrics_tags(span)
if usage:
span.set_tag_str(METRICS, json.dumps(usage, default=_unserializable_default_repr))

@staticmethod
def _llmobs_set_metadata(kwargs, instance):
metadata = {}
model_config = instance._generation_config or {}
request_config = kwargs.get("generation_config", {})
parameters = ("temperature", "max_output_tokens", "candidate_count", "top_p", "top_k")
for param in parameters:
model_config_value = _get_attr(model_config, param, None)
request_config_value = _get_attr(request_config, param, None)
if model_config_value or request_config_value:
metadata[param] = request_config_value or model_config_value
return metadata

@staticmethod
def _extract_message_from_part(part, role):
text = _get_attr(part, "text", "")
function_call = _get_attr(part, "function_call", None)
function_response = _get_attr(part, "function_response", None)
message = {"content": text}
if role:
message["role"] = role
if function_call:
function_call_dict = function_call
if not isinstance(function_call, dict):
function_call_dict = type(function_call).to_dict(function_call)
message["tool_calls"] = [
{"name": function_call_dict.get("name", ""), "arguments": function_call_dict.get("args", {})}
]
if function_response:
function_response_dict = function_response
if not isinstance(function_response, dict):
function_response_dict = type(function_response).to_dict(function_response)
message["content"] = "[tool result: {}]".format(function_response_dict.get("response", ""))
return message

def _extract_input_message(self, contents, system_instruction=None):
messages = []
if system_instruction:
for part in system_instruction.parts:
messages.append({"content": part.text or "", "role": "system"})
if isinstance(contents, str):
messages.append({"content": contents})
return messages
if isinstance(contents, dict):
message = {"content": contents.get("text", "")}
if contents.get("role", None):
message["role"] = contents["role"]
messages.append(message)
return messages
if not isinstance(contents, list):
messages.append({"content": "[Non-text content object: {}]".format(repr(contents))})
return messages
for content in contents:
if isinstance(content, str):
messages.append({"content": content})
continue
role = _get_attr(content, "role", None)
parts = _get_attr(content, "parts", [])
if not parts or not isinstance(parts, Iterable):
message = {"content": "[Non-text content object: {}]".format(repr(content))}
if role:
message["role"] = role
messages.append(message)
continue
for part in parts:
message = self._extract_message_from_part(part, role)
messages.append(message)
return messages

def _extract_output_message(self, generations):
output_messages = []
generations_dict = generations.to_dict()
for candidate in generations_dict.get("candidates", []):
content = candidate.get("content", {})
role = content.get("role", "model")
parts = content.get("parts", [])
for part in parts:
message = self._extract_message_from_part(part, role)
output_messages.append(message)
return output_messages

@staticmethod
def _get_llmobs_metrics_tags(span):
usage = {}
input_tokens = span.get_metric("google_generativeai.response.usage.prompt_tokens")
output_tokens = span.get_metric("google_generativeai.response.usage.completion_tokens")
total_tokens = span.get_metric("google_generativeai.response.usage.total_tokens")

if input_tokens is not None:
usage[INPUT_TOKENS_METRIC_KEY] = input_tokens
if output_tokens is not None:
usage[OUTPUT_TOKENS_METRIC_KEY] = output_tokens
if total_tokens is not None:
usage[TOTAL_TOKENS_METRIC_KEY] = total_tokens
return usage
3 changes: 2 additions & 1 deletion ddtrace/llmobs/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from ddtrace import config
from ddtrace.ext import SpanTypes
from ddtrace.internal.logger import get_logger
from ddtrace.llmobs._constants import GEMINI_APM_SPAN_NAME
from ddtrace.llmobs._constants import LANGCHAIN_APM_SPAN_NAME
from ddtrace.llmobs._constants import ML_APP
from ddtrace.llmobs._constants import OPENAI_APM_SPAN_NAME
Expand Down Expand Up @@ -46,7 +47,7 @@ def _get_llmobs_parent_id(span: Span) -> Optional[str]:


def _get_span_name(span: Span) -> str:
if span.name == LANGCHAIN_APM_SPAN_NAME and span.resource != "":
if span.name in (LANGCHAIN_APM_SPAN_NAME, GEMINI_APM_SPAN_NAME) and span.resource != "":
return span.resource
elif span.name == OPENAI_APM_SPAN_NAME and span.resource != "":
return "openai.{}".format(span.resource)
Expand Down
4 changes: 4 additions & 0 deletions releasenotes/notes/feat-llmobs-gemini-b65c714ceef9eb12.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
features:
- |
LLM Observability: Adds support to automatically submit Gemini Python SDK calls to LLM Observability.
18 changes: 18 additions & 0 deletions tests/contrib/google_generativeai/conftest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import os

import mock
import pytest

from ddtrace.contrib.google_generativeai import patch
from ddtrace.contrib.google_generativeai import unpatch
from ddtrace.llmobs import LLMObs
from ddtrace.pin import Pin
from tests.contrib.google_generativeai.utils import MockGenerativeModelAsyncClient
from tests.contrib.google_generativeai.utils import MockGenerativeModelClient
Expand Down Expand Up @@ -35,11 +37,27 @@ def mock_tracer(ddtrace_global_config, genai):
mock_tracer = DummyTracer(writer=DummyWriter(trace_flush_enabled=False))
pin.override(genai, tracer=mock_tracer)
pin.tracer.configure()
if ddtrace_global_config.get("_llmobs_enabled", False):
# Have to disable and re-enable LLMObs to use to mock tracer.
LLMObs.disable()
LLMObs.enable(_tracer=mock_tracer, integrations_enabled=False)
yield mock_tracer
except Exception:
yield


@pytest.fixture
def mock_llmobs_writer():
patcher = mock.patch("ddtrace.llmobs._llmobs.LLMObsSpanWriter")
try:
LLMObsSpanWriterMock = patcher.start()
m = mock.MagicMock()
LLMObsSpanWriterMock.return_value = m
yield m
finally:
patcher.stop()


@pytest.fixture
def mock_client():
yield MockGenerativeModelClient()
Expand Down
Loading

0 comments on commit dc7e31e

Please sign in to comment.