Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add session displacement generation #3231

Open
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

JoeZiminski
Copy link
Collaborator

@JoeZiminski JoeZiminski commented Jul 18, 2024

This PR adds an 'inter-session displacement' ground-truth recording generator. This is to act as test data for inter-session alignment e.g. #2626 #3126. The reason for this is to support inter-session alignment (e.g. #2626, #3126 ). The idea is to create separate recordings with the same templates but shifted unit locations across recordings. There are options to model:

  • rigid shift in (x, y)
  • nonrigid shift in (x, y)
  • neurons dropping out across recordings, recording_amplitude_scalings allows scaling the injected template amplitudes to different sizes across recordings.
  • additional neurons being introduced into the recording. By default, when the templates shift, the space previously occupied on the probe by those recordings is left empty. In reality, new neurons should be shifted into the recording. shift_units_outside_probe=True` introduces new neurons into the recording after probe shift.

This PR tries to use the existing motion machinery where possible, and makes some refactorings with this aim. The main changes are the refactorings, introduction of a new function for generating multi-session drifting recordings, and associated tests. Below are some sections to highlight usage, and example scripts for quick-running or deeper debugging.

I'm not sure about the name 'session displacement'. Maybe just 'generate_multi_session_recording()` is easier to understand. Also, please let me know if any other variable names are unclear.

Examples

Example 1

Set the amplitude of 2 units to zero between sessions, also with a 100 um shift. The units are shifted across sessions, and the top / bottom units are removed (e.g. simulating these neurons disappearing between sessions)_.

Screenshot 2024-09-03 at 17 00 57

Example 2

Set shift_units_outside_probe to True which introduces new units into the recording due to shift, alongside a 250 um shift. Note the top unit has been shifted out of the probe, the middle units are shifted up, and 2 new units are introduced at the bottom of the probe.

Screenshot 2024-09-03 at 17 02 53

Note

