diff --git a/bumble/device.py b/bumble/device.py index 1d40a357d..37f161087 100644 --- a/bumble/device.py +++ b/bumble/device.py @@ -23,6 +23,7 @@ import logging from contextlib import asynccontextmanager, AsyncExitStack from dataclasses import dataclass +from collections.abc import Iterable from typing import ( Any, Callable, @@ -32,6 +33,7 @@ Optional, Tuple, Type, + TypeVar, Set, Union, cast, @@ -440,8 +442,11 @@ def __int__(self): # ----------------------------------------------------------------------------- +_PROXY_CLASS = TypeVar('_PROXY_CLASS', bound=gatt_client.ProfileServiceProxy) + + class Peer: - def __init__(self, connection): + def __init__(self, connection: Connection) -> None: self.connection = connection # Create a GATT client for the connection @@ -449,77 +454,113 @@ def __init__(self, connection): connection.gatt_client = self.gatt_client @property - def services(self): + def services(self) -> List[gatt_client.ServiceProxy]: return self.gatt_client.services - async def request_mtu(self, mtu): + async def request_mtu(self, mtu: int) -> int: mtu = await self.gatt_client.request_mtu(mtu) self.connection.emit('connection_att_mtu_update') return mtu - async def discover_service(self, uuid): + async def discover_service( + self, uuid: Union[core.UUID, str] + ) -> List[gatt_client.ServiceProxy]: return await self.gatt_client.discover_service(uuid) - async def discover_services(self, uuids=()): + async def discover_services( + self, uuids: Iterable[core.UUID] = () + ) -> List[gatt_client.ServiceProxy]: return await self.gatt_client.discover_services(uuids) - async def discover_included_services(self, service): + async def discover_included_services( + self, service: gatt_client.ServiceProxy + ) -> List[gatt_client.ServiceProxy]: return await self.gatt_client.discover_included_services(service) - async def discover_characteristics(self, uuids=(), service=None): + async def discover_characteristics( + self, + uuids: Iterable[Union[core.UUID, str]] = (), + service: Optional[gatt_client.ServiceProxy] = None, + ) -> List[gatt_client.CharacteristicProxy]: return await self.gatt_client.discover_characteristics( uuids=uuids, service=service ) async def discover_descriptors( - self, characteristic=None, start_handle=None, end_handle=None + self, + characteristic: Optional[gatt_client.CharacteristicProxy] = None, + start_handle: Optional[int] = None, + end_handle: Optional[int] = None, ): return await self.gatt_client.discover_descriptors( characteristic, start_handle, end_handle ) - async def discover_attributes(self): + async def discover_attributes(self) -> List[gatt_client.AttributeProxy]: return await self.gatt_client.discover_attributes() - async def subscribe(self, characteristic, subscriber=None, prefer_notify=True): + async def subscribe( + self, + characteristic: gatt_client.CharacteristicProxy, + subscriber: Optional[Callable[[bytes], Any]] = None, + prefer_notify: bool = True, + ) -> None: return await self.gatt_client.subscribe( characteristic, subscriber, prefer_notify ) - async def unsubscribe(self, characteristic, subscriber=None): + async def unsubscribe( + self, + characteristic: gatt_client.CharacteristicProxy, + subscriber: Optional[Callable[[bytes], Any]] = None, + ) -> None: return await self.gatt_client.unsubscribe(characteristic, subscriber) - async def read_value(self, attribute): + async def read_value( + self, attribute: Union[int, gatt_client.AttributeProxy] + ) -> bytes: return await self.gatt_client.read_value(attribute) - async def write_value(self, attribute, value, with_response=False): + async def write_value( + self, + attribute: Union[int, gatt_client.AttributeProxy], + value: bytes, + with_response: bool = False, + ) -> None: return await self.gatt_client.write_value(attribute, value, with_response) - async def read_characteristics_by_uuid(self, uuid, service=None): + async def read_characteristics_by_uuid( + self, uuid: core.UUID, service: Optional[gatt_client.ServiceProxy] = None + ) -> List[bytes]: return await self.gatt_client.read_characteristics_by_uuid(uuid, service) - def get_services_by_uuid(self, uuid): + def get_services_by_uuid(self, uuid: core.UUID) -> List[gatt_client.ServiceProxy]: return self.gatt_client.get_services_by_uuid(uuid) - def get_characteristics_by_uuid(self, uuid, service=None): + def get_characteristics_by_uuid( + self, uuid: core.UUID, service: Optional[gatt_client.ServiceProxy] = None + ) -> List[gatt_client.CharacteristicProxy]: return self.gatt_client.get_characteristics_by_uuid(uuid, service) - def create_service_proxy(self, proxy_class): - return proxy_class.from_client(self.gatt_client) + def create_service_proxy(self, proxy_class: Type[_PROXY_CLASS]) -> _PROXY_CLASS: + return cast(_PROXY_CLASS, proxy_class.from_client(self.gatt_client)) - async def discover_service_and_create_proxy(self, proxy_class): + async def discover_service_and_create_proxy( + self, proxy_class: Type[_PROXY_CLASS] + ) -> Optional[_PROXY_CLASS]: # Discover the first matching service and its characteristics services = await self.discover_service(proxy_class.SERVICE_CLASS.UUID) if services: service = services[0] await service.discover_characteristics() return self.create_service_proxy(proxy_class) + return None - async def sustain(self, timeout=None): + async def sustain(self, timeout: Optional[float] = None) -> None: await self.connection.sustain(timeout) # [Classic only] - async def request_name(self): + async def request_name(self) -> str: return await self.connection.request_remote_name() async def __aenter__(self): @@ -532,7 +573,7 @@ async def __aenter__(self): async def __aexit__(self, exc_type, exc_value, traceback): pass - def __str__(self): + def __str__(self) -> str: return f'{self.connection.peer_address} as {self.connection.role_name}' @@ -732,7 +773,7 @@ async def encrypt(self, enable: bool = True) -> None: async def switch_role(self, role: int) -> None: return await self.device.switch_role(self, role) - async def sustain(self, timeout=None): + async def sustain(self, timeout: Optional[float] = None) -> None: """Idles the current task waiting for a disconnect or timeout""" abort = asyncio.get_running_loop().create_future() diff --git a/bumble/gatt_client.py b/bumble/gatt_client.py index e3b8bb212..4a3dedd01 100644 --- a/bumble/gatt_client.py +++ b/bumble/gatt_client.py @@ -38,6 +38,7 @@ Any, Iterable, Type, + Set, TYPE_CHECKING, ) @@ -128,7 +129,7 @@ class ServiceProxy(AttributeProxy): included_services: List[ServiceProxy] @staticmethod - def from_client(service_class, client, service_uuid): + def from_client(service_class, client: Client, service_uuid: UUID): # The service and its characteristics are considered to have already been # discovered services = client.get_services_by_uuid(service_uuid) @@ -246,8 +247,12 @@ def from_client(cls, client: Client) -> ProfileServiceProxy: class Client: services: List[ServiceProxy] cached_values: Dict[int, Tuple[datetime, bytes]] - notification_subscribers: Dict[int, Callable[[bytes], Any]] - indication_subscribers: Dict[int, Callable[[bytes], Any]] + notification_subscribers: Dict[ + int, Set[Union[CharacteristicProxy, Callable[[bytes], Any]]] + ] + indication_subscribers: Dict[ + int, Set[Union[CharacteristicProxy, Callable[[bytes], Any]]] + ] pending_response: Optional[asyncio.futures.Future[ATT_PDU]] pending_request: Optional[ATT_PDU] @@ -682,8 +687,8 @@ async def discover_characteristics( async def discover_descriptors( self, characteristic: Optional[CharacteristicProxy] = None, - start_handle=None, - end_handle=None, + start_handle: Optional[int] = None, + end_handle: Optional[int] = None, ) -> List[DescriptorProxy]: ''' See Vol 3, Part G - 4.7.1 Discover All Characteristic Descriptors @@ -789,7 +794,12 @@ async def discover_attributes(self) -> List[AttributeProxy]: return attributes - async def subscribe(self, characteristic, subscriber=None, prefer_notify=True): + async def subscribe( + self, + characteristic: CharacteristicProxy, + subscriber: Optional[Callable[[bytes], Any]] = None, + prefer_notify: bool = True, + ) -> None: # If we haven't already discovered the descriptors for this characteristic, # do it now if not characteristic.descriptors_discovered: @@ -833,7 +843,11 @@ async def subscribe(self, characteristic, subscriber=None, prefer_notify=True): await self.write_value(cccd, struct.pack(' None: # If we haven't already discovered the descriptors for this characteristic, # do it now if not characteristic.descriptors_discovered: @@ -853,7 +867,7 @@ async def unsubscribe(self, characteristic, subscriber=None): self.notification_subscribers, self.indication_subscribers, ): - subscribers = subscriber_set.get(characteristic.handle, []) + subscribers = subscriber_set.get(characteristic.handle, set()) if subscriber in subscribers: subscribers.remove(subscriber) @@ -1067,7 +1081,7 @@ def on_gatt_pdu(self, att_pdu: ATT_PDU) -> None: def on_att_handle_value_notification(self, notification): # Call all subscribers subscribers = self.notification_subscribers.get( - notification.attribute_handle, [] + notification.attribute_handle, set() ) if not subscribers: logger.warning('!!! received notification with no subscriber') @@ -1081,7 +1095,9 @@ def on_att_handle_value_notification(self, notification): def on_att_handle_value_indication(self, indication): # Call all subscribers - subscribers = self.indication_subscribers.get(indication.attribute_handle, []) + subscribers = self.indication_subscribers.get( + indication.attribute_handle, set() + ) if not subscribers: logger.warning('!!! received indication with no subscriber')