Skip to content

Commit

Permalink
Typing GATT Client and Device Peer
Browse files Browse the repository at this point in the history
  • Loading branch information
zxzxwu committed Nov 28, 2023
1 parent a13e193 commit 58ba8a1
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 34 deletions.
87 changes: 64 additions & 23 deletions bumble/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import logging
from contextlib import asynccontextmanager, AsyncExitStack
from dataclasses import dataclass
from collections.abc import Iterable
from typing import (
Any,
Callable,
Expand All @@ -32,6 +33,7 @@
Optional,
Tuple,
Type,
TypeVar,
Set,
Union,
cast,
Expand Down Expand Up @@ -440,86 +442,125 @@ def __int__(self):


# -----------------------------------------------------------------------------
_PROXY_CLASS = TypeVar('_PROXY_CLASS', bound=gatt_client.ProfileServiceProxy)


class Peer:
def __init__(self, connection):
def __init__(self, connection: Connection) -> None:
self.connection = connection

# Create a GATT client for the connection
self.gatt_client = gatt_client.Client(connection)
connection.gatt_client = self.gatt_client

@property
def services(self):
def services(self) -> List[gatt_client.ServiceProxy]:
return self.gatt_client.services

async def request_mtu(self, mtu):
async def request_mtu(self, mtu: int) -> int:
mtu = await self.gatt_client.request_mtu(mtu)
self.connection.emit('connection_att_mtu_update')
return mtu

async def discover_service(self, uuid):
async def discover_service(
self, uuid: Union[core.UUID, str]
) -> List[gatt_client.ServiceProxy]:
return await self.gatt_client.discover_service(uuid)

async def discover_services(self, uuids=()):
async def discover_services(
self, uuids: Iterable[core.UUID] = ()
) -> List[gatt_client.ServiceProxy]:
return await self.gatt_client.discover_services(uuids)

async def discover_included_services(self, service):
async def discover_included_services(
self, service: gatt_client.ServiceProxy
) -> List[gatt_client.ServiceProxy]:
return await self.gatt_client.discover_included_services(service)

async def discover_characteristics(self, uuids=(), service=None):
async def discover_characteristics(
self,
uuids: Iterable[Union[core.UUID, str]] = (),
service: Optional[gatt_client.ServiceProxy] = None,
) -> List[gatt_client.CharacteristicProxy]:
return await self.gatt_client.discover_characteristics(
uuids=uuids, service=service
)

async def discover_descriptors(
self, characteristic=None, start_handle=None, end_handle=None
self,
characteristic: Optional[gatt_client.CharacteristicProxy] = None,
start_handle: Optional[int] = None,
end_handle: Optional[int] = None,
):
return await self.gatt_client.discover_descriptors(
characteristic, start_handle, end_handle
)

async def discover_attributes(self):
async def discover_attributes(self) -> List[gatt_client.AttributeProxy]:
return await self.gatt_client.discover_attributes()

async def subscribe(self, characteristic, subscriber=None, prefer_notify=True):
async def subscribe(
self,
characteristic: gatt_client.CharacteristicProxy,
subscriber: Optional[Callable[[bytes], Any]] = None,
prefer_notify: bool = True,
) -> None:
return await self.gatt_client.subscribe(
characteristic, subscriber, prefer_notify
)

async def unsubscribe(self, characteristic, subscriber=None):
async def unsubscribe(
self,
characteristic: gatt_client.CharacteristicProxy,
subscriber: Optional[Callable[[bytes], Any]] = None,
) -> None:
return await self.gatt_client.unsubscribe(characteristic, subscriber)

async def read_value(self, attribute):
async def read_value(
self, attribute: Union[int, gatt_client.AttributeProxy]
) -> bytes:
return await self.gatt_client.read_value(attribute)

async def write_value(self, attribute, value, with_response=False):
async def write_value(
self,
attribute: Union[int, gatt_client.AttributeProxy],
value: bytes,
with_response: bool = False,
) -> None:
return await self.gatt_client.write_value(attribute, value, with_response)

async def read_characteristics_by_uuid(self, uuid, service=None):
async def read_characteristics_by_uuid(
self, uuid: core.UUID, service: Optional[gatt_client.ServiceProxy] = None
) -> List[bytes]:
return await self.gatt_client.read_characteristics_by_uuid(uuid, service)

def get_services_by_uuid(self, uuid):
def get_services_by_uuid(self, uuid: core.UUID) -> List[gatt_client.ServiceProxy]:
return self.gatt_client.get_services_by_uuid(uuid)

def get_characteristics_by_uuid(self, uuid, service=None):
def get_characteristics_by_uuid(
self, uuid: core.UUID, service: Optional[gatt_client.ServiceProxy] = None
) -> List[gatt_client.CharacteristicProxy]:
return self.gatt_client.get_characteristics_by_uuid(uuid, service)

def create_service_proxy(self, proxy_class):
return proxy_class.from_client(self.gatt_client)
def create_service_proxy(self, proxy_class: Type[_PROXY_CLASS]) -> _PROXY_CLASS:
return cast(_PROXY_CLASS, proxy_class.from_client(self.gatt_client))

async def discover_service_and_create_proxy(self, proxy_class):
async def discover_service_and_create_proxy(
self, proxy_class: Type[_PROXY_CLASS]
) -> Optional[_PROXY_CLASS]:
# Discover the first matching service and its characteristics
services = await self.discover_service(proxy_class.SERVICE_CLASS.UUID)
if services:
service = services[0]
await service.discover_characteristics()
return self.create_service_proxy(proxy_class)
return None

async def sustain(self, timeout=None):
async def sustain(self, timeout: Optional[float] = None) -> None:
await self.connection.sustain(timeout)

# [Classic only]
async def request_name(self):
async def request_name(self) -> str:
return await self.connection.request_remote_name()

async def __aenter__(self):
Expand All @@ -532,7 +573,7 @@ async def __aenter__(self):
async def __aexit__(self, exc_type, exc_value, traceback):
pass

def __str__(self):
def __str__(self) -> str:
return f'{self.connection.peer_address} as {self.connection.role_name}'


Expand Down Expand Up @@ -732,7 +773,7 @@ async def encrypt(self, enable: bool = True) -> None:
async def switch_role(self, role: int) -> None:
return await self.device.switch_role(self, role)

async def sustain(self, timeout=None):
async def sustain(self, timeout: Optional[float] = None) -> None:
"""Idles the current task waiting for a disconnect or timeout"""

abort = asyncio.get_running_loop().create_future()
Expand Down
39 changes: 28 additions & 11 deletions bumble/gatt_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
Any,
Iterable,
Type,
TypeVar,
Set,
TYPE_CHECKING,
)

