Skip to content

Commit

Permalink
Typing A2DP
Browse files Browse the repository at this point in the history
  • Loading branch information
zxzxwu committed Nov 17, 2023
1 parent 0667e83 commit e1fdb12
Showing 1 changed file with 83 additions and 68 deletions.
151 changes: 83 additions & 68 deletions bumble/a2dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,13 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations

import dataclasses
import struct
import logging
from collections import namedtuple
from collections.abc import AsyncGenerator
from typing import List, Callable, Awaitable

from .company_ids import COMPANY_IDENTIFIERS
from .sdp import (
Expand Down Expand Up @@ -239,24 +243,20 @@ def make_audio_sink_service_sdp_records(service_record_handle, version=(1, 3)):


# -----------------------------------------------------------------------------
class SbcMediaCodecInformation(
namedtuple(
'SbcMediaCodecInformation',
[
'sampling_frequency',
'channel_mode',
'block_length',
'subbands',
'allocation_method',
'minimum_bitpool_value',
'maximum_bitpool_value',
],
)
):
@dataclasses.dataclass
class SbcMediaCodecInformation:
'''
A2DP spec - 4.3.2 Codec Specific Information Elements
'''

sampling_frequency: int
channel_mode: int
block_length: int
subbands: int
allocation_method: int
minimum_bitpool_value: int
maximum_bitpool_value: int

SAMPLING_FREQUENCY_BITS = {16000: 1 << 3, 32000: 1 << 2, 44100: 1 << 1, 48000: 1}
CHANNEL_MODE_BITS = {
SBC_MONO_CHANNEL_MODE: 1 << 3,
Expand All @@ -272,7 +272,7 @@ class SbcMediaCodecInformation(
}

@staticmethod
def from_bytes(data: bytes) -> 'SbcMediaCodecInformation':
def from_bytes(data: bytes) -> SbcMediaCodecInformation:
sampling_frequency = (data[0] >> 4) & 0x0F
channel_mode = (data[0] >> 0) & 0x0F
block_length = (data[1] >> 4) & 0x0F
Expand All @@ -293,14 +293,14 @@ def from_bytes(data: bytes) -> 'SbcMediaCodecInformation':
@classmethod
def from_discrete_values(
cls,
sampling_frequency,
channel_mode,
block_length,
subbands,
allocation_method,
minimum_bitpool_value,
maximum_bitpool_value,
):
sampling_frequency: int,
channel_mode: int,
block_length: int,
subbands: int,
allocation_method: int,
minimum_bitpool_value: int,
maximum_bitpool_value: int,
) -> SbcMediaCodecInformation:
return SbcMediaCodecInformation(
sampling_frequency=cls.SAMPLING_FREQUENCY_BITS[sampling_frequency],
channel_mode=cls.CHANNEL_MODE_BITS[channel_mode],
Expand All @@ -314,14 +314,14 @@ def from_discrete_values(
@classmethod
def from_lists(
cls,
sampling_frequencies,
channel_modes,
block_lengths,
subbands,
allocation_methods,
minimum_bitpool_value,
maximum_bitpool_value,
):
sampling_frequencies: List[int],
channel_modes: List[int],
block_lengths: List[int],
subbands: List[int],
allocation_methods: List[int],
minimum_bitpool_value: int,
maximum_bitpool_value: int,
) -> SbcMediaCodecInformation:
return SbcMediaCodecInformation(
sampling_frequency=sum(
cls.SAMPLING_FREQUENCY_BITS[x] for x in sampling_frequencies
Expand All @@ -348,7 +348,7 @@ def __bytes__(self) -> bytes:
]
)

def __str__(self):
def __str__(self) -> str:
channel_modes = ['MONO', 'DUAL_CHANNEL', 'STEREO', 'JOINT_STEREO']
allocation_methods = ['SNR', 'Loudness']
return '\n'.join(
Expand All @@ -367,16 +367,19 @@ def __str__(self):


# -----------------------------------------------------------------------------
class AacMediaCodecInformation(
namedtuple(
'AacMediaCodecInformation',
['object_type', 'sampling_frequency', 'channels', 'rfa', 'vbr', 'bitrate'],
)
):
@dataclasses.dataclass
class AacMediaCodecInformation:
'''
A2DP spec - 4.5.2 Codec Specific Information Elements
'''

object_type: int
sampling_frequency: int
channels: int
rfa: int
vbr: int
bitrate: int

OBJECT_TYPE_BITS = {
MPEG_2_AAC_LC_OBJECT_TYPE: 1 << 7,
MPEG_4_AAC_LC_OBJECT_TYPE: 1 << 6,
Expand All @@ -400,7 +403,7 @@ class AacMediaCodecInformation(
CHANNELS_BITS = {1: 1 << 1, 2: 1}

@staticmethod
def from_bytes(data: bytes) -> 'AacMediaCodecInformation':
def from_bytes(data: bytes) -> AacMediaCodecInformation:
object_type = data[0]
sampling_frequency = (data[1] << 4) | ((data[2] >> 4) & 0x0F)
channels = (data[2] >> 2) & 0x03
Expand All @@ -413,8 +416,13 @@ def from_bytes(data: bytes) -> 'AacMediaCodecInformation':

@classmethod
def from_discrete_values(
cls, object_type, sampling_frequency, channels, vbr, bitrate
):
cls,
object_type: int,
sampling_frequency: int,
channels: int,
vbr: int,
bitrate: int,
) -> AacMediaCodecInformation:
return AacMediaCodecInformation(
object_type=cls.OBJECT_TYPE_BITS[object_type],
sampling_frequency=cls.SAMPLING_FREQUENCY_BITS[sampling_frequency],
Expand All @@ -425,7 +433,14 @@ def from_discrete_values(
)

@classmethod
def from_lists(cls, object_types, sampling_frequencies, channels, vbr, bitrate):
def from_lists(
cls,
object_types: List[int],
sampling_frequencies: List[int],
channels: List[int],
vbr: int,
bitrate: int,
) -> AacMediaCodecInformation:
return AacMediaCodecInformation(
object_type=sum(cls.OBJECT_TYPE_BITS[x] for x in object_types),
sampling_frequency=sum(
Expand All @@ -449,7 +464,7 @@ def __bytes__(self) -> bytes:
]
)

def __str__(self):
def __str__(self) -> str:
object_types = [
'MPEG_2_AAC_LC',
'MPEG_4_AAC_LC',
Expand All @@ -474,26 +489,26 @@ def __str__(self):
)


@dataclasses.dataclass
# -----------------------------------------------------------------------------
class VendorSpecificMediaCodecInformation:
'''
A2DP spec - 4.7.2 Codec Specific Information Elements
'''

vendor_id: int
codec_id: int
value: bytes

@staticmethod
def from_bytes(data):
def from_bytes(data: bytes) -> VendorSpecificMediaCodecInformation:
(vendor_id, codec_id) = struct.unpack_from('<IH', data, 0)
return VendorSpecificMediaCodecInformation(vendor_id, codec_id, data[6:])

def __init__(self, vendor_id, codec_id, value):
self.vendor_id = vendor_id
self.codec_id = codec_id
self.value = value

def __bytes__(self):
def __bytes__(self) -> bytes:
return struct.pack('<IH', self.vendor_id, self.codec_id, self.value)

def __str__(self):
def __str__(self) -> str:
# pylint: disable=line-too-long
return '\n'.join(
[
Expand All @@ -506,29 +521,27 @@ def __str__(self):


# -----------------------------------------------------------------------------
@dataclasses.dataclass
class SbcFrame:
def __init__(
self, sampling_frequency, block_count, channel_mode, subband_count, payload
):
self.sampling_frequency = sampling_frequency
self.block_count = block_count
self.channel_mode = channel_mode
self.subband_count = subband_count
self.payload = payload
sampling_frequency: int
block_count: int
channel_mode: int
subband_count: int
payload: bytes

@property
def sample_count(self):
def sample_count(self) -> int:
return self.subband_count * self.block_count

@property
def bitrate(self):
def bitrate(self) -> int:
return 8 * ((len(self.payload) * self.sampling_frequency) // self.sample_count)

@property
def duration(self):
def duration(self) -> float:
return self.sample_count / self.sampling_frequency

def __str__(self):
def __str__(self) -> str:
return (
f'SBC(sf={self.sampling_frequency},'
f'cm={self.channel_mode},'
Expand All @@ -540,12 +553,12 @@ def __str__(self):

# -----------------------------------------------------------------------------
class SbcParser:
def __init__(self, read):
def __init__(self, read: Callable[[int], Awaitable[bytes]]) -> None:
self.read = read

@property
def frames(self):
async def generate_frames():
def frames(self) -> AsyncGenerator[SbcFrame, None]:
async def generate_frames() -> AsyncGenerator[SbcFrame, None]:
while True:
# Read 4 bytes of header
header = await self.read(4)
Expand Down Expand Up @@ -589,7 +602,9 @@ async def generate_frames():

# -----------------------------------------------------------------------------
class SbcPacketSource:
def __init__(self, read, mtu, codec_capabilities):
def __init__(
self, read: Callable[[int], Awaitable[bytes]], mtu: int, codec_capabilities
) -> None:
self.read = read
self.mtu = mtu
self.codec_capabilities = codec_capabilities
Expand Down

0 comments on commit e1fdb12

Please sign in to comment.