Skip to content

Commit

Permalink
validate t,f0 shape (should be equal) in F0. validate non-empty f0 list
Browse files Browse the repository at this point in the history
  • Loading branch information
tandav committed Apr 7, 2023
1 parent 2628d81 commit 1bbdf37
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 2 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ repos:
rev: v0.982
hooks:
- id: mypy
additional_dependencies: [types-redis, types-tabulate]
additional_dependencies: [types-redis, types-tabulate, pydantic]

- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.302
Expand Down
2 changes: 2 additions & 0 deletions pitch_detectors/algorithms/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def vote_and_median(
for i, (name, data) in enumerate(algorithms.items()):
t = data.t
f0 = data.f0
if len(f0) == 0:
raise ValueError(f'algorithm {name} returned an empty f0 array')
f0_resampled[name] = np.full_like(t_resampled, fill_value=np.nan)
notna_slices = np.ma.clump_unmasked(np.ma.masked_invalid(f0))

Expand Down
9 changes: 9 additions & 0 deletions pitch_detectors/schemas.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from typing import Any

import numpy as np
from pydantic import BaseModel
from pydantic import root_validator


class ArbitraryBaseModel(BaseModel):
Expand All @@ -16,6 +19,12 @@ class F0(ArbitraryBaseModel):
t: np.ndarray
f0: np.ndarray

@root_validator
def check_shape(cls, values: dict[str, Any]) -> dict[str, Any]: # pylint: disable=no-self-argument
if values['t'].shape != values['f0'].shape:
raise ValueError('t and f0 must have the same shape')
return values


class Record(ArbitraryBaseModel):
fs: int | None = None
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ warn_return_any = true
warn_unreachable = true
warn_unused_configs = true
warn_unused_ignores = true
plugins = [
"pydantic.mypy",
]

[[tool.mypy.overrides]]
module = ["tests.*"]
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
@pytest.fixture
def record():
fs, a = util.load_wav(Path(__file__).parent.parent / 'data' / 'b1a5da49d564a7341e7e1327aa3f229a.wav')
return Record(fs, a)
return Record(fs=fs, a=a)


@pytest.fixture
Expand Down

0 comments on commit 1bbdf37

Please sign in to comment.