From ff29a570a7a3599f94669cfcdc432a7ebe843705 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Wed, 29 Nov 2023 10:57:30 -0600 Subject: [PATCH] V0 migration model. Renaming - surgery required --- src/spyglass/common/__init__.py | 12 +- src/spyglass/common/common_position.py | 305 +++++++-------- src/spyglass/linearization/merge.py | 8 + src/spyglass/linearization/v0/__init__.py | 15 +- src/spyglass/linearization/v0/main.py | 429 +--------------------- src/spyglass/linearization/v1/__init__.py | 7 + src/spyglass/linearization/v1/main.py | 4 +- src/spyglass/utils/dj_helper_fn.py | 2 +- 8 files changed, 195 insertions(+), 587 deletions(-) diff --git a/src/spyglass/common/__init__.py b/src/spyglass/common/__init__.py index 540bcad30..1866f6055 100644 --- a/src/spyglass/common/__init__.py +++ b/src/spyglass/common/__init__.py @@ -70,9 +70,9 @@ from .prepopulate import populate_from_yaml, prepopulate_default from spyglass.linearization.v0 import ( # isort:skip - IntervalLinearizationSelection, - IntervalLinearizedPosition, - LinearizationParameters, + LinearizationParams, + LinearizedSelection, + LinearizedV0, TrackGraph, ) @@ -89,8 +89,6 @@ "ElectrodeGroup", "FirFilterParameters", "Institution", - "IntervalLinearizationSelection", - "IntervalLinearizedPosition", "IntervalList", "IntervalPositionInfo", "IntervalPositionInfoSelection", @@ -101,7 +99,9 @@ "Lab", "LabMember", "LabTeam", - "LinearizationParameters", + "LinearizationParams", + "LinearizedV0", + "LinearizedSelection", "Nwbfile", "NwbfileKachery", "PositionInfoParameters", diff --git a/src/spyglass/common/common_position.py b/src/spyglass/common/common_position.py index 93441983a..a12d976f3 100644 --- a/src/spyglass/common/common_position.py +++ b/src/spyglass/common/common_position.py @@ -496,157 +496,6 @@ def _data_to_df(data, prefix="head_", add_frame_ind=False): return df -# ------------------------------ Migrated Tables ------------------------------ - -from spyglass.linearization.v0 import main as linV0 # noqa: E402 - -( - LinearizationParameters, - TrackGraph, - IntervalLinearizationSelection, - IntervalLinearizedPosition, -) = deprecated_factory( - [ - ("LinearizationParameters", linV0.LinearizationParameters), - ("TrackGraph", linV0.TrackGraph), - ( - "IntervalLinearizationSelection", - linV0.IntervalLinearizationSelection, - ), - ( - "IntervalLinearizedPosition", - linV0.IntervalLinearizedPosition, - ), - ], - old_module=__name__, -) - - -class NodePicker: - """Interactive creation of track graph by looking at video frames.""" - - def __init__( - self, ax=None, video_filename=None, node_color="#1f78b4", node_size=100 - ): - if ax is None: - ax = plt.gca() - self.ax = ax - self.canvas = ax.get_figure().canvas - self.cid = None - self._nodes = [] - self.node_color = node_color - self._nodes_plot = ax.scatter( - [], [], zorder=5, s=node_size, color=node_color - ) - self.edges = [[]] - self.video_filename = video_filename - - if video_filename is not None: - self.video = cv2.VideoCapture(video_filename) - frame = self.get_video_frame() - ax.imshow(frame, picker=True) - ax.set_title( - "Left click to place node.\nRight click to remove node." - "\nShift+Left click to clear nodes." - "\nCntrl+Left click two nodes to place an edge" - ) - - self.connect() - - @property - def node_positions(self): - return np.asarray(self._nodes) - - def connect(self): - if self.cid is None: - self.cid = self.canvas.mpl_connect( - "button_press_event", self.click_event - ) - - def disconnect(self): - if self.cid is not None: - self.canvas.mpl_disconnect(self.cid) - self.cid = None - - def click_event(self, event): - if not event.inaxes: - return - if (event.key not in ["control", "shift"]) & ( - event.button == 1 - ): # left click - self._nodes.append((event.xdata, event.ydata)) - if (event.key not in ["control", "shift"]) & ( - event.button == 3 - ): # right click - self.remove_point((event.xdata, event.ydata)) - if (event.key == "shift") & (event.button == 1): - self.clear() - if (event.key == "control") & (event.button == 1): - point = (event.xdata, event.ydata) - distance_to_nodes = np.linalg.norm( - self.node_positions - point, axis=1 - ) - closest_node_ind = np.argmin(distance_to_nodes) - if len(self.edges[-1]) < 2: - self.edges[-1].append(closest_node_ind) - else: - self.edges.append([closest_node_ind]) - - self.redraw() - - def redraw(self): - # Draw Node Circles - if len(self.node_positions) > 0: - self._nodes_plot.set_offsets(self.node_positions) - else: - self._nodes_plot.set_offsets([]) - - # Draw Node Numbers - self.ax.texts = [] - for ind, (x, y) in enumerate(self.node_positions): - self.ax.text( - x, - y, - ind, - zorder=6, - fontsize=12, - horizontalalignment="center", - verticalalignment="center", - clip_on=True, - bbox=None, - transform=self.ax.transData, - ) - # Draw Edges - self.ax.lines = [] # clears the existing lines - for edge in self.edges: - if len(edge) > 1: - x1, y1 = self.node_positions[edge[0]] - x2, y2 = self.node_positions[edge[1]] - self.ax.plot( - [x1, x2], [y1, y2], color=self.node_color, linewidth=2 - ) - - self.canvas.draw_idle() - - def remove_point(self, point): - if len(self._nodes) > 0: - distance_to_nodes = np.linalg.norm( - self.node_positions - point, axis=1 - ) - closest_node_ind = np.argmin(distance_to_nodes) - self._nodes.pop(closest_node_ind) - - def clear(self): - self._nodes = [] - self.edges = [[]] - self.redraw() - - def get_video_frame(self): - is_grabbed, frame = self.video.read() - if is_grabbed: - return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) - - @schema class PositionVideo(dj.Computed): """Creates a video of the computed head position and orientation as well as @@ -656,7 +505,6 @@ class PositionVideo(dj.Computed): definition = """ -> IntervalPositionInfo - --- """ def make(self, key): @@ -887,6 +735,159 @@ def make_video( cv2.destroyAllWindows() +# ------------------------------ Migrated Tables ------------------------------ + +from spyglass.linearization.v0 import main as linV0 # noqa: E402 + +( + LinearizationParameters, + TrackGraph, + IntervalLinearizationSelection, + IntervalLinearizedPosition, +) = deprecated_factory( + [ + ("LinearizationParameters", linV0.LinearizationParams), + ("TrackGraph", linV0.TrackGraph), + ( + "IntervalLinearizationSelection", + linV0.LinearizedSelection, + ), + ( + "IntervalLinearizedPosition", + linV0.LinearizedV0, + ), + ], + old_module=__name__, +) + +# ------------------------ Helper classes and functions ------------------------ + + +class NodePicker: + """Interactive creation of track graph by looking at video frames.""" + + def __init__( + self, ax=None, video_filename=None, node_color="#1f78b4", node_size=100 + ): + if ax is None: + ax = plt.gca() + self.ax = ax + self.canvas = ax.get_figure().canvas + self.cid = None + self._nodes = [] + self.node_color = node_color + self._nodes_plot = ax.scatter( + [], [], zorder=5, s=node_size, color=node_color + ) + self.edges = [[]] + self.video_filename = video_filename + + if video_filename is not None: + self.video = cv2.VideoCapture(video_filename) + frame = self.get_video_frame() + ax.imshow(frame, picker=True) + ax.set_title( + "Left click to place node.\nRight click to remove node." + "\nShift+Left click to clear nodes." + "\nCntrl+Left click two nodes to place an edge" + ) + + self.connect() + + @property + def node_positions(self): + return np.asarray(self._nodes) + + def connect(self): + if self.cid is None: + self.cid = self.canvas.mpl_connect( + "button_press_event", self.click_event + ) + + def disconnect(self): + if self.cid is not None: + self.canvas.mpl_disconnect(self.cid) + self.cid = None + + def click_event(self, event): + if not event.inaxes: + return + if (event.key not in ["control", "shift"]) & ( + event.button == 1 + ): # left click + self._nodes.append((event.xdata, event.ydata)) + if (event.key not in ["control", "shift"]) & ( + event.button == 3 + ): # right click + self.remove_point((event.xdata, event.ydata)) + if (event.key == "shift") & (event.button == 1): + self.clear() + if (event.key == "control") & (event.button == 1): + point = (event.xdata, event.ydata) + distance_to_nodes = np.linalg.norm( + self.node_positions - point, axis=1 + ) + closest_node_ind = np.argmin(distance_to_nodes) + if len(self.edges[-1]) < 2: + self.edges[-1].append(closest_node_ind) + else: + self.edges.append([closest_node_ind]) + + self.redraw() + + def redraw(self): + # Draw Node Circles + if len(self.node_positions) > 0: + self._nodes_plot.set_offsets(self.node_positions) + else: + self._nodes_plot.set_offsets([]) + + # Draw Node Numbers + self.ax.texts = [] + for ind, (x, y) in enumerate(self.node_positions): + self.ax.text( + x, + y, + ind, + zorder=6, + fontsize=12, + horizontalalignment="center", + verticalalignment="center", + clip_on=True, + bbox=None, + transform=self.ax.transData, + ) + # Draw Edges + self.ax.lines = [] # clears the existing lines + for edge in self.edges: + if len(edge) > 1: + x1, y1 = self.node_positions[edge[0]] + x2, y2 = self.node_positions[edge[1]] + self.ax.plot( + [x1, x2], [y1, y2], color=self.node_color, linewidth=2 + ) + + self.canvas.draw_idle() + + def remove_point(self, point): + if len(self._nodes) > 0: + distance_to_nodes = np.linalg.norm( + self.node_positions - point, axis=1 + ) + closest_node_ind = np.argmin(distance_to_nodes) + self._nodes.pop(closest_node_ind) + + def clear(self): + self._nodes = [] + self.edges = [[]] + self.redraw() + + def get_video_frame(self): + is_grabbed, frame = self.video.read() + if is_grabbed: + return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + + def _fix_col_names(spatial_df): """Renames columns in spatial dataframe according to previous norm diff --git a/src/spyglass/linearization/merge.py b/src/spyglass/linearization/merge.py index e084037a2..9585341e3 100644 --- a/src/spyglass/linearization/merge.py +++ b/src/spyglass/linearization/merge.py @@ -1,5 +1,6 @@ import datajoint as dj +from spyglass.linearization.v0.main import LinearizedV0 # noqa F401 from spyglass.linearization.v1.main import LinearizedV1 # noqa F401 from ..utils.dj_merge_tables import _Merge @@ -15,6 +16,13 @@ class LinearizedOutput(_Merge): source: varchar(32) """ + class LinearizedV0(dj.Part): # noqa F811 + definition = """ + -> master + --- + -> LinearizedV0 + """ + class LinearizedV1(dj.Part): # noqa F811 definition = """ -> master diff --git a/src/spyglass/linearization/v0/__init__.py b/src/spyglass/linearization/v0/__init__.py index 438552bcf..0d054524b 100644 --- a/src/spyglass/linearization/v0/__init__.py +++ b/src/spyglass/linearization/v0/__init__.py @@ -1,6 +1,13 @@ -from .main import ( - IntervalLinearizationSelection, - IntervalLinearizedPosition, - LinearizationParameters, +from spyglass.linearization.v0.main import ( + LinearizationParams, + LinearizedSelection, + LinearizedV0, TrackGraph, ) + +__all__ = [ + "LinearizationParams", + "Linearized0", + "LinearizedSelection", + "TrackGraph", +] diff --git a/src/spyglass/linearization/v0/main.py b/src/spyglass/linearization/v0/main.py index e4b5a4932..fd13f9fa6 100644 --- a/src/spyglass/linearization/v0/main.py +++ b/src/spyglass/linearization/v0/main.py @@ -19,16 +19,11 @@ from spyglass.settings import raw_dir, video_dir from spyglass.utils.dj_helper_fn import fetch_nwb -schema = dj.schema("common_position") -# CBroz: I would rename 'linearization_v0', but would require db surgery -# Similarly, I would rename tables below and transfer contents -# - LinearizationParameters -> LinearizationParams -# - IntervalLinearizedSelection -> LinerarizedSeledtion -# - IntervalLinearizedPosition -> LinearizedV0 +schema = dj.schema("linearization_v0") @schema -class LinearizationParameters(dj.Lookup): +class LinearizationParams(dj.Lookup): """Choose whether to use an HMM to linearize position. This can help when the euclidean distances between separate arms are too @@ -112,20 +107,20 @@ def plot_track_graph_as_1D( @schema -class IntervalLinearizationSelection(dj.Lookup): +class LinearizedSelection(dj.Lookup): definition = """ -> IntervalPositionInfo -> TrackGraph - -> LinearizationParameters + -> LinearizationParams """ @schema -class IntervalLinearizedPosition(dj.Computed): +class LinearizedV0(dj.Computed): """Linearized position for a given interval""" definition = """ - -> IntervalLinearizationSelection + -> LinearizedSelection --- -> AnalysisNwbfile linearized_position_object_id : varchar(40) @@ -155,7 +150,7 @@ def make(self, key): ) linearization_parameters = ( - LinearizationParameters() + LinearizationParams() & {"linearization_param_name": key["linearization_param_name"]} ).fetch1() track_graph_info = ( @@ -204,413 +199,3 @@ def fetch_nwb(self, *attrs, **kwargs): def fetch1_dataframe(self): return self.fetch_nwb()[0]["linearized_position"].set_index("time") - - -class NodePicker: - """Interactive creation of track graph by looking at video frames.""" - - def __init__( - self, ax=None, video_filename=None, node_color="#1f78b4", node_size=100 - ): - if ax is None: - ax = plt.gca() - self.ax = ax - self.canvas = ax.get_figure().canvas - self.cid = None - self._nodes = [] - self.node_color = node_color - self._nodes_plot = ax.scatter( - [], [], zorder=5, s=node_size, color=node_color - ) - self.edges = [[]] - self.video_filename = video_filename - - if video_filename is not None: - self.video = cv2.VideoCapture(video_filename) - frame = self.get_video_frame() - ax.imshow(frame, picker=True) - ax.set_title( - "Left click to place node.\nRight click to remove node." - "\nShift+Left click to clear nodes." - "\nCntrl+Left click two nodes to place an edge" - ) - - self.connect() - - @property - def node_positions(self): - return np.asarray(self._nodes) - - def connect(self): - if self.cid is None: - self.cid = self.canvas.mpl_connect( - "button_press_event", self.click_event - ) - - def disconnect(self): - if self.cid is not None: - self.canvas.mpl_disconnect(self.cid) - self.cid = None - - def click_event(self, event): - if not event.inaxes: - return - if (event.key not in ["control", "shift"]) & ( - event.button == 1 - ): # left click - self._nodes.append((event.xdata, event.ydata)) - if (event.key not in ["control", "shift"]) & ( - event.button == 3 - ): # right click - self.remove_point((event.xdata, event.ydata)) - if (event.key == "shift") & (event.button == 1): - self.clear() - if (event.key == "control") & (event.button == 1): - point = (event.xdata, event.ydata) - distance_to_nodes = np.linalg.norm( - self.node_positions - point, axis=1 - ) - closest_node_ind = np.argmin(distance_to_nodes) - if len(self.edges[-1]) < 2: - self.edges[-1].append(closest_node_ind) - else: - self.edges.append([closest_node_ind]) - - self.redraw() - - def redraw(self): - # Draw Node Circles - if len(self.node_positions) > 0: - self._nodes_plot.set_offsets(self.node_positions) - else: - self._nodes_plot.set_offsets([]) - - # Draw Node Numbers - self.ax.texts = [] - for ind, (x, y) in enumerate(self.node_positions): - self.ax.text( - x, - y, - ind, - zorder=6, - fontsize=12, - horizontalalignment="center", - verticalalignment="center", - clip_on=True, - bbox=None, - transform=self.ax.transData, - ) - # Draw Edges - self.ax.lines = [] # clears the existing lines - for edge in self.edges: - if len(edge) > 1: - x1, y1 = self.node_positions[edge[0]] - x2, y2 = self.node_positions[edge[1]] - self.ax.plot( - [x1, x2], [y1, y2], color=self.node_color, linewidth=2 - ) - - self.canvas.draw_idle() - - def remove_point(self, point): - if len(self._nodes) > 0: - distance_to_nodes = np.linalg.norm( - self.node_positions - point, axis=1 - ) - closest_node_ind = np.argmin(distance_to_nodes) - self._nodes.pop(closest_node_ind) - - def clear(self): - self._nodes = [] - self.edges = [[]] - self.redraw() - - def get_video_frame(self): - is_grabbed, frame = self.video.read() - if is_grabbed: - return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) - - -@schema -class PositionVideo(dj.Computed): - """Creates a video of the computed head position and orientation as well as - the original LED positions overlaid on the video of the animal. - - Use for debugging the effect of position extraction parameters.""" - - definition = """ - -> IntervalPositionInfo - --- - """ - - def make(self, key): - M_TO_CM = 100 - - print("Loading position data...") - raw_position_df = ( - RawPosition() - & { - "nwb_file_name": key["nwb_file_name"], - "interval_list_name": key["interval_list_name"], - } - ).fetch1_dataframe() - position_info_df = ( - IntervalPositionInfo() - & { - "nwb_file_name": key["nwb_file_name"], - "interval_list_name": key["interval_list_name"], - "position_info_param_name": key["position_info_param_name"], - } - ).fetch1_dataframe() - - print("Loading video data...") - epoch = ( - int( - key["interval_list_name"] - .replace("pos ", "") - .replace(" valid times", "") - ) - + 1 - ) - video_info = ( - VideoFile() - & {"nwb_file_name": key["nwb_file_name"], "epoch": epoch} - ).fetch1() - io = pynwb.NWBHDF5IO(raw_dir + "/" + video_info["nwb_file_name"], "r") - nwb_file = io.read() - nwb_video = nwb_file.objects[video_info["video_file_object_id"]] - video_filename = nwb_video.external_file[0] - - nwb_base_filename = key["nwb_file_name"].replace(".nwb", "") - output_video_filename = ( - f"{nwb_base_filename}_{epoch:02d}_" - f'{key["position_info_param_name"]}.mp4' - ) - - # ensure standardized column names - raw_position_df = _fix_col_names(raw_position_df) - # if IntervalPositionInfo supersampled position, downsample to video - if position_info_df.shape[0] > raw_position_df.shape[0]: - ind = np.digitize( - raw_position_df.index, position_info_df.index, right=True - ) - position_info_df = position_info_df.iloc[ind] - - centroids = { - "red": np.asarray(raw_position_df[["xloc", "yloc"]]), - "green": np.asarray(raw_position_df[["xloc2", "yloc2"]]), - } - head_position_mean = np.asarray( - position_info_df[["head_position_x", "head_position_y"]] - ) - head_orientation_mean = np.asarray( - position_info_df[["head_orientation"]] - ) - video_time = np.asarray(nwb_video.timestamps) - position_time = np.asarray(position_info_df.index) - cm_per_pixel = nwb_video.device.meters_per_pixel * M_TO_CM - - print("Making video...") - self.make_video( - f"{video_dir}/{video_filename}", - centroids, - head_position_mean, - head_orientation_mean, - video_time, - position_time, - output_video_filename=output_video_filename, - cm_to_pixels=cm_per_pixel, - disable_progressbar=False, - ) - - @staticmethod - def convert_to_pixels(data, frame_size, cm_to_pixels=1.0): - """Converts from cm to pixels and flips the y-axis. - Parameters - ---------- - data : ndarray, shape (n_time, 2) - frame_size : array_like, shape (2,) - cm_to_pixels : float - - Returns - ------- - converted_data : ndarray, shape (n_time, 2) - """ - return data / cm_to_pixels - - @staticmethod - def fill_nan(variable, video_time, variable_time): - video_ind = np.digitize(variable_time, video_time[1:]) - - n_video_time = len(video_time) - try: - n_variable_dims = variable.shape[1] - filled_variable = np.full((n_video_time, n_variable_dims), np.nan) - except IndexError: - filled_variable = np.full((n_video_time,), np.nan) - filled_variable[video_ind] = variable - - return filled_variable - - def make_video( - self, - video_filename, - centroids, - head_position_mean, - head_orientation_mean, - video_time, - position_time, - output_video_filename="output.mp4", - cm_to_pixels=1.0, - disable_progressbar=False, - arrow_radius=15, - circle_radius=8, - ): - RGB_PINK = (234, 82, 111) - RGB_YELLOW = (253, 231, 76) - RGB_WHITE = (255, 255, 255) - - video = cv2.VideoCapture(video_filename) - fourcc = cv2.VideoWriter_fourcc(*"mp4v") - frame_size = (int(video.get(3)), int(video.get(4))) - frame_rate = video.get(5) - n_frames = int(head_orientation_mean.shape[0]) - - out = cv2.VideoWriter( - output_video_filename, fourcc, frame_rate, frame_size, True - ) - - centroids = { - color: self.fill_nan(data, video_time, position_time) - for color, data in centroids.items() - } - head_position_mean = self.fill_nan( - head_position_mean, video_time, position_time - ) - head_orientation_mean = self.fill_nan( - head_orientation_mean, video_time, position_time - ) - - for time_ind in tqdm( - range(n_frames - 1), desc="frames", disable=disable_progressbar - ): - is_grabbed, frame = video.read() - if is_grabbed: - frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) - - red_centroid = centroids["red"][time_ind] - green_centroid = centroids["green"][time_ind] - - head_position = head_position_mean[time_ind] - head_position = self.convert_to_pixels( - head_position, frame_size, cm_to_pixels - ) - head_orientation = head_orientation_mean[time_ind] - - if np.all(~np.isnan(red_centroid)): - cv2.circle( - img=frame, - center=tuple(red_centroid.astype(int)), - radius=circle_radius, - color=RGB_YELLOW, - thickness=-1, - shift=cv2.CV_8U, - ) - - if np.all(~np.isnan(green_centroid)): - cv2.circle( - img=frame, - center=tuple(green_centroid.astype(int)), - radius=circle_radius, - color=RGB_PINK, - thickness=-1, - shift=cv2.CV_8U, - ) - - if np.all(~np.isnan(head_position)) & np.all( - ~np.isnan(head_orientation) - ): - arrow_tip = ( - int( - head_position[0] - + arrow_radius * np.cos(head_orientation) - ), - int( - head_position[1] - + arrow_radius * np.sin(head_orientation) - ), - ) - cv2.arrowedLine( - img=frame, - pt1=tuple(head_position.astype(int)), - pt2=arrow_tip, - color=RGB_WHITE, - thickness=4, - line_type=8, - shift=cv2.CV_8U, - tipLength=0.25, - ) - - if np.all(~np.isnan(head_position)): - cv2.circle( - img=frame, - center=tuple(head_position.astype(int)), - radius=circle_radius, - color=RGB_WHITE, - thickness=-1, - shift=cv2.CV_8U, - ) - - frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) - out.write(frame) - else: - break - - video.release() - out.release() - cv2.destroyAllWindows() - - -def _fix_col_names(spatial_df): - """Renames columns in spatial dataframe according to previous norm - - Accepts unnamed first led, 1 or 0 indexed. - Prompts user for confirmation of renaming unexpected columns. - For backwards compatibility, renames to "xloc", "yloc", "xloc2", "yloc2" - """ - - DEFAULT_COLS = ["xloc", "yloc", "xloc2", "yloc2"] - ONE_IDX_COLS = ["xloc1", "yloc1", "xloc2", "yloc2"] - ZERO_IDX_COLS = ["xloc0", "yloc0", "xloc1", "yloc1"] - - input_cols = list(spatial_df.columns) - - has_default = all([c in input_cols for c in DEFAULT_COLS]) - has_0_idx = all([c in input_cols for c in ZERO_IDX_COLS]) - has_1_idx = all([c in input_cols for c in ONE_IDX_COLS]) - - if has_default: - # move the 4 position columns to front, continue - spatial_df = spatial_df[DEFAULT_COLS] - elif has_0_idx: - # move the 4 position columns to front, rename to default, continue - spatial_df = spatial_df[ZERO_IDX_COLS] - spatial_df.columns = DEFAULT_COLS - elif has_1_idx: - # move the 4 position columns to front, rename to default, continue - spatial_df = spatial_df[ONE_IDX_COLS] - spatial_df.columns = DEFAULT_COLS - else: - if len(input_cols) != 4 or not has_default: - choice = dj.utils.user_choice( - "Unexpected columns in raw position. Assume " - + f"{DEFAULT_COLS[:4]}?\n{spatial_df}\n" - ) - if choice.lower() not in ["yes", "y"]: - raise ValueError( - f"Unexpected columns in raw position: {input_cols}" - ) - # rename first 4 columns, keep rest. Rest dropped below - spatial_df.columns = DEFAULT_COLS + input_cols[4:] - - return spatial_df diff --git a/src/spyglass/linearization/v1/__init__.py b/src/spyglass/linearization/v1/__init__.py index 9a8415311..0087c209e 100644 --- a/src/spyglass/linearization/v1/__init__.py +++ b/src/spyglass/linearization/v1/__init__.py @@ -4,3 +4,10 @@ LinearizedV1, TrackGraph, ) + +__all__ = [ + "LinearizationParams", + "LinearizedPosition", + "LinearizedSelection", + "TrackGraph", +] diff --git a/src/spyglass/linearization/v1/main.py b/src/spyglass/linearization/v1/main.py index a02df4b11..4dbd3a90c 100644 --- a/src/spyglass/linearization/v1/main.py +++ b/src/spyglass/linearization/v1/main.py @@ -111,7 +111,7 @@ def plot_track_graph_as_1D( @schema -class LinearizationSelection(dj.Lookup): +class LinearizedSelection(dj.Lookup): definition = """ -> PositionOutput.proj(pos_merge_id='merge_id') -> TrackGraph @@ -124,7 +124,7 @@ class LinearizedV1(dj.Computed): """Linearized position for a given interval""" definition = """ - -> LinearizationSelection + -> LinearizedSelection --- -> AnalysisNwbfile linearized_position_object_id : varchar(40) diff --git a/src/spyglass/utils/dj_helper_fn.py b/src/spyglass/utils/dj_helper_fn.py index 4ad2d58ce..e0247dffb 100644 --- a/src/spyglass/utils/dj_helper_fn.py +++ b/src/spyglass/utils/dj_helper_fn.py @@ -37,7 +37,7 @@ def deprecated_factory(classes: list, old_module: str = "") -> list: def _subclass_factory( old_name: str, new_class: Type, old_module: str = "" ) -> Type: - """Creates a sublcass with a deprecation warning on __init__ + """Creates a subclass with a deprecation warning on __init__ Old class is a subclass of new class, so it will inherit all of the new class's methods. Old class retains its original name and module. Use