Skip to content

Commit

Permalink
FIX: Improve typings
Browse files Browse the repository at this point in the history
  • Loading branch information
woile committed Sep 27, 2024
1 parent e360f82 commit 7b2e11b
Showing 1 changed file with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions src/opentelemetry_instrumentation_kstreams/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import json
from logging import getLogger
from typing import Any, Callable, Iterable, List, Optional
from typing import Any, Awaitable, Callable, Iterable, List, Optional, Union

from kstreams import Stream, StreamEngine, ConsumerRecord
from kstreams import Send, Stream, StreamEngine, ConsumerRecord, RecordMetadata
from kstreams.backends.kafka import Kafka
from opentelemetry import propagate, context, trace
from opentelemetry.propagators import textmap
Expand All @@ -17,7 +17,7 @@
_LOG = getLogger(__name__)


HeadersT = list[tuple[str, bytes | None]] | dict[str, str | None]
HeadersT = Union[list[tuple[str, Union[bytes, None]]], dict[str, Union[str, None]]]


class KStreamsContextGetter(textmap.Getter[HeadersT]):
Expand Down Expand Up @@ -45,7 +45,7 @@ def keys(self, carrier: HeadersT) -> List[str]:


class KStreamsContextSetter(textmap.Setter[HeadersT]):
def set(self, carrier: HeadersT, key: str, value: str | None) -> None:
def set(self, carrier: HeadersT, key: str, value: Optional[str]) -> None:
if carrier is None or key is None:
return

Expand Down Expand Up @@ -89,20 +89,16 @@ def extract_send_topic(args, kwargs) -> Any:
)

@staticmethod
def extract_send_partition(record_metadata: Any) -> Any:
if hasattr(record_metadata, "partition"):
return record_metadata.partition
def extract_send_partition(record_metadata: Any) -> Optional[int]:
return getattr(record_metadata, "partition", None)

@staticmethod
def extract_send_offset(record_metadata: Any) -> Any:
if hasattr(record_metadata, "offset"):
return record_metadata.offset
def extract_send_offset(record_metadata: Any) -> Optional[int]:
return getattr(record_metadata, "offset", None)

@staticmethod
def extract_consumer_group(consumer: Any) -> Optional[str]:
if hasattr(consumer, "group_id"):
return consumer.group_id
return None
return getattr(consumer, "group_id", None)

@staticmethod
def extract_producer_client_id(instance: StreamEngine) -> Optional[str]:
Expand Down Expand Up @@ -182,7 +178,9 @@ def _get_span_name(operation: str, topic: str) -> str:


def _wrap_send(tracer: Tracer) -> Callable:
async def _traced_send(func, instance: StreamEngine, args, kwargs) -> Any:
async def _traced_send(
func: Send, instance: StreamEngine, args, kwargs
) -> Awaitable[RecordMetadata]:
if not isinstance(instance.backend, Kafka):
raise NotImplementedError("Only Kafka backend is supported for now")

Expand Down Expand Up @@ -263,7 +261,7 @@ def _create_consumer_span(
# span.set_attribute(
# SpanAttributes.MESSAGING_CONSUMER_GROUP_NAME, consumer_group
# )

# trace.set_span_in_context(span)
context.detach(token)


Expand Down Expand Up @@ -296,6 +294,8 @@ async def _traced_anext(func, instance: Stream, args, kwargs):
args,
kwargs,
)
# instance._current_context_token = context.attach(
# )

return record

Expand Down

0 comments on commit 7b2e11b

Please sign in to comment.