Skip to content

Commit

Permalink
WIP: Add coverage, see details:
Browse files Browse the repository at this point in the history
- Add `return_fig` param to plotting helper functions to permit tests
  - `common_filter`
  - `common_interval`
- Add coverage for ~1/2 of `common`
  - `common_behav`
  - `common_device`
  - `common_ephys`
  - `common_filter`
  - `common_interval` - with helper funcs tested seperately
  - `common_lab`
  - `common_nwbfile` - partial
  • Loading branch information
CBroz1 committed Jan 9, 2024
1 parent 7ff01ad commit dec3655
Show file tree
Hide file tree
Showing 12 changed files with 523 additions and 11 deletions.
12 changes: 8 additions & 4 deletions src/spyglass/common/common_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,9 @@ def add_filter(
def _filter_restrict(self, filter_name, fs):
return (
self & {"filter_name": filter_name} & {"filter_sampling_rate": fs}
).fetch1(as_dict=True)
).fetch1()

def plot_magnitude(self, filter_name, fs):
def plot_magnitude(self, filter_name, fs, return_fig=False):
filter_dict = self._filter_restrict(filter_name, fs)
plt.figure()
w, h = signal.freqz(filter_dict["filter_coeff"], worN=65536)
Expand All @@ -178,11 +178,13 @@ def plot_magnitude(self, filter_name, fs):
plt.xlabel("Frequency (Hz)")
plt.ylabel("Magnitude")
plt.title("Frequency Response")
plt.xlim(0, np.max(filter_dict["filter_coeffand_edges"] * 2))
plt.xlim(0, np.max(filter_dict["filter_band_edges"] * 2))
plt.ylim(np.min(magnitude), -1 * np.min(magnitude) * 0.1)
plt.grid(True)
if return_fig:
return plt.gcf()

def plot_fir_filter(self, filter_name, fs):
def plot_fir_filter(self, filter_name, fs, return_fig=False):
filter_dict = self._filter_restrict(filter_name, fs)
plt.figure()
plt.clf()
Expand All @@ -191,6 +193,8 @@ def plot_fir_filter(self, filter_name, fs):
plt.ylabel("Magnitude")
plt.title("Filter Taps")
plt.grid(True)
if return_fig:
return plt.gcf()

