From 9dbd2d465288899ae597f57a05150162f7cc5ef0 Mon Sep 17 00:00:00 2001 From: uael Date: Mon, 6 Nov 2023 01:16:54 -0800 Subject: [PATCH] test --- bumble/l2cap.py | 54 ++++++-------- bumble/pandora/l2cap.py | 161 +++++++++++++++++++++------------------- 2 files changed, 105 insertions(+), 110 deletions(-) diff --git a/bumble/l2cap.py b/bumble/l2cap.py index 9d86e7d91..4a0c60cc9 100644 --- a/bumble/l2cap.py +++ b/bumble/l2cap.py @@ -26,17 +26,16 @@ from pyee import EventEmitter from typing import ( Dict, - Set, Type, List, Optional, Tuple, Callable, Any, - TypeVar, Union, Deque, Iterable, + Set, SupportsBytes, TYPE_CHECKING, overload, @@ -1450,9 +1449,6 @@ class LeCreditBased(Any): # ----------------------------------------------------------------------------- -TPendingConnection = TypeVar('TPendingConnection', bound=PendingConnection.Any) - - class IncomingConnection: @dataclasses.dataclass class Any: @@ -1462,38 +1458,26 @@ class Any: psm: int source_cid: int - def expired(self) -> bool: - ... - - class Future(asyncio.Future[TPendingConnection]): - def accept(self, pend: TPendingConnection) -> bool: - """Accept this connection request.""" - try: - self.set_result(pend) - return True - except asyncio.InvalidStateError: - return False - - def expired(self) -> bool: - return self.done() or self.cancelled() + def __post_init__(self) -> None: + self.future: asyncio.Future[Any] = asyncio.Future() @dataclasses.dataclass - class Basic(Future[PendingConnection.Basic], Any): + class Basic(Any): """L2CAP incoming basic channel connection request.""" - def __post_init__(self) -> None: - super().__init__() + future: asyncio.Future[PendingConnection.Basic] = dataclasses.field(init=False) @dataclasses.dataclass - class LeCreditBased(Future[PendingConnection.LeCreditBased], Any): + class LeCreditBased(Any): """L2CAP incoming LE credit based channel connection request.""" mtu: int mps: int initial_credits: int - def __post_init__(self) -> None: - super().__init__() + future: asyncio.Future[PendingConnection.LeCreditBased] = dataclasses.field( + init=False + ) # ----------------------------------------------------------------------------- @@ -1684,13 +1668,16 @@ def create_classic_server( def listener(incoming: IncomingConnection.Basic) -> None: if incoming.psm == spec.psm: - incoming.accept(PendingConnection.Basic(server.on_connection, spec.mtu)) + incoming.future.set_result( + PendingConnection.Basic(server.on_connection, spec.mtu) + ) def close() -> None: self.unlisten(listener) assert spec.psm is not None self.free_psm(spec.psm) + self.listen(listener) server = ClassicChannelServer(close, spec.psm, handler) return server @@ -1725,7 +1712,7 @@ def create_le_credit_based_server( def listener(incoming: IncomingConnection.LeCreditBased) -> None: if incoming.psm == spec.psm: - incoming.accept( + incoming.future.set_result( PendingConnection.LeCreditBased( server.on_connection, spec.mtu, spec.mps, spec.max_credits ) @@ -1736,6 +1723,7 @@ def close() -> None: assert spec.psm is not None self.free_psm(spec.psm) + self.listen(listener) server = LeCreditBasedChannelServer(close, spec.psm, handler) return server @@ -1848,13 +1836,13 @@ async def handle_connection_request() -> None: # Dispatch incoming connection. for listener in self.listeners: - if not incoming.done(): + if not incoming.future.done(): listener(incoming) try: - pending = await asyncio.wait_for(incoming, timeout=3.0) + pending = await asyncio.wait_for(incoming.future, timeout=3.0) except asyncio.TimeoutError as e: - incoming.cancel(e) + incoming.future.cancel(e) pending = None if pending: @@ -2127,13 +2115,13 @@ async def handle_connection_request() -> None: # Dispatch incoming connection. for listener in self.listeners: - if not incoming.done(): + if not incoming.future.done(): listener(incoming) try: - pending = await asyncio.wait_for(incoming, timeout=3.0) + pending = await asyncio.wait_for(incoming.future, timeout=3.0) except asyncio.TimeoutError as e: - incoming.cancel(e) + incoming.future.cancel(e) pending = None if pending: diff --git a/bumble/pandora/l2cap.py b/bumble/pandora/l2cap.py index c8904e192..47cb7a83e 100644 --- a/bumble/pandora/l2cap.py +++ b/bumble/pandora/l2cap.py @@ -12,18 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import abc import asyncio import dataclasses -import logging import grpc import struct from bumble import device from bumble import l2cap -from bumble.utils import EventWatcher from bumble.pandora import config from bumble.pandora import utils +from bumble.utils import EventWatcher from google.protobuf import any_pb2 # pytype: disable=pyi-error from google.protobuf import empty_pb2 # pytype: disable=pyi-error from pandora import l2cap_pb2 @@ -31,35 +29,20 @@ from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Union -class ChannelProxy(abc.ABC): - rx: asyncio.Queue[bytes] = asyncio.Queue() - - async def receive(self) -> bytes: - return await self.rx.get() - - def send(self, data: bytes) -> None: - ... - - async def disconnect(self) -> None: - ... - - async def wait_disconnect(self) -> None: - ... - - @dataclasses.dataclass -class DynamicChannelProxy(ChannelProxy): +class ChannelProxy: channel: Union[l2cap.ClassicChannel, l2cap.LeCreditBasedChannel, None] def __post_init__(self) -> None: assert self.channel - self.disconnection_result = asyncio.get_event_loop().create_future() + self.rx: asyncio.Queue[bytes] = asyncio.Queue() + self._disconnection_result: asyncio.Future[None] = asyncio.Future() self.channel.sink = self.rx.put_nowait def on_close() -> None: - assert not self.disconnection_result.done() + assert not self._disconnection_result.done() self.channel = None - self.disconnection_result.set_result(None) + self._disconnection_result.set_result(None) self.channel.on('close', on_close) @@ -75,22 +58,8 @@ async def disconnect(self) -> None: await self.channel.disconnect() async def wait_disconnect(self) -> None: - return await self.disconnection_result - - -@dataclasses.dataclass -class FixedChannelProxy(ChannelProxy): - connection: device.Connection - cid: int - - def send(self, data: bytes) -> None: - self.connection.device.send_l2cap_pdu(self.connection.handle, self.cid, data) - - async def disconnect(self) -> None: - raise RuntimeError('Fixed channel cannot be disconnected') - - async def wait_disconnect(self) -> None: - raise RuntimeError('Fixed channel cannot be disconnected') + await self._disconnection_result + assert not self.channel @dataclasses.dataclass @@ -123,34 +92,33 @@ def __init__(self, dev: device.Device, config: config.Config) -> None: self.device = dev self.config = config - @self.device.l2cap_channel_manager.listen - def _(incoming: l2cap.IncomingConnection.Any) -> None: + def on_connection(incoming: l2cap.IncomingConnection.Any) -> None: self.pending.append(incoming) for acceptor in self.accepts: acceptor.put_nowait(incoming) + # Make sure our listener is called before the builtins ones. + self.device.l2cap_channel_manager.listeners.insert(0, on_connection) + def register(self, index: ChannelIndex, proxy: ChannelProxy) -> None: self.channels[index] = proxy def on_close(*_: Any) -> None: - # TODO: Fix Bumble L2CAP emit `close` twice. + # TODO: Fix Bumble L2CAP which emit `close` event twice. if index in self.channels: del self.channels[index] # Listen for disconnection. - if isinstance(proxy, FixedChannelProxy): - proxy.connection.on('disconnection', on_close) - elif isinstance(proxy, DynamicChannelProxy): - assert proxy.channel - proxy.channel.on('close', on_close) + assert proxy.channel + proxy.channel.on('close', on_close) async def listen(self) -> AsyncIterator[l2cap.IncomingConnection.Any]: for incoming in self.pending: - if incoming.expired(): + if incoming.future.done(): self.pending.remove(incoming) continue yield incoming - queue = asyncio.Queue() + queue: asyncio.Queue[l2cap.IncomingConnection.Any] = asyncio.Queue() self.accepts.append(queue) try: while incoming := await queue.get(): @@ -168,6 +136,7 @@ async def Connect( if connection is None: raise RuntimeError(f'{connection_handle}: not connection for handle') + channel: Union[l2cap.ClassicChannel, l2cap.LeCreditBasedChannel] if request.type_variant() == 'basic': assert request.basic channel = await connection.create_l2cap_channel( @@ -189,7 +158,7 @@ async def Connect( raise NotImplementedError(f"{request.type_variant()}: unsupported type") index = ChannelIndex(channel.connection.handle, channel.source_cid) - self.register(index, DynamicChannelProxy(channel)) + self.register(index, ChannelProxy(channel)) return l2cap_pb2.ConnectResponse(channel=index.into_token()) @utils.rpc @@ -208,43 +177,43 @@ async def WaitConnection( if request.type_variant() == 'basic': assert request.basic - async for it in ( + basic = l2cap.PendingConnection.Basic( + fut.set_result, + request.basic.mtu or l2cap.L2CAP_MIN_BR_EDR_MTU, + ) + async for i in ( it async for it in iter if isinstance(it, l2cap.IncomingConnection.Basic) - and it.psm == request.basic.psm ): - pend = l2cap.PendingConnection.Basic( - fut.set_result, - request.basic.mtu or l2cap.L2CAP_MIN_BR_EDR_MTU, - ) - if it.accept(pend): + if not i.future.done() and i.psm == request.basic.psm: + i.future.set_result(basic) break elif request.type_variant() == 'le_credit_based': assert request.le_credit_based - async for it in ( + le_credit_based = l2cap.PendingConnection.LeCreditBased( + fut.set_result, + request.le_credit_based.mtu + or l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU, + request.le_credit_based.mps + or l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS, + request.le_credit_based.initial_credit + or l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS, + ) + async for j in ( it async for it in iter if isinstance(it, l2cap.IncomingConnection.LeCreditBased) - and it.psm == request.le_credit_based.spsm ): - pend = l2cap.PendingConnection.LeCreditBased( - fut.set_result, - request.le_credit_based.mtu - or l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU, - request.le_credit_based.mps - or l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS, - request.le_credit_based.initial_credit - or l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS, - ) - if it.accept(pend): + if not j.future.done() and j.psm == request.le_credit_based.spsm: + j.future.set_result(le_credit_based) break else: raise NotImplementedError(f"{request.type_variant()}: unsupported type") channel = await fut index = ChannelIndex(channel.connection.handle, channel.source_cid) - self.register(index, DynamicChannelProxy(channel)) + self.register(index, ChannelProxy(channel)) return l2cap_pb2.WaitConnectionResponse(channel=index.into_token()) @utils.rpc @@ -267,16 +236,54 @@ async def WaitDisconnection( async def Receive( self, request: l2cap_pb2.ReceiveRequest, context: grpc.ServicerContext ) -> AsyncGenerator[l2cap_pb2.ReceiveResponse, None]: - # TODO: fixed channel `Receive` - channel = self.channels[ChannelIndex.from_token(request.channel)] - while packet := await channel.receive(): - yield l2cap_pb2.ReceiveResponse(data=packet) + watcher = EventWatcher() + if request.source_variant() == 'channel': + assert request.channel + channel = self.channels[ChannelIndex.from_token(request.channel)] + rx = channel.rx + elif request.source_variant() == 'fixed_channel': + assert request.fixed_channel + rx = asyncio.Queue() + handle = request.fixed_channel.connection is not None and int.from_bytes( + request.fixed_channel.connection.cookie.value, 'big' + ) + + @watcher.on(self.device.host, 'l2cap_pdu') + def _(connection: device.Connection, cid: int, pdu: bytes) -> None: + assert request.fixed_channel + if cid == request.fixed_channel.cid and ( + handle is None or handle == connection.handle + ): + rx.put_nowait(pdu) + + else: + raise NotImplementedError(f"{request.source_variant()}: unsupported type") + try: + while data := await rx.get(): + yield l2cap_pb2.ReceiveResponse(data=data) + finally: + watcher.close() @utils.rpc async def Send( self, request: l2cap_pb2.SendRequest, context: grpc.ServicerContext ) -> l2cap_pb2.SendResponse: - # TODO: fixed channel `Send` - channel = self.channels[ChannelIndex.from_token(request.channel)] - channel.send(request.data) + if request.sink_variant() == 'channel': + assert request.channel + channel = self.channels[ChannelIndex.from_token(request.channel)] + channel.send(request.data) + elif request.sink_variant() == 'fixed_channel': + assert request.fixed_channel + # Retrieve Bumble `Connection` from request. + connection_handle = int.from_bytes( + request.fixed_channel.connection.cookie.value, 'big' + ) + connection = self.device.lookup_connection(connection_handle) + if connection is None: + raise RuntimeError(f'{connection_handle}: not connection for handle') + self.device.l2cap_channel_manager.send_pdu( + connection, request.fixed_channel.cid, request.data + ) + else: + raise NotImplementedError(f"{request.sink_variant()}: unsupported type") return l2cap_pb2.SendResponse(success=empty_pb2.Empty())