Skip to content
Snippets Groups Projects
Commit 1075431f authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

[doc] Add segmentation documentation

parent b4b8c620
No related branches found
No related tags found
1 merge request!46Create common library
......@@ -12,68 +12,135 @@ This section includes information for using the Python API of
``mednet``.
.. _mednet.libs.classification.api.data:
Common library
--------------
This common library contains methods and scripts that can be reused by more specialized libraries.
.. _mednet.libs.common.api.data:
Data Methods
------------
^^^^^^^^^^^^
Auxiliary classes and methods to define raw dataset iterators.
.. autosummary::
:toctree: api/data
mednet.libs.classification.data.augmentations
mednet.libs.classification.data.datamodule
mednet.libs.classification.data.image_utils
mednet.libs.classification.data.split
mednet.libs.classification.data.typing
mednet.libs.common.data.augmentations
mednet.libs.common.data.datamodule
mednet.libs.common.data.image_utils
mednet.libs.common.data.split
mednet.libs.common.data.typing
.. _mednet.libs.classification.api.models:
.. _mednet.libs.common.api.engines:
Command engines
^^^^^^^^^^^^^^^
Functions to actuate on the data.
.. autosummary::
:toctree: api/engine
mednet.libs.common.engine.callbacks
mednet.libs.common.engine.device
mednet.libs.common.engine.loggers
mednet.libs.common.engine.trainer
.. _mednet.libs.common.api.models:
Models
------
^^^^^^
CNN and other models implemented.
Common model utilities.
.. autosummary::
:toctree: api/models
mednet.libs.classification.models.pasa
mednet.libs.classification.models.alexnet
mednet.libs.classification.models.densenet
mednet.libs.classification.models.logistic_regression
mednet.libs.classification.models.loss_weights
mednet.libs.classification.models.mlp
mednet.libs.classification.models.model
mednet.libs.classification.models.normalizer
mednet.libs.classification.models.separate
mednet.libs.classification.models.transforms
mednet.libs.classification.models.typing
mednet.libs.common.models.loss_weights
mednet.libs.common.models.model
mednet.libs.common.models.normalizer
mednet.libs.common.models.typing
.. _mednet.libs.common.api.utils:
Utils
^^^^^
Reusable auxiliary functions.
.. autosummary::
:toctree: api/utils
mednet.libs.common.utils.checkpointer
mednet.libs.common.utils.gitlab
mednet.libs.common.utils.resources
mednet.libs.common.utils.summary
mednet.libs.common.utils.tensorboard
Classification library
----------------------
Library for training models on classification tasks
.. _mednet.libs.classification.api.data:
Data
^^^^
Classification-specific data methods
.. autosummary::
:toctree: api/data
mednet.libs.classification.data.typing
.. _mednet.libs.classification.api.engines:
Command engines
---------------
^^^^^^^^^^^^^^^
Functions to actuate on the data.
.. autosummary::
:toctree: api/engine
mednet.libs.classification.engine.callbacks
mednet.libs.classification.engine.device
mednet.libs.classification.engine.evaluator
mednet.libs.classification.engine.loggers
mednet.libs.classification.engine.predictor
mednet.libs.classification.engine.trainer
.. _mednet.libs.classification.api.models:
Models
^^^^^^
CNN and other models implemented.
.. autosummary::
:toctree: api/models
mednet.libs.classification.models.pasa
mednet.libs.classification.models.alexnet
mednet.libs.classification.models.densenet
mednet.libs.classification.models.loss_weights
mednet.libs.classification.models.logistic_regression
mednet.libs.classification.models.mlp
mednet.libs.classification.models.separate
mednet.libs.classification.models.transforms
mednet.libs.classification.models.typing
.. _mednet.libs.classification.api.saliency:
Saliency Map Generation and Analysis
------------------------------------
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Engines to generate and analyze saliency mapping techniques.
......@@ -89,19 +156,79 @@ Engines to generate and analyze saliency mapping techniques.
.. _mednet.libs.classification.api.utils:
Various utilities
-----------------
Utils
^^^^^
Reusable auxiliary functions.
Classification-specific utilities.
.. autosummary::
:toctree: api/utils
mednet.utils.checkpointer
mednet.utils.gitlab
mednet.utils.rc
mednet.libs.common.utils.resources
mednet.utils.tensorboard
mednet.libs.segmentation.utils.rc
Segmentation library
--------------------
Library for training models on segmentation tasks
.. _mednet.libs.segmentation.api.data:
Data
^^^^
Segmentation-specific data methods
.. autosummary::
:toctree: api/data
mednet.libs.segmentation.data.typing
.. _mednet.libs.segmentation.api.engines:
Command engines
^^^^^^^^^^^^^^^
Functions to actuate on the data.
.. autosummary::
:toctree: api/engine
mednet.libs.segmentation.engine.evaluator
.. _mednet.libs.segmentation.api.models:
Models
^^^^^^
CNN and other models implemented.
.. autosummary::
:toctree: api/models
mednet.libs.segmentation.models.losses
mednet.libs.segmentation.models.lwnet
mednet.libs.segmentation.models.separate
mednet.libs.segmentation.models.typing
.. _mednet.libs.segmentation.api.utils:
Utils
^^^^^
Segmentation-specific utilities.
.. autosummary::
:toctree: api/utils
mednet.libs.segmentation.utils.measure
mednet.libs.segmentation.utils.plot
mednet.libs.segmentation.utils.rc
mednet.libs.segmentation.utils.table
.. include:: links.rst
......@@ -2,7 +2,7 @@
..
.. SPDX-License-Identifier: GPL-3.0-or-later
.. _mednet.libs.classification.cli:
.. _mednet.libs.common.cli:
========================
Command-line Interface
......@@ -12,7 +12,7 @@ This section contains an overview of command-line applications shipped with
this package.
.. click:: mednet.libs.classification.scripts.cli:cli
.. click:: mednet.scripts.cli:cli
:prog: mednet
:nested: full
......
......@@ -117,6 +117,7 @@ autodoc_default_options = {
auto_intersphinx_packages = [
"matplotlib",
"numpy",
"pandas",
"pillow",
"psutil",
"scipy",
......
......@@ -81,3 +81,28 @@
**A Consistent and Efficient Evaluation Strategy for Attribution Methods** in
Proceedings of the 39th International Conference on Machine Learning, PMLR,
Jun. 2022, pp. 18770–18795. https://proceedings.mlr.press/v162/rong22a.html
.. [IGLOVIKOV-2018] *V. Iglovikov, S. Seferbekov, A. Buslaev and A. Shvets*,
**TernausNetV2: Fully Convolutional Network for Instance Segmentation**,
2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition
Workshops (CVPRW), Salt Lake City, UT, 2018, pp. 228-2284.
https://doi.org/10.1109/CVPRW.2018.00042
.. [XIE-2015] *S. Xie and Z. Tu*, **Holistically-Nested Edge Detection**, 2015
IEEE International Conference on Computer Vision (ICCV), Santiago, 2015, pp.
1395-1403. https://doi.org/10.1109/ICCV.2015.164
.. [MANINIS-2016] *K.-K. Maninis, J. Pont-Tuset, P. Arbeláez, and L. Van Gool*,
**Deep Retinal Image Understanding**, in Medical Image Computing and
Computer-Assisted Intervention – MICCAI 2016, Cham, 2016, pp. 140–148.
https://doi.org/10.1007/978-3-319-46723-8_17
.. [GALDRAN-2020] *A. Galdran, A. Anjos, J. Dolz, H. Chakor, H. Lombaert, and
I. Ben Ayed*, **The Little W-Net That Could: State-of-the-Art Retinal Vessel
Segmentation with Minimalistic Models**, 2020.
https://arxiv.org/abs/2009.01907
.. [GOUTTE-2005] *C. Goutte and E. Gaussier*, **A probabilistic interpretation
of precision, recall and F-score, with implication for evaluation**,
European conference on Advances in Information Retrieval Research, 2005.
https://doi.org/10.1007/978-3-540-31865-1_25
......@@ -59,9 +59,9 @@ Stratified k-folding has been used (10 folds) to generate these results.
.. tip::
To generate the following results, you first need to predict TB on each
fold, then use the :ref:`aggregpred command <mednet.cli>` to aggregate the
fold, then use the :ref:`aggregpred command <mednet.libs.common.cli>` to aggregate the
predictions together, and finally evaluate this new file using the
:ref:`compare command <mednet.cli>`.
:ref:`compare command <mednet.libs.common.cli>`.
Pasa and DenseNet-121 (random initialization)
"""""""""""""""""""""""""""""""""""""""""""""
......@@ -113,37 +113,37 @@ Thresholds used:
:scale: 50%
:alt: Testing sets ROC curves for Pasa model trained on normalized-kfold MC
:py:mod:`Pasa <mednet.config.models.pasa>`: Pasa trained on normalized-kfold MC
:py:mod:`Pasa <mednet.libs.classification.config.models.pasa>`: Pasa trained on normalized-kfold MC
- .. figure:: img/compare_pasa_mc_ch_kfold_500.jpg
:align: center
:scale: 50%
:alt: Testing sets ROC curves for Pasa model trained on normalized-kfold MC-CH
:py:mod:`Pasa <mednet.config.models.pasa>`: Pasa trained on normalized-kfold MC-CH
:py:mod:`Pasa <mednet.libs.classification.config.models.pasa>`: Pasa trained on normalized-kfold MC-CH
- .. figure:: img/compare_pasa_mc_ch_in_kfold_500.jpg
:align: center
:scale: 50%
:alt: Testing sets ROC curves for Pasa model trained on normalized-kfold MC-CH-IN
:py:mod:`Pasa <mednet.config.models.pasa>`: Pasa trained on normalized-kfold MC-CH-IN
:py:mod:`Pasa <mednet.libs.classification.config.models.pasa>`: Pasa trained on normalized-kfold MC-CH-IN
* - .. figure:: img/compare_densenet_mc_kfold_2000.jpg
:align: center
:scale: 50%
:alt: Testing sets ROC curves for DenseNet model trained on normalized-kfold MC
:py:mod:`DenseNet <mednet.config.models.densenet>`: DenseNet trained on normalized-kfold MC
:py:mod:`DenseNet <mednet.libs.classification.config.models.densenet>`: DenseNet trained on normalized-kfold MC
- .. figure:: img/compare_densenet_mc_ch_kfold_2000.jpg
:align: center
:scale: 50%
:alt: Testing sets ROC curves for DenseNet model trained on normalized-kfold MC-CH
:py:mod:`DenseNet <mednet.config.models.densenet>`: DenseNet trained on normalized-kfold MC-CH
:py:mod:`DenseNet <mednet.libs.classification.config.models.densenet>`: DenseNet trained on normalized-kfold MC-CH
- .. figure:: img/compare_densenet_mc_ch_in_kfold_2000.jpg
:align: center
:scale: 50%
:alt: Testing sets ROC curves for DenseNet model trained on normalized-kfold MC-CH-IN
:py:mod:`DenseNet <mednet.config.models.densenet>`: DenseNet trained on normalized-kfold MC-CH-IN
:py:mod:`DenseNet <mednet.libs.classification.config.models.densenet>`: DenseNet trained on normalized-kfold MC-CH-IN
DenseNet-121 (pretrained on ImageNet)
"""""""""""""""""""""""""""""""""""""
......@@ -180,19 +180,19 @@ Thresholds used:
:scale: 50%
:alt: Testing sets ROC curves for DenseNet model trained on normalized-kfold MC
:py:mod:`DenseNet <mednet.config.models.densenet>` DenseNet trained on normalized-kfold MC
:py:mod:`DenseNet <mednet.libs.classification.config.models.densenet>` DenseNet trained on normalized-kfold MC
- .. figure:: img/compare_densenetpreIN_mc_ch_kfold_600.jpg
:align: center
:scale: 50%
:alt: Testing sets ROC curves for DenseNet model trained on normalized-kfold MC-CH
:py:mod:`DenseNet <mednet.config.models.densenet>` DenseNet trained on normalized-kfold MC-CH
:py:mod:`DenseNet <mednet.libs.classification.config.models.densenet>` DenseNet trained on normalized-kfold MC-CH
- .. figure:: img/compare_densenetpreIN_mc_ch_ch_kfold_600.jpg
:align: center
:scale: 50%
:alt: Testing sets ROC curves for DenseNet model trained on normalized-kfold MC-CH-IN
:py:mod:`DenseNet <mednet.config.models.densenet>` DenseNet trained on normalized-kfold MC-CH-IN
:py:mod:`DenseNet <mednet.libs.classification.config.models.densenet>` DenseNet trained on normalized-kfold MC-CH-IN
Logistic Regression Classifier
""""""""""""""""""""""""""""""
......@@ -229,19 +229,19 @@ Thresholds used:
:scale: 50%
:alt: Testing sets ROC curves for LogReg model trained on normalized-kfold MC
:py:mod:`LogReg <mednet.config.models.logistic_regression>`: LogReg trained on normalized-kfold MC
:py:mod:`LogReg <mednet.libs.classification.config.models.logistic_regression>`: LogReg trained on normalized-kfold MC
- .. figure:: img/compare_logreg_mc_ch_kfold_100.jpg
:align: center
:scale: 50%
:alt: Testing sets ROC curves for LogReg model trained on normalized-kfold MC-CH
:py:mod:`LogReg <mednet.config.models.logistic_regression>`: LogReg trained on normalized-kfold MC-CH
:py:mod:`LogReg <mednet.libs.classification.config.models.logistic_regression>`: LogReg trained on normalized-kfold MC-CH
- .. figure:: img/compare_logreg_mc_ch_in_kfold_100.jpg
:align: center
:scale: 50%
:alt: Testing sets ROC curves for LogReg model trained on normalized-kfold MC-CH-IN
:py:mod:`LogReg <mednet.config.models.logistic_regression>`: LogReg trained on normalized-kfold MC-CH-IN
:py:mod:`LogReg <mednet.libs.classification.config.models.logistic_regression>`: LogReg trained on normalized-kfold MC-CH-IN
DenseNet-121 (pretrained on ImageNet and NIH CXR14)
"""""""""""""""""""""""""""""""""""""""""""""""""""
......@@ -278,19 +278,19 @@ Thresholds used:
:scale: 50%
:alt: Testing sets ROC curves for DenseNet model trained on normalized-kfold MC (pretrained on NIH)
:py:mod:`DenseNet <mednet.config.models.densenet>`: DenseNet trained on normalized-kfold MC (pretrained on NIH)
:py:mod:`DenseNet <mednet.libs.classification.config.models.densenet>`: DenseNet trained on normalized-kfold MC (pretrained on NIH)
- .. figure:: img/compare_densenetpre_mc_ch_kfold_300.jpg
:align: center
:scale: 50%
:alt: Testing sets ROC curves for DenseNet model trained on normalized-kfold MC-CH (pretrained on NIH)
:py:mod:`DenseNet <mednet.config.models.densenet>`: DenseNet trained on normalized-kfold MC-CH (pretrained on NIH)
:py:mod:`DenseNet <mednet.libs.classification.config.models.densenet>`: DenseNet trained on normalized-kfold MC-CH (pretrained on NIH)
- .. figure:: img/compare_densenetpre_mc_ch_in_kfold_300.jpg
:align: center
:scale: 50%
:alt: Testing sets ROC curves for DenseNet model trained on normalized-kfold MC-CH-IN (pretrained on NIH)
:py:mod:`DenseNet <mednet.config.models.densenet>`: DenseNet trained on normalized-kfold MC-CH-IN (pretrained on NIH)
:py:mod:`DenseNet <mednet.libs.classification.config.models.densenet>`: DenseNet trained on normalized-kfold MC-CH-IN (pretrained on NIH)
Global sensitivity analysis (relevance)
......@@ -301,7 +301,7 @@ Model used to generate the following figures: LogReg trained on MC-CH-IN fold 0
.. tip::
Use the ``--relevance-analysis`` argument of the :ref:`predict command
<mednet.cli>` to generate the following plots.
<mednet.libs.common.cli>` to generate the following plots.
* Green color: likely TB
* Orange color: Could be TB
......
......@@ -20,7 +20,7 @@ Inference
In inference (or prediction) mode, we input a model, a dataset, a model checkpoint generated during training, and output
a json file containing the prediction outputs for every input image.
To run inference, use the sub-command :ref:`predict <mednet.libs.classification.cli>`.
To run inference, use the sub-command :ref:`predict <mednet.libs.common.cli>`.
Examples
========
......@@ -40,7 +40,7 @@ Evaluation
In evaluation, we input predictions to generate performance summaries that help analysis of a trained model.
The generated files are a .pdf containing various plots and a table of metrics for each dataset split.
Evaluation is done using the :ref:`evaluate command <mednet.libs.classification.cli>` followed by the json file generated during
Evaluation is done using the :ref:`evaluate command <mednet.libs.common.cli>` followed by the json file generated during
the inference step and a threshold.
Use ``mednet evaluate --help`` for more information.
......
......@@ -8,7 +8,7 @@
Running complete experiments
==============================
We provide an :ref:`experiment command <mednet.libs.classification.cli>`
We provide an :ref:`experiment command <mednet.libs.common.cli>`
that runs training, followed by prediction and evaluation.
After running, you will be able to find results from model fitting,
prediction and evaluation under a single output directory.
......
......@@ -51,7 +51,7 @@ Direct detection
* Comparison: Use predictions results to compare performance of multiple
systems.
We provide :ref:`command-line interfaces (CLI) <mednet.libs.classification.cli>` that implement
We provide :ref:`command-line interfaces (CLI) <mednet.libs.common.cli>` that implement
each of the phases above. This interface is configurable using :ref:`clapper's
extensible configuration framework <clapper.config>`. In essence, each
command-line option may be provided as a variable with the same name in a
......
......@@ -17,7 +17,7 @@ Some of the scripts require the use of a database with human-annotated saliency
Generation
----------
Saliency maps can be generated with the :ref:`saliency generate command <mednet.libs.classification.cli>`.
Saliency maps can be generated with the :ref:`saliency generate command <mednet.libs.common.cli>`.
They are represented as numpy arrays of the same size as thes images, with values in the range [0-1] and saved in .npy files.
Several mapping algorithms are available to choose from, which can be specified with the -s option.
......@@ -36,7 +36,7 @@ objects on the output directory:
Viewing
-------
To overlay saliency maps over the original images, use the :ref:`saliency view command <mednet.libs.classification.cli>`.
To overlay saliency maps over the original images, use the :ref:`saliency view command <mednet.libs.common.cli>`.
Results are saved as PNG images in which brigter pixels correspond to areas with higher saliency.
Examples
......
......@@ -71,7 +71,7 @@ Plotting training metrics
Various metrics are recorded at each epoch during training, such as the execution time, loss and resource usage.
These are saved in a Tensorboard file, located in a `logs` subdirectory of the training output folder.
Mednet provides a :ref:`train-analysis <mednet.libs.classification.cli>` convenience script that graphs the scalars stored in these files and saves them in a .pdf file.
Mednet provides a :ref:`train-analysis <mednet.libs.common.cli>` convenience script that graphs the scalars stored in these files and saves them in a .pdf file.
Examples
========
......
......@@ -14,7 +14,6 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
@click.command(
entry_point_group="mednet.config",
cls=ConfigCommand,
epilog="""Examples:
......
......@@ -5,10 +5,9 @@
"""Loss implementations."""
import torch
from torch.nn.modules.loss import _Loss
class WeightedBCELogitsLoss(_Loss):
class WeightedBCELogitsLoss(torch.nn.Module):
"""Calculates sum of weighted cross entropy loss.
Implements Equation 1 in [MANINIS-2016]_. The weight depends on the
......@@ -29,17 +28,14 @@ class WeightedBCELogitsLoss(_Loss):
sample
Value produced by the model to be evaluated, with the shape ``[n, c,
h, w]``.
target
Ground-truth information with the shape ``[n, c, h, w]``.
mask
Mask to be use for specifying the region of interest where to
compute the loss, with the shape ``[n, c, h, w]``.
Returns
-------
loss
The average loss for all input data.
"""
......@@ -57,7 +53,7 @@ class WeightedBCELogitsLoss(_Loss):
)
class SoftJaccardBCELogitsLoss(_Loss):
class SoftJaccardBCELogitsLoss(torch.nn.Module):
r"""Implement the generalized loss function of Equation (3) in.
[IGLOVIKOV-2018]_, with J being the Jaccard distance, and H, the Binary
......@@ -89,17 +85,14 @@ class SoftJaccardBCELogitsLoss(_Loss):
tensor
Value produced by the model to be evaluated, with the shape ``[n, c,
h, w]``.
target
Ground-truth information with the shape ``[n, c, h, w]``.
mask
Mask to be use for specifying the region of interest where to
compute the loss, with the shape ``[n, c, h, w]``.
Returns
-------
loss
Loss, in a single entry.
"""
......@@ -135,10 +128,8 @@ class MultiWeightedBCELogitsLoss(WeightedBCELogitsLoss):
tensor
Value produced by the model to be evaluated, with the shape ``[L,
n, c, h, w]``.
target
Ground-truth information with the shape ``[n, c, h, w]``.
mask
Mask to be use for specifying the region of interest where to
compute the loss, with the shape ``[n, c, h, w]``.
......@@ -181,10 +172,8 @@ class MultiSoftJaccardBCELogitsLoss(SoftJaccardBCELogitsLoss):
tensor
Value produced by the model to be evaluated, with the shape ``[L,
n, c, h, w]``.
target
Ground-truth information with the shape ``[n, c, h, w]``.
mask
Mask to be use for specifying the region of interest where to
compute the loss, with the shape ``[n, c, h, w]``.
......@@ -204,80 +193,76 @@ class MultiSoftJaccardBCELogitsLoss(SoftJaccardBCELogitsLoss):
).mean()
class MixJacLoss(_Loss):
"""
Parameters
----------
lambda_u
Determines the weighting of SoftJaccard and BCE.
jacalpha
Determines the weighting of J and H.
size_average
By default, the losses are averaged over each loss element in the batch. Note that for
some losses, there are multiple elements per sample. If the field :attr:`size_average`
is set to ``False``, the losses are instead summed for each minibatch. Ignored
when :attr:`reduce` is ``False``. Default: ``True``.
reduce
By default, the losses are averaged or summed over observations for each minibatch depending
on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
batch element instead and ignores :attr:`size_average`. Default: ``True``.
reduction
Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
``'mean'``: the sum of the output will be divided by the number of
elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
and :attr:`reduce` are in the process of being deprecated, and in the meantime,
specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``.
"""
def __init__(
self,
lambda_u: int = 100,
jacalpha=0.7,
size_average=None,
reduce=None,
reduction="mean",
):
super().__init__(size_average, reduce, reduction)
self.lambda_u = lambda_u
self.labeled_loss = SoftJaccardBCELogitsLoss(alpha=jacalpha)
self.unlabeled_loss = torch.nn.BCEWithLogitsLoss()
def forward(
self,
tensor: torch.Tensor,
target: torch.Tensor,
unlabeled_tensor: torch.Tensor,
unlabeled_target: torch.Tensor,
ramp_up_factor: float,
) -> tuple:
"""Forward pass.
Parameters
----------
tensor
Value produced by the model to be evaluated, with the shape ``[L,
n, c, h, w]``.
target
Ground-truth information with the shape ``[n, c, h, w]``.
unlabeled_tensor
unlabeled_target
ramp_up_factor
Returns
-------
list
"""
ll = self.labeled_loss(tensor, target)
ul = self.unlabeled_loss(unlabeled_tensor, unlabeled_target)
loss = ll + self.lambda_u * ramp_up_factor * ul
return loss, ll, ul
# class MixJacLoss(torch.nn.Module):
# """Implements Mix Jaccard Loss.
# Parameters
# ----------
# lambda_u
# Determines the weighting of SoftJaccard and BCE.
# jacalpha
# Determines the weighting of J and H.
# size_average
# By default, the losses are averaged over each loss element in the batch. Note that for
# some losses, there are multiple elements per sample. If the field `size_average`
# is set to ``False``, the losses are instead summed for each minibatch. Ignored
# when `reduce` is ``False``. Default: ``True``.
# reduce
# By default, the losses are averaged or summed over observations for each minibatch depending
# on `size_average`. When `reduce` is ``False``, returns a loss per
# batch element instead and ignores `size_average`. Default: ``True``.
# reduction
# Specifies the reduction to apply to the output:
# ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
# ``'mean'``: the sum of the output will be divided by the number of
# elements in the output, ``'sum'``: the output will be summed. Note: `size_average`
# and `reduce` are in the process of being deprecated, and in the meantime,
# specifying either of those two args will override `reduction`. Default: ``'mean'``.
# """
# def __init__(
# self,
# lambda_u: int = 100,
# jacalpha=0.7,
# size_average=None,
# reduce=None,
# reduction="mean",
# ):
# super().__init__(size_average, reduce, reduction)
# self.lambda_u = lambda_u
# self.labeled_loss = SoftJaccardBCELogitsLoss(alpha=jacalpha)
# self.unlabeled_loss = torch.nn.BCEWithLogitsLoss()
# def forward(
# self,
# tensor: torch.Tensor,
# target: torch.Tensor,
# unlabeled_tensor: torch.Tensor,
# unlabeled_target: torch.Tensor,
# ramp_up_factor: float,
# ) -> tuple:
# """Forward pass.
# Parameters
# ----------
# tensor
# Value produced by the model to be evaluated, with the shape ``[L,
# n, c, h, w]``.
# target
# Ground-truth information with the shape ``[n, c, h, w]``.
# unlabeled_tensor
# unlabeled_target
# ramp_up_factor
# Returns
# -------
# list
# """
# ll = self.labeled_loss(tensor, target)
# ul = self.unlabeled_loss(unlabeled_tensor, unlabeled_target)
# loss = ll + self.lambda_u * ramp_up_factor * ul
# return loss, ll, ul
......@@ -34,6 +34,22 @@ def _conv1x1(in_planes, out_planes, stride=1):
class ConvBlock(torch.nn.Module):
"""Convolution block.
Parameters
----------
in_c
Number of input channels.
out_c
Number of output channels.
k_sz
Kernel Size.
shortcut
If True, adds a Conv2d layer.
pool
If True, adds a MaxPool2d layer.
"""
def __init__(self, in_c, out_c, k_sz=3, shortcut=False, pool=True):
super().__init__()
if shortcut is True:
......@@ -74,6 +90,18 @@ class ConvBlock(torch.nn.Module):
class UpsampleBlock(torch.nn.Module):
"""Upsample block implementation.
Parameters
----------
in_c
Number of input channels.
out_c
Number of output channels.
up_mode
Upsampling mode.
"""
def __init__(self, in_c, out_c, up_mode="transp_conv"):
super().__init__()
block = []
......@@ -98,6 +126,16 @@ class UpsampleBlock(torch.nn.Module):
class ConvBridgeBlock(torch.nn.Module):
"""ConvBridgeBlock implementation.
Parameters
----------
channels
Number of channels.
k_sz
Kernel Size.
"""
def __init__(self, channels, k_sz=3):
super().__init__()
pad = (k_sz - 1) // 2
......@@ -116,6 +154,24 @@ class ConvBridgeBlock(torch.nn.Module):
class UpConvBlock(torch.nn.Module):
"""UpConvBlock implementation.
Parameters
----------
in_c
Number of input channels.
out_c
Number of output channels.
k_sz
Kernel Size.
up_mode
Upsampling mode.
conv_bridge
If True, adds a ConvBridgeBlock layer.
shortcut
If True, adds a Conv2d layer.
"""
def __init__(
self,
in_c,
......@@ -150,19 +206,19 @@ class LittleUNet(torch.nn.Module):
Parameters
----------
in_c
Number of input channels.
n_classes
Number of outputs (classes) for this model.
layers
Number of layers of the model.
k_sz
Kernel Size.
up_mode
Upsampling mode.
conv_bridge
If True, adds a ConvBridgeBlock layer.
shortcut
If True, adds a Conv2d layer.
"""
def __init__(
......@@ -327,16 +383,3 @@ class LittleWNet(Model):
return self._optimizer_type(
self.parameters(), **self._optimizer_arguments
)
"""def configure_optimizers(self):
optimizer = getattr(
self, 'optimizer', Adam(self.parameters(), lr=1e-3)
)
if optimizer is None:
raise ValueError("Optimizer not found. Please provide an optimizer.")
scheduler = getattr(self, 'scheduler', None)
if scheduler is None:
return {'optimizer': optimizer}
else:
return {'optimizer': optimizer, 'lr_scheduler': scheduler}"""
......@@ -152,7 +152,6 @@ def precision_recall_f1iso(
Returns
-------
figure
A matplotlib figure you can save or display (uses an ``agg`` backend).
"""
......@@ -258,7 +257,6 @@ def loss_curve(df: pandas.DataFrame) -> matplotlib.figure.Figure:
Returns
-------
figure
A figure, that may be saved or displayed.
"""
......
......@@ -44,7 +44,6 @@ def performance_table(data: dict[str, dict[str, typing.Any]], fmt: str) -> str:
Returns
-------
table
A table in a specific format.
"""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment