Skip to content

Commit

Permalink
add TCN tempo histogram processor and TCNTempoDetector
Browse files Browse the repository at this point in the history
  • Loading branch information
Sebastian Böck committed Jan 6, 2022
1 parent 1933f4d commit d25e3b8
Show file tree
Hide file tree
Showing 7 changed files with 260 additions and 18 deletions.
107 changes: 107 additions & 0 deletions bin/TCNTempoDetector
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
#!/usr/bin/env python
# encoding: utf-8
"""
TCNTempoDetector beat tracking algorithm.
"""

from __future__ import absolute_import, division, print_function

import argparse

import numpy as np

from madmom.audio import SignalProcessor
from madmom.features import ActivationsProcessor
from madmom.features.beats import TCNBeatProcessor
from madmom.features.tempo import TCNTempoHistogramProcessor, TempoEstimationProcessor
from madmom.io import write_events, write_tempo
from madmom.processors import IOProcessor, io_arguments


def main():
"""TCNTempoDetector"""

# define parser
p = argparse.ArgumentParser(
formatter_class=argparse.RawDescriptionHelpFormatter, description='''
The TCNTempoDetector program detects the tempo in an audio file according
to the method described in:
"Multi-Task learning of tempo and beat: learning one to improve the other"
Sebastian Böck, Matthew Davies and Peter Knees.
Proc. of the 20th International Society for Music Information Retrieval
Conference (ISMIR), 2019.
This program can be run in 'single' file mode to process a single audio
file and write the detected beats to STDOUT or the given output file.
$ TCNTempoDetector single INFILE [-o OUTFILE]
If multiple audio files should be processed, the program can also be run
in 'batch' mode to save the detected beats to files with the given suffix.
$ TCNTempoDetector batch [-o OUTPUT_DIR] [-s OUTPUT_SUFFIX] FILES
If no output directory is given, the program writes the files with the
detected beats to the same location as the audio files.
The 'pickle' mode can be used to store the used parameters to be able to
exactly reproduce experiments.
''')
# version
p.add_argument('--version', action='version',
version='TCNTempoDetector')
# input/output options
io_arguments(p, output_suffix='.bpm.txt', online=True)
ActivationsProcessor.add_arguments(p)
# signal processing arguments
SignalProcessor.add_arguments(p, norm=False, gain=0)
# tempo arguments
TempoEstimationProcessor.add_arguments(p, hist_smooth=15)

# parse arguments
args = p.parse_args()

# set immutable arguments
args.tasks = (1, )
args.interpolate = True
args.method = None
args.act_smooth = None

# print arguments
if args.verbose:
print(args)

# input processor
if args.load:
# load the activations from file
in_processor = ActivationsProcessor(mode='r', **vars(args))
else:
# use a TCN to predict beats and tempo
in_processor = TCNBeatProcessor(**vars(args))

# output processor
if args.save:
# save the TCN activations to file
out_processor = ActivationsProcessor(mode='w', **vars(args))
else:
# extract the tempo histogram from the NN output
args.histogram_processor = TCNTempoHistogramProcessor(**vars(args))
# estimate tempo
tempo_estimator = TempoEstimationProcessor(**vars(args))
# output handler
output = write_tempo
# sequentially process them
out_processor = [tempo_estimator, output]

# create an IOProcessor
processor = IOProcessor(in_processor, out_processor)

# and call the processing function
args.func(processor, **vars(args))


if __name__ == '__main__':
main()
89 changes: 71 additions & 18 deletions madmom/features/tempo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import sys
import warnings
from operator import itemgetter

import numpy as np

Expand Down Expand Up @@ -204,9 +205,9 @@ def detect_tempo(histogram, fps=None, interpolate=False):
Histogram (tuple of 2 numpy arrays, the first giving the strengths of
the bins and the second corresponding tempo/delay values).
fps : float, optional
Frames per second. If 'None', the second element is interpreted as tempo
values. If set, the histogram's second element is interpreted as inter
beat intervals (IBIs) in frames with the given rate.
Frames per second. If 'None', the second element is interpreted as
tempo values. If set, the histogram's second element is interpreted as
inter beat intervals (IBIs) in frames with the given rate.
interpolate : bool, optional
Interpolate the histogram bins.
Expand Down Expand Up @@ -285,8 +286,8 @@ def __init__(self, min_bpm, max_bpm, hist_buffer=HIST_BUFFER, fps=None,
online=False, **kwargs):
# pylint: disable=unused-argument
super(TempoHistogramProcessor, self).__init__(online=online)
self.min_bpm = min_bpm
self.max_bpm = max_bpm
self.min_bpm = float(min_bpm)
self.max_bpm = float(max_bpm)
self.hist_buffer = hist_buffer
self.fps = fps
if self.online:
Expand Down Expand Up @@ -609,34 +610,86 @@ def process_online(self, activations, reset=True, **kwargs):
return np.sum(bins, axis=0), self.intervals


