Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
CBroz1 committed Dec 12, 2023
1 parent 91ce55f commit 45710f8
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 18 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
- Clean up following pre-commit checks. #688
- Add Mixin class to centralize `fetch_nwb` functionality. #692
- Minor fixes to LinearizedPositionV1 pipeline #695
- Add SpikeSorting V1 pipeline #651
- Refactor restriction use in `delete_downstream_merge`

## [0.4.3] (November 7, 2023)

Expand Down
63 changes: 45 additions & 18 deletions src/spyglass/utils/dj_merge_tables.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
from contextlib import nullcontext
from itertools import chain as iter_chain
from pprint import pprint
Expand Down Expand Up @@ -35,9 +36,9 @@ def __init__(self):
f"\n {self._reserved_pk}: uuid\n ---\n"
+ f" {self._reserved_sk}: varchar({RESERVED_SK_LENGTH})\n "
)
# TODO: Change warnings to logger. Throw error? - CBroz1
if not self.is_declared:
if self.definition != merge_def:
# remove comments after # from each line of definition
if self._remove_comments(self.definition) != merge_def:
print(
"WARNING: merge table with non-default definition\n\t"
+ f"Expected: {merge_def.strip()}\n\t"
Expand All @@ -51,6 +52,12 @@ def __init__(self):
+ f"\n\tActual : {part.primary_key}"
)

def _remove_comments(self, definition):
"""Use regular expressions to remove comments and blank lines"""
return re.sub( # First remove comments, then blank lines
r"\n\s*\n", "\n", re.sub(r"#.*\n", "\n", definition)
)

@classmethod
def _merge_restrict_parts(
cls,
Expand Down Expand Up @@ -676,9 +683,10 @@ def merge_populate(source: str, key=None):

def delete_downstream_merge(
table: dj.Table,
restriction: str = True,
restriction: str = None,
dry_run=True,
recurse_level=2,
disable_warning=False,
**kwargs,
) -> list:
"""Given a table/restriction, id or delete relevant downstream merge entries
Expand All @@ -695,6 +703,8 @@ def delete_downstream_merge(
downstream of table input. Otherwise, delete merge/part table entries.
recurse_level: int
Default 2. Depth to recurse into table descendants.
disable_warning: bool
Default False. If True, don't warn about restrictions on table object.
kwargs: dict
Additional keyword arguments for DataJoint delete.
Expand All @@ -703,20 +713,25 @@ def delete_downstream_merge(
List[Tuple[dj.Table, dj.Table]]
Entries in merge/part tables downstream of table input.
"""
if table.restriction:
if not disable_warning and restriction is None and table().restriction:
print(
f"Warning: ignoring table restriction: {table.restriction}.\n\t"
f"Warning: ignoring table restriction: {table().restriction}.\n\t"
+ "Please pass restrictions as an arg"
)
if not restriction:
restriction = True

descendants = _unique_descendants(table, recurse_level)
merge_table_pairs = _master_table_pairs(descendants, restriction)
merge_table_pairs = _master_table_pairs(
table_list=descendants,
restricted_parent=(table & restriction),
)

# restrict the merge table based on uuids in part
# don't need part for del, but show on dry_run
merge_pairs = [
(merge & uuids, part) # don't need part for del, but show on dry_run
(merge & part.fetch(RESERVED_PRIMARY_KEY, as_dict=True), part)
for merge, part in merge_table_pairs
for uuids in part.fetch(RESERVED_PRIMARY_KEY, as_dict=True)
]

if dry_run:
Expand Down Expand Up @@ -779,7 +794,7 @@ def recurse_descendants(sub_table, level):

def _master_table_pairs(
table_list: list,
restriction: str = True,
restricted_parent: dj.expression.QueryExpression = True,
connection: dj.connection.Connection = None,
) -> list:
"""
Expand All @@ -792,8 +807,9 @@ def _master_table_pairs(
----------
table_list : List[dj.Table]
A list of datajoint tables.
restriction : str
A restriction string. Default True, no restriction.
restricted_parent : dj.expression.QueryExpression
Parent table restricted, to be joined with master and part. Default
True, no restriction.
connection : datajoint.connection.Connection
A database connection. Default None, use connection from first table.
Expand All @@ -805,22 +821,33 @@ def _master_table_pairs(
conn = connection or table_list[0].connection

master_table_pairs = []
unique_parts = []

# Adapted from Spyglass PR 535
for table in table_list:
master_name = get_master(table.full_table_name)
table_name = table.full_table_name
if table_name in unique_parts: # then repeat in list
continue

master_name = get_master(table_name)
if not master_name: # then it's not a part table
continue

master = dj.FreeTable(conn, master_name)

if RESERVED_PRIMARY_KEY not in master.heading.attributes.keys():
continue

restricted_table = table.restrict(restriction)
continue # then it's not a merge table

if not restricted_table:
restricted_join = restricted_parent * table
if not restricted_join: # No entries relevant to restriction in part
continue

master_table_pairs.append((master, restricted_table))
unique_parts.append(table_name)
master_table_pairs.append(
(
master,
table
& restricted_join.fetch(RESERVED_PRIMARY_KEY, as_dict=True),
)
)

return master_table_pairs

0 comments on commit 45710f8

Please sign in to comment.