Skip to content

Commit

Permalink
Pytest WIP. Position centriod fix. Centralize device prompt logic
Browse files Browse the repository at this point in the history
  • Loading branch information
CBroz1 committed Dec 29, 2023
1 parent 8800f99 commit 03be7d4
Show file tree
Hide file tree
Showing 23 changed files with 974 additions and 688 deletions.
10 changes: 0 additions & 10 deletions .github/workflows/test-conda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,6 @@ jobs:
env:
OS: ${{ matrix.os }}
PYTHON: '3.8'
# SPYGLASS_BASE_DIR: ./data
# KACHERY_STORAGE_DIR: ./data/kachery-storage
# DJ_SUPPORT_FILEPATH_MANAGEMENT: True
# services:
# datajoint_test_server:
# image: datajoint/mysql
# ports:
# - 3306:3306
# options: >-
# -e MYSQL_ROOT_PASSWORD=tutorial
steps:
- name: Cancel Workflow Action
uses: styfle/[email protected]
Expand Down
42 changes: 40 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,10 @@ spyglass_cli = "spyglass.cli:cli"
"Homepage" = "https://github.com/LorenFrankLab/spyglass"
"Bug Tracker" = "https://github.com/LorenFrankLab/spyglass/issues"

[project.optional-dependencies]
[project.optional-dependencies]
position = ["ffmpeg", "numba>=0.54", "deeplabcut<2.3.0"]
test = [
"docker", # for tests in a container
"pytest", # unit testing
"pytest-cov", # code coverage
"kachery", # database access
Expand Down Expand Up @@ -110,5 +111,42 @@ line-length = 80

[tool.codespell]
skip = '.git,*.pdf,*.svg,*.ipynb,./docs/site/**,temp*'
ignore-words-list = 'nevers'
# Nevers - name in Citation
ignore-words-list = 'nevers'

[tool.pytest.ini_options]
minversion = "7.0"
addopts = [
"-sv",
"-p no:warnings",
"--no-teardown",
"--quiet-spy",
"--show-capture=no",
"--pdbcls=IPython.terminal.debugger:TerminalPdb", # use ipython debugger
"--cov=spyglass",
"--cov-report=term-missing",
"--no-cov-on-fail",
]
testpaths = ["tests"]
log_level = "INFO"

[tool.coverage.run]
source = ["*/src/spyglass/*"]
omit = [
"*/__init__.py",
"*/_version.py",
"*/cli/*",
# "*/common/*",
# "*/data_import/*",
"*/decoding/*",
"*/figurl_views/*",
"*/lfp/*",
"*/linearization/*",
"*/lock/*",
"*/position/*",
"*/ripple/*",
"*/sharing/*",
"*/spikesorting/*",
# "*/utils/*",
]

158 changes: 70 additions & 88 deletions src/spyglass/common/common_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import ndx_franklab_novela

from spyglass.common.errors import PopulateException
from spyglass.utils.dj_mixin import SpyglassMixin
from spyglass.utils.logging import logger
from spyglass.settings import test_mode
from spyglass.utils import SpyglassMixin, logger
from spyglass.utils.nwb_helper_fn import get_nwb_file

schema = dj.schema("common_device")
Expand Down Expand Up @@ -154,25 +154,9 @@ def _add_device(cls, new_device_dict):
all_values = DataAcquisitionDevice.fetch(
"data_acquisition_device_name"
).tolist()
if name not in all_values:
# no entry with the same name exists, prompt user to add a new entry
logger.info(
f"\nData acquisition device '{name}' was not found in the "
f"database. The current values are: {all_values}. "
"Please ensure that the device you want to add does not already"
" exist in the database under a different name or spelling. "
"If you want to use an existing device in the database, "
"please change the corresponding Device object in the NWB file."
" Entering 'N' will raise an exception."
)
to_db = " to the database"
val = input(f"Add data acquisition device '{name}'{to_db}? (y/N)")
if val.lower() in ["y", "yes"]:
cls.insert1(new_device_dict, skip_duplicates=True)
return
raise PopulateException(
f"User chose not to add device '{name}'{to_db}."
)
if prompt_insert(name=name, all_values=all_values):
cls.insert1(new_device_dict, skip_duplicates=True)
return

# Check if values provided match the values stored in the database
db_dict = (
Expand Down Expand Up @@ -213,28 +197,11 @@ def _add_system(cls, system):
all_values = DataAcquisitionDeviceSystem.fetch(
"data_acquisition_device_system"
).tolist()
if system not in all_values:
logger.info(
f"\nData acquisition device system '{system}' was not found in"
f" the database. The current values are: {all_values}. "
"Please ensure that the system you want to add does not already"
" exist in the database under a different name or spelling. "
"If you want to use an existing system in the database, "
"please change the corresponding Device object in the NWB file."
" Entering 'N' will raise an exception."
)
val = input(
f"Do you want to add data acquisition device system '{system}'"
+ " to the database? (y/N)"
)
if val.lower() in ["y", "yes"]:
key = {"data_acquisition_device_system": system}
DataAcquisitionDeviceSystem.insert1(key, skip_duplicates=True)
else:
raise PopulateException(
"User chose not to add data acquisition device system "
+ f"'{system}' to the database."
)
if prompt_insert(
name=system, all_values=all_values, table_type="system"
):
key = {"data_acquisition_device_system": system}
DataAcquisitionDeviceSystem.insert1(key, skip_duplicates=True)
return system

@classmethod
Expand Down Expand Up @@ -264,30 +231,11 @@ def _add_amplifier(cls, amplifier):
all_values = DataAcquisitionDeviceAmplifier.fetch(
"data_acquisition_device_amplifier"
).tolist()
if amplifier not in all_values:
logger.info(
f"\nData acquisition device amplifier '{amplifier}' was not "
f"found in the database. The current values are: {all_values}. "
"Please ensure that the amplifier you want to add does not "
"already exist in the database under a different name or "
"spelling. If you want to use an existing name in the database,"
" please change the corresponding Device object in the NWB "
"file. Entering 'N' will raise an exception."
)
val = input(
"Do you want to add data acquisition device amplifier "
+ f"'{amplifier}' to the database? (y/N)"
)
if val.lower() in ["y", "yes"]:
key = {"data_acquisition_device_amplifier": amplifier}
DataAcquisitionDeviceAmplifier.insert1(
key, skip_duplicates=True
)
else:
raise PopulateException(
"User chose not to add data acquisition device amplifier "
+ f"'{amplifier}' to the database."
)
if prompt_insert(
name=amplifier, all_values=all_values, table_type="amplifier"
):
key = {"data_acquisition_device_amplifier": amplifier}
DataAcquisitionDeviceAmplifier.insert1(key, skip_duplicates=True)
return amplifier


Expand Down Expand Up @@ -576,27 +524,9 @@ def _add_probe_type(cls, new_probe_type_dict):
"""
probe_type = new_probe_type_dict["probe_type"]
all_values = ProbeType.fetch("probe_type").tolist()
if probe_type not in all_values:
logger.info(
f"\nProbe type '{probe_type}' was not found in the database. "
f"The current values are: {all_values}. "
"Please ensure that the probe type you want to add does not "
"already exist in the database under a different name or "
"spelling. If you want to use an existing name in the "
"database, please change the corresponding Probe object in the "
"NWB file. Entering 'N' will raise an exception."
)
val = input(
f"Do you want to add probe type '{probe_type}' to the database?"
+ " (y/N)"
)
if val.lower() in ["y", "yes"]:
ProbeType.insert1(new_probe_type_dict, skip_duplicates=True)
return
raise PopulateException(
f"User chose not to add probe type '{probe_type}' to the "
+ "database."
)
if prompt_insert(probe_type, all_values, table="probe type"):
ProbeType.insert1(new_probe_type_dict, skip_duplicates=True)
return

# else / entry exists: check whether the values provided match the
# values stored in the database
Expand Down Expand Up @@ -738,3 +668,55 @@ def create_from_nwbfile(
cls.Shank.insert1(shank, skip_duplicates=True)
for electrode in elect_dict.values():
cls.Electrode.insert1(electrode, skip_duplicates=True)


# ---------------------------- Helper functions ----------------------------


# Migrated down to reduce redundancy and centralize 'test_mode' check for pytest
def prompt_insert(
name: str,
all_values: list,
table: str = "Data Acquisition Device",
table_type: str = None,
) -> bool:
"""Prompt user to add an item to the database. Return True if yes.
Assume insert during test mode.
Parameters
----------
name : str
The name of the item to add.
all_values : list
List of all values in the database.
table : str, optional
The name of the table to add to, by default Data Acquisition Device
table_type : str, optional
The type of item to add, by default None. Data Acquisition Device X
"""
if name in all_values:
return False

if test_mode:
return True

if table_type:
table_type += " "

logger.info(
f"{table}{table_type} '{name}' was not found in the"
f"database. The current values are: {all_values}.\n"
"Please ensure that the device you want to add does not already"
"exist in the database under a different name or spelling. If you"
"want to use an existing device in the database, please change the"
"corresponding Device object in the NWB file.\nEntering 'N' will "
"raise an exception."
)
msg = f"Do you want to add {table}{table_type} '{name}' to the database?"
if dj.utils.user_choice(msg).lower() in ["y", "yes"]:
return True

raise PopulateException(
f"User chose not to add {table}{table_type} '{name}' to the database."
)
4 changes: 2 additions & 2 deletions src/spyglass/common/common_position.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pynwb.behavior
from position_tools import (
get_angle,
get_centriod,
get_centroid,
get_distance,
get_speed,
get_velocity,
Expand Down Expand Up @@ -417,7 +417,7 @@ def calculate_position_info(
)

# Calculate position, orientation, velocity, speed
position = get_centriod(back_LED, front_LED) # cm
position = get_centroid(back_LED, front_LED) # cm

orientation = get_angle(back_LED, front_LED) # radians
is_nan = np.isnan(orientation)
Expand Down
23 changes: 9 additions & 14 deletions src/spyglass/common/common_session.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
import datajoint as dj

from spyglass.common.common_device import (
CameraDevice,
DataAcquisitionDevice,
Probe,
)
from spyglass.common.common_device import CameraDevice, DataAcquisitionDevice, Probe
from spyglass.common.common_lab import Institution, Lab, LabMember
from spyglass.common.common_nwbfile import Nwbfile
from spyglass.common.common_subject import Subject
Expand Down Expand Up @@ -63,13 +59,15 @@ def make(self, key):
nwbf = get_nwb_file(nwb_file_abspath)
config = get_config(nwb_file_abspath)

# certain data are not associated with a single NWB file / session because they may apply to
# multiple sessions. these data go into dj.Manual tables.
# e.g., a lab member may be associated with multiple experiments, so the lab member table should not
# be dependent on (contain a primary key for) a session.
# certain data are not associated with a single NWB file / session
# because they may apply to multiple sessions. these data go into
# dj.Manual tables. e.g., a lab member may be associated with multiple
# experiments, so the lab member table should not be dependent on
# (contain a primary key for) a session.

# here, we create new entries in these dj.Manual tables based on the values read from the NWB file
# then, they are linked to the session via fields of Session (e.g., Subject, Institution, Lab) or part
# here, we create new entries in these dj.Manual tables based on the
# values read from the NWB file then, they are linked to the session
# via fields of Session (e.g., Subject, Institution, Lab) or part
# tables (e.g., Experimenter, DataAcquisitionDevice).

logger.info("Institution...")
Expand All @@ -87,15 +85,12 @@ def make(self, key):
if not debug_mode: # TODO: remove when demo files agree on device
logger.info("Populate DataAcquisitionDevice...")
DataAcquisitionDevice.insert_from_nwbfile(nwbf, config)
logger.info()

logger.info("Populate CameraDevice...")
CameraDevice.insert_from_nwbfile(nwbf)
logger.info()

logger.info("Populate Probe...")
Probe.insert_from_nwbfile(nwbf, config)
logger.info()

if nwbf.subject is not None:
subject_id = nwbf.subject.subject_id
Expand Down
1 change: 1 addition & 0 deletions src/spyglass/data_import/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
# TODO: change naming to avoid match between module and function
from .insert_sessions import insert_sessions
2 changes: 1 addition & 1 deletion src/spyglass/data_import/insert_sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def copy_nwb_link_raw_ephys(nwb_file_name, out_nwb_file_name):
if os.path.exists(out_nwb_file_abs_path):
if debug_mode:
return out_nwb_file_abs_path
warnings.warn(
logger.warn(
f"Output file {out_nwb_file_abs_path} exists and will be "
+ "overwritten."
)
Expand Down
Loading

0 comments on commit 03be7d4

Please sign in to comment.