diff --git a/src/opentelemetry_instrumentation_kstreams/utils.py b/src/opentelemetry_instrumentation_kstreams/utils.py index ececd8c..2322c27 100644 --- a/src/opentelemetry_instrumentation_kstreams/utils.py +++ b/src/opentelemetry_instrumentation_kstreams/utils.py @@ -1,6 +1,6 @@ import json from logging import getLogger -from typing import Any, Callable, List, Optional +from typing import Any, Callable, Iterable, List, Optional from kstreams import Stream, StreamEngine, ConsumerRecord from kstreams.backends.kafka import Kafka @@ -14,41 +14,46 @@ _LOG = getLogger(__name__) -class KStreamsContextGetter(textmap.Getter[textmap.CarrierT]): - def get(self, carrier: textmap.CarrierT, key: str) -> Optional[List[str]]: +HeadersT = list[tuple[str, bytes | None]] | dict[str, str | None] + + +class KStreamsContextGetter(textmap.Getter[HeadersT]): + def get(self, carrier: HeadersT, key: str) -> Optional[List[str]]: + print("Getting context!!!", carrier, key) if carrier is None: return None - - carrier_items = carrier + carrier_items: Iterable = carrier if isinstance(carrier, dict): - carrier_items = carrier.items() # type: ignore + carrier_items = carrier.items() - for item_key, value in carrier_items: # type: ignore + for item_key, value in carrier_items: if item_key == key: if value is not None: + print("Found context, here you goooo", value) return [value.decode()] return None - def keys(self, carrier: textmap.CarrierT) -> List[str]: + def keys(self, carrier: HeadersT) -> List[str]: if carrier is None: return [] - carrier_items = carrier + carrier_items: Iterable = carrier if isinstance(carrier, dict): - carrier_items = carrier.items() # type: ignore + carrier_items = carrier.items() - return [key for (key, _) in carrier_items] # type: ignore + return [key for (key, _) in carrier_items] -class KStreamsContextSetter(textmap.Setter[textmap.CarrierT]): - def set(self, carrier: textmap.CarrierT, key: str, value: str) -> None: +class KStreamsContextSetter(textmap.Setter[HeadersT]): + def set(self, carrier: HeadersT, key: str, value: str | None) -> None: if carrier is None or key is None: return if isinstance(carrier, list): - carrier.append((key, value)) # type: ignore - - if isinstance(carrier, dict): - carrier[key] = value # type: ignore + if value is not None: + carrier.append((key, value.encode())) + elif isinstance(carrier, dict): + if value is not None: + carrier[key] = value _kstreams_getter: KStreamsContextGetter = KStreamsContextGetter() @@ -107,17 +112,44 @@ def _enrich_span( topic: str, partition: Optional[int], client_id: Optional[str], + offset: Optional[int] = None, ): + """ + Enriches the given span with Kafka-related attributes. + + Used to enrich consumer and producer spans. + + Args: + span: The span to enrich. + bootstrap_servers: List of Kafka bootstrap servers. + topic: The Kafka topic name. + partition: The partition number, if available. + client_id: The Kafka client ID, if available. + offset: The message offset, if available. Defaults to None. + Returns: + None + """ + if not span.is_recording(): + return + span.set_attribute(SpanAttributes.MESSAGING_SYSTEM, "kafka") span.set_attribute(SpanAttributes.MESSAGING_DESTINATION, topic) span.set_attribute(SpanAttributes.MESSAGING_URL, json.dumps(bootstrap_servers)) span.set_attribute(SpanAttributes.MESSAGING_DESTINATION_KIND, "topic") + if client_id is not None: span.set_attribute(SpanAttributes.MESSAGING_KAFKA_CLIENT_ID, client_id) if span.is_recording(): + if offset is not None: + span.set_attribute(SpanAttributes.MESSAGING_KAFKA_MESSAGE_OFFSET, offset) if partition is not None: span.set_attribute(SpanAttributes.MESSAGING_KAFKA_PARTITION, partition) + if offset and partition: + span.set_attribute( + SpanAttributes.MESSAGING_MESSAGE_ID, + f"{topic}.{partition}.{offset}", + ) def _get_span_name(operation: str, topic: str): @@ -133,7 +165,7 @@ async def _traced_send(func, instance: StreamEngine, args, kwargs): if headers is None: headers = [] - kwargs["headers"] = headers.copy() + kwargs["headers"] = headers topic = KStreamsKafkaExtractor.extract_send_topic(args, kwargs) bootstrap_servers = KStreamsKafkaExtractor.extract_bootstrap_servers( @@ -165,6 +197,21 @@ def _create_consumer_span( args: Any, kwargs: Any, ) -> None: + """ + Creates and starts a consumer span for a given Kafka record. + + Args: + tracer: The tracer instance used to create the span. + record: The Kafka consumer record for which the span is created. + extracted_context: The context extracted from the incoming message. + bootstrap_servers: List of bootstrap servers for the Kafka cluster. + client_id: The client ID of the Kafka consumer. + args: Additional positional arguments. + kwargs: Additional keyword arguments. + + Returns: + None + """ span_name = _get_span_name("receive", record.topic) with tracer.start_as_current_span( span_name, @@ -173,7 +220,7 @@ def _create_consumer_span( ) as span: new_context = trace.set_span_in_context(span, extracted_context) token = context.attach(new_context) - _enrich_span(span, bootstrap_servers, record.topic, record.partition, client_id) + _enrich_span(span, bootstrap_servers, record.topic, record.partition, client_id, record.offset) context.detach(token) diff --git a/tests/test_utils.py b/tests/test_utils.py index 0dc1112..c4450e8 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -212,7 +212,12 @@ def test_create_consumer_span( attach.assert_called_once_with(set_span_in_context.return_value) enrich_span.assert_called_once_with( - span, bootstrap_servers, record.topic, record.partition, client_id + span, + bootstrap_servers, + record.topic, + record.partition, + client_id, + record.offset, ) # consume_hook.assert_called_once_with(span, record, self.args, self.kwargs) detach.assert_called_once_with(attach.return_value)