Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update typing #364

Merged
merged 1 commit into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ jobs:
- name: Tests
run: |
pytest --cov=snowplow_tracker --cov-report=xml

- name: MyPy
run: |
python -m pip install mypy
mypy snowplow_tracker --exclude '/test'

- name: Demo
run: |
Expand Down
6 changes: 5 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,5 +65,9 @@
"Programming Language :: Python :: 3.12",
"Operating System :: OS Independent",
],
install_requires=["requests>=2.25.1,<3.0", "typing_extensions>=3.7.4"],
install_requires=[
"requests>=2.25.1,<3.0",
"types-requests>=2.25.1,<3.0",
"typing_extensions>=3.7.4",
],
)
2 changes: 1 addition & 1 deletion snowplow_tracker/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from snowplow_tracker import _version, SelfDescribingJson

VERSION = "py-%s" % _version.__version__
DEFAULT_ENCODE_BASE64 = True
DEFAULT_ENCODE_BASE64: bool = True # Type hint required for Python 3.6 MyPy check
BASE_SCHEMA_PATH = "iglu:com.snowplowanalytics.snowplow"
MOBILE_SCHEMA_PATH = "iglu:com.snowplowanalytics.mobile"
SCHEMA_TAG = "jsonschema"
Expand Down
2 changes: 1 addition & 1 deletion snowplow_tracker/contracts.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _get_parameter_name() -> str:
match = _MATCH_FIRST_PARAMETER_REGEX.search(code)
if not match:
return "Unnamed parameter"
return match.groups(0)[0]
return str(match.groups(0)[0])


def _check_form_element(element: Dict[str, Any]) -> bool:
Expand Down
35 changes: 28 additions & 7 deletions snowplow_tracker/emitters.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import threading
import requests
import random
from typing import Optional, Union, Tuple, Dict
from typing import Optional, Union, Tuple, Dict, cast, Callable
from queue import Queue

from snowplow_tracker.self_describing_json import SelfDescribingJson
Expand All @@ -31,6 +31,7 @@
Method,
SuccessCallback,
FailureCallback,
EmitterProtocol,
)
from snowplow_tracker.contracts import one_of
from snowplow_tracker.event_store import EventStore, InMemoryEventStore
Expand All @@ -48,7 +49,20 @@
METHODS = {"get", "post"}


class Emitter(object):
# Unifes the two request methods under one interface
class Requester:
post: Callable
get: Callable

def __init__(self, post: Callable, get: Callable):
# 3.6 MyPy compatibility:
# error: Cannot assign to a method
# https://github.com/python/mypy/issues/2427
setattr(self, "post", post)
setattr(self, "get", get)


class Emitter(EmitterProtocol):
"""
Synchronously send Snowplow events to a Snowplow collector
Supports both GET and POST requests
Expand Down Expand Up @@ -151,12 +165,15 @@ def __init__(
self.retry_timer = FlushTimer(emitter=self, repeating=False)

self.max_retry_delay_seconds = max_retry_delay_seconds
self.retry_delay = 0
self.retry_delay: Union[int, float] = 0

self.custom_retry_codes = custom_retry_codes
logger.info("Emitter initialized with endpoint " + self.endpoint)

self.request_method = requests if session is None else session
if session is None:
self.request_method = Requester(post=requests.post, get=requests.get)
else:
self.request_method = Requester(post=session.post, get=session.get)

@staticmethod
def as_collector_uri(
Expand All @@ -183,7 +200,7 @@ def as_collector_uri(

if endpoint.split("://")[0] in PROTOCOLS:
endpoint_arr = endpoint.split("://")
protocol = endpoint_arr[0]
protocol = cast(HttpProtocol, endpoint_arr[0])
endpoint = endpoint_arr[1]

if method == "get":
Expand Down Expand Up @@ -427,6 +444,10 @@ def _cancel_retry_timer(self) -> None:
"""
self.retry_timer.cancel()

# This is only here to satisfy the `EmitterProtocol` interface
def async_flush(self) -> None:
return


class AsyncEmitter(Emitter):
"""
Expand All @@ -446,7 +467,7 @@ def __init__(
byte_limit: Optional[int] = None,
request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
max_retry_delay_seconds: int = 60,
buffer_capacity: int = None,
buffer_capacity: Optional[int] = None,
custom_retry_codes: Dict[int, bool] = {},
event_store: Optional[EventStore] = None,
session: Optional[requests.Session] = None,
Expand Down Expand Up @@ -501,7 +522,7 @@ def __init__(
event_store=event_store,
session=session,
)
self.queue = Queue()
self.queue: Queue = Queue()
for i in range(thread_count):
t = threading.Thread(target=self.consume)
t.daemon = True
Expand Down
11 changes: 6 additions & 5 deletions snowplow_tracker/event_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# language governing permissions and limitations there under.
# """

from typing import List
from typing_extensions import Protocol
from snowplow_tracker.typing import PayloadDict, PayloadDictList
from logging import Logger
Expand All @@ -25,7 +26,7 @@ class EventStore(Protocol):
EventStore protocol. For buffering events in the Emitter.
"""

def add_event(payload: PayloadDict) -> bool:
def add_event(self, payload: PayloadDict) -> bool:
"""
Add PayloadDict to buffer. Returns True if successful.

Expand All @@ -35,15 +36,15 @@ def add_event(payload: PayloadDict) -> bool:
"""
...

def get_events_batch() -> PayloadDictList:
def get_events_batch(self) -> PayloadDictList:
"""
Get a list of all the PayloadDicts in the buffer.

:rtype PayloadDictList
"""
...

def cleanup(batch: PayloadDictList, need_retry: bool) -> None:
def cleanup(self, batch: PayloadDictList, need_retry: bool) -> None:
"""
Removes sent events from the event store. If events need to be retried they are re-added to the buffer.