When running with a set of n num_units, it is initially suprising that you do not always see exactly n units clearly (e.g. below, which was run with num_units=5. The main driver of this is that the default generate_unit_locations_kwargs is between 5 and 45. For low-amplitude neurons, a far z axis will not generate enough signal to reach the probe (this is of course more realistic for simulation purposes). If you want to be sure you see n units in the raster, set margin_um and maximum_z lower in generate_unit_locations_kwargs e.g.:

image

generate_unit_locations_kwargs=dict(
    margin_um=0.0,
    minimum_z=5.0,
    maximum_z=10.0,
    minimum_distance=18.0,
    max_iteration=100,
    distance_strict=False,
)
Quick Run Code
import spikeinterface.full as si
from spikeinterface.generation.session_displacement_generator import generate_session_displacement_recordings
import matplotlib.pyplot as plt
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
from spikeinterface.sortingcomponents.peak_localization import localize_peaks

rec_list, _ = generate_session_displacement_recordings(
    non_rigid_gradient=None,  # Note this will set nonlinearity to both x and y (the same)
    num_units=5,
    recording_durations=[25, 25, 25],
    recording_shifts=[
        (0, 0),
        (0, 75),
        (0, -150),
    ],
    shift_units_outside_probe=False,
    seed=None,
)

# Plot the raster maps.

for rec in rec_list:

    peaks = detect_peaks(rec, method="locally_exclusive")
    peak_locs = localize_peaks(rec, peaks, method="grid_convolution")

    si.plot_drift_raster_map(
        peaks=peaks,
        peak_locations=peak_locs,
        recording=rec,
        clim=(-300, 0)  # fix clim for comparability across plots
    )
    plt.show()
Full Debugging Code
import spikeinterface.full as si
from spikeinterface.generation.session_displacement_generator import generate_session_displacement_recordings
import matplotlib.pyplot as plt
import numpy as np
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
from spikeinterface.sortingcomponents.peak_localization import localize_peaks

# Generate a ground truth recording where every unit is firing a lot,
# with high amplitude, and is close to the spike, so all picked up.
# This just makes it easy to play around with the units, e.g.
# if specifying 5 units, 5 unit peaks are clearly visible, none are lost
# because their position is too far from probe.

default_unit_params_range = dict(
    alpha=(100.0, 500.0),
    depolarization_ms=(0.09, 0.14),
    repolarization_ms=(0.5, 0.8),
    recovery_ms=(1.0, 1.5),
    positive_amplitude=(0.1, 0.25),
    smooth_ms=(0.03, 0.07),
    spatial_decay=(20, 40),
    propagation_speed=(250.0, 350.0),
    b=(0.1, 1),
    c=(0.1, 1),
    x_angle=(0, np.pi),
    y_angle=(0, np.pi),
    z_angle=(0, np.pi),
)

default_unit_params_range["alpha"] = (500, 500)  # do this or change the margin on generate_unit_locations_kwargs
default_unit_params_range["b"] = (0.5, 1)        # and make the units fatter, easier to receive signal!
default_unit_params_range["c"] = (0.5, 1)

scale_ = [np.array([0.25, 0.5, 1, 1, 0])] * 2
scale_ = [np.ones(5)] + scale_

rec_list, _ = generate_session_displacement_recordings(
    non_rigid_gradient=None,  # 0.05, TODO: note this will set nonlinearity to both x and y (the same)
    num_units=5,
    recording_durations=(25, 25, 25),  # TODO: checks on inputs
    recording_shifts=(
        (0, 0),
        (0, 0),
        (0, 0),
    ),
    recording_amplitude_scalings= {
        "method": "by_amplitude_and_firing_rate",
        "scalings": scale_,
    },
    shift_units_outside_probe=False,
    generate_sorting_kwargs=dict(firing_rates=(0, 200), refractory_period_ms=4.0),
    generate_templates_kwargs=dict(unit_params=default_unit_params_range, ms_before=1.5, ms_after=3),
    seed=None,
    generate_unit_locations_kwargs=dict(
        margin_um=0.0,  # if this is say 20, then units go off the edge of the probe and are such low amplitude they are not picked up.
        minimum_z=5.0,
        maximum_z=45.0,
        minimum_distance=18.0,
        max_iteration=100,
        distance_strict=False,
    ),
)

# Iterate through each recording, plotting the raw traces then
# detecting and plotting the peaks.

for rec in rec_list:

    si.plot_traces(rec, time_range=(0, 1))
    plt.show()

    peaks = detect_peaks(rec, method="locally_exclusive")
    peak_locs = localize_peaks(rec, peaks, method="grid_convolution")

    si.plot_drift_raster_map(
        peaks=peaks,
        peak_locations=peak_locs,
        recording=rec,
        clim=(-300, 0)  # fix clim for comparability across plots
    )
    plt.show()

TODO:

  • tests are randomly failing every now and that only on macOS

@JoeZiminski JoeZiminski force-pushed the add_session_displacement_generation branch from 089cbc0 to 60c8e5e Compare July 29, 2024 15:39
@alejoe91 alejoe91 added motion correction Questions related to motion correction generators Related to generator tools labels Aug 27, 2024
@JoeZiminski JoeZiminski force-pushed the add_session_displacement_generation branch from 72ca7fa to 33254b9 Compare September 3, 2024 14:13
@JoeZiminski JoeZiminski marked this pull request as ready for review September 3, 2024 16:08
@JoeZiminski JoeZiminski force-pushed the add_session_displacement_generation branch from 33254b9 to e996dee Compare September 3, 2024 16:35
Copy link
Collaborator

@cwindolf cwindolf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool! This seems really nice. I think my only question is whether there can be more logic sharing with the single-session generate_drifting_recording, since some of the changes you've made seem like they could be helpful there. But, not sure how much sense that makes.


displacement_vectors = np.concatenate(displacement_vectors, axis=2)

return displacement_vectors, displacement_unit_factor, displacement_sampling_frequency, displacements_steps


def calculate_displacement_unit_factor(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe this could be called something like "simulate_linear_gradient_drift"? i was a bit confused reading it, but it seems to be generating drift which is 0 at the top of the probe and something not zero at the bottom?

maybe someone can help explain what exactly

displacement_unit_factor = factors * (1 - f) + f

ends up producing... is it like there is some global drift plus per-unit linear drift?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @cwindolf thanks a lot for this review. This function is a refactoring of this code. I like simulate_linear_gradient_drift, for this PR I will keep the current naming for consistency with the old code. However I'll make an issue based on some of the points you raise in this PR (e.g. including some things in the within-session drift) and add a note on this there.

I agree I got quite confused the first (few) times looking through the use of non_rigid_gradient. I think the easiest way to see it is with some example values. In the first part of the function, the dot-product of the displacement vector and the unit location (expressed as a vector from the probe origin, I am not sure where it is, maybe bottom-left). In the y-displacement only case, this is just the unit y position. These unit positions are scaled to [0, 1] and called factors.

The f and expression you show ensure that the 'largest' unit location (e.g. near the top of the probe if the origin is bottom left) is scaled by 1 (no change). The scaling is linear across all unit positions, its kind of like a linspace where the max value is 1 and f sets the min value. e.g. looking at the smallest, largest and a middle location unit (i.e. factors 0, 1, and 0.5)

non_rigid_gradient=0.8

0 * 0.2 + 0.8= 0.8
0.5 * 0.2 + 0.8 = 0.9
1 * 0.8 + 0.2 = 1

non_rigid_gradient=0.2

0 * 0.8 + 0.2 = 0.2
0.5 * 0.8 + 0.2 = 0.6
1 * 0.8 + 0.2 = 1

So in the first case, the scaling of the units is only in the range [0.8, 1] of the normalised position of the unit. But for the smaller non_rigid_gradient=0.2, the scaling is between [0.2, 1].

I re-wrote the docstring of the function, let me know if its any clearer, I think there is still room for improvement, I am also not sure how much depth to go into.

)


def _update_kwargs_for_extended_units(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a cool feature! in the single-session case, i have previously dealt with this by adding extra units that are off the probe to start. i'm wondering how this fits into the single session case... would this function be useful there too?

Copy link
Collaborator Author

@JoeZiminski JoeZiminski Jan 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cheers! The approach here is exactly the same, some additional simulated unit locations are generated that are off the probe. As the displacement is applied, these out-of-probe units are moved into the probe. I think this would also be useful in the single session case, as the probe drifts within a session more and more 'new' signal (i.e. signal that was not detected in the probes original position) will be introduced and could affect the correction.

I guess this is a difficult problem, as the nature of the 'new' units introduced into the region in which the probe is measuring will be random and presumably highly variable across preparations.

# units is duplicated for each section, so the new num units
# is 3x the old num units.
num_units, unit_locations, fixed_generate_templates_kwargs, fixed_generate_sorting_kwargs = (
_update_kwargs_for_extended_units(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i guess it would only need max_displacement padding of units? maybe plus a spread factor of 200um?

also, would it be worth adding this to the single-session generator too?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point! I did it in this way as it was relatively easy to implement and cover all cases and I didn't think about option, I guess the primary benefit of this would be faster execution (?), which I did not check. I just did a quick test and although including 3x as many units increases the time to generate the recording, it is already so fast the difference is negligible (see below). As such I will probably keep it as-is just to keep the implementation simple if you agree that makes sense.

<style> </style>
recording time Introduce new units Generate Recording Detect and localise
25s ON 0.067 16.2
  OFF 0.221 16.24
       
  ON 0.0661 53.7
  OFF 0.0228 53.452

I think this would be useful for within-session drift benchmarking, I'll add it as an issue!

@JoeZiminski
Copy link
Collaborator Author

Hey @samuelgarcia @alejoe91 hope you both are good! I know things are very busy at the moment, but I was wondering if anyone might be available to give this a quick review? It would be useful for #3231 to merge this and it is (relatively) orthogonal from existing code.

The number of units in the generated recordings.
recording_durations : list
An array of length (num_recordings,) specifying the
duration that each created recording should be.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worth saying the units?

Suggested change
duration that each created recording should be.
duration that each created recording should be, in seconds.

@@ -776,7 +786,7 @@ def synthesize_poisson_spike_vector(
unit_indices = unit_indices[sort_indices]
spike_frames = spike_frames[sort_indices]

return spike_frames, unit_indices
return spike_frames, unit_indices, firing_rates
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure it's worth changing the output of this public function? Alternatively, you could add

firing_rates_array = _ensure_firing_rates(firing_rates, num_units, seed)

at line 154 and keep it as before.

I know @h-mayorquin uses this function.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree on this one.

Plus, the firing rates are already passed to synthesize_poisson_spike_vector, why return them? You can make an array of the input in one line if necessary.

recordings are shifted relative to it. e.g. to create two recordings,
the second shifted by 50 um in the x-direction and 250 um in the y
direction : ((0, 0), (50, 250)).
non_rigid_gradient : float
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could it be worthwhile to have a different non_rigid_gradient for each segment, and the user passes a list?

def _get_inter_session_displacements(shift, non_rigid_gradient, num_units, unit_locations):
"""
Get the formatted `displacement_vector` and `displacement_unit_factor`
used to shift the `unit_locations`..
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
used to shift the `unit_locations`..
used to shift the `unit_locations`.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
generators Related to generator tools motion correction Questions related to motion correction
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants