Skip to content

Commit

Permalink
BIG Sync app
Browse files Browse the repository at this point in the history
  • Loading branch information
zxzxwu committed Nov 18, 2024
1 parent 698aa8d commit 21c291c
Show file tree
Hide file tree
Showing 4 changed files with 285 additions and 26 deletions.
256 changes: 244 additions & 12 deletions apps/auracast.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,20 @@
import asyncio
import contextlib
import dataclasses
import enum
import functools
import logging
import os
from typing import cast, Any, AsyncGenerator, Coroutine, Dict, Optional, Tuple

import click
import pyee

import ctypes
import wasmtime
import wasmtime.loader
from lea_unicast import liblc3 # type: ignore # pylint: disable=E0401

from bumble.colors import color
import bumble.company_ids
import bumble.core
Expand Down Expand Up @@ -54,6 +61,94 @@
AURACAST_DEFAULT_ATT_MTU = 256


# -----------------------------------------------------------------------------
# WASM - liblc3
# -----------------------------------------------------------------------------
store = wasmtime.loader.store
_memory = cast(wasmtime.Memory, liblc3.memory)
STACK_POINTER = _memory.data_len(store)
_memory.grow(store, 1)
# Mapping wasmtime memory to linear address
memory = (ctypes.c_ubyte * _memory.data_len(store)).from_address(
ctypes.addressof(_memory.data_ptr(store).contents) # type: ignore
)


class Liblc3PcmFormat(enum.IntEnum):
S16 = 0
S24 = 1
S24_3LE = 2
FLOAT = 3


MAX_DECODER_SIZE = liblc3.lc3_decoder_size(10000, 48000)

DECODER_STACK_POINTER = STACK_POINTER
DECODE_BUFFER_STACK_POINTER = DECODER_STACK_POINTER + MAX_DECODER_SIZE * 2
DEFAULT_PCM_SAMPLE_RATE = 48000
DEFAULT_PCM_FORMAT = Liblc3PcmFormat.S16
DEFAULT_PCM_BYTES_PER_SAMPLE = 2


decoders = list[int]()


def setup_decoders(
sample_rate_hz: int, frame_duration_us: int, num_channels: int
) -> None:
logger.info(
f"setup_decoders {sample_rate_hz}Hz {frame_duration_us}us {num_channels}channels"
)
decoders[:num_channels] = [
liblc3.lc3_setup_decoder(
frame_duration_us,
sample_rate_hz,
DEFAULT_PCM_SAMPLE_RATE, # Output sample rate
DECODER_STACK_POINTER + MAX_DECODER_SIZE * i,
)
for i in range(num_channels)
]


def decode(
frame_duration_us: int,
input_bytes: bytes,
channel_index: int,
) -> bytes:
if not input_bytes:
return b''

input_buffer_offset = DECODE_BUFFER_STACK_POINTER
input_buffer_size = len(input_bytes)
input_bytes_per_frame = input_buffer_size

# Copy into wasm
memory[input_buffer_offset : input_buffer_offset + input_buffer_size] = input_bytes # type: ignore

output_buffer_offset = input_buffer_offset + input_buffer_size
output_buffer_size = (
liblc3.lc3_frame_samples(frame_duration_us, DEFAULT_PCM_SAMPLE_RATE)
* DEFAULT_PCM_BYTES_PER_SAMPLE
)

res = liblc3.lc3_decode(
decoders[channel_index],
input_buffer_offset,
input_bytes_per_frame,
DEFAULT_PCM_FORMAT,
output_buffer_offset,
1,
)

if res != 0:
logging.error(f"Parsing failed, res={res}")

# Extract decoded data from the output buffer
return bytes(
memory[output_buffer_offset : output_buffer_offset + output_buffer_size]
)


