Skip to content

Commit

Permalink
Merge pull request #6 from funkelab/update-post-processing
Browse files Browse the repository at this point in the history
Update post processing
  • Loading branch information
lmanan authored Feb 27, 2024
2 parents 7752a94 + e21592c commit 6dc2834
Show file tree
Hide file tree
Showing 26 changed files with 854 additions and 358 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ pip install -e .

### Getting Started

Try out `2D Example` available **[here](https://funkelab.github.io/cellulus)**.
Try out the `2D Example` available **[here](https://funkelab.github.io/cellulus)**.

### Citation

Expand Down
2 changes: 1 addition & 1 deletion cellulus/configs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .dataset_config import DatasetConfig
from .experiment_config import ExperimentConfig

__all__ = ["ExperimentConfig", "DatasetConfig"]
__all__ = ["DatasetConfig", "ExperimentConfig"]
15 changes: 12 additions & 3 deletions cellulus/configs/dataset_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,29 @@
class DatasetConfig:
"""Dataset configuration.
Parameters:
Parameters
----------
container_path:
A path to the zarr/N5 container.
dataset_name:
The name of the dataset containing raw data in the container.
The name of the dataset containing the raw data in the container.
secondary_dataset_name:
The name of the dataset containing the data which needs processing.
The name of the secondary dataset containing the data which needs
processing.
'dataset_name' and 'secondary_dataset_name' can be thought of as the
output and input to a certain task, respectively.
For example, during segmentation, 'dataset_name' would refer to the output
segmentation masks and 'secondary_dataset_name' would refer to the input
predicted embeddings.
During evaluation, 'dataset_name' would refer to the ground truth masks
and 'secondary_dataset_name' would refer to the input segmentation masks.
"""

Expand Down
18 changes: 14 additions & 4 deletions cellulus/configs/experiment_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,28 @@
class ExperimentConfig:
"""Top-level config for an experiment (containing training and prediction).
Parameters:
Parameters
----------
experiment_name: (default = 'YYYY-MM-DD')
A unique name for the experiment.
object_size: (default = 26.0)
object_size: (default = 30)
A rough estimate of the size of objects in the image, given in
world units. The "patch size" of the network will be chosen based
on this estimate.
normalization_factor: (default = None)
The factor to use, for dividing the raw image pixel intensities.
If 'None', a factor is chosen based on the dtype of the array .
(e.g., np.uint8 would result in a factor of 1.0/255).
model_config:
The model configuration.
Configuration object for the model.
train_config:
Expand All @@ -42,7 +49,10 @@ class ExperimentConfig:
experiment_name: str = attrs.field(
default=datetime.today().strftime("%Y-%m-%d"), validator=instance_of(str)
)
object_size: float = attrs.field(default=26.0, validator=instance_of(float))
normalization_factor: float = attrs.field(
default=None, validator=attrs.validators.optional(instance_of(float))
)
object_size: int = attrs.field(default=30, validator=instance_of(int))

train_config: TrainConfig = attrs.field(
default=None, converter=to_config(TrainConfig)
Expand Down
70 changes: 55 additions & 15 deletions cellulus/configs/inference_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import List

import attrs
from attrs.validators import instance_of
from attrs.validators import in_, instance_of

from .dataset_config import DatasetConfig
from .utils import to_config
Expand All @@ -11,7 +11,8 @@
class InferenceConfig:
"""Inference configuration.
Parameters:
Parameters
----------
dataset_config:
Expand All @@ -27,51 +28,86 @@ class InferenceConfig:
post_processed_dataset_config:
Configuration object produced by postprocess.py.
Configuration object produced by post_process.py.
evaluation_dataset_config:
Configuration object for the ground truth masks.
crop_size:
crop_size (default = [252, 252]):
ROI used by the scan node in gunpowder.
p_salt_pepper:
p_salt_pepper (default = 0.01):
Fraction of pixels that will have salt-pepper noise.
num_infer_iterations:
num_infer_iterations (default = 16):
Number of times the salt-peper noise is added to the raw image.
This is used to infer the foreground and background in the raw image.
bandwidth:
bandwidth (default = None):
Band-width used to perform mean-shift clustering on the predicted
Bandwidth used to perform mean-shift clustering on the predicted
embeddings.
threshold:
threshold (default = None):
Threshold to use for binary partitioning into foreground and background
pixel regions. If None, this is figured out automatically by performing
Otsu Thresholding on the last channel of the predicted embeddings.
reduction_probability:
min_size (default = None):
min_size:
Ignore objects which are smaller than min_size number of pixels.
Ignore objects which are smaller than `min_size` number of pixels.
device (default = 'cuda:0'):
The device to infer on.
Set to 'cpu' to infer without GPU.
clustering (default = 'meanshift'):
How to cluster the embeddings?
Can be one of 'meanshift' or 'greedy'.
num_bandwidths (default = 1):
Number of bandwidths to obtain segmentations for.
reduction_probability (default = 0.1):
If set to less than 1.0, this fraction of available pixels are used
to determine the clusters (fitting stage) while performing
meanshift clustering.
Once clusters are available, they are used to predict the cluster assignment
of the remaining pixels (prediction stage).
min_size (default = None):
Objects below `min_size` pixels will be removed.
post_processing (default= 'morphological'):
Can be one of 'morphological' or 'intensity' operations.
If 'morphological', the individual detections grow and shrink by
'grow_distance' and 'shrink_distance' number of pixels.
If 'intensity', each detection is partitioned based on a binary intensity
threshold calculated automatically from the raw image data.
By default, the channel `0` in the raw image is used for
intensity thresholding.
grow_distance (default = 3):
Only used if post_processing (see above) is equal to 'morphological'.
shrink_distance (default = 6):
Only used if post_processing (see above) is equal to
'morphological'.
"""

dataset_config: DatasetConfig = attrs.field(
Expand Down Expand Up @@ -100,13 +136,17 @@ class InferenceConfig:
threshold = attrs.field(
default=None, validator=attrs.validators.optional(instance_of(float))
)
clustering = attrs.field(
default="meanshift", validator=in_(["meanshift", "greedy"])
)
bandwidth = attrs.field(
default=None, validator=attrs.validators.optional(instance_of(int))
default=None, validator=attrs.validators.optional(instance_of(float))
)
num_bandwidths = attrs.field(default=1, validator=instance_of(int))
reduction_probability = attrs.field(default=0.1, validator=instance_of(float))
min_size = attrs.field(
default=None, validator=attrs.validators.optional(instance_of(int))
)
post_processing = attrs.field(default="cell", validator=in_(["cell", "nucleus"]))
grow_distance = attrs.field(default=3, validator=instance_of(int))
shrink_distance = attrs.field(default=6, validator=instance_of(int))
15 changes: 8 additions & 7 deletions cellulus/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@

@attrs.define
class ModelConfig:
"""Model configuration.
"""Model Configuration.
Parameters:
Parameters
----------
num_fmaps:
Expand All @@ -22,27 +23,27 @@ class ModelConfig:
The factor by which to increase the number of feature maps between
levels of the U-Net.
features_in_last_layer (optional, default = 64):
features_in_last_layer (default = 64):
The number of feature channels in the last layer of the U-Net
downsampling_factors:
downsampling_factors (default = [[2,2]]):
A list of downsampling factors, each given per dimension (e.g.,
[[2,2], [3,3]] would correspond to two downsample layers, one with
an isotropic factor of 2, and another one with 3). This parameter
will also determine the number of levels in the U-Net.
checkpoint (optional, default ``None``):
checkpoint (default = None):
A path to a checkpoint of the network. Needs to be set for networks
that are used for prediction. If set during training, the
checkpoint will be used to resume training, otherwise the network
will be trained from scratch.
initialize (default: True)
initialize (default = True)
If True, initialize the model weights with Kaiming Normal
If True, initialize the model weights with Kaiming Normal.
"""

Expand Down
28 changes: 20 additions & 8 deletions cellulus/configs/train_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,19 @@
class TrainConfig:
"""Train configuration.
Parameters:
Parameters
----------
crop_size:
crop_size (default = [252, 252]):
The size of the crops - specified as a list of number of pixels -
extracted from the raw images, used during training.
batch_size:
batch_size (default = 8):
The number of samples to use per batch.
max_iterations:
max_iterations (default = 100000):
The maximum number of iterations to train for.
Expand Down Expand Up @@ -50,28 +51,39 @@ class TrainConfig:
Neighborhood radius to extract patches from.
save_model_every (default = 1e3):
save_model_every (default = 1000):
The model weights are saved every few iterations.
save_snapshot_every (default = 1e3):
save_best_model_every (default = 100):
The best loss is evaluated every few iterations.
save_snapshot_every (default = 1000):
The zarr snapshot is saved every few iterations.
num_workers (default = 8):
The number of sub-processes to use for data-loading.
elastic_deform (default = True):
If set to True, the data is elastically deformed
in order to increase training samples.
control_point_spacing (default = 64):
The distance in pixels between control points used for elastic
deformation of the raw data during training.
Only used if `elastic_deform` is set to True.
control_point_jitter (default = 2.0):
How much to jitter the control points for elastic deformation
of the raw data during training, given as the standard deviation of
a normal distribution with zero mean.
Only used if `elastic_deform` is set to True.
train_data_config:
Expand Down Expand Up @@ -105,11 +117,11 @@ class TrainConfig:
kappa: float = attrs.field(default=10.0, validator=instance_of(float))
temperature: float = attrs.field(default=10.0, validator=instance_of(float))
regularizer_weight: float = attrs.field(default=1e-5, validator=instance_of(float))
reduce_mean: bool = attrs.field(default=True, validator=instance_of(bool))
save_model_every: int = attrs.field(default=1_000, validator=instance_of(int))
save_best_model_every: int = attrs.field(default=100, validator=instance_of(int))
save_snapshot_every: int = attrs.field(default=1_000, validator=instance_of(int))
num_workers: int = attrs.field(default=8, validator=instance_of(int))

elastic_deform: bool = attrs.field(default=True, validator=instance_of(bool))
control_point_spacing: int = attrs.field(default=64, validator=instance_of(int))
control_point_jitter: float = attrs.field(default=2.0, validator=instance_of(float))
device: str = attrs.field(default="cuda:0", validator=instance_of(str))
4 changes: 0 additions & 4 deletions cellulus/criterions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,13 @@ def get_loss(
temperature,
regularizer_weight,
density,
kappa,
num_spatial_dims,
reduce_mean,
device,
):
return OCELoss(
temperature,
regularizer_weight,
density,
kappa,
num_spatial_dims,
reduce_mean,
device,
)
Loading

0 comments on commit 6dc2834

Please sign in to comment.