Skip to content

Commit

Permalink
Merge pull request #355 from google/gbg/fix-gatt-unsubscribe
Browse files Browse the repository at this point in the history
fix #354 (gatt unsubscribe)
  • Loading branch information
barbibulle authored Nov 30, 2023
2 parents a9c4c58 + 58c9c4f commit 320164d
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 10 deletions.
40 changes: 30 additions & 10 deletions bumble/gatt_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,11 +207,11 @@ def on_change(value):

return await self.client.subscribe(self, subscriber, prefer_notify)

async def unsubscribe(self, subscriber=None):
async def unsubscribe(self, subscriber=None, force=False):
if subscriber in self.subscribers:
subscriber = self.subscribers.pop(subscriber)

return await self.client.unsubscribe(self, subscriber)
return await self.client.unsubscribe(self, subscriber, force)

def __str__(self) -> str:
return (
Expand Down Expand Up @@ -262,10 +262,8 @@ def __init__(self, connection: Connection) -> None:
self.request_semaphore = asyncio.Semaphore(1)
self.pending_request = None
self.pending_response = None
self.notification_subscribers = (
{}
) # Notification subscribers, by attribute handle
self.indication_subscribers = {} # Indication subscribers, by attribute handle
self.notification_subscribers = {} # Subscriber set, by attribute handle
self.indication_subscribers = {} # Subscriber set, by attribute handle
self.services = []
self.cached_values = {}

Expand Down Expand Up @@ -836,6 +834,7 @@ async def subscribe(
subscriber_set = subscribers.setdefault(characteristic.handle, set())
if subscriber is not None:
subscriber_set.add(subscriber)

# Add the characteristic as a subscriber, which will result in the
# characteristic emitting an 'update' event when a notification or indication
# is received
Expand All @@ -847,7 +846,14 @@ async def unsubscribe(
self,
characteristic: CharacteristicProxy,
subscriber: Optional[Callable[[bytes], Any]] = None,
force: bool = False,
) -> None:
'''
Unsubscribe from a characteristic.
If `force` is True, this will write zeros to the CCCD when there are no
subscribers left, even if there were already no registered subscribers.
'''
# If we haven't already discovered the descriptors for this characteristic,
# do it now
if not characteristic.descriptors_discovered:
Expand All @@ -861,25 +867,39 @@ async def unsubscribe(
logger.warning('unsubscribing from characteristic with no CCCD descriptor')
return

# Check if the characteristic has subscribers
if not (
characteristic.handle in self.notification_subscribers
or characteristic.handle in self.indication_subscribers
):
if not force:
return

# Remove the subscriber(s)
if subscriber is not None:
# Remove matching subscriber from subscriber sets
for subscriber_set in (
self.notification_subscribers,
self.indication_subscribers,
):
subscribers = subscriber_set.get(characteristic.handle, set())
if subscriber in subscribers:
if (
subscribers := subscriber_set.get(characteristic.handle)
) and subscriber in subscribers:
subscribers.remove(subscriber)

# Cleanup if we removed the last one
if not subscribers:
del subscriber_set[characteristic.handle]
else:
# Remove all subscribers for this attribute from the sets!
# Remove all subscribers for this attribute from the sets
self.notification_subscribers.pop(characteristic.handle, None)
self.indication_subscribers.pop(characteristic.handle, None)

if not self.notification_subscribers and not self.indication_subscribers:
# Update the CCCD
if not (
characteristic.handle in self.notification_subscribers
or characteristic.handle in self.indication_subscribers
):
# No more subscribers left
await self.write_value(cccd, b'\x00\x00', with_response=True)

Expand Down
79 changes: 79 additions & 0 deletions tests/gatt_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import os
import struct
import pytest
from unittest.mock import Mock, ANY

from bumble.controller import Controller
from bumble.gatt_client import CharacteristicProxy
Expand Down Expand Up @@ -763,6 +764,83 @@ def on_c3_update_3(value): # for indicate
assert not c3._called_3


# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_unsubscribe():
[client, server] = LinkedDevices().devices[:2]

characteristic1 = Characteristic(
'FDB159DB-036C-49E3-B3DB-6325AC750806',
Characteristic.Properties.READ | Characteristic.Properties.NOTIFY,
Characteristic.READABLE,
bytes([1, 2, 3]),
)
characteristic2 = Characteristic(
'3234C4F4-3F34-4616-8935-45A50EE05DEB',
Characteristic.Properties.READ | Characteristic.Properties.NOTIFY,
Characteristic.READABLE,
bytes([1, 2, 3]),
)

service1 = Service(
'3A657F47-D34F-46B3-B1EC-698E29B6B829',
[characteristic1, characteristic2],
)
server.add_services([service1])

mock1 = Mock()
characteristic1.on('subscription', mock1)
mock2 = Mock()
characteristic2.on('subscription', mock2)

await client.power_on()
await server.power_on()
connection = await client.connect(server.random_address)
peer = Peer(connection)

await peer.discover_services()
await peer.discover_characteristics()
c = peer.get_characteristics_by_uuid(characteristic1.uuid)
assert len(c) == 1
c1 = c[0]
c = peer.get_characteristics_by_uuid(characteristic2.uuid)
assert len(c) == 1
c2 = c[0]

await c1.subscribe()
await async_barrier()
mock1.assert_called_once_with(ANY, True, False)

await c2.subscribe()
await async_barrier()
mock2.assert_called_once_with(ANY, True, False)

mock1.reset_mock()
await c1.unsubscribe()
await async_barrier()
mock1.assert_called_once_with(ANY, False, False)

mock2.reset_mock()
await c2.unsubscribe()
await async_barrier()
mock2.assert_called_once_with(ANY, False, False)

mock1.reset_mock()
await c1.unsubscribe()
await async_barrier()
mock1.assert_not_called()

mock2.reset_mock()
await c2.unsubscribe()
await async_barrier()
mock2.assert_not_called()

mock1.reset_mock()
await c1.unsubscribe(force=True)
await async_barrier()
mock1.assert_called_once_with(ANY, False, False)


# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_mtu_exchange():
Expand Down Expand Up @@ -886,6 +964,7 @@ async def async_main():
await test_read_write()
await test_read_write2()
await test_subscribe_notify()
await test_unsubscribe()
await test_characteristic_encoding()
await test_mtu_exchange()

Expand Down

0 comments on commit 320164d

Please sign in to comment.