From 0fbd9058d5ecdc0372714cac330184f82e609b7d Mon Sep 17 00:00:00 2001 From: kaldan007 Date: Fri, 20 Dec 2024 10:41:29 +0530 Subject: [PATCH] outline set up --- pyproject.toml | 5 + .../get_original_transcript.py | 10 -- ...get_reference_transcript_with_timestamp.py | 0 src/Stt_data_with_llm/inference_transcript.py | 25 ---- src/Stt_data_with_llm/stt_data_corrector.py | 137 ------------------ src/Stt_data_with_llm/util.py | 19 --- src/stt_data_with_llm/LLM_post_corrector.py | 2 + .../__init__.py | 0 src/stt_data_with_llm/audio_parser.py | 5 + src/stt_data_with_llm/catalog_parser.py | 2 + src/stt_data_with_llm/config.py | 2 + src/stt_data_with_llm/inference_transcript.py | 2 + src/stt_data_with_llm/main.py | 68 +++++++++ 13 files changed, 86 insertions(+), 191 deletions(-) delete mode 100644 src/Stt_data_with_llm/get_original_transcript.py delete mode 100644 src/Stt_data_with_llm/get_reference_transcript_with_timestamp.py delete mode 100644 src/Stt_data_with_llm/inference_transcript.py delete mode 100644 src/Stt_data_with_llm/stt_data_corrector.py delete mode 100644 src/Stt_data_with_llm/util.py create mode 100644 src/stt_data_with_llm/LLM_post_corrector.py rename src/{Stt_data_with_llm => stt_data_with_llm}/__init__.py (100%) create mode 100644 src/stt_data_with_llm/audio_parser.py create mode 100644 src/stt_data_with_llm/catalog_parser.py create mode 100644 src/stt_data_with_llm/config.py create mode 100644 src/stt_data_with_llm/inference_transcript.py create mode 100644 src/stt_data_with_llm/main.py diff --git a/pyproject.toml b/pyproject.toml index e7aa9e8..8d6e925 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,11 @@ classifiers = [ "Operating System :: OS Independent", ] +dependencies = [ + "git+https://github.com/OpenPecha/fast-antx.git", +] + + [project.optional-dependencies] dev = [ "pre-commit", diff --git a/src/Stt_data_with_llm/get_original_transcript.py b/src/Stt_data_with_llm/get_original_transcript.py deleted file mode 100644 index e210ad1..0000000 --- a/src/Stt_data_with_llm/get_original_transcript.py +++ /dev/null @@ -1,10 +0,0 @@ -def say_hi(name): - if name.endswith("J"): - return "Hello, J!" - else: - print("NOthin") - - -if __name__ == "__main__": - p = say_hi("impaJ") - print(p) diff --git a/src/Stt_data_with_llm/get_reference_transcript_with_timestamp.py b/src/Stt_data_with_llm/get_reference_transcript_with_timestamp.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/Stt_data_with_llm/inference_transcript.py b/src/Stt_data_with_llm/inference_transcript.py deleted file mode 100644 index 6943c39..0000000 --- a/src/Stt_data_with_llm/inference_transcript.py +++ /dev/null @@ -1,25 +0,0 @@ -from util import get_audio, get_split_audio - - -def get_inference(audio_split_data): - - return inference_transcript - - -def get_inference_transcript(audio_url, full_audio_id): - segment_upper_limit = 8 - segment_lower_limit = 2 - # Get audio from URL - audio_data = get_audio(audio_url) # Store audio in memory - # Get split audio data - split_audio_data = get_split_audio( - audio_data, segment_upper_limit, segment_lower_limit - ) - inference_transcript = [] - # Get inference transcript from the split audio data - for split_data in split_audio_data: - # upload to the s3 bucket - # audio_split_url = upload_to_s3(split_data, full_audio_id) - inference_transcript.append(get_inference(split_data)) - - return inference_transcript diff --git a/src/Stt_data_with_llm/stt_data_corrector.py b/src/Stt_data_with_llm/stt_data_corrector.py deleted file mode 100644 index 73c5a1c..0000000 --- a/src/Stt_data_with_llm/stt_data_corrector.py +++ /dev/null @@ -1,137 +0,0 @@ -import os -import sys - -import pandas as pd -from anthropic import Client -from tqdm.auto import tqdm - -sys.path.append("../util") - -from common_utils import parse_args_and_load_config - -# Initialize Claude AI client -API_KEY = os.getenv("ANTHROPIC_API_KEY") -client = Client(api_key=API_KEY) - - -def correct_transcription(inference_text: str, transfer_text: str) -> str: - """ - Corrects the inference transcription by comparing it with the transferred text using Claude AI. - - Args: - inference_text (str): The initial transcription text that may contain spelling errors. - transfer_text (str): The transferred text with accurate spelling - - Returns: - str: The corrected transcription with spelling errors fixed and context preserved. - """ - - prompt = f""" - \n\nHuman: Here is an initial transcription with context accuracy: - {inference_text} - - Here is the transferred text with accurate spelling: - {transfer_text} - - Please review the initial transcription and correct **only the spelling errors** by referring to the words in the transferred text. - - Ensure that the overall structure and context of the initial transcription remain unchanged. - - If any word in the initial transcription is misspelled and a correct spelling is found in the transferred text, use it as a correction. - - If a word in the initial transcription is correct or if no similar match exists in the transferred text, leave it unchanged. - - Output only the corrected transcription without any additional explanation, ensuring the meaning and structure remain the same. - \n\nAssistant: - """ - response = client.completions.create( - prompt=prompt, - stop_sequences=["\n\nHuman:"], - model="claude-2", - max_tokens_to_sample=500, - temperature=0.7, - ) - - # Clean response to remove extra text like "Here is the corrected transcription:" - corrected_text = response.completion.strip() - - # Remove any leading explanation in English if present - if "Here is the corrected transcription:" in corrected_text: - corrected_text = corrected_text.split("Here is the corrected transcription:")[ - 1 - ].strip() - - return corrected_text - - -def main(config): - # Extract configuration - dept = config["DEPARTMENT"] - from_id = config["FROM_ID"] - to_id = config["TO_ID"] - group_id = config["GROUP_ID"] - - # Load inference CSV - inference_csv_path = f"../data/{dept}.csv" - transfer_csv_path = f"../data/{dept}_{from_id}_to_{to_id}_transferred.csv" - output_csv_path = f"../data/{dept}_{group_id}_{from_id}_to_{to_id}_corrected.csv" - - print("Reading inference CSV...") - if not os.path.exists(inference_csv_path): - raise FileNotFoundError(f"Inference CSV file not found: {inference_csv_path}") - inference_csv_files_df = pd.read_csv(inference_csv_path) - inference_csv_files_df["inference_transcript"] = ( - inference_csv_files_df["inference_transcript"].astype(str).str.strip() - ) - - print("Reading transferred text CSV...") - if not os.path.exists(transfer_csv_path): - raise FileNotFoundError( - f"Transferred Text CSV file not found: {transfer_csv_path}" - ) - existing_transfer_csv_files_df = pd.read_csv(transfer_csv_path) - existing_transfer_csv_files_df["inference_transcript"] = ( - existing_transfer_csv_files_df["inference_transcript"].astype(str).str.strip() - ) - - # Merge the datasets - print("Merging inference and transfer data...") - merged_data = pd.merge( - inference_csv_files_df, - existing_transfer_csv_files_df[["file_name", "inference_transcript"]], - on="file_name", - suffixes=("_inference", "_transfer"), - ) - - corrected_transcriptions = [] - is_changed_flags = [] - - print("Processing corrections...") - for _, row in tqdm( - merged_data.iterrows(), total=len(merged_data), desc="Correcting Transcriptions" - ): - file_name = row["file_name"] - original_inference_text = row["inference_transcript_inference"] - transfer_text = row["inference_transcript_transfer"] - - # Correct transcription - corrected_text = correct_transcription(original_inference_text, transfer_text) - - # Track changes - is_changed_flags.append(corrected_text != original_inference_text) - corrected_transcriptions.append(corrected_text) - - # Update the inference DataFrame with corrections - inference_csv_files_df[ - "inference_transcript" - ] = corrected_transcriptions # Replace original column - inference_csv_files_df["is_changed"] = is_changed_flags - - # Save the corrected CSV - inference_csv_files_df.to_csv(output_csv_path, index=False) - print(f"Corrected transcriptions saved to {output_csv_path}") - - -if __name__ == "__main__": - # Parse arguments and load config - config = parse_args_and_load_config() - - # Run the main pipeline logic - main(config) diff --git a/src/Stt_data_with_llm/util.py b/src/Stt_data_with_llm/util.py deleted file mode 100644 index fd17064..0000000 --- a/src/Stt_data_with_llm/util.py +++ /dev/null @@ -1,19 +0,0 @@ -def get_audio(audio_url): - # Get audio from URL - return audio_data - - -def get_split_audio(audio_url, segment_upper_limit, segment_lower_limit): - audio_data = get_audio(audio_url) # Store audio in memory - - -def upload_to_s3(split_audio_data, full_audio_id): - # upload to the s3 bucket - return audio_split_url - - -def save_inference_transcript_to_csv( - full_audio_id, audio_split_url, inference_transcript -): - # Save the inference transcript to their corresponding - print("Saving inference transcript") diff --git a/src/stt_data_with_llm/LLM_post_corrector.py b/src/stt_data_with_llm/LLM_post_corrector.py new file mode 100644 index 0000000..6985da2 --- /dev/null +++ b/src/stt_data_with_llm/LLM_post_corrector.py @@ -0,0 +1,2 @@ +def get_LLM_corrected_text(inference_text, reference_text): + pass \ No newline at end of file diff --git a/src/Stt_data_with_llm/__init__.py b/src/stt_data_with_llm/__init__.py similarity index 100% rename from src/Stt_data_with_llm/__init__.py rename to src/stt_data_with_llm/__init__.py diff --git a/src/stt_data_with_llm/audio_parser.py b/src/stt_data_with_llm/audio_parser.py new file mode 100644 index 0000000..9e58436 --- /dev/null +++ b/src/stt_data_with_llm/audio_parser.py @@ -0,0 +1,5 @@ +def get_audio(audio_url): + pass + +def get_split_audio(audio_data, AUDIO_SEG_LOWER_LIMIT, AUDIO_SEG_UPPER_LIMIT, full_audio_id): + pass \ No newline at end of file diff --git a/src/stt_data_with_llm/catalog_parser.py b/src/stt_data_with_llm/catalog_parser.py new file mode 100644 index 0000000..b537a06 --- /dev/null +++ b/src/stt_data_with_llm/catalog_parser.py @@ -0,0 +1,2 @@ +def parse_catalog(catalog): + pass \ No newline at end of file diff --git a/src/stt_data_with_llm/config.py b/src/stt_data_with_llm/config.py new file mode 100644 index 0000000..fccb846 --- /dev/null +++ b/src/stt_data_with_llm/config.py @@ -0,0 +1,2 @@ +AUDIO_SEG_UPPER_LIMIT = 8 +AUDIO_SEG_LOWER_LIMIT = 2 \ No newline at end of file diff --git a/src/stt_data_with_llm/inference_transcript.py b/src/stt_data_with_llm/inference_transcript.py new file mode 100644 index 0000000..cdf5f43 --- /dev/null +++ b/src/stt_data_with_llm/inference_transcript.py @@ -0,0 +1,2 @@ +def get_audio_inference_text(audio_seg_data): + pass \ No newline at end of file diff --git a/src/stt_data_with_llm/main.py b/src/stt_data_with_llm/main.py new file mode 100644 index 0000000..02ea350 --- /dev/null +++ b/src/stt_data_with_llm/main.py @@ -0,0 +1,68 @@ +import logging + +from fast_antx.core import transfer + +from stt_data_with_llm.catalog_parser import parse_catalog +from stt_data_with_llm.config import AUDIO_SEG_UPPER_LIMIT, AUDIO_SEG_LOWER_LIMIT +from stt_data_with_llm.audio_parser import get_audio, get_split_audio +from stt_data_with_llm.LLM_post_corrector import get_LLM_corrected_text +from stt_data_with_llm.inference_transcript import get_audio_inference_text + + +logging.basicConfig(filename='./pipeline.log', level=logging.INFO) + +def transfer_segmentation(inference_transcript, reference_transcript): + reference_transcript = reference_transcript.replace("\n", " ") + patterns = [['segmentation', '(\n)']] + reference_transcript_with_inference_segmentation = transfer(inference_transcript, patterns, reference_transcript) + return reference_transcript_with_inference_segmentation + +def is_valid_transcript(inference_transcript, reference_transcript): + pass + +def post_process_audio_transcript_pairs(audio_data_info): + post_processed_audio_transcript_pairs = {} + inference_transcript = "" + audio_url = audio_data_info.get("audio_url", "") + full_audio_id = audio_data_info.get("full_audio_id", "") + reference_transcript = audio_data_info.get("reference_transcript", "") + if not audio_url: + return None,full_audio_id + audio_data = get_audio(audio_url) + split_audio_data = get_split_audio(audio_data, AUDIO_SEG_LOWER_LIMIT, AUDIO_SEG_UPPER_LIMIT, full_audio_id) + for audio_seg_id, audio_seg_data in split_audio_data.items(): + audio_seg_inference_transcript = get_audio_inference_text(audio_seg_data) + inference_transcript += f'{audio_seg_inference_transcript}\n' + if not is_valid_transcript(inference_transcript, reference_transcript): + return None,full_audio_id + reference_transcript_with_inference_segmentation = transfer_segmentation(inference_transcript, reference_transcript) + inference_transcripts = inference_transcript.split("\n") + reference_transcripts = reference_transcript_with_inference_segmentation.split("\n") + for seg_walker,(audio_seg_id, audio_seg_data) in enumerate(split_audio_data.items()): + seg_inference_text = inference_transcripts[seg_walker] + seg_reference_text = reference_transcripts[seg_walker] + seg_LLM_corrected_text = get_LLM_corrected_text(seg_inference_text, seg_reference_text) + post_process_audio_transcript_pairs[audio_seg_id] = { + "audio_seg_data": audio_seg_data, + "inference_transcript": seg_inference_text, + "reference_transcript": seg_reference_text, + "LLM_corrected_text": seg_LLM_corrected_text + } + return post_processed_audio_transcript_pairs,full_audio_id + + +def save_post_processed_audio_transcript_pairs(post_processed_audio_transcript_pairs, audio_data_info): + # Save post processed audio transcript pairs in csv + pass + + +def get_audio_transcript_pairs(audio_transcription_catalog_url): + audio_transcription_datas = parse_catalog(audio_transcription_catalog_url) + for data_id, audio_data_info in audio_transcription_datas.items(): + post_processed_audio_transcript_pairs,full_audio_id = post_process_audio_transcript_pairs(audio_data_info) + if post_processed_audio_transcript_pairs: + save_post_processed_audio_transcript_pairs(post_processed_audio_transcript_pairs, audio_data_info) + else: + logging.info(f"Audio data with ID {full_audio_id} has invalid transcript") + +