Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/LorenFrankLab/spyglass in…
Browse files Browse the repository at this point in the history
…to dev
  • Loading branch information
CBroz1 committed Dec 21, 2023
2 parents 6f8c6c7 + 00bb398 commit 89db508
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 14 deletions.
2 changes: 1 addition & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"editor.defaultFormatter": "ms-python.black-formatter",
"editor.formatOnSave": true,
"editor.codeActionsOnSave": {
"source.organizeImports": true
"source.organizeImports": "explicit"
},
},
"isort.args": [
Expand Down
2 changes: 1 addition & 1 deletion src/spyglass/spikesorting/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


@schema
class SpikeSortingOutput(_Merge):
class SpikeSortingOutput(_Merge, SpyglassMixin):
definition = """
# Output of spike sorting pipelines.
merge_id: uuid
Expand Down
2 changes: 2 additions & 0 deletions src/spyglass/spikesorting/v1/recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,8 @@ def _consolidate_intervals(intervals, timestamps):
"""
# Convert intervals to a numpy array if it's not
intervals = np.array(intervals)
if intervals.ndim == 1:
intervals = intervals.reshape(-1, 2)
if intervals.shape[1] != 2:
raise ValueError(
"Input array must have shape (N, 2) where N is the number of intervals."
Expand Down
57 changes: 45 additions & 12 deletions src/spyglass/spikesorting/v1/sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,11 +289,34 @@ def get_sorting(cls, key: dict) -> si.BaseSorting:
"""

recording_id = (
SpikeSortingRecording * SpikeSortingSelection & key
).fetch1("recording_id")
recording = SpikeSortingRecording.get_recording(
{"recording_id": recording_id}
)
sampling_frequency = recording.get_sampling_frequency()
analysis_file_name = (cls & key).fetch1("analysis_file_name")
analysis_file_abs_path = AnalysisNwbfile.get_abs_path(
analysis_file_name
)
sorting = se.read_nwb_sorting(analysis_file_abs_path)
with pynwb.NWBHDF5IO(
analysis_file_abs_path, "r", load_namespaces=True
) as io:
nwbf = io.read()
units = nwbf.units.to_dataframe()
units_dict_list = [
{
unit_id: np.searchsorted(recording.get_times(), spike_times)
for unit_id, spike_times in zip(
units.index, units["spike_times"]
)
}
]

sorting = si.NumpySorting.from_unit_dict(
units_dict_list, sampling_frequency=sampling_frequency
)

return sorting

Expand Down Expand Up @@ -330,18 +353,28 @@ def _write_sorting_to_nwb(
load_namespaces=True,
) as io:
nwbf = io.read()
nwbf.add_unit_column(
name="curation_label",
description="curation label applied to a unit",
)
for unit_id in sorting.get_unit_ids():
spike_times = sorting.get_unit_spike_train(unit_id)
nwbf.add_unit(
spike_times=timestamps[spike_times],
id=unit_id,
obs_intervals=sort_interval,
curation_label="uncurated",
if sorting.get_num_units() == 0:
nwbf.units = pynwb.misc.Units(
name="units", description="Empty units table."
)
else:
nwbf.add_unit_column(
name="curation_label",
description="curation label applied to a unit",
)
obs_interval = (
sort_interval
if sort_interval.ndim == 2
else sort_interval.reshape(1, 2)
)
for unit_id in sorting.get_unit_ids():
spike_times = sorting.get_unit_spike_train(unit_id)
nwbf.add_unit(
spike_times=timestamps[spike_times],
id=unit_id,
obs_intervals=obs_interval,
curation_label="uncurated",
)
units_object_id = nwbf.units.object_id
io.write(nwbf)
return analysis_nwb_file, units_object_id

0 comments on commit 89db508

Please sign in to comment.