Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat inference transcript #11

Merged
merged 4 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions src/stt_data_with_llm/audio_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,22 @@
import librosa
import requests
import torchaudio
from dotenv import load_dotenv
from pyannote.audio import Pipeline
from pydub import AudioSegment

from stt_data_with_llm.config import (
AUDIO_HEADERS,
AUDIO_SEG_LOWER_LIMIT,
AUDIO_SEG_UPPER_LIMIT,
HEADERS,
HYPER_PARAMETERS,
USE_AUTH_TOKEN,
)
from stt_data_with_llm.util import setup_logging

# load the evnironment variable
load_dotenv()

USE_AUTH_TOKEN = os.getenv("use_auth_token")
# Call the setup_logging function at the beginning of your script
setup_logging("audio_parser.log")

Expand Down Expand Up @@ -62,15 +66,21 @@ def sec_to_frame(sec, sr):
def initialize_vad_pipeline():
"""
Initializes the Voice Activity Detection (VAD) pipeline using Pyannote.

Returns:
Pipeline: Initialized VAD pipeline
"""
logging.info("Initializing Voice Activity Detection pipeline...")
vad_pipeline = Pipeline.from_pretrained(
"pyannote/voice-activity-detection",
use_auth_token=USE_AUTH_TOKEN,
)
try:
vad_pipeline = Pipeline.from_pretrained(
"pyannote/voice-activity-detection",
use_auth_token=USE_AUTH_TOKEN,
)
except Exception as e:
logging.warning(f"Failed to load online model: {e}. Using local model.")
vad_pipeline = Pipeline.from_pretrained(
"tests/pyannote_vad_model",
use_auth_token=False,
)
vad_pipeline.instantiate(HYPER_PARAMETERS)
logging.info("VAD pipeline initialized successfully.")
return vad_pipeline
Expand Down Expand Up @@ -135,7 +145,7 @@ def get_audio(audio_url):
bytes: Downloaded and converted audio data
"""
logging.info(f"Downloading audio from: {audio_url}")
response = requests.get(audio_url, headers=HEADERS, stream=True)
response = requests.get(audio_url, headers=AUDIO_HEADERS, stream=True)
if response.status_code == 200:
audio_data = response.content # Store original audio in memory
logging.info("Converting Audio to 16k")
Expand Down
9 changes: 7 additions & 2 deletions src/stt_data_with_llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
}

# Define the headers (as given)
HEADERS = {
AUDIO_HEADERS = {
"accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7", # noqa: E501
"accept-encoding": "gzip, deflate, br, zstd",
"accept-language": "en-US,en;q=0.9,en-IN;q=0.8",
Expand All @@ -37,4 +37,9 @@
"user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/129.0.0.0 Safari/537.36 Edg/129.0.0.0", # noqa: E501
}

USE_AUTH_TOKEN = "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"
# Inferfence
SAMPLE_RATE = 16000
CHANNELS = 1
SAMPLE_WIDTH = 2

API_URL = "https://wpgzw4at8o6876h0.us-east-1.aws.endpoints.huggingface.cloud"
104 changes: 102 additions & 2 deletions src/stt_data_with_llm/inference_transcript.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,102 @@
def get_audio_inference_text(audio_seg_data):
pass
import logging
import os
import wave
from io import BytesIO

import requests
from dotenv import load_dotenv

from stt_data_with_llm.config import API_URL, CHANNELS, SAMPLE_RATE, SAMPLE_WIDTH
from stt_data_with_llm.util import setup_logging

load_dotenv()
TOKEN_ID = os.getenv("token_id")
# Call the setup_logging function at the beginning of your script
setup_logging("inference_log.log")

INFERENCE_HEADERS = {
"Accept": "application/json",
"Authorization": f"Bearer {TOKEN_ID}",
"Content-Type": "audio/wav",
}


def convert_raw_to_wav_in_memory(raw_audio, sample_rate, channels, sample_width):
"""
Converts raw audio data to a valid WAV format in memory.

