diff --git a/CHANGELOG.md b/CHANGELOG.md index 8e6997e68..df0022b40 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,8 @@ dj.FreeTable(dj.conn(), "common_session.session_group").drop() - Update DataJoint to 0.14.2 #1081 - Allow restriction based on parent keys in `Merge.fetch_nwb()` #1086, #1126 - Import `datajoint.dependencies.unite_master_parts` -> `topo_sort` #1116 +- Allow definition of tasks and new probe entries from config #1074, #1120 +- Enforce match between ingested nwb probe geometry and existing table entry #1074 ### Pipelines diff --git a/src/spyglass/common/common_task.py b/src/spyglass/common/common_task.py index 9b8368192..744d5d773 100644 --- a/src/spyglass/common/common_task.py +++ b/src/spyglass/common/common_task.py @@ -134,8 +134,8 @@ def make(self, key): # schema if it isn't there and then add an entry for each epoch tasks_mod = nwbf.processing.get("tasks") - config_tasks = config.get("Tasks") - if tasks_mod is None and config_tasks is None: + config_tasks = config.get("Tasks", []) + if tasks_mod is None and (not config_tasks): logger.warn( f"No tasks processing module found in {nwbf} or config\n" ) @@ -236,12 +236,10 @@ def get_epoch_interval_name(cls, epoch, session_intervals): if target_interval in interval ] if not possible_targets: - logger.warn( - f"Interval not found for epoch {epoch} in {nwb_file_name}." - ) + logger.warn(f"Interval not found for epoch {epoch}.") elif len(possible_targets) > 1: logger.warn( - f"Multiple intervals found for epoch {epoch} in {nwb_file_name}. " + f"Multiple intervals found for epoch {epoch}. " + f"matches are {possible_targets}." ) else: diff --git a/tests/common/test_behav.py b/tests/common/test_behav.py index bcfd50270..0295023df 100644 --- a/tests/common/test_behav.py +++ b/tests/common/test_behav.py @@ -101,8 +101,11 @@ def test_pos_interval_no_transaction(verbose_context, common, mini_restr): common.PositionIntervalMap()._no_transaction_make(mini_restr) after = common.PositionIntervalMap().fetch() assert ( - len(after) == len(before) + 2 + len(after) == len(before) + 3 ), "PositionIntervalMap no_transaction had unexpected effect" + assert ( + "" in after["position_interval_name"] + ), "PositionIntervalMap null insert failed" def test_get_pos_interval_name(pos_src, pos_interval_01):