def filter_delay(self, filter_name, fs):
return self.calc_filter_delay(
Expand Down
8 changes: 6 additions & 2 deletions src/spyglass/common/common_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def insert_from_nwbfile(cls, nwbf, *, nwb_file_name):

cls.insert1(epoch_dict, skip_duplicates=True)

def plot_intervals(self, figsize=(20, 5)):
def plot_intervals(self, figsize=(20, 5), return_fig=False):
interval_list = pd.DataFrame(self)
fig, ax = plt.subplots(figsize=figsize)
interval_count = 0
Expand All @@ -83,8 +83,10 @@ def plot_intervals(self, figsize=(20, 5)):
ax.set_yticklabels(interval_list.interval_list_name)
ax.set_xlabel("Time [s]")
ax.grid(True)
if return_fig:
return fig

def plot_epoch_pos_raw_intervals(self, figsize=(20, 5)):
def plot_epoch_pos_raw_intervals(self, figsize=(20, 5), return_fig=False):
interval_list = pd.DataFrame(self)
fig, ax = plt.subplots(figsize=(30, 3))

Expand Down Expand Up @@ -144,6 +146,8 @@ def plot_epoch_pos_raw_intervals(self, figsize=(20, 5)):
ax.set_yticklabels(["pos valid times", "raw data valid times", "epoch"])
ax.set_xlabel("Time [s]")
ax.grid(True)
if return_fig:
return fig


def intervals_by_length(interval_list, min_length=0.0, max_length=1e10):
Expand Down
5 changes: 5 additions & 0 deletions tests/common/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,8 @@ def pos_src(common):
@pytest.fixture(scope="session")
def pos_interval_01(pos_src):
yield [pos_src.get_pos_interval_name(x) for x in range(1)]


@pytest.fixture(scope="session")
def common_ephys(common):
yield common.common_ephys
6 changes: 3 additions & 3 deletions tests/common/test_behav.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ def test_posinterval_no_transaction(verbose_context, common, mini_restr):
with verbose_context:
common.PositionIntervalMap()._no_transaction_make(mini_restr)
after = common.PositionIntervalMap().fetch()
assert array_equal(
before, after
assert (
len(after) == len(before) + 2
), "PositionIntervalMap no_transaction had unexpected effect"


def test_get_pos_interval_name(pos_src, mini_copy_name, pos_interval_01):
def test_get_pos_interval_name(pos_src, pos_interval_01):
"""Test get pos interval name"""
names = [f"pos {x} valid times" for x in range(1)]
assert pos_interval_01 == names, "get_pos_interval_name failed"
Expand Down
2 changes: 1 addition & 1 deletion tests/common/test_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_create_probe(common, mini_devices, mini_path, mini_copy_name):
probe_type = common.ProbeType.fetch("KEY", as_dict=True)[0]
before = common.Probe.fetch()
common.Probe.create_from_nwbfile(
nwb_file_name=mini_copy_name.split("/")[-1],
nwb_file_name=mini_copy_name,
nwb_device_name="probe 0",
contact_side_numbering=False,
**probe_id,
Expand Down
33 changes: 33 additions & 0 deletions tests/common/test_ephys.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import pytest
from numpy import array_equal


def test_create_from_config(mini_insert, common_ephys, mini_path):
before = common_ephys.Electrode().fetch()
common_ephys.Electrode.create_from_config(mini_path.stem)
after = common_ephys.Electrode().fetch()
# Because already inserted, expect no change
assert array_equal(
before, after
), "Electrode.create_from_config had unexpected effect"


def test_raw_object(mini_insert, common_ephys, mini_dict, mini_content):
obj_fetch = common_ephys.Raw().nwb_object(mini_dict).object_id
obj_raw = mini_content.get_acquisition().object_id
assert obj_fetch == obj_raw, "Raw.nwb_object did not return expected object"


def test_set_lfp_electrodes(mini_insert, common_ephys, mini_copy_name):
before = common_ephys.LFPSelection().fetch()
common_ephys.LFPSelection().set_lfp_electrodes(mini_copy_name, [0])
after = common_ephys.LFPSelection().fetch()
# Because already inserted, expect no change
assert (
len(after) == len(before) + 1
), "Set LFP electrodes had unexpected effect"


@pytest.mark.skip(reason="Not testing V0: common lfp")
def test_lfp():
pass
79 changes: 79 additions & 0 deletions tests/common/test_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import pytest


@pytest.fixture(scope="session")
def filter_parameters(common):
yield common.FirFilterParameters()


@pytest.fixture(scope="session")
def filter_dict(filter_parameters):
yield {"filter_name": "test", "fs": 10}


@pytest.fixture(scope="session")
def add_filter(filter_parameters, filter_dict):
filter_parameters.add_filter(
**filter_dict, filter_type="lowpass", band_edges=[1, 2]
)


@pytest.fixture(scope="session")
def filter_coeff(filter_parameters, filter_dict):
yield filter_parameters._filter_restrict(**filter_dict)["filter_coeff"]


def test_add_filter(filter_parameters, add_filter, filter_dict):
"""Test add filter"""
assert filter_parameters & filter_dict, "add_filter failed"


def test_filter_restrict(
filter_parameters, add_filter, filter_dict, filter_coeff
):
assert sum(filter_coeff) == pytest.approx(
0.999134, abs=1e-6
), "filter_restrict failed"


def test_plot_magitude(filter_parameters, add_filter, filter_dict):
fig = filter_parameters.plot_magnitude(**filter_dict, return_fig=True)
assert sum(fig.get_axes()[0].lines[0].get_xdata()) == pytest.approx(
163837.5, abs=1
), "plot_magnitude failed"


def test_plot_fir_filter(
filter_parameters, add_filter, filter_dict, filter_coeff
):
fig = filter_parameters.plot_fir_filter(**filter_dict, return_fig=True)
assert sum(fig.get_axes()[0].lines[0].get_ydata()) == sum(
filter_coeff
), "Plot filter failed"


def test_filter_delay(filter_parameters, add_filter, filter_dict):
delay = filter_parameters.filter_delay(**filter_dict)
assert delay == 27, "filter_delay failed"


def test_time_bound_warning(filter_parameters, add_filter, filter_dict):
with pytest.warns(UserWarning):
filter_parameters._time_bound_check(1, 3, [2, 5], 4)


@pytest.mark.skip(reason="Not testing V0: filter_data")
def test_filter_data(filter_parameters, mini_content):
pass


def test_calc_filter_delay(filter_parameters, filter_coeff):
delay = filter_parameters.calc_filter_delay(filter_coeff)
assert delay == 27, "filter_delay failed"


def test_create_standard_filters(filter_parameters):
filter_parameters.create_standard_filters()
assert filter_parameters & {
"filter_name": "LFP 0-400 Hz"
}, "create_standard_filters failed"
27 changes: 27 additions & 0 deletions tests/common/test_interval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import pytest
from numpy import array_equal


@pytest.fixture(scope="session")
def interval_list(common):
yield common.IntervalList()


def test_plot_intervals(mini_insert, interval_list):
fig = interval_list.plot_intervals(return_fig=True)
interval_list_name = fig.get_axes()[0].get_yticklabels()[0].get_text()
times_fetch = (
interval_list & {"interval_list_name": interval_list_name}
).fetch1("valid_times")[0]
times_plot = fig.get_axes()[0].lines[0].get_xdata()

assert array_equal(times_fetch, times_plot), "plot_intervals failed"


def test_plot_epoch(mini_insert, interval_list):
fig = interval_list.plot_epoch_pos_raw_intervals(return_fig=True)
epoch_label = fig.get_axes()[0].get_yticklabels()[-1].get_text()
assert epoch_label == "epoch", "plot_epoch failed"

epoch_interv = fig.get_axes()[0].lines[0].get_ydata()
assert array_equal(epoch_interv, [1, 1]), "plot_epoch failed"
Loading

0 comments on commit dec3655

Please sign in to comment.