# -----------------------------------------------------------------------------
# Scan For Broadcasts
# -----------------------------------------------------------------------------
Expand All @@ -62,6 +157,7 @@ class BroadcastScanner(pyee.EventEmitter):
class Broadcast(pyee.EventEmitter):
name: str | None
sync: bumble.device.PeriodicAdvertisingSync
broadcast_id: int
rssi: int = 0
public_broadcast_announcement: Optional[
bumble.profiles.pbp.PublicBroadcastAnnouncement
Expand Down Expand Up @@ -280,11 +376,14 @@ def on_advertisement(self, advertisement: bumble.device.Advertisement) -> None:
bumble.core.AdvertisingData.SERVICE_DATA_16_BIT_UUID
)
) or not (
any(
ad
for ad in ads
if isinstance(ad, tuple)
and ad[0] == bumble.gatt.GATT_BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE
broadcast_audio_annoucement := next(
(
ad
for ad in ads
if isinstance(ad, tuple)
and ad[0] == bumble.gatt.GATT_BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE
),
None,
)
):
return
Expand All @@ -293,25 +392,35 @@ def on_advertisement(self, advertisement: bumble.device.Advertisement) -> None:
bumble.core.AdvertisingData.BROADCAST_NAME
)
assert isinstance(broadcast_name, str) or broadcast_name is None
assert isinstance(broadcast_audio_annoucement[1], bytes)

if broadcast := self.broadcasts.get(advertisement.address):
broadcast.update(advertisement)
return

bumble.utils.AsyncRunner.spawn(
self.on_new_broadcast(broadcast_name, advertisement)
self.on_new_broadcast(
broadcast_name,
advertisement,
bumble.profiles.bap.BroadcastAudioAnnouncement.from_bytes(
broadcast_audio_annoucement[1]
).broadcast_id,
)
)

async def on_new_broadcast(
self, name: str | None, advertisement: bumble.device.Advertisement
self,
name: str | None,
advertisement: bumble.device.Advertisement,
broadcast_id: int,
) -> None:
periodic_advertising_sync = await self.device.create_periodic_advertising_sync(
advertiser_address=advertisement.address,
sid=advertisement.sid,
sync_timeout=self.sync_timeout,
filter_duplicates=self.filter_duplicates,
)
broadcast = self.Broadcast(name, periodic_advertising_sync)
broadcast = self.Broadcast(name, periodic_advertising_sync, broadcast_id)
broadcast.update(advertisement)
self.broadcasts[advertisement.address] = broadcast
periodic_advertising_sync.on('loss', lambda: self.on_broadcast_loss(broadcast))
Expand All @@ -323,10 +432,11 @@ def on_broadcast_loss(self, broadcast: Broadcast) -> None:
self.emit('broadcast_loss', broadcast)


class PrintingBroadcastScanner:
class PrintingBroadcastScanner(pyee.EventEmitter):
def __init__(
self, device: bumble.device.Device, filter_duplicates: bool, sync_timeout: float
) -> None:
super().__init__()
self.scanner = BroadcastScanner(device, filter_duplicates, sync_timeout)
self.scanner.on('new_broadcast', self.on_new_broadcast)
self.scanner.on('broadcast_loss', self.on_broadcast_loss)
Expand Down Expand Up @@ -610,6 +720,108 @@ async def run_pair(transport: str, address: str) -> None:
print("+++ Paired")


async def run_receive(
transport: str, broadcast_id: int, broadcast_code: str | None, sync_timeout: float
) -> None:
async with create_device(transport) as device:
if not device.supports_le_periodic_advertising:
print(color('Periodic advertising not supported', 'red'))
return

scanner = BroadcastScanner(device, False, sync_timeout)
scan_result: asyncio.Future[BroadcastScanner.Broadcast] = (
asyncio.get_running_loop().create_future()
)

def on_new_broadcast(broadcast: BroadcastScanner.Broadcast) -> None:
if scan_result.done():
return
if broadcast.broadcast_id == broadcast_id:
scan_result.set_result(broadcast)

scanner.on('new_broadcast', on_new_broadcast)
await scanner.start()
print('Start scanning...')
broadcast = await scan_result
print('Advertisement found:')
broadcast.print()
basic_audio_announcement_scanned = asyncio.Event()

def on_change() -> None:
if (
broadcast.basic_audio_announcement
and not basic_audio_announcement_scanned.is_set()
):
basic_audio_announcement_scanned.set()

