From 45710f826ba193752640dacd8ee923f9eaf39e91 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Tue, 12 Dec 2023 16:50:02 -0600 Subject: [PATCH] #700 only --- CHANGELOG.md | 2 + src/spyglass/utils/dj_merge_tables.py | 63 +++++++++++++++++++-------- 2 files changed, 47 insertions(+), 18 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 437e27493..116810a74 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ - Clean up following pre-commit checks. #688 - Add Mixin class to centralize `fetch_nwb` functionality. #692 - Minor fixes to LinearizedPositionV1 pipeline #695 +- Add SpikeSorting V1 pipeline #651 +- Refactor restriction use in `delete_downstream_merge` ## [0.4.3] (November 7, 2023) diff --git a/src/spyglass/utils/dj_merge_tables.py b/src/spyglass/utils/dj_merge_tables.py index 2e184c66a..9d5cc6197 100644 --- a/src/spyglass/utils/dj_merge_tables.py +++ b/src/spyglass/utils/dj_merge_tables.py @@ -1,3 +1,4 @@ +import re from contextlib import nullcontext from itertools import chain as iter_chain from pprint import pprint @@ -35,9 +36,9 @@ def __init__(self): f"\n {self._reserved_pk}: uuid\n ---\n" + f" {self._reserved_sk}: varchar({RESERVED_SK_LENGTH})\n " ) - # TODO: Change warnings to logger. Throw error? - CBroz1 if not self.is_declared: - if self.definition != merge_def: + # remove comments after # from each line of definition + if self._remove_comments(self.definition) != merge_def: print( "WARNING: merge table with non-default definition\n\t" + f"Expected: {merge_def.strip()}\n\t" @@ -51,6 +52,12 @@ def __init__(self): + f"\n\tActual : {part.primary_key}" ) + def _remove_comments(self, definition): + """Use regular expressions to remove comments and blank lines""" + return re.sub( # First remove comments, then blank lines + r"\n\s*\n", "\n", re.sub(r"#.*\n", "\n", definition) + ) + @classmethod def _merge_restrict_parts( cls, @@ -676,9 +683,10 @@ def merge_populate(source: str, key=None): def delete_downstream_merge( table: dj.Table, - restriction: str = True, + restriction: str = None, dry_run=True, recurse_level=2, + disable_warning=False, **kwargs, ) -> list: """Given a table/restriction, id or delete relevant downstream merge entries @@ -695,6 +703,8 @@ def delete_downstream_merge( downstream of table input. Otherwise, delete merge/part table entries. recurse_level: int Default 2. Depth to recurse into table descendants. + disable_warning: bool + Default False. If True, don't warn about restrictions on table object. kwargs: dict Additional keyword arguments for DataJoint delete. @@ -703,20 +713,25 @@ def delete_downstream_merge( List[Tuple[dj.Table, dj.Table]] Entries in merge/part tables downstream of table input. """ - if table.restriction: + if not disable_warning and restriction is None and table().restriction: print( - f"Warning: ignoring table restriction: {table.restriction}.\n\t" + f"Warning: ignoring table restriction: {table().restriction}.\n\t" + "Please pass restrictions as an arg" ) + if not restriction: + restriction = True descendants = _unique_descendants(table, recurse_level) - merge_table_pairs = _master_table_pairs(descendants, restriction) + merge_table_pairs = _master_table_pairs( + table_list=descendants, + restricted_parent=(table & restriction), + ) # restrict the merge table based on uuids in part + # don't need part for del, but show on dry_run merge_pairs = [ - (merge & uuids, part) # don't need part for del, but show on dry_run + (merge & part.fetch(RESERVED_PRIMARY_KEY, as_dict=True), part) for merge, part in merge_table_pairs - for uuids in part.fetch(RESERVED_PRIMARY_KEY, as_dict=True) ] if dry_run: @@ -779,7 +794,7 @@ def recurse_descendants(sub_table, level): def _master_table_pairs( table_list: list, - restriction: str = True, + restricted_parent: dj.expression.QueryExpression = True, connection: dj.connection.Connection = None, ) -> list: """ @@ -792,8 +807,9 @@ def _master_table_pairs( ---------- table_list : List[dj.Table] A list of datajoint tables. - restriction : str - A restriction string. Default True, no restriction. + restricted_parent : dj.expression.QueryExpression + Parent table restricted, to be joined with master and part. Default + True, no restriction. connection : datajoint.connection.Connection A database connection. Default None, use connection from first table. @@ -805,22 +821,33 @@ def _master_table_pairs( conn = connection or table_list[0].connection master_table_pairs = [] + unique_parts = [] + # Adapted from Spyglass PR 535 for table in table_list: - master_name = get_master(table.full_table_name) + table_name = table.full_table_name + if table_name in unique_parts: # then repeat in list + continue + + master_name = get_master(table_name) if not master_name: # then it's not a part table continue master = dj.FreeTable(conn, master_name) - if RESERVED_PRIMARY_KEY not in master.heading.attributes.keys(): - continue - - restricted_table = table.restrict(restriction) + continue # then it's not a merge table - if not restricted_table: + restricted_join = restricted_parent * table + if not restricted_join: # No entries relevant to restriction in part continue - master_table_pairs.append((master, restricted_table)) + unique_parts.append(table_name) + master_table_pairs.append( + ( + master, + table + & restricted_join.fetch(RESERVED_PRIMARY_KEY, as_dict=True), + ) + ) return master_table_pairs