class TCNTempoHistogramProcessor(TempoHistogramProcessor):
"""
Derive a tempo histogram from (multi-task) TCN output.
Parameters
----------
min_bpm : float, optional
Minimum tempo to detect [bpm].
max_bpm : float, optional
Maximum tempo to detect [bpm].
References
----------
.. [1] Sebastian Böck, Matthew Davies and Peter Knees,
"Multi-Task learning of tempo and beat: learning one to improve the
other",
Proceedings of the 20th International Society for Music Information
Retrieval Conference (ISMIR), 2019.
"""

def __init__(self, min_bpm=MIN_BPM, max_bpm=MAX_BPM, **kwargs):
# pylint: disable=unused-argument
super(TCNTempoHistogramProcessor, self).__init__(
min_bpm=min_bpm, max_bpm=max_bpm, **kwargs)

def process(self, data, **kwargs):
"""
Extract tempo histogram from (multi-task) TCN output.
Parameters
----------
data : numpy array or tuple of numpy arrays
Tempo-task (numpy array) or multi-task (tuple) output of TCN.
Returns
-------
histogram_bins : numpy array
Bins of tempo histogram, i.e. tempo strengths.
histogram_tempi : numpy array
Corresponding tempi [bpm].
"""
# if data is a tuple, tempo is usually last item of TCN output
if type(data) == tuple:
data = itemgetter(-1)(data)
# use a linear tempo range
tempi = np.arange(len(data))
# determine tempo range to consider
min_idx = np.argmax(tempi >= self.min_bpm)
max_idx = np.argmin(tempi <= self.max_bpm)
# return only selected range
return data[min_idx:max_idx], tempi[min_idx:max_idx]