broadcast.on('change', on_change)
if not broadcast.basic_audio_announcement:
print('Wait for Basic Audio Announcement...')
await basic_audio_announcement_scanned.wait()
print('Basic Audio Announcement found')
broadcast.print()
print('Stop scanning')
await scanner.stop()
print('Start sync to BIG')
assert broadcast.basic_audio_announcement
configuration = broadcast.basic_audio_announcement.subgroups[
0
].codec_specific_configuration
assert configuration
assert (sampling_frequency := configuration.sampling_frequency)
assert (frame_duration := configuration.frame_duration)
big_sync = await device.create_big_sync(
broadcast.sync,
bumble.device.BigSyncParameters(
big_sync_timeout=0x4000,
bis=[
bis.index
for bis in broadcast.basic_audio_announcement.subgroups[0].bis
],
),
)
setup_decoders(
sampling_frequency.hz,
frame_duration.us,
len(big_sync.bis_links),
)

for i, bis_link in enumerate(big_sync.bis_links):
print(f'Setup ISO for BIS {bis_link.handle}')

def sink(
subprocess: asyncio.subprocess.Process,
index: int,
packet: bumble.hci.HCI_IsoDataPacket,
):
assert subprocess.stdin
pcm = decode(frame_duration.us, packet.iso_sdu_fragment, index)
subprocess.stdin.write(pcm)

subprocess = await asyncio.create_subprocess_shell(
f'ffmpeg -ac 1 -f s16le -i pipe: -af "pan=stereo|c{i}=c0" -f nut pipe: | ffplay -i -',
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
bis_link.sink = functools.partial(sink, subprocess, i)
await device.send_command(
bumble.hci.HCI_LE_Setup_ISO_Data_Path_Command(
connection_handle=bis_link.handle,
data_path_direction=bumble.hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.CONTROLLER_TO_HOST,
data_path_id=0,
codec_id=bumble.hci.CodingFormat(
codec_id=bumble.hci.CodecID.TRANSPARENT
),
controller_delay=0,
codec_configuration=b'',
),
check_result=True,
)

await asyncio.Event().wait()


def run_async(async_command: Coroutine) -> None:
try:
asyncio.run(async_command)
Expand All @@ -631,9 +843,7 @@ def run_async(async_command: Coroutine) -> None:
# -----------------------------------------------------------------------------
@click.group()
@click.pass_context
def auracast(
ctx,
):
def auracast(ctx):
ctx.ensure_object(dict)


Expand Down Expand Up @@ -691,6 +901,28 @@ def pair(ctx, transport, address):
run_async(run_pair(transport, address))


@auracast.command('receive')
@click.argument('transport')
@click.argument('broadcast_id', type=int)
@click.option(
'--broadcast-code',
metavar='BROADCAST_CODE',
type=str,
help='Boradcast encryption code in hex format',
)
@click.option(
'--sync-timeout',
metavar='SYNC_TIMEOUT',
type=float,
default=AURACAST_DEFAULT_SYNC_TIMEOUT,
help='Sync timeout (in seconds)',
)
@click.pass_context
def receive(ctx, transport, broadcast_id, broadcast_code, sync_timeout):
"""Receive a broadcast source"""
run_async(run_receive(transport, broadcast_id, broadcast_code, sync_timeout))


def main():
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
auracast()
Expand Down
1 change: 1 addition & 0 deletions apps/liblc3.wasm
5 changes: 4 additions & 1 deletion bumble/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1501,7 +1501,10 @@ def ad_data_to_string(ad_type, ad_data):
ad_data_str = f'"{ad_data.decode("utf-8")}"'
elif ad_type == AdvertisingData.COMPLETE_LOCAL_NAME:
ad_type_str = 'Complete Local Name'
ad_data_str = f'"{ad_data.decode("utf-8")}"'
try:
ad_data_str = f'"{ad_data.decode("utf-8")}"'
except UnicodeDecodeError:
ad_data_str = ad_data.hex()
elif ad_type == AdvertisingData.TX_POWER_LEVEL:
ad_type_str = 'TX Power Level'
ad_data_str = str(ad_data[0])
Expand Down
Loading

0 comments on commit 21c291c

Please sign in to comment.