From e20c8caddb677c1a57b705bb7d87533a3c2ae869 Mon Sep 17 00:00:00 2001 From: Gaetan Lepage Date: Fri, 29 Nov 2024 09:01:48 +0100 Subject: [PATCH] python312Packages.asteroid-filterbanks: fix tests --- .../asteroid-filterbanks/default.nix | 55 ++++++++++++++++--- .../torch-stft-return-complex.patch | 20 +++++++ 2 files changed, 67 insertions(+), 8 deletions(-) create mode 100644 pkgs/development/python-modules/asteroid-filterbanks/torch-stft-return-complex.patch diff --git a/pkgs/development/python-modules/asteroid-filterbanks/default.nix b/pkgs/development/python-modules/asteroid-filterbanks/default.nix index e7dc0cba9798f8..91b4ae2a24dc85 100644 --- a/pkgs/development/python-modules/asteroid-filterbanks/default.nix +++ b/pkgs/development/python-modules/asteroid-filterbanks/default.nix @@ -1,9 +1,13 @@ { lib, + stdenv, buildPythonPackage, fetchFromGitHub, + + # build-system setuptools, - wheel, + + # dependencies black, coverage, librosa, @@ -12,6 +16,9 @@ pytest, scipy, torch, + + # tests + pytestCheckHook, }: buildPythonPackage rec { @@ -22,16 +29,25 @@ buildPythonPackage rec { src = fetchFromGitHub { owner = "asteroid-team"; repo = "asteroid-filterbanks"; - rev = "v${version}"; + rev = "refs/tags/v${version}"; hash = "sha256-Z5M2Xgj83lzqov9kCw/rkjJ5KXbjuP+FHYCjhi5nYFE="; }; - nativeBuildInputs = [ + patches = [ + ./torch-stft-return-complex.patch + ]; + + # np.float is deprecated + postPatch = '' + substituteInPlace asteroid_filterbanks/multiphase_gammatone_fb.py \ + --replace-fail "np.float(" "float(" + ''; + + build-system = [ setuptools - wheel ]; - propagatedBuildInputs = [ + dependencies = [ black coverage librosa @@ -44,10 +60,33 @@ buildPythonPackage rec { pythonImportsCheck = [ "asteroid_filterbanks" ]; - meta = with lib; { + nativeCheckInputs = [ + pytestCheckHook + ]; + + disabledTests = + [ + # RuntimeError: cannot cache function '__o_fold': no locator available for file + # '/nix/store/d1znhn1n48z2raj0j9zbz80hhg4k2shw-python3.12-librosa-0.10.2.post1/lib/python3.12/site-packages/librosa/core/notation.py' + "test_melgram_encoder" + "test_melscale" + + # AssertionError: The values for attribute 'shape' do not match + "test_torch_stft" + ] + ++ lib.optionals stdenv.hostPlatform.isDarwin [ + # Issue with JIT on darwin: + # RuntimeError: required keyword attribute 'value' has the wrong type + "test_jit_filterbanks" + "test_jit_filterbanks_enc" + "test_pcen_jit" + "test_stateful_pcen_jit" + ]; + + meta = { description = "PyTorch-based audio source separation toolkit for researchers"; homepage = "https://github.com/asteroid-team/asteroid-filterbanks"; - license = licenses.mit; - maintainers = with maintainers; [ matthewcroughan ]; + license = lib.licenses.mit; + maintainers = with lib.maintainers; [ matthewcroughan ]; }; } diff --git a/pkgs/development/python-modules/asteroid-filterbanks/torch-stft-return-complex.patch b/pkgs/development/python-modules/asteroid-filterbanks/torch-stft-return-complex.patch new file mode 100644 index 00000000000000..c6472258a4ea5e --- /dev/null +++ b/pkgs/development/python-modules/asteroid-filterbanks/torch-stft-return-complex.patch @@ -0,0 +1,20 @@ +diff --git a/tests/torch_stft_test.py b/tests/torch_stft_test.py +index 1d29a51..0c7bb30 100644 +--- a/tests/torch_stft_test.py ++++ b/tests/torch_stft_test.py +@@ -128,6 +128,7 @@ def test_torch_stft( + pad_mode=pad_mode, + normalized=normalized, + onesided=True, ++ return_complex=True, + ) + + spec_asteroid = stft(wav) +@@ -145,6 +146,7 @@ def test_torch_stft( + normalized=normalized, + onesided=True, + length=output_len, ++ return_complex=True, + ) + + except RuntimeError: