Skip to content

Commit

Permalink
add some tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dmeliza committed Dec 29, 2024
1 parent 7fbb312 commit 70a3955
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 72 deletions.
11 changes: 2 additions & 9 deletions arf.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def convert_timestamp(obj: Timestamp) -> ArfTimeStamp:
else:
try:
out[:2] = obj[:2]
except IndexError as err:
except (IndexError, ValueError) as err:
raise TypeError(f"unable to convert {obj} to timestamp") from err
return out

Expand Down Expand Up @@ -439,18 +439,11 @@ def set_uuid(obj: h5.HLObject, uuid: Union[str, bytes, UUID, None] = None):
def get_uuid(obj: h5.HLObject) -> UUID:
"""Return the uuid for obj, or null uuid if none is set"""
# TODO: deprecate null uuid ret val
from uuid import UUID

try:
uuid = obj.attrs["uuid"]
except KeyError:
return UUID(int=0)
# convert to unicode for python 3
try:
uuid = uuid.decode("ascii")
except (LookupError, AttributeError):
pass
return UUID(uuid)
return UUID(uuid.decode("ascii"))


def count_children(obj: h5.HLObject, type=None) -> int:
Expand Down
173 changes: 110 additions & 63 deletions tests/test_arf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

import time

import numpy as nx
import numpy as np
import pytest
import h5py as h5
from h5py.version import version as h5py_version
from numpy.random import randint, randn
from packaging import version
from uuid import UUID, uuid4

import arf

Expand Down Expand Up @@ -35,6 +37,12 @@
datatype=arf.DataTypes.EXTRAC_HP,
compression=9,
),
dict(
name="multichannel",
data=randn(10000, 2),
sampling_rate=20000,
datatype=arf.DataTypes.ACOUSTIC
),
dict(
name="spikes",
data=randint(0, 100000, 100),
Expand All @@ -44,15 +52,15 @@
),
dict(
name="empty-spikes",
data=nx.array([], dtype="f"),
data=np.array([], dtype="f"),
datatype=arf.DataTypes.SPIKET,
method="broken",
maxshape=(None,),
units="s",
),
dict(
name="events",
data=nx.rec.fromrecords(
data=np.rec.fromrecords(
[(1.0, 1, b"stimulus"), (5.0, 0, b"stimulus")],
names=("start", "state", "name"),
), # 'start' required
Expand All @@ -70,19 +78,19 @@
),
dict(
name="missing start field",
data=nx.rec.fromrecords([(1.0, 1), (2.0, 2)], names=("time", "state")),
data=np.rec.fromrecords([(1.0, 1), (2.0, 2)], names=("time", "state")),
units="s",
),
dict(
name="missing units for complex dtype",
data=nx.rec.fromrecords(
data=np.rec.fromrecords(
[(1.0, 1, b"stimulus"), (5.0, 0, b"stimulus")],
names=("start", "state", "name"),
),
),
dict(
name="wrong length units for complex dtype",
data=nx.rec.fromrecords(
data=np.rec.fromrecords(
[(1.0, 1, b"stimulus"), (5.0, 0, b"stimulus")],
names=("start", "state", "name"),
),
Expand All @@ -92,114 +100,149 @@


@pytest.fixture
def test_file(tmp_path):
def tmp_file(tmp_path):
path = tmp_path / "test"
fp = arf.open_file(path, "w", driver="core", backing_store=False)
yield fp
fp.close()


@pytest.fixture
def test_entry(test_file):
return arf.create_entry(test_file, "entry", tstamp)
def tmp_entry(tmp_file):
return arf.create_entry(tmp_file, "entry", tstamp)


@pytest.fixture
def tmp_dataset(tmp_entry):
return arf.create_dataset(tmp_entry, **datasets[2])


@pytest.fixture
def test_dataset(test_entry):
return arf.create_dataset(test_entry, **datasets[2])
def read_only_file(tmp_path):
path = tmp_path / "test"
fp = arf.open_file(path, "w")
entry = arf.create_entry(fp, "entry", tstamp)
for dset in datasets:
d = arf.create_dataset(entry, **dset)
fp.close()
return arf.open_file(path, "r")


def test_created_datasets(read_only_file):
tmp_entry = read_only_file["/entry"]
assert len(tmp_entry) == len(datasets)
assert set(tmp_entry.keys()) == set(dset["name"] for dset in datasets)
for dset, d in zip(datasets, tmp_entry.values(), strict=True):
assert d.shape == dset["data"].shape
assert not arf.is_entry(d)


def test_child_type_counts(read_only_file):
assert arf.count_children(read_only_file) == 1
assert arf.count_children(read_only_file, h5.Group) == 1
assert arf.count_children(read_only_file, h5.Dataset) == 0
entry = read_only_file["/entry"]
assert arf.count_children(entry) == len(datasets)
assert arf.count_children(entry, h5.Group) == 0
assert arf.count_children(entry, h5.Dataset) == len(datasets)

def test00_create_entries(test_file):

def test_channel_counts(read_only_file):
dset1 = read_only_file["/entry/acoustic"]
assert arf.count_channels(dset1) == 1
dset2 = read_only_file["/entry/multichannel"]
assert arf.count_channels(dset2) == 2

def test_create_entries(tmp_file):
N = 5
for i in range(N):
name = entry_base % i
g = arf.create_entry(test_file, name, tstamp, **entry_attributes)
assert name in test_file
g = arf.create_entry(tmp_file, name, tstamp, **entry_attributes)
assert name in tmp_file
assert arf.is_entry(g)
assert arf.timestamp_to_float(g.attrs["timestamp"]) > 0
for k in entry_attributes:
assert k in g.attrs
assert len(test_file) == N
assert len(tmp_file) == N


def test01_create_existing_entry(test_file, test_entry):
def test_create_existing_entry(tmp_file, tmp_entry):
with pytest.raises(ValueError):
arf.create_entry(test_file, "entry", tstamp, **entry_attributes)
arf.create_entry(tmp_file, "entry", tstamp, **entry_attributes)


def test02_create_datasets(test_entry):
for dset in datasets:
d = arf.create_dataset(test_entry, **dset)
assert d.shape == dset["data"].shape
assert not arf.is_entry(d)
assert len(test_entry) == len(datasets)
assert set(test_entry.keys()) == set(dset["name"] for dset in datasets)


def test04_create_bad_dataset(test_entry):
def test_create_bad_dataset(tmp_entry):
for dset in bad_datasets:
with pytest.raises(ValueError):
_ = arf.create_dataset(test_entry, **dset)
_ = arf.create_dataset(tmp_entry, **dset)


def test05_set_attributes(test_entry):
def test_set_attributes(tmp_entry):
"""tests the set_attributes convenience function"""
arf.set_attributes(test_entry, mystr="myvalue", myint=5000)
assert test_entry.attrs["myint"] == 5000
assert test_entry.attrs["mystr"] == "myvalue"
arf.set_attributes(test_entry, mystr=None)
assert "mystr" not in test_entry.attrs
arf.set_attributes(tmp_entry, mystr="myvalue", myint=5000)
assert tmp_entry.attrs["myint"] == 5000
assert tmp_entry.attrs["mystr"] == "myvalue"
arf.set_attributes(tmp_entry, mystr="blah blah", overwrite=False)
assert tmp_entry.attrs["mystr"] == "myvalue"
arf.set_attributes(tmp_entry, mystr=None)
assert "mystr" not in tmp_entry.attrs


def test06_null_uuid(test_entry):
def test_set_null_uuid(tmp_entry):
# nulls in a uuid can make various things barf
from uuid import UUID
uuid = UUID(bytes=b"".rjust(16, b"\0"))
arf.set_uuid(tmp_entry, uuid)
assert arf.get_uuid(tmp_entry) == uuid

def test_get_null_uuid(tmp_entry):
uuid = UUID(bytes=b"".rjust(16, b"\0"))
arf.set_uuid(test_entry, uuid)
assert arf.get_uuid(test_entry) == uuid
del tmp_entry.attrs["uuid"]
assert arf.get_uuid(tmp_entry) == uuid


def test_set_uuid_with_bytes(tmp_entry):
uuid = uuid4()
arf.set_uuid(tmp_entry, uuid.bytes)
assert arf.get_uuid(tmp_entry) == uuid


def test07_copy_entry_with_attrs(test_file, test_entry):
src_entry_attrs = dict(test_entry.attrs)
def test_copy_entry_with_attrs(tmp_file, tmp_entry):
src_entry_attrs = dict(tmp_entry.attrs)
src_entry_timestamp = src_entry_attrs.pop("timestamp")
tgt_entry = arf.create_entry(
test_file, "copied_entry", src_entry_timestamp, **src_entry_attrs
tmp_file, "copied_entry", src_entry_timestamp, **src_entry_attrs
)
assert test_entry.attrs["uuid"] == tgt_entry.attrs["uuid"]
assert tmp_entry.attrs["uuid"] == tgt_entry.attrs["uuid"]


def test08_check_file_version(test_file):
arf.check_file_version(test_file)
def test_check_file_version(tmp_file):
arf.check_file_version(tmp_file)


def test09_append_to_table(test_file):
dtype = nx.dtype({"names": ("f1", "f2"), "formats": [nx.uint, nx.int32]})
dset = arf.create_table(test_file, "test", dtype=dtype)
def test_append_to_table(tmp_file):
dtype = np.dtype({"names": ("f1", "f2"), "formats": [np.uint, np.int32]})
dset = arf.create_table(tmp_file, "test", dtype=dtype)
assert dset.shape[0] == 0
arf.append_data(dset, (5, 10))
assert dset.shape[0] == 1

def test_append_nothing(tmp_file):
data = np.random.randn(100)
dset = arf.create_dataset(tmp_file, "test", data=data, sampling_rate=1)
arf.append_data(dset, np.random.randn(0))
assert dset.shape == data.shape

@pytest.mark.skipif(
version.Version(h5py_version) < version.Version("2.2"),
reason="not supported on h5py < 2.2",
)
def test01_creation_iter(test_file):
def test_creation_iter(tmp_file):
# self.fp = arf.open_file("test06", mode="a", driver="core", backing_store=False)
entry_names = ["z", "y", "a", "q", "zzyfij"]
for name in entry_names:
g = arf.create_entry(test_file, name, 0)
g = arf.create_entry(tmp_file, name, 0)
arf.create_dataset(g, "dset", (1,), sampling_rate=1)
assert list(arf.keys_by_creation(test_file)) == entry_names
assert list(arf.keys_by_creation(tmp_file)) == entry_names


@pytest.mark.skipif(
version.Version(h5py_version) < version.Version("2.2"),
reason="not supported on h5py < 2.2",
)
def test10_select_from_timeseries(test_file):
entry = arf.create_entry(test_file, "entry", tstamp)
def test_select_from_timeseries(tmp_file):
entry = arf.create_entry(tmp_file, "entry", tstamp)
for data in datasets:
arf.create_dataset(entry, **data)
dset = entry[data["name"]]
Expand All @@ -208,23 +251,27 @@ def test10_select_from_timeseries(test_file):
else:
selected, offset = arf.select_interval(dset, 0.0, 1.0)
if arf.is_time_series(dset):
nx.testing.assert_array_equal(
np.testing.assert_array_equal(
selected, data["data"][: data["sampling_rate"]]
)


def test01_timestamp_conversion():
def test_timestamp_conversion():
from datetime import datetime

dt = datetime.now()
ts = arf.convert_timestamp(dt)
assert arf.timestamp_to_datetime(ts) == dt
assert all(arf.convert_timestamp(ts) == ts)
# lose the sub-second resolution
assert arf.convert_timestamp(dt.timetuple())[0] == ts[0]
ts = arf.convert_timestamp(1000)
assert int(arf.timestamp_to_float(ts)) == 1000
with pytest.raises(TypeError):
arf.convert_timestamp("blah blah")


def test99_various():
def test_various():
# test some functions difficult to cover otherwise
arf.DataTypes._doc()
arf.DataTypes._todict()
Expand Down

0 comments on commit 70a3955

Please sign in to comment.