-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add SnowflakeTraceIdGenerator implementation
- Loading branch information
1 parent
403bdd6
commit 1c6a991
Showing
4 changed files
with
89 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# | ||
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. | ||
# | ||
|
||
import time | ||
import random | ||
|
||
from opentelemetry import trace | ||
from opentelemetry.sdk.trace import RandomIdGenerator | ||
|
||
# Generator that returns | ||
# trace_id: the given (inherited) trace id on the first call to generate_trace_id, and a Snowflake trace_id on subsequent calls | ||
# span_id: a random span_id | ||
class SnowflakeTraceIdGenerator(RandomIdGenerator): | ||
def generate_trace_id(self) -> int: | ||
trace_id = trace.INVALID_TRACE_ID | ||
while trace_id == trace.INVALID_TRACE_ID: | ||
# Number of minutes since the epoch | ||
timestamp_in_minutes = int(time.time()) // 60 | ||
# Convert and pad to 4 bytes | ||
timestamp_bytes = timestamp_in_minutes.to_bytes(4, byteorder='big', signed=False) | ||
suffix_bytes = random.getrandbits(96).to_bytes(12, byteorder='big', signed=False) | ||
trace_id = int.from_bytes(timestamp_bytes + suffix_bytes, byteorder='big', signed=False) | ||
return trace_id | ||
|
||
|
||
__all__ = [ | ||
"SnowflakeTraceIdGenerator", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import unittest | ||
from unittest.mock import patch | ||
|
||
from opentelemetry import trace | ||
from snowflake.telemetry.trace import SnowflakeTraceIdGenerator | ||
|
||
MOCK_TIMESTAMP = 1719588243.3379807 | ||
INVALID_TRACE_ID = 0x00000000000000000000000000000000 | ||
TRACE_ID_MAX_VALUE = 2**128 - 1 | ||
SPAN_ID_MAX_VALUE = 2**64 - 1 | ||
|
||
|
||
class TestSnowflakeTraceIdGenerator(unittest.TestCase): | ||
|
||
def test_valid_span_id(self): | ||
id_generator = SnowflakeTraceIdGenerator() | ||
self.assertTrue(trace.INVALID_SPAN_ID < id_generator.generate_span_id() <= SPAN_ID_MAX_VALUE) | ||
self.assertTrue(trace.INVALID_SPAN_ID < id_generator.generate_span_id() <= SPAN_ID_MAX_VALUE) | ||
self.assertTrue(trace.INVALID_SPAN_ID < id_generator.generate_span_id() <= SPAN_ID_MAX_VALUE) | ||
self.assertTrue(trace.INVALID_SPAN_ID < id_generator.generate_span_id() <= SPAN_ID_MAX_VALUE) | ||
self.assertTrue(trace.INVALID_SPAN_ID < id_generator.generate_span_id() <= SPAN_ID_MAX_VALUE) | ||
|
||
@patch('time.time', return_value=MOCK_TIMESTAMP) | ||
def test_valid_snowflake_trace_id(self, mock_timestamp): | ||
id_generator = SnowflakeTraceIdGenerator() | ||
self._verify_snowflake_trace_id(id_generator.generate_trace_id()) | ||
self._verify_snowflake_trace_id(id_generator.generate_trace_id()) | ||
self._verify_snowflake_trace_id(id_generator.generate_trace_id()) | ||
self._verify_snowflake_trace_id(id_generator.generate_trace_id()) | ||
self._verify_snowflake_trace_id(id_generator.generate_trace_id()) | ||
|
||
def _verify_snowflake_trace_id(self, trace_id: int): | ||
# https://github.com/open-telemetry/opentelemetry-python/blob/main/opentelemetry-api/src/opentelemetry/trace/span.py | ||
self.assertTrue(trace.INVALID_TRACE_ID < trace_id <= TRACE_ID_MAX_VALUE) | ||
|
||
# Get the hex format of the snowflake_trace_id and pad it to 32 characters | ||
# The timestamp prefix is the first 8 characters of this. | ||
timestamp_prefix = f'{trace_id:x}'.zfill(32)[:8] | ||
|
||
# the expected prefix is the timestamp (in minutes) in hex format padded to 8 characters. | ||
mock_timestamp_minutes = int(MOCK_TIMESTAMP) // 60 | ||
mock_timestamp_prefix = f'{mock_timestamp_minutes:x}'.zfill(8) | ||
self.assertEqual(timestamp_prefix, mock_timestamp_prefix) |