Skip to content

Commit

Permalink
V0 migration model. No surgery. See details.
Browse files Browse the repository at this point in the history
PR LorenFrankLab#694 had a lot of conflicts related to PR LorenFrankLab#711, and discussions
suggested avoiding rename of existing tables. Rather than cherry pick
and rebase, this commit replaces the work in LorenFrankLab#694.
  • Loading branch information
CBroz1 committed Dec 19, 2023
1 parent bec8471 commit 47b1884
Show file tree
Hide file tree
Showing 13 changed files with 328 additions and 327 deletions.
17 changes: 13 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,24 @@

## [0.4.4] (Unreleased)

### Infrastructure

- Additional documentation. #690
- Refactor input validation in DLC pipeline. #688
- Clean up following pre-commit checks. #688
- Add Mixin class to centralize `fetch_nwb` functionality. #692
- Minor fixes to LinearizedPositionV1 pipeline #695
- Add SpikeSorting V1 pipeline #651
- Refactor restriction use in `delete_downstream_merge` #703
- Minor fixes to LFPBandV1 populator #706
- Add `cautious_delete` to Mixin class, initial implementation. #711
- Add `deprecation_factory` to facilitate table migration

### Pipelines

- Position: Refactor input validation in DLC pipeline. #688
- Spike sorting: Add SpikeSorting V1 pipeline #651
- LFP: Minor fixes to LFPBandV1 populator #706
- Linearization:
- Minor fixes to LinearizedPositionV1 pipeline #695
- Rename `position_linearization` -> `linearization`
- Migrate linearization tables: `common_position` -> `linearization.v0`

## [0.4.3] (November 7, 2023)

Expand Down
341 changes: 35 additions & 306 deletions src/spyglass/common/common_position.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@
plot_track_graph,
)

from ..settings import raw_dir, video_dir
from ..utils.dj_mixin import SpyglassMixin
from .common_behav import RawPosition, VideoFile
from .common_interval import IntervalList # noqa F401
from .common_nwbfile import AnalysisNwbfile
from spyglass.common.common_behav import RawPosition, VideoFile
from spyglass.common.common_interval import IntervalList # noqa F401
from spyglass.common.common_nwbfile import AnalysisNwbfile
from spyglass.settings import raw_dir, video_dir
from spyglass.utils.dj_helper_fn import deprecated_factory
from spyglass.utils.dj_mixin import SpyglassMixin

schema = dj.schema("common_position")

Expand Down Expand Up @@ -498,306 +499,6 @@ def _data_to_df(data, prefix="head_", add_frame_ind=False):
return df


@schema
class LinearizationParameters(SpyglassMixin, dj.Lookup):
"""Choose whether to use an HMM to linearize position.
This can help when the euclidean distances between separate arms are too
close and the previous position has some information about which arm the
animal is on.
route_euclidean_distance_scaling: How much to prefer route distances between
successive time points that are closer to the euclidean distance. Smaller
numbers mean the route distance is more likely to be close to the euclidean
distance.
"""

definition = """
linearization_param_name : varchar(80) # name for this set of parameters
---
use_hmm = 0 : int # use HMM to determine linearization
route_euclidean_distance_scaling = 1.0 : float # Preference for euclidean.
sensor_std_dev = 5.0 : float # Uncertainty of position sensor (in cm).
# Biases the transition matrix to prefer the current track segment.
diagonal_bias = 0.5 : float
"""


@schema
class TrackGraph(SpyglassMixin, dj.Manual):
"""Graph representation of track representing the spatial environment.
Used for linearizing position.
"""

definition = """
track_graph_name : varchar(80)
----
environment : varchar(80) # Type of Environment
node_positions : blob # 2D position of nodes, (n_nodes, 2)
edges: blob # shape (n_edges, 2)
linear_edge_order : blob # order of edges in linear space, (n_edges, 2)
linear_edge_spacing : blob # space btwn edges in linear space, (n_edges,)
"""

def get_networkx_track_graph(self, track_graph_parameters=None):
if track_graph_parameters is None:
track_graph_parameters = self.fetch1()
return make_track_graph(
node_positions=track_graph_parameters["node_positions"],
edges=track_graph_parameters["edges"],
)

