diff --git a/doc/api.rst b/doc/api.rst index d4e441f00698cd6bbb8bbf42c233f0b3048f8c46..84df54ed4c6881f34eb37c5d29dc30fe4179d4a5 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -12,68 +12,135 @@ This section includes information for using the Python API of ``mednet``. -.. _mednet.libs.classification.api.data: +Common library +-------------- + +This common library contains methods and scripts that can be reused by more specialized libraries. + +.. _mednet.libs.common.api.data: Data Methods ------------- +^^^^^^^^^^^^ Auxiliary classes and methods to define raw dataset iterators. .. autosummary:: :toctree: api/data - mednet.libs.classification.data.augmentations - mednet.libs.classification.data.datamodule - mednet.libs.classification.data.image_utils - mednet.libs.classification.data.split - mednet.libs.classification.data.typing + mednet.libs.common.data.augmentations + mednet.libs.common.data.datamodule + mednet.libs.common.data.image_utils + mednet.libs.common.data.split + mednet.libs.common.data.typing -.. _mednet.libs.classification.api.models: +.. _mednet.libs.common.api.engines: + +Command engines +^^^^^^^^^^^^^^^ + +Functions to actuate on the data. + +.. autosummary:: + :toctree: api/engine + + mednet.libs.common.engine.callbacks + mednet.libs.common.engine.device + mednet.libs.common.engine.loggers + mednet.libs.common.engine.trainer + + +.. _mednet.libs.common.api.models: Models ------- +^^^^^^ -CNN and other models implemented. +Common model utilities. .. autosummary:: :toctree: api/models - mednet.libs.classification.models.pasa - mednet.libs.classification.models.alexnet - mednet.libs.classification.models.densenet - mednet.libs.classification.models.logistic_regression - mednet.libs.classification.models.loss_weights - mednet.libs.classification.models.mlp - mednet.libs.classification.models.model - mednet.libs.classification.models.normalizer - mednet.libs.classification.models.separate - mednet.libs.classification.models.transforms - mednet.libs.classification.models.typing + mednet.libs.common.models.loss_weights + mednet.libs.common.models.model + mednet.libs.common.models.normalizer + mednet.libs.common.models.typing + + +.. _mednet.libs.common.api.utils: + +Utils +^^^^^ + +Reusable auxiliary functions. + +.. autosummary:: + :toctree: api/utils + + mednet.libs.common.utils.checkpointer + mednet.libs.common.utils.gitlab + mednet.libs.common.utils.resources + mednet.libs.common.utils.summary + mednet.libs.common.utils.tensorboard + + +Classification library +---------------------- + +Library for training models on classification tasks + + +.. _mednet.libs.classification.api.data: + +Data +^^^^ + +Classification-specific data methods + +.. autosummary:: + :toctree: api/data + + mednet.libs.classification.data.typing .. _mednet.libs.classification.api.engines: Command engines ---------------- +^^^^^^^^^^^^^^^ Functions to actuate on the data. .. autosummary:: :toctree: api/engine - mednet.libs.classification.engine.callbacks - mednet.libs.classification.engine.device mednet.libs.classification.engine.evaluator - mednet.libs.classification.engine.loggers mednet.libs.classification.engine.predictor - mednet.libs.classification.engine.trainer + + +.. _mednet.libs.classification.api.models: + +Models +^^^^^^ + +CNN and other models implemented. + +.. autosummary:: + :toctree: api/models + + mednet.libs.classification.models.pasa + mednet.libs.classification.models.alexnet + mednet.libs.classification.models.densenet + mednet.libs.classification.models.loss_weights + mednet.libs.classification.models.logistic_regression + mednet.libs.classification.models.mlp + mednet.libs.classification.models.separate + mednet.libs.classification.models.transforms + mednet.libs.classification.models.typing .. _mednet.libs.classification.api.saliency: Saliency Map Generation and Analysis ------------------------------------- +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Engines to generate and analyze saliency mapping techniques. @@ -89,19 +156,79 @@ Engines to generate and analyze saliency mapping techniques. .. _mednet.libs.classification.api.utils: -Various utilities ------------------ +Utils +^^^^^ -Reusable auxiliary functions. +Classification-specific utilities. .. autosummary:: :toctree: api/utils - mednet.utils.checkpointer - mednet.utils.gitlab - mednet.utils.rc - mednet.libs.common.utils.resources - mednet.utils.tensorboard + mednet.libs.segmentation.utils.rc + + +Segmentation library +-------------------- + +Library for training models on segmentation tasks + + +.. _mednet.libs.segmentation.api.data: + +Data +^^^^ + +Segmentation-specific data methods + +.. autosummary:: + :toctree: api/data + + mednet.libs.segmentation.data.typing + + +.. _mednet.libs.segmentation.api.engines: + +Command engines +^^^^^^^^^^^^^^^ + +Functions to actuate on the data. + +.. autosummary:: + :toctree: api/engine + + mednet.libs.segmentation.engine.evaluator + + +.. _mednet.libs.segmentation.api.models: + +Models +^^^^^^ + +CNN and other models implemented. + +.. autosummary:: + :toctree: api/models + + mednet.libs.segmentation.models.losses + mednet.libs.segmentation.models.lwnet + mednet.libs.segmentation.models.separate + mednet.libs.segmentation.models.typing + + +.. _mednet.libs.segmentation.api.utils: + +Utils +^^^^^ + +Segmentation-specific utilities. + +.. autosummary:: + :toctree: api/utils + + mednet.libs.segmentation.utils.measure + mednet.libs.segmentation.utils.plot + mednet.libs.segmentation.utils.rc + mednet.libs.segmentation.utils.table .. include:: links.rst diff --git a/doc/cli.rst b/doc/cli.rst index c9f96653b97feec52dd3e683e41a476ccc3c3119..fd01e8a7a92859ec4757a293a02910c3b5268042 100644 --- a/doc/cli.rst +++ b/doc/cli.rst @@ -2,7 +2,7 @@ .. .. SPDX-License-Identifier: GPL-3.0-or-later -.. _mednet.libs.classification.cli: +.. _mednet.libs.common.cli: ======================== Command-line Interface @@ -12,7 +12,7 @@ This section contains an overview of command-line applications shipped with this package. -.. click:: mednet.libs.classification.scripts.cli:cli +.. click:: mednet.scripts.cli:cli :prog: mednet :nested: full diff --git a/doc/conf.py b/doc/conf.py index 1191dc74d9681e1ac13003c09032c07ab726bd16..6b179cc3a834e6ff3216cab0362578b17280f177 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -117,6 +117,7 @@ autodoc_default_options = { auto_intersphinx_packages = [ "matplotlib", "numpy", + "pandas", "pillow", "psutil", "scipy", diff --git a/doc/references.rst b/doc/references.rst index f70d92d25853be6cc98de4c15813241b99f12d3e..6d5189cbb276963b532e7969f121804125d8dee4 100644 --- a/doc/references.rst +++ b/doc/references.rst @@ -81,3 +81,28 @@ **A Consistent and Efficient Evaluation Strategy for Attribution Methods** in Proceedings of the 39th International Conference on Machine Learning, PMLR, Jun. 2022, pp. 18770–18795. https://proceedings.mlr.press/v162/rong22a.html + +.. [IGLOVIKOV-2018] *V. Iglovikov, S. Seferbekov, A. Buslaev and A. Shvets*, + **TernausNetV2: Fully Convolutional Network for Instance Segmentation**, + 2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition + Workshops (CVPRW), Salt Lake City, UT, 2018, pp. 228-2284. + https://doi.org/10.1109/CVPRW.2018.00042 + +.. [XIE-2015] *S. Xie and Z. Tu*, **Holistically-Nested Edge Detection**, 2015 + IEEE International Conference on Computer Vision (ICCV), Santiago, 2015, pp. + 1395-1403. https://doi.org/10.1109/ICCV.2015.164 + +.. [MANINIS-2016] *K.-K. Maninis, J. Pont-Tuset, P. Arbeláez, and L. Van Gool*, + **Deep Retinal Image Understanding**, in Medical Image Computing and + Computer-Assisted Intervention – MICCAI 2016, Cham, 2016, pp. 140–148. + https://doi.org/10.1007/978-3-319-46723-8_17 + +.. [GALDRAN-2020] *A. Galdran, A. Anjos, J. Dolz, H. Chakor, H. Lombaert, and + I. Ben Ayed*, **The Little W-Net That Could: State-of-the-Art Retinal Vessel + Segmentation with Minimalistic Models**, 2020. + https://arxiv.org/abs/2009.01907 + +.. [GOUTTE-2005] *C. Goutte and E. Gaussier*, **A probabilistic interpretation + of precision, recall and F-score, with implication for evaluation**, + European conference on Advances in Information Retrieval Research, 2005. + https://doi.org/10.1007/978-3-540-31865-1_25 diff --git a/doc/results/index.rst b/doc/results/index.rst index 1fec86c3f6e384c816ed8805532bc2669530b06e..85b646d1b2333ffce3dcc238d4a9d728ec105c0c 100644 --- a/doc/results/index.rst +++ b/doc/results/index.rst @@ -59,9 +59,9 @@ Stratified k-folding has been used (10 folds) to generate these results. .. tip:: To generate the following results, you first need to predict TB on each - fold, then use the :ref:`aggregpred command <mednet.cli>` to aggregate the + fold, then use the :ref:`aggregpred command <mednet.libs.common.cli>` to aggregate the predictions together, and finally evaluate this new file using the - :ref:`compare command <mednet.cli>`. + :ref:`compare command <mednet.libs.common.cli>`. Pasa and DenseNet-121 (random initialization) """"""""""""""""""""""""""""""""""""""""""""" @@ -113,37 +113,37 @@ Thresholds used: :scale: 50% :alt: Testing sets ROC curves for Pasa model trained on normalized-kfold MC - :py:mod:`Pasa <mednet.config.models.pasa>`: Pasa trained on normalized-kfold MC + :py:mod:`Pasa <mednet.libs.classification.config.models.pasa>`: Pasa trained on normalized-kfold MC - .. figure:: img/compare_pasa_mc_ch_kfold_500.jpg :align: center :scale: 50% :alt: Testing sets ROC curves for Pasa model trained on normalized-kfold MC-CH - :py:mod:`Pasa <mednet.config.models.pasa>`: Pasa trained on normalized-kfold MC-CH + :py:mod:`Pasa <mednet.libs.classification.config.models.pasa>`: Pasa trained on normalized-kfold MC-CH - .. figure:: img/compare_pasa_mc_ch_in_kfold_500.jpg :align: center :scale: 50% :alt: Testing sets ROC curves for Pasa model trained on normalized-kfold MC-CH-IN - :py:mod:`Pasa <mednet.config.models.pasa>`: Pasa trained on normalized-kfold MC-CH-IN + :py:mod:`Pasa <mednet.libs.classification.config.models.pasa>`: Pasa trained on normalized-kfold MC-CH-IN * - .. figure:: img/compare_densenet_mc_kfold_2000.jpg :align: center :scale: 50% :alt: Testing sets ROC curves for DenseNet model trained on normalized-kfold MC - :py:mod:`DenseNet <mednet.config.models.densenet>`: DenseNet trained on normalized-kfold MC + :py:mod:`DenseNet <mednet.libs.classification.config.models.densenet>`: DenseNet trained on normalized-kfold MC - .. figure:: img/compare_densenet_mc_ch_kfold_2000.jpg :align: center :scale: 50% :alt: Testing sets ROC curves for DenseNet model trained on normalized-kfold MC-CH - :py:mod:`DenseNet <mednet.config.models.densenet>`: DenseNet trained on normalized-kfold MC-CH + :py:mod:`DenseNet <mednet.libs.classification.config.models.densenet>`: DenseNet trained on normalized-kfold MC-CH - .. figure:: img/compare_densenet_mc_ch_in_kfold_2000.jpg :align: center :scale: 50% :alt: Testing sets ROC curves for DenseNet model trained on normalized-kfold MC-CH-IN - :py:mod:`DenseNet <mednet.config.models.densenet>`: DenseNet trained on normalized-kfold MC-CH-IN + :py:mod:`DenseNet <mednet.libs.classification.config.models.densenet>`: DenseNet trained on normalized-kfold MC-CH-IN DenseNet-121 (pretrained on ImageNet) """"""""""""""""""""""""""""""""""""" @@ -180,19 +180,19 @@ Thresholds used: :scale: 50% :alt: Testing sets ROC curves for DenseNet model trained on normalized-kfold MC - :py:mod:`DenseNet <mednet.config.models.densenet>` DenseNet trained on normalized-kfold MC + :py:mod:`DenseNet <mednet.libs.classification.config.models.densenet>` DenseNet trained on normalized-kfold MC - .. figure:: img/compare_densenetpreIN_mc_ch_kfold_600.jpg :align: center :scale: 50% :alt: Testing sets ROC curves for DenseNet model trained on normalized-kfold MC-CH - :py:mod:`DenseNet <mednet.config.models.densenet>` DenseNet trained on normalized-kfold MC-CH + :py:mod:`DenseNet <mednet.libs.classification.config.models.densenet>` DenseNet trained on normalized-kfold MC-CH - .. figure:: img/compare_densenetpreIN_mc_ch_ch_kfold_600.jpg :align: center :scale: 50% :alt: Testing sets ROC curves for DenseNet model trained on normalized-kfold MC-CH-IN - :py:mod:`DenseNet <mednet.config.models.densenet>` DenseNet trained on normalized-kfold MC-CH-IN + :py:mod:`DenseNet <mednet.libs.classification.config.models.densenet>` DenseNet trained on normalized-kfold MC-CH-IN Logistic Regression Classifier """""""""""""""""""""""""""""" @@ -229,19 +229,19 @@ Thresholds used: :scale: 50% :alt: Testing sets ROC curves for LogReg model trained on normalized-kfold MC - :py:mod:`LogReg <mednet.config.models.logistic_regression>`: LogReg trained on normalized-kfold MC + :py:mod:`LogReg <mednet.libs.classification.config.models.logistic_regression>`: LogReg trained on normalized-kfold MC - .. figure:: img/compare_logreg_mc_ch_kfold_100.jpg :align: center :scale: 50% :alt: Testing sets ROC curves for LogReg model trained on normalized-kfold MC-CH - :py:mod:`LogReg <mednet.config.models.logistic_regression>`: LogReg trained on normalized-kfold MC-CH + :py:mod:`LogReg <mednet.libs.classification.config.models.logistic_regression>`: LogReg trained on normalized-kfold MC-CH - .. figure:: img/compare_logreg_mc_ch_in_kfold_100.jpg :align: center :scale: 50% :alt: Testing sets ROC curves for LogReg model trained on normalized-kfold MC-CH-IN - :py:mod:`LogReg <mednet.config.models.logistic_regression>`: LogReg trained on normalized-kfold MC-CH-IN + :py:mod:`LogReg <mednet.libs.classification.config.models.logistic_regression>`: LogReg trained on normalized-kfold MC-CH-IN DenseNet-121 (pretrained on ImageNet and NIH CXR14) """"""""""""""""""""""""""""""""""""""""""""""""""" @@ -278,19 +278,19 @@ Thresholds used: :scale: 50% :alt: Testing sets ROC curves for DenseNet model trained on normalized-kfold MC (pretrained on NIH) - :py:mod:`DenseNet <mednet.config.models.densenet>`: DenseNet trained on normalized-kfold MC (pretrained on NIH) + :py:mod:`DenseNet <mednet.libs.classification.config.models.densenet>`: DenseNet trained on normalized-kfold MC (pretrained on NIH) - .. figure:: img/compare_densenetpre_mc_ch_kfold_300.jpg :align: center :scale: 50% :alt: Testing sets ROC curves for DenseNet model trained on normalized-kfold MC-CH (pretrained on NIH) - :py:mod:`DenseNet <mednet.config.models.densenet>`: DenseNet trained on normalized-kfold MC-CH (pretrained on NIH) + :py:mod:`DenseNet <mednet.libs.classification.config.models.densenet>`: DenseNet trained on normalized-kfold MC-CH (pretrained on NIH) - .. figure:: img/compare_densenetpre_mc_ch_in_kfold_300.jpg :align: center :scale: 50% :alt: Testing sets ROC curves for DenseNet model trained on normalized-kfold MC-CH-IN (pretrained on NIH) - :py:mod:`DenseNet <mednet.config.models.densenet>`: DenseNet trained on normalized-kfold MC-CH-IN (pretrained on NIH) + :py:mod:`DenseNet <mednet.libs.classification.config.models.densenet>`: DenseNet trained on normalized-kfold MC-CH-IN (pretrained on NIH) Global sensitivity analysis (relevance) @@ -301,7 +301,7 @@ Model used to generate the following figures: LogReg trained on MC-CH-IN fold 0 .. tip:: Use the ``--relevance-analysis`` argument of the :ref:`predict command - <mednet.cli>` to generate the following plots. + <mednet.libs.common.cli>` to generate the following plots. * Green color: likely TB * Orange color: Could be TB diff --git a/doc/usage/evaluation.rst b/doc/usage/evaluation.rst index 1288bab2007f377d961a606e835feaef86eb5d45..45519265f27d221f2feed66f527cc24b4204b5a4 100644 --- a/doc/usage/evaluation.rst +++ b/doc/usage/evaluation.rst @@ -20,7 +20,7 @@ Inference In inference (or prediction) mode, we input a model, a dataset, a model checkpoint generated during training, and output a json file containing the prediction outputs for every input image. -To run inference, use the sub-command :ref:`predict <mednet.libs.classification.cli>`. +To run inference, use the sub-command :ref:`predict <mednet.libs.common.cli>`. Examples ======== @@ -40,7 +40,7 @@ Evaluation In evaluation, we input predictions to generate performance summaries that help analysis of a trained model. The generated files are a .pdf containing various plots and a table of metrics for each dataset split. -Evaluation is done using the :ref:`evaluate command <mednet.libs.classification.cli>` followed by the json file generated during +Evaluation is done using the :ref:`evaluate command <mednet.libs.common.cli>` followed by the json file generated during the inference step and a threshold. Use ``mednet evaluate --help`` for more information. diff --git a/doc/usage/experiment.rst b/doc/usage/experiment.rst index 8dc75b9d76e432dd91ca13e94f9585daf7e4e2a9..597178bf6ab1773da236e72950c4e54910641362 100644 --- a/doc/usage/experiment.rst +++ b/doc/usage/experiment.rst @@ -8,7 +8,7 @@ Running complete experiments ============================== -We provide an :ref:`experiment command <mednet.libs.classification.cli>` +We provide an :ref:`experiment command <mednet.libs.common.cli>` that runs training, followed by prediction and evaluation. After running, you will be able to find results from model fitting, prediction and evaluation under a single output directory. diff --git a/doc/usage/index.rst b/doc/usage/index.rst index ce1c5c065c24ad02bbf1fd97b768b2fd74d5032f..4b643aab8827939b059c962ea05324171a38157d 100644 --- a/doc/usage/index.rst +++ b/doc/usage/index.rst @@ -51,7 +51,7 @@ Direct detection * Comparison: Use predictions results to compare performance of multiple systems. -We provide :ref:`command-line interfaces (CLI) <mednet.libs.classification.cli>` that implement +We provide :ref:`command-line interfaces (CLI) <mednet.libs.common.cli>` that implement each of the phases above. This interface is configurable using :ref:`clapper's extensible configuration framework <clapper.config>`. In essence, each command-line option may be provided as a variable with the same name in a diff --git a/doc/usage/saliency.rst b/doc/usage/saliency.rst index ebd03bf72a2fd2967a992aa6100183c94104e617..178aca289f4a43de29a1363b00dd2683f5809a93 100644 --- a/doc/usage/saliency.rst +++ b/doc/usage/saliency.rst @@ -17,7 +17,7 @@ Some of the scripts require the use of a database with human-annotated saliency Generation ---------- -Saliency maps can be generated with the :ref:`saliency generate command <mednet.libs.classification.cli>`. +Saliency maps can be generated with the :ref:`saliency generate command <mednet.libs.common.cli>`. They are represented as numpy arrays of the same size as thes images, with values in the range [0-1] and saved in .npy files. Several mapping algorithms are available to choose from, which can be specified with the -s option. @@ -36,7 +36,7 @@ objects on the output directory: Viewing ------- -To overlay saliency maps over the original images, use the :ref:`saliency view command <mednet.libs.classification.cli>`. +To overlay saliency maps over the original images, use the :ref:`saliency view command <mednet.libs.common.cli>`. Results are saved as PNG images in which brigter pixels correspond to areas with higher saliency. Examples diff --git a/doc/usage/training.rst b/doc/usage/training.rst index a123fce48ffa5be39ca69cf0fa886ef3acdc79c6..380e5ef5233c9d7a02085940b74779cce32cd02b 100644 --- a/doc/usage/training.rst +++ b/doc/usage/training.rst @@ -71,7 +71,7 @@ Plotting training metrics Various metrics are recorded at each epoch during training, such as the execution time, loss and resource usage. These are saved in a Tensorboard file, located in a `logs` subdirectory of the training output folder. -Mednet provides a :ref:`train-analysis <mednet.libs.classification.cli>` convenience script that graphs the scalars stored in these files and saves them in a .pdf file. +Mednet provides a :ref:`train-analysis <mednet.libs.common.cli>` convenience script that graphs the scalars stored in these files and saves them in a .pdf file. Examples ======== diff --git a/src/mednet/libs/common/scripts/upload.py b/src/mednet/libs/common/scripts/upload.py index ccca1ce840f68010bbdb93658733a90cdf2b068c..409bcc36fdbf74e13359390cb2d55e845daaf4e0 100644 --- a/src/mednet/libs/common/scripts/upload.py +++ b/src/mednet/libs/common/scripts/upload.py @@ -14,7 +14,6 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @click.command( - entry_point_group="mednet.config", cls=ConfigCommand, epilog="""Examples: diff --git a/src/mednet/libs/segmentation/models/losses.py b/src/mednet/libs/segmentation/models/losses.py index 3af10f0dd434132c992197ec310b8e33042ffd3f..797da8371d2874637b715f02a6ad9b60e60b5b23 100644 --- a/src/mednet/libs/segmentation/models/losses.py +++ b/src/mednet/libs/segmentation/models/losses.py @@ -5,10 +5,9 @@ """Loss implementations.""" import torch -from torch.nn.modules.loss import _Loss -class WeightedBCELogitsLoss(_Loss): +class WeightedBCELogitsLoss(torch.nn.Module): """Calculates sum of weighted cross entropy loss. Implements Equation 1 in [MANINIS-2016]_. The weight depends on the @@ -29,17 +28,14 @@ class WeightedBCELogitsLoss(_Loss): sample Value produced by the model to be evaluated, with the shape ``[n, c, h, w]``. - target Ground-truth information with the shape ``[n, c, h, w]``. - mask Mask to be use for specifying the region of interest where to compute the loss, with the shape ``[n, c, h, w]``. Returns ------- - loss The average loss for all input data. """ @@ -57,7 +53,7 @@ class WeightedBCELogitsLoss(_Loss): ) -class SoftJaccardBCELogitsLoss(_Loss): +class SoftJaccardBCELogitsLoss(torch.nn.Module): r"""Implement the generalized loss function of Equation (3) in. [IGLOVIKOV-2018]_, with J being the Jaccard distance, and H, the Binary @@ -89,17 +85,14 @@ class SoftJaccardBCELogitsLoss(_Loss): tensor Value produced by the model to be evaluated, with the shape ``[n, c, h, w]``. - target Ground-truth information with the shape ``[n, c, h, w]``. - mask Mask to be use for specifying the region of interest where to compute the loss, with the shape ``[n, c, h, w]``. Returns ------- - loss Loss, in a single entry. """ @@ -135,10 +128,8 @@ class MultiWeightedBCELogitsLoss(WeightedBCELogitsLoss): tensor Value produced by the model to be evaluated, with the shape ``[L, n, c, h, w]``. - target Ground-truth information with the shape ``[n, c, h, w]``. - mask Mask to be use for specifying the region of interest where to compute the loss, with the shape ``[n, c, h, w]``. @@ -181,10 +172,8 @@ class MultiSoftJaccardBCELogitsLoss(SoftJaccardBCELogitsLoss): tensor Value produced by the model to be evaluated, with the shape ``[L, n, c, h, w]``. - target Ground-truth information with the shape ``[n, c, h, w]``. - mask Mask to be use for specifying the region of interest where to compute the loss, with the shape ``[n, c, h, w]``. @@ -204,80 +193,76 @@ class MultiSoftJaccardBCELogitsLoss(SoftJaccardBCELogitsLoss): ).mean() -class MixJacLoss(_Loss): - """ - Parameters - ---------- - lambda_u - Determines the weighting of SoftJaccard and BCE. - - jacalpha - Determines the weighting of J and H. - - size_average - By default, the losses are averaged over each loss element in the batch. Note that for - some losses, there are multiple elements per sample. If the field :attr:`size_average` - is set to ``False``, the losses are instead summed for each minibatch. Ignored - when :attr:`reduce` is ``False``. Default: ``True``. - - reduce - By default, the losses are averaged or summed over observations for each minibatch depending - on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per - batch element instead and ignores :attr:`size_average`. Default: ``True``. - - reduction - Specifies the reduction to apply to the output: - ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, - ``'mean'``: the sum of the output will be divided by the number of - elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` - and :attr:`reduce` are in the process of being deprecated, and in the meantime, - specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``. - """ - - def __init__( - self, - lambda_u: int = 100, - jacalpha=0.7, - size_average=None, - reduce=None, - reduction="mean", - ): - super().__init__(size_average, reduce, reduction) - self.lambda_u = lambda_u - self.labeled_loss = SoftJaccardBCELogitsLoss(alpha=jacalpha) - self.unlabeled_loss = torch.nn.BCEWithLogitsLoss() - - def forward( - self, - tensor: torch.Tensor, - target: torch.Tensor, - unlabeled_tensor: torch.Tensor, - unlabeled_target: torch.Tensor, - ramp_up_factor: float, - ) -> tuple: - """Forward pass. - - Parameters - ---------- - tensor - Value produced by the model to be evaluated, with the shape ``[L, - n, c, h, w]``. - - target - Ground-truth information with the shape ``[n, c, h, w]``. - - unlabeled_tensor - - unlabeled_target - - ramp_up_factor - - Returns - ------- - list - """ - ll = self.labeled_loss(tensor, target) - ul = self.unlabeled_loss(unlabeled_tensor, unlabeled_target) - - loss = ll + self.lambda_u * ramp_up_factor * ul - return loss, ll, ul +# class MixJacLoss(torch.nn.Module): +# """Implements Mix Jaccard Loss. + +# Parameters +# ---------- +# lambda_u +# Determines the weighting of SoftJaccard and BCE. +# jacalpha +# Determines the weighting of J and H. +# size_average +# By default, the losses are averaged over each loss element in the batch. Note that for +# some losses, there are multiple elements per sample. If the field `size_average` +# is set to ``False``, the losses are instead summed for each minibatch. Ignored +# when `reduce` is ``False``. Default: ``True``. +# reduce +# By default, the losses are averaged or summed over observations for each minibatch depending +# on `size_average`. When `reduce` is ``False``, returns a loss per +# batch element instead and ignores `size_average`. Default: ``True``. +# reduction +# Specifies the reduction to apply to the output: +# ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, +# ``'mean'``: the sum of the output will be divided by the number of +# elements in the output, ``'sum'``: the output will be summed. Note: `size_average` +# and `reduce` are in the process of being deprecated, and in the meantime, +# specifying either of those two args will override `reduction`. Default: ``'mean'``. +# """ + +# def __init__( +# self, +# lambda_u: int = 100, +# jacalpha=0.7, +# size_average=None, +# reduce=None, +# reduction="mean", +# ): +# super().__init__(size_average, reduce, reduction) +# self.lambda_u = lambda_u +# self.labeled_loss = SoftJaccardBCELogitsLoss(alpha=jacalpha) +# self.unlabeled_loss = torch.nn.BCEWithLogitsLoss() + +# def forward( +# self, +# tensor: torch.Tensor, +# target: torch.Tensor, +# unlabeled_tensor: torch.Tensor, +# unlabeled_target: torch.Tensor, +# ramp_up_factor: float, +# ) -> tuple: +# """Forward pass. + +# Parameters +# ---------- +# tensor +# Value produced by the model to be evaluated, with the shape ``[L, +# n, c, h, w]``. +# target +# Ground-truth information with the shape ``[n, c, h, w]``. + +# unlabeled_tensor + +# unlabeled_target + +# ramp_up_factor + +# Returns +# ------- +# list +# """ +# ll = self.labeled_loss(tensor, target) +# ul = self.unlabeled_loss(unlabeled_tensor, unlabeled_target) + +# loss = ll + self.lambda_u * ramp_up_factor * ul +# return loss, ll, ul diff --git a/src/mednet/libs/segmentation/models/lwnet.py b/src/mednet/libs/segmentation/models/lwnet.py index 593c2f9675b77ae8872bbe7f6e73827231f0699c..a5cfc9bd711c552dece3a5724b60290dff2603bc 100644 --- a/src/mednet/libs/segmentation/models/lwnet.py +++ b/src/mednet/libs/segmentation/models/lwnet.py @@ -34,6 +34,22 @@ def _conv1x1(in_planes, out_planes, stride=1): class ConvBlock(torch.nn.Module): + """Convolution block. + + Parameters + ---------- + in_c + Number of input channels. + out_c + Number of output channels. + k_sz + Kernel Size. + shortcut + If True, adds a Conv2d layer. + pool + If True, adds a MaxPool2d layer. + """ + def __init__(self, in_c, out_c, k_sz=3, shortcut=False, pool=True): super().__init__() if shortcut is True: @@ -74,6 +90,18 @@ class ConvBlock(torch.nn.Module): class UpsampleBlock(torch.nn.Module): + """Upsample block implementation. + + Parameters + ---------- + in_c + Number of input channels. + out_c + Number of output channels. + up_mode + Upsampling mode. + """ + def __init__(self, in_c, out_c, up_mode="transp_conv"): super().__init__() block = [] @@ -98,6 +126,16 @@ class UpsampleBlock(torch.nn.Module): class ConvBridgeBlock(torch.nn.Module): + """ConvBridgeBlock implementation. + + Parameters + ---------- + channels + Number of channels. + k_sz + Kernel Size. + """ + def __init__(self, channels, k_sz=3): super().__init__() pad = (k_sz - 1) // 2 @@ -116,6 +154,24 @@ class ConvBridgeBlock(torch.nn.Module): class UpConvBlock(torch.nn.Module): + """UpConvBlock implementation. + + Parameters + ---------- + in_c + Number of input channels. + out_c + Number of output channels. + k_sz + Kernel Size. + up_mode + Upsampling mode. + conv_bridge + If True, adds a ConvBridgeBlock layer. + shortcut + If True, adds a Conv2d layer. + """ + def __init__( self, in_c, @@ -150,19 +206,19 @@ class LittleUNet(torch.nn.Module): Parameters ---------- in_c - + Number of input channels. n_classes Number of outputs (classes) for this model. - layers - + Number of layers of the model. k_sz - + Kernel Size. up_mode - + Upsampling mode. conv_bridge - + If True, adds a ConvBridgeBlock layer. shortcut + If True, adds a Conv2d layer. """ def __init__( @@ -327,16 +383,3 @@ class LittleWNet(Model): return self._optimizer_type( self.parameters(), **self._optimizer_arguments ) - - """def configure_optimizers(self): - optimizer = getattr( - self, 'optimizer', Adam(self.parameters(), lr=1e-3) - ) - if optimizer is None: - raise ValueError("Optimizer not found. Please provide an optimizer.") - - scheduler = getattr(self, 'scheduler', None) - if scheduler is None: - return {'optimizer': optimizer} - else: - return {'optimizer': optimizer, 'lr_scheduler': scheduler}""" diff --git a/src/mednet/libs/segmentation/utils/plot.py b/src/mednet/libs/segmentation/utils/plot.py index 746a1877cef8c87d3d2ab49535e537c47ccf88c3..1cd19b320b7ce740fef05500d17cbe29f0ee7033 100644 --- a/src/mednet/libs/segmentation/utils/plot.py +++ b/src/mednet/libs/segmentation/utils/plot.py @@ -152,7 +152,6 @@ def precision_recall_f1iso( Returns ------- - figure A matplotlib figure you can save or display (uses an ``agg`` backend). """ @@ -258,7 +257,6 @@ def loss_curve(df: pandas.DataFrame) -> matplotlib.figure.Figure: Returns ------- - figure A figure, that may be saved or displayed. """ diff --git a/src/mednet/libs/segmentation/utils/table.py b/src/mednet/libs/segmentation/utils/table.py index eaf3ee4e43490cb4e76445d7bd3b59874675a38a..f433835ec04df9461f7dcfdc19d0061bd418b352 100644 --- a/src/mednet/libs/segmentation/utils/table.py +++ b/src/mednet/libs/segmentation/utils/table.py @@ -44,7 +44,6 @@ def performance_table(data: dict[str, dict[str, typing.Any]], fmt: str) -> str: Returns ------- - table A table in a specific format. """