From 8715333706f81f35cf2e368db913f4fd9b501d7f Mon Sep 17 00:00:00 2001 From: Gilles Boccon-Gibod Date: Mon, 18 Nov 2024 12:13:19 -0800 Subject: [PATCH] Add a GATT adapter that uses `from_bytes` and `__bytes__` as conversion methods. --- bumble/att.py | 20 +- bumble/gatt.py | 38 ++- bumble/gatt_server.py | 13 +- bumble/profiles/aics.py | 84 ++--- bumble/profiles/bass.py | 11 +- bumble/profiles/device_information_service.py | 5 +- bumble/profiles/heart_rate_service.py | 11 +- bumble/utils.py | 25 +- examples/keyboard.py | 2 +- examples/run_gatt_server.py | 2 +- examples/run_gatt_with_adapters.py | 319 ++++++++++++++++++ tests/aics_test.py | 12 +- tests/gatt_test.py | 110 +++--- 13 files changed, 520 insertions(+), 132 deletions(-) create mode 100644 examples/run_gatt_with_adapters.py diff --git a/bumble/att.py b/bumble/att.py index 15ad8c69..61a7ae35 100644 --- a/bumble/att.py +++ b/bumble/att.py @@ -759,13 +759,13 @@ class AttributeValue: def __init__( self, read: Union[ - Callable[[Optional[Connection]], bytes], - Callable[[Optional[Connection]], Awaitable[bytes]], + Callable[[Optional[Connection]], Any], + Callable[[Optional[Connection]], Awaitable[Any]], None, ] = None, write: Union[ - Callable[[Optional[Connection], bytes], None], - Callable[[Optional[Connection], bytes], Awaitable[None]], + Callable[[Optional[Connection], Any], None], + Callable[[Optional[Connection], Any], Awaitable[None]], None, ] = None, ): @@ -824,13 +824,13 @@ def from_string(cls, permissions_str: str) -> Attribute.Permissions: READ_REQUIRES_AUTHORIZATION = Permissions.READ_REQUIRES_AUTHORIZATION WRITE_REQUIRES_AUTHORIZATION = Permissions.WRITE_REQUIRES_AUTHORIZATION - value: Union[bytes, AttributeValue] + value: Any def __init__( self, attribute_type: Union[str, bytes, UUID], permissions: Union[str, Attribute.Permissions], - value: Union[str, bytes, AttributeValue] = b'', + value: Any = b'', ) -> None: EventEmitter.__init__(self) self.handle = 0 @@ -848,11 +848,7 @@ def __init__( else: self.type = attribute_type - # Convert the value to a byte array - if isinstance(value, str): - self.value = bytes(value, 'utf-8') - else: - self.value = value + self.value = value def encode_value(self, value: Any) -> bytes: return value @@ -895,6 +891,8 @@ async def read_value(self, connection: Optional[Connection]) -> bytes: else: value = self.value + self.emit('read', connection, value) + return self.encode_value(value) async def write_value(self, connection: Connection, value_bytes: bytes) -> None: diff --git a/bumble/gatt.py b/bumble/gatt.py index ea65116d..9d82f023 100644 --- a/bumble/gatt.py +++ b/bumble/gatt.py @@ -28,12 +28,15 @@ import logging import struct from typing import ( + Any, Callable, Dict, Iterable, List, Optional, Sequence, + SupportsBytes, + Type, Union, TYPE_CHECKING, ) @@ -41,6 +44,7 @@ from bumble.colors import color from bumble.core import BaseBumbleError, UUID from bumble.att import Attribute, AttributeValue +from bumble.utils import ByteSerializable if TYPE_CHECKING: from bumble.gatt_client import AttributeProxy @@ -343,7 +347,7 @@ class Service(Attribute): def __init__( self, uuid: Union[str, UUID], - characteristics: List[Characteristic], + characteristics: Iterable[Characteristic], primary=True, included_services: Iterable[Service] = (), ) -> None: @@ -362,7 +366,7 @@ def __init__( ) self.uuid = uuid self.included_services = list(included_services) - self.characteristics = characteristics[:] + self.characteristics = list(characteristics) self.primary = primary def get_advertising_data(self) -> Optional[bytes]: @@ -393,7 +397,7 @@ class TemplateService(Service): def __init__( self, - characteristics: List[Characteristic], + characteristics: Iterable[Characteristic], primary: bool = True, included_services: Iterable[Service] = (), ) -> None: @@ -490,7 +494,7 @@ def __init__( uuid: Union[str, bytes, UUID], properties: Characteristic.Properties, permissions: Union[str, Attribute.Permissions], - value: Union[str, bytes, CharacteristicValue] = b'', + value: Any = b'', descriptors: Sequence[Descriptor] = (), ): super().__init__(uuid, permissions, value) @@ -525,7 +529,11 @@ class CharacteristicDeclaration(Attribute): characteristic: Characteristic - def __init__(self, characteristic: Characteristic, value_handle: int) -> None: + def __init__( + self, + characteristic: Characteristic, + value_handle: int, + ) -> None: declaration_bytes = ( struct.pack(' str: return value.decode('utf-8') +# ----------------------------------------------------------------------------- +class SerializableCharacteristicAdapter(CharacteristicAdapter): + ''' + Adapter that converts any class to/from bytes using the class' + `to_bytes` and `__bytes__` methods, respectively. + ''' + + def __init__(self, characteristic, cls: Type[ByteSerializable]): + super().__init__(characteristic) + self.cls = cls + + def encode_value(self, value: SupportsBytes) -> bytes: + return bytes(value) + + def decode_value(self, value: bytes) -> Any: + return self.cls.from_bytes(value) + + # ----------------------------------------------------------------------------- class Descriptor(Attribute): ''' diff --git a/bumble/gatt_server.py b/bumble/gatt_server.py index 0ee673c0..a0ff6093 100644 --- a/bumble/gatt_server.py +++ b/bumble/gatt_server.py @@ -28,7 +28,17 @@ import logging from collections import defaultdict import struct -from typing import List, Tuple, Optional, TypeVar, Type, Dict, Iterable, TYPE_CHECKING +from typing import ( + Dict, + Iterable, + List, + Optional, + Tuple, + TypeVar, + Type, + Union, + TYPE_CHECKING, +) from pyee import EventEmitter from bumble.colors import color @@ -68,6 +78,7 @@ GATT_REQUEST_TIMEOUT, GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE, Characteristic, + CharacteristicAdapter, CharacteristicDeclaration, CharacteristicValue, IncludedServiceDeclaration, diff --git a/bumble/profiles/aics.py b/bumble/profiles/aics.py index 3a696272..eb6a7420 100644 --- a/bumble/profiles/aics.py +++ b/bumble/profiles/aics.py @@ -17,6 +17,7 @@ # ----------------------------------------------------------------------------- # Imports # ----------------------------------------------------------------------------- +from __future__ import annotations import logging import struct @@ -28,10 +29,11 @@ from bumble.att import ATT_Error from bumble.gatt import ( Characteristic, - DelegatedCharacteristicAdapter, + SerializableCharacteristicAdapter, + PackedCharacteristicAdapter, TemplateService, CharacteristicValue, - PackedCharacteristicAdapter, + UTF8CharacteristicAdapter, GATT_AUDIO_INPUT_CONTROL_SERVICE, GATT_AUDIO_INPUT_STATE_CHARACTERISTIC, GATT_GAIN_SETTINGS_ATTRIBUTE_CHARACTERISTIC, @@ -154,9 +156,6 @@ async def notify_subscribers_via_connection(self, connection: Connection) -> Non attribute=self.attribute_value, value=bytes(self) ) - def on_read(self, _connection: Optional[Connection]) -> bytes: - return bytes(self) - @dataclass class GainSettingsProperties: @@ -173,7 +172,7 @@ def from_bytes(cls, data: bytes): (gain_settings_unit, gain_settings_minimum, gain_settings_maximum) = ( struct.unpack('BBB', data) ) - GainSettingsProperties( + return GainSettingsProperties( gain_settings_unit, gain_settings_minimum, gain_settings_maximum ) @@ -186,9 +185,6 @@ def __bytes__(self) -> bytes: ] ) - def on_read(self, _connection: Optional[Connection]) -> bytes: - return bytes(self) - @dataclass class AudioInputControlPoint: @@ -321,21 +317,14 @@ class AudioInputDescription: audio_input_description: str = "Bluetooth" attribute_value: Optional[CharacteristicValue] = None - @classmethod - def from_bytes(cls, data: bytes): - return cls(audio_input_description=data.decode('utf-8')) - - def __bytes__(self) -> bytes: - return self.audio_input_description.encode('utf-8') - - def on_read(self, _connection: Optional[Connection]) -> bytes: - return self.audio_input_description.encode('utf-8') + def on_read(self, _connection: Optional[Connection]) -> str: + return self.audio_input_description - async def on_write(self, connection: Optional[Connection], value: bytes) -> None: + async def on_write(self, connection: Optional[Connection], value: str) -> None: assert connection assert self.attribute_value - self.audio_input_description = value.decode('utf-8') + self.audio_input_description = value await connection.device.notify_subscribers( attribute=self.attribute_value, value=value ) @@ -375,26 +364,29 @@ def __init__( self.audio_input_state, self.gain_settings_properties ) - self.audio_input_state_characteristic = DelegatedCharacteristicAdapter( + self.audio_input_state_characteristic = SerializableCharacteristicAdapter( Characteristic( uuid=GATT_AUDIO_INPUT_STATE_CHARACTERISTIC, properties=Characteristic.Properties.READ | Characteristic.Properties.NOTIFY, permissions=Characteristic.Permissions.READ_REQUIRES_ENCRYPTION, - value=CharacteristicValue(read=self.audio_input_state.on_read), + value=self.audio_input_state, ), - encode=lambda value: bytes(value), + AudioInputState, ) self.audio_input_state.attribute_value = ( self.audio_input_state_characteristic.value ) - self.gain_settings_properties_characteristic = DelegatedCharacteristicAdapter( - Characteristic( - uuid=GATT_GAIN_SETTINGS_ATTRIBUTE_CHARACTERISTIC, - properties=Characteristic.Properties.READ, - permissions=Characteristic.Permissions.READ_REQUIRES_ENCRYPTION, - value=CharacteristicValue(read=self.gain_settings_properties.on_read), + self.gain_settings_properties_characteristic = ( + SerializableCharacteristicAdapter( + Characteristic( + uuid=GATT_GAIN_SETTINGS_ATTRIBUTE_CHARACTERISTIC, + properties=Characteristic.Properties.READ, + permissions=Characteristic.Permissions.READ_REQUIRES_ENCRYPTION, + value=self.gain_settings_properties, + ), + GainSettingsProperties, ) ) @@ -402,7 +394,7 @@ def __init__( uuid=GATT_AUDIO_INPUT_TYPE_CHARACTERISTIC, properties=Characteristic.Properties.READ, permissions=Characteristic.Permissions.READ_REQUIRES_ENCRYPTION, - value=audio_input_type, + value=bytes(audio_input_type, 'utf-8'), ) self.audio_input_status_characteristic = Characteristic( @@ -412,18 +404,14 @@ def __init__( value=bytes([self.audio_input_status]), ) - self.audio_input_control_point_characteristic = DelegatedCharacteristicAdapter( - Characteristic( - uuid=GATT_AUDIO_INPUT_CONTROL_POINT_CHARACTERISTIC, - properties=Characteristic.Properties.WRITE, - permissions=Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION, - value=CharacteristicValue( - write=self.audio_input_control_point.on_write - ), - ) + self.audio_input_control_point_characteristic = Characteristic( + uuid=GATT_AUDIO_INPUT_CONTROL_POINT_CHARACTERISTIC, + properties=Characteristic.Properties.WRITE, + permissions=Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION, + value=CharacteristicValue(write=self.audio_input_control_point.on_write), ) - self.audio_input_description_characteristic = DelegatedCharacteristicAdapter( + self.audio_input_description_characteristic = UTF8CharacteristicAdapter( Characteristic( uuid=GATT_AUDIO_INPUT_DESCRIPTION_CHARACTERISTIC, properties=Characteristic.Properties.READ @@ -469,8 +457,8 @@ def __init__(self, service_proxy: ServiceProxy) -> None: ) ): raise gatt.InvalidServiceError("Audio Input State Characteristic not found") - self.audio_input_state = DelegatedCharacteristicAdapter( - characteristic=characteristics[0], decode=AudioInputState.from_bytes + self.audio_input_state = SerializableCharacteristicAdapter( + characteristics[0], AudioInputState ) if not ( @@ -481,9 +469,8 @@ def __init__(self, service_proxy: ServiceProxy) -> None: raise gatt.InvalidServiceError( "Gain Settings Attribute Characteristic not found" ) - self.gain_settings_properties = PackedCharacteristicAdapter( - characteristics[0], - 'BBB', + self.gain_settings_properties = SerializableCharacteristicAdapter( + characteristics[0], GainSettingsProperties ) if not ( @@ -494,10 +481,7 @@ def __init__(self, service_proxy: ServiceProxy) -> None: raise gatt.InvalidServiceError( "Audio Input Status Characteristic not found" ) - self.audio_input_status = PackedCharacteristicAdapter( - characteristics[0], - 'B', - ) + self.audio_input_status = PackedCharacteristicAdapter(characteristics[0], 'B') if not ( characteristics := service_proxy.get_characteristics_by_uuid( @@ -517,4 +501,4 @@ def __init__(self, service_proxy: ServiceProxy) -> None: raise gatt.InvalidServiceError( "Audio Input Description Characteristic not found" ) - self.audio_input_description = characteristics[0] + self.audio_input_description = UTF8CharacteristicAdapter(characteristics[0]) diff --git a/bumble/profiles/bass.py b/bumble/profiles/bass.py index 57531dbd..9ded4ef9 100644 --- a/bumble/profiles/bass.py +++ b/bumble/profiles/bass.py @@ -276,10 +276,7 @@ class BigEncryption(utils.OpenIntEnum): subgroups: List[SubgroupInfo] @classmethod - def from_bytes(cls, data: bytes) -> Optional[BroadcastReceiveState]: - if not data: - return None - + def from_bytes(cls, data: bytes) -> BroadcastReceiveState: source_id = data[0] _, source_address = hci.Address.parse_address_preceded_by_type(data, 2) source_adv_sid = data[8] @@ -357,7 +354,7 @@ class BroadcastAudioScanServiceProxy(gatt_client.ProfileServiceProxy): SERVICE_CLASS = BroadcastAudioScanService broadcast_audio_scan_control_point: gatt_client.CharacteristicProxy - broadcast_receive_states: List[gatt.DelegatedCharacteristicAdapter] + broadcast_receive_states: List[gatt.SerializableCharacteristicAdapter] def __init__(self, service_proxy: gatt_client.ServiceProxy): self.service_proxy = service_proxy @@ -381,8 +378,8 @@ def __init__(self, service_proxy: gatt_client.ServiceProxy): "Broadcast Receive State characteristic not found" ) self.broadcast_receive_states = [ - gatt.DelegatedCharacteristicAdapter( - characteristic, decode=BroadcastReceiveState.from_bytes + gatt.SerializableCharacteristicAdapter( + characteristic, BroadcastReceiveState ) for characteristic in characteristics ] diff --git a/bumble/profiles/device_information_service.py b/bumble/profiles/device_information_service.py index ecb1c0f8..d1128038 100644 --- a/bumble/profiles/device_information_service.py +++ b/bumble/profiles/device_information_service.py @@ -64,7 +64,10 @@ def __init__( ): characteristics = [ Characteristic( - uuid, Characteristic.Properties.READ, Characteristic.READABLE, field + uuid, + Characteristic.Properties.READ, + Characteristic.READABLE, + bytes(field, 'utf-8'), ) for (field, uuid) in ( (manufacturer_name, GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC), diff --git a/bumble/profiles/heart_rate_service.py b/bumble/profiles/heart_rate_service.py index 0c9a12f0..7685e52e 100644 --- a/bumble/profiles/heart_rate_service.py +++ b/bumble/profiles/heart_rate_service.py @@ -30,6 +30,7 @@ TemplateService, Characteristic, CharacteristicValue, + SerializableCharacteristicAdapter, DelegatedCharacteristicAdapter, PackedCharacteristicAdapter, ) @@ -150,15 +151,14 @@ def __init__( body_sensor_location=None, reset_energy_expended=None, ): - self.heart_rate_measurement_characteristic = DelegatedCharacteristicAdapter( + self.heart_rate_measurement_characteristic = SerializableCharacteristicAdapter( Characteristic( GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC, Characteristic.Properties.NOTIFY, 0, CharacteristicValue(read=read_heart_rate_measurement), ), - # pylint: disable=unnecessary-lambda - encode=lambda value: bytes(value), + HeartRateService.HeartRateMeasurement, ) characteristics = [self.heart_rate_measurement_characteristic] @@ -204,9 +204,8 @@ def __init__(self, service_proxy): if characteristics := service_proxy.get_characteristics_by_uuid( GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC ): - self.heart_rate_measurement = DelegatedCharacteristicAdapter( - characteristics[0], - decode=HeartRateService.HeartRateMeasurement.from_bytes, + self.heart_rate_measurement = SerializableCharacteristicAdapter( + characteristics[0], HeartRateService.HeartRateMeasurement ) else: self.heart_rate_measurement = None diff --git a/bumble/utils.py b/bumble/utils.py index 4c9407f5..d8864bb1 100644 --- a/bumble/utils.py +++ b/bumble/utils.py @@ -24,17 +24,19 @@ import sys import warnings from typing import ( + Any, Awaitable, - Set, - TypeVar, - List, - Tuple, Callable, - Any, + List, Optional, + Protocol, + Set, + Tuple, + TypeVar, Union, overload, ) +from typing_extensions import Self from pyee import EventEmitter @@ -487,3 +489,16 @@ def _missing_(cls, value): obj._value_ = value obj._name_ = f"{cls.__name__}[{value}]" return obj + + +# ----------------------------------------------------------------------------- +class ByteSerializable(Protocol): + """ + Type protocol for classes that can be instantiated from bytes and serialized + to bytes. + """ + + @classmethod + def from_bytes(cls, data: bytes) -> Self: ... + + def __bytes__(self) -> bytes: ... diff --git a/examples/keyboard.py b/examples/keyboard.py index f2afe189..52a4c783 100644 --- a/examples/keyboard.py +++ b/examples/keyboard.py @@ -282,7 +282,7 @@ async def keyboard_device(device, command): GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC, Characteristic.Properties.READ, Characteristic.READABLE, - 'Bumble', + bytes('Bumble', 'utf-8'), ) ], ), diff --git a/examples/run_gatt_server.py b/examples/run_gatt_server.py index 874115cf..66f65538 100644 --- a/examples/run_gatt_server.py +++ b/examples/run_gatt_server.py @@ -127,7 +127,7 @@ async def main() -> None: '486F64C6-4B5F-4B3B-8AFF-EDE134A8446A', Characteristic.Properties.READ | Characteristic.Properties.NOTIFY, Characteristic.READABLE, - 'hello', + bytes('hello', 'utf-8'), ), ], ) diff --git a/examples/run_gatt_with_adapters.py b/examples/run_gatt_with_adapters.py new file mode 100644 index 00000000..f5430b8e --- /dev/null +++ b/examples/run_gatt_with_adapters.py @@ -0,0 +1,319 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- +from __future__ import annotations +import asyncio +import dataclasses +import logging +import os +import random +import struct +import sys +from typing import Any, List, Union + +from bumble.device import Connection, Device, Peer +from bumble import transport +from bumble import gatt +from bumble import hci +from bumble import core + + +# ----------------------------------------------------------------------------- +SERVICE_UUID = core.UUID("50DB505C-8AC4-4738-8448-3B1D9CC09CC5") +CHARACTERISTIC_UUID_BASE = "D901B45B-4916-412E-ACCA-0000000000" + + +# ----------------------------------------------------------------------------- +@dataclasses.dataclass +class CustomSerializableClass: + x: int + y: int + + @classmethod + def from_bytes(cls, data: bytes) -> CustomSerializableClass: + return cls(*struct.unpack(">II", data)) + + def __bytes__(self) -> bytes: + return struct.pack(">II", self.x, self.y) + + +# ----------------------------------------------------------------------------- +@dataclasses.dataclass +class CustomClass: + a: int + b: int + + @classmethod + def decode(cls, data: bytes) -> CustomClass: + return cls(*struct.unpack(">II", data)) + + def encode(self) -> bytes: + return struct.pack(">II", self.a, self.b) + + +# ----------------------------------------------------------------------------- +async def client(device: Device, address: hci.Address) -> None: + print(f'=== Connecting to {address}...') + connection = await device.connect(address) + print('=== Connected') + + # Discover all characteristics. + peer = Peer(connection) + print("*** Discovering services and characteristics...") + await peer.discover_all() + print("*** Discovery complete") + + service = peer.get_services_by_uuid(SERVICE_UUID)[0] + characteristics = [] + for index in range(1, 9): + characteristics.append( + service.get_characteristics_by_uuid( + CHARACTERISTIC_UUID_BASE + f"{index:02X}" + )[0] + ) + + # Read all characteristics as raw bytes. + for characteristic in characteristics: + value = await characteristic.read_value() + print(f"### {characteristic} = {value} ({value.hex()})") + + # Static characteristic with a bytes value. + c1 = characteristics[0] + c1_value = await c1.read_value() + print(f"@@@ C1 {c1} value = {c1_value} (type={type(c1_value)})") + await c1.write_value("happy π day".encode("utf-8")) + + # Static characteristic with a string value. + c2 = gatt.UTF8CharacteristicAdapter(characteristics[1]) + c2_value = await c2.read_value() + print(f"@@@ C2 {c2} value = {c2_value} (type={type(c2_value)})") + await c2.write_value("happy π day") + + # Static characteristic with a tuple value. + c3 = gatt.PackedCharacteristicAdapter(characteristics[2], ">III") + c3_value = await c3.read_value() + print(f"@@@ C3 {c3} value = {c3_value} (type={type(c3_value)})") + await c3.write_value((2001, 2002, 2003)) + + # Static characteristic with a named tuple value. + c4 = gatt.MappedCharacteristicAdapter( + characteristics[3], ">III", ["f1", "f2", "f3"] + ) + c4_value = await c4.read_value() + print(f"@@@ C4 {c4} value = {c4_value} (type={type(c4_value)})") + await c4.write_value({"f1": 4001, "f2": 4002, "f3": 4003}) + + # Static characteristic with a serializable value. + c5 = gatt.SerializableCharacteristicAdapter( + characteristics[4], CustomSerializableClass + ) + c5_value = await c5.read_value() + print(f"@@@ C5 {c5} value = {c5_value} (type={type(c5_value)})") + await c5.write_value(CustomSerializableClass(56, 57)) + + # Static characteristic with a delegated value. + c6 = gatt.DelegatedCharacteristicAdapter( + characteristics[5], encode=CustomClass.encode, decode=CustomClass.decode + ) + c6_value = await c6.read_value() + print(f"@@@ C6 {c6} value = {c6_value} (type={type(c6_value)})") + await c6.write_value(CustomClass(6, 7)) + + # Dynamic characteristic with a bytes value. + c7 = characteristics[6] + c7_value = await c7.read_value() + print(f"@@@ C7 {c7} value = {c7_value} (type={type(c7_value)})") + await c7.write_value(bytes.fromhex("01020304")) + + # Dynamic characteristic with a string value. + c8 = gatt.UTF8CharacteristicAdapter(characteristics[7]) + c8_value = await c8.read_value() + print(f"@@@ C8 {c8} value = {c8_value} (type={type(c8_value)})") + await c8.write_value("howdy") + + +# ----------------------------------------------------------------------------- +def dynamic_read(selector: str) -> Union[bytes, str]: + if selector == "bytes": + print("$$$ Returning random bytes") + return random.randbytes(7) + elif selector == "string": + print("$$$ Returning random string") + return random.randbytes(7).hex() + + raise ValueError("invalid selector") + + +# ----------------------------------------------------------------------------- +def dynamic_write(selector: str, value: Any) -> None: + print(f"$$$ Received[{selector}]: {value} (type={type(value)})") + + +# ----------------------------------------------------------------------------- +def on_characteristic_read(characteristic: gatt.Characteristic, value: Any) -> None: + """Event listener invoked when a characteristic is read.""" + print(f"<<< READ: {characteristic} -> {value} ({type(value)})") + + +# ----------------------------------------------------------------------------- +def on_characteristic_write(characteristic: gatt.Characteristic, value: Any) -> None: + """Event listener invoked when a characteristic is written.""" + print(f"<<< WRITE: {characteristic} <- {value} ({type(value)})") + + +# ----------------------------------------------------------------------------- +async def main() -> None: + if len(sys.argv) < 2: + print("Usage: run_gatt_with_adapters.py []") + print("example: run_gatt_with_adapters.py usb:0 E1:CA:72:48:C4:E8") + return + + async with await transport.open_transport(sys.argv[1]) as hci_transport: + # Create a device to manage the host + device = Device.with_hci( + "Bumble", + hci.Address("F0:F1:F2:F3:F4:F5"), + hci_transport.source, + hci_transport.sink, + ) + + # Static characteristic with a bytes value. + c1 = gatt.Characteristic( + CHARACTERISTIC_UUID_BASE + "01", + gatt.Characteristic.Properties.READ | gatt.Characteristic.Properties.WRITE, + gatt.Characteristic.READABLE | gatt.Characteristic.WRITEABLE, + b'hello', + ) + + # Static characteristic with a string value. + c2 = gatt.UTF8CharacteristicAdapter( + gatt.Characteristic( + CHARACTERISTIC_UUID_BASE + "02", + gatt.Characteristic.Properties.READ + | gatt.Characteristic.Properties.WRITE, + gatt.Characteristic.READABLE | gatt.Characteristic.WRITEABLE, + 'hello', + ) + ) + + # Static characteristic with a tuple value. + c3 = gatt.PackedCharacteristicAdapter( + gatt.Characteristic( + CHARACTERISTIC_UUID_BASE + "03", + gatt.Characteristic.Properties.READ + | gatt.Characteristic.Properties.WRITE, + gatt.Characteristic.READABLE | gatt.Characteristic.WRITEABLE, + (1007, 1008, 1009), + ), + ">III", + ) + + # Static characteristic with a named tuple value. + c4 = gatt.MappedCharacteristicAdapter( + gatt.Characteristic( + CHARACTERISTIC_UUID_BASE + "04", + gatt.Characteristic.Properties.READ + | gatt.Characteristic.Properties.WRITE, + gatt.Characteristic.READABLE | gatt.Characteristic.WRITEABLE, + {"f1": 3007, "f2": 3008, "f3": 3009}, + ), + ">III", + ["f1", "f2", "f3"], + ) + + # Static characteristic with a serializable value. + c5 = gatt.SerializableCharacteristicAdapter( + gatt.Characteristic( + CHARACTERISTIC_UUID_BASE + "05", + gatt.Characteristic.Properties.READ + | gatt.Characteristic.Properties.WRITE, + gatt.Characteristic.READABLE | gatt.Characteristic.WRITEABLE, + CustomSerializableClass(11, 12), + ), + CustomSerializableClass, + ) + + # Static characteristic with a delegated value. + c6 = gatt.DelegatedCharacteristicAdapter( + gatt.Characteristic( + CHARACTERISTIC_UUID_BASE + "06", + gatt.Characteristic.Properties.READ + | gatt.Characteristic.Properties.WRITE, + gatt.Characteristic.READABLE | gatt.Characteristic.WRITEABLE, + CustomClass(1, 2), + ), + encode=CustomClass.encode, + decode=CustomClass.decode, + ) + + # Dynamic characteristic with a bytes value. + c7 = gatt.Characteristic( + CHARACTERISTIC_UUID_BASE + "07", + gatt.Characteristic.Properties.READ | gatt.Characteristic.Properties.WRITE, + gatt.Characteristic.READABLE | gatt.Characteristic.WRITEABLE, + gatt.CharacteristicValue( + read=lambda connection: dynamic_read("bytes"), + write=lambda connection, value: dynamic_write("bytes", value), + ), + ) + + # Dynamic characteristic with a string value. + c8 = gatt.UTF8CharacteristicAdapter( + gatt.Characteristic( + CHARACTERISTIC_UUID_BASE + "08", + gatt.Characteristic.Properties.READ + | gatt.Characteristic.Properties.WRITE, + gatt.Characteristic.READABLE | gatt.Characteristic.WRITEABLE, + gatt.CharacteristicValue( + read=lambda connection: dynamic_read("string"), + write=lambda connection, value: dynamic_write("string", value), + ), + ) + ) + + characteristics: List[ + Union[gatt.Characteristic, gatt.CharacteristicAdapter] + ] = [c1, c2, c3, c4, c5, c6, c7, c8] + + # Listen for read and write events. + for characteristic in characteristics: + characteristic.on( + "read", + lambda _, value, c=characteristic: on_characteristic_read(c, value), + ) + characteristic.on( + "write", + lambda _, value, c=characteristic: on_characteristic_write(c, value), + ) + + device.add_service(gatt.Service(SERVICE_UUID, characteristics)) # type: ignore + + # Get things going + await device.power_on() + + # Connect to a peer + if len(sys.argv) > 2: + await client(device, hci.Address(sys.argv[2])) + else: + await device.start_advertising(auto_restart=True) + + await hci_transport.source.wait_for_termination() + + +# ----------------------------------------------------------------------------- +logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) +asyncio.run(main()) diff --git a/tests/aics_test.py b/tests/aics_test.py index 9526558f..44826a9c 100644 --- a/tests/aics_test.py +++ b/tests/aics_test.py @@ -28,6 +28,7 @@ AudioInputState, AICSServiceProxy, GainMode, + GainSettingsProperties, AudioInputStatus, AudioInputControlPointOpCode, ErrorCode, @@ -82,7 +83,12 @@ async def test_init_service(aics_client: AICSServiceProxy): gain_mode=GainMode.MANUAL, change_counter=0, ) - assert await aics_client.gain_settings_properties.read_value() == (1, 0, 255) + assert ( + await aics_client.gain_settings_properties.read_value() + == GainSettingsProperties( + gain_settings_unit=1, gain_settings_minimum=0, gain_settings_maximum=255 + ) + ) assert await aics_client.audio_input_status.read_value() == ( AudioInputStatus.ACTIVE ) @@ -481,12 +487,12 @@ async def test_set_automatic_gain_mode_when_automatic_only( @pytest.mark.asyncio async def test_audio_input_description_initial_value(aics_client: AICSServiceProxy): description = await aics_client.audio_input_description.read_value() - assert description.decode('utf-8') == "Bluetooth" + assert description == "Bluetooth" @pytest.mark.asyncio async def test_audio_input_description_write_and_read(aics_client: AICSServiceProxy): - new_description = "Line Input".encode('utf-8') + new_description = "Line Input" await aics_client.audio_input_description.write_value(new_description) diff --git a/tests/gatt_test.py b/tests/gatt_test.py index 8d73eb3f..a7ce8930 100644 --- a/tests/gatt_test.py +++ b/tests/gatt_test.py @@ -15,11 +15,13 @@ # ----------------------------------------------------------------------------- # Imports # ----------------------------------------------------------------------------- +from __future__ import annotations import asyncio import logging import os import struct import pytest +from typing_extensions import Self from unittest.mock import AsyncMock, Mock, ANY from bumble.controller import Controller @@ -31,6 +33,7 @@ GATT_BATTERY_LEVEL_CHARACTERISTIC, GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR, CharacteristicAdapter, + SerializableCharacteristicAdapter, DelegatedCharacteristicAdapter, PackedCharacteristicAdapter, MappedCharacteristicAdapter, @@ -310,7 +313,7 @@ async def test_attribute_getters(): # ----------------------------------------------------------------------------- @pytest.mark.asyncio -async def test_CharacteristicAdapter(): +async def test_CharacteristicAdapter() -> None: # Check that the CharacteristicAdapter base class is transparent v = bytes([1, 2, 3]) c = Characteristic( @@ -329,67 +332,94 @@ async def test_CharacteristicAdapter(): assert c.value == v # Simple delegated adapter - a = DelegatedCharacteristicAdapter( + delegated = DelegatedCharacteristicAdapter( c, lambda x: bytes(reversed(x)), lambda x: bytes(reversed(x)) ) - value = await a.read_value(None) - assert value == bytes(reversed(v)) + delegated_value = await delegated.read_value(None) + assert delegated_value == bytes(reversed(v)) - v = bytes([3, 4, 5]) - await a.write_value(None, v) - assert a.value == bytes(reversed(v)) + delegated_value2 = bytes([3, 4, 5]) + await delegated.write_value(None, delegated_value2) + assert delegated.value == bytes(reversed(delegated_value2)) # Packed adapter with single element format - v = 1234 - pv = struct.pack('>H', v) - c.value = v - a = PackedCharacteristicAdapter(c, '>H') + packed_value_ref = 1234 + packed_value_bytes = struct.pack('>H', packed_value_ref) + c.value = packed_value_ref + packed = PackedCharacteristicAdapter(c, '>H') - value = await a.read_value(None) - assert value == pv - c.value = None - await a.write_value(None, pv) - assert a.value == v + packed_value_read = await packed.read_value(None) + assert packed_value_read == packed_value_bytes + c.value = b'' + await packed.write_value(None, packed_value_bytes) + assert packed.value == packed_value_ref # Packed adapter with multi-element format v1 = 1234 v2 = 5678 - pv = struct.pack('>HH', v1, v2) + packed_multi_value_bytes = struct.pack('>HH', v1, v2) c.value = (v1, v2) - a = PackedCharacteristicAdapter(c, '>HH') + packed_multi = PackedCharacteristicAdapter(c, '>HH') - value = await a.read_value(None) - assert value == pv - c.value = None - await a.write_value(None, pv) - assert a.value == (v1, v2) + packed_multi_read_value = await packed_multi.read_value(None) + assert packed_multi_read_value == packed_multi_value_bytes + packed_multi.value = b'' + await packed_multi.write_value(None, packed_multi_value_bytes) + assert packed_multi.value == (v1, v2) # Mapped adapter v1 = 1234 v2 = 5678 - pv = struct.pack('>HH', v1, v2) + packed_mapped_value_bytes = struct.pack('>HH', v1, v2) mapped = {'v1': v1, 'v2': v2} c.value = mapped - a = MappedCharacteristicAdapter(c, '>HH', ('v1', 'v2')) + packed_mapped = MappedCharacteristicAdapter(c, '>HH', ('v1', 'v2')) - value = await a.read_value(None) - assert value == pv - c.value = None - await a.write_value(None, pv) - assert a.value == mapped + packed_mapped_read_value = await packed_mapped.read_value(None) + assert packed_mapped_read_value == packed_mapped_value_bytes + c.value = b'' + await packed_mapped.write_value(None, packed_mapped_value_bytes) + assert packed_mapped.value == mapped # UTF-8 adapter - v = 'Hello π' - ev = v.encode('utf-8') - c.value = v - a = UTF8CharacteristicAdapter(c) - - value = await a.read_value(None) - assert value == ev - c.value = None - await a.write_value(None, ev) - assert a.value == v + string_value = 'Hello π' + string_value_bytes = string_value.encode('utf-8') + c.value = string_value + string_c = UTF8CharacteristicAdapter(c) + + string_read_value = await string_c.read_value(None) + assert string_read_value == string_value_bytes + c.value = b'' + await string_c.write_value(None, string_value_bytes) + assert string_c.value == string_value + + # Class adapter + class BlaBla: + def __init__(self, a: int, b: int) -> None: + self.a = a + self.b = b + + @classmethod + def from_bytes(cls, data: bytes) -> Self: + a, b = struct.unpack(">II", data) + return cls(a, b) + + def __bytes__(self) -> bytes: + return struct.pack(">II", self.a, self.b) + + class_value = BlaBla(3, 4) + class_value_bytes = struct.pack(">II", 3, 4) + c.value = class_value + class_c = SerializableCharacteristicAdapter(c, BlaBla) + + class_read_value = await class_c.read_value(None) + assert class_read_value == class_value_bytes + c.value = b'' + await class_c.write_value(None, class_value_bytes) + assert isinstance(c.value, BlaBla) + assert c.value.a == 3 + assert c.value.b == 4 # -----------------------------------------------------------------------------