diff --git a/pyproject.toml b/pyproject.toml index 3a5d57eb7f5428f65a675c9950ed5e4be87b5315..d4e1aef6e781053509e570e7f57a0808d6f51e91 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,7 +76,7 @@ mednet = "mednet.scripts.cli:cli" [tool.pixi.project] channels = ["conda-forge", "pytorch"] platforms = ["linux-64", "osx-arm64"] -conda-pypi-map = {"https://conda.anaconda.org/pytorch" = ".pixi-pytorch-mapping.json"} +conda-pypi-map = { "https://conda.anaconda.org/pytorch" = ".pixi-pytorch-mapping.json" } [tool.pixi.system-requirements] linux = "4.19.0" @@ -181,7 +181,7 @@ cuda = "12.1" [tool.pixi.feature.cuda.target.linux-64.dependencies] #cuda = { version = "*", channel = "nvidia" } pytorch-cuda = { version = "12.1.*", channel = "pytorch" } -pip = "*" # required for docker image building +pip = "*" # required for docker image building [tool.pixi.environments] default = { features = ["qa", "build", "doc", "test", "dev", "py312", "self"] } @@ -416,6 +416,13 @@ montgomery-shenzhen-indian-padchest = "mednet.libs.classification.config.data.mo # VISCERAL dataset visceral = "mednet.config.data.visceral.default" +[project.entry-points."mednet.libs.segmentation.config"] + +lwnet = "mednet.libs.segmentation.config.models.lwnet" + +# drive dataset - retinal vessel segmentation +drive = "mednet.libs.segmentation.config.data.drive.default" + [tool.ruff] line-length = 88 target-version = "py310" diff --git a/src/mednet/libs/segmentation/__init__.py b/src/mednet/libs/segmentation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/mednet/libs/segmentation/config/data/drive/__init__.py b/src/mednet/libs/segmentation/config/data/drive/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/mednet/libs/segmentation/config/data/drive/datamodule.py b/src/mednet/libs/segmentation/config/data/drive/datamodule.py new file mode 100644 index 0000000000000000000000000000000000000000..ecab53d8e9f3bbf598fb919ff97bd55c248cd6a4 --- /dev/null +++ b/src/mednet/libs/segmentation/config/data/drive/datamodule.py @@ -0,0 +1,129 @@ +# SPDX-FileCopyrightText: Copyright © 2024 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later +"""COVD-DRIVE for Vessel Segmentation.""" + +import importlib.resources +import os +from pathlib import Path + +import PIL.Image +from mednet.libs.common.data.datamodule import CachingDataModule +from mednet.libs.common.data.split import JSONDatabaseSplit +from mednet.libs.common.data.typing import DatabaseSplit, Sample +from mednet.libs.common.data.typing import RawDataLoader as _BaseRawDataLoader +from torchvision.transforms.functional import to_tensor + +from ....utils.rc import load_rc + +CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2]) +"""Key to search for in the configuration file for the root directory of this +database.""" + + +class RawDataLoader(_BaseRawDataLoader): + """A specialized raw-data-loader for the Montgomery dataset.""" + + datadir: str + """This variable contains the base directory where the database raw data is + stored.""" + + def __init__(self): + self.datadir = load_rc().get( + CONFIGURATION_KEY_DATADIR, os.path.realpath(os.curdir) + ) + + def sample(self, sample: tuple[str, int]) -> Sample: + """Load a single image sample from the disk. + + Parameters + ---------- + sample + A tuple containing the path suffix, within the dataset root folder, + where to find the image to be loaded, and an integer, representing the + sample label. + + Returns + ------- + The sample representation. + """ + + image = PIL.Image.open(Path(self.datadir) / str(sample[0])).convert( + mode="RGB" + ) + tensor = to_tensor(image) + label = PIL.Image.open(Path(self.datadir) / str(sample[1])).convert( + mode="1", dither=None + ) + mask = PIL.Image.open(Path(self.datadir) / str(sample[2])).convert( + mode="1", dither=None + ) + return tensor, dict( + label=to_tensor(label), mask=to_tensor(mask), name=sample[0] + ) # type: ignore[arg-type] + + def label(self, sample: tuple[str, int]) -> int: + """Load a single image sample label from the disk. + + Parameters + ---------- + sample + A tuple containing the path suffix, within the dataset root folder, + where to find the image to be loaded, and an integer, representing the + sample label. + + Returns + ------- + int + The integer label associated with the sample. + """ + return sample[1] + + +def make_split(basename: str) -> DatabaseSplit: + """Return a database split for the Montgomery database. + + Parameters + ---------- + basename + Name of the .json file containing the split to load. + + Returns + ------- + An instance of DatabaseSplit. + """ + + return JSONDatabaseSplit( + importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(basename) + ) + + +class DataModule(CachingDataModule): + """DRIVE dataset for Vessel Segmentation. + + The DRIVE database has been established to enable comparative studies on + segmentation of blood vessels in retinal images. + + * Reference: [DRIVE-2004]_ + * Original resolution (height x width): 584 x 565 + * Split reference: [DRIVE-2004]_ + * Protocol ``default``: + + * Training samples: 20 (including labels and masks) + * Test samples: 20 (including labels from annotator 1 and masks) + + * Protocol ``second-annotator``: + + * Test samples: 20 (including labels from annotator 2 and masks) + + Parameters + ---------- + split_filename + Name of the .json file containing the split to load. + """ + + def __init__(self, split_filename: str): + super().__init__( + database_split=make_split(split_filename), + raw_data_loader=RawDataLoader(), + ) diff --git a/src/mednet/libs/segmentation/config/data/drive/default.json b/src/mednet/libs/segmentation/config/data/drive/default.json new file mode 100644 index 0000000000000000000000000000000000000000..6707e6edd93546915d7f9960f8ba85d3449a8d6f --- /dev/null +++ b/src/mednet/libs/segmentation/config/data/drive/default.json @@ -0,0 +1,206 @@ +{ + "train": [ + [ + "training/images/21_training.tif", + "training/1st_manual/21_manual1.gif", + "training/mask/21_training_mask.gif" + ], + [ + "training/images/22_training.tif", + "training/1st_manual/22_manual1.gif", + "training/mask/22_training_mask.gif" + ], + [ + "training/images/23_training.tif", + "training/1st_manual/23_manual1.gif", + "training/mask/23_training_mask.gif" + ], + [ + "training/images/24_training.tif", + "training/1st_manual/24_manual1.gif", + "training/mask/24_training_mask.gif" + ], + [ + "training/images/25_training.tif", + "training/1st_manual/25_manual1.gif", + "training/mask/25_training_mask.gif" + ], + [ + "training/images/26_training.tif", + "training/1st_manual/26_manual1.gif", + "training/mask/26_training_mask.gif" + ], + [ + "training/images/27_training.tif", + "training/1st_manual/27_manual1.gif", + "training/mask/27_training_mask.gif" + ], + [ + "training/images/28_training.tif", + "training/1st_manual/28_manual1.gif", + "training/mask/28_training_mask.gif" + ], + [ + "training/images/29_training.tif", + "training/1st_manual/29_manual1.gif", + "training/mask/29_training_mask.gif" + ], + [ + "training/images/30_training.tif", + "training/1st_manual/30_manual1.gif", + "training/mask/30_training_mask.gif" + ], + [ + "training/images/31_training.tif", + "training/1st_manual/31_manual1.gif", + "training/mask/31_training_mask.gif" + ], + [ + "training/images/32_training.tif", + "training/1st_manual/32_manual1.gif", + "training/mask/32_training_mask.gif" + ], + [ + "training/images/33_training.tif", + "training/1st_manual/33_manual1.gif", + "training/mask/33_training_mask.gif" + ], + [ + "training/images/34_training.tif", + "training/1st_manual/34_manual1.gif", + "training/mask/34_training_mask.gif" + ], + [ + "training/images/35_training.tif", + "training/1st_manual/35_manual1.gif", + "training/mask/35_training_mask.gif" + ], + [ + "training/images/36_training.tif", + "training/1st_manual/36_manual1.gif", + "training/mask/36_training_mask.gif" + ], + [ + "training/images/37_training.tif", + "training/1st_manual/37_manual1.gif", + "training/mask/37_training_mask.gif" + ], + [ + "training/images/38_training.tif", + "training/1st_manual/38_manual1.gif", + "training/mask/38_training_mask.gif" + ], + [ + "training/images/39_training.tif", + "training/1st_manual/39_manual1.gif", + "training/mask/39_training_mask.gif" + ], + [ + "training/images/40_training.tif", + "training/1st_manual/40_manual1.gif", + "training/mask/40_training_mask.gif" + ] + ], + "test": [ + [ + "test/images/01_test.tif", + "test/1st_manual/01_manual1.gif", + "test/mask/01_test_mask.gif" + ], + [ + "test/images/02_test.tif", + "test/1st_manual/02_manual1.gif", + "test/mask/02_test_mask.gif" + ], + [ + "test/images/03_test.tif", + "test/1st_manual/03_manual1.gif", + "test/mask/03_test_mask.gif" + ], + [ + "test/images/04_test.tif", + "test/1st_manual/04_manual1.gif", + "test/mask/04_test_mask.gif" + ], + [ + "test/images/05_test.tif", + "test/1st_manual/05_manual1.gif", + "test/mask/05_test_mask.gif" + ], + [ + "test/images/06_test.tif", + "test/1st_manual/06_manual1.gif", + "test/mask/06_test_mask.gif" + ], + [ + "test/images/07_test.tif", + "test/1st_manual/07_manual1.gif", + "test/mask/07_test_mask.gif" + ], + [ + "test/images/08_test.tif", + "test/1st_manual/08_manual1.gif", + "test/mask/08_test_mask.gif" + ], + [ + "test/images/09_test.tif", + "test/1st_manual/09_manual1.gif", + "test/mask/09_test_mask.gif" + ], + [ + "test/images/10_test.tif", + "test/1st_manual/10_manual1.gif", + "test/mask/10_test_mask.gif" + ], + [ + "test/images/11_test.tif", + "test/1st_manual/11_manual1.gif", + "test/mask/11_test_mask.gif" + ], + [ + "test/images/12_test.tif", + "test/1st_manual/12_manual1.gif", + "test/mask/12_test_mask.gif" + ], + [ + "test/images/13_test.tif", + "test/1st_manual/13_manual1.gif", + "test/mask/13_test_mask.gif" + ], + [ + "test/images/14_test.tif", + "test/1st_manual/14_manual1.gif", + "test/mask/14_test_mask.gif" + ], + [ + "test/images/15_test.tif", + "test/1st_manual/15_manual1.gif", + "test/mask/15_test_mask.gif" + ], + [ + "test/images/16_test.tif", + "test/1st_manual/16_manual1.gif", + "test/mask/16_test_mask.gif" + ], + [ + "test/images/17_test.tif", + "test/1st_manual/17_manual1.gif", + "test/mask/17_test_mask.gif" + ], + [ + "test/images/18_test.tif", + "test/1st_manual/18_manual1.gif", + "test/mask/18_test_mask.gif" + ], + [ + "test/images/19_test.tif", + "test/1st_manual/19_manual1.gif", + "test/mask/19_test_mask.gif" + ], + [ + "test/images/20_test.tif", + "test/1st_manual/20_manual1.gif", + "test/mask/20_test_mask.gif" + ] + ] +} diff --git a/src/mednet/libs/segmentation/config/data/drive/default.py b/src/mednet/libs/segmentation/config/data/drive/default.py new file mode 100644 index 0000000000000000000000000000000000000000..4fcaa0c67cd2889fedb2919802f888e04c9bdd11 --- /dev/null +++ b/src/mednet/libs/segmentation/config/data/drive/default.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright © 2024 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later +"""DRIVE dataset for Vessel Segmentation (default protocol). + +* Split reference: [DRIVE-2004]_ +* This configuration resolution: 544 x 544 (center-crop) +* See :py:mod:`deepdraw.data.drive` for dataset details +* This dataset offers a second-annotator comparison for the test set only +""" + +from mednet.libs.segmentation.config.data.drive.datamodule import DataModule + +datamodule = DataModule("default.json") diff --git a/src/mednet/libs/segmentation/config/models/__init__.py b/src/mednet/libs/segmentation/config/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/mednet/libs/segmentation/config/models/lwnet.py b/src/mednet/libs/segmentation/config/models/lwnet.py new file mode 100644 index 0000000000000000000000000000000000000000..9cb08ea3b463ed79c9b2c8c4dbc670eae7b9fa64 --- /dev/null +++ b/src/mednet/libs/segmentation/config/models/lwnet.py @@ -0,0 +1,26 @@ +# SPDX-FileCopyrightText: Copyright © 2024 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later +"""Little W-Net for image segmentation. + +The Little W-Net architecture contains roughly around 70k parameters and +closely matches (or outperforms) other more complex techniques. + +Reference: [GALDRAN-2020]_ +""" + +from mednet.libs.segmentation.models.losses import MultiWeightedBCELogitsLoss +from mednet.libs.segmentation.models.lwnet import LittleWNet +from torch.optim import Adam + +max_lr = 0.01 # start +min_lr = 1e-08 # valley +cycle = 50 # epochs for a complete scheduling cycle + +model = LittleWNet( + train_loss=MultiWeightedBCELogitsLoss(), + validation_loss=MultiWeightedBCELogitsLoss(), + optimizer_type=Adam, + optimizer_arguments=dict(lr=max_lr), + augmentation_transforms=[], +) diff --git a/src/mednet/libs/segmentation/data/__init__.py b/src/mednet/libs/segmentation/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/mednet/libs/segmentation/models/__init__.py b/src/mednet/libs/segmentation/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/mednet/libs/segmentation/models/losses.py b/src/mednet/libs/segmentation/models/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..3af10f0dd434132c992197ec310b8e33042ffd3f --- /dev/null +++ b/src/mednet/libs/segmentation/models/losses.py @@ -0,0 +1,283 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Loss implementations.""" + +import torch +from torch.nn.modules.loss import _Loss + + +class WeightedBCELogitsLoss(_Loss): + """Calculates sum of weighted cross entropy loss. + + Implements Equation 1 in [MANINIS-2016]_. The weight depends on the + current proportion between negatives and positives in the ground- + truth sample being analyzed. + """ + + def __init__(self): + super().__init__() + + def forward( + self, sample: torch.Tensor, target: torch.Tensor, mask: torch.Tensor + ) -> torch.Tensor: + """Forward pass. + + Parameters + ---------- + sample + Value produced by the model to be evaluated, with the shape ``[n, c, + h, w]``. + + target + Ground-truth information with the shape ``[n, c, h, w]``. + + mask + Mask to be use for specifying the region of interest where to + compute the loss, with the shape ``[n, c, h, w]``. + + Returns + ------- + loss + The average loss for all input data. + """ + + # calculates the proportion of negatives to the total number of pixels + # available in the masked region + valid = mask > 0.5 + num_pos = target[valid].sum() + num_neg = valid.sum() - num_pos + pos_weight = num_neg / num_pos + return torch.nn.functional.binary_cross_entropy_with_logits( + sample[valid], + target[valid], + reduction="mean", + pos_weight=pos_weight, + ) + + +class SoftJaccardBCELogitsLoss(_Loss): + r"""Implement the generalized loss function of Equation (3) in. + + [IGLOVIKOV-2018]_, with J being the Jaccard distance, and H, the Binary + Cross-Entropy Loss: + + .. math:: + + L = \alpha H + (1-\alpha)(1-J) + + Our implementation is based on :py:class:`torch.nn.BCEWithLogitsLoss`. + + Parameters + ---------- + alpha + Determines the weighting of J and H. Default: ``0.7``. + """ + + def __init__(self, alpha: float = 0.7): + super().__init__() + self.alpha = alpha + + def forward( + self, tensor: torch.Tensor, target: torch.Tensor, mask: torch.Tensor + ) -> torch.Tensor: + """Forward pass. + + Parameters + ---------- + tensor + Value produced by the model to be evaluated, with the shape ``[n, c, + h, w]``. + + target + Ground-truth information with the shape ``[n, c, h, w]``. + + mask + Mask to be use for specifying the region of interest where to + compute the loss, with the shape ``[n, c, h, w]``. + + Returns + ------- + loss + Loss, in a single entry. + """ + + eps = 1e-8 + valid = mask > 0.5 + probabilities = torch.sigmoid(tensor[valid]) + intersection = (probabilities * target[valid]).sum() + sums = probabilities.sum() + target[valid].sum() + j = intersection / (sums - intersection + eps) + + # this implements the support for looking just into the RoI + h = torch.nn.functional.binary_cross_entropy_with_logits( + tensor[valid], target[valid], reduction="mean" + ) + return (self.alpha * h) + ((1 - self.alpha) * (1 - j)) + + +class MultiWeightedBCELogitsLoss(WeightedBCELogitsLoss): + """Weighted Binary Cross-Entropy Loss for multi-layered inputs (e.g. for + Holistically-Nested Edge Detection in [XIE-2015]_). + """ + + def __init__(self): + super().__init__() + + def forward( + self, tensor: torch.Tensor, target: torch.Tensor, mask: torch.Tensor + ) -> torch.Tensor: + """Forward pass. + + Parameters + ---------- + tensor + Value produced by the model to be evaluated, with the shape ``[L, + n, c, h, w]``. + + target + Ground-truth information with the shape ``[n, c, h, w]``. + + mask + Mask to be use for specifying the region of interest where to + compute the loss, with the shape ``[n, c, h, w]``. + + Returns + ------- + The average loss for all input data. + """ + + return torch.cat( + [ + super(MultiWeightedBCELogitsLoss, self) + .forward(i, target, mask) + .unsqueeze(0) + for i in tensor + ] + ).mean() + + +class MultiSoftJaccardBCELogitsLoss(SoftJaccardBCELogitsLoss): + """Implements Equation 3 in [IGLOVIKOV-2018]_ for the multi-output + networks such as HED or Little W-Net. + + Parameters + ---------- + alpha : float + Determines the weighting of SoftJaccard and BCE. Default: ``0.3``. + """ + + def __init__(self, alpha: float = 0.7): + super().__init__(alpha=alpha) + + def forward( + self, tensor: torch.Tensor, target: torch.Tensor, mask: torch.Tensor + ) -> torch.Tensor: + """Forward pass. + + Parameters + ---------- + tensor + Value produced by the model to be evaluated, with the shape ``[L, + n, c, h, w]``. + + target + Ground-truth information with the shape ``[n, c, h, w]``. + + mask + Mask to be use for specifying the region of interest where to + compute the loss, with the shape ``[n, c, h, w]``. + + Returns + ------- + The average loss for all input data. + """ + + return torch.cat( + [ + super(MultiSoftJaccardBCELogitsLoss, self) + .forward(i, target, mask) + .unsqueeze(0) + for i in tensor + ] + ).mean() + + +class MixJacLoss(_Loss): + """ + Parameters + ---------- + lambda_u + Determines the weighting of SoftJaccard and BCE. + + jacalpha + Determines the weighting of J and H. + + size_average + By default, the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True``. + + reduce + By default, the losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True``. + + reduction + Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``. + """ + + def __init__( + self, + lambda_u: int = 100, + jacalpha=0.7, + size_average=None, + reduce=None, + reduction="mean", + ): + super().__init__(size_average, reduce, reduction) + self.lambda_u = lambda_u + self.labeled_loss = SoftJaccardBCELogitsLoss(alpha=jacalpha) + self.unlabeled_loss = torch.nn.BCEWithLogitsLoss() + + def forward( + self, + tensor: torch.Tensor, + target: torch.Tensor, + unlabeled_tensor: torch.Tensor, + unlabeled_target: torch.Tensor, + ramp_up_factor: float, + ) -> tuple: + """Forward pass. + + Parameters + ---------- + tensor + Value produced by the model to be evaluated, with the shape ``[L, + n, c, h, w]``. + + target + Ground-truth information with the shape ``[n, c, h, w]``. + + unlabeled_tensor + + unlabeled_target + + ramp_up_factor + + Returns + ------- + list + """ + ll = self.labeled_loss(tensor, target) + ul = self.unlabeled_loss(unlabeled_tensor, unlabeled_target) + + loss = ll + self.lambda_u * ramp_up_factor * ul + return loss, ll, ul diff --git a/src/mednet/libs/segmentation/models/lwnet.py b/src/mednet/libs/segmentation/models/lwnet.py new file mode 100644 index 0000000000000000000000000000000000000000..6cad67c925c5e9ff0a38527ba328b36242b62ab3 --- /dev/null +++ b/src/mednet/libs/segmentation/models/lwnet.py @@ -0,0 +1,351 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Little W-Net. + +Code was originally developed by Adrian Galdran +(https://github.com/agaldran/lwnet), loosely inspired on +https://github.com/jvanvugt/pytorch-unet + +It is based on two simple U-Nets with 3 layers concatenated to each other. The +first U-Net produces a segmentation map that is used by the second to better +guide segmentation. + +Reference: [GALDRAN-2020]_ +""" + +import typing + +import lightning.pytorch as pl +import torch +import torch.nn +import torchvision.transforms +from mednet.libs.common.data.typing import TransformSequence +from mednet.libs.segmentation.models.losses import MultiWeightedBCELogitsLoss +from torchvision.transforms.v2 import CenterCrop + +from .separate import separate + + +def _conv1x1(in_planes, out_planes, stride=1): + return torch.nn.Conv2d( + in_planes, out_planes, kernel_size=1, stride=stride, bias=False + ) + + +class ConvBlock(torch.nn.Module): + def __init__(self, in_c, out_c, k_sz=3, shortcut=False, pool=True): + super().__init__() + if shortcut is True: + self.shortcut = torch.nn.Sequential( + _conv1x1(in_c, out_c), torch.nn.BatchNorm2d(out_c) + ) + else: + self.shortcut = False + pad = (k_sz - 1) // 2 + + block = [] + if pool: + self.pool = torch.nn.MaxPool2d(kernel_size=2) + else: + self.pool = False + + block.append( + torch.nn.Conv2d(in_c, out_c, kernel_size=k_sz, padding=pad) + ) + block.append(torch.nn.ReLU()) + block.append(torch.nn.BatchNorm2d(out_c)) + + block.append( + torch.nn.Conv2d(out_c, out_c, kernel_size=k_sz, padding=pad) + ) + block.append(torch.nn.ReLU()) + block.append(torch.nn.BatchNorm2d(out_c)) + + self.block = torch.nn.Sequential(*block) + + def forward(self, x): + if self.pool: + x = self.pool(x) + out = self.block(x) + if self.shortcut: + return out + self.shortcut(x) + return out + + +class UpsampleBlock(torch.nn.Module): + def __init__(self, in_c, out_c, up_mode="transp_conv"): + super().__init__() + block = [] + if up_mode == "transp_conv": + block.append( + torch.nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2) + ) + elif up_mode == "up_conv": + block.append( + torch.nn.Upsample( + mode="bilinear", scale_factor=2, align_corners=False + ) + ) + block.append(torch.nn.Conv2d(in_c, out_c, kernel_size=1)) + else: + raise Exception("Upsampling mode not supported") + + self.block = torch.nn.Sequential(*block) + + def forward(self, x): + return self.block(x) + + +class ConvBridgeBlock(torch.nn.Module): + def __init__(self, channels, k_sz=3): + super().__init__() + pad = (k_sz - 1) // 2 + block = [] + + block.append( + torch.nn.Conv2d(channels, channels, kernel_size=k_sz, padding=pad) + ) + block.append(torch.nn.ReLU()) + block.append(torch.nn.BatchNorm2d(channels)) + + self.block = torch.nn.Sequential(*block) + + def forward(self, x): + return self.block(x) + + +class UpConvBlock(torch.nn.Module): + def __init__( + self, + in_c, + out_c, + k_sz=3, + up_mode="up_conv", + conv_bridge=False, + shortcut=False, + ): + super().__init__() + self.conv_bridge = conv_bridge + + self.up_layer = UpsampleBlock(in_c, out_c, up_mode=up_mode) + self.conv_layer = ConvBlock( + 2 * out_c, out_c, k_sz=k_sz, shortcut=shortcut, pool=False + ) + if self.conv_bridge: + self.conv_bridge_layer = ConvBridgeBlock(out_c, k_sz=k_sz) + + def forward(self, x, skip): + up = self.up_layer(x) + if self.conv_bridge: + out = torch.cat([up, self.conv_bridge_layer(skip)], dim=1) + else: + out = torch.cat([up, skip], dim=1) + return self.conv_layer(out) + + +class LittleUNet(torch.nn.Module): + """Little U-Net model. + + Parameters + ---------- + in_c + + n_classes + Number of outputs (classes) for this model. + + layers + + k_sz + + up_mode + + conv_bridge + + shortcut + """ + + def __init__( + self, + in_c, + n_classes, + layers, + k_sz=3, + up_mode="transp_conv", + conv_bridge=True, + shortcut=True, + ): + super().__init__() + self.n_classes = n_classes + self.first = ConvBlock( + in_c=in_c, out_c=layers[0], k_sz=k_sz, shortcut=shortcut, pool=False + ) + + self.down_path = torch.nn.ModuleList() + for i in range(len(layers) - 1): + block = ConvBlock( + in_c=layers[i], + out_c=layers[i + 1], + k_sz=k_sz, + shortcut=shortcut, + pool=True, + ) + self.down_path.append(block) + + self.up_path = torch.nn.ModuleList() + reversed_layers = list(reversed(layers)) + for i in range(len(layers) - 1): + block = UpConvBlock( + in_c=reversed_layers[i], + out_c=reversed_layers[i + 1], + k_sz=k_sz, + up_mode=up_mode, + conv_bridge=conv_bridge, + shortcut=shortcut, + ) + self.up_path.append(block) + + # init, shamelessly lifted from torchvision/models/resnet.py + for m in self.modules(): + if isinstance(m, torch.nn.Conv2d): + torch.nn.init.kaiming_normal_( + m.weight, mode="fan_out", nonlinearity="relu" + ) + elif isinstance(m, torch.nn.BatchNorm2d | torch.nn.GroupNorm): + torch.nn.init.constant_(m.weight, 1) + torch.nn.init.constant_(m.bias, 0) + + self.final = torch.nn.Conv2d(layers[0], n_classes, kernel_size=1) + + def forward(self, x): + x = self.first(x) + down_activations = [] + for i, down in enumerate(self.down_path): + down_activations.append(x) + x = down(x) + down_activations.reverse() + for i, up in enumerate(self.up_path): + x = up(x, down_activations[i]) + return self.final(x) + + +class LittleWNet(pl.LightningModule): + """Little W-Net model, concatenating two Little U-Net models. + + Parameters + ---------- + train_loss + The loss to be used during the training. + + .. warning:: + + The loss should be set to always return batch averages (as opposed + to the batch sum), as our logging system expects it so. + validation_loss + The loss to be used for validation (may be different from the training + loss). If extra-validation sets are provided, the same loss will be + used throughout. + + .. warning:: + + The loss should be set to always return batch averages (as opposed + to the batch sum), as our logging system expects it so. + optimizer_type + The type of optimizer to use for training. + optimizer_arguments + Arguments to the optimizer after ``params``. + augmentation_transforms + An optional sequence of torch modules containing transforms to be + applied on the input **before** it is fed into the network. + num_classes + Number of outputs (classes) for this model. + """ + + def __init__( + self, + train_loss=MultiWeightedBCELogitsLoss(), + validation_loss=MultiWeightedBCELogitsLoss(), + optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, + optimizer_arguments: dict[str, typing.Any] = {}, + augmentation_transforms: TransformSequence = [], + num_classes: int = 1, + ): + super().__init__() + + self.name = "lwnet" + self.num_classes = num_classes + + self.model_transforms = [CenterCrop(size=(544, 544))] + + self._train_loss = train_loss + self._validation_loss = ( + validation_loss if validation_loss is not None else train_loss + ) + self._optimizer_type = optimizer_type + self._optimizer_arguments = optimizer_arguments + + self._augmentation_transforms = torchvision.transforms.Compose( + augmentation_transforms + ) + + self.unet1 = LittleUNet( + in_c=3, + n_classes=self.num_classes, + layers=(8, 16, 32), + conv_bridge=True, + shortcut=True, + ) + self.unet2 = LittleUNet( + in_c=3 + self.num_classes, + n_classes=self.num_classes, + layers=(8, 16, 32), + conv_bridge=True, + shortcut=True, + ) + + def forward(self, x): + x1 = self.unet1(x) + x2 = self.unet2(torch.cat([x, x1], dim=1)) + + return x1, x2 + + def training_step(self, batch, batch_idx): + images = batch[0] + ground_truths = batch[1]["label"] + masks = torch.ones_like(ground_truths) if len(batch) < 4 else batch[3] + + outputs = self(self._augmentation_transforms(images)) + + return self._train_loss(outputs, ground_truths, masks) + + def validation_step(self, batch, batch_idx): + images = batch[0] + ground_truths = batch[1]["label"] + masks = torch.ones_like(ground_truths) if len(batch) < 4 else batch[3] + + outputs = self(images) + return self._validation_loss(outputs, ground_truths, masks) + + def predict_step(self, batch, batch_idx, dataloader_idx=0): + outputs = self(batch[0]) + probabilities = torch.sigmoid(outputs) + return separate((probabilities, batch[1])) + + def configure_optimizers(self): + return self._optimizer_type( + self.parameters(), **self._optimizer_arguments + ) + + """def configure_optimizers(self): + optimizer = getattr( + self, 'optimizer', Adam(self.parameters(), lr=1e-3) + ) + if optimizer is None: + raise ValueError("Optimizer not found. Please provide an optimizer.") + + scheduler = getattr(self, 'scheduler', None) + if scheduler is None: + return {'optimizer': optimizer} + else: + return {'optimizer': optimizer, 'lr_scheduler': scheduler}""" diff --git a/src/mednet/libs/segmentation/models/separate.py b/src/mednet/libs/segmentation/models/separate.py new file mode 100644 index 0000000000000000000000000000000000000000..8abd2329030c055606056741c5642e5ad29256dd --- /dev/null +++ b/src/mednet/libs/segmentation/models/separate.py @@ -0,0 +1,61 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later +"""Contains the inverse :py:func:`torch.utils.data.default_collate`.""" + +import typing + +import torch +from mednet.libs.common.data.typing import Sample + +from .typing import BinaryPrediction, MultiClassPrediction + + +def _as_predictions( + samples: typing.Iterable[Sample], +) -> list[BinaryPrediction | MultiClassPrediction]: + """Take a list of separated batch predictions and transforms it into a list + of formal predictions. + + Parameters + ---------- + samples + A sequence of samples as returned by :py:func:`separate`. + + Returns + ------- + list[BinaryPrediction | MultiClassPrediction] + A list of typed predictions that can be saved to disk. + """ + return [(v[1]["name"], v[1]["label"].item(), v[0].item()) for v in samples] + + +def separate(batch: Sample) -> list[BinaryPrediction | MultiClassPrediction]: + """Separate a collated batch, reconstituting its samples. + + This function implements the inverse of + :py:func:`torch.utils.data.default_collate`, and can separate, into + samples, batches of data with different attributes. It follows the inverse + path of that function, and implements the following separation algorithms: + + * :class:`torch.Tensor` -> :class:`torch.Tensor` (with a removed outer + dimension, via :py:func:`torch.flatten`) + * ``typing.Mapping[K, V[]]`` -> ``[dict[K, V_1], dict[K, V_2], ...]`` + + Parameters + ---------- + batch + A batch, as output by torch model forwarding. + + Returns + ------- + A list of predictions that contains the predictions and associated metadata + for each processed sample. + """ + + # as of now, this is really simple - to be made more complex upon need. + metadata = [ + {key: value[i] for key, value in batch[1].items()} + for i in range(len(batch[0])) + ] + return _as_predictions(zip(torch.flatten(batch[0]), metadata)) diff --git a/src/mednet/libs/segmentation/models/typing.py b/src/mednet/libs/segmentation/models/typing.py new file mode 100644 index 0000000000000000000000000000000000000000..3eb9017c1e7bbbf860cf7be67622cdfe5db513df --- /dev/null +++ b/src/mednet/libs/segmentation/models/typing.py @@ -0,0 +1,27 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later +"""Defines most common types used in code.""" + +import typing + +Checkpoint: typing.TypeAlias = typing.MutableMapping[str, typing.Any] +"""Definition of a lightning checkpoint.""" + +BinaryPrediction: typing.TypeAlias = tuple[str, int, float] +"""The sample name, the target, and the predicted value.""" + +MultiClassPrediction: typing.TypeAlias = tuple[ + str, typing.Sequence[int], typing.Sequence[float] +] +"""The sample name, the target, and the predicted value.""" + +BinaryPredictionSplit: typing.TypeAlias = typing.Mapping[ + str, typing.Sequence[BinaryPrediction] +] +"""A series of predictions for different database splits.""" + +MultiClassPredictionSplit: typing.TypeAlias = typing.Mapping[ + str, typing.Sequence[MultiClassPrediction] +] +"""A series of predictions for different database splits.""" diff --git a/src/mednet/libs/segmentation/scripts/__init__.py b/src/mednet/libs/segmentation/scripts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/mednet/libs/segmentation/scripts/cli.py b/src/mednet/libs/segmentation/scripts/cli.py new file mode 100644 index 0000000000000000000000000000000000000000..add5fdac1b11a770b04c21c6f2ea20ae13abf881 --- /dev/null +++ b/src/mednet/libs/segmentation/scripts/cli.py @@ -0,0 +1,41 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import click +from clapper.click import AliasedGroup + +from . import ( + # analyze, + # compare, + # config, + # dataset, + # evaluate, + # experiment, + # mkmask, + # predict, + # significance, + train, +) + + +@click.group( + cls=AliasedGroup, + context_settings=dict(help_option_names=["-?", "-h", "--help"]), +) +def segmentation(): + """Binary Segmentation Benchmark.""" + pass + + +# segmentation.add_command(analyze.analyze) +# segmentation.add_command(compare.compare) +# segmentation.add_command(config.config) +# segmentation.add_command(dataset.dataset) +# segmentation.add_command(evaluate.evaluate) +# segmentation.add_command(experiment.experiment) +# segmentation.add_command(mkmask.mkmask) +# segmentation.add_command(predict.predict) +# segmentation.add_command(significance.significance) +segmentation.add_command(train.train) +# segmentation.add_command(train_analysis.train_analysis) diff --git a/src/mednet/libs/segmentation/scripts/click.py b/src/mednet/libs/segmentation/scripts/click.py new file mode 100644 index 0000000000000000000000000000000000000000..b028e6a737345fbfc1c61474e180e84e50679be9 --- /dev/null +++ b/src/mednet/libs/segmentation/scripts/click.py @@ -0,0 +1,28 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import click +from clapper.click import ConfigCommand as _BaseConfigCommand + + +class ConfigCommand(_BaseConfigCommand): + """A click command-class that has the properties of :py:class:`clapper.click.ConfigCommand` and adds verbatim epilog formatting.""" + + def format_epilog( + self, _: click.core.Context, formatter: click.formatting.HelpFormatter + ) -> None: + """Format the command epilog during --help. + + Parameters + ---------- + _ + The current parsing context. + formatter + The formatter to use for printing text. + """ + + if self.epilog: + formatter.write_paragraph() + for line in self.epilog.split("\n"): + formatter.write(line + "\n") diff --git a/src/mednet/libs/segmentation/scripts/train.py b/src/mednet/libs/segmentation/scripts/train.py new file mode 100644 index 0000000000000000000000000000000000000000..0796c3d77df3d28fc491de60545a056ebde9cf9e --- /dev/null +++ b/src/mednet/libs/segmentation/scripts/train.py @@ -0,0 +1,376 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import functools +import pathlib +import typing +from pathlib import Path + +import click +from clapper.click import ResourceOption, verbosity_option +from clapper.logging import setup + +from .click import ConfigCommand + +logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") + + +def reusable_options(f): + """The options that can be re-used by top-level scripts (i.e. + ``experiment``). + + This decorator equips the target function ``f`` with all (reusable) + ``train`` script options. + + Parameters + ---------- + f + The target function to equip with options. This function must have + parameters that accept such options. + + Returns + ------- + The decorated version of function ``f`` + """ # noqa D401 + + @click.option( + "--output-folder", + "-o", + help="Directory in which to store results (created if does not exist)", + required=True, + type=click.Path( + file_okay=False, + dir_okay=True, + writable=True, + path_type=pathlib.Path, + ), + default="results", + cls=ResourceOption, + ) + @click.option( + "--model", + "-m", + help="A lightning module instance implementing the network to be trained", + required=True, + cls=ResourceOption, + ) + @click.option( + "--datamodule", + "-d", + help="A lightning DataModule containing the training and validation sets.", + required=True, + cls=ResourceOption, + ) + @click.option( + "--batch-size", + "-b", + help="Number of samples in every batch (this parameter affects " + "memory requirements for the network). If the number of samples in " + "the batch is larger than the total number of samples available for " + "training, this value is truncated. If this number is smaller, then " + "batches of the specified size are created and fed to the network " + "until there are no more new samples to feed (epoch is finished). " + "If the total number of training samples is not a multiple of the " + "batch-size, the last batch will be smaller than the first, unless " + "--drop-incomplete-batch is set, in which case this batch is not used.", + required=True, + show_default=True, + default=1, + type=click.IntRange(min=1), + cls=ResourceOption, + ) + @click.option( + "--batch-chunk-count", + "-c", + help="Number of chunks in every batch (this parameter affects " + "memory requirements for the network). The number of samples " + "loaded for every iteration will be batch-size/batch-chunk-count. " + "batch-size needs to be divisible by batch-chunk-count, otherwise an " + "error will be raised. This parameter is used to reduce the number of " + "samples loaded in each iteration, in order to reduce the memory usage " + "in exchange for processing time (more iterations). This is especially " + "interesting when one is training on GPUs with limited RAM. The " + "default of 1 forces the whole batch to be processed at once. Otherwise " + "the batch is broken into batch-chunk-count pieces, and gradients are " + "accumulated to complete each batch.", + required=True, + show_default=True, + default=1, + type=click.IntRange(min=1), + cls=ResourceOption, + ) + @click.option( + "--drop-incomplete-batch/--no-drop-incomplete-batch", + "-D", + help="If set, the last batch in an epoch will be dropped if " + "incomplete. If you set this option, you should also consider " + "increasing the total number of epochs of training, as the total number " + "of training steps may be reduced.", + required=True, + show_default=True, + default=False, + cls=ResourceOption, + ) + @click.option( + "--epochs", + "-e", + help="""Number of epochs (complete training set passes) to train for. + If continuing from a saved checkpoint, ensure to provide a greater + number of epochs than was saved in the checkpoint to be loaded.""", + show_default=True, + required=True, + default=1000, + type=click.IntRange(min=1), + cls=ResourceOption, + ) + @click.option( + "--validation-period", + "-p", + help="""Number of epochs after which validation happens. By default, + we run validation after every training epoch (period=1). You can + change this to make validation more sparse, by increasing the + validation period. Notice that this affects checkpoint saving. While + checkpoints are created after every training step (the last training + step always triggers the overriding of latest checkpoint), and + this process is independent of validation runs, evaluation of the + 'best' model obtained so far based on those will be influenced by this + setting.""", + show_default=True, + required=True, + default=1, + type=click.IntRange(min=1), + cls=ResourceOption, + ) + @click.option( + "--device", + "-x", + help='A string indicating the device to use (e.g. "cpu" or "cuda:0")', + show_default=True, + required=True, + default="cpu", + cls=ResourceOption, + ) + @click.option( + "--cache-samples/--no-cache-samples", + help="If set to True, loads the sample into memory, " + "otherwise loads them at runtime.", + required=True, + show_default=True, + default=False, + cls=ResourceOption, + ) + @click.option( + "--seed", + "-s", + help="Seed to use for the random number generator", + show_default=True, + required=False, + default=42, + type=click.IntRange(min=0), + cls=ResourceOption, + ) + @click.option( + "--parallel", + "-P", + help="""Use multiprocessing for data loading: if set to -1 (default), + disables multiprocessing data loading. Set to 0 to enable as many data + loading instances as processing cores available in the system. Set to + >= 1 to enable that many multiprocessing instances for data + loading.""", + type=click.IntRange(min=-1), + show_default=True, + required=True, + default=-1, + cls=ResourceOption, + ) + @click.option( + "--monitoring-interval", + "-I", + help="""Time between checks for the use of resources during each training + epoch, in seconds. An interval of 5 seconds, for example, will lead to + CPU and GPU resources being probed every 5 seconds during each training + epoch. Values registered in the training logs correspond to averages + (or maxima) observed through possibly many probes in each epoch. + Notice that setting a very small value may cause the probing process to + become extremely busy, potentially biasing the overall perception of + resource usage.""", + type=click.FloatRange(min=0.1), + show_default=True, + required=True, + default=5.0, + cls=ResourceOption, + ) + @click.option( + "--balance-classes/--no-balance-classes", + "-B/-N", + help="""If set, balances weights of the random sampler during + training so that samples from all sample classes are picked + equitably.""", + required=True, + show_default=True, + default=True, + cls=ResourceOption, + ) + @functools.wraps(f) + def wrapper_reusable_options(*args, **kwargs): + return f(*args, **kwargs) + + return wrapper_reusable_options + + +@click.command( + entry_point_group="mednet.libs.segmentation.config", + cls=ConfigCommand, + epilog="""Examples: + +1. Train a pasa model with the montgomery dataset, on a GPU (``cuda:0``): + + .. code:: sh + + deepdraw train -vv pasa montgomery --batch-size=4 --device="cuda:0" +""", +) +@reusable_options +@verbosity_option(logger=logger, cls=ResourceOption, expose_value=False) +def train( + model, + output_folder, + epochs, + batch_size, + batch_chunk_count, + drop_incomplete_batch, + datamodule, + validation_period, + device, + cache_samples, + seed, + parallel, + monitoring_interval, + balance_classes, + **_, +) -> None: # numpydoc ignore=PR01 + """Train an CNN to perform image classification. + + Training is performed for a configurable number of epochs, and + generates checkpoints. Checkpoints are model files with a .ckpt + extension that are used in subsequent tasks or from which training + can be resumed. + """ + + import torch + from lightning.pytorch import seed_everything + from mednet.libs.common.engine.device import DeviceManager + from mednet.libs.common.engine.trainer import run + from mednet.libs.common.utils.checkpointer import ( + get_checkpoint_to_resume_training, + ) + + from .utils import ( + device_properties, + execution_metadata, + model_summary, + save_json_with_backup, + ) + + checkpoint_file = None + if Path.is_dir(output_folder): + try: + checkpoint_file = get_checkpoint_to_resume_training(output_folder) + except FileNotFoundError: + logger.info( + f"Folder {output_folder} already exists, but I did not" + f" find any usable checkpoint file to resume training" + f" from. Starting from scratch..." + ) + + seed_everything(seed) + + # reset datamodule with user configurable options + datamodule.set_chunk_size(batch_size, batch_chunk_count) + datamodule.drop_incomplete_batch = drop_incomplete_batch + datamodule.cache_samples = cache_samples + datamodule.parallel = parallel + datamodule.model_transforms = model.model_transforms + + datamodule.prepare_data() + datamodule.setup(stage="fit") + + # If asked, rebalances the loss criterion based on the relative proportion + # of class examples available in the training set. Also affects the + # validation loss if a validation set is available on the DataModule. + if balance_classes: + logger.info("Applying DataModule train sampler balancing...") + datamodule.balance_sampler_by_class = True + # logger.info("Applying train/valid loss balancing...") + # model.balance_losses_by_class(datamodule) + else: + logger.info( + "Skipping sample class/dataset ownership balancing on user request" + ) + + logger.info(f"Training for at most {epochs} epochs.") + + arguments = {} + arguments["max_epoch"] = epochs + arguments["epoch"] = 0 + + if checkpoint_file is None or not hasattr(model, "on_load_checkpoint"): + # Sets the model normalizer with the unaugmented-train-subset if we are + # starting from scratch and/or the model does not contain its own + # checkpoint loading strategy (e.g. a pytorch stock checkpoint). This + # call may be a NOOP, if the model comes from outside this framework, + # and expects different weights for the normalisation layer. + if hasattr(model, "set_normalizer"): + model.set_normalizer(datamodule.unshuffled_train_dataloader()) + else: + logger.warning( + f"Model {model.name} has no `set_normalizer` method. " + "Skipping normalization setup (unsupported external model)." + ) + else: + # Normalizer will be loaded during model.on_load_checkpoint + checkpoint = torch.load(checkpoint_file) + start_epoch = checkpoint["epoch"] + logger.info( + f"Resuming from epoch {start_epoch} " + f"(checkpoint file: `{str(checkpoint_file)}`)..." + ) + + device_manager = DeviceManager(device) + + # stores all information we can think of, to reproduce this later + json_data: dict[str, typing.Any] = execution_metadata() + json_data.update(device_properties(device_manager.device_type)) + json_data.update( + dict( + database_name=datamodule.database_name, + split_name=datamodule.split_name, + epochs=epochs, + batch_size=batch_size, + batch_chunk_count=batch_chunk_count, + drop_incomplete_batch=drop_incomplete_batch, + validation_period=validation_period, + cache_samples=cache_samples, + seed=seed, + parallel=parallel, + monitoring_interval=monitoring_interval, + balance_classes=balance_classes, + model_name=model.name, + ), + ) + json_data.update(model_summary(model)) + json_data = {k.replace("_", "-"): v for k, v in json_data.items()} + save_json_with_backup(output_folder / "meta.json", json_data) + + run( + model=model, + datamodule=datamodule, + validation_period=validation_period, + device_manager=device_manager, + max_epochs=epochs, + output_folder=output_folder, + monitoring_interval=monitoring_interval, + batch_chunk_count=batch_chunk_count, + checkpoint=checkpoint_file, + ) diff --git a/src/mednet/libs/segmentation/scripts/utils.py b/src/mednet/libs/segmentation/scripts/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..890d1c5437d80c71a9a6956fa9fdaf7e259f4259 --- /dev/null +++ b/src/mednet/libs/segmentation/scripts/utils.py @@ -0,0 +1,192 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later +"""Utilities for command-line scripts.""" + +import json +import logging +import pathlib +import re +import shutil + +import lightning.pytorch +import lightning.pytorch.callbacks +import torch.nn +from mednet.libs.common.engine.device import SupportedPytorchDevice + +logger = logging.getLogger("mednet") + + +def model_summary( + model: torch.nn.Module, +) -> dict[str, int | list[tuple[str, str, int]]]: + """Save a little summary of the model in a txt file. + + Parameters + ---------- + model + Instance of the model for which to save the summary. + + Returns + ------- + tuple[lightning.pytorch.callbacks.ModelSummary, int] + A tuple with the model summary in a text format and number of parameters of the model. + """ + + s = lightning.pytorch.utilities.model_summary.ModelSummary( # type: ignore + model, + ) + + return dict( + model_summary=list(zip(s.layer_names, s.layer_types, s.param_nums)), + model_size=s.total_parameters, + ) + + +def device_properties( + device_type: SupportedPytorchDevice, +) -> dict[str, int | float | str]: + """Generate information concerning hardware properties. + + Parameters + ---------- + device_type + The type of compute device we are using. + + Returns + ------- + Static properties of the current machine. + """ + + from mednet.libs.common.utils.resources import ( + cpu_constants, + cuda_constants, + mps_constants, + ) + + retval: dict[str, int | float | str] = {} + retval.update(cpu_constants()) + + match device_type: + case "cpu": + pass + case "cuda": + results = cuda_constants() + if results is not None: + retval.update(results) + case "mps": + results = mps_constants() + if results is not None: + retval.update(results) + case _: + pass + + return retval + + +def execution_metadata() -> dict[str, int | float | str | dict[str, str]]: + """Produce metadata concerning the running script, in the form of a + dictionary. + + This function returns potentially useful metadata concerning program + execution. It contains a certain number of preset variables. + + Returns + ------- + A dictionary that contains the following fields: + + * ``package-name``: current package name (e.g. ``mednet``) + * ``package-version``: current package version (e.g. ``1.0.0b0``) + * ``datetime``: date and time in ISO8601 format (e.g. ``2024-02-23T18:38:09+01:00``) + * ``user``: username (e.g. ``johndoe``) + * ``conda-env``: if set, the name of the current conda environment + * ``path``: current path when executing the command + * ``command-line``: the command-line that is being run + * ``hostname``: machine hostname (e.g. ``localhost``) + * ``platform``: machine platform (e.g. ``darwin``) + """ + + import importlib.metadata + import importlib.util + import os + import sys + + args = [] + for k in sys.argv: + if " " in k: + args.append(f"'{k}'") + else: + args.append(k) + + # current date time, in ISO8610 format + datetime = __import__("datetime").datetime.now().astimezone().isoformat() + + # collects dependence information + package_name = __package__.split(".")[0] + requires = importlib.metadata.requires(package_name) or [] + dependence_names = [re.split(r"(\=|~|!|>|<|;|\s)+", k)[0] for k in requires] + dependencies = { + k: importlib.metadata.version(k) # version number as str + for k in dependence_names + if importlib.util.find_spec(k) is not None # if is installed + } + + # checks if the current version corresponds to a dirty (uncommitted) change + # set, issues a warning to the user + current_version = importlib.metadata.version(package_name) + try: + import versioningit + + actual_version = versioningit.get_version(".", config={}) + if current_version != actual_version: + logger.warning( + f"Version mismatch between current version set " + f"({current_version}) and actual version returned by " + f"versioningit ({actual_version}). This typically happens " + f"when you commit changes locally and do not re-install the " + f"package. Run `pip install -e .` or equivalent to fix this.", + ) + except Exception as e: + # not in a git repo? + logger.debug(f"Error {e}") + pass + + return { + "datetime": datetime, + "package-name": __package__.split(".")[0], + "package-version": current_version, + "dependencies": dependencies, + "user": __import__("getpass").getuser(), + "conda-env": os.environ.get("CONDA_DEFAULT_ENV", ""), + "path": os.path.realpath(os.curdir), + "command-line": " ".join(args), + "hostname": __import__("platform").node(), + "platform": sys.platform, + } + + +def save_json_with_backup(path: pathlib.Path, data: dict | list) -> None: + """Save a dictionary into a JSON file with path checking and backup. + + This function will save a dictionary into a JSON file. It will check to + the existence of the directory leading to the file and create it if + necessary. If the file already exists on the destination folder, it is + backed-up before a new file is created with the new contents. + + Parameters + ---------- + path + The full path where to save the JSON data. + data + The data to save on the JSON file. + """ + + logger.info(f"Writing run metadata at `{path}`...") + + path.parent.mkdir(parents=True, exist_ok=True) + if path.exists(): + backup = path.parent / (path.name + "~") + shutil.copy(path, backup) + + with path.open("w") as f: + json.dump(data, f, indent=2) diff --git a/src/mednet/libs/segmentation/utils/rc.py b/src/mednet/libs/segmentation/utils/rc.py new file mode 100644 index 0000000000000000000000000000000000000000..d7b3c48d83a9882f28d9392a36058e35b6405df9 --- /dev/null +++ b/src/mednet/libs/segmentation/utils/rc.py @@ -0,0 +1,15 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from clapper.rc import UserDefaults + + +def load_rc() -> UserDefaults: + """Return global configuration variables. + + Returns + ------- + The user defaults read from the user .toml configuration file. + """ + return UserDefaults("deepdraw.toml") diff --git a/src/mednet/scripts/cli.py b/src/mednet/scripts/cli.py index 322b8844ea9345b8ceca181f4aebf3b64b590da2..25b08c05e39ac1117bf731105ab792a572048191 100644 --- a/src/mednet/scripts/cli.py +++ b/src/mednet/scripts/cli.py @@ -1,6 +1,7 @@ import click from clapper.click import AliasedGroup from mednet.libs.classification.scripts.cli import classification +from mednet.libs.segmentation.scripts.cli import segmentation @click.group( @@ -13,3 +14,4 @@ def cli(): cli.add_command(classification) +cli.add_command(segmentation)