Expand Down Expand Up @@ -128,7 +130,7 @@ class ServiceProxy(AttributeProxy):
included_services: List[ServiceProxy]

@staticmethod
def from_client(service_class, client, service_uuid):
def from_client(service_class, client: Client, service_uuid: UUID):
# The service and its characteristics are considered to have already been
# discovered
services = client.get_services_by_uuid(service_uuid)
Expand Down Expand Up @@ -246,8 +248,12 @@ def from_client(cls, client: Client) -> ProfileServiceProxy:
class Client:
services: List[ServiceProxy]
cached_values: Dict[int, Tuple[datetime, bytes]]
notification_subscribers: Dict[int, Callable[[bytes], Any]]
indication_subscribers: Dict[int, Callable[[bytes], Any]]
notification_subscribers: Dict[
int, Set[Union[CharacteristicProxy, Callable[[bytes], Any]]]
]
indication_subscribers: Dict[
int, Set[Union[CharacteristicProxy, Callable[[bytes], Any]]]
]
pending_response: Optional[asyncio.futures.Future[ATT_PDU]]
pending_request: Optional[ATT_PDU]

Expand Down Expand Up @@ -682,8 +688,8 @@ async def discover_characteristics(
async def discover_descriptors(
self,
characteristic: Optional[CharacteristicProxy] = None,
start_handle=None,
end_handle=None,
start_handle: Optional[int] = None,
end_handle: Optional[int] = None,
) -> List[DescriptorProxy]:
'''
See Vol 3, Part G - 4.7.1 Discover All Characteristic Descriptors
Expand Down Expand Up @@ -789,7 +795,12 @@ async def discover_attributes(self) -> List[AttributeProxy]:

return attributes

async def subscribe(self, characteristic, subscriber=None, prefer_notify=True):
async def subscribe(
self,
characteristic: CharacteristicProxy,
subscriber: Optional[Callable[[bytes], Any]] = None,
prefer_notify: bool = True,
) -> None:
# If we haven't already discovered the descriptors for this characteristic,
# do it now
if not characteristic.descriptors_discovered:
Expand Down Expand Up @@ -833,7 +844,11 @@ async def subscribe(self, characteristic, subscriber=None, prefer_notify=True):

await self.write_value(cccd, struct.pack('<H', bits), with_response=True)

async def unsubscribe(self, characteristic, subscriber=None):
async def unsubscribe(
self,
characteristic: CharacteristicProxy,
subscriber: Optional[Callable[[bytes], Any]] = None,
) -> None:
# If we haven't already discovered the descriptors for this characteristic,
# do it now
if not characteristic.descriptors_discovered:
Expand All @@ -853,7 +868,7 @@ async def unsubscribe(self, characteristic, subscriber=None):
self.notification_subscribers,
self.indication_subscribers,
):
subscribers = subscriber_set.get(characteristic.handle, [])
subscribers = subscriber_set.get(characteristic.handle, set())
if subscriber in subscribers:
subscribers.remove(subscriber)

Expand All @@ -871,7 +886,7 @@ async def unsubscribe(self, characteristic, subscriber=None):

async def read_value(
self, attribute: Union[int, AttributeProxy], no_long_read: bool = False
) -> Any:
) -> bytes:
'''
See Vol 3, Part G - 4.8.1 Read Characteristic Value
Expand Down Expand Up @@ -1067,7 +1082,7 @@ def on_gatt_pdu(self, att_pdu: ATT_PDU) -> None:
def on_att_handle_value_notification(self, notification):
# Call all subscribers
subscribers = self.notification_subscribers.get(
notification.attribute_handle, []
notification.attribute_handle, set()
)
if not subscribers:
logger.warning('!!! received notification with no subscriber')
Expand All @@ -1081,7 +1096,9 @@ def on_att_handle_value_notification(self, notification):

def on_att_handle_value_indication(self, indication):
# Call all subscribers
subscribers = self.indication_subscribers.get(indication.attribute_handle, [])
subscribers = self.indication_subscribers.get(
indication.attribute_handle, set()
)
if not subscribers:
logger.warning('!!! received indication with no subscriber')

Expand Down

0 comments on commit 58ba8a1

Please sign in to comment.