diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8592911..a285a2d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -40,6 +40,11 @@ jobs: - name: Tests run: | pytest --cov=snowplow_tracker --cov-report=xml + + - name: MyPy + run: | + python -m pip install mypy + mypy snowplow_tracker --exclude '/test' - name: Demo run: | diff --git a/setup.py b/setup.py index d0ef7f0..e1b2aa3 100644 --- a/setup.py +++ b/setup.py @@ -65,5 +65,9 @@ "Programming Language :: Python :: 3.12", "Operating System :: OS Independent", ], - install_requires=["requests>=2.25.1,<3.0", "typing_extensions>=3.7.4"], + install_requires=[ + "requests>=2.25.1,<3.0", + "types-requests>=2.25.1,<3.0", + "typing_extensions>=3.7.4", + ], ) diff --git a/snowplow_tracker/constants.py b/snowplow_tracker/constants.py index 579ff86..53ecc15 100644 --- a/snowplow_tracker/constants.py +++ b/snowplow_tracker/constants.py @@ -18,7 +18,7 @@ from snowplow_tracker import _version, SelfDescribingJson VERSION = "py-%s" % _version.__version__ -DEFAULT_ENCODE_BASE64 = True +DEFAULT_ENCODE_BASE64: bool = True # Type hint required for Python 3.6 MyPy check BASE_SCHEMA_PATH = "iglu:com.snowplowanalytics.snowplow" MOBILE_SCHEMA_PATH = "iglu:com.snowplowanalytics.mobile" SCHEMA_TAG = "jsonschema" diff --git a/snowplow_tracker/contracts.py b/snowplow_tracker/contracts.py index c54ac66..3b17e1a 100644 --- a/snowplow_tracker/contracts.py +++ b/snowplow_tracker/contracts.py @@ -77,7 +77,7 @@ def _get_parameter_name() -> str: match = _MATCH_FIRST_PARAMETER_REGEX.search(code) if not match: return "Unnamed parameter" - return match.groups(0)[0] + return str(match.groups(0)[0]) def _check_form_element(element: Dict[str, Any]) -> bool: diff --git a/snowplow_tracker/emitters.py b/snowplow_tracker/emitters.py index af23356..6a138f0 100644 --- a/snowplow_tracker/emitters.py +++ b/snowplow_tracker/emitters.py @@ -20,7 +20,7 @@ import threading import requests import random -from typing import Optional, Union, Tuple, Dict +from typing import Optional, Union, Tuple, Dict, cast, Callable from queue import Queue from snowplow_tracker.self_describing_json import SelfDescribingJson @@ -31,6 +31,7 @@ Method, SuccessCallback, FailureCallback, + EmitterProtocol, ) from snowplow_tracker.contracts import one_of from snowplow_tracker.event_store import EventStore, InMemoryEventStore @@ -48,7 +49,20 @@ METHODS = {"get", "post"} -class Emitter(object): +# Unifes the two request methods under one interface +class Requester: + post: Callable + get: Callable + + def __init__(self, post: Callable, get: Callable): + # 3.6 MyPy compatibility: + # error: Cannot assign to a method + # https://github.com/python/mypy/issues/2427 + setattr(self, "post", post) + setattr(self, "get", get) + + +class Emitter(EmitterProtocol): """ Synchronously send Snowplow events to a Snowplow collector Supports both GET and POST requests @@ -151,12 +165,15 @@ def __init__( self.retry_timer = FlushTimer(emitter=self, repeating=False) self.max_retry_delay_seconds = max_retry_delay_seconds - self.retry_delay = 0 + self.retry_delay: Union[int, float] = 0 self.custom_retry_codes = custom_retry_codes logger.info("Emitter initialized with endpoint " + self.endpoint) - self.request_method = requests if session is None else session + if session is None: + self.request_method = Requester(post=requests.post, get=requests.get) + else: + self.request_method = Requester(post=session.post, get=session.get) @staticmethod def as_collector_uri( @@ -183,7 +200,7 @@ def as_collector_uri( if endpoint.split("://")[0] in PROTOCOLS: endpoint_arr = endpoint.split("://") - protocol = endpoint_arr[0] + protocol = cast(HttpProtocol, endpoint_arr[0]) endpoint = endpoint_arr[1] if method == "get": @@ -427,6 +444,10 @@ def _cancel_retry_timer(self) -> None: """ self.retry_timer.cancel() + # This is only here to satisfy the `EmitterProtocol` interface + def async_flush(self) -> None: + return + class AsyncEmitter(Emitter): """ @@ -446,7 +467,7 @@ def __init__( byte_limit: Optional[int] = None, request_timeout: Optional[Union[float, Tuple[float, float]]] = None, max_retry_delay_seconds: int = 60, - buffer_capacity: int = None, + buffer_capacity: Optional[int] = None, custom_retry_codes: Dict[int, bool] = {}, event_store: Optional[EventStore] = None, session: Optional[requests.Session] = None, @@ -501,7 +522,7 @@ def __init__( event_store=event_store, session=session, ) - self.queue = Queue() + self.queue: Queue = Queue() for i in range(thread_count): t = threading.Thread(target=self.consume) t.daemon = True diff --git a/snowplow_tracker/event_store.py b/snowplow_tracker/event_store.py index 898f92f..b8d1302 100644 --- a/snowplow_tracker/event_store.py +++ b/snowplow_tracker/event_store.py @@ -15,6 +15,7 @@ # language governing permissions and limitations there under. # """ +from typing import List from typing_extensions import Protocol from snowplow_tracker.typing import PayloadDict, PayloadDictList from logging import Logger @@ -25,7 +26,7 @@ class EventStore(Protocol): EventStore protocol. For buffering events in the Emitter. """ - def add_event(payload: PayloadDict) -> bool: + def add_event(self, payload: PayloadDict) -> bool: """ Add PayloadDict to buffer. Returns True if successful. @@ -35,7 +36,7 @@ def add_event(payload: PayloadDict) -> bool: """ ... - def get_events_batch() -> PayloadDictList: + def get_events_batch(self) -> PayloadDictList: """ Get a list of all the PayloadDicts in the buffer. @@ -43,7 +44,7 @@ def get_events_batch() -> PayloadDictList: """ ... - def cleanup(batch: PayloadDictList, need_retry: bool) -> None: + def cleanup(self, batch: PayloadDictList, need_retry: bool) -> None: """ Removes sent events from the event store. If events need to be retried they are re-added to the buffer. @@ -54,7 +55,7 @@ def cleanup(batch: PayloadDictList, need_retry: bool) -> None: """ ... - def size() -> int: + def size(self) -> int: """ Returns the number of events in the buffer @@ -76,7 +77,7 @@ def __init__(self, logger: Logger, buffer_capacity: int = 10000) -> None: When the buffer is full new events are lost. :type buffer_capacity int """ - self.event_buffer = [] + self.event_buffer: List[PayloadDict] = [] self.buffer_capacity = buffer_capacity self.logger = logger diff --git a/snowplow_tracker/events/event.py b/snowplow_tracker/events/event.py index c9d9b82..fb300b8 100644 --- a/snowplow_tracker/events/event.py +++ b/snowplow_tracker/events/event.py @@ -97,10 +97,9 @@ def build_payload( if self.event_subject is not None: fin_payload_dict = self.event_subject.combine_subject(subject) else: - fin_payload_dict = None if subject is None else subject.standard_nv_pairs + fin_payload_dict = {} if subject is None else subject.standard_nv_pairs - if fin_payload_dict is not None: - self.payload.add_dict(fin_payload_dict) + self.payload.add_dict(fin_payload_dict) return self.payload @property diff --git a/snowplow_tracker/events/screen_view.py b/snowplow_tracker/events/screen_view.py index d0cea5d..6b4af92 100644 --- a/snowplow_tracker/events/screen_view.py +++ b/snowplow_tracker/events/screen_view.py @@ -15,7 +15,7 @@ # language governing permissions and limitations there under. # """ -from typing import Optional, List +from typing import Dict, Optional, List from snowplow_tracker.typing import JsonEncoderFunction from snowplow_tracker.events.event import Event from snowplow_tracker.events.self_describing import SelfDescribing @@ -76,7 +76,7 @@ def __init__( super(ScreenView, self).__init__( event_subject=event_subject, context=context, true_timestamp=true_timestamp ) - self.screen_view_properties = {} + self.screen_view_properties: Dict[str, str] = {} self.id_ = id_ self.name = name self.type = type diff --git a/snowplow_tracker/events/structured_event.py b/snowplow_tracker/events/structured_event.py index 00658e9..23abafa 100644 --- a/snowplow_tracker/events/structured_event.py +++ b/snowplow_tracker/events/structured_event.py @@ -15,7 +15,7 @@ # language governing permissions and limitations there under. # """ from snowplow_tracker.events.event import Event -from typing import Optional, List +from typing import Optional, List, Union from snowplow_tracker.subject import Subject from snowplow_tracker.self_describing_json import SelfDescribingJson from snowplow_tracker.contracts import non_empty_string @@ -41,7 +41,7 @@ def __init__( action: str, label: Optional[str] = None, property_: Optional[str] = None, - value: Optional[int] = None, + value: Optional[Union[int, float]] = None, event_subject: Optional[Subject] = None, context: Optional[List[SelfDescribingJson]] = None, true_timestamp: Optional[float] = None, @@ -84,7 +84,7 @@ def category(self) -> Optional[str]: return self.payload.nv_pairs.get("se_ca") @category.setter - def category(self, value: Optional[str]): + def category(self, value: str): non_empty_string(value) self.payload.add("se_ca", value) @@ -96,7 +96,7 @@ def action(self) -> Optional[str]: return self.payload.nv_pairs.get("se_ac") @action.setter - def action(self, value: Optional[str]): + def action(self, value: str): non_empty_string(value) self.payload.add("se_ac", value) @@ -123,12 +123,12 @@ def property_(self, value: Optional[str]): self.payload.add("se_pr", value) @property - def value(self) -> Optional[int]: + def value(self) -> Optional[Union[int, float]]: """ A value associated with the user action """ return self.payload.nv_pairs.get("se_va") @value.setter - def value(self, value: Optional[int]): + def value(self, value: Optional[Union[int, float]]): self.payload.add("se_va", value) diff --git a/snowplow_tracker/payload.py b/snowplow_tracker/payload.py index 26e3262..18d1bf4 100644 --- a/snowplow_tracker/payload.py +++ b/snowplow_tracker/payload.py @@ -83,9 +83,8 @@ def add_json( if encode_base64: encoded_dict = base64.urlsafe_b64encode(json_dict.encode("utf-8")) - if not isinstance(encoded_dict, str): - encoded_dict = encoded_dict.decode("utf-8") - self.add(type_when_encoded, encoded_dict) + encoded_dict_str = encoded_dict.decode("utf-8") + self.add(type_when_encoded, encoded_dict_str) else: self.add(type_when_not_encoded, json_dict) diff --git a/snowplow_tracker/snowplow.py b/snowplow_tracker/snowplow.py index d824ed2..daa1434 100644 --- a/snowplow_tracker/snowplow.py +++ b/snowplow_tracker/snowplow.py @@ -16,7 +16,7 @@ # """ import logging -from typing import Optional +from typing import Dict, Optional from snowplow_tracker import ( Tracker, Emitter, @@ -37,7 +37,7 @@ class Snowplow: - _trackers = {} + _trackers: Dict[str, Tracker] = {} @staticmethod def create_tracker( @@ -149,7 +149,7 @@ def reset(cls): cls._trackers = {} @classmethod - def get_tracker(cls, namespace: str) -> Tracker: + def get_tracker(cls, namespace: str) -> Optional[Tracker]: """ Returns a Snowplow tracker from the Snowplow object if it exists :param namespace: Snowplow tracker namespace diff --git a/snowplow_tracker/subject.py b/snowplow_tracker/subject.py index c3165d3..cbf29aa 100644 --- a/snowplow_tracker/subject.py +++ b/snowplow_tracker/subject.py @@ -15,7 +15,7 @@ # language governing permissions and limitations there under. # """ -from typing import Optional +from typing import Dict, Optional, Union from snowplow_tracker.contracts import one_of, greater_than from snowplow_tracker.typing import SupportedPlatform, SUPPORTED_PLATFORMS, PayloadDict @@ -30,7 +30,7 @@ class Subject(object): """ def __init__(self) -> None: - self.standard_nv_pairs = {"p": DEFAULT_PLATFORM} + self.standard_nv_pairs: Dict[str, Union[str, int]] = {"p": DEFAULT_PLATFORM} def set_platform(self, value: SupportedPlatform) -> "Subject": """ diff --git a/snowplow_tracker/tracker.py b/snowplow_tracker/tracker.py index 2effe83..4dc489d 100644 --- a/snowplow_tracker/tracker.py +++ b/snowplow_tracker/tracker.py @@ -80,13 +80,13 @@ def __init__( if subject is None: subject = Subject() - if type(emitters) is list: + if isinstance(emitters, list): non_empty(emitters) self.emitters = emitters else: self.emitters = [emitters] - self.subject = subject + self.subject: Optional[Subject] = subject self.encode_base64 = encode_base64 self.json_encoder = json_encoder @@ -145,6 +145,8 @@ def track( if "eid" in payload.nv_pairs.keys(): return payload.nv_pairs["eid"] + return None + def complete_payload( self, event: Event, @@ -298,7 +300,7 @@ def track_link_click( ) non_empty_string(target_url) - properties = {} + properties: Dict[str, Union[str, ElementClasses]] = {} properties["targetUrl"] = target_url if element_id is not None: properties["elementId"] = element_id @@ -361,7 +363,7 @@ def track_add_to_cart( ) non_empty_string(sku) - properties = {} + properties: Union[Dict[str, Union[str, float, int]]] = {} properties["sku"] = sku properties["quantity"] = quantity if name is not None: @@ -425,7 +427,7 @@ def track_remove_from_cart( ) non_empty_string(sku) - properties = {} + properties: Dict[str, Union[str, float, int]] = {} properties["sku"] = sku properties["quantity"] = quantity if name is not None: @@ -493,7 +495,7 @@ def track_form_change( if type_ is not None: one_of(type_.lower(), FORM_TYPES) - properties = dict() + properties: Dict[str, Union[Optional[str], ElementClasses]] = dict() properties["formId"] = form_id properties["elementId"] = element_id properties["nodeName"] = node_name @@ -549,7 +551,9 @@ def track_form_submit( for element in elements or []: form_element(element) - properties = dict() + properties: Dict[ + str, Union[str, ElementClasses, FormClasses, List[Dict[str, Any]]] + ] = dict() properties["formId"] = form_id if form_classes is not None: properties["formClasses"] = form_classes @@ -602,7 +606,9 @@ def track_site_search( ) non_empty(terms) - properties = {} + properties: Dict[ + str, Union[Sequence[str], Dict[str, Union[str, bool]], int] + ] = {} properties["terms"] = terms if filters is not None: properties["filters"] = filters @@ -878,7 +884,7 @@ def track_struct_event( action: str, label: Optional[str] = None, property_: Optional[str] = None, - value: Optional[float] = None, + value: Optional[Union[int, float]] = None, context: Optional[List[SelfDescribingJson]] = None, tstamp: Optional[float] = None, event_subject: Optional[Subject] = None, @@ -1037,4 +1043,9 @@ def add_emitter(self, emitter: EmitterProtocol) -> "Tracker": return self def get_namespace(self) -> str: - return self.standard_nv_pairs["tna"] + # As app_id is added to the standard_nv_pairs dict above with a type of Optional[str], the type for + # the whole standard_nv_pairs dict is inferred to be dict[str, Optional[str]]. + # But, we know that "tna" should always be present in the dict, since namespace is a required argument. + # + # This ignores MyPy saying Incompatible return value type (got "str | None", expected "str") + return self.standard_nv_pairs["tna"] # type: ignore diff --git a/snowplow_tracker/tracker_configuration.py b/snowplow_tracker/tracker_configuration.py index af2a4b9..6a574dc 100644 --- a/snowplow_tracker/tracker_configuration.py +++ b/snowplow_tracker/tracker_configuration.py @@ -22,7 +22,7 @@ class TrackerConfiguration(object): def __init__( self, - encode_base64: Optional[bool] = None, + encode_base64: bool = True, json_encoder: Optional[JsonEncoderFunction] = None, ) -> None: """ @@ -37,18 +37,16 @@ def __init__( self.json_encoder = json_encoder @property - def encode_base64(self) -> Optional[bool]: + def encode_base64(self) -> bool: """ Whether JSONs in the payload should be base-64 encoded. Default is True. """ return self._encode_base64 @encode_base64.setter - def encode_base64(self, value: Optional[bool]): + def encode_base64(self, value: bool): if isinstance(value, bool) or value is None: self._encode_base64 = value - else: - raise ValueError("encode_base64 must be True or False") @property def json_encoder(self) -> Optional[JsonEncoderFunction]: diff --git a/snowplow_tracker/typing.py b/snowplow_tracker/typing.py index 5bbc477..3e97356 100644 --- a/snowplow_tracker/typing.py +++ b/snowplow_tracker/typing.py @@ -65,5 +65,10 @@ class EmitterProtocol(Protocol): - def input(self, payload: PayloadDict) -> None: - ... + def input(self, payload: PayloadDict) -> None: ... + + def flush(self) -> None: ... + + def async_flush(self) -> None: ... + + def sync_flush(self) -> None: ...