diff --git a/src/trodes_to_nwb/convert_dios.py b/src/trodes_to_nwb/convert_dios.py index 8161c53..cdcac66 100644 --- a/src/trodes_to_nwb/convert_dios.py +++ b/src/trodes_to_nwb/convert_dios.py @@ -70,26 +70,29 @@ def add_dios(nwbfile: NWBFile, recfile: list[str], metadata: dict) -> None: prefix = "ECU_" break - for channel_name in channel_name_map: - # merge streams from multiple files - all_timestamps = [] - all_state_changes = [] - for io in neo_io: + all_timestamps = [[] for _ in channel_name_map] + all_state_changes = [[] for _ in channel_name_map] + # Loop through io objects and get timestamps and state changes for each channel + for io in neo_io: + for i, channel_name in enumerate(channel_name_map): timestamps, state_changes = io.get_digitalsignal( stream_name, prefix + channel_name ) - all_timestamps.append(timestamps) - all_state_changes.append(state_changes) - all_timestamps = np.concatenate(all_timestamps) - all_state_changes = np.concatenate(all_state_changes) - assert isinstance(all_timestamps[0], np.float64) - assert isinstance(all_timestamps, np.ndarray) + all_timestamps[i].append(timestamps) + all_state_changes[i].append(state_changes) + for channel_name, state_changes, timestamps in zip( + channel_name_map, all_state_changes, all_timestamps + ): + timestamps = np.concatenate(timestamps) + state_changes = np.concatenate(state_changes) + assert isinstance(timestamps[0], np.float64) + assert isinstance(timestamps, np.ndarray) ts = TimeSeries( name=channel_name_map[channel_name], description=channel_name, - data=all_state_changes, + data=state_changes, unit="-1", # TODO change to "N/A", - timestamps=all_timestamps, # TODO adjust timestamps + timestamps=timestamps, # TODO adjust timestamps ) beh_events.add_timeseries(ts)