diff --git a/src/mednet/libs/classification/engine/predictor.py b/src/mednet/libs/classification/engine/predictor.py index 771d918f7ef454656038380d743d2da45e7745cd..d0298a3ee95652cc90ffeb2db846d51c8b46fedd 100644 --- a/src/mednet/libs/classification/engine/predictor.py +++ b/src/mednet/libs/classification/engine/predictor.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later import logging +import typing import lightning.pytorch import torch.utils.data @@ -18,6 +19,79 @@ from ..models.typing import ( logger = logging.getLogger("mednet") +class _JSONMetadataCollector(lightning.pytorch.callbacks.BasePredictionWriter): + """Collects further sample metadata to store with predictions. + + This object collects further sample metadata we typically keep with + predictions. + + Parameters + ---------- + write_interval + When will this callback be active. + """ + + def __init__( + self, + write_interval: typing.Literal["batch", "epoch", "batch_and_epoch"] = "batch", + ): + super().__init__(write_interval=write_interval) + self._data: list[BinaryPrediction] | list[MultiClassPrediction] = [] + + def write_on_batch_end( + self, + trainer: lightning.pytorch.Trainer, + pl_module: lightning.pytorch.LightningModule, + prediction: typing.Any, + batch_indices: typing.Sequence[int] | None, + batch: typing.Any, + batch_idx: int, + dataloader_idx: int, + ) -> None: + """Write batch predictions to disk. + + Parameters + ---------- + trainer + The trainer being used. + pl_module + The pytorch module. + prediction + The actual predictions to record. + batch_indices + The relative position of samples on the epoch. + batch + The current batch. + batch_idx + Index of the batch overall. + dataloader_idx + Index of the dataloader overall. + """ + for k, sample_pred in enumerate(prediction): + sample_name: str = batch[1]["name"][k] + target_shape = batch[1]["target"][k].shape + self._data.append( + ( + sample_name, + batch[1]["target"][k].cpu().numpy().tolist(), + sample_pred.cpu().numpy().reshape(target_shape).tolist(), + ) + ) + + def reset(self) -> list[BinaryPrediction] | list[MultiClassPrediction]: + """Summary of written objects. + + Also resets the internal state. + + Returns + ------- + A list containing a summary of all samples written. + """ + retval = self._data + self._data = [] + return retval + + def run( model: lightning.pytorch.LightningModule, datamodule: lightning.pytorch.LightningDataModule, @@ -77,35 +151,38 @@ def run( from lightning.pytorch.loggers.logger import DummyLogger + collector = _JSONMetadataCollector() + accelerator, devices = device_manager.lightning_accelerator() trainer = lightning.pytorch.Trainer( accelerator=accelerator, devices=devices, logger=DummyLogger(), + callbacks=[collector], ) - def _flatten(p: list[list]): - return [sample for batch in p for sample in batch] - dataloaders = datamodule.predict_dataloader() if isinstance(dataloaders, torch.utils.data.DataLoader): logger.info("Running prediction on a single dataloader...") - return _flatten(trainer.predict(model, dataloaders)) # type: ignore + trainer.predict(model, dataloaders, return_predictions=False) + return collector.reset() if isinstance(dataloaders, list): retval_list = [] for k, dataloader in enumerate(dataloaders): logger.info(f"Running prediction on split `{k}`...") - retval_list.append(_flatten(trainer.predict(model, dataloader))) # type: ignore - return retval_list + trainer.predict(model, dataloader, return_predictions=False) + retval_list.append(collector.reset()) + return retval_list # type: ignore if isinstance(dataloaders, dict): retval_dict = {} for name, dataloader in dataloaders.items(): logger.info(f"Running prediction on `{name}` split...") - retval_dict[name] = _flatten(trainer.predict(model, dataloader)) # type: ignore - return retval_dict + trainer.predict(model, dataloader, return_predictions=False) + retval_dict[name] = collector.reset() + return retval_dict # type: ignore if dataloaders is None: logger.warning("Datamodule did not return any prediction dataloaders!") diff --git a/src/mednet/libs/classification/models/alexnet.py b/src/mednet/libs/classification/models/alexnet.py index eab264185b16343936e0daf2d6f1935190ca15f3..7e20c5943a25a2852f7c38e892af04f91262f977 100644 --- a/src/mednet/libs/classification/models/alexnet.py +++ b/src/mednet/libs/classification/models/alexnet.py @@ -13,8 +13,6 @@ import torchvision.models as models from mednet.libs.common.data.typing import TransformSequence from mednet.libs.common.models.model import Model -from .separate import separate - logger = logging.getLogger("mednet") @@ -117,13 +115,7 @@ class Alexnet(Model): ) self.normalizer = make_imagenet_normalizer() else: - from .normalizer import make_z_normalizer - - logger.info( - f"Uninitialised {self.name} model - " - f"computing z-norm factors from train dataloader.", - ) - self.normalizer = make_z_normalizer(dataloader) + super().set_normalizer(dataloader) def training_step(self, batch, _): images = batch[0] @@ -160,5 +152,4 @@ class Alexnet(Model): def predict_step(self, batch, batch_idx, dataloader_idx=0): outputs = self(batch[0]) - probabilities = torch.sigmoid(outputs) - return separate((probabilities, batch[1])) + return torch.sigmoid(outputs) diff --git a/src/mednet/libs/classification/models/densenet.py b/src/mednet/libs/classification/models/densenet.py index 428ba28187328822c3eb241468ea09249d1786ff..5f41f1aebfbc8fb4b2315996a33dd39f4301d3c6 100644 --- a/src/mednet/libs/classification/models/densenet.py +++ b/src/mednet/libs/classification/models/densenet.py @@ -13,8 +13,6 @@ import torchvision.models as models from mednet.libs.common.data.typing import TransformSequence from mednet.libs.common.models.model import Model -from .separate import separate - logger = logging.getLogger("mednet") @@ -120,13 +118,7 @@ class Densenet(Model): ) self.normalizer = make_imagenet_normalizer() else: - from .normalizer import make_z_normalizer - - logger.info( - f"Uninitialised {self.name} model - " - f"computing z-norm factors from train dataloader.", - ) - self.normalizer = make_z_normalizer(dataloader) + super().set_normalizer(dataloader) def training_step(self, batch, _): images = batch[0] @@ -158,5 +150,4 @@ class Densenet(Model): def predict_step(self, batch, batch_idx, dataloader_idx=0): outputs = self(batch[0]) - probabilities = torch.sigmoid(outputs) - return separate((probabilities, batch[1])) + return torch.sigmoid(outputs) diff --git a/src/mednet/libs/classification/models/logistic_regression.py b/src/mednet/libs/classification/models/logistic_regression.py index f203e35221671f6abb8a9db87f12d3d82f87201f..f9fe1847be16de65dc3dfbd88223b89db6eeec97 100644 --- a/src/mednet/libs/classification/models/logistic_regression.py +++ b/src/mednet/libs/classification/models/logistic_regression.py @@ -8,8 +8,6 @@ import lightning.pytorch as pl import torch import torch.nn as nn -from .separate import separate - class LogisticRegression(pl.LightningModule): """Logistic regression classifier with a single output. @@ -62,7 +60,7 @@ class LogisticRegression(pl.LightningModule): self.linear = nn.Linear(input_size, 1) def forward(self, x): - return self.linear(x) + return self.linear(self.normalizer(x)) def training_step(self, batch, batch_idx): _input = batch[1] @@ -105,8 +103,7 @@ class LogisticRegression(pl.LightningModule): def predict_step(self, batch, batch_idx, dataloader_idx=0): outputs = self(batch[0]) - probabilities = torch.sigmoid(outputs) - return separate((probabilities, batch[1])) + return torch.sigmoid(outputs) def configure_optimizers(self): return self._optimizer_type( diff --git a/src/mednet/libs/classification/models/mlp.py b/src/mednet/libs/classification/models/mlp.py index e8e4b2904264d8fc9485ad28adbf4b90213f51eb..bd928410a3133141a0048de8537df4bf65cd8e49 100644 --- a/src/mednet/libs/classification/models/mlp.py +++ b/src/mednet/libs/classification/models/mlp.py @@ -7,8 +7,6 @@ import typing import lightning.pytorch as pl import torch -from .separate import separate - class MultiLayerPerceptron(pl.LightningModule): """MLP with a variable number of inputs and hidden neurons (single layer). @@ -66,7 +64,7 @@ class MultiLayerPerceptron(pl.LightningModule): self.fc2 = torch.nn.Linear(hidden_size, 1) def forward(self, x): - return self.fc2(self.relu(self.fc1(x))) + return self.fc2(self.relu(self.fc1(self.normalizer(x)))) def training_step(self, batch, batch_idx): _input = batch[1] @@ -109,8 +107,7 @@ class MultiLayerPerceptron(pl.LightningModule): def predict_step(self, batch, batch_idx, dataloader_idx=0): outputs = self(batch[0]) - probabilities = torch.sigmoid(outputs) - return separate((probabilities, batch[1])) + return torch.sigmoid(outputs) def configure_optimizers(self): return self._optimizer_type( diff --git a/src/mednet/libs/classification/models/pasa.py b/src/mednet/libs/classification/models/pasa.py index 478c9397c30d506cc59c08107dcad5a8dd4e1ffb..37415a7690a128728d5cd280a47ebef194114330 100644 --- a/src/mednet/libs/classification/models/pasa.py +++ b/src/mednet/libs/classification/models/pasa.py @@ -13,8 +13,6 @@ import torch.utils.data from mednet.libs.common.data.typing import TransformSequence from mednet.libs.common.models.model import Model -from .separate import separate - logger = logging.getLogger("mednet") @@ -223,5 +221,4 @@ class Pasa(Model): def predict_step(self, batch, batch_idx, dataloader_idx=0): outputs = self(batch[0]) - probabilities = torch.sigmoid(outputs) - return separate((probabilities, batch[1])) + return torch.sigmoid(outputs) diff --git a/src/mednet/libs/classification/models/separate.py b/src/mednet/libs/classification/models/separate.py deleted file mode 100644 index 9b575e8ee1a39258b2437e1c968b303aff7b028b..0000000000000000000000000000000000000000 --- a/src/mednet/libs/classification/models/separate.py +++ /dev/null @@ -1,61 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later -"""Contains the inverse :py:func:`torch.utils.data.default_collate`.""" - -import typing - -import torch -from mednet.libs.common.data.typing import Sample - -from .typing import BinaryPrediction, MultiClassPrediction - - -def _as_predictions( - samples: typing.Iterable[Sample], -) -> list[BinaryPrediction | MultiClassPrediction]: - """Take a list of separated batch predictions and transforms it into a list - of formal predictions. - - Parameters - ---------- - samples - A sequence of samples as returned by :py:func:`separate`. - - Returns - ------- - list[BinaryPrediction | MultiClassPrediction] - A list of typed predictions that can be saved to disk. - """ - - return [(v[1]["name"], v[1]["target"].item(), v[0].item()) for v in samples] - - -def separate(batch: Sample) -> list[BinaryPrediction | MultiClassPrediction]: - """Separate a collated batch, reconstituting its samples. - - This function implements the inverse of - :py:func:`torch.utils.data.default_collate`, and can separate, into - samples, batches of data with different attributes. It follows the inverse - path of that function, and implements the following separation algorithms: - - * :class:`torch.Tensor` -> :class:`torch.Tensor` (with a removed outer - dimension, via :py:func:`torch.flatten`) - * ``typing.Mapping[K, V[]]`` -> ``[dict[K, V_1], dict[K, V_2], ...]`` - - Parameters - ---------- - batch - A batch, as output by torch model forwarding. - - Returns - ------- - A list of predictions that contains the predictions and associated metadata - for each processed sample. - """ - - # as of now, this is really simple - to be made more complex upon need. - metadata = [ - {key: value[i] for key, value in batch[1].items()} for i in range(len(batch[0])) - ] - return _as_predictions(zip(torch.flatten(batch[0]), metadata)) diff --git a/src/mednet/libs/segmentation/engine/predictor.py b/src/mednet/libs/segmentation/engine/predictor.py index 8371c9a4dc754d875d306d500d45f418913b2f5a..0e5c290f1ed28e975dc6a3bd702353039ba0ae34 100644 --- a/src/mednet/libs/segmentation/engine/predictor.py +++ b/src/mednet/libs/segmentation/engine/predictor.py @@ -38,7 +38,7 @@ class _HDF5Writer(lightning.pytorch.callbacks.BasePredictionWriter): ): super().__init__(write_interval=write_interval) self.output_folder = output_folder - self._written: list[list[str]] = [] + self._data: list[tuple[str, str]] = [] def write_on_batch_end( self, @@ -69,39 +69,40 @@ class _HDF5Writer(lightning.pytorch.callbacks.BasePredictionWriter): dataloader_idx Index of the dataloader overall. """ - for k, p in enumerate(prediction): - stem = pathlib.Path(p[0]).with_suffix(".hdf5") + for k, sample_pred in enumerate(prediction): + sample_name: str = batch[1]["name"][k] + stem = pathlib.Path(sample_name).with_suffix(".hdf5") output_path = self.output_folder / stem - tqdm.tqdm.write(f"`{p[0]}` -> `{str(output_path)}`") + tqdm.tqdm.write(f"`{sample_name}` -> `{str(output_path)}`") output_path.parent.mkdir(parents=True, exist_ok=True) with h5py.File(output_path, "w") as f: f.create_dataset( "image", - data=batch[0][k].numpy(), + data=batch[0][k].cpu().numpy(), compression="gzip", compression_opts=9, ) f.create_dataset( "prediction", - data=p[3].numpy().squeeze(0), + data=sample_pred.cpu().numpy().squeeze(0), compression="gzip", compression_opts=9, ) f.create_dataset( "target", - data=(batch[1]["target"][k].squeeze(0).numpy() > 0.5), + data=(batch[1]["target"][k].squeeze(0).cpu().numpy() > 0.5), compression="gzip", compression_opts=9, ) f.create_dataset( "mask", - data=(batch[1]["mask"][k].squeeze(0).numpy() > 0.5), + data=(batch[1]["mask"][k].squeeze(0).cpu().numpy() > 0.5), compression="gzip", compression_opts=9, ) - self._written.append([p[0], str(stem)]) + self._data.append((sample_name, str(stem))) - def written(self) -> list[list[str]]: + def reset(self) -> list[tuple[str, str]]: """Summary of written objects. Also resets the internal state. @@ -110,8 +111,8 @@ class _HDF5Writer(lightning.pytorch.callbacks.BasePredictionWriter): ------- A list containing a summary of all samples written. """ - retval = self._written - self._written = [] + retval = self._data + self._data = [] return retval @@ -120,7 +121,12 @@ def run( datamodule: lightning.pytorch.LightningDataModule, device_manager: DeviceManager, output_folder: pathlib.Path, -) -> dict[str, list[list[str]]] | list[list[list[str]]] | list[list[str]] | None: +) -> ( + dict[str, list[tuple[str, str]]] + | list[list[tuple[str, str]]] + | list[tuple[str, str]] + | None +): """Run inference on input data, output predictions. Parameters @@ -154,14 +160,14 @@ def run( from lightning.pytorch.loggers.logger import DummyLogger - writer = _HDF5Writer(output_folder) + collector = _HDF5Writer(output_folder) accelerator, devices = device_manager.lightning_accelerator() trainer = lightning.pytorch.Trainer( accelerator=accelerator, devices=devices, logger=DummyLogger(), - callbacks=[writer], + callbacks=[collector], ) dataloaders = datamodule.predict_dataloader() @@ -169,14 +175,14 @@ def run( if isinstance(dataloaders, torch.utils.data.DataLoader): logger.info("Running prediction on a single dataloader...") trainer.predict(model, dataloaders, return_predictions=False) - return writer.written() + return collector.reset() if isinstance(dataloaders, list): retval_list = [] for k, dataloader in enumerate(dataloaders): logger.info(f"Running prediction on split `{k}`...") trainer.predict(model, dataloader, return_predictions=False) - retval_list.append(writer.written()) + retval_list.append(collector.reset()) return retval_list if isinstance(dataloaders, dict): @@ -184,7 +190,7 @@ def run( for name, dataloader in dataloaders.items(): logger.info(f"Running prediction on `{name}` split...") trainer.predict(model, dataloader, return_predictions=False) - retval_dict[name] = writer.written() + retval_dict[name] = collector.reset() return retval_dict if dataloaders is None: diff --git a/src/mednet/libs/segmentation/models/driu.py b/src/mednet/libs/segmentation/models/driu.py index cd77fb28feff97df2f608b63660810be664434b5..6dbdef7880bbdf6c937a96c19739ec36af8817d4 100644 --- a/src/mednet/libs/segmentation/models/driu.py +++ b/src/mednet/libs/segmentation/models/driu.py @@ -14,7 +14,6 @@ from mednet.libs.common.models.model import Model from .backbones.vgg import vgg16_for_segmentation from .losses import SoftJaccardBCELogitsLoss from .make_layers import UpsampleCropBlock, conv_with_kaiming_uniform -from .separate import separate logger = logging.getLogger("mednet") @@ -133,8 +132,7 @@ class DRIU(Model): self.head = DRIUHead([64, 128, 256, 512]) def forward(self, x): - if self.normalizer is not None: - x = self.normalizer(x) + x = self.normalizer(x) x = self.backbone(x) return self.head(x) @@ -160,7 +158,7 @@ class DRIU(Model): ) self.normalizer = make_imagenet_normalizer() else: - self.normalizer = None + super().set_normalizer(dataloader) def training_step(self, batch, batch_idx): images = batch[0] @@ -180,8 +178,7 @@ class DRIU(Model): def predict_step(self, batch, batch_idx, dataloader_idx=0): output = self(batch[0])[1] - probabilities = torch.sigmoid(output) - return separate((probabilities, batch[1])) + return torch.sigmoid(output) def configure_optimizers(self): return self._optimizer_type(self.parameters(), **self._optimizer_arguments) diff --git a/src/mednet/libs/segmentation/models/driu_bn.py b/src/mednet/libs/segmentation/models/driu_bn.py index 4c19c267c92ffc153f53e85ca4c5a3690ac8e684..07ddc62508f79d0717edca85171404ff980dcd56 100644 --- a/src/mednet/libs/segmentation/models/driu_bn.py +++ b/src/mednet/libs/segmentation/models/driu_bn.py @@ -14,7 +14,6 @@ from mednet.libs.common.models.model import Model from .backbones.vgg import vgg16_for_segmentation from .losses import SoftJaccardBCELogitsLoss from .make_layers import UpsampleCropBlock, conv_with_kaiming_uniform -from .separate import separate logger = logging.getLogger("mednet") @@ -136,8 +135,7 @@ class DRIUBN(Model): self.head = DRIUBNHead([64, 128, 256, 512]) def forward(self, x): - if self.normalizer is not None: - x = self.normalizer(x) + x = self.normalizer(x) x = self.backbone(x) return self.head(x) @@ -163,7 +161,7 @@ class DRIUBN(Model): ) self.normalizer = make_imagenet_normalizer() else: - self.normalizer = None + super().set_normalizer(dataloader) def training_step(self, batch, batch_idx): images = batch[0] @@ -183,8 +181,7 @@ class DRIUBN(Model): def predict_step(self, batch, batch_idx, dataloader_idx=0): output = self(batch[0])[1] - probabilities = torch.sigmoid(output) - return separate((probabilities, batch[1])) + return torch.sigmoid(output) def configure_optimizers(self): return self._optimizer_type(self.parameters(), **self._optimizer_arguments) diff --git a/src/mednet/libs/segmentation/models/driu_od.py b/src/mednet/libs/segmentation/models/driu_od.py index 308f87e19f551119c23765a22142f4671bbbb6e5..1e1802af8f973ac9792a5bb7ab99114629dd929f 100644 --- a/src/mednet/libs/segmentation/models/driu_od.py +++ b/src/mednet/libs/segmentation/models/driu_od.py @@ -15,7 +15,6 @@ from .backbones.vgg import vgg16_for_segmentation from .driu import ConcatFuseBlock from .losses import SoftJaccardBCELogitsLoss from .make_layers import UpsampleCropBlock -from .separate import separate logger = logging.getLogger("mednet") @@ -118,8 +117,7 @@ class DRIUOD(Model): self.head = DRIUODHead([128, 256, 512, 512]) def forward(self, x): - if self.normalizer is not None: - x = self.normalizer(x) + x = self.normalizer(x) x = self.backbone(x) return self.head(x) @@ -145,7 +143,7 @@ class DRIUOD(Model): ) self.normalizer = make_imagenet_normalizer() else: - self.normalizer = None + super().set_normalizer(dataloader) def training_step(self, batch, batch_idx): images = batch[0] @@ -165,8 +163,7 @@ class DRIUOD(Model): def predict_step(self, batch, batch_idx, dataloader_idx=0): output = self(batch[0])[1] - probabilities = torch.sigmoid(output) - return separate((probabilities, batch[1])) + return torch.sigmoid(output) def configure_optimizers(self): return self._optimizer_type(self.parameters(), **self._optimizer_arguments) diff --git a/src/mednet/libs/segmentation/models/driu_pix.py b/src/mednet/libs/segmentation/models/driu_pix.py index a4cbbc5b41c0a8aea287be949a2391091500abac..29942643e4c6e9f44d2121534ec15d6bca707c43 100644 --- a/src/mednet/libs/segmentation/models/driu_pix.py +++ b/src/mednet/libs/segmentation/models/driu_pix.py @@ -15,7 +15,6 @@ from .backbones.vgg import vgg16_for_segmentation from .driu import ConcatFuseBlock from .losses import SoftJaccardBCELogitsLoss from .make_layers import UpsampleCropBlock -from .separate import separate logger = logging.getLogger("mednet") @@ -122,8 +121,7 @@ class DRIUPix(Model): self.head = DRIUPIXHead([64, 128, 256, 512]) def forward(self, x): - if self.normalizer is not None: - x = self.normalizer(x) + x = self.normalizer(x) x = self.backbone(x) return self.head(x) @@ -149,7 +147,7 @@ class DRIUPix(Model): ) self.normalizer = make_imagenet_normalizer() else: - self.normalizer = None + super().set_normalizer(dataloader) def training_step(self, batch, batch_idx): images = batch[0] @@ -169,8 +167,7 @@ class DRIUPix(Model): def predict_step(self, batch, batch_idx, dataloader_idx=0): output = self(batch[0])[1] - probabilities = torch.sigmoid(output) - return separate((probabilities, batch[1])) + return torch.sigmoid(output) def configure_optimizers(self): return self._optimizer_type(self.parameters(), **self._optimizer_arguments) diff --git a/src/mednet/libs/segmentation/models/hed.py b/src/mednet/libs/segmentation/models/hed.py index 779c48cfce7f77ea88c2e5c19d58f36c8fd4c457..e3bcd094814ff4981ddead877ae079322b3546d4 100644 --- a/src/mednet/libs/segmentation/models/hed.py +++ b/src/mednet/libs/segmentation/models/hed.py @@ -13,7 +13,6 @@ from mednet.libs.common.models.model import Model from .backbones.vgg import vgg16_for_segmentation from .losses import MultiSoftJaccardBCELogitsLoss from .make_layers import UpsampleCropBlock, conv_with_kaiming_uniform -from .separate import separate logger = logging.getLogger("mednet") @@ -137,8 +136,7 @@ class HED(Model): self.head = HEDHead([64, 128, 256, 512, 512]) def forward(self, x): - if self.normalizer is not None: - x = self.normalizer(x) + x = self.normalizer(x) x = self.backbone(x) return self.head(x) @@ -164,7 +162,7 @@ class HED(Model): ) self.normalizer = make_imagenet_normalizer() else: - self.normalizer = None + super().set_normalizer(dataloader) def training_step(self, batch, batch_idx): images = batch[0] @@ -184,8 +182,7 @@ class HED(Model): def predict_step(self, batch, batch_idx, dataloader_idx=0): output = self(batch[0])[1] - probabilities = torch.sigmoid(output) - return separate((probabilities, batch[1])) + return torch.sigmoid(output) def configure_optimizers(self): return self._optimizer_type(self.parameters(), **self._optimizer_arguments) diff --git a/src/mednet/libs/segmentation/models/lwnet.py b/src/mednet/libs/segmentation/models/lwnet.py index 7a21ad58c026c66bec17105f89a6ff5665d2b57b..90507e4dcd513e9f0a8d12a368f08a010c9ff77c 100644 --- a/src/mednet/libs/segmentation/models/lwnet.py +++ b/src/mednet/libs/segmentation/models/lwnet.py @@ -23,8 +23,6 @@ from mednet.libs.common.data.typing import TransformSequence from mednet.libs.common.models.model import Model from mednet.libs.segmentation.models.losses import MultiWeightedBCELogitsLoss -from .separate import separate - def _conv1x1(in_planes, out_planes, stride=1): return torch.nn.Conv2d( @@ -341,8 +339,9 @@ class LittleWNet(Model): ) def forward(self, x): - x1 = self.unet1(x) - x2 = self.unet2(torch.cat([x, x1], dim=1)) + xn = self.normalizer(x) + x1 = self.unet1(xn) + x2 = self.unet2(torch.cat([xn, x1], dim=1)) return x1, x2 @@ -364,8 +363,7 @@ class LittleWNet(Model): def predict_step(self, batch, batch_idx, dataloader_idx=0): output = self(batch[0])[1] - probabilities = torch.sigmoid(output) - return separate((probabilities, batch[1])) + return torch.sigmoid(output) def configure_optimizers(self): return self._optimizer_type(self.parameters(), **self._optimizer_arguments) diff --git a/src/mednet/libs/segmentation/models/m2unet.py b/src/mednet/libs/segmentation/models/m2unet.py index 24587582641680c419eaddb1a8f5b49aad248f94..d934881ce9082694effad5ed4e0c4a431d51e654 100644 --- a/src/mednet/libs/segmentation/models/m2unet.py +++ b/src/mednet/libs/segmentation/models/m2unet.py @@ -13,7 +13,6 @@ from mednet.libs.segmentation.models.losses import SoftJaccardBCELogitsLoss from torchvision.models.mobilenetv2 import InvertedResidual from .backbones.mobilenetv2 import mobilenet_v2_for_segmentation -from .separate import separate logger = logging.getLogger("mednet") @@ -185,8 +184,7 @@ class M2UNET(Model): self.head = M2UNetHead(in_channels_list=[16, 24, 32, 96]) def forward(self, x): - if self.normalizer is not None: - x = self.normalizer(x) + x = self.normalizer(x) x = self.backbone(x) return self.head(x) @@ -212,7 +210,7 @@ class M2UNET(Model): ) self.normalizer = make_imagenet_normalizer() else: - self.normalizer = None + super().set_normalizer(dataloader) def training_step(self, batch, batch_idx): images = batch[0] @@ -232,8 +230,7 @@ class M2UNET(Model): def predict_step(self, batch, batch_idx, dataloader_idx=0): output = self(batch[0])[1] - probabilities = torch.sigmoid(output) - return separate((probabilities, batch[1])) + return torch.sigmoid(output) def configure_optimizers(self): return self._optimizer_type(self.parameters(), **self._optimizer_arguments) diff --git a/src/mednet/libs/segmentation/models/separate.py b/src/mednet/libs/segmentation/models/separate.py deleted file mode 100644 index 4f2628f8b41d4b1fea238b8cbda441200cfe109e..0000000000000000000000000000000000000000 --- a/src/mednet/libs/segmentation/models/separate.py +++ /dev/null @@ -1,59 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later -"""Contains the inverse :py:func:`torch.utils.data.default_collate`.""" - -import typing - -from mednet.libs.common.data.typing import Sample - -from .typing import SegmentationPrediction - - -def _as_predictions( - samples: typing.Iterable[Sample], -) -> list[SegmentationPrediction]: - """Take a list of separated batch predictions and transforms it into a list - of formal predictions. - - Parameters - ---------- - samples - A sequence of samples as returned by :py:func:`separate`. - - Returns - ------- - A list of typed predictions that can be saved to disk. - """ - return [(v[1]["name"], v[1]["target"], v[1]["mask"], v[0]) for v in samples] - - -def separate(batch: Sample) -> list[SegmentationPrediction]: - """Separate a collated batch, reconstituting its samples. - - This function implements the inverse of - :py:func:`torch.utils.data.default_collate`, and can separate, into - samples, batches of data with different attributes. It follows the inverse - path of that function, and implements the following separation algorithms: - - * :class:`torch.Tensor` -> :class:`torch.Tensor` (with a removed outer - dimension, via :py:func:`torch.flatten`) - * ``typing.Mapping[K, V[]]`` -> ``[dict[K, V_1], dict[K, V_2], ...]`` - - Parameters - ---------- - batch - A batch, as output by torch model forwarding. - - Returns - ------- - A list of predictions that contains the predictions and associated metadata - for each processed sample. - """ - - # as of now, this is really simple - to be made more complex upon need. - metadata = [ - {key: value[i] for key, value in batch[1].items()} for i in range(len(batch[0])) - ] - - return _as_predictions(zip(batch[0], metadata)) diff --git a/src/mednet/libs/segmentation/models/typing.py b/src/mednet/libs/segmentation/models/typing.py deleted file mode 100644 index 11ec1457ff1ece9fa26d779190c529c0bd78c724..0000000000000000000000000000000000000000 --- a/src/mednet/libs/segmentation/models/typing.py +++ /dev/null @@ -1,15 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later -"""Defines most common types used in code.""" - -import pathlib -import typing - -Checkpoint: typing.TypeAlias = typing.MutableMapping[str, typing.Any] -"""Definition of a lightning checkpoint.""" - -SegmentationPrediction: typing.TypeAlias = tuple[ - pathlib.Path, pathlib.Path, pathlib.Path, pathlib.Path -] -"""The sample name, the target, mask, and the prediction.""" diff --git a/src/mednet/libs/segmentation/models/unet.py b/src/mednet/libs/segmentation/models/unet.py index 0943d283cdac1f2b72f306a116de3d3fe579ce5b..04317572c856704091c5b5022a843049c69e53fe 100644 --- a/src/mednet/libs/segmentation/models/unet.py +++ b/src/mednet/libs/segmentation/models/unet.py @@ -13,7 +13,6 @@ from mednet.libs.common.models.model import Model from .backbones.vgg import vgg16_for_segmentation from .losses import SoftJaccardBCELogitsLoss from .make_layers import UnetBlock, conv_with_kaiming_uniform -from .separate import separate logger = logging.getLogger("mednet") @@ -126,8 +125,7 @@ class Unet(Model): self.head = UNetHead([64, 128, 256, 512, 512], pixel_shuffle=False) def forward(self, x): - if self.normalizer is not None: - x = self.normalizer(x) + x = self.normalizer(x) x = self.backbone(x) return self.head(x) @@ -153,7 +151,7 @@ class Unet(Model): ) self.normalizer = make_imagenet_normalizer() else: - self.normalizer = None + super().set_normalizer(dataloader) def training_step(self, batch, batch_idx): images = batch[0] @@ -173,8 +171,7 @@ class Unet(Model): def predict_step(self, batch, batch_idx, dataloader_idx=0): output = self(batch[0])[1] - probabilities = torch.sigmoid(output) - return separate((probabilities, batch[1])) + return torch.sigmoid(output) def configure_optimizers(self): return self._optimizer_type(self.parameters(), **self._optimizer_arguments) diff --git a/src/mednet/libs/segmentation/scripts/utils.py b/src/mednet/libs/segmentation/scripts/utils.py deleted file mode 100644 index 890d1c5437d80c71a9a6956fa9fdaf7e259f4259..0000000000000000000000000000000000000000 --- a/src/mednet/libs/segmentation/scripts/utils.py +++ /dev/null @@ -1,192 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later -"""Utilities for command-line scripts.""" - -import json -import logging -import pathlib -import re -import shutil - -import lightning.pytorch -import lightning.pytorch.callbacks -import torch.nn -from mednet.libs.common.engine.device import SupportedPytorchDevice - -logger = logging.getLogger("mednet") - - -def model_summary( - model: torch.nn.Module, -) -> dict[str, int | list[tuple[str, str, int]]]: - """Save a little summary of the model in a txt file. - - Parameters - ---------- - model - Instance of the model for which to save the summary. - - Returns - ------- - tuple[lightning.pytorch.callbacks.ModelSummary, int] - A tuple with the model summary in a text format and number of parameters of the model. - """ - - s = lightning.pytorch.utilities.model_summary.ModelSummary( # type: ignore - model, - ) - - return dict( - model_summary=list(zip(s.layer_names, s.layer_types, s.param_nums)), - model_size=s.total_parameters, - ) - - -def device_properties( - device_type: SupportedPytorchDevice, -) -> dict[str, int | float | str]: - """Generate information concerning hardware properties. - - Parameters - ---------- - device_type - The type of compute device we are using. - - Returns - ------- - Static properties of the current machine. - """ - - from mednet.libs.common.utils.resources import ( - cpu_constants, - cuda_constants, - mps_constants, - ) - - retval: dict[str, int | float | str] = {} - retval.update(cpu_constants()) - - match device_type: - case "cpu": - pass - case "cuda": - results = cuda_constants() - if results is not None: - retval.update(results) - case "mps": - results = mps_constants() - if results is not None: - retval.update(results) - case _: - pass - - return retval - - -def execution_metadata() -> dict[str, int | float | str | dict[str, str]]: - """Produce metadata concerning the running script, in the form of a - dictionary. - - This function returns potentially useful metadata concerning program - execution. It contains a certain number of preset variables. - - Returns - ------- - A dictionary that contains the following fields: - - * ``package-name``: current package name (e.g. ``mednet``) - * ``package-version``: current package version (e.g. ``1.0.0b0``) - * ``datetime``: date and time in ISO8601 format (e.g. ``2024-02-23T18:38:09+01:00``) - * ``user``: username (e.g. ``johndoe``) - * ``conda-env``: if set, the name of the current conda environment - * ``path``: current path when executing the command - * ``command-line``: the command-line that is being run - * ``hostname``: machine hostname (e.g. ``localhost``) - * ``platform``: machine platform (e.g. ``darwin``) - """ - - import importlib.metadata - import importlib.util - import os - import sys - - args = [] - for k in sys.argv: - if " " in k: - args.append(f"'{k}'") - else: - args.append(k) - - # current date time, in ISO8610 format - datetime = __import__("datetime").datetime.now().astimezone().isoformat() - - # collects dependence information - package_name = __package__.split(".")[0] - requires = importlib.metadata.requires(package_name) or [] - dependence_names = [re.split(r"(\=|~|!|>|<|;|\s)+", k)[0] for k in requires] - dependencies = { - k: importlib.metadata.version(k) # version number as str - for k in dependence_names - if importlib.util.find_spec(k) is not None # if is installed - } - - # checks if the current version corresponds to a dirty (uncommitted) change - # set, issues a warning to the user - current_version = importlib.metadata.version(package_name) - try: - import versioningit - - actual_version = versioningit.get_version(".", config={}) - if current_version != actual_version: - logger.warning( - f"Version mismatch between current version set " - f"({current_version}) and actual version returned by " - f"versioningit ({actual_version}). This typically happens " - f"when you commit changes locally and do not re-install the " - f"package. Run `pip install -e .` or equivalent to fix this.", - ) - except Exception as e: - # not in a git repo? - logger.debug(f"Error {e}") - pass - - return { - "datetime": datetime, - "package-name": __package__.split(".")[0], - "package-version": current_version, - "dependencies": dependencies, - "user": __import__("getpass").getuser(), - "conda-env": os.environ.get("CONDA_DEFAULT_ENV", ""), - "path": os.path.realpath(os.curdir), - "command-line": " ".join(args), - "hostname": __import__("platform").node(), - "platform": sys.platform, - } - - -def save_json_with_backup(path: pathlib.Path, data: dict | list) -> None: - """Save a dictionary into a JSON file with path checking and backup. - - This function will save a dictionary into a JSON file. It will check to - the existence of the directory leading to the file and create it if - necessary. If the file already exists on the destination folder, it is - backed-up before a new file is created with the new contents. - - Parameters - ---------- - path - The full path where to save the JSON data. - data - The data to save on the JSON file. - """ - - logger.info(f"Writing run metadata at `{path}`...") - - path.parent.mkdir(parents=True, exist_ok=True) - if path.exists(): - backup = path.parent / (path.name + "~") - shutil.copy(path, backup) - - with path.open("w") as f: - json.dump(data, f, indent=2) diff --git a/src/mednet/libs/segmentation/scripts/view.py b/src/mednet/libs/segmentation/scripts/view.py index e0cb9706d3e8d1e646a2f329210efd3fd8a98a3e..79fcb468abde40e8253bc1209663f62524630889 100644 --- a/src/mednet/libs/segmentation/scripts/view.py +++ b/src/mednet/libs/segmentation/scripts/view.py @@ -24,7 +24,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") epilog="""Examples: \b - 1. Runs evaluation on an existing dataset configuration: + 1. Runs view on an existing dataset configuration: .. code:: sh @@ -146,8 +146,8 @@ def view( ) from mednet.libs.segmentation.engine.viewer import view - evaluation_filename = "evaluation.json" - evaluation_file = output_folder / evaluation_filename + view_filename = "view.json" + view_file = output_folder / view_filename with predictions.open("r") as f: predict_data = json.load(f) @@ -164,7 +164,7 @@ def view( ), ) json_data = {k.replace("_", "-"): v for k, v in json_data.items()} - save_json_with_backup(evaluation_file.with_suffix(".meta.json"), json_data) + save_json_with_backup(view_file.with_suffix(".meta.json"), json_data) threshold = validate_threshold(threshold, predict_data) threshold_list = numpy.arange( @@ -203,5 +203,6 @@ def view( alpha=alpha, ) dest = (output_folder / sample[1]).with_suffix(".png") + dest.parent.mkdir(parents=True, exist_ok=True) tqdm.tqdm.write(f"{sample[1]} -> {dest}") image.save(dest)