Skip to content

Commit

Permalink
OPT: Instrument get_middlewares if available
Browse files Browse the repository at this point in the history
  • Loading branch information
woile committed Oct 15, 2024
1 parent b7cb9c4 commit e16bace
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 46 deletions.
52 changes: 26 additions & 26 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

28 changes: 21 additions & 7 deletions src/opentelemetry_instrumentation_kstreams/instrumentor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from .package import _instruments
from .version import __version__
from .wrappers import (
# _wrap_getone,
_wrap_build_stream_middleware_stack,
_wrap_get_middlewares,
_wrap_send,
)

Expand Down Expand Up @@ -38,12 +38,26 @@ def _instrument(self, **kwargs: Any):
schema_url="https://opentelemetry.io/schemas/1.11.0",
)
wrap_function_wrapper(StreamEngine, "send", _wrap_send(tracer))
wrap_function_wrapper(
StreamEngine,
"build_stream_middleware_stack",
_wrap_build_stream_middleware_stack(tracer),
)

# kstreams >= 0.24.1
if hasattr(Stream, "get_middlewares"):
wrap_function_wrapper(
Stream,
"get_middlewares",
_wrap_get_middlewares(tracer),
)
else:
wrap_function_wrapper(
StreamEngine,
"_build_stream_middleware_stack",
_wrap_build_stream_middleware_stack(tracer),
)

def _uninstrument(self, **kwargs: Any):
unwrap(StreamEngine, "send")
unwrap(Stream, "build_stream_middleware_stack")

# kstreams >= 0.24.1
if hasattr(Stream, "get_middlewares"):
unwrap(Stream, "get_middlewares")
else:
unwrap(StreamEngine, "_build_stream_middleware_stack")
24 changes: 24 additions & 0 deletions src/opentelemetry_instrumentation_kstreams/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,27 @@ def _traced_build_stream_middleware_stack(
return next_call

return _traced_build_stream_middleware_stack


def _wrap_get_middlewares(
tracer: Tracer,
) -> Callable:
def _traced_get_middlewares(
func, instance: Stream, args, kwargs
) -> NextMiddlewareCall:
# let's check if otel is already present in the middlewares
if (
len(instance.middlewares) > 0
and instance.middlewares[0].middleware == OpenTelemetryMiddleware
):
return func(*args, **kwargs)

instance.middlewares.insert(
0, middleware.Middleware(OpenTelemetryMiddleware, tracer=tracer)
)

next_call = func(*args, **kwargs)

return next_call

return _traced_get_middlewares
4 changes: 2 additions & 2 deletions tests/test_instrumentation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from kstreams import StreamEngine
from kstreams import Stream, StreamEngine
from wrapt import BoundFunctionWrapper

from opentelemetry_instrumentation_kstreams import KStreamsInstrumentor
Expand All @@ -8,4 +8,4 @@ def test_instrument_api() -> None:
instrumentation = KStreamsInstrumentor()
instrumentation.instrument()
assert isinstance(StreamEngine.send, BoundFunctionWrapper)
assert isinstance(StreamEngine.build_stream_middleware_stack, BoundFunctionWrapper)
assert isinstance(Stream.get_middlewares, BoundFunctionWrapper)
32 changes: 21 additions & 11 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import asyncio
from unittest import TestCase, mock

from kstreams import ConsumerRecord, StreamEngine
from kstreams import ConsumerRecord, Stream, StreamEngine
from kstreams.backends.kafka import Kafka
from kstreams.middleware import BaseMiddleware, ExceptionMiddleware, Middleware
from kstreams.streams_utils import StreamErrorPolicy
Expand Down Expand Up @@ -226,10 +226,16 @@ async def __call__(self, cr: ConsumerRecord):
consumer_class = mock.MagicMock()
producer_class = mock.MagicMock()
monitor = mock.MagicMock()

func_mock = mock.MagicMock()
# Create the stream with an extra middleware
stream = mock.MagicMock()
stream.middlewares = [Middleware(S3Middleware)]
stream = Stream(
topics=["test_topic"],
func=func_mock,
error_policy=StreamErrorPolicy.STOP,
consumer_class=consumer_class,
middlewares=[Middleware(S3Middleware)],
)
# stream.middlewares = [Middleware(S3Middleware)]

backend = Kafka()
stream_engine = StreamEngine(
Expand All @@ -243,22 +249,26 @@ async def __call__(self, cr: ConsumerRecord):
stream_engine.start()

# Build the middleware stack
stream_engine.build_stream_middleware_stack(
stream=stream, error_policy=StreamErrorPolicy.STOP_ENGINE
)
stream_engine._build_stream_middleware_stack(stream=stream)
stream_engine.stop()

assert len(stream.middlewares) == 3
assert len(stream.get_middlewares(engine=stream_engine)) == 3

# In this case, we simulated the real workflow using the stream_engine
# so the first should be the ExceptionMiddleware
first_middleware_class = stream.middlewares[0].middleware
first_middleware_class = stream.get_middlewares(engine=stream_engine)[
0
].middleware
assert first_middleware_class == ExceptionMiddleware

# The second should be the OpenTelemetryMiddleware
second_middleware_class = stream.middlewares[1].middleware
second_middleware_class = stream.get_middlewares(engine=stream_engine)[
1
].middleware
assert second_middleware_class == OpenTelemetryMiddleware

# The third should be the S3Middleware
third_middleware_class = stream.middlewares[2].middleware
third_middleware_class = stream.get_middlewares(engine=stream_engine)[
2
].middleware
assert third_middleware_class == S3Middleware

0 comments on commit e16bace

Please sign in to comment.