Skip to content

Commit

Permalink
Export Pyannote models to torchscript (#694)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Jan 8, 2025
1 parent 2ffe462 commit ea0c0b5
Show file tree
Hide file tree
Showing 6 changed files with 587 additions and 0 deletions.
165 changes: 165 additions & 0 deletions .github/workflows/export-pyannote.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
name: export-pyannote

on:
push:
branches:
- export-pyannote
workflow_dispatch:

concurrency:
group: export-pyannote-${{ github.ref }}
cancel-in-progress: true

jobs:
export-pyannote:
if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj'
name: export ${{ matrix.model }}
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [macos-latest]
python-version: ["3.10"]
model: ['pyannote']

steps:
- uses: actions/checkout@v4

- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

- name: Install pyannote
shell: bash
run: |
pip install torch==2.2.0 torchaudio==2.2.0 onnxruntime onnx kaldi-native-fbank funasr numpy==1.26.4 pyannote.audio==3.3.0
- name: Export ${{ matrix.model }}
shell: bash
run: |
pushd scripts/pyannote/segmentation
model=${{ matrix.model }}
if [[ $model == 'pyannote' ]]; then
curl -SL -O https://huggingface.co/csukuangfj/pyannote-models/resolve/main/segmentation-3.0/pytorch_model.bin
else
curl -SL -O https://huggingface.co/openspeech/revai-models/resolve/main/v1/pytorch_model.bin
fi
python3 ./export.py
ls -lh
- name: Test ${{ matrix.model }}
shell: bash
run: |
pushd scripts/pyannote/segmentation
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/lei-jun-test.wav
python3 ./vad.py --model ./model.pt --wav ./lei-jun-test.wav
- name: Collect results
shell: bash
run: |
model=${{ matrix.model }}
if [[ $model == 'pyannote' ]]; then
d=sherpa-pyannote-segmentation-3-0
else
d=sherpa-reverb-diarization-v1
fi
mkdir $d
mv -v scripts/pyannote/segmentation/model.pt $d/
mv -v scripts/pyannote/segmentation/README.md $d/
mv -v scripts/pyannote/segmentation/LICENSE $d/
if [[ $model == revai ]]; then
echo "Models in this folder are converted from https://huggingface.co/Revai/reverb-diarization-v1" > $d/README.md
fi
cat $d/README.md
ls -lh $d
tar cjvf $d.tar.bz2 $d
- name: Release
uses: svenstaro/upload-release-action@v2
with:
file_glob: true
file: ./*.tar.bz2
overwrite: true
repo_name: k2-fsa/sherpa
repo_token: ${{ secrets.UPLOAD_GH_SHERPA_TOKEN }}
tag: speaker-segmentation-models

- name: Publish ${{ matrix.model }} to huggingface
shell: bash
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
model=${{ matrix.model }}
if [[ $model == 'pyannote' ]]; then
src=sherpa-pyannote-segmentation-3-0
else
src=sherpa-reverb-diarization-v1
fi
git config --global user.email "[email protected]"
git config --global user.name "Fangjun Kuang"
export GIT_CLONE_PROTECTION_ACTIVE=false
export GIT_LFS_SKIP_SMUDGE=1
rm -rf huggingface
git clone https://csukuangfj:[email protected]/csukuangfj/$src huggingface
rm -rf huggingface/*
cp -av $src/* ./huggingface/
cd huggingface
git status
ls -lh
git lfs track "*.pt*"
git add .
git commit -m "upload $src" || true
git push https://csukuangfj:[email protected]/csukuangfj/$src main || true
- name: Publish ${{ matrix.model }} to huggingface
shell: bash
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
model=${{ matrix.model }}
if [[ $model == 'pyannote' ]]; then
src=sherpa-pyannote-segmentation-3-0
else
src=sherpa-reverb-diarization-v1
fi
git config --global user.email "[email protected]"
git config --global user.name "Fangjun Kuang"
export GIT_CLONE_PROTECTION_ACTIVE=false
export GIT_LFS_SKIP_SMUDGE=1
rm -rf huggingface
git clone https://csukuangfj:[email protected]/k2-fsa/sherpa-models huggingface
mkdir -p ./huggingface/speaker-segmentation
cp -av $src.tar.bz2 ./huggingface/speaker-segmentation
cd huggingface
git status
ls -lh
git lfs track "*.tar.bz2*"
git add .
git commit -m "upload $src" || true
git push https://csukuangfj:[email protected]/k2-fsa/sherpa-models main || true
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@ sherpa-nemo-ctc*
/.vscode
*.pt
tokens.txt
*.bin
21 changes: 21 additions & 0 deletions scripts/pyannote/segmentation/LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2022 CNRS

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
4 changes: 4 additions & 0 deletions scripts/pyannote/segmentation/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Introduction

Models in this file are converted from
https://huggingface.co/pyannote/segmentation-3.0/tree/main
137 changes: 137 additions & 0 deletions scripts/pyannote/segmentation/export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
#!/usr/bin/env python3
# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)

import torch
import torch.nn.functional as F
from pyannote.audio import Model
from pyannote.audio.models.blocks.sincnet import SincNet
from pyannote.core.utils.generators import pairwise
from torch import nn

"""
"linear": {'hidden_size': 128, 'num_layers': 2}
"lstm": {'hidden_size': 256, 'num_layers': 2, 'bidirectional': True, 'monolithic': True, 'dropout': 0.0, 'batch_first': True}
"num_channels": 1
"sample_rate": 16000
"sincnet": {'stride': 10, 'sample_rate': 16000}
"""


class PyanNet(torch.nn.Module):
def __init__(self, m):
super().__init__()
self.sincnet = SincNet(**m.hparams.sincnet)

multi_layer_lstm = dict(m.hparams.lstm)
del multi_layer_lstm["monolithic"]
self.lstm = nn.LSTM(60, **multi_layer_lstm)

lstm_out_features: int = m.hparams.lstm["hidden_size"] * (
2 if m.hparams.lstm["bidirectional"] else 1
)
self.linear = nn.ModuleList(
[
nn.Linear(in_features, out_features)
for in_features, out_features in pairwise(
[
lstm_out_features,
]
+ [m.hparams.linear["hidden_size"]] * m.hparams.linear["num_layers"]
)
]
)

if m.hparams.linear["num_layers"] > 0:
in_features = m.hparams.linear["hidden_size"]
else:
in_features = m.hparams.lstm["hidden_size"] * (
2 if m.hparams.lstm["bidirectional"] else 1
)

self.classifier = nn.Linear(in_features, m.dimension)
self.activation = m.default_activation()

def forward(self, waveforms):
"""Pass forward
Parameters
----------
waveforms : (batch, channel, sample)
Returns
-------
scores : (batch, frame, classes)
"""

outputs = self.sincnet(waveforms)

outputs, _ = self.lstm(torch.permute(outputs, (0, 2, 1)))

for linear in self.linear:
outputs = F.leaky_relu(linear(outputs))

return self.activation(self.classifier(outputs))


@torch.inference_mode()
def main():
# You can download ./pytorch_model.bin from
# https://hf-mirror.com/csukuangfj/pyannote-models/tree/main/segmentation-3.0
# or from
# https://huggingface.co/Revai/reverb-diarization-v1/tree/main
pt_filename = "./pytorch_model.bin"
model = Model.from_pretrained(pt_filename)
wrapper = PyanNet(model)

num_param1 = sum([p.numel() for p in model.parameters()])
num_param2 = sum([p.numel() for p in wrapper.parameters()])

assert num_param1 == num_param2, (num_param1, num_param2, model.hparams)
print(f"Number of model parameters1: {num_param1}")
print(f"Number of model parameters2: {num_param2}")

model.eval()

# model.to_torchscript() # won't work

wrapper.eval()

wrapper.load_state_dict(model.state_dict())

x = torch.rand(1, 1, 10 * 16000)

y1 = model(x)
y2 = wrapper(x)

assert y1.shape == y2.shape, (y1.shape, y2.shape)
assert torch.allclose(y1, y2), (y1.sum(), y2.sum())

m = torch.jit.script(wrapper)

sample_rate = model.audio.sample_rate
assert sample_rate == 16000, sample_rate

window_size = int(model.specifications.duration) * 16000
receptive_field_size = int(model.receptive_field.duration * 16000)
receptive_field_shift = int(model.receptive_field.step * 16000)

meta_data = {
"num_speakers": str(len(model.specifications.classes)),
"powerset_max_classes": str(model.specifications.powerset_max_classes),
"num_classes": str(model.dimension),
"sample_rate": str(sample_rate),
"window_size": str(window_size),
"receptive_field_size": str(receptive_field_size),
"receptive_field_shift": str(receptive_field_shift),
"model_type": "pyannote-segmentation-3.0",
"version": "1",
"maintainer": "k2-fsa",
}

m.save("model.pt", _extra_files=meta_data)
print(meta_data)


if __name__ == "__main__":
torch.manual_seed(20240108)
main()
Loading

0 comments on commit ea0c0b5

Please sign in to comment.