def plot_track_graph(self, ax=None, draw_edge_labels=False, **kwds):
"""Plot the track graph in 2D position space."""
track_graph = self.get_networkx_track_graph()
plot_track_graph(
track_graph, ax=ax, draw_edge_labels=draw_edge_labels, **kwds
)

def plot_track_graph_as_1D(
self,
ax=None,
axis="x",
other_axis_start=0.0,
draw_edge_labels=False,
node_size=300,
node_color="#1f77b4",
):
"""Plot the track graph in 1D to see how the linearization is set up."""
track_graph_parameters = self.fetch1()
track_graph = self.get_networkx_track_graph(
track_graph_parameters=track_graph_parameters
)
plot_graph_as_1D(
track_graph,
edge_order=track_graph_parameters["linear_edge_order"],
edge_spacing=track_graph_parameters["linear_edge_spacing"],
ax=ax,
axis=axis,
other_axis_start=other_axis_start,
draw_edge_labels=draw_edge_labels,
node_size=node_size,
node_color=node_color,
)


@schema
class IntervalLinearizationSelection(SpyglassMixin, dj.Lookup):
definition = """
-> IntervalPositionInfo
-> TrackGraph
-> LinearizationParameters
---
"""


@schema
class IntervalLinearizedPosition(SpyglassMixin, dj.Computed):
"""Linearized position for a given interval"""

definition = """
-> IntervalLinearizationSelection
---
-> AnalysisNwbfile
linearized_position_object_id : varchar(40)
"""

def make(self, key):
print(f"Computing linear position for: {key}")

key["analysis_file_name"] = AnalysisNwbfile().create(
key["nwb_file_name"]
)

position_nwb = (
IntervalPositionInfo
& {
"nwb_file_name": key["nwb_file_name"],
"interval_list_name": key["interval_list_name"],
"position_info_param_name": key["position_info_param_name"],
}
).fetch_nwb()[0]

position = np.asarray(
position_nwb["head_position"].get_spatial_series().data
)
time = np.asarray(
position_nwb["head_position"].get_spatial_series().timestamps
)

linearization_parameters = (
LinearizationParameters()
& {"linearization_param_name": key["linearization_param_name"]}
).fetch1()
track_graph_info = (
TrackGraph() & {"track_graph_name": key["track_graph_name"]}
).fetch1()

track_graph = make_track_graph(
node_positions=track_graph_info["node_positions"],
edges=track_graph_info["edges"],
)

linear_position_df = get_linearized_position(
position=position,
track_graph=track_graph,
edge_spacing=track_graph_info["linear_edge_spacing"],
edge_order=track_graph_info["linear_edge_order"],
use_HMM=linearization_parameters["use_hmm"],
route_euclidean_distance_scaling=linearization_parameters[
"route_euclidean_distance_scaling"
],
sensor_std_dev=linearization_parameters["sensor_std_dev"],
diagonal_bias=linearization_parameters["diagonal_bias"],
)

linear_position_df["time"] = time

# Insert into analysis nwb file
nwb_analysis_file = AnalysisNwbfile()

key["linearized_position_object_id"] = nwb_analysis_file.add_nwb_object(
analysis_file_name=key["analysis_file_name"],
nwb_object=linear_position_df,
)

nwb_analysis_file.add(
nwb_file_name=key["nwb_file_name"],
analysis_file_name=key["analysis_file_name"],
)

self.insert1(key)

def fetch1_dataframe(self):
return self.fetch_nwb()[0]["linearized_position"].set_index("time")


class NodePicker:
"""Interactive creation of track graph by looking at video frames."""

def __init__(
self, ax=None, video_filename=None, node_color="#1f78b4", node_size=100
):
if ax is None:
ax = plt.gca()
self.ax = ax
self.canvas = ax.get_figure().canvas
self.cid = None
self._nodes = []
self.node_color = node_color
self._nodes_plot = ax.scatter(
[], [], zorder=5, s=node_size, color=node_color
)
self.edges = [[]]
self.video_filename = video_filename

