Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WhisperForCTC Integration #222

Merged
merged 1 commit into from
Dec 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Empty file.
15 changes: 15 additions & 0 deletions nlu/components/classifiers/asr_whisper/whisper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from sparknlp.annotator import *


class Whisper:
@staticmethod
def get_default_model():
return WhisperForCTC.pretrained() \
.setInputCols("audio_assembler") \
.setOutputCol("text")

@staticmethod
def get_pretrained_model(name, language, bucket=None):
return WhisperForCTC.pretrained(name, language, bucket) \
.setInputCols("audio_assembler") \
.setOutputCol("text")
1 change: 1 addition & 0 deletions nlu/pipe/utils/pipe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,7 @@ def add_metadata_to_pipe(pipe: NLUPipeline):
if c.license == Licenses.open_source \
and c.name != NLP_NODE_IDS.WAV2VEC_FOR_CTC \
and c.name != NLP_NODE_IDS.HUBERT_FOR_CTC \
and c.name != NLP_NODE_IDS.WHISPER_FOR_CTC \
and c.name != NLP_NODE_IDS.AUDIO_ASSEMBLER:
# TODO Table Assembler/VIT/ Other non txt open source
pipe.has_nlp_components = True
Expand Down
190 changes: 184 additions & 6 deletions nlu/spellbook.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions nlu/universe/annotator_class_universe.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class AnnoClassRef:
A_N.E5_SENTENCE_EMBEDDINGS: 'E5Embeddings',
A_N.INSTRUCTOR_SENTENCE_EMBEDDINGS:'InstructorEmbeddings',

A_N.WHISPER_FOR_CTC: 'WhisperForCTC',
A_N.HUBERT_FOR_CTC: 'HubertForCTC',
A_N.CAMEMBERT_FOR_QUESTION_ANSWERING: 'CamemBertForQuestionAnswering',
A_N.SWIN_IMAGE_CLASSIFICATION: 'SwinForImageClassification',
Expand Down
22 changes: 22 additions & 0 deletions nlu/universe/component_universes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from nlu.components.chunkers.ngram.ngram import NGram
from nlu.components.classifiers.asr.wav2Vec import Wav2Vec
from nlu.components.classifiers.asr_hubert.hubert import Hubert
from nlu.components.classifiers.asr_whisper.whisper import Whisper
from nlu.components.classifiers.bert_zero_shot_classification.bert_zero_shot import BertZeroShotClassifier
from nlu.components.classifiers.classifier_dl.classifier_dl import ClassifierDl
from nlu.components.classifiers.distil_bert_zero_shot_classification.distil_bert_zero_shot import \
Expand Down Expand Up @@ -1481,6 +1482,27 @@ class ComponentUniverse:
applicable_file_types=['wav', 'mp3', 'flac', 'aiff', 'aifc', 'ogg', 'aflac', 'alac',
'dsd', 'pcm', ]
),
A.WHISPER_FOR_CTC: partial(NluComponent,
name=A.WHISPER_FOR_CTC,
type=T.SPEECH_RECOGNIZER,
get_default_model=Whisper.get_default_model,
get_pretrained_model=Whisper.get_pretrained_model,
pdf_extractor_methods={'default': default_only_result_config,
'default_full': default_full_config, },
pdf_col_name_substitutor=substitute_wav2vec_cols,
output_level=L.DOCUMENT,
node=NLP_FEATURE_NODES.nodes[A.WHISPER_FOR_CTC],
description='Whisper is an automatic speech recognition (ASR) system trained on 680,000 hours of multilingual and multitask supervised data collected from the web. It transcribe in multiple languages, as well as translate from those languages into English.',
provider=ComponentBackends.open_source,
license=Licenses.open_source,
computation_context=ComputeContexts.spark,
output_context=ComputeContexts.spark,
jsl_anno_class_id=A.WHISPER_FOR_CTC,
jsl_anno_py_class=ACR.JSL_anno2_py_class[A.WHISPER_FOR_CTC],
# Bas on Librosa which uses http://www.mega-nerd.com/libsndfile/
applicable_file_types=['wav', 'mp3', 'flac', 'aiff', 'aifc', 'ogg', 'aflac', 'alac',
'dsd', 'pcm', ]
),