Expand All @@ -54,7 +55,7 @@ def cleanup(batch: PayloadDictList, need_retry: bool) -> None:
"""
...

def size() -> int:
def size(self) -> int:
"""
Returns the number of events in the buffer

Expand All @@ -76,7 +77,7 @@ def __init__(self, logger: Logger, buffer_capacity: int = 10000) -> None:
When the buffer is full new events are lost.
:type buffer_capacity int
"""
self.event_buffer = []
self.event_buffer: List[PayloadDict] = []
self.buffer_capacity = buffer_capacity
self.logger = logger

Expand Down
5 changes: 2 additions & 3 deletions snowplow_tracker/events/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,9 @@ def build_payload(
if self.event_subject is not None:
fin_payload_dict = self.event_subject.combine_subject(subject)
else:
fin_payload_dict = None if subject is None else subject.standard_nv_pairs
fin_payload_dict = {} if subject is None else subject.standard_nv_pairs

if fin_payload_dict is not None:
self.payload.add_dict(fin_payload_dict)
self.payload.add_dict(fin_payload_dict)
return self.payload

@property
Expand Down
4 changes: 2 additions & 2 deletions snowplow_tracker/events/screen_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# language governing permissions and limitations there under.
# """

from typing import Optional, List
from typing import Dict, Optional, List
from snowplow_tracker.typing import JsonEncoderFunction
from snowplow_tracker.events.event import Event
from snowplow_tracker.events.self_describing import SelfDescribing
Expand Down Expand Up @@ -76,7 +76,7 @@ def __init__(
super(ScreenView, self).__init__(
event_subject=event_subject, context=context, true_timestamp=true_timestamp
)
self.screen_view_properties = {}
self.screen_view_properties: Dict[str, str] = {}
self.id_ = id_
self.name = name
self.type = type
Expand Down
12 changes: 6 additions & 6 deletions snowplow_tracker/events/structured_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# language governing permissions and limitations there under.
# """
from snowplow_tracker.events.event import Event
from typing import Optional, List
from typing import Optional, List, Union
from snowplow_tracker.subject import Subject
from snowplow_tracker.self_describing_json import SelfDescribingJson
from snowplow_tracker.contracts import non_empty_string
Expand All @@ -41,7 +41,7 @@ def __init__(
action: str,
label: Optional[str] = None,
property_: Optional[str] = None,
value: Optional[int] = None,
value: Optional[Union[int, float]] = None,
event_subject: Optional[Subject] = None,
context: Optional[List[SelfDescribingJson]] = None,
true_timestamp: Optional[float] = None,
Expand Down Expand Up @@ -84,7 +84,7 @@ def category(self) -> Optional[str]:
return self.payload.nv_pairs.get("se_ca")

@category.setter
def category(self, value: Optional[str]):
def category(self, value: str):
non_empty_string(value)
self.payload.add("se_ca", value)

Expand All @@ -96,7 +96,7 @@ def action(self) -> Optional[str]:
return self.payload.nv_pairs.get("se_ac")

@action.setter
def action(self, value: Optional[str]):
def action(self, value: str):
non_empty_string(value)
self.payload.add("se_ac", value)

Expand All @@ -123,12 +123,12 @@ def property_(self, value: Optional[str]):
self.payload.add("se_pr", value)

@property
def value(self) -> Optional[int]:
def value(self) -> Optional[Union[int, float]]:
"""
A value associated with the user action
"""
return self.payload.nv_pairs.get("se_va")

@value.setter
def value(self, value: Optional[int]):
def value(self, value: Optional[Union[int, float]]):
self.payload.add("se_va", value)
5 changes: 2 additions & 3 deletions snowplow_tracker/payload.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,8 @@ def add_json(

if encode_base64:
encoded_dict = base64.urlsafe_b64encode(json_dict.encode("utf-8"))
if not isinstance(encoded_dict, str):
encoded_dict = encoded_dict.decode("utf-8")
self.add(type_when_encoded, encoded_dict)
encoded_dict_str = encoded_dict.decode("utf-8")
self.add(type_when_encoded, encoded_dict_str)

else:
self.add(type_when_not_encoded, json_dict)
Expand Down
6 changes: 3 additions & 3 deletions snowplow_tracker/snowplow.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# """

import logging
from typing import Optional
from typing import Dict, Optional
from snowplow_tracker import (
Tracker,
Emitter,
Expand All @@ -37,7 +37,7 @@


class Snowplow:
_trackers = {}
_trackers: Dict[str, Tracker] = {}

@staticmethod
def create_tracker(
Expand Down Expand Up @@ -149,7 +149,7 @@ def reset(cls):
cls._trackers = {}

@classmethod
def get_tracker(cls, namespace: str) -> Tracker:
def get_tracker(cls, namespace: str) -> Optional[Tracker]:
"""
Returns a Snowplow tracker from the Snowplow object if it exists
:param namespace: Snowplow tracker namespace
Expand Down
4 changes: 2 additions & 2 deletions snowplow_tracker/subject.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# language governing permissions and limitations there under.
# """

from typing import Optional
from typing import Dict, Optional, Union
from snowplow_tracker.contracts import one_of, greater_than
from snowplow_tracker.typing import SupportedPlatform, SUPPORTED_PLATFORMS, PayloadDict

Expand All @@ -30,7 +30,7 @@ class Subject(object):
"""

def __init__(self) -> None:
self.standard_nv_pairs = {"p": DEFAULT_PLATFORM}
self.standard_nv_pairs: Dict[str, Union[str, int]] = {"p": DEFAULT_PLATFORM}

def set_platform(self, value: SupportedPlatform) -> "Subject":
"""
Expand Down
Loading
Loading