if video_filename is not None:
self.video = cv2.VideoCapture(video_filename)
frame = self.get_video_frame()
ax.imshow(frame, picker=True)
ax.set_title(
"Left click to place node.\nRight click to remove node."
"\nShift+Left click to clear nodes."
"\nCntrl+Left click two nodes to place an edge"
)

self.connect()

@property
def node_positions(self):
return np.asarray(self._nodes)

def connect(self):
if self.cid is None:
self.cid = self.canvas.mpl_connect(
"button_press_event", self.click_event
)

def disconnect(self):
if self.cid is not None:
self.canvas.mpl_disconnect(self.cid)
self.cid = None

def click_event(self, event):
if not event.inaxes:
return
if (event.key not in ["control", "shift"]) & (
event.button == 1
): # left click
self._nodes.append((event.xdata, event.ydata))
if (event.key not in ["control", "shift"]) & (
event.button == 3
): # right click
self.remove_point((event.xdata, event.ydata))
if (event.key == "shift") & (event.button == 1):
self.clear()
if (event.key == "control") & (event.button == 1):
point = (event.xdata, event.ydata)
distance_to_nodes = np.linalg.norm(
self.node_positions - point, axis=1
)
closest_node_ind = np.argmin(distance_to_nodes)
if len(self.edges[-1]) < 2:
self.edges[-1].append(closest_node_ind)
else:
self.edges.append([closest_node_ind])

self.redraw()

def redraw(self):
# Draw Node Circles
if len(self.node_positions) > 0:
self._nodes_plot.set_offsets(self.node_positions)
else:
self._nodes_plot.set_offsets([])

# Draw Node Numbers
self.ax.texts = []
for ind, (x, y) in enumerate(self.node_positions):
self.ax.text(
x,
y,
ind,
zorder=6,
fontsize=12,
horizontalalignment="center",
verticalalignment="center",
clip_on=True,
bbox=None,
transform=self.ax.transData,
)
# Draw Edges
self.ax.lines = [] # clears the existing lines
for edge in self.edges:
if len(edge) > 1:
x1, y1 = self.node_positions[edge[0]]
x2, y2 = self.node_positions[edge[1]]
self.ax.plot(
[x1, x2], [y1, y2], color=self.node_color, linewidth=2
)

self.canvas.draw_idle()

def remove_point(self, point):
if len(self._nodes) > 0:
distance_to_nodes = np.linalg.norm(
self.node_positions - point, axis=1
)
closest_node_ind = np.argmin(distance_to_nodes)
self._nodes.pop(closest_node_ind)

def clear(self):
self._nodes = []
self.edges = [[]]
self.redraw()

def get_video_frame(self):
is_grabbed, frame = self.video.read()
if is_grabbed:
return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)


@schema
class PositionVideo(SpyglassMixin, dj.Computed):
"""Creates a video of the computed head position and orientation as well as
Expand All @@ -807,7 +508,6 @@ class PositionVideo(SpyglassMixin, dj.Computed):

definition = """
-> IntervalPositionInfo
---
"""

def make(self, key):
Expand Down Expand Up @@ -1038,6 +738,35 @@ def make_video(
cv2.destroyAllWindows()


# ----------------------------- Migrated Tables -----------------------------


from spyglass.linearization.v0 import main as linV0 # noqa: E402

(
LinearizationParameters,
TrackGraph,
IntervalLinearizationSelection,
IntervalLinearizedPosition,
) = deprecated_factory(
[
("LinearizationParameters", linV0.LinearizationParameters),
("TrackGraph", linV0.TrackGraph),
(
"IntervalLinearizationSelection",
linV0.IntervalLinearizationSelection,
),
(
"IntervalLinearizedPosition",
linV0.IntervalLinearizedPosition,
),
],
old_module=__name__,
)

# ----------------------------- Helper Functions -----------------------------


def _fix_col_names(spatial_df):
"""Renames columns in spatial dataframe according to previous norm
Expand Down
1 change: 1 addition & 0 deletions src/spyglass/linearization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from spyglass.linearization.merge import LinearizedPositionOutput
Loading

0 comments on commit 47b1884

Please sign in to comment.