Skip to content

Commit

Permalink
Refactor Fishers
Browse files Browse the repository at this point in the history
  • Loading branch information
JulioAPeraza committed May 23, 2024
1 parent ae25648 commit d012701
Showing 1 changed file with 31 additions and 24 deletions.
55 changes: 31 additions & 24 deletions nimare/meta/ibma.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,21 @@ def _generate_description(self):
)
return description

def _fit_model(self, stat_maps):
"""Fit the model to the data."""
n_studies, n_voxels = stat_maps.shape

pymare_dset = pymare.Dataset(y=stat_maps)
est = pymare.estimators.FisherCombinationTest()
est.fit_dataset(pymare_dset)
est_summary = est.summary()

z_map = est_summary.z.squeeze()
p_map = est_summary.p.squeeze()
dof_map = np.tile(n_studies - 1, n_voxels).astype(np.int32)

return z_map, p_map, dof_map

def _fit(self, dataset):
self.dataset = dataset
self.masker = self.masker or dataset.masker
Expand All @@ -234,32 +249,25 @@ def _fit(self, dataset):
)

if self.aggressive_mask:
pymare_dset = pymare.Dataset(y=self.inputs_["z_maps"])
est = pymare.estimators.FisherCombinationTest()
est.fit_dataset(pymare_dset)
est_summary = est.summary()
voxel_mask = self.inputs_["aggressive_mask"]

z_map = _boolean_unmask(est_summary.z.squeeze(), self.inputs_["aggressive_mask"])
p_map = _boolean_unmask(est_summary.p.squeeze(), self.inputs_["aggressive_mask"])
dof_map = np.tile(
self.inputs_["z_maps"].shape[0] - 1,
self.inputs_["z_maps"].shape[1],
).astype(np.int32)
dof_map = _boolean_unmask(dof_map, self.inputs_["aggressive_mask"])
z_map, p_map, dof_map = self._fit_model(self.inputs_["z_maps"][:, voxel_mask])

z_map = _boolean_unmask(z_map, voxel_mask)
p_map = _boolean_unmask(p_map, voxel_mask)
dof_map = _boolean_unmask(dof_map, voxel_mask)

else:
n_total_voxels = self.inputs_["z_maps"].shape[1]
z_map = np.zeros(n_total_voxels, dtype=float)
p_map = np.zeros(n_total_voxels, dtype=float)
dof_map = np.zeros(n_total_voxels, dtype=np.int32)
for bag in self.inputs_["data_bags"]["z_maps"]:
pymare_dset = pymare.Dataset(y=bag["values"])
est = pymare.estimators.FisherCombinationTest()
est.fit_dataset(pymare_dset)
est_summary = est.summary()
z_map[bag["voxel_mask"]] = est_summary.z.squeeze()
p_map[bag["voxel_mask"]] = est_summary.p.squeeze()
dof_map[bag["voxel_mask"]] = bag["values"].shape[0] - 1
z_map_temp, p_map_temp, dof_map_temp = self._fit_model(bag["values"])

z_map[bag["voxel_mask"]] = z_map_temp
p_map[bag["voxel_mask"]] = p_map_temp
dof_map[bag["voxel_mask"]] = dof_map_temp

maps = {"z": z_map, "p": p_map, "dof": dof_map}
description = self._generate_description()
Expand All @@ -277,6 +285,7 @@ class Stouffers(IBMAEstimator):
.. versionchanged:: 0.2.3
* Add correction for multiple contrasts within a study.
* New parameter: ``use_group_size`` to use publication group sizes for weights.
.. versionchanged:: 0.2.1
Expand Down Expand Up @@ -339,6 +348,8 @@ def __init__(self, use_sample_size=False, use_group_size=False, **kwargs):
if self.use_sample_size:
self._required_inputs["sample_sizes"] = ("metadata", "sample_sizes")

self.use_group_size = use_group_size

def _preprocess_input(self, dataset):
"""Preprocess additional inputs to the Estimator from the Dataset as needed."""
super()._preprocess_input(dataset)
Expand Down Expand Up @@ -419,12 +430,11 @@ def _fit(self, dataset):
corr = np.corrcoef(self.inputs_["z_maps"], rowvar=True)

if self.aggressive_mask:
stat_maps = self.inputs_["z_maps"]
study_mask = self.dataset.images["id"].isin(self.inputs_["id"])
voxel_mask = self.inputs_["aggressive_mask"]

z_map, p_map, dof_map = self._fit_model(
stat_maps[:, voxel_mask],
self.inputs_["z_maps"][:, voxel_mask],
study_mask,
corr=corr,
)
Expand All @@ -439,12 +449,9 @@ def _fit(self, dataset):
p_map = np.zeros(n_voxels, dtype=float)
dof_map = np.zeros(n_voxels, dtype=np.int32)
for bag in self.inputs_["data_bags"]["z_maps"]:
study_mask = bag["study_mask"]
voxel_mask = bag["voxel_mask"]

z_map_temp, p_map_temp, dof_map_temp = self._fit_model(
bag["values"],
study_mask,
bag["study_mask"],
corr=corr,
)

Expand Down

0 comments on commit d012701

Please sign in to comment.