Skip to content

Commit

Permalink
Merge pull request #72 from LorenFrankLab/analog_multiplex_select_bug
Browse files Browse the repository at this point in the history
Analog multiplex select bug
  • Loading branch information
edeno authored Nov 27, 2023
2 parents 2a1d7be + 499e7a9 commit 9b18ad5
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 5 deletions.
4 changes: 2 additions & 2 deletions src/trodes_to_nwb/convert_analog.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import pynwb
from hdmf.backends.hdf5 import H5DataIO
from pynwb import NWBFile
from trodes_to_nwb.convert_ephys import RecFileDataChunkIterator

from trodes_to_nwb import convert_rec_header
from trodes_to_nwb.convert_ephys import RecFileDataChunkIterator


def add_analog_data(
Expand Down Expand Up @@ -53,7 +53,7 @@ def add_analog_data(
# by studies by the NWB team.
# could also add compression here. zstd/blosc-zstd are recommended by the NWB team, but
# they require the hdf5plugin library to be installed. gzip is available by default.
data_data_io = H5DataIO(rec_dci, chunks=(16384, min(rec_dci.n_channel, 32)))
data_data_io = H5DataIO(rec_dci, chunks=(16384, min(len(analog_channel_ids), 32)))

# make the objects to add to the nwb file
nwbfile.create_processing_module(
Expand Down
28 changes: 28 additions & 0 deletions src/trodes_to_nwb/convert_ephys.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,34 @@ def _get_data(self, selection: Tuple[slice]) -> np.ndarray:
)

data = (np.concatenate(data) * self.conversion).astype("int16")
# Handle the appended multiplex data
if (
self.neo_io[0].header["signal_streams"][self.stream_index]["id"]
== "ECU_analog"
) and self.is_analog:
multiplex_keys = self.neo_io[0].multiplexed_channel_xml.keys()
n_multiplex = len(multiplex_keys)
n_analog = (
self.n_channel
) # number of non-multiplexed channels in the dataset
n_analog_selected = data.shape[1] - n_multiplex
return_indices = np.arange(
n_analog_selected
) # include all non-multiplexed channels pulled
# determine which multiplex channels are being requested
if (
selection[1].stop > n_analog
): # if multiplexed channels are being requested
requested_multiplex = np.arange(n_multiplex) + n_analog_selected
multiplex_slice = slice(
max(selection[1].start - n_analog, 0),
max(selection[1].stop - n_analog, 0),
selection[1].step,
)
requested_multiplex = requested_multiplex[multiplex_slice]
return_indices = np.append(return_indices, requested_multiplex)
data = data[:, return_indices]

return data

def _get_maxshape(self) -> Tuple[int, int]:
Expand Down
41 changes: 38 additions & 3 deletions src/trodes_to_nwb/tests/test_convert_analog.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import os

import pynwb

from trodes_to_nwb import convert_rec_header, convert_yaml
from trodes_to_nwb.convert_analog import add_analog_data, get_analog_channel_names
from trodes_to_nwb.convert_ephys import RecFileDataChunkIterator
from trodes_to_nwb.tests.test_convert_rec_header import default_test_xml_tree
from trodes_to_nwb.tests.utils import data_path

from trodes_to_nwb import convert_rec_header, convert_yaml


def test_add_analog_data():
# load metadata yml and make nwb file
Expand All @@ -31,7 +32,7 @@ def test_add_analog_data():
assert "analog" in read_nwbfile.processing["analog"]["analog"].time_series
assert read_nwbfile.processing["analog"]["analog"]["analog"].data.chunks == (
16384,
12,
22,
)

with pynwb.NWBHDF5IO(rec_to_nwb_file, "r", load_namespaces=True) as io2:
Expand Down Expand Up @@ -68,3 +69,37 @@ def test_add_analog_data():
).all()
# cleanup
os.remove(filename)


def test_selection_of_multiplexed_data():
rec_file = data_path / "20230622_sample_01_a1.rec"
rec_header = convert_rec_header.read_header(rec_file)
hconf = rec_header.find("HardwareConfiguration")
ecu_conf = None
for conf in hconf:
if conf.attrib["name"] == "ECU":
ecu_conf = conf
break
analog_channel_ids = []
for channel in ecu_conf:
if channel.attrib["dataType"] == "analog":
analog_channel_ids.append(channel.attrib["id"])
assert (len(analog_channel_ids)) == 12
rec_dci = RecFileDataChunkIterator(
[rec_file],
nwb_hw_channel_order=analog_channel_ids,
stream_index=2,
is_analog=True,
)
assert len(rec_dci.neo_io[0].multiplexed_channel_xml.keys()) == 10
slice_ind = [(0, 4), (0, 30), (1, 15), (5, 15), (20, 25)]
expected_channels = [4, 22, 14, 10, 2]
for ind, expected in zip(slice_ind, expected_channels):
data = rec_dci._get_data(
(
slice(0, 100, None),
slice(ind[0], ind[1], None),
)
)
assert data.shape[1] == expected
return

0 comments on commit 9b18ad5

Please sign in to comment.