A.TAPAS_FOR_QA: partial(NluComponent,
name=A.TAPAS_FOR_QA,
Expand Down
1 change: 1 addition & 0 deletions nlu/universe/feature_node_ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class NLP_NODE_IDS:
AUDIO_ASSEMBLER = JslAnnoId('audio_assembler')
WAV2VEC_FOR_CTC = JslAnnoId('wav2vec_for_ctc')
HUBERT_FOR_CTC = JslAnnoId('hubert_for_ctc')
WHISPER_FOR_CTC = JslAnnoId('whisper_for_ctc')
TABLE_ASSEMBLER = JslAnnoId('table_assembler')
TAPAS_FOR_QA = JslAnnoId('tapas')
MULTI_DOCUMENT_ASSEMBLER = JslAnnoId('multi_document_assembler')
Expand Down
1 change: 1 addition & 0 deletions nlu/universe/feature_node_universes.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ class NLP_FEATURE_NODES: # or Mode Node?
# A.WAV2VEC_FOR_CTC: NlpFeatureNode(A.PARTIALLY_IMPLEMENTED, [F.AUDIO], [F.RECOGNIZED_SPEECH_TEXT]),
A.WAV2VEC_FOR_CTC: NlpFeatureNode(A.WAV2VEC_FOR_CTC, [F.AUDIO], [E.RAW_TEXT]),
A.HUBERT_FOR_CTC: NlpFeatureNode(A.HUBERT_FOR_CTC, [F.AUDIO], [E.RAW_TEXT]),
A.WHISPER_FOR_CTC: NlpFeatureNode(A.WHISPER_FOR_CTC, [F.AUDIO], [E.RAW_TEXT]),

A.IMAGE_ASSEMBLER: NlpFeatureNode(A.IMAGE_ASSEMBLER, [F.SPARK_NLP_IMAGE, F.SPARK_NLP_FILE_PATH], [F.IMAGE]),
A.DOCUMENT_NORMALIZER: NlpFeatureNode(A.DOCUMENT_NORMALIZER, [F.DOCUMENT], [F.DOCUMENT_GENERATED]),
Expand Down
31 changes: 14 additions & 17 deletions tests/nlu_core_tests/component_tests/classifier_tests/asr_tests.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,16 @@
import unittest
import sparknlp
import librosa as librosa
from sparknlp.base import *
from sparknlp.annotator import *
import pandas as pd
import pyspark.sql.functions as F
from pyspark.sql.types import *
import pyspark.sql.functions as F
import sparknlp
import sparknlp
from pyspark.ml import Pipeline
from sparknlp.annotator import *
from sparknlp.base import *
import os


os.environ['PYSPARK_PYTHON'] = '/home/ckl/anaconda3/bin/python3'
os.environ['PYSPARK_DRIVER_PYTHON'] = '/home/ckl/anaconda3/bin/python3'
#os.environ['PYSPARK_PYTHON'] = '/home/ckl/anaconda3/bin/python3'
#os.environ['PYSPARK_DRIVER_PYTHON'] = '/home/ckl/anaconda3/bin/python3'



class AsrTestCase(unittest.TestCase):
def test_wav2vec(self):
import nlu
p = nlu.load('en.wav2vec.wip',verbose=True)
p = nlu.load('en.speech2text.wav2vec2.v2_base_960h',verbose=True)
FILE_PATH = os.path.normpath(r"tests/datasets/audio/asr/ngm_12484_01067234848.wav")

print("Got p ",p)
Expand All @@ -35,7 +22,17 @@ def test_wav2vec(self):

def test_hubert(self):
import nlu
p = nlu.load('en.asr_hubert_large_ls960',verbose=True)
p = nlu.load('en.speech2text.hubert.large_ls960',verbose=True)
FILE_PATH = os.path.normpath(r"tests/datasets/audio/asr/ngm_12484_01067234848.wav")

print("Got p ",p)
df = p.predict(FILE_PATH)
print(df)
df = p.predict([FILE_PATH,FILE_PATH])
print(df)
def test_whisper(self):
import nlu
p = nlu.load('xx.speech2text.whisper.tiny',verbose=True)
FILE_PATH = os.path.normpath(r"tests/datasets/audio/asr/ngm_12484_01067234848.wav")

print("Got p ",p)
Expand Down
Loading