diff --git a/src/snowflake/telemetry/__init__.py b/src/snowflake/telemetry/__init__.py index 0df4a4f..fd12137 100644 --- a/src/snowflake/telemetry/__init__.py +++ b/src/snowflake/telemetry/__init__.py @@ -11,7 +11,7 @@ Stored Procedures. """ -from opentelemetry import trace +from opentelemetry.trace import get_current_span from opentelemetry.util import types from snowflake.telemetry.version import VERSION @@ -26,7 +26,8 @@ def add_event( """ Add an event name and associated attributes to the current span. """ - trace.get_current_span().add_event(name, attributes) + get_current_span().add_event(name, attributes) + def set_span_attribute( key: str, @@ -35,4 +36,4 @@ def set_span_attribute( """ Set an attribute key, value pair on the current span. """ - trace.get_current_span().set_attribute(key, value) + get_current_span().set_attribute(key, value) diff --git a/src/snowflake/telemetry/trace/__init__.py b/src/snowflake/telemetry/trace/__init__.py new file mode 100644 index 0000000..4925ac0 --- /dev/null +++ b/src/snowflake/telemetry/trace/__init__.py @@ -0,0 +1,29 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +import time +import random + +from opentelemetry import trace +from opentelemetry.sdk.trace import RandomIdGenerator + +# Generator that returns +# trace_id: the given (inherited) trace id on the first call to generate_trace_id, and a Snowflake trace_id on subsequent calls +# span_id: a random span_id +class SnowflakeTraceIdGenerator(RandomIdGenerator): + def generate_trace_id(self) -> int: + trace_id = trace.INVALID_TRACE_ID + while trace_id == trace.INVALID_TRACE_ID: + # Number of minutes since the epoch + timestamp_in_minutes = int(time.time()) // 60 + # Convert and pad to 4 bytes + timestamp_bytes = timestamp_in_minutes.to_bytes(4, byteorder='big', signed=False) + suffix_bytes = random.getrandbits(96).to_bytes(12, byteorder='big', signed=False) + trace_id = int.from_bytes(timestamp_bytes + suffix_bytes, byteorder='big', signed=False) + return trace_id + + +__all__ = [ + "SnowflakeTraceIdGenerator", +] diff --git a/tests/test_snowflake_trace_id_generator.py b/tests/test_snowflake_trace_id_generator.py new file mode 100644 index 0000000..6cbe41e --- /dev/null +++ b/tests/test_snowflake_trace_id_generator.py @@ -0,0 +1,34 @@ +import unittest +from unittest.mock import patch + +from opentelemetry import trace +from snowflake.telemetry.trace import SnowflakeTraceIdGenerator + +MOCK_TIMESTAMP = 1719588243.3379807 +INVALID_TRACE_ID = 0x00000000000000000000000000000000 +TRACE_ID_MAX_VALUE = 2**128 - 1 + + +class TestSnowflakeTraceIdGenerator(unittest.TestCase): + + @patch('time.time', return_value=MOCK_TIMESTAMP) + def test_valid_snowflake_trace_id(self, mock_timestamp): + id_generator = SnowflakeTraceIdGenerator() + self._verify_snowflake_trace_id(id_generator.generate_trace_id()) + self._verify_snowflake_trace_id(id_generator.generate_trace_id()) + self._verify_snowflake_trace_id(id_generator.generate_trace_id()) + self._verify_snowflake_trace_id(id_generator.generate_trace_id()) + self._verify_snowflake_trace_id(id_generator.generate_trace_id()) + + def _verify_snowflake_trace_id(self, trace_id: int): + # https://github.com/open-telemetry/opentelemetry-python/blob/main/opentelemetry-api/src/opentelemetry/trace/span.py + self.assertTrue(trace.INVALID_TRACE_ID < trace_id <= TRACE_ID_MAX_VALUE) + + # Get the hex format of the snowflake_trace_id and pad it to 32 characters + # The timestamp prefix is the first 8 characters of this. + timestamp_prefix = f'{trace_id:x}'.zfill(32)[:8] + + # the expected prefix is the timestamp (in minutes) in hex format padded to 8 characters. + mock_timestamp_minutes = int(MOCK_TIMESTAMP) // 60 + mock_timestamp_prefix = f'{mock_timestamp_minutes:x}'.zfill(8) + self.assertEqual(timestamp_prefix, mock_timestamp_prefix)