Skip to content

Commit

Permalink
add some tests for core; rework copy_entries
Browse files Browse the repository at this point in the history
  • Loading branch information
dmeliza committed Jan 2, 2025
1 parent be90646 commit b70b69c
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 18 deletions.
39 changes: 21 additions & 18 deletions arfx/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import os
import sys
from functools import cache
from pathlib import Path
from typing import Iterable, Optional, Union

import arf
import h5py as h5
Expand Down Expand Up @@ -356,49 +358,50 @@ def delete_entries(src, entries, **options):
repack_file(src, **options)


def copy_entries(tgt, files, **options):
def copy_entries(
tgt: Union[Path, str],
files: Iterable[Union[Path, str]],
entry_base: Optional[str] = None,
**options,
) -> None:
"""
Copy data from another arf file. Arguments can refer to entire arf
files (just the filename) or specific entries (using path
notation). Record IDs and all other metadata are copied with the entry.
entry_base: if specified, rename entries sequentially in target file
entry_base: if specified, rename entries sequentially in target file using this base
"""
import posixpath as pp

from h5py import Group

ebase = options.get("template", None)
acache = cache(arf.open_file)

with arf.open_file(tgt, "a") as arfp:
arf.check_file_version(arfp)
for f in files:
# this is a bit tricky:
# file.arf is a file; file.arf/entry is entry
# dir/file.arf is a file; dir/file.arf/entry is entry
# on windows, dir\file.arf/entry is an entry
pn, fn = pp.split(f)
if os.path.isfile(f):
it = ((f, entry) for ename, entry in acache(f, mode="r").items())
elif os.path.isfile(pn, mode="r"):
fp = acache(pn)
if fn in fp:
it = ((pn, fp[fn]),)
else:
src = Path(f)
if src.is_file():
items = ((src, entry) for _, entry in acache(f, mode="r").items())
elif src.parent.is_file():
fp = acache(src.parent, mode="r")
try:
items = ((src.parent, fp[src.name]),)
except KeyError:
log.error("unable to copy %s: no such entry", f)
continue
else:
log.error("unable to copy %s: does not exist", f)
continue

for fname, entry in it:
if ebase is not None:
for fname, entry in items:
if entry_base is not None:
entry_name = default_entry_template.format(
base=ebase, index=arf.count_children(arfp, Group)
base=entry_base, index=arf.count_children(arfp, Group)
)
else:
entry_name = pp.basename(entry.name)
entry_name = Path(entry.name).name
arfp.copy(entry, arfp, name=entry_name)
log.debug("%s%s -> %s/%s", fname, entry.name, tgt, entry_name)

Expand Down
127 changes: 127 additions & 0 deletions test/test_core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# -*- coding: utf-8 -*-
# -*- mode: python -*-
import time

import arf
import numpy as np
import pytest
from numpy.random import randint, randn

from arfx import core

entry_base = "entry_%03d"
tstamp = time.mktime(time.localtime())
entry_attributes = {
"intattr": 1,
"vecattr": [1, 2, 3],
"arrattr": randn(5),
"strattr": "an attribute",
}
datasets = [
dict(
name="acoustic",
data=randn(100000),
sampling_rate=20000,
datatype=arf.DataTypes.ACOUSTIC,
maxshape=(None,),
microphone="DK-1234",
compression=0,
),
dict(
name="neural",
data=(randn(100000) * 2**16).astype("h"),
sampling_rate=20000,
datatype=arf.DataTypes.EXTRAC_HP,
compression=9,
),
dict(
name="multichannel",
data=randn(10000, 2),
sampling_rate=20000,
datatype=arf.DataTypes.ACOUSTIC,
),
dict(
name="spikes",
data=randint(0, 100000, 100),
datatype=arf.DataTypes.SPIKET,
units="samples",
sampling_rate=20000, # required
),
dict(
name="empty-spikes",
data=np.array([], dtype="f"),
datatype=arf.DataTypes.SPIKET,
method="broken",
maxshape=(None,),
units="s",
),
dict(
name="events",
data=np.rec.fromrecords(
[(1.0, 1, b"stimulus"), (5.0, 0, b"stimulus")],
names=("start", "state", "name"),
), # 'start' required
datatype=arf.DataTypes.EVENT,
units=(b"s", b"", b""),
), # only bytes supported by h5py
]


@pytest.fixture
def src_file(tmp_path):
path = tmp_path / "input.arf"
with arf.open_file(path, "w") as fp:
entry = arf.create_entry(fp, "entry", tstamp)
for dset in datasets:
_ = arf.create_dataset(entry, **dset)
return path


def test_copy_file(src_file, tmp_path):
tgt_file = tmp_path / "output.arf"
core.copy_entries(tgt_file, [src_file])

with arf.open_file(tgt_file, "r") as fp:
entry = fp["/entry"]
assert len(entry) == len(datasets)
assert set(entry.keys()) == set(dset["name"] for dset in datasets)
# this will fail if iteration is not in order of creation
for dset, d in zip(datasets, entry.values()):
assert d.shape == dset["data"].shape
assert not arf.is_entry(d)


def test_copy_files(src_file, tmp_path):
tgt_file = tmp_path / "output.arf"
with pytest.raises(RuntimeError):
# names will collide and produce error after copying one entry
core.copy_entries(tgt_file, [src_file, src_file])

core.copy_entries(tgt_file, [src_file, src_file], entry_base="new_entry")
fp = arf.open_file(tgt_file, "r")
print(fp.keys())
assert len(fp) == 3
for i in range(2):
entry_name = core.default_entry_template.format(base="new_entry", index=i + 1)
entry = fp[entry_name]
assert len(entry) == len(datasets)
assert set(entry.keys()) == set(dset["name"] for dset in datasets)
# this will fail if iteration is not in order of creation
for dset, d in zip(datasets, entry.values()):
assert d.shape == dset["data"].shape
assert not arf.is_entry(d)


def test_copy_entry(src_file, tmp_path):
tgt_file = tmp_path / "output.arf"
src_entry = src_file / "entry"
core.copy_entries(tgt_file, [src_entry])

with arf.open_file(tgt_file, "r") as fp:
entry = fp["/entry"]
assert len(entry) == len(datasets)
assert set(entry.keys()) == set(dset["name"] for dset in datasets)
# this will fail if iteration is not in order of creation
for dset, d in zip(datasets, entry.values()):
assert d.shape == dset["data"].shape
assert not arf.is_entry(d)

0 comments on commit b70b69c

Please sign in to comment.