diff --git a/src/snowflake/telemetry/__init__.py b/src/snowflake/telemetry/__init__.py index 0df4a4f..fd12137 100644 --- a/src/snowflake/telemetry/__init__.py +++ b/src/snowflake/telemetry/__init__.py @@ -11,7 +11,7 @@ Stored Procedures. """ -from opentelemetry import trace +from opentelemetry.trace import get_current_span from opentelemetry.util import types from snowflake.telemetry.version import VERSION @@ -26,7 +26,8 @@ def add_event( """ Add an event name and associated attributes to the current span. """ - trace.get_current_span().add_event(name, attributes) + get_current_span().add_event(name, attributes) + def set_span_attribute( key: str, @@ -35,4 +36,4 @@ def set_span_attribute( """ Set an attribute key, value pair on the current span. """ - trace.get_current_span().set_attribute(key, value) + get_current_span().set_attribute(key, value) diff --git a/src/snowflake/telemetry/trace/__init__.py b/src/snowflake/telemetry/trace/__init__.py new file mode 100644 index 0000000..4925ac0 --- /dev/null +++ b/src/snowflake/telemetry/trace/__init__.py @@ -0,0 +1,29 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +import time +import random + +from opentelemetry import trace +from opentelemetry.sdk.trace import RandomIdGenerator + +# Generator that returns +# trace_id: the given (inherited) trace id on the first call to generate_trace_id, and a Snowflake trace_id on subsequent calls +# span_id: a random span_id +class SnowflakeTraceIdGenerator(RandomIdGenerator): + def generate_trace_id(self) -> int: + trace_id = trace.INVALID_TRACE_ID + while trace_id == trace.INVALID_TRACE_ID: + # Number of minutes since the epoch + timestamp_in_minutes = int(time.time()) // 60 + # Convert and pad to 4 bytes + timestamp_bytes = timestamp_in_minutes.to_bytes(4, byteorder='big', signed=False) + suffix_bytes = random.getrandbits(96).to_bytes(12, byteorder='big', signed=False) + trace_id = int.from_bytes(timestamp_bytes + suffix_bytes, byteorder='big', signed=False) + return trace_id + + +__all__ = [ + "SnowflakeTraceIdGenerator", +] diff --git a/tests/test_basic.py b/tests/test_basic.py index 3493a96..aa99894 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -4,7 +4,7 @@ import unittest from snowflake import telemetry -from opentelemetry import trace +from opentelemetry.trace import get_current_span from opentelemetry.sdk.trace import ( TracerProvider, ) @@ -14,10 +14,7 @@ from opentelemetry.sdk.trace.export.in_memory_span_exporter import ( InMemorySpanExporter, ) -from opentelemetry.trace import ( - Status, - StatusCode, -) + from opentelemetry.trace.span import INVALID_SPAN @@ -33,7 +30,7 @@ def test_api_without_current_span(self): """ Tests that no exceptions are raised by public API methods when called without a current span """ - self.assertEqual(trace.get_current_span(), INVALID_SPAN) + self.assertEqual(get_current_span(), INVALID_SPAN) telemetry.add_event("EventName1") telemetry.add_event("EventName2", { @@ -48,9 +45,9 @@ def test_add_event(self): """ self.configure_open_telemetry() with self.tracer.start_as_current_span("Auto-instrumented span"): - self.assertNotEqual(trace.get_current_span(), INVALID_SPAN) + self.assertNotEqual(get_current_span(), INVALID_SPAN) telemetry.add_event("EventName1") - trace.get_current_span().end() + get_current_span().end() spans = self.memory_exporter.get_finished_spans() self.assertEqual(1, len(spans)) events = spans[0].events @@ -64,9 +61,9 @@ def test_add_event_none_name(self): """ self.configure_open_telemetry() with self.tracer.start_as_current_span("Auto-instrumented span"): - self.assertNotEqual(trace.get_current_span(), INVALID_SPAN) + self.assertNotEqual(get_current_span(), INVALID_SPAN) telemetry.add_event(None) - trace.get_current_span().end() + get_current_span().end() spans = self.memory_exporter.get_finished_spans() self.assertEqual(1, len(spans)) events = spans[0].events @@ -80,9 +77,9 @@ def test_add_event_empty_name(self): """ self.configure_open_telemetry() with self.tracer.start_as_current_span("Auto-instrumented span"): - self.assertNotEqual(trace.get_current_span(), INVALID_SPAN) + self.assertNotEqual(get_current_span(), INVALID_SPAN) telemetry.add_event("") - trace.get_current_span().end() + get_current_span().end() spans = self.memory_exporter.get_finished_spans() self.assertEqual(1, len(spans)) events = spans[0].events @@ -93,7 +90,7 @@ def test_add_event_empty_name(self): def test_add_event_with_attributes(self): self.configure_open_telemetry() with self.tracer.start_as_current_span("Auto-instrumented span"): - self.assertNotEqual(trace.get_current_span(), INVALID_SPAN) + self.assertNotEqual(get_current_span(), INVALID_SPAN) telemetry.add_event("EventName2", { "some int": 42, @@ -103,7 +100,7 @@ def test_add_event_with_attributes(self): "a false value": False, "a none value": None, }) - trace.get_current_span().end() + get_current_span().end() spans = self.memory_exporter.get_finished_spans() self.assertEqual(1, len(spans)) events = spans[0].events @@ -120,14 +117,14 @@ def test_add_event_with_attributes(self): def test_set_span_attribute(self): self.configure_open_telemetry() with self.tracer.start_as_current_span("Auto-instrumented span"): - self.assertNotEqual(trace.get_current_span(), INVALID_SPAN) + self.assertNotEqual(get_current_span(), INVALID_SPAN) telemetry.set_span_attribute("some int", 42) telemetry.set_span_attribute("some str", "Val1") telemetry.set_span_attribute("some float", 3.14) telemetry.set_span_attribute("a true value", True) telemetry.set_span_attribute("a false value", False) telemetry.set_span_attribute("a none value", None) - trace.get_current_span().end() + get_current_span().end() spans = self.memory_exporter.get_finished_spans() self.assertEqual(1, len(spans)) attributes = spans[0].attributes diff --git a/tests/test_snowflake_trace_id_generator.py b/tests/test_snowflake_trace_id_generator.py new file mode 100644 index 0000000..ce4fbf3 --- /dev/null +++ b/tests/test_snowflake_trace_id_generator.py @@ -0,0 +1,43 @@ +import unittest +from unittest.mock import patch + +from opentelemetry import trace +from snowflake.telemetry.trace import SnowflakeTraceIdGenerator + +MOCK_TIMESTAMP = 1719588243.3379807 +INVALID_TRACE_ID = 0x00000000000000000000000000000000 +TRACE_ID_MAX_VALUE = 2**128 - 1 +SPAN_ID_MAX_VALUE = 2**64 - 1 + + +class TestSnowflakeTraceIdGenerator(unittest.TestCase): + + def test_valid_span_id(self): + id_generator = SnowflakeTraceIdGenerator() + self.assertTrue(trace.INVALID_SPAN_ID < id_generator.generate_span_id() <= SPAN_ID_MAX_VALUE) + self.assertTrue(trace.INVALID_SPAN_ID < id_generator.generate_span_id() <= SPAN_ID_MAX_VALUE) + self.assertTrue(trace.INVALID_SPAN_ID < id_generator.generate_span_id() <= SPAN_ID_MAX_VALUE) + self.assertTrue(trace.INVALID_SPAN_ID < id_generator.generate_span_id() <= SPAN_ID_MAX_VALUE) + self.assertTrue(trace.INVALID_SPAN_ID < id_generator.generate_span_id() <= SPAN_ID_MAX_VALUE) + + @patch('time.time', return_value=MOCK_TIMESTAMP) + def test_valid_snowflake_trace_id(self, mock_time): + id_generator = SnowflakeTraceIdGenerator() + self._verify_snowflake_trace_id(id_generator.generate_trace_id()) + self._verify_snowflake_trace_id(id_generator.generate_trace_id()) + self._verify_snowflake_trace_id(id_generator.generate_trace_id()) + self._verify_snowflake_trace_id(id_generator.generate_trace_id()) + self._verify_snowflake_trace_id(id_generator.generate_trace_id()) + + def _verify_snowflake_trace_id(self, trace_id: int): + # https://github.com/open-telemetry/opentelemetry-python/blob/main/opentelemetry-api/src/opentelemetry/trace/span.py + self.assertTrue(trace.INVALID_TRACE_ID < trace_id <= TRACE_ID_MAX_VALUE) + + # Get the hex format of the snowflake_trace_id and pad it to 32 characters + # The timestamp prefix is the first 8 characters of this. + timestamp_prefix = f'{trace_id:x}'.zfill(32)[:8] + + # the expected prefix is the timestamp (in minutes) in hex format padded to 8 characters. + mock_timestamp_minutes = int(MOCK_TIMESTAMP) // 60 + mock_timestamp_prefix = f'{mock_timestamp_minutes:x}'.zfill(8) + self.assertEqual(timestamp_prefix, mock_timestamp_prefix)