Skip to content

Commit

Permalink
use normal integer dtype for state intervals
Browse files Browse the repository at this point in the history
  • Loading branch information
Sebastian Böck committed Sep 19, 2017
1 parent 515ee79 commit 1c640e5
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 17 deletions.
14 changes: 7 additions & 7 deletions madmom/features/beats_hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,17 +80,17 @@ def __init__(self, min_interval, max_interval, num_intervals=None):
intervals = np.unique(np.round(intervals))
num_log_intervals += 1
# save the intervals
self.intervals = np.ascontiguousarray(intervals, dtype=np.uint32)
self.intervals = np.ascontiguousarray(intervals, dtype=np.int)
# number of states and intervals
self.num_states = int(np.sum(intervals))
self.num_intervals = len(intervals)
# define first and last states
first_states = np.cumsum(np.r_[0, self.intervals[:-1]])
self.first_states = first_states.astype(np.uint32)
self.last_states = np.cumsum(self.intervals).astype(np.uint32) - 1
self.first_states = first_states.astype(np.int)
self.last_states = np.cumsum(self.intervals) - 1
# define the positions and intervals of the states
self.state_positions = np.empty(self.num_states)
self.state_intervals = np.empty(self.num_states, dtype=np.uint32)
self.state_intervals = np.empty(self.num_states, dtype=np.int)
# Note: having an index counter is faster than ndenumerate
idx = 0
for i in self.intervals:
Expand Down Expand Up @@ -150,7 +150,7 @@ def __init__(self, num_beats, min_interval, max_interval,
# model N beats as a bar
self.num_beats = int(num_beats)
self.state_positions = np.empty(0)
self.state_intervals = np.empty(0, dtype=np.uint32)
self.state_intervals = np.empty(0, dtype=np.int)
self.num_states = 0
# save the first and last states of the individual beats in a list
self.first_states = []
Expand Down Expand Up @@ -196,8 +196,8 @@ def __init__(self, state_spaces):
self.num_patterns = len(state_spaces)
self.state_spaces = state_spaces
self.state_positions = np.empty(0)
self.state_intervals = np.empty(0, dtype=np.uint32)
self.state_patterns = np.empty(0, dtype=np.uint32)
self.state_intervals = np.empty(0, dtype=np.int)
self.state_patterns = np.empty(0, dtype=np.int)
self.num_states = 0
# save the first and last states of the individual patterns in a list
self.first_states = []
Expand Down
16 changes: 6 additions & 10 deletions tests/test_features_beats_hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ def test_types(self):
self.assertIsInstance(bss.num_states, int)
self.assertIsInstance(bss.num_intervals, int)
# dtypes
self.assertTrue(bss.intervals.dtype == np.uint32)
self.assertTrue(bss.intervals.dtype == np.int)
self.assertTrue(bss.state_positions.dtype == np.float)
self.assertTrue(bss.state_intervals.dtype == np.uint32)
self.assertTrue(bss.first_states.dtype == np.uint32)
self.assertTrue(bss.last_states.dtype == np.uint32)
self.assertTrue(bss.state_intervals.dtype == np.int)
self.assertTrue(bss.first_states.dtype == np.int)
self.assertTrue(bss.last_states.dtype == np.int)

def test_values(self):
bss = BeatStateSpace(1, 4)
Expand Down Expand Up @@ -73,9 +73,8 @@ def test_types(self):
self.assertIsInstance(bss.first_states, list)
self.assertIsInstance(bss.last_states, list)
# dtypes
# self.assertTrue(bss.intervals.dtype == np.uint32)
self.assertTrue(bss.state_positions.dtype == np.float)
self.assertTrue(bss.state_intervals.dtype == np.uint32)
self.assertTrue(bss.state_intervals.dtype == np.int)

def test_values(self):
# 2 beats, intervals 1 to 4
Expand Down Expand Up @@ -128,11 +127,8 @@ def test_types(self):
# self.assertIsInstance(mpss.num_intervals, int)
self.assertIsInstance(mpss.num_patterns, int)
# dtypes
# self.assertTrue(mpss.intervals.dtype == np.uint32)
self.assertTrue(mpss.state_positions.dtype == np.float)
self.assertTrue(mpss.state_intervals.dtype == np.uint32)
# self.assertTrue(mpss.first_states.dtype == np.uint32)
# self.assertTrue(mpss.last_states.dtype == np.uint32)
self.assertTrue(mpss.state_intervals.dtype == np.int)

def test_values_beat(self):
# test with 2 BeatStateSpaces as before
Expand Down

0 comments on commit 1c640e5

Please sign in to comment.