From b70b69c0ec432832a25b19d7015a004558b546fe Mon Sep 17 00:00:00 2001 From: Dan Meliza Date: Thu, 2 Jan 2025 17:58:04 -0500 Subject: [PATCH] add some tests for core; rework copy_entries --- arfx/core.py | 39 +++++++------- test/test_core.py | 127 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 148 insertions(+), 18 deletions(-) create mode 100644 test/test_core.py diff --git a/arfx/core.py b/arfx/core.py index 741f6e7..a1da27a 100644 --- a/arfx/core.py +++ b/arfx/core.py @@ -21,6 +21,8 @@ import os import sys from functools import cache +from pathlib import Path +from typing import Iterable, Optional, Union import arf import h5py as h5 @@ -356,21 +358,22 @@ def delete_entries(src, entries, **options): repack_file(src, **options) -def copy_entries(tgt, files, **options): +def copy_entries( + tgt: Union[Path, str], + files: Iterable[Union[Path, str]], + entry_base: Optional[str] = None, + **options, +) -> None: """ Copy data from another arf file. Arguments can refer to entire arf files (just the filename) or specific entries (using path notation). Record IDs and all other metadata are copied with the entry. - entry_base: if specified, rename entries sequentially in target file + entry_base: if specified, rename entries sequentially in target file using this base """ - import posixpath as pp - from h5py import Group - ebase = options.get("template", None) acache = cache(arf.open_file) - with arf.open_file(tgt, "a") as arfp: arf.check_file_version(arfp) for f in files: @@ -378,27 +381,27 @@ def copy_entries(tgt, files, **options): # file.arf is a file; file.arf/entry is entry # dir/file.arf is a file; dir/file.arf/entry is entry # on windows, dir\file.arf/entry is an entry - pn, fn = pp.split(f) - if os.path.isfile(f): - it = ((f, entry) for ename, entry in acache(f, mode="r").items()) - elif os.path.isfile(pn, mode="r"): - fp = acache(pn) - if fn in fp: - it = ((pn, fp[fn]),) - else: + src = Path(f) + if src.is_file(): + items = ((src, entry) for _, entry in acache(f, mode="r").items()) + elif src.parent.is_file(): + fp = acache(src.parent, mode="r") + try: + items = ((src.parent, fp[src.name]),) + except KeyError: log.error("unable to copy %s: no such entry", f) continue else: log.error("unable to copy %s: does not exist", f) continue - for fname, entry in it: - if ebase is not None: + for fname, entry in items: + if entry_base is not None: entry_name = default_entry_template.format( - base=ebase, index=arf.count_children(arfp, Group) + base=entry_base, index=arf.count_children(arfp, Group) ) else: - entry_name = pp.basename(entry.name) + entry_name = Path(entry.name).name arfp.copy(entry, arfp, name=entry_name) log.debug("%s%s -> %s/%s", fname, entry.name, tgt, entry_name) diff --git a/test/test_core.py b/test/test_core.py new file mode 100644 index 0000000..09ee92d --- /dev/null +++ b/test/test_core.py @@ -0,0 +1,127 @@ +# -*- coding: utf-8 -*- +# -*- mode: python -*- +import time + +import arf +import numpy as np +import pytest +from numpy.random import randint, randn + +from arfx import core + +entry_base = "entry_%03d" +tstamp = time.mktime(time.localtime()) +entry_attributes = { + "intattr": 1, + "vecattr": [1, 2, 3], + "arrattr": randn(5), + "strattr": "an attribute", +} +datasets = [ + dict( + name="acoustic", + data=randn(100000), + sampling_rate=20000, + datatype=arf.DataTypes.ACOUSTIC, + maxshape=(None,), + microphone="DK-1234", + compression=0, + ), + dict( + name="neural", + data=(randn(100000) * 2**16).astype("h"), + sampling_rate=20000, + datatype=arf.DataTypes.EXTRAC_HP, + compression=9, + ), + dict( + name="multichannel", + data=randn(10000, 2), + sampling_rate=20000, + datatype=arf.DataTypes.ACOUSTIC, + ), + dict( + name="spikes", + data=randint(0, 100000, 100), + datatype=arf.DataTypes.SPIKET, + units="samples", + sampling_rate=20000, # required + ), + dict( + name="empty-spikes", + data=np.array([], dtype="f"), + datatype=arf.DataTypes.SPIKET, + method="broken", + maxshape=(None,), + units="s", + ), + dict( + name="events", + data=np.rec.fromrecords( + [(1.0, 1, b"stimulus"), (5.0, 0, b"stimulus")], + names=("start", "state", "name"), + ), # 'start' required + datatype=arf.DataTypes.EVENT, + units=(b"s", b"", b""), + ), # only bytes supported by h5py +] + + +@pytest.fixture +def src_file(tmp_path): + path = tmp_path / "input.arf" + with arf.open_file(path, "w") as fp: + entry = arf.create_entry(fp, "entry", tstamp) + for dset in datasets: + _ = arf.create_dataset(entry, **dset) + return path + + +def test_copy_file(src_file, tmp_path): + tgt_file = tmp_path / "output.arf" + core.copy_entries(tgt_file, [src_file]) + + with arf.open_file(tgt_file, "r") as fp: + entry = fp["/entry"] + assert len(entry) == len(datasets) + assert set(entry.keys()) == set(dset["name"] for dset in datasets) + # this will fail if iteration is not in order of creation + for dset, d in zip(datasets, entry.values()): + assert d.shape == dset["data"].shape + assert not arf.is_entry(d) + + +def test_copy_files(src_file, tmp_path): + tgt_file = tmp_path / "output.arf" + with pytest.raises(RuntimeError): + # names will collide and produce error after copying one entry + core.copy_entries(tgt_file, [src_file, src_file]) + + core.copy_entries(tgt_file, [src_file, src_file], entry_base="new_entry") + fp = arf.open_file(tgt_file, "r") + print(fp.keys()) + assert len(fp) == 3 + for i in range(2): + entry_name = core.default_entry_template.format(base="new_entry", index=i + 1) + entry = fp[entry_name] + assert len(entry) == len(datasets) + assert set(entry.keys()) == set(dset["name"] for dset in datasets) + # this will fail if iteration is not in order of creation + for dset, d in zip(datasets, entry.values()): + assert d.shape == dset["data"].shape + assert not arf.is_entry(d) + + +def test_copy_entry(src_file, tmp_path): + tgt_file = tmp_path / "output.arf" + src_entry = src_file / "entry" + core.copy_entries(tgt_file, [src_entry]) + + with arf.open_file(tgt_file, "r") as fp: + entry = fp["/entry"] + assert len(entry) == len(datasets) + assert set(entry.keys()) == set(dset["name"] for dset in datasets) + # this will fail if iteration is not in order of creation + for dset, d in zip(datasets, entry.values()): + assert d.shape == dset["data"].shape + assert not arf.is_entry(d)