Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introducing first version of semi-supervised functionality #3

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions cellulus/configs/semi_supervised_experiment_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import attrs
from attrs.validators import instance_of

from .inference_config import InferenceConfig
from .model_config import ModelConfig
from .semi_supervised_train_config import SemiSupervisedTrainConfig
from .utils import to_config


@attrs.define
class SemiSupervisedExperimentConfig:
"""Top-level config for a semi-supervised experiment
(containing training and prediction).

Parameters:

experiment_name:

A unique name for the experiment.

object_size:

A rough estimate of the size of objects in the image, given in
world units. The "patch size" of the network will be chosen based
on this estimate.

model_config:

The model configuration.

semi_sup_train_config:

Configuration object for training the semi-supervised model.

inference_config:

Configuration object for prediction.
"""

experiment_name: str = attrs.field(validator=instance_of(str))
object_size: float = attrs.field(validator=instance_of(float))

model_config: ModelConfig = attrs.field(converter=to_config(ModelConfig))
semi_sup_train_config: SemiSupervisedTrainConfig = attrs.field(
default=None, converter=to_config(SemiSupervisedTrainConfig)
)
inference_config: InferenceConfig = attrs.field(
default=None, converter=to_config(InferenceConfig)
)
113 changes: 113 additions & 0 deletions cellulus/configs/semi_supervised_train_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from typing import List

import attrs
from attrs.validators import instance_of

from .dataset_config import DatasetConfig
from .utils import to_config


@attrs.define
class SemiSupervisedTrainConfig:
"""Train configuration.

Parameters:

raw_data_config:

Configuration object for the raw training data.

pseudo_data_config:

Configuration object for the pseudo-ground-truth labels.

supervised_data_config:

Configuration object for the ground-truth labels/annotations.

crop_size:

The size of the crops - specified as a tuple of pixels -
extracted from the raw images, used during training.

batch_size:

The number of samples to use per batch.

max_iterations:

The maximum number of iterations to train for.

initial_learning_rate (default = 4e-5):

Initial learning rate of the optimizer.

temperature (default = 10):

Factor used to scale the gaussian function and control the rate of damping.

regularizer_weight (default = 1e-5):

The weight of the L2 regularizer on the object-centric embeddings.

reduce_mean (default = True):

If True, the loss contribution is averaged across all pairs of patches.

density (default = 0.2)

Determines the fraction of patches to sample per crop, during training.

kappa (default = 10.0):

Neighborhood radius to extract patches from

save_model_every (default = 1e3):

The model weights are saved every few iterations.

save_snapshot_every (default = 1e3):

The zarr snapshot is saved every few iterations.

num_workers (default = 8):

The number of sub-processes to use for data-loading.

control_point_spacing (default = 64):

The distance in pixels between control points used for elastic
deformation of the raw data during training.

control_point_jitter (default = 2.0):

How much to jitter the control points for elastic deformation
of the raw data during training, given as the standard deviation of
a normal distribution with zero mean.


"""

raw_data_config: DatasetConfig = attrs.field(converter=to_config(DatasetConfig))
pseudo_data_config: DatasetConfig = attrs.field(converter=to_config(DatasetConfig))
supervised_data_config: DatasetConfig = attrs.field(
converter=to_config(DatasetConfig)
)

crop_size: List = attrs.field(default=[252, 252], validator=instance_of(List))
batch_size: int = attrs.field(default=8, validator=instance_of(int))
max_iterations: int = attrs.field(default=100_000, validator=instance_of(int))
initial_learning_rate: float = attrs.field(
default=4e-5, validator=instance_of(float)
)
density: float = attrs.field(default=0.2, validator=instance_of(float))
kappa: float = attrs.field(default=10.0, validator=instance_of(float))
temperature: float = attrs.field(default=10.0, validator=instance_of(float))
regularizer_weight: float = attrs.field(default=1e-5, validator=instance_of(float))
reduce_mean: bool = attrs.field(default=True, validator=instance_of(bool))
save_model_every: int = attrs.field(default=1_000, validator=instance_of(int))
save_snapshot_every: int = attrs.field(default=1_000, validator=instance_of(int))
num_workers: int = attrs.field(default=8, validator=instance_of(int))

control_point_spacing: int = attrs.field(default=64, validator=instance_of(int))
control_point_jitter: float = attrs.field(default=2.0, validator=instance_of(float))
118 changes: 118 additions & 0 deletions cellulus/criterions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
import gunpowder as gp
import numpy as np
import stardist
from inferno.io.transform import Transform

from cellulus.criterions.oce_loss import OCELoss


Expand All @@ -19,3 +24,116 @@ def get_loss(
reduce_mean,
device,
)


class TransformStardist(gp.BatchFilter):
def __init__(self, array):
self.array = array

def prepare(self, request):
# the requested ROI for array
# expects (17,x,y)
roi = request[self.array].roi

self.stardist_shape = roi.get_shape()
self.stardist_roi = roi
print("roi = ", roi)

# 1. compute the context
# context = gp.Coordinate((self.truncate,)*roi.dims()) * self.sigma

# 2. enlarge the requested ROI by the context
# roi.__offset = [0,0,0,0]
# context_roi = roi.set_shape([1,1,stardist_shape[1],stardist_shape[2]])
# context_roi = gp.Roi((0,0,0,0),(1,1,self.stardist_shape[1],self.stardist_shape[2])) # noqa: E501
roi = gp.Roi(
(0, 0, 0, 0), (1, 1, self.stardist_shape[1], self.stardist_shape[2])
)
print("roi =", roi)

