Skip to content

Commit

Permalink
expose bmag parameter and check number of b0s consistent after rounding
Browse files Browse the repository at this point in the history
  • Loading branch information
josephmje committed Mar 11, 2021
1 parent fe3abf4 commit 075c706
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 13 deletions.
12 changes: 8 additions & 4 deletions dmriprep/interfaces/reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
\t\t<summary>Summary</summary>
\t\t<ul class="elem-desc">
\t\t\t<li>Phase-encoding (PE) direction: {pe_direction}</li>
\t\t\t<li>Distinct shells: {shells_dist}</li>
\t\t\t<li>Shell distribution: {shell_dist}</li>
\t\t</ul>
\t\t</details>
"""
Expand Down Expand Up @@ -133,18 +133,22 @@ class DiffusionSummaryInputSpec(BaseInterfaceInputSpec):
pe_direction = traits.Enum(
None, "i", "i-", "j", "j-", "k", "k-", desc="Phase-encoding direction detected"
)
shells_dist = traits.Dict(mandatory=True, desc="Number of distinct shells")
shell_dist = traits.Dict(mandatory=True, desc="Shell distribution")


class DiffusionSummary(SummaryInterface):
input_spec = DiffusionSummaryInputSpec

def _generate_segment(self):
pe_direction = self.inputs.pe_direction
shells_dist = self.inputs.shells_dist
shell_dist = self.inputs.shell_dist
shell_dist_text = ", ".join(
f"{shell_dist[key]} directions at b={key} s/mm\u00B2"
for key in shell_dist
)

return DIFFUSION_TEMPLATE.format(
pe_direction=pe_direction, shells_dist=shells_dist
pe_direction=pe_direction, shell_dist=shell_dist_text
)


Expand Down
10 changes: 6 additions & 4 deletions dmriprep/interfaces/vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class _CheckGradientTableInputSpec(BaseInterfaceInputSpec):
in_rasb = File(exists=True, xor=["in_bval", "in_bvec"])
b0_threshold = traits.Float(B0_THRESHOLD, usedefault=True)
bvec_norm_epsilon = traits.Float(BVEC_NORM_EPSILON, usedefault=True)
b_mag = traits.Int(None, usedefault=True)
b_scale = traits.Bool(True, usedefault=True)


Expand All @@ -30,7 +31,7 @@ class _CheckGradientTableOutputSpec(TraitedSpec):
full_sphere = traits.Bool()
pole = traits.Tuple(traits.Float, traits.Float, traits.Float)
num_shells = traits.Int
shells_dist = traits.Dict
shell_dist = traits.Dict
b0_ixs = traits.List(traits.Int)
b0_mask = traits.List(traits.Bool)

Expand All @@ -52,7 +53,7 @@ class CheckGradientTable(SimpleInterface):
True
>>> check.outputs.num_shells
3
>>> check.outputs.shells_dist
>>> check.outputs.shell_dist
{0.0: 12, 1200.0: 32, 2500.0: 61}
>>> check = CheckGradientTable(
Expand All @@ -65,7 +66,7 @@ class CheckGradientTable(SimpleInterface):
True
>>> check.outputs.num_shells
3
>>> check.outputs.shells_dist
>>> check.outputs.shell_dist
{0: 12, 1200: 32, 2500: 61}
>>> newrasb = np.loadtxt(check.outputs.out_rasb, skiprows=1)
>>> oldrasb = np.loadtxt(str(data_dir / 'dwi.tsv'), skiprows=1)
Expand All @@ -85,6 +86,7 @@ def _run_interface(self, runtime):
bvecs=_undefined(self.inputs, "in_bvec"),
bvals=_undefined(self.inputs, "in_bval"),
rasb_file=rasb_file,
b_mag=self.inputs.b_mag,
b_scale=self.inputs.b_scale,
bvec_norm_epsilon=self.inputs.bvec_norm_epsilon,
b0_threshold=self.inputs.b0_threshold,
Expand All @@ -93,7 +95,7 @@ def _run_interface(self, runtime):
self._results["pole"] = tuple(pole)
self._results["full_sphere"] = np.all(pole == 0.0)
self._results["num_shells"] = len(table.count_shells)
self._results["shells_dist"] = table.count_shells
self._results["shell_dist"] = table.count_shells
self._results["b0_ixs"] = np.where(table.b0mask)[0].tolist()

cwd = Path(runtime.cwd).absolute()
Expand Down
23 changes: 19 additions & 4 deletions dmriprep/utils/vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class DiffusionGradientTable:
__slots__ = [
"_affine",
"_b0_thres",
"_b_mag",
"_b_scale",
"_bvals",
"_bvec_norm_epsilon",
Expand All @@ -30,6 +31,7 @@ class DiffusionGradientTable:
def __init__(
self,
b0_threshold=B0_THRESHOLD,
b_mag=None,
b_scale=True,
bvals=None,
bvec_norm_epsilon=BVEC_NORM_EPSILON,
Expand All @@ -46,12 +48,14 @@ def __init__(
----------
b0_threshold : :obj:`float`
The upper threshold to consider a low-b shell as :math:`b=0`.
b_mag : :obj:`int`
The order of magnitude to round the b-values.
b_scale : :obj:`bool`
Automatically scale the *b*-values with the norm of the corresponding
*b*-vectors before the latter are normalized.
bvals : str or os.pathlike or numpy.ndarray
File path of the b-values.
b_vec_norm_epsilon : :obj:`float`
bvec_norm_epsilon : :obj:`float`
The minimum difference in the norm of two *b*-vectors to consider them different.
bvecs : str or os.pathlike or numpy.ndarray
File path of the b-vectors.
Expand Down Expand Up @@ -93,6 +97,7 @@ def __init__(
"""
self._affine = None
self._b0_thres = b0_threshold
self._b_mag = b_mag
self._b_scale = b_scale
self._bvals = None
self._bvec_norm_epsilon = bvec_norm_epsilon
Expand Down Expand Up @@ -195,6 +200,7 @@ def normalize(self):
self.bvals,
b0_threshold=self._b0_thres,
bvec_norm_epsilon=self._bvec_norm_epsilon,
b_mag=self._b_mag,
b_scale=self._b_scale,
raise_error=self._raise_inconsistent,
)
Expand Down Expand Up @@ -290,6 +296,7 @@ def normalize_gradients(
bvals,
b0_threshold=B0_THRESHOLD,
bvec_norm_epsilon=BVEC_NORM_EPSILON,
b_mag=None,
b_scale=True,
raise_error=False,
):
Expand Down Expand Up @@ -361,17 +368,25 @@ def normalize_gradients(
raise ValueError(msg)
config.loggers.cli.warning(msg)

# Rescale b-vals if requested
# Rescale bvals if requested
if b_scale:
bvals[~b0s] *= np.linalg.norm(bvecs[~b0s], axis=1) ** 2

# Ensure b0s have (0, 0, 0) vectors
bvecs[b0s, :3] = np.zeros(3)

# Round bvals
bvals = round_bvals(bvals)
bvals = round_bvals(bvals, bmag=b_mag)

# Rescale b-vecs, skipping b0's, on the appropriate axis to unit-norm length.
# Ensure rounding bvals doesn't change the number of b0s
rounded_b0s = bvals == 0
if not np.all(b0s == rounded_b0s):
msg = f"Inconsistent b0s before ({b0s.sum()}) and after rounding ({rounded_b0s.sum()})."
if raise_error:
raise ValueError(msg)
config.loggers.cli.warning(msg)

# Rescale bvecs, skipping b0's, on the appropriate axis to unit-norm length.
bvecs[~b0s] /= np.linalg.norm(bvecs[~b0s], axis=1)[..., np.newaxis]
return bvecs, bvals.astype("uint16")

Expand Down
2 changes: 1 addition & 1 deletion dmriprep/workflows/dwi/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def init_dwi_preproc_wf(dwi_file, has_fieldmap=False):
("in_bvec", "in_bvec"),
("in_bval", "in_bval")]),
(inputnode, dwi_reference_wf, [("dwi_file", "inputnode.dwi_file")]),
(gradient_table, summary, [("shells_dist", "shells_dist")]),
(gradient_table, summary, [("shell_dist", "shell_dist")]),
(gradient_table, dwi_reference_wf, [("b0_ixs", "inputnode.b0_ixs")]),
(gradient_table, outputnode, [("out_rasb", "gradients_rasb")]),
])
Expand Down

0 comments on commit 075c706

Please sign in to comment.