Skip to content

Commit

Permalink
feat: improves API of FilterChain
Browse files Browse the repository at this point in the history
  • Loading branch information
M0r13n committed Jun 25, 2024
1 parent b973000 commit fa8bb66
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 48 deletions.
25 changes: 6 additions & 19 deletions examples/filters.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from pyais import decode
from pyais.filter import (
AttributeFilter,
DistanceFilter,
Expand All @@ -7,6 +6,7 @@
MessageTypeFilter,
NoneFilter
)
from pyais.stream import TCPConnection

# Define the filter chain with various criteria
chain = FilterChain([
Expand All @@ -26,21 +26,8 @@
GridFilter(lat_min=50, lon_min=0, lat_max=52, lon_max=5),
])

# Example AIS data to filter
data = [
decode(b"!AIVDM,1,1,,B,15NG6V0P01G?cFhE`R2IU?wn28R>,0*05"),
decode(b"!AIVDM,1,1,,A,13HOI:0P0000VOHLCnHQKwvL05Ip,0*23"),
decode(b"!AIVDM,1,1,,B,100h00PP0@PHFV`Mg5gTH?vNPUIp,0*3B"),
decode(b"!AIVDM,1,1,,A,133sVfPP00PD>hRMDH@jNOvN20S8,0*7F"),
decode(b"!AIVDM,1,1,,B,13eaJF0P00Qd388Eew6aagvH85Ip,0*45"),
decode(b"!AIVDM,1,1,,A,14eGrSPP00ncMJTO5C6aBwvP2D0?,0*7A"),
decode(b"!AIVDM,1,1,,A,15MrVH0000KH<:V:NtBLoqFP2H9:,0*2F"),
decode(b"!AIVDM,1,1,,A,702R5`hwCjq8,0*6B"),
]

# Filter the data using the defined chain
filtered_data = list(chain.filter(data))

# Print the latitude and longitude of each message that passed the filters
for msg in filtered_data:
print(msg.lat, msg.lon)
# Create a stream of ais messages
with TCPConnection('153.44.253.27', port=5631) as ais_stream:
for ais_msg in chain.filter(ais_stream):
# Only messages that pass this filter chain are printed
print(ais_msg)
3 changes: 2 additions & 1 deletion pyais/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pyais.messages import NMEAMessage, ANY_MESSAGE, AISSentence
from pyais.stream import TCPConnection, FileReaderStream, IterMessages
from pyais.stream import TCPConnection, FileReaderStream, IterMessages, Stream
from pyais.encode import encode_dict, encode_msg, ais_to_nmea_0183
from pyais.decode import decode
from pyais.tracker import AISTracker, AISTrack
Expand All @@ -18,6 +18,7 @@
'TCPConnection',
'IterMessages',
'FileReaderStream',
'Stream',
'decode',
'AISTracker',
'AISTrack',
Expand Down
57 changes: 30 additions & 27 deletions pyais/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@
"""

import math
import socket
import typing
import pyais

# Type Aliases for readability
AIS_STREAM = typing.Generator[pyais.AISSentence, None, None]
FILTER_FUNCTION = typing.Callable[[pyais.AISSentence], bool]
F = typing.TypeVar("F", typing.BinaryIO, socket.socket, None)
AIS_STREAM = pyais.Stream[F]
MESSAGE_STREAM = typing.Generator[pyais.ANY_MESSAGE, None, None]
FILTER_FUNCTION = typing.Callable[[pyais.ANY_MESSAGE], bool]
LAT_LON = typing.Tuple[float, float] # Tuple type for latitude and longitude


Expand Down Expand Up @@ -66,30 +69,30 @@ def set_next(self, filter: 'Filter') -> None:
"""
self.next_filter = filter

def filter(self, data: AIS_STREAM) -> AIS_STREAM:
def filter(self, data: MESSAGE_STREAM) -> MESSAGE_STREAM:
"""
Apply the filter to the data and then pass it to the next filter.
Parameters:
data (AIS_STREAM): The stream of data to filter.
data (MESSAGE_STREAM): The stream of data to filter.
Returns:
AIS_STREAM: The filtered data stream.
MESSAGE_STREAM: The filtered data stream.
"""
data = self.filter_data(data)
if self.next_filter:
return self.next_filter.filter(data)
return data

def filter_data(self, data: AIS_STREAM) -> AIS_STREAM:
def filter_data(self, data: MESSAGE_STREAM) -> MESSAGE_STREAM:
"""
Abstract method to filter data. Should be implemented by subclasses.
Parameters:
data (AIS_STREAM): The stream of data to filter.
data (MESSAGE_STREAM): The stream of data to filter.
Returns:
AIS_STREAM: The filtered data stream.
MESSAGE_STREAM: The filtered data stream.
"""
raise NotImplementedError("This method should be overridden by subclasses.")

Expand All @@ -109,15 +112,15 @@ def __init__(self, ff: FILTER_FUNCTION) -> None:
super().__init__()
self.ff = ff

def filter_data(self, data: AIS_STREAM) -> AIS_STREAM:
def filter_data(self, data: MESSAGE_STREAM) -> MESSAGE_STREAM:
"""
Filter the data based on the user-defined function.
Parameters:
data (AIS_STREAM): The stream of data to filter.
data (MESSAGE_STREAM): The stream of data to filter.
Yields:
AIS_STREAM: The filtered data stream.
MESSAGE_STREAM: The filtered data stream.
"""
yield from filter(self.ff, data)

Expand All @@ -137,15 +140,15 @@ def __init__(self, *attrs: str) -> None:
super().__init__()
self.attrs = attrs

def filter_data(self, data: AIS_STREAM) -> AIS_STREAM:
def filter_data(self, data: MESSAGE_STREAM) -> MESSAGE_STREAM:
"""
Filter the data, allowing only messages where specified attributes are not None.
Parameters:
data (AIS_STREAM): The stream of data to filter.
data (MESSAGE_STREAM): The stream of data to filter.
Yields:
AIS_STREAM: The filtered data stream.
MESSAGE_STREAM: The filtered data stream.
"""
for msg in data:
if all(getattr(msg, attr, None) is not None for attr in self.attrs):
Expand All @@ -167,18 +170,18 @@ def __init__(self, *types: int) -> None:
super().__init__()
self.types = types

def filter_data(self, data: AIS_STREAM) -> AIS_STREAM:
def filter_data(self, data: MESSAGE_STREAM) -> MESSAGE_STREAM:
"""
Filter the data, allowing only messages of specified types.
Parameters:
data (AIS_STREAM): The stream of data to filter.
data (MESSAGE_STREAM): The stream of data to filter.
Yields:
AIS_STREAM: The filtered data stream.
MESSAGE_STREAM: The filtered data stream.
"""
for msg in data:
if msg.msg_type not in self.types: # type: ignore
if msg.msg_type not in self.types:
continue
yield msg

Expand All @@ -200,15 +203,15 @@ def __init__(self, ref_lat_lon: LAT_LON, distance_km: float) -> None:
self.ref_lat_lon = ref_lat_lon
self.distance_km = distance_km

def filter_data(self, data: AIS_STREAM) -> AIS_STREAM:
def filter_data(self, data: MESSAGE_STREAM) -> MESSAGE_STREAM:
"""
Filter the data based on distance from a reference point.
Parameters:
data (AIS_STREAM): The stream of data to filter.
data (MESSAGE_STREAM): The stream of data to filter.
Yields:
AIS_STREAM: The filtered data stream.
MESSAGE_STREAM: The filtered data stream.
"""
for msg in data:
if hasattr(msg, 'lat'):
Expand All @@ -235,15 +238,15 @@ def __init__(self, lat_min: float, lon_min: float, lat_max: float, lon_max: floa
self.lat_max = lat_max
self.lon_max = lon_max

def filter_data(self, data: AIS_STREAM) -> AIS_STREAM:
def filter_data(self, data: MESSAGE_STREAM) -> MESSAGE_STREAM:
"""
Filter the data based on whether it falls within a specified grid.
Parameters:
data (AIS_STREAM): The stream of data to filter.
data (MESSAGE_STREAM): The stream of data to filter.
Yields:
AIS_STREAM: The filtered data stream.
MESSAGE_STREAM: The filtered data stream.
"""
for msg in data:
if hasattr(msg, 'lat'):
Expand Down Expand Up @@ -274,14 +277,14 @@ def __init__(self, filters: typing.List[Filter]) -> None:
self.filters = filters
self.start = filters[0]

def filter(self, data: AIS_STREAM) -> AIS_STREAM:
def filter(self, stream: AIS_STREAM[F]) -> MESSAGE_STREAM:
"""
Apply the chain of filters to the data.
Parameters:
data (AIS_STREAM): The stream of data to filter.
stream (AIS_STREAM): The stream of data to filter.
Yields:
AIS_STREAM: The filtered data stream.
"""
yield from self.start.filter(data)
yield from self.start.filter(x.decode() for x in stream)
10 changes: 9 additions & 1 deletion tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@
import subprocess
import unittest

KEYWORDS_TO_IGNORE = (
'tcp',
'udp',
'live',
'tracking',
'filters',
)


class TestExamples(unittest.TestCase):
"""
Expand All @@ -14,7 +22,7 @@ def test_run_every_file(self):
i = -1
exe = sys.executable
for i, file in enumerate(pathlib.Path(__file__).parent.parent.joinpath('examples').glob('*.py')):
if 'tcp' not in str(file) and 'udp' not in str(file) and 'live' not in str(file) and 'tracking' not in str(file):
if all(kw not in str(file) for kw in KEYWORDS_TO_IGNORE):
env = os.environ
env['PYTHONPATH'] = f':{pathlib.Path(__file__).parent.parent.absolute()}'
assert subprocess.check_call(f'{exe} {file}'.split(), env=env, shell=False) == 0
Expand Down
4 changes: 4 additions & 0 deletions tests/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ def __init__(self, msg_type=None, lat=None, lon=None, other_attr=None):
self.lon = lon
self.other_attr = other_attr

def decode(self):
return self


class TestNoneFilter(unittest.TestCase):
def test_filtering_none_attributes(self):
Expand Down Expand Up @@ -165,6 +168,7 @@ def test_filter_chain(self):
filter1 = NoneFilter('lat', 'lon')
filter2 = MessageTypeFilter(1, 2)
chain = FilterChain([filter1, filter2])

mock_data = [MockAISMessage(lat=1, lon=1, msg_type=1), MockAISMessage(lat=None, lon=1, msg_type=2)]

# Execute
Expand Down

0 comments on commit fa8bb66

Please sign in to comment.