Skip to content

Commit

Permalink
Add SnowflakeTraceIdGenerator implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-bdrutu committed Jul 18, 2024
1 parent 403bdd6 commit c7053e1
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 3 deletions.
7 changes: 4 additions & 3 deletions src/snowflake/telemetry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
Stored Procedures.
"""

from opentelemetry import trace
from opentelemetry.trace import get_current_span
from opentelemetry.util import types

from snowflake.telemetry.version import VERSION
Expand All @@ -26,7 +26,8 @@ def add_event(
"""
Add an event name and associated attributes to the current span.
"""
trace.get_current_span().add_event(name, attributes)
get_current_span().add_event(name, attributes)


def set_span_attribute(
key: str,
Expand All @@ -35,4 +36,4 @@ def set_span_attribute(
"""
Set an attribute key, value pair on the current span.
"""
trace.get_current_span().set_attribute(key, value)
get_current_span().set_attribute(key, value)
29 changes: 29 additions & 0 deletions src/snowflake/telemetry/trace/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
#

import time
import random

from opentelemetry import trace
from opentelemetry.sdk.trace import RandomIdGenerator

# Generator that returns
# trace_id: the given (inherited) trace id on the first call to generate_trace_id, and a Snowflake trace_id on subsequent calls
# span_id: a random span_id
class SnowflakeTraceIdGenerator(RandomIdGenerator):
def generate_trace_id(self) -> int:
trace_id = trace.INVALID_TRACE_ID
while trace_id == trace.INVALID_TRACE_ID:
# Number of minutes since the epoch
timestamp_in_minutes = int(time.time()) // 60
# Convert and pad to 4 bytes
timestamp_bytes = timestamp_in_minutes.to_bytes(4, byteorder='big', signed=False)
suffix_bytes = random.getrandbits(96).to_bytes(12, byteorder='big', signed=False)
trace_id = int.from_bytes(timestamp_bytes + suffix_bytes, byteorder='big', signed=False)
return trace_id


__all__ = [
"SnowflakeTraceIdGenerator",
]
34 changes: 34 additions & 0 deletions tests/test_snowflake_trace_id_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import unittest
from unittest.mock import patch

from opentelemetry import trace
from snowflake.telemetry.trace import SnowflakeTraceIdGenerator

MOCK_TIMESTAMP = 1719588243.3379807
INVALID_TRACE_ID = 0x00000000000000000000000000000000
TRACE_ID_MAX_VALUE = 2**128 - 1


class TestSnowflakeTraceIdGenerator(unittest.TestCase):

@patch('time.time', return_value=MOCK_TIMESTAMP)
def test_valid_snowflake_trace_id(self, mock_timestamp):
id_generator = SnowflakeTraceIdGenerator()
self._verify_snowflake_trace_id(id_generator.generate_trace_id())
self._verify_snowflake_trace_id(id_generator.generate_trace_id())
self._verify_snowflake_trace_id(id_generator.generate_trace_id())
self._verify_snowflake_trace_id(id_generator.generate_trace_id())
self._verify_snowflake_trace_id(id_generator.generate_trace_id())

def _verify_snowflake_trace_id(self, trace_id: int):
# https://github.com/open-telemetry/opentelemetry-python/blob/main/opentelemetry-api/src/opentelemetry/trace/span.py
self.assertTrue(trace.INVALID_TRACE_ID < trace_id <= TRACE_ID_MAX_VALUE)

# Get the hex format of the snowflake_trace_id and pad it to 32 characters
# The timestamp prefix is the first 8 characters of this.
timestamp_prefix = f'{trace_id:x}'.zfill(32)[:8]

# the expected prefix is the timestamp (in minutes) in hex format padded to 8 characters.
mock_timestamp_minutes = int(MOCK_TIMESTAMP) // 60
mock_timestamp_prefix = f'{mock_timestamp_minutes:x}'.zfill(8)
self.assertEqual(timestamp_prefix, mock_timestamp_prefix)

0 comments on commit c7053e1

Please sign in to comment.