Args:
raw_audio (bytes): Raw audio data.
sample_rate (int): Audio sample rate.
channels (int): Number of audio channels.
sample_width (int): Number of bytes per sample.

Returns:
BytesIO: In-memory WAV file if conversion is successful, None otherwise.
"""
try:
wav_buffer = BytesIO()
with wave.open(wav_buffer, "wb") as wav_file:
wav_file.setnchannels(channels)
wav_file.setsampwidth(sample_width)
wav_file.setframerate(sample_rate)
wav_file.writeframes(raw_audio)
wav_buffer.seek(0) # Reset buffer to the beginning
logging.info("Raw audio successfully converted to WAV format in memory.")
return wav_buffer
except Exception as e:
logging.error(f"Error converting raw audio to WAV in memory: {e}")
return None


def query_audio_api(wav_buffer):
"""
Sends the WAV audio data to the Hugging Face API for inference.

Args:
wav_buffer (BytesIO): In-memory WAV file buffer.

Returns:
dict: API response containing the transcription.
"""
try:
response = requests.post(API_URL, headers=INFERENCE_HEADERS, data=wav_buffer)
response.raise_for_status()
api_response = response.json()
logging.info("API call successful")
return api_response
except requests.RequestException as e:
logging.error(f"Error during API call: {e}")
return None


def get_audio_inference_text(raw_audio):
"""
Generates the inference transcript for raw audio data.

Args:
raw_audio (bytes): Raw audio data of the segment.

Returns:
str: The transcript generated for the given audio segment.
"""
try:
# Convert raw audio to WAV format in memory
wav_buffer = convert_raw_to_wav_in_memory(
raw_audio, SAMPLE_RATE, CHANNELS, SAMPLE_WIDTH
)
if not wav_buffer:
return ""
logging.info("Running inference on audio segment")
# Send the WAV data to the API for transcription
response = query_audio_api(wav_buffer)
if not response or "text" not in response:
return ""
transcript = response["text"]

logging.info("Inference completed successfully")
return transcript

except Exception as e:
logging.error(f"Error during inference: {e}")
return ""
106 changes: 83 additions & 23 deletions tests/test_audio_parser.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,92 @@
import json
import logging
from unittest import TestCase, mock

from stt_data_with_llm.audio_parser import get_audio, get_split_audio
from stt_data_with_llm.config import AUDIO_SEG_LOWER_LIMIT, AUDIO_SEG_UPPER_LIMIT


def test_get_split_audio():
"""
Test function for the get_split_audio functionality.
"""
audio_urls = {
"NW_001": "https://www.rfa.org/tibetan/sargyur/golok-china-religious-restriction-08202024054225.html/@@stream", # noqa
"NW_002": "https://vot.org/wp-content/uploads/2024/03/tc88888888888888.mp3",
"NW_003": "https://voa-audio-ns.akamaized.net/vti/2024/04/13/01000000-0aff-0242-a7bb-08dc5bc45613.mp3",
}
num_of_seg_in_audios = {}
for seg_id, audio_url in audio_urls.items():

audio_data = get_audio(audio_url)
split_audio_data = get_split_audio(
audio_data, seg_id, AUDIO_SEG_LOWER_LIMIT, AUDIO_SEG_UPPER_LIMIT
)
num_split = len(split_audio_data)
num_of_seg_in_audios[seg_id] = num_split
expected_num_of_seg_in_audios = "tests/data/expected_audio_data.json"
with open(expected_num_of_seg_in_audios, encoding="utf-8") as file:
expected_num_split = json.load(file)
assert num_of_seg_in_audios == expected_num_split
class TestGetSplitAudio(TestCase):
@mock.patch("stt_data_with_llm.audio_parser.initialize_vad_pipeline")
@mock.patch("stt_data_with_llm.audio_parser.Pipeline")
def test_get_split_audio(self, mock_pipeline, mock_initialize_vad):
"""
Test function for the get_split_audio functionality.
"""
# Define mock VAD outputs for each audio file
vad_outputs = {
"NW_001": "./tests/vad_output/NW_001_vad_output.json",
"NW_002": "./tests/vad_output/NW_002_vad_output.json",
"NW_003": "./tests/vad_output/NW_003_vad_output.json",
}
# Load all VAD outputs dynamically
mock_vad_results = {}
for seg_id, vad_path in vad_outputs.items():
with open(vad_path, encoding="utf-8") as file:
mock_vad_results[seg_id] = json.load(file)

class MockVADPipeline:
def __init__(self, seg_id):
self.seg_id = seg_id

def __call__(self, audio_file):
return MockVADResult(self.seg_id)

class MockVADResult:
def __init__(self, seg_id):
self.vad_output = mock_vad_results[seg_id]

def get_timeline(self):
class MockTimeline:
def __init__(self, timeline):
self.timeline = timeline

def support(self):
return [
type(
"Segment",
(),
{"start": seg["start"], "end": seg["end"]},
)
for seg in self.timeline
]

return MockTimeline(self.vad_output["timeline"])

# Setup mock behavior
def mock_initialize_pipeline(seg_id):
try:
return MockVADPipeline(seg_id)
except Exception as e:
logging.warning(
f"Mocking failed: {e}. Falling back to actual function."
)
return None

audio_urls = {
"NW_001": "https://www.rfa.org/tibetan/sargyur/golok-china-religious-restriction-08202024054225.html/@@stream", # noqa
"NW_002": "https://vot.org/wp-content/uploads/2024/03/tc88888888888888.mp3",
"NW_003": "https://voa-audio-ns.akamaized.net/vti/2024/04/13/01000000-0aff-0242-a7bb-08dc5bc45613.mp3",
}
num_of_seg_in_audios = {}
for seg_id, audio_url in audio_urls.items():
mock_pipeline = mock_initialize_pipeline(seg_id)
if mock_pipeline:
mock_initialize_vad.return_value = mock_pipeline
else:
mock_initialize_vad.side_effect = None # Disable the mock for fallback

audio_data = get_audio(audio_url)
split_audio_data = get_split_audio(
audio_data, seg_id, AUDIO_SEG_LOWER_LIMIT, AUDIO_SEG_UPPER_LIMIT
)
num_split = len(split_audio_data)
num_of_seg_in_audios[seg_id] = num_split
expected_num_of_seg_in_audios = "tests/data/expected_audio_data.json"
with open(expected_num_of_seg_in_audios, encoding="utf-8") as file:
expected_num_split = json.load(file)
assert num_of_seg_in_audios == expected_num_split


if __name__ == "__main__":
test_get_split_audio()
TestGetSplitAudio().test_get_split_audio()
56 changes: 56 additions & 0 deletions tests/vad_output/NW_001_vad_output.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
{
"timeline": [
{
"start": 0.23346875,
"end": 2.5115937500000003
},
{
"start": 2.79846875,
"end": 25.86659375
},
{
"start": 26.18721875,
"end": 30.13596875
},
{
"start": 30.50721875,
"end": 34.540343750000005
},
{
"start": 34.77659375,
"end": 40.59846875
},
{
"start": 40.85159375,
"end": 46.43721875
},
{
"start": 46.84221875,
"end": 50.487218750000004
},
{
"start": 50.790968750000005,
"end": 53.001593750000005
},
{
"start": 53.28846875,
"end": 56.19096875
},
{
"start": 56.376593750000005,
"end": 68.35784375
},
{
"start": 68.67846875000001,
"end": 146.28659375
},
{
"start": 146.53971875000002,
"end": 161.86221875
},
{
"start": 162.21659375000002,
"end": 165.74346875
}
]
}
8 changes: 8 additions & 0 deletions tests/vad_output/NW_002_vad_output.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"timeline": [
{
"start": 0.03096875,
"end": 119.26971875000001
}
]
}
Loading
Loading