# create a new request with our dependencies
deps = gp.BatchRequest()
deps[self.array] = roi
print("deps created")
# return the request
return deps

def process(self, batch, request):
data = batch[self.array].data
self.data_shape = data.shape
# import numpy as np
print(self.array, data.shape, np.unique(data))
temp = stardist_transform(data)
print(temp.shape, np.unique(temp))
batch[self.array].data = temp


def stardist_transform(gt, n_rays=16, fill_label_holes=False):
if len(gt.shape) > 2:
gt = np.squeeze(gt)

if np.any(gt - gt.astype(np.uint16)):
mapping = {v: k for k, v in enumerate(np.unique(gt))}
u, inv = np.unique(gt, return_inverse=True)
Y1 = np.array([mapping[x] for x in u])[inv].reshape(gt.shape)
gt = Y1.astype(np.uint16)

if fill_label_holes:
gt = stardist.fill_label_holes(gt)

dist = stardist.geometry.star_dist(gt, n_rays=n_rays)
dist_mask = stardist.utils.edt_prob(gt.astype(int))

if gt.min() < 0:
# ignore label found
ignore_mask = gt < 0
print(gt.shape, dist.shape)
dist[ignore_mask] = 0
dist_mask[ignore_mask] = -1

dist_mask = dist_mask[None]
dist = np.transpose(dist, (2, 0, 1))

# dist_mask = torch.tensor(dist_mask)
# dist = torch.tensor(dist)
mask_and_dist = np.concatenate([dist_mask, dist], axis=0)

# mask_and_dist = torch.cat([dist_mask, dist], axis=0)
return mask_and_dist


class StardistTf(Transform):
"""Convert segmentation to stardist"""

def __init__(self, n_rays=16, fill_label_holes=False):
super().__init__()
self.n_rays = n_rays
self.fill_label_holes = fill_label_holes

def tensor_function(self, gt):
if np.any(gt - gt.astype(np.uint16)):
mapping = {v: k for k, v in enumerate(np.unique(gt))}
u, inv = np.unique(gt, return_inverse=True)
Y1 = np.array([mapping[x] for x in u])[inv].reshape(gt.shape)
gt = Y1.astype(np.uint16)
# gt = measure.label(gt)
if self.fill_label_holes:
gt = stardist.fill_label_holes(gt)
# import pdb
# pdb.set_trace()
# print('gt.type',gt.type())
dist = stardist.geometry.star_dist(gt, n_rays=self.n_rays)
dist_mask = stardist.utils.edt_prob(gt)

if gt.min() < 0:
# ignore label found
ignore_mask = gt < 0
print(gt.shape, dist.shape)
dist[ignore_mask] = 0
dist_mask[ignore_mask] = -1

dist_mask = dist_mask[None]
dist = np.transpose(dist, (2, 0, 1))

mask_and_dist = np.concatenate([dist_mask, dist], axis=0)
return mask_and_dist
63 changes: 63 additions & 0 deletions cellulus/criterions/stardist_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import torch.nn as nn
from torch.nn import functional as F


class StardistLoss(nn.Module):
"""Loss for stardist predictions combines BCE loss for probabilities
with MAE (L1) loss for distances

Args:
weight: Distance loss weight. Total loss will be bce_loss + weight * l1_loss
"""

def __init__(self, weight=1.0):
super().__init__()
self.weight = weight

def forward(self, prediction, target, mask=None):
# Predicted distances errors are weighted by object prob
if target.shape != prediction.shape:
prediction = prediction.squeeze(1)

target_prob = target[:, :1]
predicted_prob = prediction[:, :1]
target_dist = target[:, 1:]
predicted_dist = prediction[:, 1:]

if mask is not None:
target_prob = mask * target_prob
# do not train foreground prediction when mask is supplied
predicted_prob = predicted_prob.detach()

l1loss_pp = F.l1_loss(predicted_dist, target_dist, reduction="none")

ignore_mask_provided = target_prob.min() < 0
if ignore_mask_provided:
# ignore label was supplied
ignore_mask = target_prob >= 0.0
# add one to avoid division by zero
imsum = ignore_mask.sum()
if imsum == 0:
print("WARNING: Batch with only ignorelabel encountered!")
return 0 * l1loss_pp.sum()

l1loss = ((target_prob * ignore_mask) * l1loss_pp).sum() / imsum

bceloss = (
F.binary_cross_entropy_with_logits(
predicted_prob[ignore_mask],
target_prob[ignore_mask].float(),
reduction="sum",
)
/ imsum
)
return self.weight * l1loss + bceloss

# weight predictions by target probs
l1loss = (target_prob * l1loss_pp).mean()

bceloss = F.binary_cross_entropy_with_logits(
predicted_prob, target_prob.float(), reduction="mean"
)

return (self.weight * l1loss) + bceloss
6 changes: 6 additions & 0 deletions cellulus/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,16 @@ def get_dataset(
crop_size: Tuple[int, ...],
control_point_spacing: int,
control_point_jitter: float,
semi_supervised: bool = False,
supervised_dataset_config: DatasetConfig = None,
pseudo_dataset_config: DatasetConfig = None,
) -> ZarrDataset:
return ZarrDataset(
dataset_config=dataset_config,
crop_size=crop_size,
control_point_spacing=control_point_spacing,
control_point_jitter=control_point_jitter,
semi_supervised=semi_supervised,
supervised_dataset_config=supervised_dataset_config,
pseudo_dataset_config=pseudo_dataset_config,
)
Loading
Loading