class TempoEstimationProcessor(OnlineProcessor):
"""
Tempo Estimation Processor class.
Parameters
----------
method : {'comb', 'acf', 'dbn'}
Method used for tempo estimation.
method : {'comb', 'acf', 'dbn', None}
Method used for tempo histogram creation, e.g. from a beat
activation function or tempo classification layer.
min_bpm : float, optional
Minimum tempo to detect [bpm].
max_bpm : float, optional
Maximum tempo to detect [bpm].
act_smooth : float, optional (default: 0.14)
act_smooth : float, optional
Smooth the activation function over `act_smooth` seconds.
hist_smooth : int, optional (default: 7)
hist_smooth : int, optional
Smooth the tempo histogram over `hist_smooth` bins.
alpha : float, optional
Scaling factor for the comb filter.
fps : float, optional
Frames per second.
histogram_processor : :class:`TempoHistogramProcessor`, optional
Processor used to create a tempo histogram. If 'None', a default
combfilter histogram processor will be created and used.
Processor used to create a tempo histogram.
interpolate : bool, optional
Interpolate tempo with quadratic interpolation.
kwargs : dict, optional
Keyword arguments passed to :class:`CombFilterTempoHistogramProcessor`
if no `histogram_processor` was given.
Examples
--------
Expand Down Expand Up @@ -670,8 +723,8 @@ def __init__(self, method=METHOD, min_bpm=MIN_BPM, max_bpm=MAX_BPM,
if method is not None:
warnings.warn(
'Usage of `method` is deprecated as of version 0.17. '
'Please use a dedicated `TempoHistogramProcessor` '
'before the `TempoEstimationProcessor` instead. '
'Please pass a dedicated `TempoHistogramProcessor` '
'instance as `histogram_processor`.'
'Functionality will be removed in version 0.19.')
self.method = method
self.act_smooth = act_smooth
Expand Down Expand Up @@ -750,8 +803,8 @@ def process_offline(self, activations, **kwargs):
if self.act_smooth is not None:
act_smooth = int(round(self.fps * self.act_smooth))
activations = smooth_signal(activations, act_smooth)
# generate a histogram of beat intervals
histogram = self.interval_histogram(activations.astype(float))
# generate tempo histogram from beat activations/TCN classification
histogram = self.histogram_processor(activations)
# smooth the histogram
histogram = smooth_histogram(histogram, self.hist_smooth)
# detect the tempi and return them
Expand Down
Binary file added tests/data/activations/sample.beats_tcn_tempo.npz
Binary file not shown.
1 change: 1 addition & 0 deletions tests/data/detections/sample.tcn_tempo_detector.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
87.16 174.89 0.74
28 changes: 28 additions & 0 deletions tests/test_bin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1123,6 +1123,34 @@ def test_all_tempi(self):
[68.97, 0.099], [82.19, 0.096]]))


class TestTCNTempoDetectorProgram(unittest.TestCase):

def setUp(self):
self.bin = pj(program_path, "TCNTempoDetector")
self.activations = Activations(
pj(ACTIVATIONS_PATH, "sample.beats_tcn_tempo.npz"))
self.result = np.loadtxt(
pj(DETECTIONS_PATH, "sample.tcn_tempo_detector.txt"))

def test_help(self):
self.assertTrue(run_help(self.bin))

def test_binary(self):
# save activations as binary file
run_save(self.bin, sample_file, tmp_act)
act = Activations(tmp_act)
self.assertTrue(np.allclose(act, self.activations, atol=1e-5))
# reload from file
run_load(self.bin, tmp_act, tmp_result)
result = np.loadtxt(tmp_result)
self.assertTrue(np.allclose(result, self.result, atol=1e-5))

def test_run(self):
run_single(self.bin, sample_file, tmp_result)
result = np.loadtxt(tmp_result)
self.assertTrue(np.allclose(result, self.result, atol=1e-5))


# clean up
def teardown_module():
os.unlink(tmp_act)
Expand Down
51 changes: 51 additions & 0 deletions tests/test_features_tempo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import unittest
from os.path import join as pj

from madmom.features import Activations
from madmom.features.tempo import *
from madmom.io import write_tempo, load_tempo
from . import ACTIVATIONS_PATH
Expand All @@ -36,6 +37,8 @@
DBN_TEMPI_ONLINE = [[176.470588, 0.580877380], [86.9565217, 0.244729904],
[74.0740741, 0.127887992], [40.8163265, 0.0232523621],
[250.000000, 0.0232523621]]
TCN_TEMPI = np.array([[87, 0.62103526], [175, 0.2131467], [58, 0.1607556],
[41, 0.00323059], [115, 0.0008726]])
HIST = interval_histogram_comb(act, 0.79, min_tau=24, max_tau=150)


Expand Down Expand Up @@ -419,6 +422,54 @@ def test_process_online(self):
self.assertTrue(np.allclose(np.median(hist), 0))


class TestTCNTempoHistogramProcessorClass(unittest.TestCase):

def setUp(self):
self.processor = TCNTempoHistogramProcessor(min_bpm=10, max_bpm=250)
self.act = Activations(pj(ACTIVATIONS_PATH,
"sample.beats_tcn_tempo.npz"))

def test_types(self):
self.assertIsInstance(self.processor.min_bpm, float)
self.assertIsInstance(self.processor.max_bpm, float)
self.assertIsNone(self.processor.fps)

def test_values(self):
self.assertTrue(self.processor.min_bpm == 10)
self.assertTrue(self.processor.max_bpm == 250)
self.assertTrue(np.sum(self.act) == 1)

def test_process(self):
hist, tempi = self.processor(self.act)
self.assertTrue(np.allclose(tempi, np.arange(10, 251)))
self.assertTrue(np.allclose(hist.max(), 0.1326968))
self.assertTrue(np.allclose(hist.min(), 5.05e-09))
self.assertTrue(np.allclose(hist.argmax(), 77))
self.assertTrue(np.allclose(hist.argmin(), 182))
# hist sum is not 1, since we excluded tempi < 10
self.assertTrue(np.allclose(np.sum(hist), 0.9999768))
self.assertTrue(np.allclose(np.mean(hist), 0.0041492814))
self.assertTrue(np.allclose(np.median(hist), 7.891e-06))

def test_tempo(self):
tempo_processor = TempoEstimationProcessor(
histogram_processor=self.processor, act_smooth=None)
tempi = tempo_processor(self.act)
self.assertTrue(tempi.shape == (14, 2))
self.assertTrue(np.allclose(tempi[:, 0],
[87, 174, 58, 41, 115, 132, 100,
32, 197, 246, 228, 23, 214, 14]))
self.assertTrue(np.allclose(np.sum(tempi[:, 1]), 1))

def test_tempo_hist_smooth(self):
tempo_processor = TempoEstimationProcessor(
histogram_processor=self.processor, act_smooth=None,
hist_smooth=15)
tempi = tempo_processor(self.act)
self.assertTrue(tempi.shape == (9, 2))
self.assertTrue(np.allclose(tempi[:5], TCN_TEMPI, atol=0.01))


class TestWriteTempoFunction(unittest.TestCase):

def setUp(self):
Expand Down
2 changes: 2 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
pj(ACTIVATIONS_PATH, 'sample.beats_blstm_mm.npz'),
pj(ACTIVATIONS_PATH, 'sample.beats_lstm.npz'),
pj(ACTIVATIONS_PATH, 'sample.beats_tcn_beats.npz'),
pj(ACTIVATIONS_PATH, 'sample.beats_tcn_tempo.npz'),
pj(ACTIVATIONS_PATH, 'sample.cnn_chord_features.npz'),
pj(ACTIVATIONS_PATH, 'sample.downbeats_blstm.npz'),
pj(ACTIVATIONS_PATH, 'sample.deep_chroma.npz'),
Expand Down Expand Up @@ -90,6 +91,7 @@
pj(DETECTIONS_PATH, 'sample.super_flux.txt'),
pj(DETECTIONS_PATH, 'sample.super_flux_nn.txt'),
pj(DETECTIONS_PATH, 'sample.tcn_beat_tracker.txt'),
pj(DETECTIONS_PATH, 'sample.tcn_tempo_detector.txt'),
pj(DETECTIONS_PATH, 'sample.tempo_detector.txt'),
pj(DETECTIONS_PATH, 'sample2.cnn_chord_recognition.txt'),
pj(DETECTIONS_PATH, 'sample2.dc_chord_recognition.txt'),
Expand Down

0 comments on commit d25e3b8

Please sign in to comment.