Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Commit

Permalink
sty: lint notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
oesteban committed Dec 19, 2024
1 parent efc305e commit e903e40
Showing 1 changed file with 112 additions and 72 deletions.
184 changes: 112 additions & 72 deletions docs/notebooks/bold_realignment.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,19 @@
"metadata": {},
"outputs": [],
"source": [
"import nibabel as nb\n",
"import numpy as np\n",
"from pathlib import Path\n",
"from shutil import rmtree, copy, move\n",
"from importlib import reload\n",
"import asyncio\n",
"import nest_asyncio\n",
"\n",
"from scipy.ndimage import binary_dilation\n",
"from skimage.morphology import ball\n",
"from pathlib import Path\n",
"from shutil import copy, move\n",
"\n",
"import nest_asyncio\n",
"import nibabel as nb\n",
"import nitransforms as nt\n",
"import numpy as np\n",
"from nipreps.synthstrip.wrappers.nipype import SynthStrip\n",
"from nipype.interfaces.afni import Volreg\n",
"from scipy.ndimage import binary_dilation\n",
"from skimage.morphology import ball\n",
"\n",
"from eddymotion.registration import ants as erants\n",
"\n",
"nest_asyncio.apply()"
Expand All @@ -31,8 +30,8 @@
"metadata": {},
"outputs": [],
"source": [
"DATA_PATH = Path('/data/datasets/')\n",
"WORKDIR = Path.home() / 'tmp' / 'eddymotiondev' / 'ismrm25'\n",
"DATA_PATH = Path(\"/data/datasets/\")\n",
"WORKDIR = Path.home() / \"tmp\" / \"eddymotiondev\" / \"ismrm25\"\n",
"WORKDIR.mkdir(parents=True, exist_ok=True)\n",
"\n",
"OUTPUT_DIR = Path(\"/data/derivatives\") / \"eddymotion-ismrm25-exp2\"\n",
Expand All @@ -46,7 +45,8 @@
"outputs": [],
"source": [
"bold_runs = [\n",
" Path(line) for line in (DATA_PATH / \"ismrm_sample.txt\").read_text().splitlines()\n",
" Path(line)\n",
" for line in (DATA_PATH / \"ismrm_sample.txt\").read_text().splitlines()\n",
" if line.strip()\n",
"]"
]
Expand All @@ -65,46 +65,48 @@
" nii = nb.load(DATA_PATH / bold_run)\n",
" average = nii.get_fdata().mean(-1)\n",
" avg_path.parent.mkdir(exist_ok=True, parents=True)\n",
" nii.__class__(\n",
" average,\n",
" nii.affine,\n",
" nii.header\n",
" ).to_filename(avg_path)\n",
" nii.__class__(average, nii.affine, nii.header).to_filename(avg_path)\n",
"\n",
" bmask_path = OUTPUT_DIR / bold_run.parent / f\"{bold_run.name.rsplit('_', 1)[0]}_label-brain_mask.nii.gz\"\n",
" bmask_path = (\n",
" OUTPUT_DIR / bold_run.parent / f\"{bold_run.name.rsplit('_', 1)[0]}_label-brain_mask.nii.gz\"\n",
" )\n",
" if not bmask_path.exists():\n",
" bmsk_results = SynthStrip(\n",
" in_file=str(avg_path),\n",
" use_gpu=True,\n",
" ).run(cwd=str(WORKDIR))\n",
" copy(bmsk_results.outputs.out_mask, bmask_path)\n",
"\n",
" dilmask_path = avg_path.parent / f\"{avg_path.name.rsplit('_', 1)[0]}_label-braindilated_mask.nii.gz\"\n",
" dilmask_path = (\n",
" avg_path.parent / f\"{avg_path.name.rsplit('_', 1)[0]}_label-braindilated_mask.nii.gz\"\n",
" )\n",
"\n",
" if not dilmask_path.exists():\n",
" niimsk = nb.load(bmask_path)\n",
" niimsk.__class__(\n",
" binary_dilation(niimsk.get_fdata() > 0.0, ball(4)).astype(\"uint8\"),\n",
" niimsk.affine,\n",
" niimsk.header,\n",
" ).to_filename(dilmask_path)\n",
" binary_dilation(niimsk.get_fdata() > 0.0, ball(4)).astype(\"uint8\"),\n",
" niimsk.affine,\n",
" niimsk.header,\n",
" ).to_filename(dilmask_path)\n",
"\n",
" oned_matrix_path = avg_path.parent / f\"{avg_path.name.rsplit('_', 1)[0]}_desc-hmc_xfm.txt\"\n",
" realign_output = avg_path.parent / f\"{avg_path.name.rsplit('_', 1)[0]}_desc-realigned_bold.nii.gz\"\n",
" oned_matrix_path = avg_path.parent / f\"{avg_path.name.rsplit('_', 1)[0]}_desc-hmc_xfm.txt\"\n",
" realign_output = (\n",
" avg_path.parent / f\"{avg_path.name.rsplit('_', 1)[0]}_desc-realigned_bold.nii.gz\"\n",
" )\n",
"\n",
" if not realign_output.exists():\n",
" volreg_results = Volreg(\n",
" in_file=str(DATA_PATH / bold_run),\n",
" in_weight_volume=str(dilmask_path),\n",
" args='-Fourier -twopass',\n",
" args=\"-Fourier -twopass\",\n",
" zpad=4,\n",
" outputtype='NIFTI_GZ',\n",
" outputtype=\"NIFTI_GZ\",\n",
" oned_matrix_save=f\"{oned_matrix_path}.aff12.1D\",\n",
" out_file=str(realign_output),\n",
" num_threads=12,\n",
" ).run(cwd=str(WORKDIR))\n",
"\n",
" move(volreg_results.outputs.oned_matrix_save, oned_matrix_path)\n"
" move(volreg_results.outputs.oned_matrix_save, oned_matrix_path)"
]
},
{
Expand Down Expand Up @@ -173,73 +175,92 @@
"metadata": {},
"outputs": [],
"source": [
"\n",
"def plot_profile(image_path, axis=None, indexing=None, cmap='gray', label=None, figsize=(15, 1.7)):\n",
"def plot_profile(image_path, axis=None, indexing=None, cmap=\"gray\", label=None, figsize=(15, 1.7)):\n",
" \"\"\"Plots a single image slice on a given axis or a new figure if axis is None.\"\"\"\n",
" # Load the image\n",
" image_data = nb.load(image_path).get_fdata()\n",
" \n",
"\n",
" # Define default indexing if not provided\n",
" if indexing is None:\n",
" indexing = (image_data.shape[0] // 2, 3 * image_data.shape[1] // 4, slice(None), slice(None))\n",
" \n",
" indexing = (\n",
" image_data.shape[0] // 2,\n",
" 3 * image_data.shape[1] // 4,\n",
" slice(None),\n",
" slice(None),\n",
" )\n",
"\n",
" # If no axis is provided, create a new figure and axis\n",
" if axis is None:\n",
" fig, axis = plt.subplots(figsize=figsize)\n",
" else:\n",
" fig = None # If axis is provided, we won't manage the figure\n",
" \n",
"\n",
" # Display the image on the specified axis with aspect='auto' and the colormap\n",
" axis.imshow(image_data[indexing], aspect='auto', cmap=cmap)\n",
" \n",
" axis.imshow(image_data[indexing], aspect=\"auto\", cmap=cmap)\n",
"\n",
" # Turn off the axis for a cleaner look\n",
" axis.axis('off')\n",
" \n",
" axis.axis(\"off\")\n",
"\n",
" if label:\n",
" # Annotate the plot with the provided label\n",
" axis.text(0.02, 0.95, label, color='white', fontsize=12, ha='left', va='top', transform=axis.transAxes)\n",
" \n",
" axis.text(\n",
" 0.02,\n",
" 0.95,\n",
" label,\n",
" color=\"white\",\n",
" fontsize=12,\n",
" ha=\"left\",\n",
" va=\"top\",\n",
" transform=axis.transAxes,\n",
" )\n",
"\n",
" # If we created the figure, show it\n",
" if fig is not None:\n",
" plt.show()\n",
" \n",
"\n",
" return fig\n",
"\n",
"\n",
"# def plot_combined_profile(images, indexing=None, figsize=(15, 1.7), cmap='gray', labels=None):\n",
"# # Create a figure with three subplots in a vertical layout and specified figure size\n",
"# n_images = len(images)\n",
"\n",
"# nplots = n_images * len(indexing or [True])\n",
"# figsize = (figsize[0], figsize[1] * nplots)\n",
"# fig, axes = plt.subplots(nplots, 1, figsize=figsize, constrained_layout=True)\n",
" \n",
"\n",
"# if labels is None or isinstance(labels, str):\n",
"# labels = (labels, ) * nplots\n",
" \n",
"\n",
"# if indexing is None or len(indexing) == 0:\n",
"# indexing = [None]\n",
" \n",
"\n",
"# for i, idx in enumerate(indexing):\n",
"# for j in range(len(images)):\n",
"# ax = axes[i * n_images + j]\n",
"# plot_profile(images[j], axis=ax, indexing=idx, cmap=cmap, label=labels[j])\n",
" \n",
"\n",
"# return fig\n",
"\n",
"def plot_combined_profile(images, afni_fd, eddymotion_fd, indexing=None, figsize=(15, 1.7), cmap='gray', labels=None):\n",
"\n",
"def plot_combined_profile(\n",
" images, afni_fd, eddymotion_fd, indexing=None, figsize=(15, 1.7), cmap=\"gray\", labels=None\n",
"):\n",
" # Calculate the number of profile plots\n",
" n_images = len(images)\n",
" nplots = n_images * len(indexing or [True])\n",
" total_height = figsize[1] * nplots + 2 # Adjust figure height for FD plot\n",
"\n",
" # Create a figure with one extra row for the FD plot, setting `sharex=True` for shared x-axis\n",
" fig, axes = plt.subplots(nplots + 1, 1, figsize=(figsize[0], total_height), constrained_layout=True, sharex=True)\n",
" fig, axes = plt.subplots(\n",
" nplots + 1, 1, figsize=(figsize[0], total_height), constrained_layout=True, sharex=True\n",
" )\n",
"\n",
" # Plot the framewise displacement on the first axis\n",
" fd_axis = axes[0]\n",
" timepoints = np.arange(len(afni_fd)) # Assuming afni_fd and eddymotion_fd have the same length\n",
" fd_axis.plot(timepoints, afni_fd, label='AFNI 3dVolreg FD', color='blue')\n",
" fd_axis.plot(timepoints, eddymotion_fd, label='eddymotion FD', color='orange')\n",
" fd_axis.plot(timepoints, afni_fd, label=\"AFNI 3dVolreg FD\", color=\"blue\")\n",
" fd_axis.plot(timepoints, eddymotion_fd, label=\"eddymotion FD\", color=\"orange\")\n",
" fd_axis.set_ylabel(\"FD (mm)\")\n",
" fd_axis.legend(loc=\"upper right\")\n",
" fd_axis.set_xticks([]) # Hide x-ticks to keep x-axis clean\n",
Expand All @@ -251,13 +272,13 @@
" # Set indexing if not provided\n",
" if indexing is None or len(indexing) == 0:\n",
" indexing = [None]\n",
" \n",
"\n",
" # Plot each profile slice below the FD plot\n",
" for i, idx in enumerate(indexing):\n",
" for j in range(len(images)):\n",
" ax = axes[i * n_images + j + 1] # Shift index by 1 to account for FD plot\n",
" plot_profile(images[j], axis=ax, indexing=idx, cmap=cmap, label=labels[j])\n",
" \n",
"\n",
" return fig"
]
},
Expand All @@ -278,7 +299,10 @@
}
],
"source": [
"plot_combined_profile((DATA_PATH / bold_runs[15], afni_realigned[15], afni_realigned[15]), labels=(\"hmc1\", \"original\", \"hmc2\"));"
"plot_combined_profile(\n",
" (DATA_PATH / bold_runs[15], afni_realigned[15], afni_realigned[15]),\n",
" labels=(\"hmc1\", \"original\", \"hmc2\"),\n",
");"
]
},
{
Expand All @@ -299,8 +323,12 @@
],
"source": [
"datashape = nb.load(DATA_PATH / bold_runs[15]).shape\n",
"plot_profile(DATA_PATH / bold_runs[15], afni_realigned[15], afni_realigned[15],\n",
" indexing=(slice(None), 3 * datashape[1] // 4, datashape[2] // 2, slice(None)));"
"plot_profile(\n",
" DATA_PATH / bold_runs[15],\n",
" afni_realigned[15],\n",
" afni_realigned[15],\n",
" indexing=(slice(None), 3 * datashape[1] // 4, datashape[2] // 2, slice(None)),\n",
");"
]
},
{
Expand All @@ -309,7 +337,6 @@
"metadata": {},
"outputs": [],
"source": [
"from eddymotion.estimator import EddyMotionEstimator\n",
"from eddymotion.model.base import AverageModel\n",
"from eddymotion.utils import random_iterator"
]
Expand Down Expand Up @@ -424,7 +451,9 @@
" workdir = WORKDIR / bold_run.parent\n",
" workdir.mkdir(parents=True, exist_ok=True)\n",
" data_path = DATA_PATH / bold_run\n",
" brainmask_path = OUTPUT_DIR / bold_run.parent / f\"{bold_run.name.rsplit('_', 1)[0]}_label-brain_mask.nii.gz\"\n",
" brainmask_path = (\n",
" OUTPUT_DIR / bold_run.parent / f\"{bold_run.name.rsplit('_', 1)[0]}_label-brain_mask.nii.gz\"\n",
" )\n",
"\n",
" nii = nb.load(data_path)\n",
" hdr = nii.header.copy()\n",
Expand All @@ -449,6 +478,7 @@
"outputs": [],
"source": [
"from nitransforms.resampling import apply\n",
"\n",
"from eddymotion.registration.utils import displacement_framewise\n",
"\n",
"afni_fd = {}\n",
Expand All @@ -472,13 +502,14 @@
" ]\n",
"\n",
" nii = nb.load(DATA_PATH / bold_run)\n",
" nitransforms_fd[str(bold_run)] = np.array([\n",
" displacement_framewise(nii, xfm)\n",
" for xfm in xfms\n",
" ])\n",
" nitransforms_fd[str(bold_run)] = np.array([displacement_framewise(nii, xfm) for xfm in xfms])\n",
"\n",
" hmc_xfm = nt.linear.LinearTransformsMapping(xfms)\n",
" out_nitransforms = OUTPUT_DIR / bold_run.parent / f\"{bold_run.name.rsplit('_', 1)[0]}_desc-nitransforms_bold.nii.gz\"\n",
" out_nitransforms = (\n",
" OUTPUT_DIR\n",
" / bold_run.parent\n",
" / f\"{bold_run.name.rsplit('_', 1)[0]}_desc-nitransforms_bold.nii.gz\"\n",
" )\n",
" if not out_nitransforms.exists():\n",
" apply(\n",
" hmc_xfm,\n",
Expand All @@ -487,20 +518,21 @@
" ).to_filename(out_nitransforms)\n",
"\n",
" afni_xfms = nt.linear.load(\n",
" OUTPUT_DIR / bold_run.parent / f\"{bold_run.name.rsplit('_', 1)[0]}_desc-hmc_xfm.txt\"\n",
" OUTPUT_DIR / bold_run.parent / f\"{bold_run.name.rsplit('_', 1)[0]}_desc-hmc_xfm.txt\"\n",
" )\n",
" afni_fd[str(bold_run)] = np.array(\n",
" [displacement_framewise(nii, afni_xfms[i]) for i in range(len(afni_xfms))]\n",
" )\n",
" afni_fd[str(bold_run)] = np.array([\n",
" displacement_framewise(nii, afni_xfms[i])\n",
" for i in range(len(afni_xfms))\n",
" ])\n",
"\n",
" out_afni = OUTPUT_DIR / bold_run.parent / f\"{bold_run.name.rsplit('_', 1)[0]}_desc-afni_bold.nii.gz\"\n",
" out_afni = (\n",
" OUTPUT_DIR / bold_run.parent / f\"{bold_run.name.rsplit('_', 1)[0]}_desc-afni_bold.nii.gz\"\n",
" )\n",
" if not out_afni.exists():\n",
" apply(\n",
" afni_xfms,\n",
" spatialimage=nii,\n",
" reference=nii,\n",
" ).to_filename(out_afni)\n"
" ).to_filename(out_afni)"
]
},
{
Expand Down Expand Up @@ -1681,8 +1713,16 @@
"\n",
" for bold_run in bold_runs:\n",
" original = DATA_PATH / bold_run\n",
" nitransforms = OUTPUT_DIR / bold_run.parent / f\"{bold_run.name.rsplit('_', 1)[0]}_desc-nitransforms_bold.nii.gz\"\n",
" afni = OUTPUT_DIR / bold_run.parent / f\"{bold_run.name.rsplit('_', 1)[0]}_desc-realigned_bold.nii.gz\"\n",
" nitransforms = (\n",
" OUTPUT_DIR\n",
" / bold_run.parent\n",
" / f\"{bold_run.name.rsplit('_', 1)[0]}_desc-nitransforms_bold.nii.gz\"\n",
" )\n",
" afni = (\n",
" OUTPUT_DIR\n",
" / bold_run.parent\n",
" / f\"{bold_run.name.rsplit('_', 1)[0]}_desc-realigned_bold.nii.gz\"\n",
" )\n",
"\n",
" datashape = nb.load(original).shape\n",
"\n",
Expand All @@ -1691,7 +1731,7 @@
" afni_fd[str(bold_run)],\n",
" nitransforms_fd[str(bold_run)],\n",
" labels=(\"3dVolreg\", str(bold_run), \"eddymotion\"),\n",
" indexing=(None, (slice(None), 3 * datashape[1] // 4, datashape[2] // 2, slice(None)))\n",
" indexing=(None, (slice(None), 3 * datashape[1] // 4, datashape[2] // 2, slice(None))),\n",
" )\n",
"\n",
" # Save the figure\n",
Expand All @@ -1701,7 +1741,7 @@
" plt.close(fig)\n",
"\n",
" index_file.write(f\"<li><a href={out_svg.relative_to(OUTPUT_DIR)}>{bold_run}</a></li>\\n\")\n",
" \n",
"\n",
" index_file.write(\"</ul>\\n</body></html>\")"
]
},
Expand Down

0 comments on commit e903e40

Please sign in to comment.