From 4e6ae99353693e77e7498905e3757abe4781f6da Mon Sep 17 00:00:00 2001 From: lmanan Date: Sun, 21 Jan 2024 20:39:19 -0500 Subject: [PATCH 01/40] Move small object filter to post_process.py --- cellulus/post_process.py | 9 +++++++++ cellulus/utils/mean_shift.py | 18 ------------------ cellulus/utils/misc.py | 18 ++++++++++++++++++ 3 files changed, 27 insertions(+), 18 deletions(-) diff --git a/cellulus/post_process.py b/cellulus/post_process.py index 090e5a5..591d6d9 100644 --- a/cellulus/post_process.py +++ b/cellulus/post_process.py @@ -5,6 +5,7 @@ from cellulus.configs.inference_config import InferenceConfig from cellulus.datasets.meta_data import DatasetMetaData +from cellulus.utils.misc import size_filter def post_process(inference_config: InferenceConfig) -> None: @@ -36,6 +37,7 @@ def post_process(inference_config: InferenceConfig) -> None: ds_postprocessed.attrs["resolution"] = (1,) * dataset_meta_data.num_spatial_dims ds_postprocessed.attrs["offset"] = (0,) * dataset_meta_data.num_spatial_dims + # remove halo for sample in tqdm(range(dataset_meta_data.num_samples)): # first instance label masks are expanded by `grow_distance` # next, expanded instance label masks are shrunk by `shrink_distance` @@ -46,3 +48,10 @@ def post_process(inference_config: InferenceConfig) -> None: distance_background = dtedt(expanded_mask) segmentation[distance_background < inference_config.shrink_distance] = 0 ds_postprocessed[sample, bandwidth_factor, ...] = segmentation + + # size filter - remove small objects + for sample in tqdm(range(dataset_meta_data.num_samples)): + for bandwidth_factor in range(inference_config.num_bandwidths): + ds_postprocessed[sample, bandwidth_factor, ...] = size_filter( + ds_postprocessed[sample, bandwidth_factor], inference_config.min_size + ) diff --git a/cellulus/utils/mean_shift.py b/cellulus/utils/mean_shift.py index 829bf4a..bdf6012 100644 --- a/cellulus/utils/mean_shift.py +++ b/cellulus/utils/mean_shift.py @@ -1,6 +1,5 @@ import numpy as np import torch -from skimage import measure from sklearn.cluster import MeanShift @@ -36,7 +35,6 @@ def mean_shift_segmentation( reduction_probability=reduction_probability, cluster_all=False, )[0] - segmentation = sizefilter(segmentation, min_size) return segmentation @@ -49,22 +47,6 @@ def segment_with_meanshift( return anchor_mean_shift(embedding, mask=mask) + 1 -def sizefilter(segmentation, min_size, filter_non_connected=True): - if min_size == 0: - return segmentation - - if filter_non_connected: - filter_labels = measure.label(segmentation, background=0) - else: - filter_labels = segmentation - ids, sizes = np.unique(filter_labels, return_counts=True) - filter_ids = ids[sizes < min_size] - mask = np.in1d(filter_labels, filter_ids).reshape(filter_labels.shape) - segmentation[mask] = 0 - - return segmentation - - class AnchorMeanshift: def __init__(self, bandwidth, reduction_probability, cluster_all): self.mean_shift = MeanShift(bandwidth=bandwidth, cluster_all=cluster_all) diff --git a/cellulus/utils/misc.py b/cellulus/utils/misc.py index 490022a..2f77c72 100644 --- a/cellulus/utils/misc.py +++ b/cellulus/utils/misc.py @@ -4,6 +4,24 @@ from zipfile import ZipFile import matplotlib.pyplot as plt +import numpy as np +from skimage import measure + + +def size_filter(segmentation, min_size, filter_non_connected=True): + if min_size == 0: + return segmentation + + if filter_non_connected: + filter_labels = measure.label(segmentation, background=0) + else: + filter_labels = segmentation + ids, sizes = np.unique(filter_labels, return_counts=True) + filter_ids = ids[sizes < min_size] + mask = np.in1d(filter_labels, filter_ids).reshape(filter_labels.shape) + segmentation[mask] = 0 + + return segmentation def extract_data(zip_url, data_dir, project_name): From 194cc05815b2d5ab8ccda7abf0f9ec212ae85672 Mon Sep 17 00:00:00 2001 From: lmanan Date: Tue, 30 Jan 2024 15:55:40 -0500 Subject: [PATCH 02/40] Update docstrings --- cellulus/configs/dataset_config.py | 5 ++- cellulus/configs/experiment_config.py | 7 ++-- cellulus/configs/inference_config.py | 58 +++++++++++++++++++++------ cellulus/configs/model_config.py | 15 +++---- cellulus/configs/train_config.py | 9 +++-- 5 files changed, 66 insertions(+), 28 deletions(-) diff --git a/cellulus/configs/dataset_config.py b/cellulus/configs/dataset_config.py index f63a3ec..cc304e9 100644 --- a/cellulus/configs/dataset_config.py +++ b/cellulus/configs/dataset_config.py @@ -8,7 +8,8 @@ class DatasetConfig: """Dataset configuration. - Parameters: + Parameters + ---------- container_path: @@ -16,7 +17,7 @@ class DatasetConfig: dataset_name: - The name of the dataset containing raw data in the container. + The name of the dataset containing the raw data in the container. secondary_dataset_name: diff --git a/cellulus/configs/experiment_config.py b/cellulus/configs/experiment_config.py index bfb8305..d2bb761 100644 --- a/cellulus/configs/experiment_config.py +++ b/cellulus/configs/experiment_config.py @@ -13,13 +13,14 @@ class ExperimentConfig: """Top-level config for an experiment (containing training and prediction). - Parameters: + Parameters + ---------- experiment_name: (default = 'YYYY-MM-DD') A unique name for the experiment. - object_size: (default = 26.0) + object_size: (default = 26) A rough estimate of the size of objects in the image, given in world units. The "patch size" of the network will be chosen based @@ -42,7 +43,7 @@ class ExperimentConfig: experiment_name: str = attrs.field( default=datetime.today().strftime("%Y-%m-%d"), validator=instance_of(str) ) - object_size: float = attrs.field(default=26.0, validator=instance_of(float)) + object_size: int = attrs.field(default=26, validator=instance_of(int)) train_config: TrainConfig = attrs.field( default=None, converter=to_config(TrainConfig) diff --git a/cellulus/configs/inference_config.py b/cellulus/configs/inference_config.py index 8bd4159..ed2e3d4 100644 --- a/cellulus/configs/inference_config.py +++ b/cellulus/configs/inference_config.py @@ -1,7 +1,7 @@ from typing import List import attrs -from attrs.validators import instance_of +from attrs.validators import in_, instance_of from .dataset_config import DatasetConfig from .utils import to_config @@ -11,7 +11,8 @@ class InferenceConfig: """Inference configuration. - Parameters: + Parameters + ---------- dataset_config: @@ -33,45 +34,72 @@ class InferenceConfig: Configuration object for the ground truth masks. - crop_size: + crop_size (default = [252, 252]): ROI used by the scan node in gunpowder. - p_salt_pepper: + p_salt_pepper (default = 0.01): Fraction of pixels that will have salt-pepper noise. - num_infer_iterations: + num_infer_iterations (default = 16): Number of times the salt-peper noise is added to the raw image. + This is used to infer the foreground and background in the raw image. - bandwidth: + bandwidth (default = None): Band-width used to perform mean-shift clustering on the predicted embeddings. - threshold: + threshold (default = None): Threshold to use for binary partitioning into foreground and background pixel regions. If None, this is figured out automatically by performing Otsu Thresholding on the last channel of the predicted embeddings. - reduction_probability: + min_size (default = None): - - min_size: - - Ignore objects which are smaller than min_size number of pixels. + Ignore objects which are smaller than `min_size` number of pixels. device (default = 'cuda:0'): The device to infer on. Set to 'cpu' to infer without GPU. + clustering (default = 'meanshift'): + + How to cluster the embeddings? + Can be one of 'meanshift' or 'greedy'. + num_bandwidths (default = 1): Number of bandwidths to obtain segmentations for. + min_size (default = None): + + Objects below `min_size` pixels will be removed. + + post_processing (default= 'morphological'): + + Can be one of 'morphological' or 'intensity' operations. + If 'morphological', the individual detections grow and shrink by + 'grow_distance' and 'shrink_distance' number of pixels. + + If 'intensity', each detection is partitioned based on a binary intensity + threshold calculated automatically from the raw image data. + By default, the channel `0` in the raw image is used for + intensity thresholding. + + grow_distance (default = 3): + + Only used if post_processing (see above) is equal to 'morphological'. + + shrink_distance (default = 6): + + Only used if post_processing (see above) is equal to + 'morphological'. + """ dataset_config: DatasetConfig = attrs.field( @@ -100,6 +128,9 @@ class InferenceConfig: threshold = attrs.field( default=None, validator=attrs.validators.optional(instance_of(float)) ) + clustering = attrs.field( + default="meanshift", validator=in_(["meanshift", "greedy"]) + ) bandwidth = attrs.field( default=None, validator=attrs.validators.optional(instance_of(int)) ) @@ -108,5 +139,8 @@ class InferenceConfig: min_size = attrs.field( default=None, validator=attrs.validators.optional(instance_of(int)) ) + post_processing = attrs.field( + default="morphological", validator=in_(["morphological", "intensity"]) + ) grow_distance = attrs.field(default=3, validator=instance_of(int)) shrink_distance = attrs.field(default=6, validator=instance_of(int)) diff --git a/cellulus/configs/model_config.py b/cellulus/configs/model_config.py index 7920618..e74c9f7 100644 --- a/cellulus/configs/model_config.py +++ b/cellulus/configs/model_config.py @@ -9,9 +9,10 @@ @attrs.define class ModelConfig: - """Model configuration. + """Model Configuration. - Parameters: + Parameters + ---------- num_fmaps: @@ -22,27 +23,27 @@ class ModelConfig: The factor by which to increase the number of feature maps between levels of the U-Net. - features_in_last_layer (optional, default = 64): + features_in_last_layer (default = 64): The number of feature channels in the last layer of the U-Net - downsampling_factors: + downsampling_factors (default = [[2,2]]): A list of downsampling factors, each given per dimension (e.g., [[2,2], [3,3]] would correspond to two downsample layers, one with an isotropic factor of 2, and another one with 3). This parameter will also determine the number of levels in the U-Net. - checkpoint (optional, default ``None``): + checkpoint (default = None): A path to a checkpoint of the network. Needs to be set for networks that are used for prediction. If set during training, the checkpoint will be used to resume training, otherwise the network will be trained from scratch. - initialize (default: True) + initialize (default = True) - If True, initialize the model weights with Kaiming Normal + If True, initialize the model weights with Kaiming Normal. """ diff --git a/cellulus/configs/train_config.py b/cellulus/configs/train_config.py index 093f02c..f085e2b 100644 --- a/cellulus/configs/train_config.py +++ b/cellulus/configs/train_config.py @@ -11,18 +11,19 @@ class TrainConfig: """Train configuration. - Parameters: + Parameters + ---------- - crop_size: + crop_size (default = [252, 252]): The size of the crops - specified as a list of number of pixels - extracted from the raw images, used during training. - batch_size: + batch_size (default = 8): The number of samples to use per batch. - max_iterations: + max_iterations (default = 100000): The maximum number of iterations to train for. From f580a7a6ef64507eb5579d87f7855d875165f15d Mon Sep 17 00:00:00 2001 From: lmanan Date: Tue, 30 Jan 2024 15:56:15 -0500 Subject: [PATCH 03/40] Rename ds to ds_groundtruth --- cellulus/evaluate.py | 28 ++++++++++++---------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/cellulus/evaluate.py b/cellulus/evaluate.py index 359c06e..57a0067 100644 --- a/cellulus/evaluate.py +++ b/cellulus/evaluate.py @@ -11,20 +11,17 @@ def evaluate(inference_config: InferenceConfig) -> None: dataset_meta_data = DatasetMetaData.from_dataset_config(dataset_config) f = zarr.open(inference_config.evaluation_dataset_config.container_path) - ds = f[inference_config.evaluation_dataset_config.secondary_dataset_name] - - f_segmentation = zarr.open( - inference_config.evaluation_dataset_config.container_path - ) - ds_segmentation = f_segmentation[ - inference_config.evaluation_dataset_config.dataset_name + ds_segmentation = f[ + inference_config.evaluation_dataset_config.secondary_dataset_name ] + ds_groundtruth = f[inference_config.evaluation_dataset_config.dataset_name] + for bandwidth in range(inference_config.num_bandwidths): F1_list, SEG_list, TP_list, FP_list, FN_list = [], [], [], [], [] SEG_dataset, n_ids_dataset = 0, 0 for sample in tqdm(range(dataset_meta_data.num_samples)): - groundtruth = ds[sample, 0].astype(np.uint16) + groundtruth = ds_groundtruth[sample, 0].astype(np.uint16) prediction = ds_segmentation[sample, bandwidth].astype(np.uint16) IoU, SEG_image, n_GTids_image = compute_pairwise_IoU( prediction, groundtruth @@ -38,8 +35,6 @@ def evaluate(inference_config: InferenceConfig) -> None: FP_list.append(FP_image) FN_list.append(FN_image) print(f"{sample}:, F1={F1_image:.3f}, SEG={SEG_image/n_GTids_image:.3f}") - print(f"The mean F1 score is {np.mean(F1_list)}") - print(f"The mean SEG score is {np.mean(SEG_list)}") F1_dataset = 2 * sum(TP_list) / (2 * sum(TP_list) + sum(FP_list) + sum(FN_list)) @@ -52,13 +47,14 @@ def evaluate(inference_config: InferenceConfig) -> None: f.writelines("+++++++++++++++++++++++++++++++++\n") for sample in range(dataset_meta_data.num_samples): f.writelines( - f"{sample}, {F1_list[sample]:.05f}, {SEG_list[sample]:.05f}, {TP_list[sample]}, {FP_list[sample]}, {FN_list[sample]}\n" + f"{sample}," + + f" {F1_list[sample]:.05f}," + + f" {SEG_list[sample]:.05f}," + + f" {TP_list[sample]}," + + f" {FP_list[sample]}," + + f" {FN_list[sample]}\n", ) f.writelines("+++++++++++++++++++++++++++++++++\n") - f.writelines(f"Avg. F1 (averaged per sample) is {np.mean(F1_list):.05f} \n") - f.writelines( - f"Avg. SEG (averaged per sample) is {np.mean(SEG_list):.05f} \n" - ) f.writelines(f"F1 for complete dataset is {F1_dataset:.05f} \n") f.writelines( f"SEG for complete dataset is {SEG_dataset/n_ids_dataset:.05f} \n" @@ -91,7 +87,7 @@ def compute_pairwise_IoU(prediction, groundtruth): def compute_F1(IoU_table, threshold=0.5): - IoU_table_thresholded = IoU_table >= threshold + IoU_table_thresholded = IoU_table > threshold FP = np.sum(np.sum(IoU_table_thresholded, axis=1) == 0) FN = np.sum(np.sum(IoU_table_thresholded, axis=0) == 0) TP = IoU_table.shape[1] - FN From 90f3a3925519292e56d6128f2d42a5742fcf0f47 Mon Sep 17 00:00:00 2001 From: lmanan Date: Tue, 30 Jan 2024 15:56:50 -0500 Subject: [PATCH 04/40] Add greedy clustering --- cellulus/utils/greedy_cluster.py | 252 +++++++++++++++++++++++++++++++ 1 file changed, 252 insertions(+) create mode 100644 cellulus/utils/greedy_cluster.py diff --git a/cellulus/utils/greedy_cluster.py b/cellulus/utils/greedy_cluster.py new file mode 100644 index 0000000..e991229 --- /dev/null +++ b/cellulus/utils/greedy_cluster.py @@ -0,0 +1,252 @@ +import numpy as np +import torch + + +class Cluster2d: + """ + Class for Greedy Clustering of Embeddings on 2D samples. + """ + def __init__(self, width, height, fg_mask, device): + """Initializes objects of class `Cluster2d`. + + Parameters + ---------- + + width: + + Width (`W`) of the Raw Image, in number of pixels. + + height: + + Height (`H`) of the Raw Image, in number of pixels. + + fg_mask: (shape is `H` x `W`) + + Foreground Mask corresponding to the region which should be + partitioned into individual objects. + + device: + + Device on which inference is being run. + + """ + + xm = torch.linspace(0, width - 1, width).view(1, 1, -1).expand(1, height, width) + ym = ( + torch.linspace(0, height - 1, height) + .view(1, -1, 1) + .expand(1, height, width) + ) + xym = torch.cat((xm, ym), 0) + self.device = device + self.fg_mask = torch.from_numpy(fg_mask[np.newaxis]).to(self.device) + self.xym = xym.to(self.device) + + def cluster( + self, + prediction, + bandwidth, + min_object_size, + seed_thresh=0.9, + min_unclustered_sum=0, + ): + """Cluster Function. + + Parameters + ---------- + + prediction: (shape is 3 x `H` x `W`) + + Embeddings predicted for the whole raw imnage sample. + + bandwidth: + + Clustering bandwidth or sigma. + + min_object_size: + + Clusters below the `min_object_size` are ignored. + + seed_thresh (default = 0.9): + + Pixels with certainty below 0.9 are ignored to be object + centers. + + min_unclustered_sum (default = 0): + + If number of pixels which have not been clustered yet falls + below min_unclustered_sum, the clustering proces stops. + + + """ + prediction = torch.from_numpy(prediction).to(self.device) + height, width = prediction.size(1), prediction.size(2) + xym_s = self.xym[:, 0:height, 0:width] + embeddings = prediction[0:2] + xym_s # 2 x h x w + seed_map = prediction[2:3] # 1 x h x w + seed_map_min = seed_map.min() + seed_map_max = seed_map.max() + seed_map = (seed_map - seed_map_max) / (seed_map_min - seed_map_max) + instance_map = torch.zeros(height, width).short() + count = 1 + + embeddings_masked = embeddings[self.fg_mask.expand_as(embeddings)].view(2, -1) + seed_map_masked = seed_map[self.fg_mask].view(1, -1) + unclustered = torch.ones(self.fg_mask.sum()).short().to(self.device) + instance_map_masked = torch.zeros(self.fg_mask.sum()).short().to(self.device) + while unclustered.sum() > min_unclustered_sum: + seed = (seed_map_masked * unclustered.float()).argmax().item() + seed_score = (seed_map_masked * unclustered.float()).max().item() + if seed_score < seed_thresh: + break + center = embeddings_masked[:, seed : seed + 1] + unclustered[seed] = 0 + dist = torch.exp( + -1 + * torch.sum( + torch.pow(embeddings_masked - center, 2) / (2 * (bandwidth**2)), 0 + ) + ) + proposal = (dist > 0.5).squeeze() + if proposal.sum() > min_object_size: + if unclustered[proposal].sum().float() / proposal.sum().float() > 0.5: + instance_map_masked[proposal.squeeze()] = count + instance_mask = torch.zeros(height, width).short() + instance_mask[self.fg_mask.squeeze().cpu()] = proposal.short().cpu() + count += 1 + unclustered[proposal] = 0 + instance_map[self.fg_mask.squeeze().cpu()] = instance_map_masked.cpu() + return instance_map + + +class Cluster3d: + """ + Class for Greedy Clustering of Embeddings for 3D samples. + """ + def __init__(self, width, height, depth, fg_mask, device): + + """Initializes objects of class `Cluster3d`. + + Parameters + ---------- + + width: + + Width (`W`) of the Raw Image, in number of pixels. + + height: + + Height (`H`) of the Raw Image, in number of pixels. + + depth: + + Depth (`D`) of the Raw Image, in number of pixels. + + fg_mask: (shape is `D` x `H` x `W`) + + Foreground Mask corresponding to the region which should be + partitioned into individual objects. + + device: + + Device on which inference is being run. + + """ + xm = ( + torch.linspace(0, width - 1, width) + .view(1, 1, 1, -1) + .expand(1, depth, height, width) + ) + ym = ( + torch.linspace(0, height - 1, height) + .view(1, 1, -1, 1) + .expand(1, depth, height, width) + ) + zm = ( + torch.linspace(0, depth - 1, depth) + .view(1, -1, 1, 1) + .expand(1, depth, height, width) + ) + xyzm = torch.cat((xm, ym, zm), 0) + self.device = device + self.fg_mask = torch.from_numpy(fg_mask[np.newaxis]).to(self.device) + self.xyzm = xyzm.to(self.device) + + def cluster( + self, + prediction, + bandwidth, + min_object_size, + seed_thresh=0.9, + min_unclustered_sum=0, + ): + + """Cluster Function.. + + Parameters + ---------- + + prediction: (shape is 3 x `D` x `H` x `W`) + + Embeddings predicted for the whole raw imnage sample. + + bandwidth: + + Clustering bandwidth or sigma. + + min_object_size: + + Clusters below the `min_object_size` are ignored. + + seed_thresh (default = 0.9): + + Pixels with certainty below 0.9 are ignored to be object + centers. + + min_unclustered_sum (default = 0): + + If number of pixels which have not been clustered yet falls + below min_unclustered_sum, the clustering proces stops. + + """ + prediction = torch.from_numpy(prediction).to(self.device) + depth, height, width = ( + prediction.size(1), + prediction.size(2), + prediction.size(3)) + xyzm_s = self.xyzm[:, 0:depth, 0:height, 0:width] + embeddings = prediction[0:3] + xyzm_s # 3 x d x h x w + seed_map = prediction[3:4] # 1 x d x h x w + seed_map_min = seed_map.min() + seed_map_max = seed_map.max() + seed_map = (seed_map - seed_map_max) / (seed_map_min - seed_map_max) + instance_map = torch.zeros(depth, height, width).short() + count = 1 + + embeddings_masked = embeddings[self.fg_mask.expand_as(embeddings)].view(3, -1) + seed_map_masked = seed_map[self.fg_mask].view(1, -1) + unclustered = torch.ones(self.fg_mask.sum()).short().to(self.device) + instance_map_masked = torch.zeros(self.fg_mask.sum()).short().to(self.device) + while unclustered.sum() > min_unclustered_sum: + seed = (seed_map_masked * unclustered.float()).argmax().item() + seed_score = (seed_map_masked * unclustered.float()).max().item() + if seed_score < seed_thresh: + break + center = embeddings_masked[:, seed : seed + 1] + unclustered[seed] = 0 + dist = torch.exp( + -1 + * torch.sum( + torch.pow(embeddings_masked - center, 2) / (2 * (bandwidth**2)), 0 + ) + ) + proposal = (dist > 0.5).squeeze() + if proposal.sum() > min_object_size: + if unclustered[proposal].sum().float() / proposal.sum().float() > 0.5: + instance_map_masked[proposal.squeeze()] = count + instance_mask = torch.zeros(depth, height, width).short() + instance_mask[self.fg_mask.squeeze().cpu()] = proposal.short().cpu() + count += 1 + unclustered[proposal] = 0 + instance_map[self.fg_mask.squeeze().cpu()] = instance_map_masked.cpu() + return instance_map From aa60a59d3c4c65cef50d5c1ee06b41452ead7577 Mon Sep 17 00:00:00 2001 From: lmanan Date: Tue, 30 Jan 2024 15:57:17 -0500 Subject: [PATCH 05/40] Add intensity-based post-processing --- cellulus/post_process.py | 66 ++++++++++++++++++++++++++++++++++------ 1 file changed, 56 insertions(+), 10 deletions(-) diff --git a/cellulus/post_process.py b/cellulus/post_process.py index 591d6d9..a5407e7 100644 --- a/cellulus/post_process.py +++ b/cellulus/post_process.py @@ -1,6 +1,8 @@ import numpy as np import zarr +from scipy.ndimage import binary_fill_holes, label from scipy.ndimage import distance_transform_edt as dtedt +from skimage.filters import threshold_otsu from tqdm import tqdm from cellulus.configs.inference_config import InferenceConfig @@ -38,16 +40,60 @@ def post_process(inference_config: InferenceConfig) -> None: ds_postprocessed.attrs["offset"] = (0,) * dataset_meta_data.num_spatial_dims # remove halo - for sample in tqdm(range(dataset_meta_data.num_samples)): - # first instance label masks are expanded by `grow_distance` - # next, expanded instance label masks are shrunk by `shrink_distance` - for bandwidth_factor in range(inference_config.num_bandwidths): - segmentation = ds[sample, bandwidth_factor] - distance_foreground = dtedt(segmentation == 0) - expanded_mask = distance_foreground < inference_config.grow_distance - distance_background = dtedt(expanded_mask) - segmentation[distance_background < inference_config.shrink_distance] = 0 - ds_postprocessed[sample, bandwidth_factor, ...] = segmentation + if inference_config.post_processing == "morphological": + for sample in tqdm(range(dataset_meta_data.num_samples)): + # first instance label masks are expanded by `grow_distance` + # next, expanded instance label masks are shrunk by `shrink_distance` + for bandwidth_factor in range(inference_config.num_bandwidths): + segmentation = ds[sample, bandwidth_factor] + distance_foreground = dtedt(segmentation == 0) + expanded_mask = distance_foreground < inference_config.grow_distance + distance_background = dtedt(expanded_mask) + segmentation[distance_background < inference_config.shrink_distance] = 0 + ds_postprocessed[sample, bandwidth_factor, ...] = segmentation + elif inference_config.post_processing == "intensity": + ds_raw = f[inference_config.dataset_config.dataset_name] + for sample in tqdm(range(dataset_meta_data.num_samples)): + for bandwidth_factor in range(inference_config.num_bandwidths): + segmentation = ds[sample, bandwidth_factor] + raw_image = ds_raw[sample, 0] + ids = np.unique(segmentation) + ids = ids[ids != 0] + for id_ in ids: + raw_image_masked = raw_image[segmentation == id_] + threshold = threshold_otsu(raw_image_masked) + mask = (segmentation == id_) & (raw_image > threshold) + mask = binary_fill_holes(mask) + if dataset_meta_data.num_spatial_dims == 2: + y, x = np.where(mask) + ds_postprocessed[sample, bandwidth_factor, y, x] = id_ + elif dataset_meta_data.num_spatial_dims == 3: + z, y, x = np.where(mask) + ds_postprocessed[sample, bandwidth_factor, z, y, x] = id_ + + # remove non-connected components + for bandwidth_factor in range(inference_config.num_bandwidths): + ids = np.unique(ds_postprocessed[sample, bandwidth_factor]) + ids = ids[ids != 0] + counter = np.max(ids) + 1 + for id_ in ids: + ma_id = ds_postprocessed[sample, bandwidth_factor] == id_ + array, num_features = label(ma_id) + if num_features > 1: + ids_array = np.unique(array) + ids_array = ids_array[ids_array != 0] + for id_array in ids_array: + if dataset_meta_data.num_spatial_dims == 2: + y, x = np.where(array == id_array) + ds_postprocessed[ + sample, bandwidth_factor, y, x + ] = counter + elif dataset_meta_data.num_spatial_dims == 3: + z, y, x = np.where(array == id_array) + ds_postprocessed[ + sample, bandwidth_factor, z, y, x + ] = counter + counter += 1 # size filter - remove small objects for sample in tqdm(range(dataset_meta_data.num_samples)): From ab5ce5f6ce4798dfbc7e3cfb64bdbac42ffa34d5 Mon Sep 17 00:00:00 2001 From: lmanan Date: Tue, 30 Jan 2024 15:58:51 -0500 Subject: [PATCH 06/40] Specify clustering style based on value of clustering parameter --- cellulus/segment.py | 68 ++++++++++++++++++++++++++++++++------------- 1 file changed, 48 insertions(+), 20 deletions(-) diff --git a/cellulus/segment.py b/cellulus/segment.py index 3b1f6f3..7461e32 100644 --- a/cellulus/segment.py +++ b/cellulus/segment.py @@ -5,6 +5,7 @@ from cellulus.configs.inference_config import InferenceConfig from cellulus.datasets.meta_data import DatasetMetaData +from cellulus.utils.greedy_cluster import Cluster2d, Cluster3d from cellulus.utils.mean_shift import mean_shift_segmentation @@ -118,23 +119,50 @@ def segment(inference_config: InferenceConfig) -> None: embeddings_centered[2] -= c_z ds_object_centered_embeddings[sample] = embeddings_centered - for bandwidth_factor in range(inference_config.num_bandwidths): - segmentation = mean_shift_segmentation( - embeddings_mean, - embeddings_std, - bandwidth=inference_config.bandwidth / (2**bandwidth_factor), - min_size=inference_config.min_size, - reduction_probability=inference_config.reduction_probability, - threshold=threshold, - ) - # Note that the line below is needed - # because the embeddings_mean is modified - # by mean_shift_segmentation - embeddings_mean = embeddings[ - np.newaxis, : dataset_meta_data.num_spatial_dims, ... - ].copy() - ds_segmentation[ - sample, - bandwidth_factor, - ..., - ] = segmentation + if inference_config.clustering == "meanshift": + for bandwidth_factor in range(inference_config.num_bandwidths): + segmentation = mean_shift_segmentation( + embeddings_mean, + embeddings_std, + bandwidth=inference_config.bandwidth / (2**bandwidth_factor), + min_size=inference_config.min_size, + reduction_probability=inference_config.reduction_probability, + threshold=threshold, + ) + # Note that the line below is needed + # because the embeddings_mean is modified + # by mean_shift_segmentation + embeddings_mean = embeddings[ + np.newaxis, : dataset_meta_data.num_spatial_dims, ... + ].copy() + + elif inference_config.clustering == "greedy": + if dataset_meta_data.num_spatial_dims == 3: + cluster3d = Cluster3d( + width=embeddings.shape[-1], + height=embeddings.shape[-2], + depth=embeddings.shape[-3], + fg_mask=binary_mask, + device=inference_config.device, + ) + for bandwidth_factor in range(inference_config.num_bandwidths): + segmentation = cluster3d.cluster( + prediction=embeddings, + bandwidth=inference_config.bandwidth / (2**bandwidth_factor), + min_object_size=inference_config.min_size, + ) + elif dataset_meta_data.num_spatial_dims == 2: + cluster2d = Cluster2d( + width=embeddings.shape[-1], + height=embeddings.shape[-2], + fg_mask=binary_mask, + device=inference_config.device, + ) + for bandwidth_factor in range(inference_config.num_bandwidths): + segmentation = cluster2d.cluster( + prediction=embeddings, + bandwidth=inference_config.bandwidth / (2**bandwidth_factor), + min_object_size=inference_config.min_size, + ) + + ds_segmentation[sample, bandwidth_factor, ...] = segmentation From ebb51a3479b3a0675561ab3edb3e1749e5ad4ce7 Mon Sep 17 00:00:00 2001 From: lmanan Date: Tue, 30 Jan 2024 16:55:36 -0500 Subject: [PATCH 07/40] Update object size to be float type and default size = 30.0 --- cellulus/configs/experiment_config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cellulus/configs/experiment_config.py b/cellulus/configs/experiment_config.py index d2bb761..14b1f03 100644 --- a/cellulus/configs/experiment_config.py +++ b/cellulus/configs/experiment_config.py @@ -20,7 +20,7 @@ class ExperimentConfig: A unique name for the experiment. - object_size: (default = 26) + object_size: (default = 30.0) A rough estimate of the size of objects in the image, given in world units. The "patch size" of the network will be chosen based @@ -43,7 +43,7 @@ class ExperimentConfig: experiment_name: str = attrs.field( default=datetime.today().strftime("%Y-%m-%d"), validator=instance_of(str) ) - object_size: int = attrs.field(default=26, validator=instance_of(int)) + object_size: float = attrs.field(default=30.0, validator=instance_of(float)) train_config: TrainConfig = attrs.field( default=None, converter=to_config(TrainConfig) From 5848c1eb6938e9b6ff80d787a8e473b5abc273bf Mon Sep 17 00:00:00 2001 From: lmanan Date: Tue, 30 Jan 2024 16:56:03 -0500 Subject: [PATCH 08/40] Set best_loss to be the least total loss --- cellulus/train.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/cellulus/train.py b/cellulus/train.py index 9c255e8..a18b8f2 100644 --- a/cellulus/train.py +++ b/cellulus/train.py @@ -127,8 +127,8 @@ def lambda_(iteration): logger.plot() if iteration % train_config.save_model_every == 0: - is_lowest = oce_loss < lowest_loss - lowest_loss = min(oce_loss, lowest_loss) + is_lowest = loss < lowest_loss + lowest_loss = min(loss, lowest_loss) state = { "iteration": iteration, "lowest_loss": lowest_loss, @@ -150,11 +150,6 @@ def train_iteration(batch, model, criterion, optimizer, device): model.train() prediction = model(batch.to(device)) loss, oce_loss, regularization_loss = criterion(prediction) - loss, oce_loss, regularization_loss = ( - loss.mean(), - oce_loss.mean(), - regularization_loss.mean(), - ) optimizer.zero_grad() loss.backward() optimizer.step() From a711b03caf1ea5efe79cff360e6c026c894c0e8e Mon Sep 17 00:00:00 2001 From: lmanan Date: Tue, 30 Jan 2024 16:56:54 -0500 Subject: [PATCH 09/40] Refactor greedy clustering --- cellulus/utils/greedy_cluster.py | 39 ++++++++++++++++---------------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/cellulus/utils/greedy_cluster.py b/cellulus/utils/greedy_cluster.py index e991229..ee48ab4 100644 --- a/cellulus/utils/greedy_cluster.py +++ b/cellulus/utils/greedy_cluster.py @@ -6,14 +6,15 @@ class Cluster2d: """ Class for Greedy Clustering of Embeddings on 2D samples. """ + def __init__(self, width, height, fg_mask, device): """Initializes objects of class `Cluster2d`. Parameters ---------- - - width: - + + width: + Width (`W`) of the Raw Image, in number of pixels. height: @@ -25,12 +26,12 @@ def __init__(self, width, height, fg_mask, device): Foreground Mask corresponding to the region which should be partitioned into individual objects. - device: + device: Device on which inference is being run. """ - + xm = torch.linspace(0, width - 1, width).view(1, 1, -1).expand(1, height, width) ym = ( torch.linspace(0, height - 1, height) @@ -59,7 +60,7 @@ def cluster( Embeddings predicted for the whole raw imnage sample. - bandwidth: + bandwidth: Clustering bandwidth or sigma. @@ -67,7 +68,7 @@ def cluster( Clusters below the `min_object_size` are ignored. - seed_thresh (default = 0.9): + seed_thresh (default = 0.9): Pixels with certainty below 0.9 are ignored to be object centers. @@ -77,7 +78,7 @@ def cluster( If number of pixels which have not been clustered yet falls below min_unclustered_sum, the clustering proces stops. - + """ prediction = torch.from_numpy(prediction).to(self.device) height, width = prediction.size(1), prediction.size(2) @@ -123,15 +124,15 @@ class Cluster3d: """ Class for Greedy Clustering of Embeddings for 3D samples. """ + def __init__(self, width, height, depth, fg_mask, device): - """Initializes objects of class `Cluster3d`. Parameters ---------- - - width: - + + width: + Width (`W`) of the Raw Image, in number of pixels. height: @@ -140,14 +141,14 @@ def __init__(self, width, height, depth, fg_mask, device): depth: - Depth (`D`) of the Raw Image, in number of pixels. + Depth (`D`) of the Raw Image, in number of pixels. fg_mask: (shape is `D` x `H` x `W`) Foreground Mask corresponding to the region which should be partitioned into individual objects. - device: + device: Device on which inference is being run. @@ -180,7 +181,6 @@ def cluster( seed_thresh=0.9, min_unclustered_sum=0, ): - """Cluster Function.. Parameters @@ -190,7 +190,7 @@ def cluster( Embeddings predicted for the whole raw imnage sample. - bandwidth: + bandwidth: Clustering bandwidth or sigma. @@ -198,7 +198,7 @@ def cluster( Clusters below the `min_object_size` are ignored. - seed_thresh (default = 0.9): + seed_thresh (default = 0.9): Pixels with certainty below 0.9 are ignored to be object centers. @@ -208,12 +208,13 @@ def cluster( If number of pixels which have not been clustered yet falls below min_unclustered_sum, the clustering proces stops. - """ + """ prediction = torch.from_numpy(prediction).to(self.device) depth, height, width = ( prediction.size(1), prediction.size(2), - prediction.size(3)) + prediction.size(3), + ) xyzm_s = self.xyzm[:, 0:depth, 0:height, 0:width] embeddings = prediction[0:3] + xyzm_s # 3 x d x h x w seed_map = prediction[3:4] # 1 x d x h x w From cae9377ced84a2527a856c2338b7d8478a76130c Mon Sep 17 00:00:00 2001 From: lmanan Date: Wed, 31 Jan 2024 13:29:01 -0500 Subject: [PATCH 10/40] Cast torch tensor to float type before moving to device --- cellulus/utils/greedy_cluster.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cellulus/utils/greedy_cluster.py b/cellulus/utils/greedy_cluster.py index ee48ab4..b6f2add 100644 --- a/cellulus/utils/greedy_cluster.py +++ b/cellulus/utils/greedy_cluster.py @@ -80,7 +80,7 @@ def cluster( """ - prediction = torch.from_numpy(prediction).to(self.device) + prediction = torch.from_numpy(prediction).float().to(self.device) height, width = prediction.size(1), prediction.size(2) xym_s = self.xym[:, 0:height, 0:width] embeddings = prediction[0:2] + xym_s # 2 x h x w From 40f36147028320f993ed8a02701a14f3caacaf8e Mon Sep 17 00:00:00 2001 From: lmanan Date: Wed, 31 Jan 2024 13:29:30 -0500 Subject: [PATCH 11/40] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 710bf15..625644a 100644 --- a/README.md +++ b/README.md @@ -40,7 +40,7 @@ pip install -e . ### Getting Started -Try out `2D Example` available **[here](https://funkelab.github.io/cellulus)**. +Try out the `2D Example` available **[here](https://funkelab.github.io/cellulus)**. ### Citation From c42c69e54fb8d57a0698887704b31556aa3bce98 Mon Sep 17 00:00:00 2001 From: lmanan Date: Wed, 31 Jan 2024 13:30:02 -0500 Subject: [PATCH 12/40] Rename API Tab to be API Reference --- mkdocs.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mkdocs.yml b/mkdocs.yml index 7513185..e13f6dc 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -65,7 +65,7 @@ nav: - 01: examples/2d/01-data.py - 02: examples/2d/02-train.py - 03: examples/2d/03-infer.py - - API: + - API Reference: - Configs: - DatasetConfig: api/dataset_config.md - ExperimentConfig: api/experiment_config.md From 5c637918517ed0fd507dc009d1726ae4e5371261 Mon Sep 17 00:00:00 2001 From: lmanan Date: Wed, 31 Jan 2024 14:32:19 -0500 Subject: [PATCH 13/40] Update 2d example files --- docs/examples/2d/01-data.py | 3 ++- docs/examples/2d/02-train.py | 1 + docs/examples/2d/03-infer.py | 9 ++++++--- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/docs/examples/2d/01-data.py b/docs/examples/2d/01-data.py index 5dcd358..83dc89a 100644 --- a/docs/examples/2d/01-data.py +++ b/docs/examples/2d/01-data.py @@ -1,6 +1,7 @@ # # Download Data -# In this notebook, we will download data and convert it to a zarr dataset. +# In this notebook, we will download data and convert it to a zarr dataset.
+# This tutorial was written by Henry Westmacott and Manan Lalit. # For demonstration, we will use a subset of images of `Fluo-N2DL-HeLa` available on the [Cell Tracking Challenge](http://celltrackingchallenge.net/2d-datasets/) webpage. diff --git a/docs/examples/2d/02-train.py b/docs/examples/2d/02-train.py index efff05c..5d83938 100644 --- a/docs/examples/2d/02-train.py +++ b/docs/examples/2d/02-train.py @@ -56,3 +56,4 @@ # + # from cellulus.train import train # train(experiment_config) +# - diff --git a/docs/examples/2d/03-infer.py b/docs/examples/2d/03-infer.py index f3e747a..5fd0f6f 100644 --- a/docs/examples/2d/03-infer.py +++ b/docs/examples/2d/03-infer.py @@ -16,6 +16,7 @@ from cellulus.configs.model_config import ModelConfig from cellulus.infer import infer from cellulus.utils.misc import visualize_2d +from IPython.utils import io from matplotlib.colors import ListedColormap # ## Specify config values for datasets @@ -79,6 +80,7 @@ prediction_dataset_config=asdict(prediction_dataset_config), segmentation_dataset_config=asdict(segmentation_dataset_config), post_processed_dataset_config=asdict(post_processed_dataset_config), + post_processing="intensity", device=device, ) @@ -91,9 +93,10 @@ ) # Now we are ready to start the inference!!
-# (This takes around 7 minutes on a Mac Book Pro with an Apple M2 Max chip). +# (This takes around 7 minutes on a Mac Book Pro with an Apple M2 Max chip. To see the output of the cell below, remove the first line `io.capture_output()`). -infer(experiment_config) +with io.capture_output() as captured: + infer(experiment_config) # ## Inspect predictions @@ -123,7 +126,7 @@ top_right=embedding[-1], bottom_left=embedding[0], bottom_right=embedding[1], - top_right_label="STD_DEV", + top_right_label="UNCERTAINTY", bottom_left_label="OFFSET_X", bottom_right_label="OFFSET_Y", ) From 29ede2fb54654737919cf87c3a9035108857f15d Mon Sep 17 00:00:00 2001 From: lmanan Date: Sat, 3 Feb 2024 15:35:01 -0500 Subject: [PATCH 14/40] Introduce elastic_deform parameter --- cellulus/configs/train_config.py | 10 +++++++--- cellulus/datasets/__init__.py | 2 ++ cellulus/datasets/zarr_dataset.py | 15 +++++++++++---- 3 files changed, 20 insertions(+), 7 deletions(-) diff --git a/cellulus/configs/train_config.py b/cellulus/configs/train_config.py index f085e2b..f05f571 100644 --- a/cellulus/configs/train_config.py +++ b/cellulus/configs/train_config.py @@ -63,16 +63,20 @@ class TrainConfig: The number of sub-processes to use for data-loading. + elastic_deform (default = False): + + If set to True, the data is elastically deformed in order to increase training samples. + control_point_spacing (default = 64): The distance in pixels between control points used for elastic - deformation of the raw data during training. + deformation of the raw data during training. Only used if `elastic_deform` is set to True. control_point_jitter (default = 2.0): How much to jitter the control points for elastic deformation of the raw data during training, given as the standard deviation of - a normal distribution with zero mean. + a normal distribution with zero mean. Only used if `elastic_deform` is set to True. train_data_config: @@ -110,7 +114,7 @@ class TrainConfig: save_model_every: int = attrs.field(default=1_000, validator=instance_of(int)) save_snapshot_every: int = attrs.field(default=1_000, validator=instance_of(int)) num_workers: int = attrs.field(default=8, validator=instance_of(int)) - + elastic_deform: bool = attrs.field(default=False, validator=instance_of(bool)) control_point_spacing: int = attrs.field(default=64, validator=instance_of(int)) control_point_jitter: float = attrs.field(default=2.0, validator=instance_of(float)) device: str = attrs.field(default="cuda:0", validator=instance_of(str)) diff --git a/cellulus/datasets/__init__.py b/cellulus/datasets/__init__.py index 2b1f464..5bf346c 100644 --- a/cellulus/datasets/__init__.py +++ b/cellulus/datasets/__init__.py @@ -8,12 +8,14 @@ def get_dataset( dataset_config: DatasetConfig, crop_size: Tuple[int, ...], + elastic_deform: bool, control_point_spacing: int, control_point_jitter: float, ) -> ZarrDataset: return ZarrDataset( dataset_config=dataset_config, crop_size=crop_size, + elastic_deform=elastic_deform, control_point_spacing=control_point_spacing, control_point_jitter=control_point_jitter, ) diff --git a/cellulus/datasets/zarr_dataset.py b/cellulus/datasets/zarr_dataset.py index 5453076..f9bbe6d 100644 --- a/cellulus/datasets/zarr_dataset.py +++ b/cellulus/datasets/zarr_dataset.py @@ -14,6 +14,7 @@ def __init__( self, dataset_config: DatasetConfig, crop_size: Tuple[int, ...], + elastic_deform: bool, control_point_spacing: int, control_point_jitter: float, ): @@ -39,20 +40,25 @@ def __init__( should be equal to the input size of the model that predicts the OCEs. + elastic_deform: + + Whether to elastically deform data in order to augment training samples? + control_point_spacing: The distance in pixels between control points used for elastic - deformation of the raw data. + deformation of the raw data. Only used, if `elastic_deform` is set to True. control_point_jitter: How much to jitter the control points for elastic deformation of the raw data, given as the standard deviation of a normal - distribution with zero mean. + distribution with zero mean. Only used if `elastic_deform` is set to True. """ self.dataset_config = dataset_config self.crop_size = crop_size + self.elastic_deform = elastic_deform self.control_point_spacing = control_point_spacing self.control_point_jitter = control_point_jitter self.__read_meta_data() @@ -84,7 +90,9 @@ def __setup_pipeline(self): array_specs={self.raw: raw_spec}, ) + gp.RandomLocation() - + gp.ElasticAugment( + ) + if self.elastic_deform: + self.pipeline += gp.ElasticAugment( control_point_spacing=(self.control_point_spacing,) * self.num_spatial_dims, jitter_sigma=(self.control_point_jitter,) * self.num_spatial_dims, @@ -94,7 +102,6 @@ def __setup_pipeline(self): spatial_dims=self.num_spatial_dims, ) # + gp.SimpleAugment(mirror_only=spatial_dims, transpose_only=spatial_dims) - ) def __yield_sample(self): """An infinite generator of crops.""" From 582e1f180a6acdea04ce5fb0e2a153555b34066b Mon Sep 17 00:00:00 2001 From: lmanan Date: Sat, 3 Feb 2024 15:35:54 -0500 Subject: [PATCH 15/40] Introduce elastic deform parameter --- cellulus/train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/cellulus/train.py b/cellulus/train.py index a18b8f2..30b7c5b 100644 --- a/cellulus/train.py +++ b/cellulus/train.py @@ -27,6 +27,7 @@ def train(experiment_config): train_dataset = get_dataset( dataset_config=train_config.train_data_config, crop_size=tuple(train_config.crop_size), + elastic_deform=train_config.elastic_deform, control_point_spacing=train_config.control_point_spacing, control_point_jitter=train_config.control_point_jitter, ) From 9a532c5d1e44f15ae441046e00d4884c49a2d41e Mon Sep 17 00:00:00 2001 From: lmanan Date: Sat, 3 Feb 2024 15:36:13 -0500 Subject: [PATCH 16/40] Save state at last iteration --- cellulus/train.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/cellulus/train.py b/cellulus/train.py index 30b7c5b..e33764a 100644 --- a/cellulus/train.py +++ b/cellulus/train.py @@ -127,7 +127,10 @@ def lambda_(iteration): logger.write() logger.plot() - if iteration % train_config.save_model_every == 0: + if ( + iteration % train_config.save_model_every == 0 + or iteration == train_config.max_iterations - 1 + ): is_lowest = loss < lowest_loss lowest_loss = min(loss, lowest_loss) state = { From a9b904b2a6ef1bbde41defe0f666fd58edf1684b Mon Sep 17 00:00:00 2001 From: lmanan Date: Sat, 3 Feb 2024 15:38:05 -0500 Subject: [PATCH 17/40] Use index in visualization of results --- docs/examples/2d/02-train.py | 4 ++-- docs/examples/2d/03-infer.py | 13 ++++++------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/docs/examples/2d/02-train.py b/docs/examples/2d/02-train.py index 5d83938..89d73f2 100644 --- a/docs/examples/2d/02-train.py +++ b/docs/examples/2d/02-train.py @@ -33,9 +33,9 @@ # Then, we specify training-specific parameters such as the `device`, which indicates the actual device to run the training on. #
The device could be set equal to `cuda:n` (where `n` is the index of the GPU, for e.g. `cuda:0`), `cpu` or `mps`.
-# We set the `max_iterations` equal to `5e3` for demonstration purposes.
(This takes around 20 minutes on a Mac Book Pro with an Apple M2 Max chip). +# We set the `max_iterations` equal to 5000 for demonstration purposes.
(This takes around 20 minutes on a Mac Book Pro with an Apple M2 Max chip). -device = "mps" +device = "cuda:0" max_iterations = 5000 train_config = TrainConfig( diff --git a/docs/examples/2d/03-infer.py b/docs/examples/2d/03-infer.py index 5fd0f6f..d6c212e 100644 --- a/docs/examples/2d/03-infer.py +++ b/docs/examples/2d/03-infer.py @@ -71,7 +71,7 @@ # Then, we specify inference-specific parameters such as the `device`, which indicates the actual device to run the inference on. #
The device could be set equal to `cuda:n` (where `n` is the index of the GPU, for e.g. `cuda:0`), `cpu` or `mps`. -device = "mps" +device = "cuda:0" # We initialize the `inference_config` which contains our `embeddings_dataset_config`, `segmentation_dataset_config` and `post_processed_dataset_config`. @@ -93,7 +93,7 @@ ) # Now we are ready to start the inference!!
-# (This takes around 7 minutes on a Mac Book Pro with an Apple M2 Max chip. To see the output of the cell below, remove the first line `io.capture_output()`). +# (This takes around 7 minutes on a Mac Book Pro with an Apple M2 Max chip (i.e. `device = 'mps'`). To see the output of the cell below, remove the first line `io.capture_output()`). with io.capture_output() as captured: infer(experiment_config) @@ -112,7 +112,7 @@ # Change the value of `index` below to look at the raw image (left), x-offset (bottom-left), y-offset (bottom-right) and uncertainty of the embedding (top-right). # + -index = 0 +index = 10 f = zarr.open(name + ".zarr") ds = f["train/raw"] @@ -135,8 +135,6 @@ # As you can see the magnitude of the uncertainty of the embedding (top-right) is low for most of the foreground cells.
This enables extraction of the foreground, which is eventually clustered into individual instances. # + -index = 0 - f = zarr.open(name + ".zarr") ds = f["train/raw"] ds2 = f["segmentation"] @@ -145,8 +143,8 @@ visualize_2d( image, top_right=embedding[-1] < skimage.filters.threshold_otsu(embedding[-1]), - bottom_left=ds2[0, 0], - bottom_right=ds3[0, 0], + bottom_left=ds2[index, 0], + bottom_right=ds3[index, 0], top_right_label="THRESHOLDED F.G.", bottom_left_label="SEGMENTATION", bottom_right_label="POSTPROCESSED", @@ -154,3 +152,4 @@ bottom_left_cmap=new_cmp, bottom_right_cmap=new_cmp, ) +# - From 170abfd437effa98d4d793dbc2cd71ba0f19620c Mon Sep 17 00:00:00 2001 From: lmanan Date: Sat, 3 Feb 2024 18:01:58 -0500 Subject: [PATCH 18/40] Check for best model weights every few iterations --- cellulus/configs/train_config.py | 9 ++++++-- cellulus/train.py | 38 +++++++++++++++++++++++++------- 2 files changed, 37 insertions(+), 10 deletions(-) diff --git a/cellulus/configs/train_config.py b/cellulus/configs/train_config.py index f05f571..e6feaf5 100644 --- a/cellulus/configs/train_config.py +++ b/cellulus/configs/train_config.py @@ -51,11 +51,15 @@ class TrainConfig: Neighborhood radius to extract patches from. - save_model_every (default = 1e3): + save_model_every (default = 1000): The model weights are saved every few iterations. - save_snapshot_every (default = 1e3): + save_best_model_every (default = 100): + + The best loss is evaluated every few iterations. + + save_snapshot_every (default = 1000): The zarr snapshot is saved every few iterations. @@ -112,6 +116,7 @@ class TrainConfig: regularizer_weight: float = attrs.field(default=1e-5, validator=instance_of(float)) reduce_mean: bool = attrs.field(default=True, validator=instance_of(bool)) save_model_every: int = attrs.field(default=1_000, validator=instance_of(int)) + save_best_model_every: int = attrs.field(default=100, validator=instance_of(int)) save_snapshot_every: int = attrs.field(default=1_000, validator=instance_of(int)) num_workers: int = attrs.field(default=8, validator=instance_of(int)) elastic_deform: bool = attrs.field(default=False, validator=instance_of(bool)) diff --git a/cellulus/train.py b/cellulus/train.py index e33764a..c435684 100644 --- a/cellulus/train.py +++ b/cellulus/train.py @@ -92,8 +92,9 @@ def lambda_(iteration): # resume training start_iteration = 0 - lowest_loss = 1e0 - + lowest_loss = 1.0 + epoch_loss = 0 + num_iterations = 0 if model_config.checkpoint is None: pass else: @@ -106,6 +107,7 @@ def lambda_(iteration): logger.data = state["logger_data"] # call `train_iteration` + for iteration, batch in tqdm( zip( range(start_iteration, train_config.max_iterations), @@ -127,12 +129,29 @@ def lambda_(iteration): logger.write() logger.plot() + # Check if lowest loss + epoch_loss += loss + num_iterations += 1 + if iteration % train_config.save_best_model_every == 0: + is_lowest = epoch_loss / (num_iterations) < lowest_loss + lowest_loss = min(epoch_loss / num_iterations, lowest_loss) + if is_lowest: + state = { + "iteration": iteration, + "lowest_loss": lowest_loss, + "model_state_dict": model.state_dict(), + "optim_state_dict": optimizer.state_dict(), + "logger_data": logger.data, + } + save_model(state, iteration, is_lowest) + epoch_loss = 0 + num_iterations = 0 + + # Save model at specific intervals if ( iteration % train_config.save_model_every == 0 or iteration == train_config.max_iterations - 1 ): - is_lowest = loss < lowest_loss - lowest_loss = min(loss, lowest_loss) state = { "iteration": iteration, "lowest_loss": lowest_loss, @@ -140,8 +159,9 @@ def lambda_(iteration): "optim_state_dict": optimizer.state_dict(), "logger_data": logger.data, } - save_model(state, iteration, is_lowest) + save_model(state, iteration) + # Save snapshots at specific intervals if iteration % train_config.save_snapshot_every == 0: save_snapshot( batch, @@ -161,12 +181,14 @@ def train_iteration(batch, model, criterion, optimizer, device): def save_model(state, iteration, is_lowest=False): - file_name = os.path.join("models", str(iteration).zfill(6) + ".pth") - torch.save(state, file_name) - print(f"Checkpoint saved at iteration {iteration}") if is_lowest: file_name = os.path.join("models", "best_loss.pth") torch.save(state, file_name) + print(f"Best model weights saved at iteration {iteration}") + else: + file_name = os.path.join("models", str(iteration).zfill(6) + ".pth") + torch.save(state, file_name) + print(f"Checkpoint saved at iteration {iteration}") def save_snapshot(batch, prediction, iteration): From ac1ee4f0ee09e70fa5f36bb9ba67592119510b51 Mon Sep 17 00:00:00 2001 From: lmanan Date: Mon, 5 Feb 2024 15:03:48 -0500 Subject: [PATCH 19/40] Specify axis names, resolution and offset of embeddings prior to populating them --- cellulus/predict.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/cellulus/predict.py b/cellulus/predict.py index de49286..ed3a6f5 100644 --- a/cellulus/predict.py +++ b/cellulus/predict.py @@ -107,6 +107,13 @@ def predict(model: torch.nn.Module, inference_config: InferenceConfig) -> None: dtype=float, ) + ds.attrs["axis_names"] = ["s", "c"] + ["t", "z", "y", "x"][ + -dataset_meta_data.num_spatial_dims : + ] + + ds.attrs["resolution"] = (1,) * dataset_meta_data.num_spatial_dims + ds.attrs["offset"] = (0,) * dataset_meta_data.num_spatial_dims + pipeline = ( gp.ZarrSource( dataset_config.container_path, @@ -128,10 +135,3 @@ def predict(model: torch.nn.Module, inference_config: InferenceConfig) -> None: # request to pipeline for ROI of whole image/volume with gp.build(pipeline): pipeline.request_batch(request) - - ds.attrs["axis_names"] = ["s", "c"] + ["t", "z", "y", "x"][ - -dataset_meta_data.num_spatial_dims : - ] - - ds.attrs["resolution"] = (1,) * dataset_meta_data.num_spatial_dims - ds.attrs["offset"] = (0,) * dataset_meta_data.num_spatial_dims From ca0a26bd68ef00f4a6fd00c119033e203e63b3c3 Mon Sep 17 00:00:00 2001 From: lmanan Date: Wed, 7 Feb 2024 15:22:21 -0500 Subject: [PATCH 20/40] Revert "Specify axis names, resolution and offset of embeddings prior to populating them" This reverts commit ac1ee4f0ee09e70fa5f36bb9ba67592119510b51. --- cellulus/predict.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/cellulus/predict.py b/cellulus/predict.py index ed3a6f5..de49286 100644 --- a/cellulus/predict.py +++ b/cellulus/predict.py @@ -107,13 +107,6 @@ def predict(model: torch.nn.Module, inference_config: InferenceConfig) -> None: dtype=float, ) - ds.attrs["axis_names"] = ["s", "c"] + ["t", "z", "y", "x"][ - -dataset_meta_data.num_spatial_dims : - ] - - ds.attrs["resolution"] = (1,) * dataset_meta_data.num_spatial_dims - ds.attrs["offset"] = (0,) * dataset_meta_data.num_spatial_dims - pipeline = ( gp.ZarrSource( dataset_config.container_path, @@ -135,3 +128,10 @@ def predict(model: torch.nn.Module, inference_config: InferenceConfig) -> None: # request to pipeline for ROI of whole image/volume with gp.build(pipeline): pipeline.request_batch(request) + + ds.attrs["axis_names"] = ["s", "c"] + ["t", "z", "y", "x"][ + -dataset_meta_data.num_spatial_dims : + ] + + ds.attrs["resolution"] = (1,) * dataset_meta_data.num_spatial_dims + ds.attrs["offset"] = (0,) * dataset_meta_data.num_spatial_dims From a2bdab9116e701aa730faf48e8c57f358eff52aa Mon Sep 17 00:00:00 2001 From: lmanan Date: Wed, 7 Feb 2024 23:02:55 -0500 Subject: [PATCH 21/40] Handle case when GT is provided sparsely --- cellulus/evaluate.py | 77 ++++++++++++++++++++++++++------------------ 1 file changed, 45 insertions(+), 32 deletions(-) diff --git a/cellulus/evaluate.py b/cellulus/evaluate.py index 57a0067..0004d23 100644 --- a/cellulus/evaluate.py +++ b/cellulus/evaluate.py @@ -18,23 +18,33 @@ def evaluate(inference_config: InferenceConfig) -> None: ds_groundtruth = f[inference_config.evaluation_dataset_config.dataset_name] for bandwidth in range(inference_config.num_bandwidths): - F1_list, SEG_list, TP_list, FP_list, FN_list = [], [], [], [], [] + sample_list, F1_list, SEG_list, TP_list, FP_list, FN_list = ( + [], + [], + [], + [], + [], + [], + ) SEG_dataset, n_ids_dataset = 0, 0 for sample in tqdm(range(dataset_meta_data.num_samples)): groundtruth = ds_groundtruth[sample, 0].astype(np.uint16) prediction = ds_segmentation[sample, bandwidth].astype(np.uint16) - IoU, SEG_image, n_GTids_image = compute_pairwise_IoU( - prediction, groundtruth - ) - F1_image, TP_image, FP_image, FN_image = compute_F1(IoU) - F1_list.append(F1_image) - SEG_list.append(SEG_image / n_GTids_image) - SEG_dataset += SEG_image - n_ids_dataset += n_GTids_image - TP_list.append(TP_image) - FP_list.append(FP_image) - FN_list.append(FN_image) - print(f"{sample}:, F1={F1_image:.3f}, SEG={SEG_image/n_GTids_image:.3f}") + returned_values = compute_pairwise_IoU(prediction, groundtruth) + if returned_values is not None: + IoU, SEG_image, n_GTids_image = returned_values + F1_image, TP_image, FP_image, FN_image = compute_F1(IoU) + F1_list.append(F1_image) + SEG_list.append(SEG_image / n_GTids_image) + SEG_dataset += SEG_image + n_ids_dataset += n_GTids_image + TP_list.append(TP_image) + FP_list.append(FP_image) + FN_list.append(FN_image) + sample_list.append(sample) + print( + f"{sample}:, F1={F1_image:.3f}, SEG={SEG_image/n_GTids_image:.3f}" + ) F1_dataset = 2 * sum(TP_list) / (2 * sum(TP_list) + sum(FP_list) + sum(FN_list)) @@ -45,9 +55,9 @@ def evaluate(inference_config: InferenceConfig) -> None: with open(txt_file, "w") as f: f.writelines("file index, F1, SEG, TP, FP, FN \n") f.writelines("+++++++++++++++++++++++++++++++++\n") - for sample in range(dataset_meta_data.num_samples): + for sample in range(len(sample_list)): f.writelines( - f"{sample}," + f"{sample_list[sample]}," + f" {F1_list[sample]:.05f}," + f" {SEG_list[sample]:.05f}," + f" {TP_list[sample]}," @@ -67,23 +77,26 @@ def compute_pairwise_IoU(prediction, groundtruth): groundtruth_ids = np.unique(groundtruth) groundtruth_ids = groundtruth_ids[groundtruth_ids != 0] # ignore background - IoU_table = np.zeros((len(prediction_ids), len(groundtruth_ids)), dtype=float) - IoG_table = np.zeros((len(prediction_ids), len(groundtruth_ids)), dtype=float) - for j in range(len(prediction_ids)): - for k in range(len(groundtruth_ids)): - intersection = (prediction == prediction_ids[j]) & ( - groundtruth == groundtruth_ids[k] - ) - union = (prediction == prediction_ids[j]) | ( - groundtruth == groundtruth_ids[k] - ) - IoU_table[j, k] = np.sum(intersection) / np.sum(union) - IoG_table[j, k] = np.sum(intersection) / np.sum( - groundtruth == groundtruth_ids[k] - ) - # Note for SEG, we consider it a match if it is strictly - # greater than `0.5` IoU - return IoU_table, np.sum(IoU_table[IoG_table > 0.5]), len(groundtruth_ids) + if len(groundtruth_ids) == 0: + return None + else: + IoU_table = np.zeros((len(prediction_ids), len(groundtruth_ids)), dtype=float) + IoG_table = np.zeros((len(prediction_ids), len(groundtruth_ids)), dtype=float) + for j in range(len(prediction_ids)): + for k in range(len(groundtruth_ids)): + intersection = (prediction == prediction_ids[j]) & ( + groundtruth == groundtruth_ids[k] + ) + union = (prediction == prediction_ids[j]) | ( + groundtruth == groundtruth_ids[k] + ) + IoU_table[j, k] = np.sum(intersection) / np.sum(union) + IoG_table[j, k] = np.sum(intersection) / np.sum( + groundtruth == groundtruth_ids[k] + ) + # Note for SEG, we consider it a match if it is strictly + # greater than `0.5` IoU + return IoU_table, np.sum(IoU_table[IoG_table > 0.5]), len(groundtruth_ids) def compute_F1(IoU_table, threshold=0.5): From 6db2482974c9bb53c5e4c1976339e4a54e850b2a Mon Sep 17 00:00:00 2001 From: lmanan Date: Thu, 22 Feb 2024 10:58:58 -0500 Subject: [PATCH 22/40] Update config docs --- cellulus/configs/train_config.py | 11 +++++++---- cellulus/datasets/zarr_dataset.py | 8 +++++--- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/cellulus/configs/train_config.py b/cellulus/configs/train_config.py index e6feaf5..836aa6c 100644 --- a/cellulus/configs/train_config.py +++ b/cellulus/configs/train_config.py @@ -67,20 +67,23 @@ class TrainConfig: The number of sub-processes to use for data-loading. - elastic_deform (default = False): + elastic_deform (default = True): - If set to True, the data is elastically deformed in order to increase training samples. + If set to True, the data is elastically deformed + in order to increase training samples. control_point_spacing (default = 64): The distance in pixels between control points used for elastic - deformation of the raw data during training. Only used if `elastic_deform` is set to True. + deformation of the raw data during training. + Only used if `elastic_deform` is set to True. control_point_jitter (default = 2.0): How much to jitter the control points for elastic deformation of the raw data during training, given as the standard deviation of - a normal distribution with zero mean. Only used if `elastic_deform` is set to True. + a normal distribution with zero mean. + Only used if `elastic_deform` is set to True. train_data_config: diff --git a/cellulus/datasets/zarr_dataset.py b/cellulus/datasets/zarr_dataset.py index f9bbe6d..d34ba40 100644 --- a/cellulus/datasets/zarr_dataset.py +++ b/cellulus/datasets/zarr_dataset.py @@ -47,13 +47,15 @@ def __init__( control_point_spacing: The distance in pixels between control points used for elastic - deformation of the raw data. Only used, if `elastic_deform` is set to True. + deformation of the raw data. + Only used, if `elastic_deform` is set to True. control_point_jitter: How much to jitter the control points for elastic deformation of the raw data, given as the standard deviation of a normal - distribution with zero mean. Only used if `elastic_deform` is set to True. + distribution with zero mean. + Only used if `elastic_deform` is set to True. """ self.dataset_config = dataset_config @@ -68,7 +70,6 @@ def __init__( f'spatial(temporal) dimensions of the "{self.dataset_config.dataset_name}" ' f"dataset which is {self.num_spatial_dims}, but it is {crop_size}" ) - self.__setup_pipeline() def __iter__(self): @@ -117,6 +118,7 @@ def __yield_sample(self): ) sample = self.pipeline.request_batch(request) + yield sample[self.raw].data[0] def __read_meta_data(self): From 30aba0475dfe7a8a6fd53871bda07536ff128199 Mon Sep 17 00:00:00 2001 From: lmanan Date: Thu, 22 Feb 2024 11:05:54 -0500 Subject: [PATCH 23/40] Update default value of elastic deform --- cellulus/configs/train_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cellulus/configs/train_config.py b/cellulus/configs/train_config.py index 836aa6c..4eeaa14 100644 --- a/cellulus/configs/train_config.py +++ b/cellulus/configs/train_config.py @@ -122,7 +122,7 @@ class TrainConfig: save_best_model_every: int = attrs.field(default=100, validator=instance_of(int)) save_snapshot_every: int = attrs.field(default=1_000, validator=instance_of(int)) num_workers: int = attrs.field(default=8, validator=instance_of(int)) - elastic_deform: bool = attrs.field(default=False, validator=instance_of(bool)) + elastic_deform: bool = attrs.field(default=True, validator=instance_of(bool)) control_point_spacing: int = attrs.field(default=64, validator=instance_of(int)) control_point_jitter: float = attrs.field(default=2.0, validator=instance_of(float)) device: str = attrs.field(default="cuda:0", validator=instance_of(str)) From 26fc61d28d073d55c539af0e7718f55e8df1f88b Mon Sep 17 00:00:00 2001 From: lmanan Date: Mon, 26 Feb 2024 21:30:19 -0500 Subject: [PATCH 24/40] Refactor __init__.py --- cellulus/configs/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cellulus/configs/__init__.py b/cellulus/configs/__init__.py index 8cdc58d..3dcca81 100644 --- a/cellulus/configs/__init__.py +++ b/cellulus/configs/__init__.py @@ -1,4 +1,4 @@ from .dataset_config import DatasetConfig from .experiment_config import ExperimentConfig -__all__ = ["ExperimentConfig", "DatasetConfig"] +__all__ = ["DatasetConfig", "ExperimentConfig"] From 6820af71591bfb4e8130b71400d4c8c8e7f4685b Mon Sep 17 00:00:00 2001 From: lmanan Date: Mon, 26 Feb 2024 21:30:51 -0500 Subject: [PATCH 25/40] Add documentation about dataset_name and secondary_dataset_name --- cellulus/configs/dataset_config.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/cellulus/configs/dataset_config.py b/cellulus/configs/dataset_config.py index cc304e9..75aeefe 100644 --- a/cellulus/configs/dataset_config.py +++ b/cellulus/configs/dataset_config.py @@ -21,8 +21,16 @@ class DatasetConfig: secondary_dataset_name: - The name of the dataset containing the data which needs processing. - + The name of the secondary dataset containing the data which needs + processing. + + 'dataset_name' and 'secondary_dataset_name' can be thought of as the + output and input to a certain task, respectively. + For example, during segmentation, 'dataset_name' would refer to the output + segmentation masks and 'secondary_dataset_name' would refer to the input + predicted embeddings. + During evaluation, 'dataset_name' would refer to the ground truth masks + and 'secondary_dataset_name' would refer to the input segmentation masks. """ From d933c90c0ee575114f51e0937e3e953263153336 Mon Sep 17 00:00:00 2001 From: lmanan Date: Mon, 26 Feb 2024 21:31:07 -0500 Subject: [PATCH 26/40] Add normalization_factor as a config parameter --- cellulus/configs/experiment_config.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/cellulus/configs/experiment_config.py b/cellulus/configs/experiment_config.py index 14b1f03..1c659a6 100644 --- a/cellulus/configs/experiment_config.py +++ b/cellulus/configs/experiment_config.py @@ -20,15 +20,21 @@ class ExperimentConfig: A unique name for the experiment. - object_size: (default = 30.0) + object_size: (default = 30) A rough estimate of the size of objects in the image, given in world units. The "patch size" of the network will be chosen based on this estimate. + normalization_factor: (default = None) + + The factor to use, for dividing the raw image pixel intensities. + If 'None', a factor is chosen based on the dtype of the array . + (e.g., np.uint8 would result in a factor of 1.0/255). + model_config: - The model configuration. + Configuration object for the model. train_config: @@ -43,7 +49,10 @@ class ExperimentConfig: experiment_name: str = attrs.field( default=datetime.today().strftime("%Y-%m-%d"), validator=instance_of(str) ) - object_size: float = attrs.field(default=30.0, validator=instance_of(float)) + normalization_factor: float = attrs.field( + default=None, validator=attrs.validators.optional(instance_of(float)) + ) + object_size: int = attrs.field(default=30, validator=instance_of(int)) train_config: TrainConfig = attrs.field( default=None, converter=to_config(TrainConfig) From 690ac39478fff30e231702d69432e61e1dfe4112 Mon Sep 17 00:00:00 2001 From: lmanan Date: Mon, 26 Feb 2024 21:31:55 -0500 Subject: [PATCH 27/40] Add documentation about reduction_probability config parameter --- cellulus/configs/inference_config.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/cellulus/configs/inference_config.py b/cellulus/configs/inference_config.py index ed2e3d4..707333f 100644 --- a/cellulus/configs/inference_config.py +++ b/cellulus/configs/inference_config.py @@ -28,7 +28,7 @@ class InferenceConfig: post_processed_dataset_config: - Configuration object produced by postprocess.py. + Configuration object produced by post_process.py. evaluation_dataset_config: @@ -49,7 +49,7 @@ class InferenceConfig: bandwidth (default = None): - Band-width used to perform mean-shift clustering on the predicted + Bandwidth used to perform mean-shift clustering on the predicted embeddings. threshold (default = None): @@ -76,6 +76,14 @@ class InferenceConfig: Number of bandwidths to obtain segmentations for. + reduction_probability (default = 0.1): + + If set to less than 1.0, this fraction of available pixels are used + to determine the clusters (fitting stage) while performing + meanshift clustering. + Once clusters are available, they are used to predict the cluster assignment + of the remaining pixels (prediction stage). + min_size (default = None): Objects below `min_size` pixels will be removed. @@ -132,15 +140,13 @@ class InferenceConfig: default="meanshift", validator=in_(["meanshift", "greedy"]) ) bandwidth = attrs.field( - default=None, validator=attrs.validators.optional(instance_of(int)) + default=None, validator=attrs.validators.optional(instance_of(float)) ) num_bandwidths = attrs.field(default=1, validator=instance_of(int)) reduction_probability = attrs.field(default=0.1, validator=instance_of(float)) min_size = attrs.field( default=None, validator=attrs.validators.optional(instance_of(int)) ) - post_processing = attrs.field( - default="morphological", validator=in_(["morphological", "intensity"]) - ) + post_processing = attrs.field(default="cell", validator=in_(["cell", "nucleus"])) grow_distance = attrs.field(default=3, validator=instance_of(int)) shrink_distance = attrs.field(default=6, validator=instance_of(int)) From 955c2c482d54dc66017d38fe1e76f76655165107 Mon Sep 17 00:00:00 2001 From: lmanan Date: Mon, 26 Feb 2024 21:32:13 -0500 Subject: [PATCH 28/40] Remove reduce_mean config parameter --- cellulus/configs/train_config.py | 1 - 1 file changed, 1 deletion(-) diff --git a/cellulus/configs/train_config.py b/cellulus/configs/train_config.py index 4eeaa14..fa81a3a 100644 --- a/cellulus/configs/train_config.py +++ b/cellulus/configs/train_config.py @@ -117,7 +117,6 @@ class TrainConfig: kappa: float = attrs.field(default=10.0, validator=instance_of(float)) temperature: float = attrs.field(default=10.0, validator=instance_of(float)) regularizer_weight: float = attrs.field(default=1e-5, validator=instance_of(float)) - reduce_mean: bool = attrs.field(default=True, validator=instance_of(bool)) save_model_every: int = attrs.field(default=1_000, validator=instance_of(int)) save_best_model_every: int = attrs.field(default=100, validator=instance_of(int)) save_snapshot_every: int = attrs.field(default=1_000, validator=instance_of(int)) From 6da255eff6a0bb0a7e79c208238e67a447d983a0 Mon Sep 17 00:00:00 2001 From: lmanan Date: Mon, 26 Feb 2024 21:33:19 -0500 Subject: [PATCH 29/40] Move kappa parameter to ZarrDataset --- cellulus/criterions/__init__.py | 4 - cellulus/criterions/oce_loss.py | 141 +++----------------------------- 2 files changed, 12 insertions(+), 133 deletions(-) diff --git a/cellulus/criterions/__init__.py b/cellulus/criterions/__init__.py index ea57d4f..e4dbbfc 100644 --- a/cellulus/criterions/__init__.py +++ b/cellulus/criterions/__init__.py @@ -5,17 +5,13 @@ def get_loss( temperature, regularizer_weight, density, - kappa, num_spatial_dims, - reduce_mean, device, ): return OCELoss( temperature, regularizer_weight, density, - kappa, num_spatial_dims, - reduce_mean, device, ) diff --git a/cellulus/criterions/oce_loss.py b/cellulus/criterions/oce_loss.py index 31df503..d984d30 100644 --- a/cellulus/criterions/oce_loss.py +++ b/cellulus/criterions/oce_loss.py @@ -1,4 +1,3 @@ -import numpy as np import torch import torch.nn as nn @@ -9,9 +8,7 @@ def __init__( temperature: float, regularization_weight: float, density: float, - kappa: float, num_spatial_dims: int, - reduce_mean: bool, device: torch.device, ): """Class definition for loss. @@ -30,17 +27,9 @@ def __init__( Determines the fraction of patches to sample per crop, during training. - kappa: - Neighborhood radius to extract patches from. - num_spatial_dims: Should be equal to 2 for 2D and 3 for 3D. - reduce_mean: - Should be set to True if the loss should be averaged over all - pixels, and set to False, if the sum of the loss over all pixels - is expected. - device: The device to train on. Set to 'cpu' to train without GPU. @@ -50,131 +39,25 @@ def __init__( self.temperature = temperature self.regularization_weight = regularization_weight self.density = density - self.kappa = kappa self.num_spatial_dims = num_spatial_dims - self.reduce_mean = reduce_mean self.device = device - def distance_function(self, e0, e1): - diff = e0 - e1 - return diff.norm(2, dim=-1) + @staticmethod + def distance_function(embedding_0, embedding_1): + difference = embedding_0 - embedding_1 + return difference.norm(2, dim=-1) - def nonlinearity(self, distance): + def non_linearity(self, distance): return 1 - (-distance.pow(2) / self.temperature).exp() - def forward(self, prediction): - if self.num_spatial_dims == 2: - b, c, h, w = prediction.shape - - h_ = h - 2 * self.kappa # unbiased height - w_ = w - 2 * self.kappa # unbiased width - - num_anchors = int(self.density * h_ * w_) - anchor_coordinates_y = np.random.randint( - self.kappa, h - self.kappa, num_anchors - ) - anchor_coordinates_x = np.random.randint( - self.kappa, w - self.kappa, num_anchors - ) - anchor_coordinates = np.stack( - (anchor_coordinates_x, anchor_coordinates_y), axis=1 - ) # N x 2 - elif self.num_spatial_dims == 3: - b, c, d, h, w = prediction.shape - - d_ = d - 2 * self.kappa # unbiased depth - h_ = h - 2 * self.kappa # unbiased height - w_ = w - 2 * self.kappa # unbiased width - - num_anchors = int(self.density * d_ * h_ * w_) - anchor_coordinates_z = np.random.randint( - self.kappa, d - self.kappa, num_anchors - ) - anchor_coordinates_y = np.random.randint( - self.kappa, h - self.kappa, num_anchors - ) - anchor_coordinates_x = np.random.randint( - self.kappa, w - self.kappa, num_anchors - ) - anchor_coordinates = np.stack( - (anchor_coordinates_x, anchor_coordinates_y, anchor_coordinates_z), - axis=1, - ) # N x 3 - num_references = int(self.density * np.pi * self.kappa**2) - anchor_coordinates = np.repeat(anchor_coordinates, num_references, axis=0) - offsets = self.sample_offsets( - radius=self.kappa, - num_samples=len(anchor_coordinates), - ) - reference_coordinates = anchor_coordinates + offsets - anchor_coordinates = anchor_coordinates[np.newaxis, ...] - reference_coordinates = reference_coordinates[np.newaxis, ...] - anchor_coordinates = torch.from_numpy(np.repeat(anchor_coordinates, b, 0)).to( - self.device - ) - reference_coordinates = torch.from_numpy( - np.repeat(reference_coordinates, b, 0) - ).to(self.device) - anchor_embeddings = self.get_embeddings( - prediction, - anchor_coordinates, - ) # B x N x 2/3 - reference_embeddings = self.get_embeddings( - prediction, - reference_coordinates, - ) # B x N x 2/3 + def forward(self, anchor_embedding, reference_embedding): distance = self.distance_function( - anchor_embeddings, reference_embeddings.detach() + anchor_embedding, reference_embedding.detach() ) - oce_loss = self.nonlinearity(distance) - regularization_loss = self.regularization_weight * anchor_embeddings.norm( - 2, dim=-1 + non_linear_distance = self.non_linearity(distance) + oce_loss = non_linear_distance.sum() + regularization_loss = ( + self.regularization_weight * anchor_embedding.norm(2, dim=-1).sum() ) - loss = oce_loss + regularization_loss - if self.reduce_mean: - return loss.mean(), oce_loss.mean(), regularization_loss.mean() - else: - return loss.sum(), oce_loss.sum(), regularization_loss.sum() - - def sample_offsets(self, radius, num_samples): - if self.num_spatial_dims == 2: - offset_x = np.random.randint(-radius, radius + 1, size=2 * num_samples) - offset_y = np.random.randint(-radius, radius + 1, size=2 * num_samples) - - offset_coordinates = np.stack((offset_x, offset_y), axis=1) - elif self.num_spatial_dims == 3: - offset_x = np.random.randint(-radius, radius + 1, size=3 * num_samples) - offset_y = np.random.randint(-radius, radius + 1, size=3 * num_samples) - offset_z = np.random.randint(-radius, radius + 1, size=3 * num_samples) - - offset_coordinates = np.stack((offset_x, offset_y, offset_z), axis=1) - in_circle = (offset_coordinates**2).sum(axis=1) < radius**2 - offset_coordinates = offset_coordinates[in_circle] - not_zero = np.absolute(offset_coordinates).sum(axis=1) > 0 - offset_coordinates = offset_coordinates[not_zero] - if len(offset_coordinates) < num_samples: - return self.sample_offsets(radius, num_samples) - - return offset_coordinates[:num_samples] - - def get_embeddings(self, predictions, coordinates): - selection = [] - for prediction, coordinate in zip(predictions, coordinates): - if self.num_spatial_dims == 2: - embedding = prediction[ - :, coordinate[:, 1].long(), coordinate[:, 0].long() - ] - elif self.num_spatial_dims == 3: - embedding = prediction[ - :, - coordinate[:, 2].long(), - coordinate[:, 1].long(), - coordinate[:, 0].long(), - ] - embedding = embedding.transpose(1, 0) - embedding += coordinate - selection.append(embedding) - - # selection.shape = (b, c, p) where p is the number of selected positions - return torch.stack(selection, dim=0) + return loss, oce_loss, regularization_loss From 433acc5f1b4e061d7bdbbb47d80ab32eb7c77f7b Mon Sep 17 00:00:00 2001 From: lmanan Date: Mon, 26 Feb 2024 21:34:03 -0500 Subject: [PATCH 30/40] Add density, kappa and normalization_factor as parameter in ZarrDataset --- cellulus/datasets/__init__.py | 6 ++ cellulus/datasets/zarr_dataset.py | 109 +++++++++++++++++++++++++++++- 2 files changed, 113 insertions(+), 2 deletions(-) diff --git a/cellulus/datasets/__init__.py b/cellulus/datasets/__init__.py index 5bf346c..c19e532 100644 --- a/cellulus/datasets/__init__.py +++ b/cellulus/datasets/__init__.py @@ -11,6 +11,9 @@ def get_dataset( elastic_deform: bool, control_point_spacing: int, control_point_jitter: float, + density: float, + kappa: int, + normalization_factor: float, ) -> ZarrDataset: return ZarrDataset( dataset_config=dataset_config, @@ -18,4 +21,7 @@ def get_dataset( elastic_deform=elastic_deform, control_point_spacing=control_point_spacing, control_point_jitter=control_point_jitter, + density=density, + kappa=kappa, + normalization_factor=normalization_factor, ) diff --git a/cellulus/datasets/zarr_dataset.py b/cellulus/datasets/zarr_dataset.py index d34ba40..1a6ad78 100644 --- a/cellulus/datasets/zarr_dataset.py +++ b/cellulus/datasets/zarr_dataset.py @@ -2,6 +2,7 @@ from typing import Tuple import gunpowder as gp +import numpy as np from torch.utils.data import IterableDataset from cellulus.configs import DatasetConfig @@ -17,6 +18,9 @@ def __init__( elastic_deform: bool, control_point_spacing: int, control_point_jitter: float, + density: float, + kappa: float, + normalization_factor: float, ): """A dataset that serves random samples from a zarr container. @@ -56,6 +60,20 @@ def __init__( of the raw data, given as the standard deviation of a normal distribution with zero mean. Only used if `elastic_deform` is set to True. + + density: + + Determines the fraction of patches to sample per crop, during training. + + kappa: + + Neighborhood radius to extract patches from. + + normalization_factor: + + The factor to use, for dividing the raw image pixel intensities. + If 'None', a factor is chosen based on the dtype of the array . + (e.g., np.uint8 would result in a factor of 1.0/255). """ self.dataset_config = dataset_config @@ -63,6 +81,7 @@ def __init__( self.elastic_deform = elastic_deform self.control_point_spacing = control_point_spacing self.control_point_jitter = control_point_jitter + self.normalization_factor = normalization_factor self.__read_meta_data() assert len(crop_size) == self.num_spatial_dims, ( @@ -70,6 +89,13 @@ def __init__( f'spatial(temporal) dimensions of the "{self.dataset_config.dataset_name}" ' f"dataset which is {self.num_spatial_dims}, but it is {crop_size}" ) + self.density = density + self.kappa = kappa + self.output_shape = tuple(int(_ - 16) for _ in self.crop_size) + self.normalization_factor = normalization_factor + self.unbiased_shape = tuple( + int(_ - (2 * self.kappa)) for _ in self.output_shape + ) self.__setup_pipeline() def __iter__(self): @@ -91,7 +117,9 @@ def __setup_pipeline(self): array_specs={self.raw: raw_spec}, ) + gp.RandomLocation() + + gp.Normalize(self.raw, factor=self.normalization_factor) ) + if self.elastic_deform: self.pipeline += gp.ElasticAugment( control_point_spacing=(self.control_point_spacing,) @@ -118,8 +146,9 @@ def __yield_sample(self): ) sample = self.pipeline.request_batch(request) - - yield sample[self.raw].data[0] + sample_data = sample[self.raw].data[0] + anchor_samples, reference_samples = self.sample_coordinates() + yield sample_data, anchor_samples, reference_samples def __read_meta_data(self): meta_data = DatasetMetaData.from_dataset_config(self.dataset_config) @@ -137,3 +166,79 @@ def get_num_channels(self): def get_num_spatial_dims(self): return self.num_spatial_dims + + def sample_offsets_within_radius(self, radius, number_offsets): + if self.num_spatial_dims == 2: + offsets_x = np.random.randint(-radius, radius + 1, size=2 * number_offsets) + offsets_y = np.random.randint(-radius, radius + 1, size=2 * number_offsets) + offsets_coordinates = np.stack((offsets_x, offsets_y), axis=1) + elif self.num_spatial_dims == 3: + offsets_x = np.random.randint(-radius, radius + 1, size=3 * number_offsets) + offsets_y = np.random.randint(-radius, radius + 1, size=3 * number_offsets) + offsets_z = np.random.randint(-radius, radius + 1, size=3 * number_offsets) + offsets_coordinates = np.stack((offsets_x, offsets_y, offsets_z), axis=1) + + in_circle = (offsets_coordinates**2).sum(axis=1) < radius**2 + offsets_coordinates = offsets_coordinates[in_circle] + not_zero = np.absolute(offsets_coordinates).sum(axis=1) > 0 + offsets_coordinates = offsets_coordinates[not_zero] + + if len(offsets_coordinates) < number_offsets: + return self.sample_offsets_within_radius(radius, number_offsets) + + return offsets_coordinates[:number_offsets] + + def sample_coordinates(self): + num_anchors = self.get_num_anchors() + num_references = self.get_num_references() + + if self.num_spatial_dims == 2: + anchor_coordinates_x = np.random.randint( + self.kappa, + self.output_shape[0] - self.kappa + 1, + size=num_anchors, + ) + anchor_coordinates_y = np.random.randint( + self.kappa, + self.output_shape[1] - self.kappa + 1, + size=num_anchors, + ) + anchor_coordinates = np.stack( + (anchor_coordinates_x, anchor_coordinates_y), axis=1 + ) + elif self.num_spatial_dims == 3: + anchor_coordinates_x = np.random.randint( + self.kappa, + self.output_shape[0] - self.kappa + 1, + size=num_anchors, + ) + anchor_coordinates_y = np.random.randint( + self.kappa, + self.output_shape[1] - self.kappa + 1, + size=num_anchors, + ) + anchor_coordinates_z = np.random.randint( + self.kappa, + self.output_shape[2] - self.kappa + 1, + size=num_anchors, + ) + anchor_coordinates = np.stack( + (anchor_coordinates_x, anchor_coordinates_y, anchor_coordinates_z), + axis=1, + ) + anchor_samples = np.repeat(anchor_coordinates, num_references, axis=0) + offset_in_pos_radius = self.sample_offsets_within_radius( + self.kappa, len(anchor_samples) + ) + reference_samples = anchor_samples + offset_in_pos_radius + + return anchor_samples, reference_samples + + def get_num_anchors(self): + return int(self.density * self.unbiased_shape[0] * self.unbiased_shape[1]) + + def get_num_references(self): + return int(self.density * self.kappa**2 * np.pi) + + def get_num_samples(self): + return self.get_num_anchors() * self.get_num_references() From e9e133ffbfeedf5996b9b542eaa2c7754d8e65b9 Mon Sep 17 00:00:00 2001 From: lmanan Date: Mon, 26 Feb 2024 21:34:20 -0500 Subject: [PATCH 31/40] Remove comma after colon --- cellulus/evaluate.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/cellulus/evaluate.py b/cellulus/evaluate.py index 0004d23..dad49ea 100644 --- a/cellulus/evaluate.py +++ b/cellulus/evaluate.py @@ -42,9 +42,7 @@ def evaluate(inference_config: InferenceConfig) -> None: FP_list.append(FP_image) FN_list.append(FN_image) sample_list.append(sample) - print( - f"{sample}:, F1={F1_image:.3f}, SEG={SEG_image/n_GTids_image:.3f}" - ) + print(f"{sample}: F1={F1_image:.3f}, SEG={SEG_image/n_GTids_image:.3f}") F1_dataset = 2 * sum(TP_list) / (2 * sum(TP_list) + sum(FP_list) + sum(FN_list)) From 9cae0f809dc62cfb02177f60cc6db41809dcbd1e Mon Sep 17 00:00:00 2001 From: lmanan Date: Mon, 26 Feb 2024 21:35:31 -0500 Subject: [PATCH 32/40] Specify post_processing as one of cell or nucleus --- cellulus/post_process.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cellulus/post_process.py b/cellulus/post_process.py index a5407e7..4c31994 100644 --- a/cellulus/post_process.py +++ b/cellulus/post_process.py @@ -40,7 +40,7 @@ def post_process(inference_config: InferenceConfig) -> None: ds_postprocessed.attrs["offset"] = (0,) * dataset_meta_data.num_spatial_dims # remove halo - if inference_config.post_processing == "morphological": + if inference_config.post_processing == "cell": for sample in tqdm(range(dataset_meta_data.num_samples)): # first instance label masks are expanded by `grow_distance` # next, expanded instance label masks are shrunk by `shrink_distance` @@ -51,7 +51,7 @@ def post_process(inference_config: InferenceConfig) -> None: distance_background = dtedt(expanded_mask) segmentation[distance_background < inference_config.shrink_distance] = 0 ds_postprocessed[sample, bandwidth_factor, ...] = segmentation - elif inference_config.post_processing == "intensity": + elif inference_config.post_processing == "nucleus": ds_raw = f[inference_config.dataset_config.dataset_name] for sample in tqdm(range(dataset_meta_data.num_samples)): for bandwidth_factor in range(inference_config.num_bandwidths): From 3f01fa5e1e0fb0123198c52924b74abb97ad6ab6 Mon Sep 17 00:00:00 2001 From: lmanan Date: Mon, 26 Feb 2024 21:36:24 -0500 Subject: [PATCH 33/40] Remove scheduler from cellulus/train.py --- cellulus/train.py | 51 +++++++++++++++++++++++------------------------ 1 file changed, 25 insertions(+), 26 deletions(-) diff --git a/cellulus/train.py b/cellulus/train.py index c435684..b5a9f46 100644 --- a/cellulus/train.py +++ b/cellulus/train.py @@ -3,7 +3,6 @@ import numpy as np import torch import zarr -from IPython.display import clear_output from tqdm import tqdm from cellulus.criterions import get_loss @@ -30,6 +29,9 @@ def train(experiment_config): elastic_deform=train_config.elastic_deform, control_point_spacing=train_config.control_point_spacing, control_point_jitter=train_config.control_point_jitter, + density=train_config.density, + kappa=train_config.kappa, + normalization_factor=experiment_config.normalization_factor, ) # create train dataloader @@ -69,30 +71,22 @@ def train(experiment_config): criterion = get_loss( regularizer_weight=train_config.regularizer_weight, temperature=train_config.temperature, - kappa=train_config.kappa, density=train_config.density, num_spatial_dims=train_dataset.get_num_spatial_dims(), - reduce_mean=train_config.reduce_mean, device=device, ) # set optimizer optimizer = torch.optim.Adam( - model.parameters(), - lr=train_config.initial_learning_rate, + model.parameters(), lr=train_config.initial_learning_rate, weight_decay=0.01 ) - # set scheduler: - - def lambda_(iteration): - return pow((1 - ((iteration) / train_config.max_iterations)), 0.9) - # set logger logger = get_logger(keys=["loss", "oce_loss"], title="loss") # resume training start_iteration = 0 - lowest_loss = 1.0 + lowest_loss = 1e6 epoch_loss = 0 num_iterations = 0 if model_config.checkpoint is None: @@ -107,22 +101,15 @@ def lambda_(iteration): logger.data = state["logger_data"] # call `train_iteration` - for iteration, batch in tqdm( zip( range(start_iteration, train_config.max_iterations), train_dataloader, ) ): - scheduler = torch.optim.lr_scheduler.LambdaLR( - optimizer, lr_lambda=lambda_, last_epoch=iteration - 1 - ) - loss, oce_loss, prediction = train_iteration( batch, model=model, criterion=criterion, optimizer=optimizer, device=device ) - scheduler.step() - clear_output(wait=True) print(f"===> loss: {loss:.6f}, oce loss: {oce_loss:.6f}") logger.add(key="loss", value=loss) logger.add(key="oce_loss", value=oce_loss) @@ -171,13 +158,26 @@ def lambda_(iteration): def train_iteration(batch, model, criterion, optimizer, device): + raw, anchor_coordinates, reference_coordinates = batch + raw, anchor_coordinates, reference_coordinates = ( + raw.to(device), + anchor_coordinates.to(device), + reference_coordinates.to(device), + ) + model.train() - prediction = model(batch.to(device)) - loss, oce_loss, regularization_loss = criterion(prediction) + offsets = model(raw) + embeddings_anchor = model.select_and_add_coordinates(offsets, anchor_coordinates) + embeddings_reference = model.select_and_add_coordinates( + offsets, reference_coordinates + ) + loss, oce_loss, regularization_loss = criterion( + embeddings_anchor, embeddings_reference + ) optimizer.zero_grad() loss.backward() optimizer.step() - return loss.item(), oce_loss.item(), prediction + return loss.item(), oce_loss.item(), offsets def save_model(state, iteration, is_lowest=False): @@ -192,17 +192,18 @@ def save_model(state, iteration, is_lowest=False): def save_snapshot(batch, prediction, iteration): - num_spatial_dims = len(batch.shape) - 2 + raw, anchor_coordinates, reference_coordinates = batch + num_spatial_dims = len(raw.shape) - 2 axis_names = ["s", "c"] + ["t", "z", "y", "x"][-num_spatial_dims:] prediction_offset = tuple( (a - b) / 2 for a, b in zip( - batch.shape[-num_spatial_dims:], prediction.shape[-num_spatial_dims:] + raw.shape[-num_spatial_dims:], prediction.shape[-num_spatial_dims:] ) ) f = zarr.open("snapshots.zarr", "a") - f[f"{iteration}/raw"] = batch.detach().cpu().numpy() + f[f"{iteration}/raw"] = raw.detach().cpu().numpy() f[f"{iteration}/raw"].attrs["axis_names"] = axis_names f[f"{iteration}/raw"].attrs["resolution"] = [ 1, @@ -221,5 +222,3 @@ def save_snapshot(batch, prediction, iteration): f[f"{iteration}/prediction"].attrs["resolution"] = [ 1, ] * num_spatial_dims - - print(f"Snapshot saved at iteration {iteration}") From 2e68fb51247d23d944852c4a9f2972e0df78affb Mon Sep 17 00:00:00 2001 From: lmanan Date: Mon, 26 Feb 2024 21:36:56 -0500 Subject: [PATCH 34/40] Add normalization node during prediction --- cellulus/predict.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/cellulus/predict.py b/cellulus/predict.py index de49286..439dd24 100644 --- a/cellulus/predict.py +++ b/cellulus/predict.py @@ -6,7 +6,11 @@ from cellulus.datasets.meta_data import DatasetMetaData -def predict(model: torch.nn.Module, inference_config: InferenceConfig) -> None: +def predict( + model: torch.nn.Module, + inference_config: InferenceConfig, + normalization_factor: float, +) -> None: # get the dataset_config data out of inference_config dataset_config = inference_config.dataset_config dataset_meta_data = DatasetMetaData.from_dataset_config(dataset_config) @@ -113,6 +117,7 @@ def predict(model: torch.nn.Module, inference_config: InferenceConfig) -> None: {raw: dataset_config.dataset_name}, {raw: gp.ArraySpec(voxel_size=voxel_size, interpolatable=True)}, ) + + gp.Normalize(raw, factor=normalization_factor) + gp.Pad(raw, context, mode="reflect") + predict + gp.ZarrWrite( From 17b4e94dd7d841d1169e27e2210359621e596e54 Mon Sep 17 00:00:00 2001 From: lmanan Date: Mon, 26 Feb 2024 21:37:36 -0500 Subject: [PATCH 35/40] Modify default value of bandwidth, if not specified by user --- cellulus/infer.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/cellulus/infer.py b/cellulus/infer.py index 9e541d9..5cf884b 100644 --- a/cellulus/infer.py +++ b/cellulus/infer.py @@ -17,6 +17,7 @@ def infer(experiment_config): print(experiment_config) inference_config = experiment_config.inference_config + normalization_factor = experiment_config.normalization_factor model_config = experiment_config.model_config @@ -25,7 +26,7 @@ def infer(experiment_config): ) if inference_config.bandwidth is None: - inference_config.bandwidth = int(0.5 * experiment_config.object_size) + inference_config.bandwidth = int(0.25 * experiment_config.object_size) if inference_config.min_size is None: if dataset_meta_data.num_spatial_dims == 2: @@ -67,13 +68,13 @@ def infer(experiment_config): # get predicted embeddings... if inference_config.prediction_dataset_config is not None: - predict(model, inference_config) + predict(model, inference_config, normalization_factor) # ...turn them into a segmentation... if inference_config.segmentation_dataset_config is not None: segment(inference_config) # ...and post-process the segmentation if inference_config.post_processed_dataset_config is not None: post_process(inference_config) - # ...and evaluate if groundtruth exists + # ...and evaluate if ground-truth exists if inference_config.evaluation_dataset_config is not None: evaluate(inference_config) From 3673beaaae2c02786bcba192cda234e5e8d219da Mon Sep 17 00:00:00 2001 From: lmanan Date: Mon, 26 Feb 2024 21:38:25 -0500 Subject: [PATCH 36/40] Add select_and_add_coordinates method to unet.py --- cellulus/models/unet.py | 34 ++++++++++++++++++++++++++-------- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/cellulus/models/unet.py b/cellulus/models/unet.py index 48ed2bd..6c0a4d5 100644 --- a/cellulus/models/unet.py +++ b/cellulus/models/unet.py @@ -62,12 +62,6 @@ def __init__( nn.Conv3d(self.features_in_last_layer, out_channels, 1), ) - def set_infer(self, p_salt_pepper, num_infer_iterations, device): - self.mode = "infer" - self.p_salt_pepper = p_salt_pepper - self.num_infer_iterations = num_infer_iterations - self.device: torch.device = device - def head_forward(self, backbone_output): out_head = self.head(backbone_output) return out_head @@ -86,12 +80,12 @@ def forward(self, raw): noisy_input = raw_sample.detach().clone() rnd = torch.rand(*noisy_input.shape).to(self.device) noisy_input[rnd <= self.p_salt_pepper] = val - pred = ( + prediction = ( self.head_forward(self.backbone(noisy_input))[0] .detach() .cpu() ) - predictions.append(pred) + predictions.append(prediction) embedding_std, embedding_mean = torch.std_mean( torch.stack(predictions, dim=0), @@ -104,3 +98,27 @@ def forward(self, raw): embeddings.append(torch.cat((embedding_mean, embedding_std), dim=0)) return torch.stack(embeddings, dim=0) + + def set_infer(self, p_salt_pepper, num_infer_iterations, device): + self.mode = "infer" + self.p_salt_pepper = p_salt_pepper + self.num_infer_iterations = num_infer_iterations + self.device: torch.device = device + + @staticmethod + def select_and_add_coordinates(outputs, coordinates): + selections = [] + # outputs.shape = (b, c, h, w) or (b, c, d, h, w) + for output, coordinate in zip(outputs, coordinates): + if output.ndim == 3: + selection = output[:, coordinate[:, 1], coordinate[:, 0]] + elif output.ndim == 4: + selection = output[ + :, coordinate[:, 2], coordinate[:, 1], coordinate[:, 0] + ] + selection = selection.transpose(1, 0) + selection += coordinate + selections.append(selection) + + # selection.shape = (b, c, p) where p is the number of selected positions + return torch.stack(selections, dim=0) From daa525cc26902c7b1c1e64c5751655ea8d607981 Mon Sep 17 00:00:00 2001 From: lmanan Date: Mon, 26 Feb 2024 21:39:18 -0500 Subject: [PATCH 37/40] Update 2d examples --- docs/examples/2d/01-data.py | 11 +++---- docs/examples/2d/02-train.py | 29 ++++++++++------- docs/examples/2d/03-infer.py | 60 ++++++++++++++++++++++++++---------- 3 files changed, 67 insertions(+), 33 deletions(-) diff --git a/docs/examples/2d/01-data.py b/docs/examples/2d/01-data.py index 83dc89a..21fd8c2 100644 --- a/docs/examples/2d/01-data.py +++ b/docs/examples/2d/01-data.py @@ -3,7 +3,9 @@ # In this notebook, we will download data and convert it to a zarr dataset.
# This tutorial was written by Henry Westmacott and Manan Lalit. -# For demonstration, we will use a subset of images of `Fluo-N2DL-HeLa` available on the [Cell Tracking Challenge](http://celltrackingchallenge.net/2d-datasets/) webpage. +# For demonstration, we will use a subset of images of `Fluo-N2DL-HeLa` available +# on the [Cell Tracking Challenge](http://celltrackingchallenge.net/2d-datasets/) +# webpage. # Firstly, the `tif` raw images are downloaded to a directory indicated by `data_dir`. @@ -27,7 +29,7 @@ ) # - -# Next, these raw images are intensity-normalized and appended in a list. Here, we use the percentile normalization technique. +# Next, a channel dimension is added to these images and they are appended in a list. # + container_path = zarr.open(name + ".zarr") @@ -38,10 +40,7 @@ for i in tqdm(range(len(image_filenames))): im = normalize( - tifffile.imread(image_filenames[i]).astype(np.float32), - pmin=1, - pmax=99.8, - axis=(0, 1), + tifffile.imread(image_filenames[i]).astype(np.float32), 1, 99.8, axis=(0, 1) ) image_list.append(im[np.newaxis, ...]) diff --git a/docs/examples/2d/02-train.py b/docs/examples/2d/02-train.py index 89d73f2..66170ea 100644 --- a/docs/examples/2d/02-train.py +++ b/docs/examples/2d/02-train.py @@ -10,7 +10,8 @@ # ## Specify config values for dataset -# In the next cell, we specify the name of the zarr container and the dataset within it from which data would be read. +# In the next cell, we specify the name of the zarr container and the dataset +# within it from which data would be read. name = "2d-data-demo" dataset_name = "train/raw" @@ -21,8 +22,10 @@ # ## Specify config values for model -# In the next cell, we specify the number of feature maps (`num_fmaps`) in the first layer in our model.
-# Additionally, we specify `fmap_inc_factor`, which indicates by how much the number of feature maps increase between adjacent layers. +# In the next cell, we specify the number of feature maps (`num_fmaps`) in the +# first layer in our model.
+# Additionally, we specify `fmap_inc_factor`, which indicates by how much the +# number of feature maps increase between adjacent layers. num_fmaps = 24 fmap_inc_factor = 3 @@ -31,11 +34,14 @@ # ## Specify config values for the training process -# Then, we specify training-specific parameters such as the `device`, which indicates the actual device to run the training on. -#
The device could be set equal to `cuda:n` (where `n` is the index of the GPU, for e.g. `cuda:0`), `cpu` or `mps`.
-# We set the `max_iterations` equal to 5000 for demonstration purposes.
(This takes around 20 minutes on a Mac Book Pro with an Apple M2 Max chip). +# Then, we specify training-specific parameters such as the `device`, +# which indicates the actual device to run the training on. +#
The device could be set equal to `cuda:n` (where `n` is the index of +# the GPU, for e.g. `cuda:0`), `cpu` or `mps`.
+# We set the `max_iterations` equal to 5000 for demonstration purposes. +#
(This takes around 20 minutes on a Mac Book Pro with an Apple M2 Max chip). -device = "cuda:0" +device = "mps" # 'mps', 'cpu', 'cuda:0' max_iterations = 5000 train_config = TrainConfig( @@ -44,16 +50,17 @@ max_iterations=max_iterations, ) -# Next, we initialize the experiment config which puts together the config objects (`train_config` and `model_config`) which we defined above. +# Next, we initialize the experiment config which puts together the config +# objects (`train_config` and `model_config`) which we defined above. experiment_config = ExperimentConfig( - train_config=asdict(train_config), model_config=asdict(model_config) + train_config=asdict(train_config), + model_config=asdict(model_config), + normalization_factor=1.0, ) # Now we can begin the training!
# Uncomment the next two lines to train the model. -# + # from cellulus.train import train # train(experiment_config) -# - diff --git a/docs/examples/2d/03-infer.py b/docs/examples/2d/03-infer.py index d6c212e..ca2ad3d 100644 --- a/docs/examples/2d/03-infer.py +++ b/docs/examples/2d/03-infer.py @@ -1,6 +1,7 @@ # # Infer using Trained Model -# In this notebook, we will use the `cellulus` model trained in the previous step to obtain instance segmentations. +# In this notebook, we will use the `cellulus` model trained in the previous +# step to obtain instance segmentations. import urllib import zipfile @@ -21,12 +22,17 @@ # ## Specify config values for datasets -# We again specify `name` of the zarr container, and `dataset_name` which identifies the path to the raw image data, which needs to be segmented. +# We again specify `name` of the zarr container, and `dataset_name` which +# identifies the path to the raw image data, which needs to be segmented. name = "2d-data-demo" dataset_name = "train/raw" -# We initialize the `dataset_config` which relates to the raw image data, `prediction_dataset_config` which relates to the per-pixel embeddings and the uncertainty, the `segmentation_dataset_config` which relates to the segmentations post the mean-shift clustering and the `post_processed_config` which relates to the segmentations after some post-processing. +# We initialize the `dataset_config` which relates to the raw image data, +# `prediction_dataset_config` which relates to the per-pixel embeddings and the +# uncertainty, the `segmentation_dataset_config` which relates to the +# segmentations post the mean-shift clustering and the `post_processed_config` +# which relates to the segmentations after some post-processing. dataset_config = DatasetConfig(container_path=name + ".zarr", dataset_name=dataset_name) prediction_dataset_config = DatasetConfig( @@ -45,10 +51,13 @@ # ## Specify config values for the model -# We must also specify the `num_fmaps`, `fmap_inc_factor` (use same values as in the training step) and set `checkpoint` equal to `models/best_loss.pth` (best in terms of the lowest loss obtained). +# We must also specify the `num_fmaps`, `fmap_inc_factor` (use same values as +# in the training step) and set `checkpoint` equal to `models/best_loss.pth` +# (best in terms of the lowest loss obtained). # Here, we download a pretrained model trained by us for `5e3` iterations.
-# But please comment the next cell to use your own trained model, which should be available in the `models` directory. +# But please comment the next cell to use your own trained model, which +# should be available in the `models` directory. torch.hub.download_url_to_file( url="https://github.com/funkelab/cellulus/releases/download/v0.0.1-tag/2d-demo-model.zip", @@ -68,32 +77,45 @@ # ## Initialize `inference_config` -# Then, we specify inference-specific parameters such as the `device`, which indicates the actual device to run the inference on. -#
The device could be set equal to `cuda:n` (where `n` is the index of the GPU, for e.g. `cuda:0`), `cpu` or `mps`. +# Then, we specify inference-specific parameters such as the `device`, which +# indicates the actual device to run the inference on. +#
The device could be set equal to `cuda:n` (where `n` is the index of +# the GPU, for e.g. `cuda:0`), `cpu` or `mps`. -device = "cuda:0" +device = "mps" # "cuda:0", 'mps', 'cpu' -# We initialize the `inference_config` which contains our `embeddings_dataset_config`, `segmentation_dataset_config` and `post_processed_dataset_config`. +# We initialize the `inference_config` which contains our +# `embeddings_dataset_config`, `segmentation_dataset_config` and +# `post_processed_dataset_config`. +# We set post_processing to one of `cell` or `nucleus`, depending on if we +# would like the cell membrane to be segmented or the nucleus. + +post_processing = "nucleus" +bandwidth = 15.0 inference_config = InferenceConfig( dataset_config=asdict(dataset_config), prediction_dataset_config=asdict(prediction_dataset_config), segmentation_dataset_config=asdict(segmentation_dataset_config), post_processed_dataset_config=asdict(post_processed_dataset_config), - post_processing="intensity", + post_processing=post_processing, device=device, + bandwidth=bandwidth, ) # ## Initialize `experiment_config` -# Lastly we initialize the `experiment_config` which contains the `inference_config` and `model_config` initialized above. +# Lastly we initialize the `experiment_config` which contains the +# `inference_config` and `model_config` initialized above. experiment_config = ExperimentConfig( - inference_config=asdict(inference_config), model_config=asdict(model_config) + inference_config=asdict(inference_config), + model_config=asdict(model_config), + normalization_factor=1.0, ) # Now we are ready to start the inference!!
-# (This takes around 7 minutes on a Mac Book Pro with an Apple M2 Max chip (i.e. `device = 'mps'`). To see the output of the cell below, remove the first line `io.capture_output()`). +# To see the output of the cell below, remove the first line `io.capture_output()`). with io.capture_output() as captured: infer(experiment_config) @@ -101,7 +123,8 @@ # ## Inspect predictions # Let's look at some of the predicted embeddings.
-# We will first load a glasbey-like color map to show individual cells with a unique color. +# We will first load a glasbey-like color map to show individual cells +# with a unique color. urllib.request.urlretrieve( "https://github.com/funkelab/cellulus/releases/download/v0.0.1-tag/cmap_60.npy", @@ -109,7 +132,9 @@ ) new_cmp = ListedColormap(np.load("cmap_60.npy")) -# Change the value of `index` below to look at the raw image (left), x-offset (bottom-left), y-offset (bottom-right) and uncertainty of the embedding (top-right). +# Change the value of `index` below to look at the raw image (left), +# x-offset (bottom-left), y-offset (bottom-right) and uncertainty of the +# embedding (top-right). # + index = 10 @@ -132,7 +157,10 @@ ) # - -# As you can see the magnitude of the uncertainty of the embedding (top-right) is low for most of the foreground cells.
This enables extraction of the foreground, which is eventually clustered into individual instances. +# As you can see the magnitude of the uncertainty of the embedding (top-right) +# is low for most of the foreground cells.
+# This enables extraction of the foreground, which is eventually clustered +# into individual instances. # + f = zarr.open(name + ".zarr") From b29dce73c9cba3645ad66ed9c370ebe806caf652 Mon Sep 17 00:00:00 2001 From: lmanan Date: Mon, 26 Feb 2024 23:18:44 -0500 Subject: [PATCH 38/40] Set bandwidth to be float type. Default value is 0.5 times object size --- cellulus/infer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cellulus/infer.py b/cellulus/infer.py index 5cf884b..0ddff24 100644 --- a/cellulus/infer.py +++ b/cellulus/infer.py @@ -26,7 +26,7 @@ def infer(experiment_config): ) if inference_config.bandwidth is None: - inference_config.bandwidth = int(0.25 * experiment_config.object_size) + inference_config.bandwidth = 0.5 * experiment_config.object_size if inference_config.min_size is None: if dataset_meta_data.num_spatial_dims == 2: From f45278512b1560ed2c2b6ee5d8daa00c27da8fad Mon Sep 17 00:00:00 2001 From: lmanan Date: Tue, 27 Feb 2024 00:18:13 -0500 Subject: [PATCH 39/40] Update device to be cuda:0 --- docs/examples/2d/02-train.py | 3 ++- docs/examples/2d/03-infer.py | 5 +---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/docs/examples/2d/02-train.py b/docs/examples/2d/02-train.py index 66170ea..619ca0d 100644 --- a/docs/examples/2d/02-train.py +++ b/docs/examples/2d/02-train.py @@ -41,7 +41,7 @@ # We set the `max_iterations` equal to 5000 for demonstration purposes. #
(This takes around 20 minutes on a Mac Book Pro with an Apple M2 Max chip). -device = "mps" # 'mps', 'cpu', 'cuda:0' +device = "cuda:0" # 'mps', 'cpu', 'cuda:0' max_iterations = 5000 train_config = TrainConfig( @@ -62,5 +62,6 @@ # Now we can begin the training!
# Uncomment the next two lines to train the model. +# + # from cellulus.train import train # train(experiment_config) diff --git a/docs/examples/2d/03-infer.py b/docs/examples/2d/03-infer.py index ca2ad3d..1e9d20e 100644 --- a/docs/examples/2d/03-infer.py +++ b/docs/examples/2d/03-infer.py @@ -82,7 +82,7 @@ #
The device could be set equal to `cuda:n` (where `n` is the index of # the GPU, for e.g. `cuda:0`), `cpu` or `mps`. -device = "mps" # "cuda:0", 'mps', 'cpu' +device = "cuda:0" # "cuda:0", 'mps', 'cpu' # We initialize the `inference_config` which contains our # `embeddings_dataset_config`, `segmentation_dataset_config` and @@ -91,7 +91,6 @@ # would like the cell membrane to be segmented or the nucleus. post_processing = "nucleus" -bandwidth = 15.0 inference_config = InferenceConfig( dataset_config=asdict(dataset_config), @@ -100,7 +99,6 @@ post_processed_dataset_config=asdict(post_processed_dataset_config), post_processing=post_processing, device=device, - bandwidth=bandwidth, ) # ## Initialize `experiment_config` @@ -180,4 +178,3 @@ bottom_left_cmap=new_cmp, bottom_right_cmap=new_cmp, ) -# - From e21592c75238e84e33f7c6c97dba71ca7b473003 Mon Sep 17 00:00:00 2001 From: lmanan Date: Tue, 27 Feb 2024 00:18:45 -0500 Subject: [PATCH 40/40] Update title --- docs/index.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/index.md b/docs/index.md index 820e543..6161c27 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,6 +1,6 @@ --- template: home.html -title: Title +title: social: cards_layout_options: title: Documentation for Cellulus