diff --git a/.gitignore b/.gitignore index 6ca994ebbf196f8fc29f3d9f2815f9416952e996..f49b592282bb3b64e26250494dcdb88fc39a8c6a 100644 --- a/.gitignore +++ b/.gitignore @@ -26,3 +26,4 @@ _work/ .mypy_cache/ .pytest_cache/ results*/ +trainlog.pdf diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 505a38db3fbc742ebc929bd2ab81ef238bb84724..2c305176c0f4da15bffc19a7f6dd34cc207c3437 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,6 +6,10 @@ # See https://pre-commit.com for more information # See https://pre-commit.com/hooks.html for more hooks repos: + - repo: https://github.com/numpy/numpydoc + rev: v1.6.0 + hooks: + - id: numpydoc-validation - repo: https://github.com/psf/black rev: 23.12.1 hooks: @@ -14,6 +18,9 @@ repos: rev: v1.7.5 hooks: - id: docformatter + args: [ + --wrap-summaries=0, + ] - repo: https://github.com/pycqa/isort rev: 5.13.2 hooks: diff --git a/doc/api.rst b/doc/api.rst index 201bcf562df09c76e6f66c923c63c119447b0152..39b4d40f6218f6100b7f7ab50fb826aaf32e541a 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -42,9 +42,12 @@ CNN and other models implemented. mednet.models.pasa mednet.models.alexnet mednet.models.densenet - mednet.models.normalizer mednet.models.logistic_regression + mednet.models.loss_weights mednet.models.mlp + mednet.models.normalizer + mednet.models.separate + mednet.models.transforms mednet.models.typing @@ -58,11 +61,12 @@ Functions to actuate on the data. .. autosummary:: :toctree: api/engine - mednet.engine.device mednet.engine.callbacks - mednet.engine.trainer - mednet.engine.predictor + mednet.engine.device mednet.engine.evaluator + mednet.engine.loggers + mednet.engine.predictor + mednet.engine.trainer .. _mednet.api.saliency: @@ -75,9 +79,11 @@ Engines to generate and analyze saliency mapping techniques. .. autosummary:: :toctree: api/saliency - mednet.engine.saliency.generator mednet.engine.saliency.completeness + mednet.engine.saliency.evaluator + mednet.engine.saliency.generator mednet.engine.saliency.interpretability + mednet.engine.saliency.viewer .. _mednet.api.utils: diff --git a/doc/catalog.json b/doc/catalog.json index c837093f96ae2e8cd85cbfa6af52af7100e7e7f3..529a23b480af03f8020f89720ff522d6d9f14d7f 100644 --- a/doc/catalog.json +++ b/doc/catalog.json @@ -14,6 +14,15 @@ "environment": "lightning" } }, + "tensorboardx": { + "versions": { + "stable": "https://tensorboardx.readthedocs.io/en/stable/", + "latest": "https://tensorboardx.readthedocs.io/en/latest/" + }, + "sources": { + "readthedocs": "tensorboardx" + } + }, "tabulate": { "versions": { "latest": "https://tabulate.readthedocs.io/en/latest/", diff --git a/doc/conf.py b/doc/conf.py index 1411e92d8303f385882966b713d2b95e1d0970d1..7cf72ca51fb2b9e6d4f22c2cd8143d63f32893da 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -123,6 +123,7 @@ auto_intersphinx_packages = [ "torch", "torchvision", "lightning", + "tensorboardx", ("clapper", "latest"), ("python", "3"), ] diff --git a/doc/config.rst b/doc/config.rst index 1d6c5af91be6671db036e03fa5e730e20524b58a..307e5d3e41099284aef71b4d8386e09b33b0e183 100644 --- a/doc/config.rst +++ b/doc/config.rst @@ -8,7 +8,7 @@ Preset Configurations --------------------- This module contains preset configurations for baseline CNN architectures and -datamodules. +DataModules. .. _mednet.config.models: @@ -38,9 +38,9 @@ DataModule support ================== Base DataModules and raw data loaders for the various databases currently -supported in this package, for your reference. Each pre-configured data module +supported in this package, for your reference. Each pre-configured DataModule can receive the name of one or more splits as argument to build a fully -functional data module that can be used in training, prediction or testing. +functional DataModule that can be used in training, prediction or testing. .. autosummary:: :toctree: api/config.datamodules @@ -67,7 +67,7 @@ Pre-configured DataModules DataModules provide access to preset pytorch dataloaders for training, validating, testing and running prediction tasks. Each of the pre-configured -DataModule is based on one (or more) of the :ref:`supported base data modules +DataModule is based on one (or more) of the :ref:`supported base DataModules <mednet.config.datamodules>`. .. autosummary:: @@ -97,8 +97,8 @@ Cross-validation DataModules We support cross-validation with precise preset folds. In this section, you will find the configuration for the first fold (fold-0) for all supported -datamodules. Nine other folds are available for every configuration (from 1 to -9), making up 10 folds per supported datamodule. +DataModules. Nine other folds are available for every configuration (from 1 to +9), making up 10 folds per supported DataModule. .. autosummary:: diff --git a/doc/contribute.rst b/doc/contribute.rst new file mode 100644 index 0000000000000000000000000000000000000000..0c25cfa8076036c639e31bc20ffc39ece73bf29f --- /dev/null +++ b/doc/contribute.rst @@ -0,0 +1,23 @@ +.. Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +.. +.. SPDX-License-Identifier: GPL-3.0-or-later + +.. _mednet.contribute: + +=================================== + Getting Involved and Contributing +=================================== + +We will happily accept external contributions, but substantial contributions +require a signed Contributor or `Copyright License Agreement <cla_>`_ (CLA). +Our CLA, based on `Project Harmony`_, leaves copyright with you (the +contributor), but allows us to relicense the code, with a restriction based on +the license the contribution was made under. + +Contact our `Technology Transfer Officer <tto_>`_ to get a copy of the CLA_ for +this project. If you work for a company and your contributions are tied to +your job, ensure you have the legal right to sign this CLA, or refer to the +responsible person during your e-mail exchange with our TTO_. + + +.. include:: links.rst diff --git a/doc/data_model.rst b/doc/data_model.rst new file mode 100644 index 0000000000000000000000000000000000000000..d028419572a6a56145fbe098ea9e20f785327942 --- /dev/null +++ b/doc/data_model.rst @@ -0,0 +1,69 @@ +.. Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +.. +.. SPDX-License-Identifier: GPL-3.0-or-later + +.. _mednet.datamodel: + +============ + Data model +============ + +The following describes the various parts of our data model, which are used in this documentation and throughout the codebase. + + +Database +-------- +Data that is downloaded from a data provider, and contains samples in their raw data format. +The database may contain both data and metadata, and is supposed to exist on disk (or any other storage device) +in an arbitrary location that is user-configurable, in the user environment. +For example, databases 1 and 2 for user A may be under /home/user-a/databases/database-1 and /home/user-a/databases/database-2, +while for user B, they may sit in /groups/medical-data/DatabaseOne and /groups/medical-data/DatabaseTwo. + + +Sample +------ +The in-memory representation of the raw database samples. +In this package, it is specified as a two-tuple with a tensor, and metadata (typically label, name, etc.). + + +RawDataLoader +------------- +A concrete "functor" that allows one to load the raw data and associated metadata, to create a in-memory Sample representation. +RawDataLoaders are typically Database-specific due to raw data and metadata encoding varying quite a lot on different databases. +RawDataLoaders may also embed various pre-processing transformations to render data readily usable such as pre-cropping of black pixel areas, +or 16-bit to 8-bit auto-level conversion. + + +TransformSequence +----------------- +A sequence of callables that allows one to transform torch.Tensor objects into other torch.Tensor objects, +typically to crop, resize, convert Color-spaces, and the such on raw-data. + + +DatabaseSplit +------------- +A dictionary that represents an organization of the available raw data in the database to perform +an evaluation protocol (e.g. train, validation, test) through datasets (or subsets). +It is represented as dictionary mapping dataset names to lists of "raw-data" sample representations, which vary in format +depending on Database metadata availability. RawDataLoaders receive this raw representations and can convert these to in-memory Sample's. + + +ConcatDatabaseSplit +------------------- +An extension of a DatabaseSplit, in which the split can be formed by cannibalising various other DatabaseSplits to construct a new evaluation protocol. +Examples of this are cross-database tests, or the construction of multi-Database training and validation subsets. + + +Dataset +------- +An iterable object over in-memory Samples, inherited from the pytorch Dataset definition. +A dataset in our framework may be completely cached in memory or have in-memory representation of samples loaded on demand. +After data loading, our datasets can optionally apply a TransformSequence, composed of pre-processing steps defined on a per-model level +before optionally caching in-memory Sample representations. The "raw" representation of a dataset are the split dictionary values (ie. not the keys). + + +DataModule +---------- +A DataModule aggregates Splits and RawDataLoaders to provide lightning a known-interface to the complete evaluation protocol (train, validation, prediction and testing) +required for a full experiment to take place. It automates control over data loading parallelisation and caching inside our framework, +providing final access to readily-usable pytorch DataLoaders. diff --git a/doc/index.rst b/doc/index.rst index b0126ac0b4051d09395d31a4eb25a88424f8a832..1864f3db1b95e61854160189f13f1f82d1bd53ca 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -52,10 +52,12 @@ User Guide install usage/index results/index + data_model references cli config api + contribute Indices and tables diff --git a/doc/install.rst b/doc/install.rst index 151bfef284149a33451542f16b5579f8bf487eb5..aaf5baaae32b3bdf6319edc2c31e393ef3e7c75e 100644 --- a/doc/install.rst +++ b/doc/install.rst @@ -90,10 +90,10 @@ Here is an example configuration file that may be useful as a starting point: .. code:: sh - mednet dataset list + mednet database list - You must procure and download datasets by yourself. The raw data is not + You must procure and download databases by yourself. The raw data is not included in this package as we are not authorised to redistribute it. To check whether the downloaded version is consistent with the structure @@ -101,7 +101,7 @@ Here is an example configuration file that may be useful as a starting point: .. code:: sh - mednet dataset check montgomery + mednet database check <database_name> .. _mednet.setup.databases: @@ -109,8 +109,8 @@ Here is an example configuration file that may be useful as a starting point: Supported Databases =================== -Here is a list of currently supported datasets in this package, alongside -notable properties. Each dataset name is linked to the location where +Here is a list of currently supported databases in this package, alongside +notable properties. Each database name is linked to the location where raw data can be downloaded. The list of images in each split is available in the source code. @@ -120,13 +120,13 @@ in the source code. Tuberculosis databases ~~~~~~~~~~~~~~~~~~~~~~ -The following datasets contain only the tuberculosis final diagnosis (0 or 1). +The following databases contain only the tuberculosis final diagnosis (0 or 1). In addition to the splits presented in the following table, 10 folds -(for cross-validation) randomly generated are available for these datasets. +(for cross-validation) randomly generated are available for these databases. .. list-table:: - * - Dataset + * - Database - Reference - H x W - Samples @@ -156,20 +156,20 @@ In addition to the splits presented in the following table, 10 folds - 52 -.. _mednet.setup.datasets.tb+signs: +.. _mednet.setup.databases.tb+signs: Tuberculosis multilabel databases ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -The following dataset contains the labels healthy, sick & non-TB, active TB, -and latent TB. The implemented tbx11k dataset in this package is based on +The following databases contain the labels healthy, sick & non-TB, active TB, +and latent TB. The implemented tbx11k database in this package is based on the simplified version, which is just a more compact version of the original. In addition to the splits presented in the following table, 10 folds -(for cross-validation) randomly generated are available for these datasets. +(for cross-validation) randomly generated are available for these databases. .. list-table:: - * - Dataset + * - Database - Reference - H x W - Samples @@ -192,17 +192,17 @@ In addition to the splits presented in the following table, 10 folds - 2800 -.. _mednet.setup.datasets.tbmultilabel+signs: +.. _mednet.setup.databases.tbmultilabel+signs: -Tuberculosis + radiological findings dataset -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Tuberculosis + radiological findings databases +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -The following dataset contains both the tuberculosis final diagnosis (0 or 1) +The following databases contain both the tuberculosis final diagnosis (0 or 1) and radiological findings. .. list-table:: - * - Dataset + * - Database - Reference - H x W - Samples @@ -216,12 +216,12 @@ and radiological findings. - 0 -.. _mednet.setup.datasets.signs: +.. _mednet.setup.databases.signs: -Radiological findings datasets -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Radiological findings databases +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -The following dataset contains only the radiological findings without any +The following database contains only the radiological findings without any information about tuberculosis. .. note:: @@ -231,7 +231,7 @@ information about tuberculosis. .. list-table:: - * - Dataset + * - Database - Reference - H x W - Samples @@ -247,20 +247,20 @@ information about tuberculosis. - 4'054 -.. _mednet.setup.datasets.hiv-tb: +.. _mednet.setup.databases.hiv-tb: -HIV-Tuberculosis datasets -~~~~~~~~~~~~~~~~~~~~~~~~~ +HIV-Tuberculosis databases +~~~~~~~~~~~~~~~~~~~~~~~~~~ -The following datasets contain only the tuberculosis final diagnosis (0 or 1) +The following databases contain only the tuberculosis final diagnosis (0 or 1) and come from HIV infected patients. 10 folds (for cross-validation) randomly -generated are available for these datasets. +generated are available for these databases. -Please contact the authors of these datasets to have access to the data. +Please contact the authors of these databases to have access to the data. .. list-table:: - * - Dataset + * - Database - Reference - H x W - Samples diff --git a/doc/links.rst b/doc/links.rst index f4ba3310e86051b9555e9c919180187372125de0..a692f286e9317c7aebe80976ef1cfa9e3802c085 100644 --- a/doc/links.rst +++ b/doc/links.rst @@ -5,8 +5,11 @@ .. place re-used URLs here, then include this file .. on your other RST sources. -.. _conda: https://conda.io .. _idiap: http://www.idiap.ch +.. _cla: https://en.wikipedia.org/wiki/Contributor_License_Agreement +.. _project harmony: http://www.harmonyagreements.org/ +.. _tto: mailto:tto@idiap.ch +.. _conda: https://conda.io .. _python: http://www.python.org .. _pip: https://pip.pypa.io/en/stable/ .. _mamba: https://mamba.readthedocs.io/en/latest/index.html diff --git a/doc/usage/aggregpred.rst b/doc/usage/aggregpred.rst deleted file mode 100644 index 661750bdfcdd541231833745cf99888c167c5247..0000000000000000000000000000000000000000 --- a/doc/usage/aggregpred.rst +++ /dev/null @@ -1,24 +0,0 @@ -.. Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -.. -.. SPDX-License-Identifier: GPL-3.0-or-later - -.. _mednet.usage.aggregpred: - -======================================================= - Aggregate multiple prediction files into a single one -======================================================= - -This guide explains how to aggregate multiple prediction files into a single -one. It can be used when doing cross-validation to aggregate the predictions of -k different models before evaluating the aggregated predictions. We input -multiple prediction files (CSV files) and output a single one. - -Use the sub-command :ref:`aggregpred <mednet.cli>` aggregate your prediction -files together: - -.. code:: sh - - mednet aggregpred -vv path/to/fold0/predictions.csv path/to/fold1/predictions.csv --output-folder=aggregpred - - -.. include:: ../links.rst diff --git a/doc/usage/evaluation.rst b/doc/usage/evaluation.rst index d490f991c105990d81f4d29237ac2fbc49cac941..4bdd7e843696d396158f0bb31fcce847b6925feb 100644 --- a/doc/usage/evaluation.rst +++ b/doc/usage/evaluation.rst @@ -8,7 +8,7 @@ Inference and Evaluation ========================== -This guides explains how to run inference or a complete evaluation using +This guide explains how to run inference or a complete evaluation using command-line tools. Inference produces probability of TB presence for input images, while evaluation will analyze such output against existing annotations and produce performance figures. @@ -17,61 +17,42 @@ and produce performance figures. Inference --------- -In inference (or prediction) mode, we input data, the trained model, and output -a CSV file containing the prediction outputs for every input image. +In inference (or prediction) mode, we input a model, a dataset, a model checkpoint generated during training, and output +a json file containing the prediction outputs for every input image. -To run inference, use the sub-command :ref:`predict <mednet.cli>` to run -prediction on an existing dataset: - -.. code:: sh - - mednet predict -vv <model> -w <path/to/model.pth> <dataset> +To run inference, use the sub-command :ref:`predict <mednet.cli>`. +Examples +======== -Replace ``<model>`` and ``<dataset>`` by the appropriate :ref:`configuration -files <mednet.config>`. Replace ``<path/to/model.pth>`` to a path leading to -the pre-trained model. +To run inference using a trained Pasa CNN on the Montgomery dataset: -.. tip:: +.. code:: sh - An option to generate grad-CAMs is available for the :py:mod:`DensenetRS - <mednet.config.models.densenet_rs>` model. To activate it, use the - ``--grad-cams`` argument. + mednet predict -vv pasa montgomery --weight=<path/to/model.ckpt> --output=<results/folder/predictions.json> -.. tip:: - An option to generate a relevance analysis plot is available. To activate - it, use the ``--relevance-analysis`` argument. +Replace ``<path/to/model.ckpt>`` to a path leading to the pre-trained model. Evaluation ---------- -In evaluation, we input a dataset and predictions to generate performance -summaries that help analysis of a trained model. Evaluation is done using the -:ref:`evaluate command <mednet.cli>` followed by the model and the annotated -dataset configuration, and the path to the pretrained weights via the -``--weight`` argument. +In evaluation, we input predictions to generate performance summaries that help analysis of a trained model. +The generated files are a .pdf containing various plots and a table of metrics for each dataset split. +Evaluation is done using the :ref:`evaluate command <mednet.cli>` followed by the json file generated during +the inference step and a threshold. Use ``mednet evaluate --help`` for more information. -E.g. run evaluation on predictions from the Montgomery set, do the following: - -.. code:: sh - - mednet evaluate -vv montgomery -p /predictions/folder -o /eval/results/folder - - -Comparing Systems ------------------ +Examples +======== -To compare multiple systems together and generate combined plots and tables, -use the :ref:`compare command <mednet.cli>`. Use ``--help`` for a quick -guide. +To run evaluation on predictions generated in the inference step, using an optimal threshold computed from the validation set, do the following: .. code:: sh - mednet compare -vv A A/metrics.csv B B/metrics.csv --output-figure=plot.pdf --output-table=table.txt --threshold=0.5 + mednet evaluate -vv --predictions=<path/to/predictions.json> --output-folder=<results/folder> --threshold=validation .. include:: ../links.rst diff --git a/doc/usage/experiment.rst b/doc/usage/experiment.rst new file mode 100644 index 0000000000000000000000000000000000000000..52f55f2b040145d73d2da83a9417dfba7a1829a1 --- /dev/null +++ b/doc/usage/experiment.rst @@ -0,0 +1,28 @@ +.. Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +.. +.. SPDX-License-Identifier: GPL-3.0-or-later + +.. _mednet.experiment: + +============================== + Running complete experiments +============================== + +We provide an :ref:`experiment command <mednet.cli>` +that runs training, followed by prediction and evaluation. +After running, you will be able to find results from model fitting, +prediction and evaluation under a single output directory. + +For example, to train a pasa model on the montgomery database +evaluate its performance and output predictions and performance curves, +run the following: + +.. code-block:: sh + + $ mednet experiment -vv pasa montgomery + # check results in the "results" folder + +You may run the system on a GPU by using the ``--device=cuda:0`` option. + + +.. include:: ../links.rst diff --git a/doc/usage/index.rst b/doc/usage/index.rst index 05dd69181eb1aca0dd8b538b844ea3efd9367870..9e5090b78231f444bf02a8e6ec32003f5607aa41 100644 --- a/doc/usage/index.rst +++ b/doc/usage/index.rst @@ -16,6 +16,7 @@ tuberculosis detection with support for the following activities. .. _mednet.usage.direct-detection: + Direct detection ---------------- @@ -24,33 +25,31 @@ Direct detection automatically, via error back propagation. The objective of this phase is to produce a CNN model. * Inference (prediction): The CNN is used to generate TB predictions. -* Evaluation: Predications are used to evaluate CNN performance against +* Evaluation: Predictions are used to evaluate CNN performance against provided annotations, and to generate measure files and score tables. Optimal - thresholds are also calculated. -* Comparison: Use predictions results to compare performance of multiple - systems. + thresholds can also be calculated. -.. _mednet.usage.indirect-detection: +.. \_mednet.usage.indirect-detection: -Indirect detection ------------------- +.. Indirect detection + ------------------ -* Training (step 1): Images are fed to a Convolutional Neural Network (CNN), +.. * Training (step 1): Images are fed to a Convolutional Neural Network (CNN), that is trained to detect the presence of radiological signs automatically, via error back propagation. The objective of this phase is to produce a CNN model. -* Inference (prediction): The CNN is used to generate radiological signs - predictions. -* Conversion of the radiological signs predictions into a new dataset. -* Training (step 2): Radiological signs are fed to a shallow network, that is - trained to detect the presence of tuberculosis automatically, via error back - propagation. The objective of this phase is to produce a shallow model. -* Inference (prediction): The shallow model is used to generate TB predictions. -* Evaluation: Predications are used to evaluate CNN performance against - provided annotations, and to generate measure files and score tables. -* Comparison: Use predictions results to compare performance of multiple - systems. + * Inference (prediction): The CNN is used to generate radiological signs + predictions. + * Conversion of the radiological signs predictions into a new dataset. + * Training (step 2): Radiological signs are fed to a shallow network, that is + trained to detect the presence of tuberculosis automatically, via error back + propagation. The objective of this phase is to produce a shallow model. + * Inference (prediction): The shallow model is used to generate TB predictions. + * Evaluation: Predications are used to evaluate CNN performance against + provided annotations, and to generate measure files and score tables. + * Comparison: Use predictions results to compare performance of multiple + systems. We provide :ref:`command-line interfaces (CLI) <mednet.cli>` that implement each of the phases above. This interface is configurable using :ref:`clapper's @@ -63,7 +62,7 @@ to an application. For reproducibility, we recommend you stick to configuration files when parameterizing our CLI. Notice some of the options in the CLI interface - (e.g. ``--dataset``) cannot be passed via the actual command-line as it + (e.g. ``--datamodule``) cannot be passed via the actual command-line as it may require complex Python types that cannot be synthetized in a single input parameter. @@ -80,12 +79,12 @@ Commands -------- .. toctree:: - :maxdepth: 2 + :maxdepth: 2 - training - evaluation - predtojson - aggregpred + experiment + training + evaluation + saliency .. include:: ../links.rst diff --git a/doc/usage/predtojson.rst b/doc/usage/predtojson.rst deleted file mode 100644 index 30ff645a379f29b09d4898c12899c286929314fe..0000000000000000000000000000000000000000 --- a/doc/usage/predtojson.rst +++ /dev/null @@ -1,24 +0,0 @@ -.. Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -.. -.. SPDX-License-Identifier: GPL-3.0-or-later - -.. _mednet.usage.predtojson: - -======================================== - Converting predictions to JSON dataset -======================================== - -This guide explains how to convert radiological signs predictions from a model -into a JSON dataset. It can be used to create new versions of TB datasets with -the predicted radiological signs to be able to use a shallow model. We input -predictions (CSV files) and output a ``dataset.json`` file. - -Use the sub-command :ref:`predtojson <mednet.cli>` to create your JSON dataset -file: - -.. code:: sh - - mednet predtojson -vv train train/predictions.csv test test/predictions.csv --output-folder=pred_to_json - - -.. include:: ../links.rst diff --git a/doc/usage/saliency.rst b/doc/usage/saliency.rst new file mode 100644 index 0000000000000000000000000000000000000000..65074ae5b6585c1a14e5c6c529a632b4679a0479 --- /dev/null +++ b/doc/usage/saliency.rst @@ -0,0 +1,110 @@ +.. Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +.. +.. SPDX-License-Identifier: GPL-3.0-or-later + +.. _mednet.usage.saliency: + +========== + Saliency +========== + +A saliency map highlights areas of interest within an image. In the context of TB detection, this would be the locations in a chest X-ray image where tuberculosis is present. + +This package provides scripts that can generate saliency maps and compute relevant metrics for interpretability purposes. + +Some of the scripts require the use of a database with human-annotated saliency information. + +Generation +---------- + +Saliency maps can be generated with the :ref:`saliency generate command <mednet.cli>`. +They are represented as numpy arrays of the same size as thes images, with values in the range [0-1] and saved in .npy files. + +Several mapping algorithms are available to choose from, which can be specified with the -s option. + +Examples +======== + +Generates saliency maps for all prediction dataloaders on a DataModule, +using a pre-trained pasa model, and saves them as numpy-pickeled +objects on the output directory: + +.. code:: sh + + mednet saliency generate -vv pasa tbx11k-v1-healthy-vs-atb --weight=path/to/model-at-lowest-validation-loss.ckpt --output-folder=path/to/output + +Viewing +------- + +To overlay saliency maps over the original images, use the :ref:`saliency view command <mednet.cli>`. +Results are saved as PNG images in which brigter pixels correspond to areas with higher saliency. + +Examples +======== + +Generates visualizations in form of heatmaps from existing saliency maps for a dataset configuration: + +.. code:: sh + + # input-folder is the location of the saliency maps created with `mednet generate` + mednet saliency view -vv pasa tbx11k-v1-healthy-vs-atb --input-folder=parent_folder/gradcam/ --output-folder=path/to/visualizations + + +Interpretability +---------------- + +Given a target label, the interpretability step computes the proportional energy and average saliency focus in a DataModule. + +The proportional energy is defined as the quantity of activation that lies within the ground truth boxes compared to the total sum of the activations. +The average saliency focus is the sum of the values of the saliency map over the ground-truth bounding boxes, normalized by the total area covered by all ground-truth bounding boxes. + +This requires a DataModule containing human-annotated bounding boxes. + +Examples +======== + +Evaluate the generated saliency maps for their localization performance: + +.. code:: sh + + mednet saliency interpretability -vv tbx11k-v1-healthy-vs-atb --input-folder=parent-folder/saliencies/ --output-json=path/to/interpretability-scores.json + + +Completeness +------------ +The saliency completeness script computes ROAD scores of saliency maps and saves them in a .json file. + +The ROAD algorithm estimates the explainability (in the completeness sense) of saliency maps by substituting +relevant pixels in the input image by a local average, re-running prediction on the altered image, +and measuring changes in the output classification score when said perturbations are in place. +By substituting most or least relevant pixels with surrounding averages, the ROAD algorithm estimates +the importance of such elements in the produced saliency map. + +More information can be found in [ROAD-2022]_. + +This requires a DataModule containing human-annotated bounding boxes. + +Examples +======== + +Calculates the ROAD scores for an existing dataset configuration and stores them in .json files: + +.. code:: sh + + mednet saliency completeness -vv pasa tbx11k-v1-healthy-vs-atb --device="cuda:0" --weight=path/to/model-at-lowest-validation-loss.ckpt --output-json=path/to/completeness-scores.json + + +Evaluation +---------- +The saliency evaluation step generates tables and plots from the results of the interpretability and completeness steps. + +Examples +======== + +Tabulates and generates plots for two saliency map algorithms: + +.. code:: sh + + mednet saliency evaluate -vv -e gradcam path/to/gradcam-completeness.json path/to/gradcam-interpretability.json -e gradcam++ path/to/gradcam++-completeness.json path/to/gradcam++-interpretability.json + +.. include:: ../links.rst diff --git a/doc/usage/training.rst b/doc/usage/training.rst index 4172732e030e6bdb5b7a9e345d2a5da75bf1022c..7cd58605e4d6fc05fa2ec0a8564cd05d91f80612 100644 --- a/doc/usage/training.rst +++ b/doc/usage/training.rst @@ -19,7 +19,7 @@ containing more detailed instructions. .. tip:: - We strongly advice training with a GPU (using ``--device="cuda:0"``). + We strongly advise training with a GPU (using ``--device="cuda:0"``). Depending on the available GPU memory you might have to adjust your batch size (``--batch``). @@ -33,41 +33,54 @@ To train Pasa CNN on the Montgomery dataset: mednet train -vv pasa montgomery --batch-size=4 --epochs=150 -To train DensenetRS CNN on the NIH CXR14 dataset: -.. code:: sh +.. Logistic regressor or shallow network + ------------------------------------- - mednet train -vv nih_cxr14 densenet_rs --batch-size=8 --epochs=10 + To train a logistic regressor or a shallow network, use the command-line + interface (CLI) application ``mednet train``, available on your prompt. To use + this CLI, you must define the input dataset that will be used to train the + model, as well as the type of model that will be trained. + You may issue ``mednet train --help`` for a help message containing more + detailed instructions. + Examples + ======== -Logistic regressor or shallow network -------------------------------------- + To train a logistic regressor using predictions from DensenetForRS on the + Montgomery dataset: -To train a logistic regressor or a shallow network, use the command-line -interface (CLI) application ``mednet train``, available on your prompt. To use -this CLI, you must define the input dataset that will be used to train the -model, as well as the type of model that will be trained. -You may issue ``mednet train --help`` for a help message containing more -detailed instructions. + .. code:: sh -Examples -======== + mednet train -vv logistic_regression montgomery_rs --batch-size=4 --epochs=20 -To train a logistic regressor using predictions from DensenetForRS on the -Montgomery dataset: -.. code:: sh + To train an multi-layer perceptron (MLP) using predictions from a densenet + pre-trained to detect radiological findings (using NIH CXR-14), on the Shenzhen + dataset: + + .. code:: sh + + mednet train -vv mlp shenzhen_rs --batch-size=4 --epochs=20 - mednet train -vv logistic_regression montgomery_rs --batch-size=4 --epochs=20 -To train an multi-layer perceptron (MLP) using predictions from a densenet -pre-trained to detect radiological findings (using NIH CXR-14), on the Shenzhen -dataset: +Plotting training metrics +------------------------- + +Various metrics are recorded at each epoch during training, such as the execution time, loss and resource usage. +These are saved in a Tensorboard file, located in a `logs` subdirectory of the training output folder. + +Mednet provides a :ref:`train-analysis <mednet.cli>` convenience script that graphs the scalars stored in these files and saves them in a .pdf file. + +Examples +======== + +Generates a .pdf file with plots showing the evolution of logged metrics in time: .. code:: sh - mednet train -vv mlp shenzhen_rs --batch-size=4 --epochs=20 + mednet train-analysis -vv <results/logs/folder> -o <results/trainlog.pdf> .. include:: ../links.rst diff --git a/pyproject.toml b/pyproject.toml index a8b12aae164d7f03c775022628fe418d765a66be..a182a7e0be6ba6f5577b46b18d7da5fee6975ef6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,7 @@ dependencies = [ "lightning <2.2.0a0,>=2.1.0", "tensorboard", "grad-cam>=1.4.8", + "numpydoc", ] [project.urls] @@ -256,3 +257,30 @@ line-length = 80 addopts = ["--cov=mednet", "--cov-report=term-missing", "--import-mode=append"] junit_logging = "all" junit_log_passing_tests = false + +[tool.numpydoc_validation] +checks = [ + "all", # report on all checks, except the below + "ES01", # Not all functions require extended summaries + "EX01", # Not all functions require examples + "GL01", # Expects text to be on the line after the opening quotes but that is in direct opposition of the sphinx recommendations and conflicts with other pre-commit hooks. + "GL08", # Causes issues if we don't have a docstring at the top of the file. Disabling this might fail to catch actual missing docstrings. + "PR04", # numpydoc does not currently support PEP484 typehints, which we are using + "RT03", # Since sphinx is unable to understand type annotations we need to remove some types from 'Returns', which breaks this check. + "SA01", # We do not use Also sections + "SS06", # Summary will span multiple lines if too long because of reformatting by other hooks. +] + +exclude = [ # don't report on objects that match any of these regex + '\.__len__$', + '\.__getitem__$', + '\.__iter__$', + '\.__exit__$', +] + +override_SS05 = [ # override SS05 to allow docstrings starting with these words + '^Process ', + '^Assess ', + '^Access ', + '^This', +] diff --git a/src/mednet/config/data/hivtb/datamodule.py b/src/mednet/config/data/hivtb/datamodule.py index d5bf81038eb704705fc32e599edd721d6b3d18ca..7a5617330e86b9c8bdd2f57a0a20ed3902cafd90 100644 --- a/src/mednet/config/data/hivtb/datamodule.py +++ b/src/mednet/config/data/hivtb/datamodule.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""HIV-TB dataset for computer-aided diagnosis (only BMP files) +"""HIV-TB dataset for computer-aided diagnosis (only BMP files). Database reference: [HIV-TB-2019]_ """ @@ -40,7 +40,7 @@ class RawDataLoader(_BaseRawDataLoader): ) def sample(self, sample: tuple[str, int]) -> Sample: - """Loads a single image sample from the disk. + """Load a single image sample from the disk. Parameters ---------- @@ -49,10 +49,9 @@ class RawDataLoader(_BaseRawDataLoader): where to find the image to be loaded, and an integer, representing the sample label. - Returns ------- - The sample representation + The sample representation. """ image = PIL.Image.open(os.path.join(self.datadir, sample[0])).convert( "L" @@ -69,7 +68,7 @@ class RawDataLoader(_BaseRawDataLoader): return tensor, dict(label=sample[1], name=sample[0]) # type: ignore[arg-type] def label(self, sample: tuple[str, int]) -> int: - """Loads a single image sample label from the disk. + """Load a single image sample label from the disk. Parameters ---------- @@ -78,16 +77,26 @@ class RawDataLoader(_BaseRawDataLoader): where to find the image to be loaded, and an integer, representing the sample label. - Returns ------- - The integer label associated with the sample + int + The integer label associated with the sample. """ return sample[1] def make_split(basename: str) -> DatabaseSplit: - """Returns a database split for the HIV-TB database.""" + """Return a database split for the HIV-TB database. + + Parameters + ---------- + basename + Name of the .json file containing the split to load. + + Returns + ------- + An instance of DatabaseSplit. + """ return JSONDatabaseSplit( importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(basename) @@ -95,7 +104,7 @@ def make_split(basename: str) -> DatabaseSplit: class DataModule(CachingDataModule): - """HIV-TB dataset for computer-aided diagnosis (only BMP files) + """HIV-TB dataset for computer-aided diagnosis (only BMP files). * Database reference: [HIV-TB-2019]_ * Original resolution, varying with most images being 2048 x 2500 pixels @@ -122,6 +131,11 @@ class DataModule(CachingDataModule): * Grayscale, encoded as a single plane tensor, 32-bit floats, square at 2048 x 2048 pixels * Labels: 0 (healthy), 1 (active tuberculosis) + + Parameters + ---------- + split_filename + Name of the .json file containing the split to load. """ def __init__(self, split_filename: str): diff --git a/src/mednet/config/data/indian/datamodule.py b/src/mednet/config/data/indian/datamodule.py index c8b762fc069d95df01af5715050b9788ed1b3b72..08a507223436be949b4ae03de1c2671060697c1c 100644 --- a/src/mednet/config/data/indian/datamodule.py +++ b/src/mednet/config/data/indian/datamodule.py @@ -19,7 +19,17 @@ database.""" def make_split(basename: str) -> DatabaseSplit: - """Returns a database split for the Indian database.""" + """Return a database split for the Indian database. + + Parameters + ---------- + basename + Name of the .json file containing the split to load. + + Returns + ------- + An instance of DatabaseSplit. + """ return JSONDatabaseSplit( importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(basename) @@ -59,6 +69,11 @@ class DataModule(CachingDataModule): * Grayscale, encoded as a single plane tensor, 32-bit floats, square, with varying resolutions, depending on the input raw image * Labels: 0 (healthy), 1 (active tuberculosis) + + Parameters + ---------- + split_filename + Name of the .json file containing the split to load. """ def __init__(self, split_filename: str): diff --git a/src/mednet/config/data/indian/fold_0.py b/src/mednet/config/data/indian/fold_0.py index 2c94e91d0a58f81581ee618a8b6199428ecdee99..3f6d60e77a8c6e88dce67dc8485121780b182dd6 100644 --- a/src/mednet/config/data/indian/fold_0.py +++ b/src/mednet/config/data/indian/fold_0.py @@ -1,8 +1,8 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Indian collection dataset for computer-aided diagnosis (cross validation -fold 0). +"""Indian collection dataset for computer-aided diagnosis (cross validationfold +0). Database reference: [INDIAN-2013]_ diff --git a/src/mednet/config/data/montgomery/datamodule.py b/src/mednet/config/data/montgomery/datamodule.py index ec1b7c1422a0ce4bd5948c7446c14f6d50478fb2..e19f32009e7a3ea7935a1fa3aea325c493728310 100644 --- a/src/mednet/config/data/montgomery/datamodule.py +++ b/src/mednet/config/data/montgomery/datamodule.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Montgomery datamodule for TB detection. +"""Montgomery DataModule for TB detection. Database reference: [MONTGOMERY-SHENZHEN-2014]_ """ @@ -39,7 +39,7 @@ class RawDataLoader(_BaseRawDataLoader): ) def sample(self, sample: tuple[str, int]) -> Sample: - """Loads a single image sample from the disk. + """Load a single image sample from the disk. Parameters ---------- @@ -48,10 +48,9 @@ class RawDataLoader(_BaseRawDataLoader): where to find the image to be loaded, and an integer, representing the sample label. - Returns ------- - The sample representation + The sample representation. """ # N.B.: Montgomery images are encoded as grayscale PNGs, so no need to # convert them again with Image.convert("L"). @@ -68,7 +67,7 @@ class RawDataLoader(_BaseRawDataLoader): return tensor, dict(label=sample[1], name=sample[0]) # type: ignore[arg-type] def label(self, sample: tuple[str, int]) -> int: - """Loads a single image sample label from the disk. + """Load a single image sample label from the disk. Parameters ---------- @@ -77,16 +76,26 @@ class RawDataLoader(_BaseRawDataLoader): where to find the image to be loaded, and an integer, representing the sample label. - Returns ------- - The integer label associated with the sample + int + The integer label associated with the sample. """ return sample[1] def make_split(basename: str) -> DatabaseSplit: - """Returns a database split for the Montgomery database.""" + """Return a database split for the Montgomery database. + + Parameters + ---------- + basename + Name of the .json file containing the split to load. + + Returns + ------- + An instance of DatabaseSplit. + """ return JSONDatabaseSplit( importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(basename) @@ -94,7 +103,7 @@ def make_split(basename: str) -> DatabaseSplit: class DataModule(CachingDataModule): - """Montgomery datamodule for TB detection. + """Montgomery DataModule for TB detection. The standard digital image database for Tuberculosis was created by the National Library of Medicine, Maryland, USA in collaboration with Shenzhen No.3 People’s @@ -124,6 +133,11 @@ class DataModule(CachingDataModule): * Grayscale, encoded as a single plane tensor, 32-bit floats, square at 4020 x 4020 pixels * Labels: 0 (healthy), 1 (active tuberculosis) + + Parameters + ---------- + split_filename + Name of the .json file containing the split to load. """ def __init__(self, split_filename: str): diff --git a/src/mednet/config/data/montgomery/default.py b/src/mednet/config/data/montgomery/default.py index edd014e8231350bb86e1e9fe28ab48409665c10c..afb62d5db67dd2bf402fb052a884284106cc3ec1 100644 --- a/src/mednet/config/data/montgomery/default.py +++ b/src/mednet/config/data/montgomery/default.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Montgomery datamodule for TB detection. +"""Montgomery DataModule for TB detection. Database reference: [MONTGOMERY-SHENZHEN-2014]_ diff --git a/src/mednet/config/data/montgomery/fold_0.py b/src/mednet/config/data/montgomery/fold_0.py index 6d62330303dc33c2d220133b8a1677ad3499afef..02597fad834853a69612c89096ee66a2de47378b 100644 --- a/src/mednet/config/data/montgomery/fold_0.py +++ b/src/mednet/config/data/montgomery/fold_0.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Montgomery datamodule for TB detection (cross validation fold 0). +"""Montgomery DataModule for TB detection (cross validation fold 0). Database reference: [MONTGOMERY-SHENZHEN-2014]_ diff --git a/src/mednet/config/data/montgomery/fold_1.py b/src/mednet/config/data/montgomery/fold_1.py index 649f625f75266727e7beec87a5d226d959ce0798..5aff117fafeba1aab3508e8fcec94408ef19ac68 100644 --- a/src/mednet/config/data/montgomery/fold_1.py +++ b/src/mednet/config/data/montgomery/fold_1.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Montgomery datamodule for TB detection (cross validation fold 1). +"""Montgomery DataModule for TB detection (cross validation fold 1). Database reference: [MONTGOMERY-SHENZHEN-2014]_ diff --git a/src/mednet/config/data/montgomery/fold_2.py b/src/mednet/config/data/montgomery/fold_2.py index 28f55b902a5f675f01b69574089820bbf060eaaa..879562db0df7f165e66a65eaffcd0c825ccb678d 100644 --- a/src/mednet/config/data/montgomery/fold_2.py +++ b/src/mednet/config/data/montgomery/fold_2.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Montgomery datamodule for TB detection (cross validation fold 2). +"""Montgomery DataModule for TB detection (cross validation fold 2). Database reference: [MONTGOMERY-SHENZHEN-2014]_ diff --git a/src/mednet/config/data/montgomery/fold_3.py b/src/mednet/config/data/montgomery/fold_3.py index e357f27efb0fe79bc313112a044649c4bfb11d7a..1e8a31e002cbda5a722e80143045b569f51ee327 100644 --- a/src/mednet/config/data/montgomery/fold_3.py +++ b/src/mednet/config/data/montgomery/fold_3.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Montgomery datamodule for TB detection (cross validation fold 3). +"""Montgomery DataModule for TB detection (cross validation fold 3). Database reference: [MONTGOMERY-SHENZHEN-2014]_ diff --git a/src/mednet/config/data/montgomery/fold_4.py b/src/mednet/config/data/montgomery/fold_4.py index 57402c018068cf7b24f29ce61442e73067148c4c..eb396a7a2a44aae2cc8fb21dd496ff9f1346461a 100644 --- a/src/mednet/config/data/montgomery/fold_4.py +++ b/src/mednet/config/data/montgomery/fold_4.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Montgomery datamodule for TB detection (cross validation fold 4). +"""Montgomery DataModule for TB detection (cross validation fold 4). Database reference: [MONTGOMERY-SHENZHEN-2014]_ diff --git a/src/mednet/config/data/montgomery/fold_5.py b/src/mednet/config/data/montgomery/fold_5.py index 30b172411d3f9e726ad458bc85540fabdaddff5f..b3620674900c81634d5d17c577dbe240cc2b8729 100644 --- a/src/mednet/config/data/montgomery/fold_5.py +++ b/src/mednet/config/data/montgomery/fold_5.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Montgomery datamodule for TB detection (cross validation fold 5). +"""Montgomery DataModule for TB detection (cross validation fold 5). Database reference: [MONTGOMERY-SHENZHEN-2014]_ diff --git a/src/mednet/config/data/montgomery/fold_6.py b/src/mednet/config/data/montgomery/fold_6.py index d8a2a8e2ff934c02a40480f0f6709425277dcbfc..298fc9c5125e45453e1cda5c9637283a3c3a575f 100644 --- a/src/mednet/config/data/montgomery/fold_6.py +++ b/src/mednet/config/data/montgomery/fold_6.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Montgomery datamodule for TB detection (cross validation fold 6). +"""Montgomery DataModule for TB detection (cross validation fold 6). Database reference: [MONTGOMERY-SHENZHEN-2014]_ diff --git a/src/mednet/config/data/montgomery/fold_7.py b/src/mednet/config/data/montgomery/fold_7.py index 86d16e0c12ce0423454f4f32f69b99830f2b40b2..93ca3c00c8bdceed09bc911c8451873fcec2e56f 100644 --- a/src/mednet/config/data/montgomery/fold_7.py +++ b/src/mednet/config/data/montgomery/fold_7.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Montgomery datamodule for TB detection (cross validation fold 7). +"""Montgomery DataModule for TB detection (cross validation fold 7). Database reference: [MONTGOMERY-SHENZHEN-2014]_ diff --git a/src/mednet/config/data/montgomery/fold_8.py b/src/mednet/config/data/montgomery/fold_8.py index 67a337d2fcc8b7cda4fe888f63cbe253e7f0e933..05b2b1d14e1c322e4ae333d9bca984bbe425f33d 100644 --- a/src/mednet/config/data/montgomery/fold_8.py +++ b/src/mednet/config/data/montgomery/fold_8.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Montgomery datamodule for TB detection (cross validation fold 8). +"""Montgomery DataModule for TB detection (cross validation fold 8). Database reference: [MONTGOMERY-SHENZHEN-2014]_ diff --git a/src/mednet/config/data/montgomery/fold_9.py b/src/mednet/config/data/montgomery/fold_9.py index 032cdbfef25fc2b142910b1d754417856bcebefb..ac6539d960584394e298fb12013aff3827e2b2a2 100644 --- a/src/mednet/config/data/montgomery/fold_9.py +++ b/src/mednet/config/data/montgomery/fold_9.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Montgomery datamodule for TB detection (cross validation fold 9). +"""Montgomery DataModule for TB detection (cross validation fold 9). Database reference: [MONTGOMERY-SHENZHEN-2014]_ diff --git a/src/mednet/config/data/montgomery_shenzhen/datamodule.py b/src/mednet/config/data/montgomery_shenzhen/datamodule.py index 9699f19d7fde69565d70c8767538d0136c325ff5..fa83fdde5165e24c0d08b0b9c123b87bd98e3805 100644 --- a/src/mednet/config/data/montgomery_shenzhen/datamodule.py +++ b/src/mednet/config/data/montgomery_shenzhen/datamodule.py @@ -10,7 +10,13 @@ from ..shenzhen.datamodule import make_split as make_shenzhen_split class DataModule(ConcatDataModule): - """Aggregated datamodule composed of Montgomery and Shenzhen datasets.""" + """Aggregated DataModule composed of Montgomery and Shenzhen datasets. + + Parameters + ---------- + split_filename + Name of the .json file containing the split to load. + """ def __init__(self, split_filename: str): montgomery_loader = MontgomeryLoader() diff --git a/src/mednet/config/data/montgomery_shenzhen/default.py b/src/mednet/config/data/montgomery_shenzhen/default.py index 106bb415262932cd4fcfa954322f6f8d7353f554..24dcc4b3b21190f86e04288caec1575bebdfd11d 100644 --- a/src/mednet/config/data/montgomery_shenzhen/default.py +++ b/src/mednet/config/data/montgomery_shenzhen/default.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Aggregated datamodule composed of Montgomery and Shenzhen datasets (default +"""Aggregated DataModule composed of Montgomery and Shenzhen datasets (default split). See :py:class:`.montgomery_shenzhen.datamodule.DataModule` for technical details. diff --git a/src/mednet/config/data/montgomery_shenzhen/fold_0.py b/src/mednet/config/data/montgomery_shenzhen/fold_0.py index d03c9c544afd5ec5a3abec3c17984d5e571c3bec..8d3191561980e3b31d809061257ceacf757cca9e 100644 --- a/src/mednet/config/data/montgomery_shenzhen/fold_0.py +++ b/src/mednet/config/data/montgomery_shenzhen/fold_0.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Aggregated datamodule composed of Montgomery and Shenzhen datasets (cross +"""Aggregated DataModule composed of Montgomery and Shenzhen datasets (cross validation fold 0). See :py:class:`.montgomery_shenzhen.datamodule.DataModule` for technical details. diff --git a/src/mednet/config/data/montgomery_shenzhen/fold_1.py b/src/mednet/config/data/montgomery_shenzhen/fold_1.py index 948a75b07de9fe76df372964fec4c147063fcc8d..a9095a0db8a69f056576d57d89f2fcddaf06b9ab 100644 --- a/src/mednet/config/data/montgomery_shenzhen/fold_1.py +++ b/src/mednet/config/data/montgomery_shenzhen/fold_1.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Aggregated datamodule composed of Montgomery and Shenzhen datasets (cross +"""Aggregated DataModule composed of Montgomery and Shenzhen datasets (cross validation fold 1). See :py:class:`.montgomery_shenzhen.datamodule.DataModule` for technical details. diff --git a/src/mednet/config/data/montgomery_shenzhen/fold_2.py b/src/mednet/config/data/montgomery_shenzhen/fold_2.py index c627fa2253cf034a402135f8575ae99042487682..79203be407919e4485836981e7e34b546003630a 100644 --- a/src/mednet/config/data/montgomery_shenzhen/fold_2.py +++ b/src/mednet/config/data/montgomery_shenzhen/fold_2.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Aggregated datamodule composed of Montgomery and Shenzhen datasets (cross +"""Aggregated DataModule composed of Montgomery and Shenzhen datasets (cross validation fold 2). See :py:class:`.montgomery_shenzhen.datamodule.DataModule` for technical details. diff --git a/src/mednet/config/data/montgomery_shenzhen/fold_3.py b/src/mednet/config/data/montgomery_shenzhen/fold_3.py index 9ba252f3114625f05b21cbbad1a420c5ea9dc834..fb3114fd681d8174891dbd95bc8f5f9f83743c92 100644 --- a/src/mednet/config/data/montgomery_shenzhen/fold_3.py +++ b/src/mednet/config/data/montgomery_shenzhen/fold_3.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Aggregated datamodule composed of Montgomery and Shenzhen datasets (cross +"""Aggregated DataModule composed of Montgomery and Shenzhen datasets (cross validation fold 3). See :py:class:`.montgomery_shenzhen.datamodule.DataModule` for technical details. diff --git a/src/mednet/config/data/montgomery_shenzhen/fold_4.py b/src/mednet/config/data/montgomery_shenzhen/fold_4.py index 720a3adc225b864c104a2cb8bd77250fde45e7f3..24cbc745242824ef96366887a1ce71ad9920f962 100644 --- a/src/mednet/config/data/montgomery_shenzhen/fold_4.py +++ b/src/mednet/config/data/montgomery_shenzhen/fold_4.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Aggregated datamodule composed of Montgomery and Shenzhen datasets (cross +"""Aggregated DataModule composed of Montgomery and Shenzhen datasets (cross validation fold 4). See :py:class:`.montgomery_shenzhen.datamodule.DataModule` for technical details. diff --git a/src/mednet/config/data/montgomery_shenzhen/fold_5.py b/src/mednet/config/data/montgomery_shenzhen/fold_5.py index c784a4ea01bfeacb2140b7ffdff5d06a618865ab..887cc682989d94d480d79776935381ed67111d8a 100644 --- a/src/mednet/config/data/montgomery_shenzhen/fold_5.py +++ b/src/mednet/config/data/montgomery_shenzhen/fold_5.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Aggregated datamodule composed of Montgomery and Shenzhen datasets (cross +"""Aggregated DataModule composed of Montgomery and Shenzhen datasets (cross validation fold 5). See :py:class:`.montgomery_shenzhen.datamodule.DataModule` for technical details. diff --git a/src/mednet/config/data/montgomery_shenzhen/fold_6.py b/src/mednet/config/data/montgomery_shenzhen/fold_6.py index eaa4647240a73b2ae6d9b8b4e5ce7a76f2a6ad16..81a8a906366a6f60a433ff044441e9274d059eb2 100644 --- a/src/mednet/config/data/montgomery_shenzhen/fold_6.py +++ b/src/mednet/config/data/montgomery_shenzhen/fold_6.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Aggregated datamodule composed of Montgomery and Shenzhen datasets (cross +"""Aggregated DataModule composed of Montgomery and Shenzhen datasets (cross validation fold 6). See :py:class:`.montgomery_shenzhen.datamodule.DataModule` for technical details. diff --git a/src/mednet/config/data/montgomery_shenzhen/fold_7.py b/src/mednet/config/data/montgomery_shenzhen/fold_7.py index 7a6dd33b281e0c41b0823dacba70d443e19310d2..298f19616a867b9395ae3233d2d205fd1e487f93 100644 --- a/src/mednet/config/data/montgomery_shenzhen/fold_7.py +++ b/src/mednet/config/data/montgomery_shenzhen/fold_7.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Aggregated datamodule composed of Montgomery and Shenzhen datasets (cross +"""Aggregated DataModule composed of Montgomery and Shenzhen datasets (cross validation fold 7). See :py:class:`.montgomery_shenzhen.datamodule.DataModule` for technical details. diff --git a/src/mednet/config/data/montgomery_shenzhen/fold_8.py b/src/mednet/config/data/montgomery_shenzhen/fold_8.py index 8f295426865d2523caffac09e52f1d67fd138379..b8c2ff10e351aee4525827ac8a70a91e9c21619b 100644 --- a/src/mednet/config/data/montgomery_shenzhen/fold_8.py +++ b/src/mednet/config/data/montgomery_shenzhen/fold_8.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Aggregated datamodule composed of Montgomery and Shenzhen datasets (cross +"""Aggregated DataModule composed of Montgomery and Shenzhen datasets (cross validation fold 8). See :py:class:`.montgomery_shenzhen.datamodule.DataModule` for technical details. diff --git a/src/mednet/config/data/montgomery_shenzhen/fold_9.py b/src/mednet/config/data/montgomery_shenzhen/fold_9.py index 1ec041b8c56923b60eb744f744dfe579bc3fd434..30528cae581a3dbb274cec9511b0c8ab23c3d687 100644 --- a/src/mednet/config/data/montgomery_shenzhen/fold_9.py +++ b/src/mednet/config/data/montgomery_shenzhen/fold_9.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Aggregated datamodule composed of Montgomery and Shenzhen datasets (cross +"""Aggregated DataModule composed of Montgomery and Shenzhen datasets (cross validation fold 9). See :py:class:`.montgomery_shenzhen.datamodule.DataModule` for technical details. diff --git a/src/mednet/config/data/montgomery_shenzhen_indian/datamodule.py b/src/mednet/config/data/montgomery_shenzhen_indian/datamodule.py index 2001857379da827cc03c0bf7a1a21466cadc3c79..676fa8ef96d26b794341dee5c5cd09caf55d06d0 100644 --- a/src/mednet/config/data/montgomery_shenzhen_indian/datamodule.py +++ b/src/mednet/config/data/montgomery_shenzhen_indian/datamodule.py @@ -1,8 +1,7 @@ # Copyright © 2022 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Aggregated datamodule composed of Montgomery, Shenzhen and Indian -datasets.""" +"""Aggregated DataModule composed of Montgomery, Shenzhen and Indian datasets.""" from ....data.datamodule import ConcatDataModule from ..indian.datamodule import RawDataLoader as IndianLoader @@ -14,8 +13,13 @@ from ..shenzhen.datamodule import make_split as make_shenzhen_split class DataModule(ConcatDataModule): - """Aggregated datamodule composed of Montgomery, Shenzhen and Indian - datasets.""" + """Aggregated DataModule composed of Montgomery, Shenzhen and Indian datasets. + + Parameters + ---------- + split_filename + Name of the .json file containing the split to load. + """ def __init__(self, split_filename: str): montgomery_loader = MontgomeryLoader() diff --git a/src/mednet/config/data/montgomery_shenzhen_indian/default.py b/src/mednet/config/data/montgomery_shenzhen_indian/default.py index a1ac8047aa6f526b68e5b4df03a87a55de98b140..38fca2a93cf5860978397bd737ad7b8cde611704 100644 --- a/src/mednet/config/data/montgomery_shenzhen_indian/default.py +++ b/src/mednet/config/data/montgomery_shenzhen_indian/default.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Aggregated datamodule composed of Montgomery, Shenzhen and Indian datasets. +"""Aggregated DataModule composed of Montgomery, Shenzhen and Indian datasets. See :py:class:`.montgomery_shenzhen_indian.datamodule.DataModule` for technical details. diff --git a/src/mednet/config/data/montgomery_shenzhen_indian/fold_0.py b/src/mednet/config/data/montgomery_shenzhen_indian/fold_0.py index aecbbd107ab4a4d8e1bbe15b318deb8f58fe1238..83d67fe3506611d6d4b1a4bf352c6faa27ad7b08 100644 --- a/src/mednet/config/data/montgomery_shenzhen_indian/fold_0.py +++ b/src/mednet/config/data/montgomery_shenzhen_indian/fold_0.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Aggregated datamodule composed of Montgomery, Shenzhen and Indian datasets +"""Aggregated DataModule composed of Montgomery, Shenzhen and Indian datasets (cross validation fold 0). See :py:class:`.montgomery_shenzhen_indian.datamodule.DataModule` for technical diff --git a/src/mednet/config/data/montgomery_shenzhen_indian/fold_1.py b/src/mednet/config/data/montgomery_shenzhen_indian/fold_1.py index bef0e4aee628a0a22f1b49a5052ee9124d6a2ad9..a44b6a68c73942700c4a86165bce53d12a854680 100644 --- a/src/mednet/config/data/montgomery_shenzhen_indian/fold_1.py +++ b/src/mednet/config/data/montgomery_shenzhen_indian/fold_1.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Aggregated datamodule composed of Montgomery, Shenzhen and Indian datasets +"""Aggregated DataModule composed of Montgomery, Shenzhen and Indian datasets (cross validation fold 1). See :py:class:`.montgomery_shenzhen_indian.datamodule.DataModule` for technical details. diff --git a/src/mednet/config/data/montgomery_shenzhen_indian/fold_2.py b/src/mednet/config/data/montgomery_shenzhen_indian/fold_2.py index 4a75d3329902f0ca439bf17bafea950f5bc23de8..5dac88e8d3a965da37da6f1c76a52721779fad9c 100644 --- a/src/mednet/config/data/montgomery_shenzhen_indian/fold_2.py +++ b/src/mednet/config/data/montgomery_shenzhen_indian/fold_2.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Aggregated datamodule composed of Montgomery, Shenzhen and Indian datasets +"""Aggregated DataModule composed of Montgomery, Shenzhen and Indian datasets (cross validation fold 2). See :py:class:`.montgomery_shenzhen_indian.datamodule.DataModule` for technical diff --git a/src/mednet/config/data/montgomery_shenzhen_indian/fold_3.py b/src/mednet/config/data/montgomery_shenzhen_indian/fold_3.py index a85428cb606030252413c841b2849b128a98713c..8fe095a351874f51457b031ff71e91876177ae7d 100644 --- a/src/mednet/config/data/montgomery_shenzhen_indian/fold_3.py +++ b/src/mednet/config/data/montgomery_shenzhen_indian/fold_3.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Aggregated datamodule composed of Montgomery, Shenzhen and Indian datasets +"""Aggregated DataModule composed of Montgomery, Shenzhen and Indian datasets (cross validation fold 3). See :py:class:`.montgomery_shenzhen_indian.datamodule.DataModule` for technical diff --git a/src/mednet/config/data/montgomery_shenzhen_indian/fold_4.py b/src/mednet/config/data/montgomery_shenzhen_indian/fold_4.py index 857a5b622fbe5f814b8aec418fc631297632083d..6524f589f7bd67b979868ae8b51613da45df6cfd 100644 --- a/src/mednet/config/data/montgomery_shenzhen_indian/fold_4.py +++ b/src/mednet/config/data/montgomery_shenzhen_indian/fold_4.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Aggregated datamodule composed of Montgomery, Shenzhen and Indian datasets +"""Aggregated DataModule composed of Montgomery, Shenzhen and Indian datasets (cross validation fold 4). See :py:class:`.montgomery_shenzhen_indian.datamodule.DataModule` for technical diff --git a/src/mednet/config/data/montgomery_shenzhen_indian/fold_5.py b/src/mednet/config/data/montgomery_shenzhen_indian/fold_5.py index 58ac206ca6b30af085b4c7d742757414267e934a..8b0acbe6e6fdf7b5739fa2ddb45b5daadb1d1283 100644 --- a/src/mednet/config/data/montgomery_shenzhen_indian/fold_5.py +++ b/src/mednet/config/data/montgomery_shenzhen_indian/fold_5.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Aggregated datamodule composed of Montgomery, Shenzhen and Indian datasets +"""Aggregated DataModule composed of Montgomery, Shenzhen and Indian datasets (cross validation fold 5). See :py:class:`.montgomery_shenzhen_indian.datamodule.DataModule` for technical diff --git a/src/mednet/config/data/montgomery_shenzhen_indian/fold_6.py b/src/mednet/config/data/montgomery_shenzhen_indian/fold_6.py index db77fcc0fcb92446a7220c806fa63a6b7b22bf4b..d8b565adfb2e08cd2891593e74798c15187dadbf 100644 --- a/src/mednet/config/data/montgomery_shenzhen_indian/fold_6.py +++ b/src/mednet/config/data/montgomery_shenzhen_indian/fold_6.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Aggregated datamodule composed of Montgomery, Shenzhen and Indian datasets +"""Aggregated DataModule composed of Montgomery, Shenzhen and Indian datasets (cross validation fold 6). See :py:class:`.montgomery_shenzhen_indian.datamodule.DataModule` for technical diff --git a/src/mednet/config/data/montgomery_shenzhen_indian/fold_7.py b/src/mednet/config/data/montgomery_shenzhen_indian/fold_7.py index aa7f79e5d7d18d7676953ec208da8fc7acddbc54..248f6621cd9cfb048007889618c672ada29e9398 100644 --- a/src/mednet/config/data/montgomery_shenzhen_indian/fold_7.py +++ b/src/mednet/config/data/montgomery_shenzhen_indian/fold_7.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Aggregated datamodule composed of Montgomery, Shenzhen and Indian datasets +"""Aggregated DataModule composed of Montgomery, Shenzhen and Indian datasets (cross validation fold 7). See :py:class:`.montgomery_shenzhen_indian.datamodule.DataModule` for technical details. diff --git a/src/mednet/config/data/montgomery_shenzhen_indian/fold_8.py b/src/mednet/config/data/montgomery_shenzhen_indian/fold_8.py index 31fca75024ac528cc04879fbddc3150367026239..41c5e5cab86184e74eaf924612ee04f7c672855e 100644 --- a/src/mednet/config/data/montgomery_shenzhen_indian/fold_8.py +++ b/src/mednet/config/data/montgomery_shenzhen_indian/fold_8.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Aggregated datamodule composed of Montgomery, Shenzhen and Indian datasets +"""Aggregated DataModule composed of Montgomery, Shenzhen and Indian datasets (cross validation fold 8). See :py:class:`.montgomery_shenzhen_indian.datamodule.DataModule` for technical diff --git a/src/mednet/config/data/montgomery_shenzhen_indian/fold_9.py b/src/mednet/config/data/montgomery_shenzhen_indian/fold_9.py index 54c9a07e798db588cd2857331485302fd770af76..f658cf37c0c0cc47fb3fc92ab3807af12998f0a4 100644 --- a/src/mednet/config/data/montgomery_shenzhen_indian/fold_9.py +++ b/src/mednet/config/data/montgomery_shenzhen_indian/fold_9.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Aggregated datamodule composed of Montgomery, Shenzhen and Indian datasets +"""Aggregated DataModule composed of Montgomery, Shenzhen and Indian datasets (cross validation fold 9). See :py:class:`.montgomery_shenzhen_indian.datamodule.DataModule` for technical details. diff --git a/src/mednet/config/data/montgomery_shenzhen_indian_padchest/datamodule.py b/src/mednet/config/data/montgomery_shenzhen_indian_padchest/datamodule.py index 8414ad94077aa995b26ad79bf141d5d2c69ee914..2876af8f0335acd87313c02ca18ad8b47dd21bcc 100644 --- a/src/mednet/config/data/montgomery_shenzhen_indian_padchest/datamodule.py +++ b/src/mednet/config/data/montgomery_shenzhen_indian_padchest/datamodule.py @@ -1,8 +1,7 @@ # Copyright © 2022 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Aggregated datamodule composed of Montgomery, Shenzhen, Indian, and PadChest -datasets.""" +"""Aggregated DataModule composed of Montgomery, Shenzhen, Indian, and PadChest datasets.""" from ....data.datamodule import ConcatDataModule from ..indian.datamodule import RawDataLoader as IndianLoader @@ -16,8 +15,16 @@ from ..shenzhen.datamodule import make_split as make_shenzhen_split class DataModule(ConcatDataModule): - """Aggregated datamodule composed of Montgomery, Shenzhen, Indian, and - PadChest datasets.""" + """Aggregated DataModule composed of Montgomery, Shenzhen, Indian, and + PadChest datasets. + + Parameters + ---------- + split_filename + Name of the .json file containing the split to load. + padchest_split_filename + Name of the .json file from padchest containing the split to load. + """ def __init__(self, split_filename: str, padchest_split_filename: str): montgomery_loader = MontgomeryLoader() diff --git a/src/mednet/config/data/montgomery_shenzhen_indian_padchest/default.py b/src/mednet/config/data/montgomery_shenzhen_indian_padchest/default.py index 07b695cce4573ecc80347df15dfb44ea281864a7..7a0c7dce64501dca4d6017874ebf0cdf0ae26be2 100644 --- a/src/mednet/config/data/montgomery_shenzhen_indian_padchest/default.py +++ b/src/mednet/config/data/montgomery_shenzhen_indian_padchest/default.py @@ -1,8 +1,7 @@ # Copyright © 2022 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and Padchest -datasets.""" +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and Padchest datasets.""" from mednet.config.data.montgomery_shenzhen_indian_padchest.datamodule import ( DataModule, diff --git a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/datamodule.py b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/datamodule.py index 783f7251a6cf0f9bfd384d829bf94050bb8fce46..8dd831981465af7afa276d67287de5c09008bc36 100644 --- a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/datamodule.py +++ b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/datamodule.py @@ -1,8 +1,7 @@ # Copyright © 2022 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Aggregated datamodule composed of Montgomery, Shenzhen, Indian, and TBX11k -datasets.""" +"""Aggregated DataModule composed of Montgomery, Shenzhen, Indian, and TBX11k datasets.""" from ....data.datamodule import ConcatDataModule from ..indian.datamodule import RawDataLoader as IndianLoader @@ -16,10 +15,18 @@ from ..tbx11k.datamodule import make_split as make_tbx11k_split class DataModule(ConcatDataModule): - """Aggregated datamodule composed of Montgomery, Shenzhen, Indian, and - TBX11k datasets.""" + """Aggregated DataModule composed of Montgomery, Shenzhen, Indian, and + TBX11k datasets. - def __init__(self, split_filename: str, tbx11_split_filename: str): + Parameters + ---------- + split_filename + Name of the .json file containing the split to load. + tbx11k_split_filename + Name of the .json file from tbx11k containing the split to load. + """ + + def __init__(self, split_filename: str, tbx11k_split_filename: str): montgomery_loader = MontgomeryLoader() montgomery_split = make_montgomery_split(split_filename) shenzhen_loader = ShenzhenLoader() @@ -27,7 +34,7 @@ class DataModule(ConcatDataModule): indian_loader = IndianLoader() indian_split = make_indian_split(split_filename) tbx11k_loader = TBX11kLoader() - tbx11k_split = make_tbx11k_split(tbx11_split_filename) + tbx11k_split = make_tbx11k_split(tbx11k_split_filename) super().__init__( splits={ diff --git a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v1_fold_0.py b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v1_fold_0.py index 88cd03e2c8f11c86490864fb1c4362702796fec0..a866e229c5220058d5f838288eed1e44406a288d 100644 --- a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v1_fold_0.py +++ b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v1_fold_0.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Aggregated datamodule composed of Montgomery, Shenzhen, Indian, and TBX11k +"""Aggregated DataModule composed of Montgomery, Shenzhen, Indian, and TBX11k datasets (cross-validation fold 0). This remix dataset combines ``fold-0`` from Montgomery, Shenzhen, and Indian diff --git a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v1_fold_1.py b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v1_fold_1.py index 10f57c42bcc895ef47e447150859cbb7665f3d43..cc9d26ce771ee1833111e665ab102efbbf077120 100644 --- a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v1_fold_1.py +++ b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v1_fold_1.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Aggregated datamodule composed of Montgomery, Shenzhen, Indian, and TBX11k +"""Aggregated DataModule composed of Montgomery, Shenzhen, Indian, and TBX11k datasets (cross-validation fold 1). This remix dataset combines ``fold-1`` from Montgomery, Shenzhen, and Indian diff --git a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v1_fold_2.py b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v1_fold_2.py index 6ec20140b059eb0f98c60341479d0ebefd318ca7..c9f77703b1ec5363ca2431bbf4d39ac009dbc32e 100644 --- a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v1_fold_2.py +++ b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v1_fold_2.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Aggregated datamodule composed of Montgomery, Shenzhen, Indian, and TBX11k +"""Aggregated DataModule composed of Montgomery, Shenzhen, Indian, and TBX11k datasets (cross-validation fold 2). This remix dataset combines ``fold-2`` from Montgomery, Shenzhen, and Indian diff --git a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v1_fold_3.py b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v1_fold_3.py index 569ed8017632595bd2627a46571587f6bfb0f898..639ac053f3b34d9b4278d51b6bed76b0e6c77c51 100644 --- a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v1_fold_3.py +++ b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v1_fold_3.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Aggregated datamodule composed of Montgomery, Shenzhen, Indian, and TBX11k +"""Aggregated DataModule composed of Montgomery, Shenzhen, Indian, and TBX11k datasets (cross-validation fold 3). This remix dataset combines ``fold-3`` from Montgomery, Shenzhen, and Indian diff --git a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v1_fold_4.py b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v1_fold_4.py index 04a74cd35ccb133634d8ca5db18936861ef2eab0..b08f0588e44a040ad4dbc3c9c918e8a82b260023 100644 --- a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v1_fold_4.py +++ b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v1_fold_4.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Aggregated datamodule composed of Montgomery, Shenzhen, Indian, and TBX11k +"""Aggregated DataModule composed of Montgomery, Shenzhen, Indian, and TBX11k datasets (cross-validation fold 4). This remix dataset combines ``fold-4`` from Montgomery, Shenzhen, and Indian diff --git a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v1_fold_5.py b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v1_fold_5.py index b528a7b6799f76da1cb8c50c8071f088f6c082e2..7cfd77253141d959731bbb857cf04d7b090b2b53 100644 --- a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v1_fold_5.py +++ b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v1_fold_5.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Aggregated datamodule composed of Montgomery, Shenzhen, Indian, and TBX11k +"""Aggregated DataModule composed of Montgomery, Shenzhen, Indian, and TBX11k datasets (cross-validation fold 5). This remix dataset combines ``fold-5`` from Montgomery, Shenzhen, and Indian diff --git a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v1_fold_6.py b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v1_fold_6.py index fdf55b4274f8329dda7aa28f654bde654df2fc33..c8341e7e45ea8411f51ece99b4accb6b851a5151 100644 --- a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v1_fold_6.py +++ b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v1_fold_6.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Aggregated datamodule composed of Montgomery, Shenzhen, Indian, and TBX11k +"""Aggregated DataModule composed of Montgomery, Shenzhen, Indian, and TBX11k datasets (cross-validation fold 6). This remix dataset combines ``fold-6`` from Montgomery, Shenzhen, and Indian diff --git a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v1_fold_7.py b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v1_fold_7.py index 833d2c04b2b8699e5143f9ad99231440133ff9bc..9f9b19d54bbf0c8c5211794e86d30d2ba92b2099 100644 --- a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v1_fold_7.py +++ b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v1_fold_7.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Aggregated datamodule composed of Montgomery, Shenzhen, Indian, and TBX11k +"""Aggregated DataModule composed of Montgomery, Shenzhen, Indian, and TBX11k datasets (cross-validation fold 7). This remix dataset combines ``fold-7`` from Montgomery, Shenzhen, and Indian diff --git a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v1_fold_8.py b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v1_fold_8.py index a0dd7e822386ea260808a0ead9475d78136ce91e..a212479358eb129cc64baadc45865519594e9fe9 100644 --- a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v1_fold_8.py +++ b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v1_fold_8.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Aggregated datamodule composed of Montgomery, Shenzhen, Indian, and TBX11k +"""Aggregated DataModule composed of Montgomery, Shenzhen, Indian, and TBX11k datasets (cross-validation fold 8). This remix dataset combines ``fold-8`` from Montgomery, Shenzhen, and Indian diff --git a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v1_fold_9.py b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v1_fold_9.py index ea00d8ecf7726c308fc9007031ab7af21b5f4008..b2dff65bd393055eec1135a922339cbc3ee09832 100644 --- a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v1_fold_9.py +++ b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v1_fold_9.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Aggregated datamodule composed of Montgomery, Shenzhen, Indian, and TBX11k +"""Aggregated DataModule composed of Montgomery, Shenzhen, Indian, and TBX11k datasets (cross-validation fold 9). This remix dataset combines ``fold-9`` from Montgomery, Shenzhen, and Indian diff --git a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v1_healthy_vs_atb.py b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v1_healthy_vs_atb.py index 278ac64eeadc31a36b5df90ee5c76ccdc92e7a09..0f4d1393f0fb0973cf7b57cb3312f98c3bafd164 100644 --- a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v1_healthy_vs_atb.py +++ b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v1_healthy_vs_atb.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Aggregated datamodule composed of Montgomery, Shenzhen, Indian, and TBX11k +"""Aggregated DataModule composed of Montgomery, Shenzhen, Indian, and TBX11k datasets (v1-healthy-vs-atb). This remix dataset combines the ``default`` split from Montgomery, Shenzhen, diff --git a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v2_fold_0.py b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v2_fold_0.py index 022824a1795eb56bd7e95ccf7845b86452020098..93fe390dd8de655a89f4cdb8f74382734def4182 100644 --- a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v2_fold_0.py +++ b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v2_fold_0.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Aggregated datamodule composed of Montgomery, Shenzhen, Indian, and TBX11k +"""Aggregated DataModule composed of Montgomery, Shenzhen, Indian, and TBX11k datasets (cross-validation fold 0). This remix dataset combines ``fold-0`` from Montgomery, Shenzhen, and Indian diff --git a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v2_fold_1.py b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v2_fold_1.py index e48e83e797e6f69942d7345d387f5e50d2946c67..0ed16e701663a05756d4c44ce2e84d1d44ca56dc 100644 --- a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v2_fold_1.py +++ b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v2_fold_1.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Aggregated datamodule composed of Montgomery, Shenzhen, Indian, and TBX11k +"""Aggregated DataModule composed of Montgomery, Shenzhen, Indian, and TBX11k datasets (cross-validation fold 1). This remix dataset combines ``fold-1`` from Montgomery, Shenzhen, and Indian diff --git a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v2_fold_2.py b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v2_fold_2.py index 1e8f38fbd000c8e026c5451abfceffe3338bbd00..2ff42bcd694ec2142f114666d5452fa5f4361b1f 100644 --- a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v2_fold_2.py +++ b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v2_fold_2.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Aggregated datamodule composed of Montgomery, Shenzhen, Indian, and TBX11k +"""Aggregated DataModule composed of Montgomery, Shenzhen, Indian, and TBX11k datasets (cross-validation fold 2). This remix dataset combines ``fold-2`` from Montgomery, Shenzhen, and Indian diff --git a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v2_fold_3.py b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v2_fold_3.py index 553ea54762ee64a672a02f9e7478ef2bd21b2405..dc4b1a4c5fed9e272fe332432a5405f866a3cbde 100644 --- a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v2_fold_3.py +++ b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v2_fold_3.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Aggregated datamodule composed of Montgomery, Shenzhen, Indian, and TBX11k +"""Aggregated DataModule composed of Montgomery, Shenzhen, Indian, and TBX11k datasets (cross-validation fold 3). This remix dataset combines ``fold-3`` from Montgomery, Shenzhen, and Indian diff --git a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v2_fold_4.py b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v2_fold_4.py index 672c9bbe3a09c88b6b50ac27d6dccbbf594a2e82..55113094b5095463ba083cc2b12d72acde53b2fc 100644 --- a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v2_fold_4.py +++ b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v2_fold_4.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Aggregated datamodule composed of Montgomery, Shenzhen, Indian, and TBX11k +"""Aggregated DataModule composed of Montgomery, Shenzhen, Indian, and TBX11k datasets (cross-validation fold 4). This remix dataset combines ``fold-4`` from Montgomery, Shenzhen, and Indian diff --git a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v2_fold_5.py b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v2_fold_5.py index b7b3c5eddaa9daf3b1181b422667c48991d77628..0bffb95e6ff791c8d95adcea66ffe5d49619e821 100644 --- a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v2_fold_5.py +++ b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v2_fold_5.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Aggregated datamodule composed of Montgomery, Shenzhen, Indian, and TBX11k +"""Aggregated DataModule composed of Montgomery, Shenzhen, Indian, and TBX11k datasets (cross-validation fold 5). This remix dataset combines ``fold-5`` from Montgomery, Shenzhen, and Indian diff --git a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v2_fold_6.py b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v2_fold_6.py index e30017374dcb6b78bb9b829e5956878666551f05..a5507999d10facdb2f000637a3e47cc028ee9233 100644 --- a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v2_fold_6.py +++ b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v2_fold_6.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Aggregated datamodule composed of Montgomery, Shenzhen, Indian, and TBX11k +"""Aggregated DataModule composed of Montgomery, Shenzhen, Indian, and TBX11k datasets (cross-validation fold 6). This remix dataset combines ``fold-6`` from Montgomery, Shenzhen, and Indian diff --git a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v2_fold_7.py b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v2_fold_7.py index 4f3b3e9167c998306e4836840f51fded3de5963b..944c1eb9bc18f6e88093a4ab3a3ea2ac2d565b07 100644 --- a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v2_fold_7.py +++ b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v2_fold_7.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Aggregated datamodule composed of Montgomery, Shenzhen, Indian, and TBX11k +"""Aggregated DataModule composed of Montgomery, Shenzhen, Indian, and TBX11k datasets (cross-validation fold 7). This remix dataset combines ``fold-7`` from Montgomery, Shenzhen, and Indian diff --git a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v2_fold_8.py b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v2_fold_8.py index afdb0f75625348df60c5b01760cd541d684c2dae..db5331c6b84d887bb2aa1e98d325d434853e1825 100644 --- a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v2_fold_8.py +++ b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v2_fold_8.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Aggregated datamodule composed of Montgomery, Shenzhen, Indian, and TBX11k +"""Aggregated DataModule composed of Montgomery, Shenzhen, Indian, and TBX11k datasets (cross-validation fold 8). This remix dataset combines ``fold-8`` from Montgomery, Shenzhen, and Indian diff --git a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v2_fold_9.py b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v2_fold_9.py index 31b1e4400fd22674df823791d404fc3d0817e347..898a46aea2dddcfc0eafc15f06b0b77b15f6f806 100644 --- a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v2_fold_9.py +++ b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v2_fold_9.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Aggregated datamodule composed of Montgomery, Shenzhen, Indian, and TBX11k +"""Aggregated DataModule composed of Montgomery, Shenzhen, Indian, and TBX11k datasets (cross-validation fold 9). This remix dataset combines ``fold-9`` from Montgomery, Shenzhen, and Indian diff --git a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v2_others_vs_atb.py b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v2_others_vs_atb.py index 9dca260604e2c65e85f9a1b06cd96fd853560e3d..b0bce0aeed787d7d3ad9dc23dba9321d3e9785a5 100644 --- a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v2_others_vs_atb.py +++ b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/v2_others_vs_atb.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Aggregated datamodule composed of Montgomery, Shenzhen, Indian, and TBX11k +"""Aggregated DataModule composed of Montgomery, Shenzhen, Indian, and TBX11k datasets (v2-others-vs-atb). This remix dataset combines the ``default`` split from Montgomery, Shenzhen, diff --git a/src/mednet/config/data/nih_cxr14/cardiomegaly.py b/src/mednet/config/data/nih_cxr14/cardiomegaly.py index 8be9fde0b963123a6becccdf3c850ea193826371..255be72e79840527a0e401ac8c5e048033038a2a 100644 --- a/src/mednet/config/data/nih_cxr14/cardiomegaly.py +++ b/src/mednet/config/data/nih_cxr14/cardiomegaly.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""NIH CXR14 (relabeled) datamodule for computer-aided diagnosis (cardiomegaly +"""NIH CXR14 (relabeled) DataModule for computer-aided diagnosis (cardiomegaly split). Database reference: [NIH-CXR14-2017]_ diff --git a/src/mednet/config/data/nih_cxr14/datamodule.py b/src/mednet/config/data/nih_cxr14/datamodule.py index b42f1d1c815694a213e90b0e01cd7faa6f08b321..5967ee63152c6e6dac0aa9b1fb6bd13d3f24a438 100644 --- a/src/mednet/config/data/nih_cxr14/datamodule.py +++ b/src/mednet/config/data/nih_cxr14/datamodule.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""NIH CXR14 (relabeled) datamodule for computer-aided diagnosis. +"""NIH CXR14 (relabeled) DataModule for computer-aided diagnosis. Database reference: [NIH-CXR14-2017]_ """ @@ -63,7 +63,7 @@ class RawDataLoader(_BaseRawDataLoader): ) def sample(self, sample: tuple[str, list[int]]) -> Sample: - """Loads a single image sample from the disk. + """Load a single image sample from the disk. Parameters ---------- @@ -72,10 +72,9 @@ class RawDataLoader(_BaseRawDataLoader): where to find the image to be loaded, and an integer, representing the sample label. - Returns ------- - The sample representation + The sample representation. """ file_path = sample[0] # default if self.idiap_file_organisation: @@ -103,7 +102,7 @@ class RawDataLoader(_BaseRawDataLoader): return tensor, dict(label=sample[1], name=sample[0]) # type: ignore[arg-type] def label(self, sample: tuple[str, list[int]]) -> list[int]: - """Loads a single image sample label from the disk. + """Load a single image sample label from the disk. Parameters ---------- @@ -112,16 +111,26 @@ class RawDataLoader(_BaseRawDataLoader): where to find the image to be loaded, and an integer, representing the sample label. - Returns ------- - The integer labels associated with the sample + list[int] + The integer labels associated with the sample. """ return sample[1] def make_split(basename: str) -> DatabaseSplit: - """Returns a database split for the NIH CXR-14 database.""" + """Return a database split for the NIH CXR-14 database. + + Parameters + ---------- + basename + Name of the .json file containing the split to load. + + Returns + ------- + An instance of DatabaseSplit. + """ return JSONDatabaseSplit( importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(basename) @@ -129,7 +138,7 @@ def make_split(basename: str) -> DatabaseSplit: class DataModule(CachingDataModule): - """NIH CXR14 (relabeled) datamodule for computer-aided diagnosis. + """NIH CXR14 (relabeled) DataModule for computer-aided diagnosis. This dataset was extracted from the clinical PACS database at the National Institutes of Health Clinical Center (USA) and represents 60% of all their @@ -172,6 +181,11 @@ class DataModule(CachingDataModule): * fibrosis * edema * consolidation + + Parameters + ---------- + split_filename + Name of the .json file containing the split to load. """ def __init__(self, split_filename: str): diff --git a/src/mednet/config/data/nih_cxr14/default.py b/src/mednet/config/data/nih_cxr14/default.py index 2fa797be4371e78115516c22545a9c9d457fdc97..8c15cd71dd85ea5c3e9a9704ab81a877d926633f 100644 --- a/src/mednet/config/data/nih_cxr14/default.py +++ b/src/mednet/config/data/nih_cxr14/default.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""NIH CXR14 (relabeled) datamodule (``default`` protocol). +"""NIH CXR14 (relabeled) DataModule (``default`` protocol). * Training samples: 98637 * Validation samples: 6350 diff --git a/src/mednet/config/data/nih_cxr14_padchest/datamodule.py b/src/mednet/config/data/nih_cxr14_padchest/datamodule.py index 09fd623f9d6503429bfa7856ce162b7d3c67e49e..2c793c7980156f14f0ab4946652bd1f627cd22c0 100644 --- a/src/mednet/config/data/nih_cxr14_padchest/datamodule.py +++ b/src/mednet/config/data/nih_cxr14_padchest/datamodule.py @@ -11,7 +11,15 @@ from ..padchest.datamodule import make_split as make_padchest_split class DataModule(ConcatDataModule): """Aggregated dataset composed of NIH CXR14 relabeld and PadChest - (normalized) datasets.""" + (normalized) datasets. + + Parameters + ---------- + cxr14_split_filename + Name of the .json file from crx14 containing the split to load. + padchest_split_filename + Name of the .json file from padchest containing the split to load. + """ def __init__(self, cxr14_split_filename: str, padchest_split_filename): cxr14_loader = CXR14Loader() diff --git a/src/mednet/config/data/nih_cxr14_padchest/idiap.py b/src/mednet/config/data/nih_cxr14_padchest/idiap.py index aed16c8fec30a2304f068f802b0a762917941bc6..6ac62f99a2097a4add1349e47652c65e9eb8a913 100644 --- a/src/mednet/config/data/nih_cxr14_padchest/idiap.py +++ b/src/mednet/config/data/nih_cxr14_padchest/idiap.py @@ -1,8 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Aggregated dataset composed of NIH CXR14 relabeld and PadChest (normalized) -datasets (no-tb-idiap split).""" +"""Aggregated dataset composed of NIH CXR14 relabeld and PadChest (normalized) datasets (no-tb-idiap split).""" from mednet.config.data.nih_cxr14_padchest.datamodule import DataModule diff --git a/src/mednet/config/data/padchest/datamodule.py b/src/mednet/config/data/padchest/datamodule.py index a065deceaf9676219f9afd7234a40f9d9124de0d..94743f68edec801d8e73c5badb69f28f45ae8beb 100644 --- a/src/mednet/config/data/padchest/datamodule.py +++ b/src/mednet/config/data/padchest/datamodule.py @@ -41,7 +41,7 @@ class RawDataLoader(_BaseRawDataLoader): ) def sample(self, sample: tuple[str, int | list[int]]) -> Sample: - """Loads a single image sample from the disk. + """Load a single image sample from the disk. Parameters ---------- @@ -50,10 +50,9 @@ class RawDataLoader(_BaseRawDataLoader): where to find the image to be loaded, and an integer, representing the sample label. - Returns ------- - The sample representation + The sample representation. """ # N.B.: PadChest images are encoded as 16-bit grayscale images image = PIL.Image.open(os.path.join(self.datadir, sample[0])) @@ -70,7 +69,7 @@ class RawDataLoader(_BaseRawDataLoader): return tensor, dict(label=sample[1], name=sample[0]) # type: ignore[arg-type] def label(self, sample: tuple[str, int | list[int]]) -> int | list[int]: - """Loads a single image sample label from the disk. + """Load a single image sample label from the disk. Parameters ---------- @@ -79,16 +78,26 @@ class RawDataLoader(_BaseRawDataLoader): where to find the image to be loaded, and an integer, representing the sample label. - Returns ------- - The integer labels associated with the sample + list[int] + The integer labels associated with the sample. """ return sample[1] def make_split(basename: str) -> DatabaseSplit: - """Returns a database split for the NIH CXR-14 database.""" + """Return a database split for the NIH CXR-14 database. + + Parameters + ---------- + basename + Name of the .json file containing the split to load. + + Returns + ------- + An instance of DatabaseSplit. + """ return JSONDatabaseSplit( importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(basename) @@ -322,6 +331,11 @@ class DataModule(CachingDataModule): * vertebral degenerative changes * vertebral fracture * volume loss + + Parameters + ---------- + split_filename + Name of the .json file containing the split to load. """ def __init__(self, split_filename: str): diff --git a/src/mednet/config/data/shenzhen/datamodule.py b/src/mednet/config/data/shenzhen/datamodule.py index 6409179465984e94e52c29f12bfb7bcc50b09596..4b04f871849abc6e3c2e997300c4349f42cb0725 100644 --- a/src/mednet/config/data/shenzhen/datamodule.py +++ b/src/mednet/config/data/shenzhen/datamodule.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Shenzhen datamodule for computer-aided diagnosis. +"""Shenzhen DataModule for computer-aided diagnosis. Database reference: [MONTGOMERY-SHENZHEN-2014]_ """ @@ -27,7 +27,14 @@ database.""" class RawDataLoader(_BaseRawDataLoader): - """A specialized raw-data-loader for the Shenzhen dataset.""" + """A specialized raw-data-loader for the Shenzhen dataset. + + Parameters + ---------- + config_variable + Key to search for in the configuration file for the root directory of this + database. + """ datadir: str """This variable contains the base directory where the database raw data is @@ -41,7 +48,7 @@ class RawDataLoader(_BaseRawDataLoader): ) def sample(self, sample: tuple[str, int]) -> Sample: - """Loads a single image sample from the disk. + """Load a single image sample from the disk. Parameters ---------- @@ -50,10 +57,9 @@ class RawDataLoader(_BaseRawDataLoader): where to find the image to be loaded, and an integer, representing the sample label. - Returns ------- - The sample representation + The sample representation. """ # N.B.: Image.convert("L") is required to normalize grayscale back to # normal (instead of inverted). @@ -72,7 +78,7 @@ class RawDataLoader(_BaseRawDataLoader): return tensor, dict(label=sample[1], name=sample[0]) # type: ignore[arg-type] def label(self, sample: tuple[str, int]) -> int: - """Loads a single image sample label from the disk. + """Load a single image sample label from the disk. Parameters ---------- @@ -81,16 +87,26 @@ class RawDataLoader(_BaseRawDataLoader): where to find the image to be loaded, and an integer, representing the sample label. - Returns ------- - The integer label associated with the sample + int + The integer label associated with the sample. """ return sample[1] def make_split(basename: str) -> DatabaseSplit: - """Returns a database split for the Shenzhen database.""" + """Return a database split for the Shenzhen database. + + Parameters + ---------- + basename + Name of the .json file containing the split to load. + + Returns + ------- + An instance of DatabaseSplit. + """ return JSONDatabaseSplit( importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(basename) @@ -98,7 +114,7 @@ def make_split(basename: str) -> DatabaseSplit: class DataModule(CachingDataModule): - """Shenzhen datamodule for computer-aided diagnosis. + """Shenzhen DataModule for computer-aided diagnosis. The standard digital image database for Tuberculosis was created by the National Library of Medicine, Maryland, USA in collaboration with Shenzhen No.3 People’s @@ -129,6 +145,11 @@ class DataModule(CachingDataModule): * Grayscale, encoded as a single plane tensor, 32-bit floats, square with varying resolutions, depending on the input image * Labels: 0 (healthy), 1 (active tuberculosis) + + Parameters + ---------- + split_filename + Name of the .json file containing the split to load. """ def __init__(self, split_filename: str): diff --git a/src/mednet/config/data/shenzhen/fold_0.py b/src/mednet/config/data/shenzhen/fold_0.py index 17ab27e504e1467da338d22a860a7a1de7fb9bc8..218211fa31368d1abeb869b7f6dae58900c0effc 100644 --- a/src/mednet/config/data/shenzhen/fold_0.py +++ b/src/mednet/config/data/shenzhen/fold_0.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Shenzhen datamodule for computer-aided diagnosis (cross validation fold 0). +"""Shenzhen DataModule for computer-aided diagnosis (cross validation fold 0). Database reference: [MONTGOMERY-SHENZHEN-2014]_ diff --git a/src/mednet/config/data/shenzhen/fold_1.py b/src/mednet/config/data/shenzhen/fold_1.py index 61ea13aa984be81b5f94f740c4848f374659d35f..cee2fd624dc617291aa9cc0a730eeb94a4a0a58c 100644 --- a/src/mednet/config/data/shenzhen/fold_1.py +++ b/src/mednet/config/data/shenzhen/fold_1.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Shenzhen datamodule for computer-aided diagnosis (cross validation fold 1). +"""Shenzhen DataModule for computer-aided diagnosis (cross validation fold 1). Database reference: [MONTGOMERY-SHENZHEN-2014]_ diff --git a/src/mednet/config/data/shenzhen/fold_2.py b/src/mednet/config/data/shenzhen/fold_2.py index edb9247796cb6baa1f1cea0b72ad8ac3728b1f48..1373a64ccc1ecbbe3104e659aa6739bb6684450f 100644 --- a/src/mednet/config/data/shenzhen/fold_2.py +++ b/src/mednet/config/data/shenzhen/fold_2.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Shenzhen datamodule for computer-aided diagnosis (cross validation fold 2). +"""Shenzhen DataModule for computer-aided diagnosis (cross validation fold 2). Database reference: [MONTGOMERY-SHENZHEN-2014]_ diff --git a/src/mednet/config/data/shenzhen/fold_3.py b/src/mednet/config/data/shenzhen/fold_3.py index 618008481b7c2892448d1dbe205cda68e33764b7..8fa1c4c2552d240126e4e502e0a74dd5366a8ca3 100644 --- a/src/mednet/config/data/shenzhen/fold_3.py +++ b/src/mednet/config/data/shenzhen/fold_3.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Shenzhen datamodule for computer-aided diagnosis (cross validation fold 3). +"""Shenzhen DataModule for computer-aided diagnosis (cross validation fold 3). Database reference: [MONTGOMERY-SHENZHEN-2014]_ diff --git a/src/mednet/config/data/shenzhen/fold_4.py b/src/mednet/config/data/shenzhen/fold_4.py index 478af2ee758e675722d0578a316a33f7a0e47898..3998c2d2246bd98e6ec20a591dd84682b2bd0d12 100644 --- a/src/mednet/config/data/shenzhen/fold_4.py +++ b/src/mednet/config/data/shenzhen/fold_4.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Shenzhen datamodule for computer-aided diagnosis (cross validation fold 4). +"""Shenzhen DataModule for computer-aided diagnosis (cross validation fold 4). Database reference: [MONTGOMERY-SHENZHEN-2014]_ diff --git a/src/mednet/config/data/shenzhen/fold_5.py b/src/mednet/config/data/shenzhen/fold_5.py index 21dcf8ff1ec64781172cb01976e4f7eca2b79d1f..71452ef25039f7195784e29f35af3d7e1bf06260 100644 --- a/src/mednet/config/data/shenzhen/fold_5.py +++ b/src/mednet/config/data/shenzhen/fold_5.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Shenzhen datamodule for computer-aided diagnosis (cross validation fold 5). +"""Shenzhen DataModule for computer-aided diagnosis (cross validation fold 5). Database reference: [MONTGOMERY-SHENZHEN-2014]_ diff --git a/src/mednet/config/data/shenzhen/fold_6.py b/src/mednet/config/data/shenzhen/fold_6.py index fbc6aac8cb40ca898790040edcffda62999ae3e9..69a51ccef6dc87bbb66cbd89d92a26a472f48ce6 100644 --- a/src/mednet/config/data/shenzhen/fold_6.py +++ b/src/mednet/config/data/shenzhen/fold_6.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Shenzhen datamodule for computer-aided diagnosis (cross validation fold 6). +"""Shenzhen DataModule for computer-aided diagnosis (cross validation fold 6). Database reference: [MONTGOMERY-SHENZHEN-2014]_ diff --git a/src/mednet/config/data/shenzhen/fold_7.py b/src/mednet/config/data/shenzhen/fold_7.py index 204f7b0b2cc981229ba4252acd4ca0fb99983f04..619a28d0625a0344053f01d1c25d96f6a2549482 100644 --- a/src/mednet/config/data/shenzhen/fold_7.py +++ b/src/mednet/config/data/shenzhen/fold_7.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Shenzhen datamodule for computer-aided diagnosis (cross validation fold 7). +"""Shenzhen DataModule for computer-aided diagnosis (cross validation fold 7). Database reference: [MONTGOMERY-SHENZHEN-2014]_ diff --git a/src/mednet/config/data/shenzhen/fold_8.py b/src/mednet/config/data/shenzhen/fold_8.py index 3679fa18b58af7f5eba68d5e41eca06fcbfa5232..1eb4278a7f3c491914d98038018ea9383bf4d8e7 100644 --- a/src/mednet/config/data/shenzhen/fold_8.py +++ b/src/mednet/config/data/shenzhen/fold_8.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Shenzhen datamodule for computer-aided diagnosis (cross validation fold 8). +"""Shenzhen DataModule for computer-aided diagnosis (cross validation fold 8). Database reference: [MONTGOMERY-SHENZHEN-2014]_ diff --git a/src/mednet/config/data/shenzhen/fold_9.py b/src/mednet/config/data/shenzhen/fold_9.py index 32afdbd8b786f94597198f3bf99ac38a9e6773a9..c112edf906ffda4eb03bdc6414c95e5d972a27be 100644 --- a/src/mednet/config/data/shenzhen/fold_9.py +++ b/src/mednet/config/data/shenzhen/fold_9.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Shenzhen datamodule for computer-aided diagnosis (cross validation fold 9). +"""Shenzhen DataModule for computer-aided diagnosis (cross validation fold 9). Database reference: [MONTGOMERY-SHENZHEN-2014]_ diff --git a/src/mednet/config/data/tbpoc/datamodule.py b/src/mednet/config/data/tbpoc/datamodule.py index ffe59568fd9524962e534deac1a72fcc1408cedd..e9d55cef980f681c078c29603a2f50906a2e316c 100644 --- a/src/mednet/config/data/tbpoc/datamodule.py +++ b/src/mednet/config/data/tbpoc/datamodule.py @@ -35,7 +35,7 @@ class RawDataLoader(_BaseRawDataLoader): ) def sample(self, sample: tuple[str, int]) -> Sample: - """Loads a single image sample from the disk. + """Load a single image sample from the disk. Parameters ---------- @@ -44,10 +44,9 @@ class RawDataLoader(_BaseRawDataLoader): where to find the image to be loaded, and an integer, representing the sample label. - Returns ------- - The sample representation + The sample representation. """ # images from TBPOC are encoded as grayscale JPEGs, no need to # call convert("L") here. @@ -64,7 +63,7 @@ class RawDataLoader(_BaseRawDataLoader): return tensor, dict(label=sample[1], name=sample[0]) # type: ignore[arg-type] def label(self, sample: tuple[str, int]) -> int: - """Loads a single image sample label from the disk. + """Load a single image sample label from the disk. Parameters ---------- @@ -73,7 +72,6 @@ class RawDataLoader(_BaseRawDataLoader): where to find the image to be loaded, and an integer, representing the sample label. - Returns ------- The integer label associated with the sample @@ -82,7 +80,17 @@ class RawDataLoader(_BaseRawDataLoader): def make_split(basename: str) -> DatabaseSplit: - """Returns a database split for the TB-POC database.""" + """Return a database split for the TB-POC database. + + Parameters + ---------- + basename + Name of the .json file containing the split to load. + + Returns + ------- + An instance of DatabaseSplit. + """ return JSONDatabaseSplit( importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(basename) @@ -118,6 +126,11 @@ class DataModule(CachingDataModule): square with varying resolutions (2048 x 2048 being the maximum), but also depending on black borders' sizes on the input image. * Labels: 0 (healthy), 1 (active tuberculosis) + + Parameters + ---------- + split_filename + Name of the .json file containing the split to load. """ def __init__(self, split_filename: str): diff --git a/src/mednet/config/data/tbx11k/datamodule.py b/src/mednet/config/data/tbx11k/datamodule.py index ae71e46b87863b1a7134f396df5a49a2dfe0f6b0..9c76bb5af431793e2668ef4a181e4e98b28e5f25 100644 --- a/src/mednet/config/data/tbx11k/datamodule.py +++ b/src/mednet/config/data/tbx11k/datamodule.py @@ -48,10 +48,11 @@ class BoundingBox: height: int def area(self) -> int: - """Computes the bounding box area. + """Compute the bounding box area. Returns ------- + int The area in square-pixels. """ return self.width * self.height @@ -65,7 +66,7 @@ class BoundingBox: return self.ymin + self.height - 1 def intersection(self, other: typing_extensions.Self) -> int: - """Computes the area intersection between bounding boxes. + """Compute the area intersection between bounding boxes. Notice that screen geometry dictates is slightly different from floating point metrics. Consider a 1D example for the evaluation of the @@ -75,14 +76,14 @@ class BoundingBox: * 2 pixels of index : i1 = 1 and i2 = 3, the segment from pixel i1 to i2 contains 3 pixels ie l = i2 - i1 + 1 - Parameters ---------- other - The other bounding box to check intersections for + The other bounding box to check intersections for. Returns ------- + int The area intersection between this and the other bounding-box in square pixels. """ @@ -96,7 +97,13 @@ class BoundingBox: class BoundingBoxes(collections.abc.Sequence[BoundingBox]): - """A collection of bounding boxes.""" + """A collection of bounding boxes. + + Parameters + ---------- + t + A sequence of BoundingBox. + """ def __init__(self, t: typing.Sequence[BoundingBox] = []): self.t = tuple(t) @@ -111,9 +118,15 @@ class BoundingBoxes(collections.abc.Sequence[BoundingBox]): # We update the default collate function map to use our custom function as # explained at: # https://pytorch.org/docs/stable/data.html#torch.utils.data.default_collate -def _collate_boundingboxes_fn(batch, *, collate_fn_map=None): - """Custom collate_fn() for pytorch dataloaders that ignores BoundingBoxes - objects.""" +def _collate_boundingboxes_fn( + batch, *, collate_fn_map=None +): # numpydoc ignore=PR01 + """Custom collate_fn() for pytorch dataloaders that ignores BoundingBoxes objects. + + Returns + ------- + The given batch. + """ return batch @@ -148,7 +161,7 @@ class RawDataLoader(_BaseRawDataLoader): ) def sample(self, sample: DatabaseSample) -> Sample: - """Loads a single image sample from the disk. + """Load a single image sample from the disk. Parameters ---------- @@ -158,10 +171,9 @@ class RawDataLoader(_BaseRawDataLoader): sample label, and possible radiological findings represented by bounding boxes. - Returns ------- - The sample representation + The sample representation. """ image = PIL.Image.open(os.path.join(self.datadir, sample[0])) tensor = to_tensor(image) @@ -178,7 +190,7 @@ class RawDataLoader(_BaseRawDataLoader): ) def label(self, sample: DatabaseSample) -> int: - """Loads a single image sample label from the disk. + """Load a single image sample label from the disk. Parameters ---------- @@ -188,15 +200,15 @@ class RawDataLoader(_BaseRawDataLoader): sample label, and possible radiological findings represented by bounding boxes. - Returns ------- - The integer label associated with the sample + int + The integer label associated with the sample. """ return sample[1] def bounding_boxes(self, sample: DatabaseSample) -> BoundingBoxes: - """Loads image annotated bounding-boxes from the disk. + """Load image annotated bounding-boxes from the disk. Parameters ---------- @@ -206,9 +218,9 @@ class RawDataLoader(_BaseRawDataLoader): sample label, and possible radiological findings represented by bounding boxes. - Returns ------- + BoundingBoxes Bounding box annotations, if any available with the sample. """ if len(sample) > 2: @@ -218,7 +230,17 @@ class RawDataLoader(_BaseRawDataLoader): def make_split(basename: str) -> DatabaseSplit: - """Returns a database split for the Montgomery database.""" + """Return a database split for the Montgomery database. + + Parameters + ---------- + basename + Name of the .json file containing the split to load. + + Returns + ------- + An instance of DatabaseSplit. + """ return JSONDatabaseSplit( importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(basename) @@ -322,6 +344,11 @@ class DataModule(CachingDataModule): (512x512 pixels) - Labels: 0 (healthy, latent tb or sick but no tb depending on the protocol), 1 (active tuberculosis) + + Parameters + ---------- + split_filename + Name of the .json file containing the split to load. """ def __init__(self, split_filename: str): diff --git a/src/mednet/config/data/tbx11k/make_splits_from_database.py b/src/mednet/config/data/tbx11k/make_splits_from_database.py index c432a2979c225a880d122a3b384cf17e3eb82194..5c097d1749ceb4e0bd69376941a2034cf02c4725 100644 --- a/src/mednet/config/data/tbx11k/make_splits_from_database.py +++ b/src/mednet/config/data/tbx11k/make_splits_from_database.py @@ -59,7 +59,18 @@ from sklearn.model_selection import StratifiedKFold, train_test_split def reorder(data: dict) -> list: - """Reorders data from TBX11K into a sample-based organisation.""" + """Reorder data from TBX11K into a sample-based organisation. + + Parameters + ---------- + data + A dictionary containing the loaded data. + + Returns + ------- + list + The reordered data. + """ categories = {k["id"]: k["name"] for k in data["categories"]} assert len(set(categories.values())) == len( @@ -98,7 +109,7 @@ def reorder(data: dict) -> list: def normalize_labels(data: list) -> list: - """Decides on the final labels for each sample. + """Decide on the final labels for each sample. Categories are decided on the following principles: @@ -112,6 +123,16 @@ def normalize_labels(data: list) -> list: bounding boxes with label 0 and no bounding box with label 1 4: sick (but no tb), comes from the imgs/sick subdir, does not have any annotated bounding box. + + Parameters + ---------- + data + A list of samples. + + Returns + ------- + list + A list of labels per sample. """ def _set_label(s: list) -> int: @@ -163,7 +184,13 @@ def normalize_labels(data: list) -> list: def print_statistics(d: dict): - """Print some statistics about the dataset.""" + """Print some statistics about the dataset. + + Parameters + ---------- + d + A dictionary of database splits. + """ label_translations = { -1: "Unknown", @@ -175,7 +202,13 @@ def print_statistics(d: dict): } def _print_dataset(ds: list): - """Print stats only for the dataset.""" + """Print stats only for the dataset. + + Parameters + ---------- + ds + The dataset to print stats for. + """ class_count = collections.Counter([k[1] for k in ds]) for k, v in class_count.items(): print(f" - {label_translations[k]}: {v}") @@ -202,21 +235,22 @@ def create_v1_default_split(d: dict, seed: int, validation_size: float) -> dict: 1. The original validation set becomes the test set. 2. The original training set is split into new training and validation sets. The selection of samples is stratified (respects class - proportions in Özgür's way - see comments) - + proportions in Özgür's way - see comments). Parameters ---------- - d - The original dataset that will be split - + The original dataset that will be split. seed - The seed to use at the relevant RNG - + The seed to use at the relevant RNG. validation_size The proportion of data when we split the training set to make a train and validation sets. + + Returns + ------- + dict + A dict containing the various v1 splits. """ # filter cases (only interested in labels 0:healthy or 1:active-tb) @@ -255,6 +289,21 @@ def create_v2_default_split(d: dict, seed: int, validation_size) -> dict: 2. The original training set is split into new training and validation sets. The selection of samples is stratified (respects class proportions in Özgür's way - see comments) + + Parameters + ---------- + d + The original dataset that will be split. + seed + The seed to use at the relevant RNG. + validation_size + The proportion of data when we split the training set to make a + train and validation sets. + + Returns + ------- + dict + A dict containing the various v2 splits. """ # filter cases (only interested in labels 0:healthy or 1:active-tb) @@ -291,23 +340,24 @@ def create_v2_default_split(d: dict, seed: int, validation_size) -> dict: def create_folds( d: dict, n: int, seed: int, validation_size: float ) -> list[dict]: - """Creates folds from existing splits. + """Create folds from existing splits. Parameters ---------- - d - The original split to consider - + The original split to consider. n - The number of folds to produce - + The number of folds to produce. + seed + The seed to use at the relevant RNG. + validation_size + The proportion of data when we split the training set to make a + train and validation sets. Returns ------- - - folds - All the ``n`` folds + list[dict] + All the ``n`` folds. """ X = d["train"] + d["validation"] + d["test"] diff --git a/src/mednet/data/augmentations.py b/src/mednet/data/augmentations.py index a104ec4ac26b1540ed2a94a0509563fbeff0b4f8..8dbb4520c91c3338c7d3c9bb5e60c84b91f1a709 100644 --- a/src/mednet/data/augmentations.py +++ b/src/mednet/data/augmentations.py @@ -35,46 +35,37 @@ def _elastic_deformation_on_image( mode: str = "nearest", p: float = 1.0, ) -> torch.Tensor: - """Performs elastic deformation on an image. + """Perform elastic deformation on an image. This implementation is based on 2 scipy functions (:py:func:`scipy.ndimage.gaussian_filter` and :py:func:`scipy.ndimage.map_coordinates`). It is very inefficient since it - requires data is moved off the current running device and then back. - + requires data to be moved off the current running device and then back. Parameters ---------- - img - The input image to apply elastic deformation at. This image should + The input image to apply elastic deformation to. This image should always have this shape: ``[C, H, W]``. It should always represent a tensor on the CPU. - alpha - A multiplier for the gaussian filter outputs - + A multiplier for the gaussian filter outputs. sigma Standard deviation for Gaussian kernel. - spline_order The order of the spline interpolation, default is 1. The order has to be in the range 0-5. - mode The mode parameter determines how the input array is extended beyond its boundaries. - p Probability that this transformation will be applied. Meaningful when using it as a data augmentation technique. - Returns ------- - tensor - A tensor on the CPU. + The image with elastic deformation applied, as a tensor on the CPU. """ if random.random() < p: @@ -143,46 +134,37 @@ def _elastic_deformation_on_batch( p: float = 1.0, pool: multiprocessing.pool.Pool | None = None, ) -> torch.Tensor: - """Performs elastic deformation on a batch of images. + """Perform elastic deformation on a batch of images. This implementation is based on 2 scipy functions (:py:func:`scipy.ndimage.gaussian_filter` and :py:func:`scipy.ndimage.map_coordinates`). It is very inefficient since it - requires data is moved off the current running device and then back. - + requires data to be moved off the current running device and then back. Parameters ---------- - - img - The input image to apply elastic deformation at. This image should - always have this shape: ``[C, H, W]``. It should always represent a - tensor on the CPU. - + batch + The batch to apply elastic deformation to. alpha - A multiplier for the gaussian filter outputs - + A multiplier for the gaussian filter outputs. sigma Standard deviation for Gaussian kernel. - spline_order The order of the spline interpolation, default is 1. The order has to be in the range 0-5. - mode The mode parameter determines how the input array is extended beyond its boundaries. - p Probability that this transformation will be applied. Meaningful when using it as a data augmentation technique. - + pool + The multiprocessing pool to use. Returns ------- - tensor - A tensor on the CPU. + A batch of images with elastic deformation applied, as a tensor on the CPU. """ # transforms our custom functions into simpler callables partial = functools.partial( @@ -210,43 +192,37 @@ class ElasticDeformation: This implementation is based on 2 scipy functions (:py:func:`scipy.ndimage.gaussian_filter` and :py:func:`scipy.ndimage.map_coordinates`). It is very inefficient since it - requires data is moved off the current running device and then back. + requires data to be moved off the current running device and then back. .. warning:: Furthermore, this transform is not scriptable and therefore cannot run - on a CUDA or MPS device. Applying it, effectively creates a bottleneck + on a CUDA or MPS device. Applying it effectively creates a bottleneck in model training. Source: https://gist.github.com/oeway/2e3b989e0343f0884388ed7ed82eb3b0 - Parameters ---------- - alpha - + A multiplier for the gaussian filter outputs. sigma Standard deviation for Gaussian kernel. - spline_order The order of the spline interpolation, default is 1. The order has to be in the range 0-5. - mode The mode parameter determines how the input array is extended beyond its boundaries. - p Probability that this transformation will be applied. Meaningful when using it as a data augmentation technique. - parallel Use multiprocessing for processing batches of data: if set to -1 (default), disables multiprocessing. If set to -2, then enable auto-tune (use the minimum value between the first batch size and total number of processing cores). Set to 0 to enable as many processes as - processing cores as available in the system. Set to >= 1 to enable that + processing cores available in the system. Set to >= 1 to enable that many processes. """ @@ -267,14 +243,19 @@ class ElasticDeformation: self.parallel = parallel @property - def parallel(self): + def parallel(self) -> int: """Use multiprocessing for data augmentation. If set to -1 (default), disables multiprocessing. If set to -2, then enable auto-tune (use the minimum value between the first batch size and total number of processing cores). Set to 0 to - enable as many processes as processing cores as available in the + enable as many processes as processing cores available in the system. Set to >= 1 to enable that many processes. + + Returns + ------- + int + The multiprocessing type. """ return self._parallel diff --git a/src/mednet/data/datamodule.py b/src/mednet/data/datamodule.py index 82353be3fb2080b9e761dba9d2b52899316b9dd9..4f79857b88eeb7a7fada2bf40f29ebb9dcfe7209 100644 --- a/src/mednet/data/datamodule.py +++ b/src/mednet/data/datamodule.py @@ -37,16 +37,27 @@ def _sample_size_bytes(s: Sample) -> int: Parameters ---------- s - The sample to be analyzed - + The sample to be analyzed. Returns ------- - The size in bytes occupied by this sample + int + The size in bytes occupied by this sample. """ def _tensor_size_bytes(t: torch.Tensor) -> int: - """Returns a tensor size in bytes.""" + """Return a tensor size in bytes. + + Parameters + ---------- + t + A torch Tensor. + + Returns + ------- + int + The size of the Tensor in bytes. + """ return int(t.element_size() * torch.prod(torch.tensor(t.shape))) size = sys.getsizeof(s[0]) # tensor metadata @@ -65,19 +76,16 @@ def _sample_size_bytes(s: Sample) -> int: class _DelayedLoadingDataset(Dataset): """A list that loads its samples on demand. - This list mimics a pytorch Dataset, except raw data loading is done + This list mimics a pytorch Dataset, except that raw data loading is done on-the-fly, as the samples are requested through the bracket operator. - Parameters ---------- raw_dataset An iterable containing the raw dataset samples representing one of the database split datasets. - loader An object instance that can load samples and labels from storage. - transforms A set of transforms that should be applied on-the-fly for this dataset, to fit the output of the raw-data-loader to the model of interest. @@ -103,7 +111,13 @@ class _DelayedLoadingDataset(Dataset): logger.info(f"Estimated sample size: {sample_size_mb:.1f} Mb") def labels(self) -> list[int | list[int]]: - """Returns the integer labels for all samples in the dataset.""" + """Return the integer labels for all samples in the dataset. + + Returns + ------- + list[int | list[int]] + The integer labels for all samples in the dataset. + """ return [self.loader.label(k) for k in self.raw_dataset] def __getitem__(self, key: int) -> Sample: @@ -129,19 +143,17 @@ def _apply_loader_and_transforms( Parameters ---------- info - The sample information, as loaded from its raw dataset dictionary - + The sample information, as loaded from its raw dataset dictionary. load - The raw-data loader function to use for loading the sample - + The raw-data loader function to use for loading the sample. model_transform A callable that will transform the loaded tensor into something suitable for the model it will train. Typically, this will be a composed transform. - Returns ------- + Sample The loaded and transformed sample. """ sample = load(info) @@ -152,25 +164,21 @@ class _CachedDataset(Dataset): """Basically, a list of preloaded samples. This dataset will load all samples from the raw dataset during construction - instead of delaying that to the indexing. Beyong raw-data-loading, + instead of delaying that to the indexing. Beyond raw-data-loading, ``transforms`` given upon construction contribute to the cached samples. - Parameters ---------- raw_dataset An iterable containing the raw dataset samples representing one of the database split datasets. - loader An object instance that can load samples and labels from storage. - parallel Use multiprocessing for data loading: if set to -1 (default), disables multiprocessing data loading. Set to 0 to enable as many data loading - instances as processing cores as available in the system. Set to >= 1 + instances as processing cores available in the system. Set to >= 1 to enable that many multiprocessing instances for data loading. - transforms A set of transforms that should be applied to the cached samples for this dataset, to fit the output of the raw-data-loader to the model of @@ -216,7 +224,13 @@ class _CachedDataset(Dataset): ) def labels(self) -> list[int | list[int]]: - """Returns the integer labels for all samples in the dataset.""" + """Return the integer labels for all samples in the dataset. + + Returns + ------- + list[int | list[int]] + The integer labels for all samples in the dataset. + """ return [k[1]["label"] for k in self.data] def __getitem__(self, key: int) -> Sample: @@ -248,7 +262,13 @@ class _ConcatDataset(Dataset): ] def labels(self) -> list[int | list[int]]: - """Returns the integer labels for all samples in the dataset.""" + """Return the integer labels for all samples in the dataset. + + Returns + ------- + list[int | list[int]] + The integer labels for all samples in the dataset. + """ return list(itertools.chain(*[k.labels() for k in self._datasets])) def __getitem__(self, key: int) -> Sample: @@ -267,7 +287,7 @@ def _make_balanced_random_sampler( dataset: Dataset, target: str = "label", ) -> torch.utils.data.WeightedRandomSampler: - """Generates a pytorch sampler that samples according to class + """Generate a pytorch sampler that samples according to class probabilities. This function takes as input a torch Dataset, and computes the weights to @@ -305,39 +325,35 @@ def _make_balanced_random_sampler( this case). To verify this, notice that the probability of picking a sample with ``target=0`` is :math:`1/4 x 1 + 1/12 x 3 = 0.5`. 2. The probability of picking a sample with ``target=0`` from Dataset 2 is - 3 times higher than those from Dataset 1. As there are 3 times less + 3 times higher than those from Dataset 1. As there are 3 times fewer samples in Dataset 2 with ``target=0``, this makes choosing samples from Dataset 1 proportionally less likely. 3. The probability of picking a sample with ``target=1`` from Dataset 2 is - 3 times lower than those from Dataset 1. As there are 3 times less + 3 times lower than those from Dataset 1. As there are 3 times fewer samples in Dataset 1 with ``target=1``, this makes choosing samples from Dataset 2 proportionally less likely. This function assumes targets are stored on a dictionary entry named ``target`` inside the metadata information for the - :py:data:`.typing.Sample`, and that its value is integer. + :py:data:`.typing.Sample`, and that its value is an integer. We then instantiate a pytorch sampler using the inverse probabilities (the - more samples of a class, the less likely it becomes to be sampled. - + more samples in a class, the less likely it becomes to be sampled. Parameters ---------- dataset An instance of torch Dataset. :py:class:`torch.utils.data.ConcatDataset` are supported. - target The name of a metadata key pointing to an integer property that allows balancing the dataset. - Returns ------- A sampler, to be used in a dataloader equipped with the same dataset used to calculate the relative sample weights. - Raises ------ RuntimeError @@ -403,22 +419,21 @@ def _make_balanced_random_sampler( class ConcatDataModule(lightning.LightningDataModule): - """A conveninent data module with dictionary split loading, mini- batching, + """A conveninent DataModule with dictionary split loading, mini- batching, parallelisation and caching, all in one. Instances of this class can load and concatenate an arbitrary number of data-split (a.k.a. protocol) definitions for (possibly disjoint) databases, and can manage raw data-loading from disk. An optional caching mechanism - stores the data at associated CPU memory, which can improve data serving + stores the data in associated CPU memory, which can improve data serving while training and evaluating models. - This datamodule defines basic operations to handle data loading and + This DataModule defines basic operations to handle data loading and mini-batch handling within this package's framework. It can return :py:class:`torch.utils.data.DataLoader` objects for training, validation, prediction and testing conditions. Parallelisation is handled by a simple input flag. - Parameters ---------- splits @@ -435,7 +450,7 @@ class ConcatDataModule(lightning.LightningDataModule): .. tip:: - To check the split and the loader function works correctly, you may + To check the split and that the loader function works correctly, you may use :py:func:`.split.check_database_split_loading`. This class expects at least one entry called ``train`` to exist in the @@ -443,55 +458,49 @@ class ConcatDataModule(lightning.LightningDataModule): Entries named ``monitor-...`` will be considered extra datasets that do not influence any early stop criteria during training, and are just monitored beyond the ``validation`` dataset. - cache_samples If set, then issue raw data loading during ``prepare_data()``, and serves samples from CPU memory. Otherwise, loads samples from disk on demand. Running from CPU memory will offer increased speeds in exchange for CPU memory. Sufficient CPU memory must be available before you set - this attribute to ``True``. It is typicall useful for relatively small + this attribute to ``True``. It is typically useful for relatively small datasets. - balance_sampler_by_class If set, then modifies the random sampler used during training and validation to balance sample picking probability, making sample across classes **and** datasets equitable. - batch_size Number of samples in every **training** batch (this parameter affects memory requirements for the network). If the number of samples in the batch is larger than the total number of samples available for training, this value is truncated. If this number is smaller, then - batches of the specified size are created and fed to the network until + batches of the specified size are created and fed to the network until there are no more new samples to feed (epoch is finished). If the total number of training samples is not a multiple of the batch-size, the last batch will be smaller than the first, unless ``drop_incomplete_batch`` is set to ``true``, in which case this batch is not used. - batch_chunk_count Number of chunks in every batch (this parameter affects memory requirements for the network). The number of samples loaded for every iteration will be ``batch_size/batch_chunk_count``. ``batch_size`` needs to be divisible by ``batch_chunk_count``, otherwise an error will - be raised. This parameter is used to reduce number of samples loaded in + be raised. This parameter is used to reduce the number of samples loaded in each iteration, in order to reduce the memory usage in exchange for - processing time (more iterations). This is specially interesting whe - one is running with GPUs with limited RAM. The default of 1 forces the + processing time (more iterations). This is especially interesting when + one is running on GPUs with limited RAM. The default of 1 forces the whole batch to be processed at once. Otherwise the batch is broken into batch-chunk-count pieces, and gradients are accumulated to complete each batch. - drop_incomplete_batch - If set, then may drop the last batch in an epoch, in case it is + If set, then may drop the last batch in an epoch in case it is incomplete. If you set this option, you should also consider - increasing the total number of epochs of training, as the total number + increasing the total number of training epochs, as the total number of training steps may be reduced. - parallel Use multiprocessing for data loading: if set to -1 (default), disables multiprocessing data loading. Set to 0 to enable as many data loading - instances as processing cores as available in the system. Set to >= 1 + instances as processing cores available in the system. Set to >= 1 to enable that many multiprocessing instances for data loading. """ @@ -540,7 +549,7 @@ class ConcatDataModule(lightning.LightningDataModule): Use multiprocessing for data loading: if set to -1 (default), disables multiprocessing data loading. Set to 0 to enable as - many data loading instances as processing cores as available in + many data loading instances as processing cores available in the system. Set to >= 1 to enable that many multiprocessing instances for data loading. @@ -570,6 +579,11 @@ class ConcatDataModule(lightning.LightningDataModule): - ``parallel`` - Runs mini-batch data loading on as many external processes as set on ``parallel`` + + Returns + ------- + int + The value of self._parallel. """ return self._parallel @@ -597,17 +611,22 @@ class ConcatDataModule(lightning.LightningDataModule): @property def model_transforms(self) -> list[Transform] | None: - """Transforms required to fit data into the model. + """Transform required to fit data into the model. A list of transforms (torch modules) that will be applied after - raw- data-loading. and just before data is fed into the model or + raw-data-loading. and just before data is fed into the model or eventual data-augmentation transformations for all data loaders - produced by this data module. This part of the pipeline - receives data as output by the raw-data-loader, or model-related + produced by this DataModule. This part of the pipeline receives + data as output by the raw-data-loader, or model-related transforms (e.g. resize adaptions), if any is specified. If data is cached, it is cached **after** model-transforms are applied, as that is a potential memory saver (e.g., if it contains a resizing operation to smaller images). + + Returns + ------- + list + A list containing the model tansforms. """ return self._model_transforms @@ -619,14 +638,14 @@ class ConcatDataModule(lightning.LightningDataModule): # datasets that have been setup() for the current stage are reset if value != old_value and len(self._datasets): logger.warning( - f"Reseting {len(self._datasets)} loaded datasets due " + f"Resetting {len(self._datasets)} loaded datasets due " "to changes in model-transform properties. If you were caching " "data loading, this will (eventually) trigger a reload." ) self._datasets = {} @property - def balance_sampler_by_class(self): + def balance_sampler_by_class(self) -> bool: """Whether to balance samples across labels/datasets. If set, then modifies the random sampler used during training @@ -640,6 +659,11 @@ class ConcatDataModule(lightning.LightningDataModule): samples acording to their ground-truth (labels). If you'd like to have samples balanced per dataset, then implement your own data module inheriting from this one. + + Returns + ------- + bool + True if self._train_sample is set, else False. """ return self._train_sampler is not None @@ -655,7 +679,7 @@ class ConcatDataModule(lightning.LightningDataModule): self._train_sampler = None def set_chunk_size(self, batch_size: int, batch_chunk_count: int) -> None: - """Coherently sets the batch-chunk-size after validation. + """Coherently set the batch-chunk-size after validation. Parameters ---------- @@ -670,7 +694,6 @@ class ConcatDataModule(lightning.LightningDataModule): the last batch will be smaller than the first, unless ``drop_incomplete_batch`` is set to ``true``, in which case this batch is not used. - batch_chunk_count Number of chunks in every batch (this parameter affects memory requirements for the network). The number of samples loaded for every @@ -678,8 +701,8 @@ class ConcatDataModule(lightning.LightningDataModule): needs to be divisible by ``batch_chunk_count``, otherwise an error will be raised. This parameter is used to reduce number of samples loaded in each iteration, in order to reduce the memory usage in exchange for - processing time (more iterations). This is specially interesting whe - one is running with GPUs with limited RAM. The default of 1 forces the + processing time (more iterations). This is especially interesting when + one is running on GPUs with limited RAM. The default of 1 forces the whole batch to be processed at once. Otherwise the batch is broken into batch-chunk-count pieces, and gradients are accumulated to complete each batch. @@ -697,7 +720,7 @@ class ConcatDataModule(lightning.LightningDataModule): self._chunk_size = self._batch_size // self._batch_chunk_count def _setup_dataset(self, name: str) -> None: - """Sets-up a single dataset from the input data split. + """Set-up a single dataset from the input data split. Parameters ---------- @@ -754,13 +777,19 @@ class ConcatDataModule(lightning.LightningDataModule): self._datasets[name] = _ConcatDataset(datasets) def _val_dataset_keys(self) -> list[str]: - """Returns list of validation dataset names.""" + """Return list of validation dataset names. + + Returns + ------- + list[str] + The list of validation dataset names. + """ return ["validation"] + [ k for k in self.splits.keys() if k.startswith("monitor-") ] def setup(self, stage: str) -> None: - """Sets up datasets for different tasks on the pipeline. + """Set up datasets for different tasks on the pipeline. This method should setup (load, pre-process, etc) all datasets required for a particular ``stage`` (fit, validate, test, predict), and keep @@ -770,11 +799,10 @@ class ConcatDataModule(lightning.LightningDataModule): If you have set ``cache_samples``, samples are loaded at this stage and cached in memory. - Parameters ---------- stage - Name of the stage to which the setup is applicable. Can be one of + Name of the stage in which the setup is applicable. Can be one of ``fit``, ``validate``, ``test`` or ``predict``. Each stage typically uses the following data loaders: @@ -805,14 +833,13 @@ class ConcatDataModule(lightning.LightningDataModule): This method unsets (unload, remove from memory, etc) all datasets required for a particular ``stage`` (fit, validate, test, predict). - If you have set ``cache_samples``, samples are loaded, this may + If you have set ``cache_samples``, samples are loaded and this may effectivley release all the associated memory. - Parameters ---------- stage - Name of the stage to which the teardown is applicable. Can be one of + Name of the stage in which the teardown is applicable. Can be one of ``fit``, ``validate``, ``test`` or ``predict``. Each stage typically uses the following data loaders: @@ -824,7 +851,12 @@ class ConcatDataModule(lightning.LightningDataModule): super().teardown(stage) def train_dataloader(self) -> DataLoader: - """Returns the train data loader.""" + """Return the train data loader. + + Returns + ------- + The train data loader(s). + """ return torch.utils.data.DataLoader( self._datasets["train"], @@ -837,7 +869,12 @@ class ConcatDataModule(lightning.LightningDataModule): ) def unshuffled_train_dataloader(self) -> DataLoader: - """Returns the train data loader without shuffling.""" + """Return the train data loader without shuffling. + + Returns + ------- + The train data loader without shuffling. + """ return torch.utils.data.DataLoader( self._datasets["train"], @@ -848,7 +885,12 @@ class ConcatDataModule(lightning.LightningDataModule): ) def val_dataloader(self) -> dict[str, DataLoader]: - """Returns the validation data loader(s)""" + """Return the validation data loader(s). + + Returns + ------- + The validation data loader(s). + """ validation_loader_opts = { "batch_size": self._chunk_size, @@ -866,7 +908,12 @@ class ConcatDataModule(lightning.LightningDataModule): } def test_dataloader(self) -> dict[str, DataLoader]: - """Returns the test data loader(s)""" + """Return the test data loader(s). + + Returns + ------- + The test data loader(s). + """ return dict( test=torch.utils.data.DataLoader( @@ -880,7 +927,12 @@ class ConcatDataModule(lightning.LightningDataModule): ) def predict_dataloader(self) -> dict[str, DataLoader]: - """Returns the prediction data loader(s)""" + """Return the prediction data loader(s). + + Returns + ------- + The prediction data loader(s). + """ return { k: torch.utils.data.DataLoader( @@ -896,12 +948,11 @@ class ConcatDataModule(lightning.LightningDataModule): class CachingDataModule(ConcatDataModule): - """A simplified version of our data module for a single split. + """A simplified version of our DataModule for a single split. - Apart from construction, the behaviour of this data module is very similar + Apart from construction, the behaviour of this DataModule is very similar to its simpler counterpart, serving training, validation and test sets. - Parameters ---------- database_split @@ -923,12 +974,10 @@ class CachingDataModule(ConcatDataModule): Entries named ``monitor-...`` will be considered extra datasets that do not influence any early stop criteria during training, and are just monitored beyond the ``validation`` dataset. - raw_data_loader An object instance that can load samples and labels from storage. - **kwargs - List if named parameters matching those of + List of named parameters matching those of :py:class:`ConcatDataModule`, other than ``splits``. """ diff --git a/src/mednet/data/image_utils.py b/src/mednet/data/image_utils.py index 3dce9cd27ac8cf5958a4c87d7233bf9efb265915..c0e0ff6a6e9c64b994c54f9e0842cfb1af05df8a 100644 --- a/src/mednet/data/image_utils.py +++ b/src/mednet/data/image_utils.py @@ -15,14 +15,14 @@ def remove_black_borders( Parameters ---------- img - A PIL image + A PIL image. threshold Threshold value from which borders are considered black. Defaults to 0. Returns ------- - A PIL image with black borders removed + A PIL image with black borders removed. """ img_array = numpy.asarray(img) diff --git a/src/mednet/data/split.py b/src/mednet/data/split.py index 2d5370fe8e2e2f98d9438c2001ded5a555800ece..5d70aafd551b18fca856f846d91818442bb63c03 100644 --- a/src/mednet/data/split.py +++ b/src/mednet/data/split.py @@ -18,7 +18,7 @@ logger = logging.getLogger(__name__) class JSONDatabaseSplit(DatabaseSplit): - """Defines a loader that understands a database split (train, test, etc) in + """Define a loader that understands a database split (train, test, etc) in JSON format. To create a new database split, you need to provide a JSON formatted @@ -49,7 +49,7 @@ class JSONDatabaseSplit(DatabaseSplit): } Your database split many contain any number of (raw) datasets (dictionary - keys). For simplicity, we recommend all sample entries are formatted + keys). For simplicity, we recommend to format all sample entries similarly so that raw-data-loading is simplified. Use the function :py:func:`check_database_split_loading` to test raw data loading and fine tune the dataset split, or its loading. @@ -59,12 +59,10 @@ class JSONDatabaseSplit(DatabaseSplit): actual JSON file descriptors are loaded on demand using a py:func:`functools.cached_property`. - Parameters ---------- - path - Absolute path to a JSON formatted file containing the database split to be + Absolute path to a .json formatted file containing the database split to be recognized by this object. """ @@ -75,17 +73,15 @@ class JSONDatabaseSplit(DatabaseSplit): @functools.cached_property def _datasets(self) -> DatabaseSplit: - """Datasets in a split. + """Return the DatabaseSplits. - The first call to this (cached) property will trigger full JSON file + The first call to this (cached) property will trigger full .json file loading from disk. Subsequent calls will be cached. - Returns ------- - - datasets : dict - A dictionary mapping dataset names to lists of JSON objects + DatabaseSplit + A dictionary mapping dataset names to lists of JSON objects. """ if str(self._path).endswith(".bz2"): @@ -97,27 +93,27 @@ class JSONDatabaseSplit(DatabaseSplit): return json.load(f) def __getitem__(self, key: str) -> typing.Sequence[typing.Any]: - """Accesses dataset ``key`` from this split.""" + """Access dataset ``key`` from this split.""" return self._datasets[key] def __iter__(self): - """Iterates over the datasets.""" + """Iterate over the datasets.""" return iter(self._datasets) def __len__(self) -> int: - """How many datasets we currently have.""" + """The number of datasets we currently have.""" return len(self._datasets) class CSVDatabaseSplit(DatabaseSplit): - """Defines a loader that understands a database split (train, test, etc) in + """Define a loader that understands a database split (train, test, etc) in CSV format. To create a new database split, you need to provide one or more CSV formatted files, each representing a dataset of this split, containing the sample data (one per row). Example: - Inside the directory ``my-split/``, one can file files ``train.csv``, + Inside the directory ``my-split/``, one can find the files ``train.csv``, ``validation.csv``, and ``test.csv``. Each file has a structure similar to the following: @@ -127,19 +123,17 @@ class CSVDatabaseSplit(DatabaseSplit): sample2-value1,sample2-value2,sample2-value3 ... - Each file in the provided directory defines the dataset name on the split. + Each file in the provided directory defines the dataset name of the split. So, the file ``train.csv`` will contain the data from the ``train`` dataset, and so on. Objects of this class behave like a dictionary in which keys are dataset names in the split, and values represent samples data and meta-data. - Parameters ---------- - directory - Absolute path to a directory containing the database split layed down + Absolute path to a directory containing the database split organized as a set of CSV files, one per dataset. """ @@ -155,17 +149,14 @@ class CSVDatabaseSplit(DatabaseSplit): @functools.cached_property def _datasets(self) -> DatabaseSplit: - """Datasets in a split. + """Return the DatabaseSplits. The first call to this (cached) property will trigger all CSV file loading from disk. Subsequent calls will be cached. - Returns ------- - - datasets : dict - A dictionary mapping dataset names to lists of JSON objects + A dictionary mapping dataset names to lists of JSON objects. """ retval: dict[str, typing.Sequence[typing.Any]] = {} @@ -188,15 +179,15 @@ class CSVDatabaseSplit(DatabaseSplit): return retval def __getitem__(self, key: str) -> typing.Sequence[typing.Any]: - """Accesses dataset ``key`` from this split.""" + """Accesse dataset ``key`` from this split.""" return self._datasets[key] def __iter__(self): - """Iterates over the datasets.""" + """Iterate over the datasets.""" return iter(self._datasets) def __len__(self) -> int: - """How many datasets we currently have.""" + """The number of datasets we currently have.""" return len(self._datasets) @@ -208,35 +199,29 @@ def check_database_split_loading( """For each dataset in the split, check if all data can be correctly loaded using the provided loader function. - This function will return the number of errors loading samples, and will + This function will return the number of errors when loading samples, and will log more detailed information to the logging stream. - Parameters ---------- - database_split - A mapping that, contains the database split. Each key represents the + A mapping that contains the database split. Each key represents the name of a dataset in the split. Each value is a (potentially complex) object that represents a single sample. - loader A loader object that knows how to handle full-samples or just labels. - limit Maximum number of samples to check (in each split/dataset combination) in this dataset. If set to zero, then check everything. - Returns ------- - - errors - Number of errors found + int + Number of errors found. """ logger.info( - "Checking if can load all samples in all datasets of this split..." + "Checking if all samples in all datasets of this split can be loaded..." ) errors = 0 for dataset, samples in database_split.items(): diff --git a/src/mednet/data/typing.py b/src/mednet/data/typing.py index c1df54c62eea0bf5005e2e3f5c0381e62b2e07a7..e5b68d982af5ae421ce94b442dbbfa5617843d4f 100644 --- a/src/mednet/data/typing.py +++ b/src/mednet/data/typing.py @@ -25,21 +25,37 @@ class RawDataLoader: """A loader object can load samples and labels from storage.""" def sample(self, _: typing.Any) -> Sample: - """Loads whole samples from media.""" + """Load whole samples from media. + + Parameters + ---------- + _ + Information about the sample to load. Implementation dependent. + """ raise NotImplementedError("You must implement the `sample()` method") def label(self, k: typing.Any) -> int | list[int]: - """Loads only sample label from media. + """Load only sample label from media. If you do not override this implementation, then, by default, this method will call :py:meth:`sample` to load the whole sample and extract the label. + + Parameters + ---------- + k + The sample to load. This is implementation-dependent. + + Returns + ------- + int | list[int] + The label corresponding to the specified sample. """ return self.sample(k)[1]["label"] Transform: typing.TypeAlias = typing.Callable[[torch.Tensor], torch.Tensor] -"""A callable, that transforms tensors into (other) tensors. +"""A callable that transforms tensors into (other) tensors. Typically used in data-processing pipelines inside pytorch. """ @@ -72,14 +88,14 @@ be assigned a different :py:class:`.RawDataLoader`. class Dataset(torch.utils.data.Dataset[Sample], typing.Iterable, typing.Sized): - """Our own definition of a pytorch Dataset, with interesting properties. + """Our own definition of a pytorch Dataset. We iterate over Sample objects in this case. Our datasets always provide a dunder len method. """ def labels(self) -> list[int | list[int]]: - """Returns the integer labels for all samples in the dataset.""" + """Return the integer labels for all samples in the dataset.""" raise NotImplementedError("You must implement the `labels()` method") diff --git a/src/mednet/engine/callbacks.py b/src/mednet/engine/callbacks.py index cba08495a6cf3a75a9c801d65aada006d4782787..fd0da8843b5f32baf407d098eb9751dcaf1b4080 100644 --- a/src/mednet/engine/callbacks.py +++ b/src/mednet/engine/callbacks.py @@ -21,14 +21,12 @@ class LoggingCallback(lightning.pytorch.Callback): Rationale: 1. Losses are logged at the end of every batch, accumulated and handled by - the lightning framework + the lightning framework. 2. Everything else is done at the end of a training or validation epoch and mostly concerns runtime metrics such as memory and cpu/gpu utilisation. - Parameters ---------- - train_resource_monitor A monitor that watches resource usage (CPU/GPU) in a separate process and totally asynchronously with the code execution. @@ -67,15 +65,12 @@ class LoggingCallback(lightning.pytorch.Callback): This method is executed whenever you *start* training a module. - Parameters - --------- - + ---------- trainer - The Lightning trainer object - + The Lightning trainer object. pl_module - The lightning module that is being trained + The lightning module that is being trained. """ self._start_training_time = time.time() @@ -96,15 +91,12 @@ class LoggingCallback(lightning.pytorch.Callback): This is executed **while** you are training. Be very succint or face the consequences of slow training! - Parameters - --------- - + ---------- trainer - The Lightning trainer object - + The Lightning trainer object. pl_module - The lightning module that is being trained + The lightning module that is being trained. """ # summarizes resource usage since the last checkpoint # clears internal buffers and starts accumulating again. @@ -123,15 +115,12 @@ class LoggingCallback(lightning.pytorch.Callback): epochs happen as often as possible. You want to make this code relatively fast to avoid significative runtime slow-downs. - Parameters ---------- - trainer - The Lightning trainer object - + The Lightning trainer object. pl_module - The lightning module that is being trained + The lightning module that is being trained. """ # evaluates this training epoch total time, and log it @@ -189,24 +178,18 @@ class LoggingCallback(lightning.pytorch.Callback): This is executed **while** you are training. Be very succint or face the consequences of slow training! - Parameters ---------- - trainer - The Lightning trainer object - + The Lightning trainer object. pl_module - The lightning module that is being trained - + The lightning module that is being trained. outputs - The outputs of the module's ``training_step`` - + The outputs of the module's ``training_step``. batch - The data that the training step received - + The data that the training step received. batch_idx - The relative number of the batch + The relative number of the batch. """ pl_module.log( "loss/train", @@ -234,15 +217,12 @@ class LoggingCallback(lightning.pytorch.Callback): This is executed **while** you are training. Be very succint or face the consequences of slow training! - Parameters - --------- - + ---------- trainer - The Lightning trainer object - + The Lightning trainer object. pl_module - The lightning module that is being trained + The lightning module that is being trained. """ # required because the validation epoch is started **within** the # training epoch START/END. @@ -265,15 +245,12 @@ class LoggingCallback(lightning.pytorch.Callback): epochs happen as often as possible. You want to make this code relatively fast to avoid significative runtime slow-downs. - Parameters ---------- - trainer - The Lightning trainer object - + The Lightning trainer object. pl_module - The lightning module that is being trained + The lightning module that is being trained. """ # summarizes resource usage since the last checkpoint @@ -322,25 +299,18 @@ class LoggingCallback(lightning.pytorch.Callback): This is executed **while** you are training. Be very succint or face the consequences of slow training! - Parameters ---------- - trainer - The Lightning trainer object - + The Lightning trainer object. pl_module - The lightning module that is being trained - + The lightning module that is being trained. outputs - The outputs of the module's ``training_step`` - + The outputs of the module's ``training_step``. batch - The data that the training step received - + The data that the training step received. batch_idx - The relative number of the batch - + The relative number of the batch. dataloader_idx Index of the dataloader used during validation. Use this to figure out which dataset was used for this validation epoch. diff --git a/src/mednet/engine/device.py b/src/mednet/engine/device.py index d11aacffb0f14960a4c3434beccb6e283bd6c940..b008c4cbe952c33bf4ace294426e1362ff15b205 100644 --- a/src/mednet/engine/device.py +++ b/src/mednet/engine/device.py @@ -21,8 +21,18 @@ SupportedPytorchDevice: typing.TypeAlias = typing.Literal[ def _split_int_list(s: str) -> list[int]: - """Splits a list of integers encoded in a string (e.g. "1,2,3") into a - Python list of integers (e.g. ``[1, 2, 3]``).""" + """Split a list of integers encoded in a string (e.g. "1,2,3") into a Python list of integers (e.g. ``[1, 2, 3]``). + + Parameters + ---------- + s + A list of integers encoded in a string. + + Returns + ------- + list[int] + A Python list of integers. + """ return [int(k.strip()) for k in s.split(",")] @@ -38,10 +48,8 @@ class DeviceManager: Instances of this class also manage the environment variable ``$CUDA_VISIBLE_DEVICES`` if necessary. - Parameters ---------- - name The name of the device to use, in the form of a string defined by ``[\\S+][:\\d[,\\d]?]?`` (e.g.: ``cpu``, ``mps``, or ``cuda:3``). In @@ -114,7 +122,7 @@ class DeviceManager: ) def torch_device(self) -> torch.device: - """Returns a representation of the torch device to use by default. + """Return a representation of the torch device to use by default. .. warning:: @@ -122,11 +130,9 @@ class DeviceManager: device. This may impact Nvidia GPU logging in the case multiple GPU cards are used. - Returns ------- - - device + torch.device The **first** torch device (if a list of ids is set). """ @@ -144,16 +150,14 @@ class DeviceManager: ) def lightning_accelerator(self) -> tuple[str, int | list[int] | str]: - """Returns the lightning accelerator setup. + """Return the lightning accelerator setup. Returns ------- - accelerator - The lightning accelerator to use - + The lightning accelerator to use. devices - The lightning devices to use + The lightning devices to use. """ devices: int | list[int] | str = self.device_ids diff --git a/src/mednet/engine/evaluator.py b/src/mednet/engine/evaluator.py index 829b9a6b04db10beb414bba5ed04aadb24b472b9..cd24260014c5374e692c9a10cce9427e859692dc 100644 --- a/src/mednet/engine/evaluator.py +++ b/src/mednet/engine/evaluator.py @@ -24,7 +24,7 @@ logger = logging.getLogger(__name__) def eer_threshold(predictions: Iterable[BinaryPrediction]) -> float: - """Calculates the (approximate) threshold leading to the equal error rate. + """Calculate the (approximate) threshold leading to the equal error rate. Parameters ---------- @@ -32,9 +32,9 @@ def eer_threshold(predictions: Iterable[BinaryPrediction]) -> float: An iterable of multiple :py:data:`.models.typing.BinaryPrediction`'s. - Returns ------- + float The EER threshold value. """ from scipy.interpolate import interp1d @@ -51,19 +51,20 @@ def eer_threshold(predictions: Iterable[BinaryPrediction]) -> float: def _get_centered_maxf1( f1_scores: numpy.typing.NDArray, thresholds: numpy.typing.NDArray -): +) -> tuple[float, float]: """Return the centered max F1 score threshold when multiple thresholds give the same max F1 score. Parameters ---------- f1_scores - 1D array of f1 scores + 1D array of f1 scores. thresholds - 1D array of thresholds + 1D array of thresholds. Returns ------- + tuple(float, float) A tuple with the maximum F1-score and the "centered" threshold. """ maxf1 = f1_scores.max() @@ -79,7 +80,7 @@ def _get_centered_maxf1( def maxf1_threshold(predictions: Iterable[BinaryPrediction]) -> float: - """Calculates the threshold leading to the maximum F1-score on a precision- + """Calculate the threshold leading to the maximum F1-score on a precision- recall curve. Parameters @@ -88,9 +89,9 @@ def maxf1_threshold(predictions: Iterable[BinaryPrediction]) -> float: An iterable of multiple :py:data:`.models.typing.BinaryPrediction`'s. - Returns ------- + float The threshold value leading to the maximum F1-score on the provided set of predictions. """ @@ -117,22 +118,22 @@ def _score_plot( title: str, threshold: float, ) -> matplotlib.figure.Figure: - """Plots the normalized score distributions for all systems. + """Plot the normalized score distributions for all systems. Parameters ---------- labels - True labels (negatives and positives) for each entry in ``scores`` + True labels (negatives and positives) for each entry in ``scores``. scores - Likelihoods provided by the classification model, for each sample + Likelihoods provided by the classification model, for each sample. title Title of the plot. threshold - Shows where the threshold is in the figure - + Shows where the threshold is in the figure. Returns ------- + matplotlib.figure.Figure A single (matplotlib) plot containing the score distribution, ready to be saved to disk or displayed. """ @@ -192,27 +193,30 @@ def run_binary( dict[str, matplotlib.figure.Figure], dict[str, typing.Any], ]: - """Runs inference and calculates measures for binary classification. + """Run inference and calculates measures for binary classification. Parameters ---------- name The name of subset to load. predictions - A list of predictions to consider for measurement + A list of predictions to consider for measurement. threshold_a_priori A threshold to use, evaluated *a priori*, if must report single values. - If this value is not provided, a *a posteriori* threshold is calculated + If this value is not provided, an *a posteriori* threshold is calculated on the input scores. This is a biased estimator. - Returns ------- + tuple[ + dict[str, typing.Any], + dict[str, matplotlib.figure.Figure], + dict[str, typing.Any]] A tuple containing the following entries: * summary: A dictionary containing the performance summary on the - specified threshold - * figures: A dictionary of generated standalone figures + specified threshold. + * figures: A dictionary of generated standalone figures. * curves: A dictionary containing curves that can potentially be combined with other prediction lists to make aggregate plots. """ @@ -277,24 +281,23 @@ def run_binary( def aggregate_summaries( data: typing.Sequence[typing.Mapping[str, typing.Any]], fmt: str ) -> str: - """Tabulates summaries from multiple splits. + """Tabulate summaries from multiple splits. This function can properly tabulate the various summaries produced for all the splits in a prediction database. - Parameters ---------- data - An iterable over all summary data collected + An iterable over all summary data collected. fmt One of the formats supported by `python-tabulate <https://pypi.org/project/tabulate/>`_. - Returns ------- - A string containing the tabulated information + str + A string containing the tabulated information. """ headers = list(data[0].keys()) @@ -306,23 +309,22 @@ def aggregate_roc( data: typing.Mapping[str, typing.Any], title: str = "ROC", ) -> matplotlib.figure.Figure: - """Aggregates ROC curves from multiple splits. + """Aggregate ROC curves from multiple splits. This function produces a single ROC plot for multiple curves generated per split. - Parameters ---------- data A dictionary mapping split names to ROC curve data produced by :py:func:sklearn.metrics.roc_curve`. - title The title of the plot. Returns ------- + matplotlib.figure.Figure A figure, containing the aggregated ROC plot. """ fig, ax = plt.subplots(1, 1) @@ -389,20 +391,19 @@ def aggregate_roc( def _precision_recall_canvas() -> ( Iterator[tuple[matplotlib.figure.Figure, matplotlib.figure.Axes]] ): - """Generates a canvas to draw precision-recall curves. + """Generate a canvas to draw precision-recall curves. Works like a context manager, yielding a figure and an axes set in which the precision-recall curves should be added to. The figure already contains F1-ISO lines and is preset to a 0-1 square region. Once the context is finished, ``fig.tight_layout()`` is called. - Yields ------ figure - The figure that should be finally returned to the user + The figure that should be finally returned to the user. axes - An axis set where to precision-recall plots should be added to + An axis set where to precision-recall plots should be added to. """ fig, axes1 = plt.subplots(1) @@ -461,13 +462,12 @@ def aggregate_pr( data: typing.Mapping[str, typing.Any], title: str = "Precision-Recall Curve", ) -> matplotlib.figure.Figure: - """Aggregates PR curves from multiple splits. + """Aggregate PR curves from multiple splits. This function produces a single Precision-Recall plot for multiple curves generated per split. The plot will be annotated with F1-score iso-lines (in which the F1-score maintains the same value). - Parameters ---------- data @@ -476,9 +476,9 @@ def aggregate_pr( title The title of the plot. - Returns ------- + matplotlib.figure.Figure A figure, containing the aggregated PR plot. """ diff --git a/src/mednet/engine/loggers.py b/src/mednet/engine/loggers.py index 94c2ca3b64324fcaef701d2941bfc0ad84177310..8e14f774a44e1feaeadc79ecb8b9c0b62ba91927 100644 --- a/src/mednet/engine/loggers.py +++ b/src/mednet/engine/loggers.py @@ -15,7 +15,6 @@ class CustomTensorboardLogger(TensorBoardLogger): This implementation puts all logs inside the same directory, instead of a separate "version_n" directories, which is the default lightning behaviour. - Parameters ---------- save_dir @@ -44,7 +43,7 @@ class CustomTensorboardLogger(TensorBoardLogger): passed then logs are saved in ``/save_dir/name/version/sub_dir/``. Defaults to ``None`` in which logs are saved in ``/save_dir/name/version/``. - \**kwargs: + \**kwargs Additional arguments used by :py:class:`tensorboardX.SummaryWriter` can be passed as keyword arguments in this logger. To automatically flush to disk, ``max_queue`` sets the size of the queue for pending logs before diff --git a/src/mednet/engine/predictor.py b/src/mednet/engine/predictor.py index 9b6c62a177fe3b187ddc947d8b6dd430ef066be5..ae9ef271a8e0d72275a3e70fbc57f24a3e4978f3 100644 --- a/src/mednet/engine/predictor.py +++ b/src/mednet/engine/predictor.py @@ -31,40 +31,47 @@ def run( | MultiClassPredictionSplit | None ): - """Runs inference on input data, outputs csv files with predictions. + """Run inference on input data, outputs csv files with predictions. Parameters - --------- + ---------- model Neural network model (e.g. pasa). datamodule - The lightning datamodule to use for training **and** validation + The lightning DataModule to use for training **and** validation. device_manager An internal device representation, to be used for training and validation. This representation can be converted into a pytorch device - or a torch lightning accelerator setup. - + or a lightning accelerator setup. Returns ------- - Depending on the return type of the datamodule's + ( + list[BinaryPrediction] + | list[MultiClassPrediction] + | list[list[BinaryPrediction]] + | list[list[MultiClassPrediction]] + | BinaryPredictionSplit + | MultiClassPredictionSplit + | None + ) + Depending on the return type of the DataModule's ``predict_dataloader()`` method: * if :py:class:`torch.utils.data.DataLoader`, then returns a - :py:class:`list` of predictions + :py:class:`list` of predictions. * if :py:class:`list` of :py:class:`torch.utils.data.DataLoader`, then returns a list of lists of predictions, each list corresponding to the iteration over one of the dataloaders. * if :py:class:`dict` of :py:class:`str` to :py:class:`torch.utils.data.DataLoader`, then returns a dictionary - mapping names to lists of predictions - * if ``None``, then returns ``None`` - + mapping names to lists of predictions. + * if ``None``, then returns ``None``. Raises ------ TypeError - If the datamodule's ``predict_dataloader()`` method does not return any + If the DataModule's ``predict_dataloader()`` method does not return any of the types described above. """ diff --git a/src/mednet/engine/saliency/completeness.py b/src/mednet/engine/saliency/completeness.py index c711c7b3272ddeeda4e645c4d436a4a18376610d..6c8be633336b0e8a5f5150f5d2946fe14cb16fbd 100644 --- a/src/mednet/engine/saliency/completeness.py +++ b/src/mednet/engine/saliency/completeness.py @@ -43,33 +43,32 @@ def _calculate_road_scores( saliency_map_callable: typing.Callable, percentiles: typing.Sequence[int], ) -> tuple[float, float, float]: - """Calculates average ROAD scores for different removal percentiles. + """Calculate average ROAD scores for different removal percentiles. This function calculates ROAD scores by averaging the scores for different removal (hardcoded) percentiles, for a single input image, a - given visualization method, a target class. - + given visualization method, and a target class. Parameters ---------- model Neural network model (e.g. pasa). images - A batch of input images to use evaluating the ROAD scores. Currently, + A batch of input images to use for evaluating the ROAD scores. Currently, we only support batches with a single image. output_num Target output neuron to take into consideration when evaluating the - saliency maps and calculating ROAD scores + saliency maps and calculating ROAD scores. saliency_map_callable - A callable saliency-map generator from grad-cam + A callable saliency-map generator from grad-cam. percentiles A sequence of percentiles (percent x100) integer values indicating the proportion of pixels to perturb in the original image to calculate both MoRF and LeRF scores. - Returns ------- + tuple[float, float, float] A 3-tuple containing floating point numbers representing the most-relevant-first average score (``morf``), least-relevant-first average score (``lerf``) and the combined value (``(lerf-morf)/2``). @@ -89,7 +88,7 @@ def _calculate_road_scores( # current processing bottleneck. If you want to optimise anyting, look at # the evaluation of the perturbation using scipy.sparse at the # NoisyLinearImputer, part of the grad-cam package (submodule - # ``metrics.road``. + # ``metrics.road``). metric_target = [SigmoidClassifierOutputTarget(output_num)] MoRF_scores = cam_metric_ROADMoRF_avg( @@ -127,23 +126,35 @@ def _process_sample( Parameters ---------- + sample + The Sample to process. model Neural network model (e.g. pasa). device The device to process samples on. saliency_map_callable - A callable saliency-map generator from grad-cam + A callable saliency-map generator from grad-cam. target_class - Class to target for saliency estimation. Can be either set to - "all" or "highest". "highest". - positive only + Class to target for saliency estimation. Can be set to + "all" or "highest". "highest" is default, which means + only saliency maps for the class with the highest + activation will be generated. + positive_only If set, and the model chosen has a single output (binary), then - saliency maps will only be generated for samples of the positive class - + saliency maps will only be generated for samples of the positive class. percentiles A sequence of percentiles (percent x100) integer values indicating the proportion of pixels to perturb in the original image to calculate both MoRF and LeRF scores. + + Returns + ------- + list + A list containing the following items for a particular sample: + * The relative path to the sample. + * The label. + * An index to the specified target_class. + * The computed ROAD scores. """ name: str = sample[1]["name"][0] @@ -204,16 +215,16 @@ def run( percentiles: typing.Sequence[int], parallel: int, ) -> dict[str, list[typing.Any]]: - """Evaluates ROAD scores for all samples in a datamodule. + """Evaluate ROAD scores for all samples in a DataModule. - The ROAD algorithm was first described at [ROAD-2022]_. It estimates + The ROAD algorithm was first described in [ROAD-2022]_. It estimates explainability (in the completeness sense) of saliency maps by substituting - relevant pixels in the input image by a local average, and re-running + relevant pixels in the input image by a local average, re-running prediction on the altered image, and measuring changes in the output classification score when said perturbations are in place. By substituting - most or least relevant pixels with surrounding averages, the ROAD algorithm + the most or least relevant pixels with surrounding averages, the ROAD algorithm estimates the importance of such elements in the produced saliency map. As - 2023, this measurement technique is considered to be one of the + of 2023, this measurement technique is considered to be one of the state-of-the-art metrics of explainability. This function returns a dictionary containing most-relevant-first (remove a @@ -221,22 +232,21 @@ def run( percentile of the least relevant pixels), and combined ROAD evaluations per sample for a particular saliency mapping algorithm. - Parameters - --------- + ---------- model Neural network model (e.g. pasa). datamodule - The lightning datamodule to iterate on. + The lightning DataModule to iterate on. device_manager An internal device representation, to be used for training and validation. This representation can be converted into a pytorch device - or a torch lightning accelerator setup. + or a lightning accelerator setup. saliency_map_algorithm The algorithm for saliency map estimation to use. target_class (Use only with multi-label models) Which class to target for CAM - calculation. Can be either set to "all" or "highest". "highest" is + calculation. Can be set to "all" or "highest". "highest" is default, which means only saliency maps for the class with the highest activation will be generated. positive_only @@ -250,14 +260,13 @@ def run( parallel Use multiprocessing for data processing: if set to -1, disables multiprocessing. Set to 0 to enable as many data processing instances - as processing cores as available in the system. Set to >= 1 to enable + as processing cores available in the system. Set to >= 1 to enable that many multiprocessing instances for data processing. - Returns ------- - - A dictionary where keys are dataset names in the provide datamodule, + dict[str, list[typing.Any]] + A dictionary where keys are dataset names in the provide DataModule, and values are lists containing sample information alongside metrics calculated: @@ -298,8 +307,8 @@ def run( raise RuntimeError( f"The number of multiprocessing instances is set to {parallel} and " f"you asked to use a GPU (device = `{device_manager.device_type}`" - f"). The currently implementation can only handle a single GPU. " - f"Either disable GPU utilisation or set the number of " + f"). The current implementation can only handle a single GPU. " + f"Either disable GPU usage, set the number of " f"multiprocessing instances to one, or disable multiprocessing " "entirely (ie. set it to -1)." ) diff --git a/src/mednet/engine/saliency/evaluator.py b/src/mednet/engine/saliency/evaluator.py index f8a9c04eb46cfc4b17003b2ccbe1873d6e987a4b..0941dc60731d9e7950f95de1aaae0e51ae2cd5a6 100644 --- a/src/mednet/engine/saliency/evaluator.py +++ b/src/mednet/engine/saliency/evaluator.py @@ -16,25 +16,24 @@ def _reconcile_metrics( completeness: list, interpretability: list, ) -> list[tuple[str, int, float, float, float]]: - """Summarizes samples into a new table containing most important scores. + """Summarize samples into a new table containing the most important scores. - It returns a list containing a table with completeness and road scorse per + It returns a list containing a table with completeness and ROAD scores per sample, for the selected dataset. Only samples for which a completness and interpretability scores are availble are returned in the reconciled list. - Parameters ---------- completeness - A dictionary containing various tables with the sample name and + A list containing various tables with the sample name and completness (ROAD) scores. interpretability - A dictionary containing various tables with the sample name and + A list containing various tables with the sample name and interpretability (Pro. Energy) scores. - Returns ------- + list[tuple[str, int, float, float, float]] A list containing a table with the sample name, target label, completeness score (Average ROAD across different ablation thresholds), interpretability score (Proportional Energy), and the ROAD-Weighted @@ -81,23 +80,23 @@ def _make_histogram( xlim: tuple[float, float] | None = None, title: None | str = None, ) -> matplotlib.figure.Figure: - """Builds an histogram of values. + """Build an histogram of values. Parameters ---------- name - Name of the variable to be histogrammed (will appear in the figure) + Name of the variable to be histogrammed (will appear in the figure). values - Values to be histogrammed + Values to be histogrammed. xlim A tuple representing the X-axis maximum and minimum to plot. If not set, then use the bin boundaries. title - A title to set on the histogram - + A title to set on the histogram. Returns ------- + matplotlib.figure.Figure A matplotlib figure containing the histogram. """ @@ -142,7 +141,7 @@ def _make_histogram( def summary_table( summary: dict[SaliencyMapAlgorithm, dict[str, typing.Any]], fmt: str ) -> str: - """Tabulates various summaries into one table. + """Tabulate various summaries into one table. Parameters ---------- @@ -153,9 +152,9 @@ def summary_table( One of the formats supported by `python-tabulate <https://pypi.org/project/tabulate/>`_. - Returns ------- + str A string containing the tabulated information. """ @@ -185,7 +184,7 @@ def _extract_statistics( dataset: str, xlim: tuple[float, float] | None = None, ) -> dict[str, typing.Any]: - """Extracts all meaningful statistics from a reconciled statistics set. + """Extract all meaningful statistics from a reconciled statistics set. Parameters ---------- @@ -196,18 +195,17 @@ def _extract_statistics( produced by completeness and interpretability analysis as returned by :py:func:`_reconcile_metrics`. name - The name of the variable being analysed + The name of the variable being analysed. index - Which of the indexes on the tuples containing in ``data`` that should - be extracted. + The index of the tuple contained in ``data`` that should be extracted. dataset - The name of the dataset being analysed + The name of the dataset being analysed. xlim - Limits for histogram plotting - + Limits for histogram plotting. Returns ------- + dict[str, typing.Any] A dictionary containing the following elements: * ``values``: A list of values corresponding to the index on the data @@ -247,7 +245,7 @@ def run( completeness: dict[str, list], interpretability: dict[str, list], ) -> dict[str, typing.Any]: - """Evaluates multiple saliency map algorithms and produces summarized + """Evaluate multiple saliency map algorithms and produces summarized results. Parameters @@ -261,9 +259,9 @@ def run( A dictionary mapping dataset names to tables with the sample name and interpretability (among which Prop. Energy) scores. - Returns ------- + dict[str, typing.Any] A dictionary with most important statistical values for the main completeness (AOPC-Combined), interpretability (Prop. Energy), and a combination of both (ROAD-Weighted Prop. Energy) scores. diff --git a/src/mednet/engine/saliency/generator.py b/src/mednet/engine/saliency/generator.py index cb101e98ffeca29ad9c7d8a1e5cae7492f01c381..e99f78879f6049b6520c1525ef69b2db4f89143e 100644 --- a/src/mednet/engine/saliency/generator.py +++ b/src/mednet/engine/saliency/generator.py @@ -24,7 +24,23 @@ def _create_saliency_map_callable( target_layers: list[torch.nn.Module] | None, use_cuda: bool, ): - """Creates a class activation map (CAM) instance for a given model.""" + """Create a class activation map (CAM) instance for a given model. + + Parameters + ---------- + algo_type + The algorithm to use for saliency map estimation. + model + Neural network model (e.g. pasa). + target_layers + The target layers to compute CAM for. + use_cuda + Whether to use cuda or not. + + Returns + ------- + A class activation map (CAM) instance for the given model. + """ import pytorch_grad_cam @@ -93,7 +109,7 @@ def _save_saliency_map( """Helper function to save a saliency map to disk. Parameters - --------- + ---------- output_folder Directory in which the resulting saliency maps will be saved. name @@ -117,19 +133,19 @@ def run( positive_only: bool, output_folder: pathlib.Path, ) -> None: - """Applies saliency mapping techniques on input CXR, outputs pickled - saliency maps directly to disk. + """Apply saliency mapping techniques on input CXR, outputs pickled saliency + maps directly to disk. Parameters - --------- + ---------- model Neural network model (e.g. pasa). datamodule - The lightning datamodule to iterate on. + The lightning DataModule to iterate on. device_manager An internal device representation, to be used for training and validation. This representation can be converted into a pytorch device - or a torch lightning accelerator setup. + or a lightning accelerator setup. saliency_map_algorithm The algorithm to use for saliency map estimation. target_class @@ -143,7 +159,7 @@ def run( a multi-class output model. output_folder Where to save all the saliency maps (this path should exist before - this function is called) + this function is called). """ from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget diff --git a/src/mednet/engine/saliency/interpretability.py b/src/mednet/engine/saliency/interpretability.py index 89873d0e7125ef15716d062ef13413da1dc4bb6c..aeca5b6f1049a0a0d363375595452ce56d51b389 100644 --- a/src/mednet/engine/saliency/interpretability.py +++ b/src/mednet/engine/saliency/interpretability.py @@ -19,13 +19,17 @@ from ...config.data.tbx11k.datamodule import BoundingBox, BoundingBoxes logger = logging.getLogger(__name__) +SaliencyMap: typing.TypeAlias = ( + typing.Sequence[typing.Sequence[float]] | numpy.typing.NDArray[numpy.double] +) +BinaryMask: typing.TypeAlias = numpy.typing.NDArray[numpy.bool_] + def _ordered_connected_components( - saliency_map: typing.Sequence[typing.Sequence[float]] - | numpy.typing.NDArray[numpy.double], + saliency_map: SaliencyMap, threshold: float, -) -> list[numpy.typing.NDArray[numpy.bool_]]: - """Calculates the largest connected components available on a saliency map +) -> list[BinaryMask]: + """Calculate the largest connected components available on a saliency map and return those as individual masks. This implementation is based on [SCORECAM-2020]_: @@ -40,7 +44,6 @@ def _ordered_connected_components( 4. We histogram the labels and return one binary mask for each label, sorted by decreasing size. - Parameters ---------- saliency_map @@ -51,9 +54,9 @@ def _ordered_connected_components( map. A value of 0.2 will zero all values in the saliency map that are bellow 20% of the maximum value observed in the said map. - Returns ------- + list[BinaryMask] A list of boolean masks, one for each connected component, ordered by decreasing size. This list may be empty if the input ``saliency_map`` is all zeroes. @@ -76,18 +79,18 @@ def _ordered_connected_components( def _extract_bounding_box( - mask: numpy.typing.NDArray[numpy.bool_], + mask: BinaryMask, ) -> BoundingBox: - """Defines a bounding box surrounding a connected component mask. + """Define a bounding box surrounding a connected component mask. Parameters ---------- mask The connected component mask from whom extract the bounding box. - Returns ------- + BoundingBox A bounding box. """ x, y, x2, y2 = torchvision.ops.masks_to_boxes(torch.tensor(mask)[None, :])[ @@ -100,13 +103,12 @@ def _compute_max_iou_and_ioda( detected_box: BoundingBox, gt_bboxes: BoundingBoxes, ) -> tuple[float, float]: - """Will calculate how much of detected area lies in ground truth boxes. + """Calculate how much of detected area lies in ground truth boxes. If there are multiple gt boxes, the detected area will be calculated for each gt box separately and the gt box with the highest intersecting part will be used for the calculation. - Parameters ---------- detected_box @@ -115,9 +117,9 @@ def _compute_max_iou_and_ioda( Ground-truth bounding boxes in the format ``(x, y, width, height)``. - Returns ------- + tuple[float, float] The max iou and ioda values. """ detected_area = detected_box.area() @@ -147,19 +149,17 @@ def _compute_max_iou_and_ioda( def _get_largest_bounding_boxes( - saliency_map: typing.Sequence[typing.Sequence[float]] - | numpy.typing.NDArray[numpy.double], + saliency_map: SaliencyMap, n: int, threshold: float = 0.2, ) -> list[BoundingBox]: - """Returns the N largest connected components as bounding boxes in a + """Return the N largest connected components as bounding boxes in a saliency map. The return of values is subject to the value of ``threshold`` applied, as well as on the saliency map itself. The number of objects found is also affected by those parameters. - Parameters ---------- saliency_map @@ -173,9 +173,9 @@ def _get_largest_bounding_boxes( map. A value of 0.2 will zero all values in the saliency map that are bellow 20% of the maximum value observed in the said map. - Returns ------- + list[BoundingBox] The N largest connected components as bounding boxes in a saliency map. """ @@ -192,14 +192,12 @@ def _compute_simultaneous_iou_and_ioda( detected_box: BoundingBox, gt_bboxes: BoundingBoxes, ) -> tuple[float, float]: - """Will calculate how much of detected area lies between ground truth - boxes. + """Calculate how much of detected area lies between ground truth boxes. This means that if there are multiple gt boxes, the detected area will be compared to them simultaneously (and not to each gt box separately). - Parameters ---------- detected_box @@ -208,9 +206,9 @@ def _compute_simultaneous_iou_and_ioda( Collection of bounding boxes of the ground-truth drawn as ``True`` values. - Returns ------- + tuple[float, float] The iou and ioda for the provided boxes. """ @@ -227,18 +225,45 @@ def _compute_simultaneous_iou_and_ioda( return float(iou), float(ioda) +def _compute_iou_ioda_from_largest_bbox( + gt_bboxes: BoundingBoxes, + saliency_map: SaliencyMap, +) -> tuple[float, float]: + """Calculate the metrics for a single sample. + + Parameters + ---------- + gt_bboxes + A list of ground-truth bounding boxes. + saliency_map + A real-valued saliency-map that conveys regions used for + classification in the original sample. + + Returns + ------- + tuple[float, float] + A tuple containing the iou and ioda for the largest bounding box. + """ + + largest_bbox = _get_largest_bounding_boxes(saliency_map, n=1, threshold=0.2) + detected_box = ( + largest_bbox[0] if largest_bbox else BoundingBox(-1, 0, 0, 0, 0) + ) + iou, ioda = _compute_max_iou_and_ioda(detected_box, gt_bboxes) + return (iou, ioda) + + def _compute_avg_saliency_focus( - saliency_map: numpy.typing.NDArray[numpy.double], - gt_mask: numpy.typing.NDArray[numpy.bool_], + saliency_map: SaliencyMap, + gt_mask: BinaryMask, ) -> float: - """Integrates the saliency map over the ground-truth boxes and normalizes - by total bounding-box area. + """Integrate the saliency map over the ground-truth boxes and normalizes by + total bounding-box area. This function will integrate (sum) the value of the saliency map over the ground-truth bounding boxes and normalize it by the total area covered by all ground-truth bounding boxes. - Parameters ---------- saliency_map @@ -248,9 +273,9 @@ def _compute_avg_saliency_focus( Ground-truth mask containing the bounding boxes of the ground-truth drawn as ``True`` values. - Returns ------- + float A single floating-point number representing the Average saliency focus. """ @@ -262,10 +287,10 @@ def _compute_avg_saliency_focus( def _compute_proportional_energy( - saliency_map: numpy.typing.NDArray[numpy.double], - gt_mask: numpy.typing.NDArray[numpy.bool_], + saliency_map: SaliencyMap, + gt_mask: BinaryMask, ) -> float: - """Calculates how much activation lies within the ground truth boxes + """Calculate how much activation lies within the ground truth boxes compared to the total sum of the activations (integral). Parameters @@ -277,9 +302,9 @@ def _compute_proportional_energy( Ground-truth mask containing the bounding boxes of the ground-truth drawn as ``True`` values. - Returns ------- + float A single floating-point number representing the proportional energy. """ @@ -293,26 +318,24 @@ def _compute_proportional_energy( def _compute_binary_mask( gt_bboxes: BoundingBoxes, - saliency_map: numpy.typing.NDArray[numpy.double], -) -> numpy.typing.NDArray[numpy.bool_]: - """Computes a binary mask for the saliency map using BoundingBoxes. + saliency_map: SaliencyMap, +) -> BinaryMask: + """Compute a binary mask for the saliency map using BoundingBoxes. The binary_mask will be ON/True where the gt boxes are located. - Parameters ---------- gt_bboxes Ground-truth bounding boxes in the format ``(x, y, width, height)``. - saliency_map A real-valued saliency-map that conveys regions used for classification in the original sample. - Returns ------- + BinaryMask A numpy array of the same size as saliency_map with the value False everywhere except at the positions inside the bounding boxes, which will be True. @@ -329,9 +352,9 @@ def _compute_binary_mask( def _process_sample( gt_bboxes: BoundingBoxes, - saliency_map: numpy.typing.NDArray[numpy.double], + saliency_map: SaliencyMap, ) -> tuple[float, float]: - """Calculates the metrics for a single sample. + """Calculate the metrics for a single sample. Parameters ---------- @@ -341,39 +364,20 @@ def _process_sample( A real-valued saliency-map that conveys regions used for classification in the original sample. - Returns ------- + tuple[float, float] A tuple containing the following values: - * IoU - * IoDA * Proportional energy * Average saliency focus - * Largest detected bounding box """ - # largest_bbox = _get_largest_bounding_boxes(saliency_map, n=1, threshold=0.2) - # detected_box = ( - # largest_bbox[0] if largest_bbox else BoundingBox(-1, 0, 0, 0, 0) - # ) - # - # # Calculate localization metrics - # iou, ioda = _compute_max_iou_and_ioda(detected_box, gt_bboxes) - binary_mask = _compute_binary_mask(gt_bboxes, saliency_map) return ( - # iou, - # ioda, _compute_proportional_energy(saliency_map, binary_mask), _compute_avg_saliency_focus(saliency_map, binary_mask), - # ( - # detected_box.xmin, - # detected_box.ymin, - # detected_box.width, - # detected_box.height, - # ), ) @@ -382,11 +386,11 @@ def run( target_label: int, datamodule: lightning.pytorch.LightningDataModule, ) -> dict[str, list[typing.Any]]: - """Applies visualization techniques on input CXR, outputs images with - overlaid heatmaps and csv files with measurements. + """Compute the proportional energy and average saliency focus for a given + target label in a DataModule. Parameters - --------- + ---------- input_folder Directory in which the saliency maps are stored for a specific visualization type. @@ -394,12 +398,12 @@ def run( The label to target for evaluating interpretability metrics. Samples contining any other label are ignored. datamodule - The lightning datamodule to iterate on. - + The lightning DataModule to iterate on. Returns ------- - A dictionary where keys are dataset names in the provide datamodule, + dict[str, list[typing.Any]] + A dictionary where keys are dataset names in the provided DataModule, and values are lists containing sample information alongside metrics calculated: @@ -441,7 +445,7 @@ def run( if not bboxes: logger.warning( - f"Sample `{name}` does not contdain bounding-box information. " + f"Sample `{name}` does not contain bounding-box information. " f"No localization metrics can be calculated in this case. " f"Skipping..." ) diff --git a/src/mednet/engine/saliency/viewer.py b/src/mednet/engine/saliency/viewer.py index 3c0a7efe1300e069c47189274fc556f177050759..6e383801d10f0d0f4dc9875f85cc78cee0515084 100644 --- a/src/mednet/engine/saliency/viewer.py +++ b/src/mednet/engine/saliency/viewer.py @@ -53,29 +53,28 @@ def _overlay_saliency_map( ], image_weight: float, ) -> PIL.Image.Image: - """Creates an overlayed represention of the saliency map on the original + """Create an overlayed represention of the saliency map on the original image. This is a slightly modified version of the show_cam_on_image implementation in: https://github.com/jacobgil/pytorch-grad-cam, but uses matplotlib instead of opencv. - Parameters ---------- image - The input imge that will be overlayed with the saliency map + The input image that will be overlayed with the saliency map. saliencies - The saliency map that will be overlaid on the (raw) image + The saliency map that will be overlaid on the (raw) image. colormap - The name of the (matplotlib) colormap to be used + The name of the (matplotlib) colormap to be used. image_weight The final result is ``image_weight * image + (1-image_weight) * saliency_map``. - Returns ------- + PIL.Image.Image A modified version of the input ``image`` with the overlaid saliency map. """ @@ -114,14 +113,14 @@ def _overlay_bounding_box( color: str, width: int, ) -> PIL.Image.Image: - """Draws ground-truth on the input image. + """Draw ground-truth on the input image. Parameters ---------- image - The input imge that will be overlayed with the saliency map + The input image that will be overlayed with the saliency map. bbox - The bounding box to draw on the input image + The bounding box to draw on the input image. color The color to use for drawing the bounding box. Any of the colours in :any:`PIL.ImageColor.colormap` are accepted. @@ -129,9 +128,9 @@ def _overlay_bounding_box( The width of the bounding box, in pixels. A larger value creates a bounding box that is thicker, towards the outside of the boxed area. - Returns ------- + PIL.Image.Image A modified version of the input ``image`` with the ground-truth drawn on the top. """ @@ -150,23 +149,23 @@ def _process_sample( saliencies: numpy.typing.NDArray[numpy.double], ground_truth: BoundingBoxes, ) -> PIL.Image.Image: - """Generates an overlayed representation of the original sample and - saliency maps. + """Generate an overlayed representation of the original sample and saliency + maps. Parameters ---------- raw_data The raw data representing the input sample that will be overlayed with - saliency maps and annotations + saliency maps and annotations. saliencies - The saliency map recovered from the model, that will be inprinted on - the raw_data + The saliency map recovered from the model, that will be imprinted on + the raw_data. ground_truth - Ground-truth annotations that may be inprinted on the final image - + Ground-truth annotations that may be imprinted on the final image. Returns ------- + PIL.Image.Image An image with the original raw data overlayed with the different elements as selected by the user. """ @@ -196,12 +195,12 @@ def run( show_groundtruth: bool, threshold: float, ): - """Overlays saliency maps on CXR to output final images with heatmaps. + """Overlay saliency maps on CXR to output final images with heatmaps. Parameters ---------- datamodule - The lightning datamodule to iterate on. + The Lightning DataModule to iterate on. input_folder Directory in which the saliency maps are stored for a specific visualization type. @@ -209,12 +208,12 @@ def run( The label to target for evaluating interpretability metrics. Samples contining any other label are ignored. output_folder - Directory in which the resulting visualisations will be saved. + Directory in which the resulting visualizations will be saved. show_groundtruth - If set, inprint ground truth labels over the original image and + If set, imprint ground truth labels over the original image and saliency maps. threshold : float - The pixel values above ``threshold``% of max value are kept in the + The pixel values above ``threshold`` % of max value are kept in the original saliency map. Everything else is set to zero. The value proposed on [SCORECAM-2020]_ is 0.2. Use this value if unsure. """ diff --git a/src/mednet/engine/trainer.py b/src/mednet/engine/trainer.py index d4b39c4b224d2499cc74743dbaeb334f05653eec..3da73a1124d7c94704d0a6b2228eedd75664777d 100644 --- a/src/mednet/engine/trainer.py +++ b/src/mednet/engine/trainer.py @@ -30,25 +30,19 @@ def save_model_summary( output_folder: pathlib.Path, model: torch.nn.Module, ) -> tuple[lightning.pytorch.callbacks.ModelSummary, int]: - """Saves a little summary of the model in a txt file. + """Save a little summary of the model in a txt file. Parameters ---------- - output_folder - output path - + Directory in which to save the summary. model - Network (e.g. driu, hed, unet) - + Instance of the model for which to save the summary. Returns ------- - summary - The model summary in a text format - - total_parameters - The number of parameters of the model + tuple[lightning.pytorch.callbacks.ModelSummary, int] + A tuple with the model summary in a text format and number of parameters of the model. """ summary_path = output_folder / "model-summary.txt" logger.info(f"Saving model summary at {summary_path}...") @@ -70,20 +64,17 @@ def static_information_to_csv( device_type: SupportedPytorchDevice, model_size: int, ) -> None: - """Saves the static information in a CSV file. + """Save the static information in a CSV file. Parameters ---------- - static_logfile_name The static file name which is a join between the output folder and - "constants.csv" - + "constants.csv". device_type - The type of device we are using - + The type of device we are using. model_size - The size of the model we will be training + The size of the model we will be training. """ if static_logfile_name.exists(): backup = static_logfile_name.parent / (static_logfile_name.name + "~") @@ -124,21 +115,17 @@ def run( batch_chunk_count: int, checkpoint: pathlib.Path | None, ): - """Fits a CNN model using supervised learning and save it to disk. + """Fit a CNN model using supervised learning and save it to disk. This method supports periodic checkpointing and the output of a CSV-formatted log with the evolution of some figures during training. - Parameters ---------- - model Neural network model (e.g. pasa). - datamodule - The lightning datamodule to use for training **and** validation - + The lightning DataModule to use for training **and** validation. validation_period Number of epochs after which validation happens. By default, we run validation after every training epoch (period=1). You can change this @@ -148,31 +135,25 @@ def run( triggers the overriding of latest checkpoint), and that this process is independent of validation runs, evaluation of the 'best' model obtained so far based on those will be influenced by this setting. - device_manager An internal device representation, to be used for training and validation. This representation can be converted into a pytorch device - or a torch lightning accelerator setup. - + or a lightning accelerator setup. max_epochs The maximum number of epochs to train for. - output_folder Folder in which the results will be saved. - monitoring_interval Interval, in seconds (or fractions), through which we should monitor resources during training. - batch_chunk_count If this number is different than 1, then each batch will be divided in this number of chunks. Gradients will be accumulated to perform each mini-batch. This is particularly interesting when one has limited RAM on the GPU, but would like to keep training with larger batches. One exchanges for longer processing times in this case. - checkpoint - Path to an optional checkpoint file to load + Path to an optional checkpoint file to load. """ os.makedirs(output_folder, exist_ok=True) diff --git a/src/mednet/models/alexnet.py b/src/mednet/models/alexnet.py index 5e5b3e7eee3d810aaa4c8dfed31ad3d9fa5414b5..0b2d9e180f0333da40f2d568ca2202b3ddb43b9b 100644 --- a/src/mednet/models/alexnet.py +++ b/src/mednet/models/alexnet.py @@ -35,7 +35,6 @@ class Alexnet(pl.LightningModule): The loss should be set to always return batch averages (as opposed to the batch sum), as our logging system expects it so. - validation_loss The loss to be used for validation (may be different from the training loss). If extra-validation sets are provided, the same loss will be @@ -45,21 +44,16 @@ class Alexnet(pl.LightningModule): The loss should be set to always return batch averages (as opposed to the batch sum), as our logging system expects it so. - optimizer_type - The type of optimizer to use for training - + The type of optimizer to use for training. optimizer_arguments Arguments to the optimizer after ``params``. - augmentation_transforms An optional sequence of torch modules containing transforms to be applied on the input **before** it is fed into the network. - pretrained If set to True, loads pretrained model weights during initialization, else trains a new model. - num_classes Number of outputs (classes) for this model. """ @@ -118,32 +112,33 @@ class Alexnet(pl.LightningModule): return x def on_save_checkpoint(self, checkpoint: Checkpoint) -> None: - """Called by Lightning to restore your model. - - If you saved something with on_save_checkpoint() this is your chance to - restore this. + """Called by Lightning when saving a checkpoint to give you a chance to + store anything else you might want to save. Use on_load_checkpoint() to + restore what additional data is saved here. Parameters ---------- checkpoint - Loaded checkpoint + The checkpoint to save. """ checkpoint["normalizer"] = self.normalizer def on_load_checkpoint(self, checkpoint: Checkpoint) -> None: - """Called by Lightning when saving a checkpoint to give you a chance to - store anything else you might want to save. + """Called by Lightning to restore your model. + + If you saved something with on_save_checkpoint() this is your chance to + restore this. Parameters ---------- checkpoint - Loaded checkpoint + The loaded checkpoint. """ logger.info("Restoring normalizer from checkpoint.") self.normalizer = checkpoint["normalizer"] def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None: - """Initializes the normalizer for the current model. + """Initialize the normalizer for the current model. This function is NOOP if ``pretrained = True`` (normalizer set to imagenet weights, during contruction). diff --git a/src/mednet/models/densenet.py b/src/mednet/models/densenet.py index 333edb11e2112db9ef1f1bf37b2d333f5b418de6..2bd6a2dd01ecf8389b1f32a4e783b0c4315af26e 100644 --- a/src/mednet/models/densenet.py +++ b/src/mednet/models/densenet.py @@ -33,7 +33,6 @@ class Densenet(pl.LightningModule): The loss should be set to always return batch averages (as opposed to the batch sum), as our logging system expects it so. - validation_loss The loss to be used for validation (may be different from the training loss). If extra-validation sets are provided, the same loss will be @@ -43,21 +42,16 @@ class Densenet(pl.LightningModule): The loss should be set to always return batch averages (as opposed to the batch sum), as our logging system expects it so. - optimizer_type - The type of optimizer to use for training - + The type of optimizer to use for training. optimizer_arguments Arguments to the optimizer after ``params``. - augmentation_transforms An optional sequence of torch modules containing transforms to be applied on the input **before** it is fed into the network. - pretrained If set to True, loads pretrained model weights during initialization, else trains a new model. - num_classes Number of outputs (classes) for this model. """ @@ -118,32 +112,33 @@ class Densenet(pl.LightningModule): return x def on_save_checkpoint(self, checkpoint: Checkpoint) -> None: - """Called by Lightning to restore your model. - - If you saved something with on_save_checkpoint() this is your chance to - restore this. + """Called by Lightning when saving a checkpoint to give you a chance to + store anything else you might want to save. Use on_load_checkpoint() to + restore what additional data is saved here. Parameters ---------- checkpoint - Loaded checkpoint + The checkpoint to save. """ checkpoint["normalizer"] = self.normalizer def on_load_checkpoint(self, checkpoint: Checkpoint) -> None: - """Called by Lightning when saving a checkpoint to give you a chance to - store anything else you might want to save. + """Called by Lightning to restore your model. + + If you saved something with on_save_checkpoint() this is your chance to + restore this. Parameters ---------- checkpoint - Loaded checkpoint + The loaded checkpoint. """ logger.info("Restoring normalizer from checkpoint.") self.normalizer = checkpoint["normalizer"] def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None: - """Initializes the normalizer for the current model. + """Initialize the normalizer for the current model. This function is NOOP if ``pretrained = True`` (normalizer set to imagenet weights, during contruction). diff --git a/src/mednet/models/logistic_regression.py b/src/mednet/models/logistic_regression.py index 6a88d9675e09d3d6f0fb1e0a4d221ac948c0ae26..fd3281ec720ce4b926e5ecdc7b8208365769422e 100644 --- a/src/mednet/models/logistic_regression.py +++ b/src/mednet/models/logistic_regression.py @@ -23,7 +23,6 @@ class LogisticRegression(pl.LightningModule): The loss should be set to always return batch averages (as opposed to the batch sum), as our logging system expects it so. - validation_loss The loss to be used for validation (may be different from the training loss). If extra-validation sets are provided, the same loss will be @@ -33,13 +32,10 @@ class LogisticRegression(pl.LightningModule): The loss should be set to always return batch averages (as opposed to the batch sum), as our logging system expects it so. - optimizer_type - The type of optimizer to use for training - + The type of optimizer to use for training. optimizer_arguments Arguments to the optimizer after ``params``. - input_size The number of inputs this classifer shall process. """ diff --git a/src/mednet/models/loss_weights.py b/src/mednet/models/loss_weights.py index 8cf79b7757f3f677d151870f5fd1e571501ed673..8051f63b27ad7d43b04ca0316738a48c336326a3 100644 --- a/src/mednet/models/loss_weights.py +++ b/src/mednet/models/loss_weights.py @@ -15,30 +15,24 @@ logger = logging.getLogger(__name__) def _get_label_weights( dataloader: torch.utils.data.DataLoader, ) -> torch.Tensor: - """Computes the weights of each class of a DataLoader. + """Compute the weights of each class of a DataLoader. This function inputs a pytorch DataLoader and computes the ratio between number of negative and positive samples (scalar). The weight can be used to adjust minimisation criteria to in cases there is a huge data imbalance. - If - It returns a vector with weights (inverse counts) for each label. - Parameters ---------- - dataloader A DataLoader from which to compute the positive weights. Entries must be a dictionary which must contain a ``label`` key. - Returns ------- - - positive_weights - the positive weight of each class in the dataset given as input + torch.Tensor + The positive weight of each class in the dataset given as input. """ targets = torch.tensor( @@ -75,17 +69,20 @@ def _get_label_weights( def make_balanced_bcewithlogitsloss( dataloader: DataLoader, ) -> torch.nn.BCEWithLogitsLoss: - """Returns a balanced binary-cross-entropy loss. + """Return a balanced binary-cross-entropy loss. The loss is weighted using the ratio between positives and total examples available. + Parameters + ---------- + dataloader + The DataLoader to use to compute the BCE weights. Returns ------- - - loss - An instance of the weighted loss + torch.nn.BCEWithLogitsLoss + An instance of the weighted loss. """ weights = _get_label_weights(dataloader) diff --git a/src/mednet/models/mlp.py b/src/mednet/models/mlp.py index ac59ad6f6302355c352995cff9b4ee7208ed46e6..831b7385d71dc43d696ae07d62c923a76bb23653 100644 --- a/src/mednet/models/mlp.py +++ b/src/mednet/models/mlp.py @@ -22,7 +22,6 @@ class MultiLayerPerceptron(pl.LightningModule): The loss should be set to always return batch averages (as opposed to the batch sum), as our logging system expects it so. - validation_loss The loss to be used for validation (may be different from the training loss). If extra-validation sets are provided, the same loss will be @@ -32,16 +31,12 @@ class MultiLayerPerceptron(pl.LightningModule): The loss should be set to always return batch averages (as opposed to the batch sum), as our logging system expects it so. - optimizer_type - The type of optimizer to use for training - + The type of optimizer to use for training. optimizer_arguments Arguments to the optimizer after ``params``. - input_size The number of inputs this classifer shall process. - hidden_size The number of neurons on the single hidden layer. """ diff --git a/src/mednet/models/normalizer.py b/src/mednet/models/normalizer.py index 576f21cc6db52c0a6f7bfbe9c59009194afc9e14..6e09f51e3972a98e4f4c3e4cc359e44b4516aada 100644 --- a/src/mednet/models/normalizer.py +++ b/src/mednet/models/normalizer.py @@ -13,23 +13,20 @@ import tqdm def make_z_normalizer( dataloader: torch.utils.data.DataLoader, ) -> torchvision.transforms.Normalize: - """Computes mean and standard deviation from a dataloader. + """Compute mean and standard deviation from a dataloader. This function will input a dataloader, and compute the mean and standard deviation by image channel. It will work for both monochromatic, and color inputs with 2, 3 or more color planes. - Parameters ---------- - - dataloader: - A torch Dataloader from which to compute the mean and std - + dataloader + A torch Dataloader from which to compute the mean and std. Returns ------- - An initialized normalizer + An initialized normalizer. """ # Peek the number of channels of batches in the data loader @@ -58,15 +55,14 @@ def make_z_normalizer( def make_imagenet_normalizer() -> torchvision.transforms.Normalize: - """Returns the stock ImageNet normalisation weights from torchvision. + """Return the stock ImageNet normalisation weights from torchvision. The weights are wrapped in a torch module. This normalizer only works for **RGB (color) images**. - Returns ------- - An initialized normalizer + An initialized normalizer. """ return torchvision.transforms.Normalize( diff --git a/src/mednet/models/pasa.py b/src/mednet/models/pasa.py index a7b8ae62c0c56bc2c5172058fdd3382c987514a1..9007ab1b8d7e5c79b29bc65d4a996783bfc7889e 100644 --- a/src/mednet/models/pasa.py +++ b/src/mednet/models/pasa.py @@ -29,7 +29,6 @@ class Pasa(pl.LightningModule): This network has a linear output. You should use losses with ``WithLogit`` instead of cross-entropy versions when training. - Parameters ---------- train_loss @@ -39,7 +38,6 @@ class Pasa(pl.LightningModule): The loss should be set to always return batch averages (as opposed to the batch sum), as our logging system expects it so. - validation_loss The loss to be used for validation (may be different from the training loss). If extra-validation sets are provided, the same loss will be @@ -49,17 +47,13 @@ class Pasa(pl.LightningModule): The loss should be set to always return batch averages (as opposed to the batch sum), as our logging system expects it so. - optimizer_type - The type of optimizer to use for training - + The type of optimizer to use for training. optimizer_arguments Arguments to the optimizer after ``params``. - augmentation_transforms An optional sequence of torch modules containing transforms to be applied on the input **before** it is fed into the network. - num_classes Number of outputs (classes) for this model. """ @@ -204,37 +198,38 @@ class Pasa(pl.LightningModule): return x def on_save_checkpoint(self, checkpoint: Checkpoint) -> None: - """Called by Lightning to restore your model. - - If you saved something with on_save_checkpoint() this is your chance to - restore this. + """Called by Lightning when saving a checkpoint to give you a chance to + store anything else you might want to save. Use on_load_checkpoint() to + restore what additional data is saved here. Parameters ---------- checkpoint - Loaded checkpoint + The checkpoint to save. """ checkpoint["normalizer"] = self.normalizer def on_load_checkpoint(self, checkpoint: Checkpoint) -> None: - """Called by Lightning when saving a checkpoint to give you a chance to - store anything else you might want to save. + """Called by Lightning to restore your model. + + If you saved something with on_save_checkpoint() this is your chance to + restore this. Parameters ---------- checkpoint - Loaded checkpoint + The loaded checkpoint. """ logger.info("Restoring normalizer from checkpoint.") self.normalizer = checkpoint["normalizer"] def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None: - """Initializes the input normalizer for the current model. + """Initialize the input normalizer for the current model. Parameters ---------- dataloader - A torch Dataloader from which to compute the mean and std + A torch Dataloader from which to compute the mean and std. """ from .normalizer import make_z_normalizer diff --git a/src/mednet/models/separate.py b/src/mednet/models/separate.py index 9568721169a58a69c1f54f57bcca05398cb02e3c..9e5c0ee7b1c604521b2ce868e96837373fa967e3 100644 --- a/src/mednet/models/separate.py +++ b/src/mednet/models/separate.py @@ -14,24 +14,24 @@ from .typing import BinaryPrediction, MultiClassPrediction def _as_predictions( samples: typing.Iterable[Sample], ) -> list[BinaryPrediction | MultiClassPrediction]: - """Takes a list of separated batch predictions and transform into a list of - formal predictions. + """Take a list of separated batch predictions and transforms it into a list + of formal predictions. Parameters ---------- samples A sequence of samples as returned by :py:func:`separate`. - Returns ------- + list[BinaryPrediction | MultiClassPrediction] A list of typed predictions that can be saved to disk. """ return [(v[1]["name"], v[1]["label"].item(), v[0].item()) for v in samples] def separate(batch: Sample) -> list[BinaryPrediction | MultiClassPrediction]: - """Separates a collated batch reconstituting its samples. + """Separate a collated batch, reconstituting its samples. This function implements the inverse of :py:func:`torch.utils.data.default_collate`, and can separate, into @@ -42,12 +42,10 @@ def separate(batch: Sample) -> list[BinaryPrediction | MultiClassPrediction]: dimension, via :py:func:`torch.flatten`) * ``typing.Mapping[K, V[]]`` -> ``[dict[K, V_1], dict[K, V_2], ...]`` - Parameters ---------- batch - A batch, as output by torch model forwarding - + A batch, as output by torch model forwarding. Returns ------- diff --git a/src/mednet/models/transforms.py b/src/mednet/models/transforms.py index 8ff0af726aada9256603882dd8a6c0aeb6d34cad..4ee73548a5b458ed6e5e043fb5922dcb752cb126 100644 --- a/src/mednet/models/transforms.py +++ b/src/mednet/models/transforms.py @@ -9,37 +9,35 @@ import torchvision.transforms.functional def grayscale_to_rgb(img: torch.Tensor) -> torch.Tensor: - """Converts an image in grayscale to RGB. + """Convert an image in grayscale to RGB. If the image is already in RGB format, then this is a NOOP - the same tensor is returned (no cloning). If the image is in grayscale format - (number of bands = 1), then triplicate that band 3 times (a new copy is - returned in this case). - + (number of color channels = 1), then replicate it to obtain 3 color channels + (a new copy is returned in this case). Parameters ---------- - img The tensor to be transformed. Expected to be in the form: ``[..., [1,3], H, W]`` (i.e. arbitrary number of leading dimensions). Returns ------- - - img - transformed tensor where the 3rd dimension from the last is 3. + torch.Tensor + Transformed tensor with 3 identical color channels. """ if img.ndim < 3: raise TypeError( - f"Input image tensor should have at least 3 dimensions," - f"but found {img.ndim}" + f"Input image tensor should have at least 3 dimensions, " + f"but found {img.ndim}. If a grayscale image was provided, " + f"ensure to include a channel dimension of size 1 ( i.e: [1, height, width])." ) if img.shape[-3] not in (1, 3): raise TypeError( - f"Input image tensor should have 1 or 3 planes," - f"but found {img.shape[-3]}" + f"Input image tensor should have 1 or 3 color channels," + f"but found {img.shape[-3]}." ) if img.shape[-3] == 3: @@ -52,11 +50,12 @@ def grayscale_to_rgb(img: torch.Tensor) -> torch.Tensor: def rgb_to_grayscale(img: torch.Tensor) -> torch.Tensor: - """Converts an image in RGB to grayscale. + """Convert an image in RGB to grayscale. If the image is already in grayscale format, then this is a NOOP - the same - tensor is returned (no cloning). If the image is in RGB format, then - compresses the color planes into grayscale following this equation: + tensor is returned (no cloning). If the image is in RGB format + (number of color channels = 3), then compresses the color channels into + a single grayscale channel following this equation: .. math:: @@ -64,24 +63,22 @@ def rgb_to_grayscale(img: torch.Tensor) -> torch.Tensor: A new tensor is returned in this case. - Parameters ---------- - img The tensor to be transformed. Expected to be in the form: ``[..., [1,3], H, W]`` (i.e. arbitrary number of leading dimensions). Returns ------- - - img - transformed tensor where the 3rd dimension from the last is 3. + torch.Tensor + Transformed tensor with a single (grayscale) color channel. """ if img.ndim < 3: raise TypeError( - f"Input image tensor should have at least 3 dimensions," - f"but found {img.ndim}" + f"Input image tensor should have at least 3 dimensions, " + f"but found {img.ndim}. If a grayscale image was provided, " + f"ensure to include a channel dimension of size 1 ( i.e: [1, height, width])." ) if img.shape[-3] not in (1, 3): @@ -98,13 +95,7 @@ def rgb_to_grayscale(img: torch.Tensor) -> torch.Tensor: class RGB(torch.nn.Module): - """Converts an image in grayscale to RGB. - - If the image is already in RGB format, then this is a NOOP - the same - tensor is returned (no cloning). If the image is in grayscale format - (number of bands = 1), then triplicate that band 3 times (a new copy is - returned in this case). - """ + """Wrapper class around :py:func:`.grayscale_to_rgb` to be used as a model transform.""" def __init__(self): super().__init__() @@ -114,18 +105,7 @@ class RGB(torch.nn.Module): class Grayscale(torch.nn.Module): - """Converts an image in RGB to grayscale. - - If the image is already in grayscale format, then this is a NOOP - the same - tensor is returned (no cloning). If the image is in RGB format, then - compresses the color planes into grayscale following this equation: - - .. math:: - - grayscale = (0.2989 * r + 0.587 * g + 0.114 * b) - - A new tensor is returned in this case. - """ + """Wrapper class around :py:func:`rgb_to_grayscale` to be used as a model transform.""" def __init__(self): super().__init__() diff --git a/src/mednet/scripts/cli.py b/src/mednet/scripts/cli.py index e8f7ba244526e807bd8d63a290c7010b530baafe..3c44ea0a847f9c7ee17e67ad0dc19ee628a6bd51 100644 --- a/src/mednet/scripts/cli.py +++ b/src/mednet/scripts/cli.py @@ -40,7 +40,7 @@ cli.add_command( context_settings=dict(help_option_names=["-?", "-h", "--help"]), ) def saliency(): - """Sub-commands to generate, evaluate and view saliency maps.""" + """The sub-commands to generate, evaluate and view saliency maps.""" pass diff --git a/src/mednet/scripts/click.py b/src/mednet/scripts/click.py index 07dfe697c5f0e35356f75958c949c97aaef36aa8..8cb7c392e50c074c4b6b2523bee17f0953d75acb 100644 --- a/src/mednet/scripts/click.py +++ b/src/mednet/scripts/click.py @@ -8,20 +8,19 @@ from clapper.click import ConfigCommand as _BaseConfigCommand class ConfigCommand(_BaseConfigCommand): - """A click command-class that has the properties of - :py:class:`clapper.click.ConfigCommand` and adds verbatim epilog - formatting.""" + """A click command-class that has the properties of :py:class:`clapper.click.ConfigCommand` and adds verbatim epilog formatting.""" def format_epilog( self, _: click.core.Context, formatter: click.formatting.HelpFormatter ) -> None: - """Formats the command epilog during --help. - - Arguments: - - _: The current parsing context - - formatter: The formatter to use for printing text + """Format the command epilog during --help. + + Parameters + ---------- + _ + The current parsing context. + formatter + The formatter to use for printing text. """ if self.epilog: diff --git a/src/mednet/scripts/config.py b/src/mednet/scripts/config.py index 14abfe9aef185dd47e25cf22542731da72fb66fc..5091cd28d9faf62f9e7dcc988a23868e85768dec 100644 --- a/src/mednet/scripts/config.py +++ b/src/mednet/scripts/config.py @@ -16,7 +16,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @click.group(cls=AliasedGroup) def config(): - """Commands for listing, describing and copying configuration resources.""" + """Command for listing, describing and copying configuration resources.""" pass @@ -42,8 +42,8 @@ def config(): """ ) @verbosity_option(logger=logger) -def list(verbose) -> None: - """Lists configuration files installed.""" +def list(verbose) -> None: # numpydoc ignore=PR01 + """List configuration files installed.""" entry_points = importlib.metadata.entry_points().select( group="mednet.config" ) @@ -98,7 +98,7 @@ def list(verbose) -> None: epilog="""Examples: \b - 1. Describes the Montgomery dataset configuration: + 1. Describe the Montgomery dataset configuration: .. code:: sh @@ -106,7 +106,7 @@ def list(verbose) -> None: \b - 2. Describes the Montgomery dataset configuration and lists its + 2. Describe the Montgomery dataset configuration and lists its contents: .. code:: sh @@ -121,8 +121,8 @@ def list(verbose) -> None: nargs=-1, ) @verbosity_option(logger=logger) -def describe(name, verbose) -> None: - """Describes a specific configuration file.""" +def describe(name, verbose) -> None: # numpydoc ignore=PR01 + """Describe a specific configuration file.""" entry_points = importlib.metadata.entry_points().select( group="mednet.config" ) @@ -152,7 +152,7 @@ def describe(name, verbose) -> None: epilog="""Examples: \b - 1. Makes a copy of one of the stock configuration files locally, so it can be + 1. Make a copy of one of the stock configuration files locally, so it can be adapted: .. code:: sh @@ -172,7 +172,7 @@ def describe(name, verbose) -> None: nargs=1, ) @verbosity_option(logger=logger, expose_value=False) -def copy(source, destination) -> None: +def copy(source, destination) -> None: # numpydoc ignore=PR01 """Copy a specific configuration resource so it can be modified locally.""" import shutil diff --git a/src/mednet/scripts/database.py b/src/mednet/scripts/database.py index 683cad2ede520e1cae8442cdd784fa0984a8687c..e968ef26e8f8b0c6422c28814fb8828c0507b45c 100644 --- a/src/mednet/scripts/database.py +++ b/src/mednet/scripts/database.py @@ -11,16 +11,16 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") def _get_raw_databases() -> dict[str, dict[str, str]]: - """Returns a list of all supported (raw) databases. + """Return a list of all supported (raw) databases. Returns ------- - d + dict[str, dict[str, str]] Dictionary where keys are database names, and values are dictionaries containing two string keys: * ``module``: the full Pythonic module name (e.g. - ``mednet.data.montgomery``) + ``mednet.data.montgomery``). * ``datadir``: points to the user-configured data directory for the current dataset, if set, or ``None`` otherwise. """ @@ -57,7 +57,7 @@ def _get_raw_databases() -> dict[str, dict[str, str]]: @click.group(cls=AliasedGroup) def database() -> None: - """Commands for listing and verifying databases installed.""" + """Command for listing and verifying databases installed.""" pass @@ -90,7 +90,7 @@ def database() -> None: ) @verbosity_option(logger=logger, expose_value=False) def list(): - """Lists all supported and configured databases.""" + """List all supported and configured databases.""" config = _get_raw_databases() click.echo("Available databases:") @@ -131,8 +131,8 @@ def list(): default=0, ) @verbosity_option(logger=logger, expose_value=False) -def check(split, limit): - """Checks file access on one or more datamodules.""" +def check(split, limit): # numpydoc ignore=PR01 + """Check file access on one or more DataModules.""" import importlib.metadata import sys @@ -189,5 +189,5 @@ def check(split, limit): ) else: click.secho( - f"Found {errors} errors loading datamodule `{split}`.", fg="red" + f"Found {errors} errors loading DataModule `{split}`.", fg="red" ) diff --git a/src/mednet/scripts/evaluate.py b/src/mednet/scripts/evaluate.py index 2a5a85cb931fd1d28066493b4b37aeb752186235..3c21ef632a25bcac0e39ba7d7678dc5737b2637f 100644 --- a/src/mednet/scripts/evaluate.py +++ b/src/mednet/scripts/evaluate.py @@ -22,13 +22,13 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") cls=ConfigCommand, epilog="""Examples: -1. Runs evaluation on an existing prediction output: +1. Run evaluation on an existing prediction output: .. code:: sh mednet evaluate -vv --predictions=path/to/predictions.json --output-folder=path/to/results -2. Runs evaluation on an existing prediction output, tune threshold a priori on the `validation` set: +2. Run evaluation on an existing prediction output, tune threshold a priori on the `validation` set: .. code:: sh @@ -38,7 +38,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @click.option( "--predictions", "-p", - help="Path where predictions are currently stored", + help="Directory in which predictions are currently stored", required=True, type=click.Path( file_okay=True, dir_okay=False, writable=True, path_type=pathlib.Path @@ -48,7 +48,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @click.option( "--output-folder", "-o", - help="Path where to store the analysis result (created if does not exist)", + help="Directory in which to store the analysis result (created if does not exist)", required=False, default="results", type=click.Path(file_okay=False, dir_okay=True, path_type=pathlib.Path), @@ -78,8 +78,8 @@ def evaluate( output_folder: pathlib.Path, threshold: str | float, **_, # ignored -) -> None: - """Evaluates predictions (from a model) on a classification task.""" +) -> None: # numpydoc ignore=PR01 + """Evaluate predictions (from a model) on a classification task.""" import json import typing @@ -114,7 +114,7 @@ def evaluate( raise click.BadParameter( f"""The value of --threshold=`{threshold}` does not match one of the database split names ({', '.join(predict_data.keys())}) - or can be converted to float. Check your input.""" + or can not be converted to a float. Check your input.""" ) results: dict[ diff --git a/src/mednet/scripts/experiment.py b/src/mednet/scripts/experiment.py index 151a2356f55db173839aab44a3ec46a942223b1c..38d88fd94c32fee58c8959ea3ed767c59ff0fd2f 100644 --- a/src/mednet/scripts/experiment.py +++ b/src/mednet/scripts/experiment.py @@ -21,9 +21,9 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") epilog="""Examples: \b - 1. Trains a pasa model with montgomery dataset, on the CPU, for only two + 1. Train a pasa model with montgomery dataset, on the CPU, for only two epochs, then runs inference and evaluation on stock datasets, report - performance as a table and a figure: + performance as a table and figures: .. code:: sh @@ -50,12 +50,11 @@ def experiment( monitoring_interval, balance_classes, **_, -): - """Runs a complete experiment, from training, to prediction and evaluation. +): # numpydoc ignore=PR01 + """Run a complete experiment, from training, to prediction and evaluation. This script is just a wrapper around the individual scripts for training, - running prediction, evaluating and comparing model performance. It - organises the output in a preset way:: + running prediction, and evaluating. It organises the output in a preset way:: \b └─ <output-folder>/ diff --git a/src/mednet/scripts/predict.py b/src/mednet/scripts/predict.py index 828ba4ff9304db9352e73aea1885229f361b0432..68dd8da7e5f1e28b03d21746b9550039f4caa440 100644 --- a/src/mednet/scripts/predict.py +++ b/src/mednet/scripts/predict.py @@ -19,13 +19,13 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") cls=ConfigCommand, epilog="""Examples: -1. Runs prediction on an existing datamodule configuration: +1. Run prediction on an existing DataModule configuration: .. code:: sh mednet predict -vv pasa montgomery --weight=path/to/model.ckpt --output=path/to/predictions.json -2. Enables multi-processing data loading with 6 processes: +2. Enable multi-processing data loading with 6 processes: .. code:: sh @@ -36,8 +36,8 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @click.option( "--output", "-o", - help="""Path where to store the JSON predictions for all samples in the - input datamodule (leading directories are created if they do not not + help="""Path to a .json file in which to save predictions for all samples in the + input DataModule (leading directories are created if they do not exist).""", required=True, default="results", @@ -49,7 +49,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @click.option( "--model", "-m", - help="""A lightining module instance implementing the network architecture + help="""A lightning module instance implementing the network architecture (not the weights, necessarily) to be used for prediction.""", required=True, cls=ResourceOption, @@ -57,9 +57,9 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @click.option( "--datamodule", "-d", - help="""A lighting data module that will be asked for prediction data - loaders. Typically, this includes all configured splits in a datamodule, - however this is not a requirement. A datamodule that returns a single + help="""A lightning DataModule that will be asked for prediction data + loaders. Typically, this includes all configured splits in a DataModule, + however this is not a requirement. A DataModule that returns a single dataloader for prediction (wrapped in a dictionary) is acceptable.""", required=True, cls=ResourceOption, @@ -106,7 +106,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") "-P", help="""Use multiprocessing for data loading: if set to -1 (default), disables multiprocessing data loading. Set to 0 to enable as many data - loading instances as processing cores as available in the system. Set to + loading instances as processing cores available in the system. Set to >= 1 to enable that many multiprocessing instances for data loading.""", type=click.IntRange(min=-1), show_default=True, @@ -124,9 +124,8 @@ def predict( weight, parallel, **_, -) -> None: - """Runs inference (generates scores) on all input images, using a pre- - trained model.""" +) -> None: # numpydoc ignore=PR01 + """Run inference (generates scores) on all input images, using a pre-trained model.""" import json import shutil diff --git a/src/mednet/scripts/saliency/completeness.py b/src/mednet/scripts/saliency/completeness.py index 41b877d4decb63febb77fc2b20b8965887a4bb89..a6e30bbea94e3420927c439bd5710ebfeb5853a6 100644 --- a/src/mednet/scripts/saliency/completeness.py +++ b/src/mednet/scripts/saliency/completeness.py @@ -21,7 +21,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") cls=ConfigCommand, epilog="""Examples: -1. Calculates the ROAD scores for an existing dataset configuration and stores them in .csv files: +1. Calculate the ROAD scores for an existing dataset configuration and stores them in .json files: .. code:: sh @@ -32,7 +32,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @click.option( "--model", "-m", - help="""A lightining module instance implementing the network architecture + help="""A lightning module instance implementing the network architecture (not the weights, necessarily) to be used for inference. Currently, only supports pasa and densenet models.""", required=True, @@ -41,17 +41,17 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @click.option( "--datamodule", "-d", - help="""A lighting data module that will be asked for prediction data - loaders. Typically, this includes all configured splits in a datamodule, - however this is not a requirement. A datamodule that returns a single - dataloader for prediction (wrapped in a dictionary) is acceptable.""", + help="""A lightning DataModule that will be asked for prediction DataLoaders. + Typically, this includes all configured splits in a DataModule, + however this is not a requirement. A DataModule that returns a single + DataLoader for prediction (wrapped in a dictionary) is acceptable.""", required=True, cls=ResourceOption, ) @click.option( "--output-json", "-o", - help="""Path where to store the output JSON file containing all + help="""Directory in which to store the output .json file containing all measures.""", required=True, type=click.Path( @@ -172,8 +172,8 @@ def completeness( positive_only, percentile, **_, -) -> None: - """Evaluates saliency map algorithm completeness using RemOve And Debias +) -> None: # numpydoc ignore=PR01 + """Evaluate saliency map algorithm completeness using RemOve And Debias (ROAD). For the selected saliency map algorithm, evaluates the completeness of @@ -188,15 +188,15 @@ def completeness( 2023, this measurement technique is considered to be one of the state-of-the-art metrics of explainability. - This program outputs a JSON file containing the ROAD evaluations (using + This program outputs a .json file containing the ROAD evaluations (using most-relevant-first, or MoRF, and least-relevant-first, or LeRF for each - sample in the datamodule. Values for MoRF and LeRF represent averages by + sample in the DataModule. Values for MoRF and LeRF represent averages by removing 20, 40, 60 and 80% of most or least relevant pixels respectively from the image, and averaging results for all these percentiles. .. note:: - This application is relatively slow when processing a large datamodule + This application is relatively slow when processing a large DataModule with many (positive) samples. """ import json diff --git a/src/mednet/scripts/saliency/evaluate.py b/src/mednet/scripts/saliency/evaluate.py index f84ee73e1da3bdc89bb4dc3509769e91219e3b72..26a9b8d1f6ce75b3a4202240541266e0fafa38c8 100644 --- a/src/mednet/scripts/saliency/evaluate.py +++ b/src/mednet/scripts/saliency/evaluate.py @@ -24,7 +24,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") cls=ConfigCommand, epilog="""Examples: -1. Tabulates and generates plots for two saliency map algorithms: +1. Tabulate and generates plots for two saliency map algorithms: .. code:: sh @@ -65,7 +65,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @click.option( "--output-folder", "-o", - help="Path where to store the analysis result (created if does not exist)", + help="Directory in which to store the analysis result (created if does not exist)", required=False, default="results", type=click.Path(file_okay=False, dir_okay=True, path_type=pathlib.Path), @@ -76,8 +76,8 @@ def evaluate( entry, output_folder, **_, # ignored -) -> None: - """Calculates summary statistics for a saliency map algorithm.""" +) -> None: # numpydoc ignore=PR01 + """Calculate summary statistics for a saliency map algorithm.""" import json from matplotlib.backends.backend_pdf import PdfPages diff --git a/src/mednet/scripts/saliency/generate.py b/src/mednet/scripts/saliency/generate.py index bb1bf4a03ada352c2d2978feca834496afac007f..583c250c7061181298d6a5ed0bdb27f12826211d 100644 --- a/src/mednet/scripts/saliency/generate.py +++ b/src/mednet/scripts/saliency/generate.py @@ -21,7 +21,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") cls=ConfigCommand, epilog="""Examples: -1. Generates saliency maps for all prediction dataloaders on a datamodule, +1. Generate saliency maps for all prediction dataloaders on a DataModule, using a pre-trained DenseNet model, and saves them as numpy-pickeled objects on the output directory: @@ -34,7 +34,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @click.option( "--model", "-m", - help="""A lightining module instance implementing the network architecture + help="""A lightning module instance implementing the network architecture (not the weights, necessarily) to be used for inference. Currently, only supports pasa and densenet models.""", required=True, @@ -43,9 +43,9 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @click.option( "--datamodule", "-d", - help="""A lighting data module that will be asked for prediction data - loaders. Typically, this includes all configured splits in a datamodule, - however this is not a requirement. A datamodule that returns a single + help="""A lightning DataModule that will be asked for prediction data + loaders. Typically, this includes all configured splits in a DataModule, + however this is not a requirement. A DataModule that returns a single dataloader for prediction (wrapped in a dictionary) is acceptable.""", required=True, cls=ResourceOption, @@ -53,7 +53,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @click.option( "--output-folder", "-o", - help="Path where to store saliency maps (created if does not exist)", + help="Directory in which to store saliency maps (created if does not exist)", required=True, type=click.Path( exists=False, @@ -86,7 +86,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @click.option( "--weight", "-w", - help="""Path or URL to pretrained model file (`.ckpt` extension), + help="""Path or URL to a pretrained model file (`.ckpt` extension), corresponding to the architecture set with `--model`. Optionally, you may also pass a directory containing the result of a training session, in which case either the best (lowest validation) or latest model will be loaded.""", @@ -105,7 +105,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") "-P", help="""Use multiprocessing for data loading: if set to -1 (default), disables multiprocessing data loading. Set to 0 to enable as many data - loading instances as processing cores as available in the system. Set to + loading instances as processing cores available in the system. Set to >= 1 to enable that many multiprocessing instances for data loading.""", type=click.IntRange(min=-1), show_default=True, @@ -161,8 +161,8 @@ def generate( target_class, positive_only, **_, -) -> None: - """Generates saliency maps for locations on input images that affected the +) -> None: # numpydoc ignore=PR01 + """Generate saliency maps for locations on input images that affected the prediction. The quality of saliency information depends on the saliency map diff --git a/src/mednet/scripts/saliency/interpretability.py b/src/mednet/scripts/saliency/interpretability.py index 1a77cbd9320cc7f1b36d30f5c26408b87c900d44..0d1b84f7692a1c218cca603d7e9306b617a8d1c5 100644 --- a/src/mednet/scripts/saliency/interpretability.py +++ b/src/mednet/scripts/saliency/interpretability.py @@ -30,9 +30,9 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @click.option( "--datamodule", "-d", - help="""A lighting data module that will be asked for prediction data - loaders. Typically, this includes all configured splits in a datamodule, - however this is not a requirement. A datamodule that returns a single + help="""A lightning DataModule that will be asked for prediction data + loaders. Typically, this includes all configured splits in a DataModule, + however this is not a requirement. A DataModule that returns a single dataloader for prediction (wrapped in a dictionary) is acceptable.""", required=True, cls=ResourceOption, @@ -40,7 +40,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @click.option( "--input-folder", "-i", - help="""Path where to load saliency maps from. You can generate saliency + help="""Path from where to load saliency maps. You can generate saliency maps with ``mednet saliency generate``.""", required=True, type=click.Path( @@ -66,8 +66,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @click.option( "--output-json", "-o", - help="""Path where to store the output JSON file containing all - measures.""", + help="""Path to the .json file in which all measures will be saved.""", required=True, type=click.Path( file_okay=True, @@ -84,8 +83,8 @@ def interpretability( target_label, output_json, **_, -) -> None: - """Evaluates saliency map agreement with annotations (human +) -> None: # numpydoc ignore=PR01 + """Evaluate saliency map agreement with annotations (human interpretability). The evaluation happens by comparing saliency maps with ground-truth @@ -97,9 +96,8 @@ def interpretability( For obvious reasons, this evaluation is limited to datasets that contain built-in annotations which corroborate classification. - - As a result of the evaluation, this application creates a single JSON file - that resembles the original datamodule, with added information containing + As a result of the evaluation, this application creates a single .json file + that resembles the original DataModule, with added information containing the following measures, for each sample: * Proportional Energy: A measure that compares (UNthresholed) saliency maps @@ -108,7 +106,7 @@ def interpretability( of the activations. * Average Saliency Focus: estimates how much of the ground truth bounding boxes area is covered by the activations. It is similar to the - proportional energy measure in the sense it does not need explicit + proportional energy measure in the sense that it does not need explicit thresholding. """ diff --git a/src/mednet/scripts/saliency/view.py b/src/mednet/scripts/saliency/view.py index a583637d5bd1e29e473087b50864fb3c997a71fd..f44be79f8f4999fcc69fc38da368b53869fe696e 100644 --- a/src/mednet/scripts/saliency/view.py +++ b/src/mednet/scripts/saliency/view.py @@ -18,7 +18,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") cls=ConfigCommand, epilog="""Examples: -1. Generates visualizations in form of heatmaps from existing saliency maps for a dataset configuration: +1. Generate visualizations in the form of heatmaps from existing saliency maps for a dataset configuration: .. code:: sh @@ -28,21 +28,21 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @click.option( "--model", "-m", - help="A lightining module instance implementing the network to be used for applying the necessary data transformations.", + help="A lightning module instance implementing the network to be used for applying the necessary data transformations.", required=True, cls=ResourceOption, ) @click.option( "--datamodule", "-d", - help="A lighting data module containing the training, validation and test sets.", + help="A lightning DataModule containing the training, validation and test sets.", required=True, cls=ResourceOption, ) @click.option( "--input-folder", "-i", - help="Path to the folder containing the saliency maps for a specific visualization type.", + help="Path to the directory containing the saliency maps for a specific visualization type.", required=True, type=click.Path( file_okay=False, @@ -56,7 +56,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @click.option( "--output-folder", "-o", - help="Path where to store the ROAD scores (created if does not exist)", + help="Directory in which to store the visualizations (created if does not exist)", required=True, type=click.Path( file_okay=False, @@ -97,8 +97,8 @@ def view( show_groundtruth, threshold, **_, -) -> None: - """Generates heatmaps for input CXRs based on existing saliency maps.""" +) -> None: # numpydoc ignore=PR01 + """Generate heatmaps for input CXRs based on existing saliency maps.""" from ...engine.saliency.viewer import run from ..utils import save_sh_command diff --git a/src/mednet/scripts/train.py b/src/mednet/scripts/train.py index 7c329359a61c7e14d72b360db58ff3852e4141af..1dc78e4a68e690e7a311b31966270b2de4eedadf 100644 --- a/src/mednet/scripts/train.py +++ b/src/mednet/scripts/train.py @@ -16,19 +16,18 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") def reusable_options(f): - """Options that can be re-used by top-level scripts (i.e. ``experiment``). + """The options that can be re-used by top-level scripts (i.e. + ``experiment``). This decorator equips the target function ``f`` with all (reusable) ``train`` script options. - Parameters ---------- f The target function to equip with options. This function must have parameters that accept such options. - Returns ------- The decorated version of function ``f`` @@ -37,7 +36,7 @@ def reusable_options(f): @click.option( "--output-folder", "-o", - help="Path where to store results (created if does not exist)", + help="Directory in which to store results (created if does not exist)", required=True, type=click.Path( file_okay=False, @@ -51,14 +50,14 @@ def reusable_options(f): @click.option( "--model", "-m", - help="A lightining module instance implementing the network to be trained", + help="A lightning module instance implementing the network to be trained", required=True, cls=ResourceOption, ) @click.option( "--datamodule", "-d", - help="A lighting data module containing the training and validation sets.", + help="A lightning DataModule containing the training and validation sets.", required=True, cls=ResourceOption, ) @@ -87,10 +86,10 @@ def reusable_options(f): "memory requirements for the network). The number of samples " "loaded for every iteration will be batch-size/batch-chunk-count. " "batch-size needs to be divisible by batch-chunk-count, otherwise an " - "error will be raised. This parameter is used to reduce number of " + "error will be raised. This parameter is used to reduce the number of " "samples loaded in each iteration, in order to reduce the memory usage " - "in exchange for processing time (more iterations). This is specially " - "interesting whe one is running with GPUs with limited RAM. The " + "in exchange for processing time (more iterations). This is especially " + "interesting when one is training on GPUs with limited RAM. The " "default of 1 forces the whole batch to be processed at once. Otherwise " "the batch is broken into batch-chunk-count pieces, and gradients are " "accumulated to complete each batch.", @@ -103,10 +102,10 @@ def reusable_options(f): @click.option( "--drop-incomplete-batch/--no-drop-incomplete-batch", "-D", - help="If set, then may drop the last batch in an epoch, in case it is " + help="If set, the last batch in an epoch will be dropped if " "incomplete. If you set this option, you should also consider " "increasing the total number of epochs of training, as the total number " - "of training steps may be reduced", + "of training steps may be reduced.", required=True, show_default=True, default=False, @@ -117,7 +116,7 @@ def reusable_options(f): "-e", help="""Number of epochs (complete training set passes) to train for. If continuing from a saved checkpoint, ensure to provide a greater - number of epochs than that saved on the checkpoint to be loaded.""", + number of epochs than was saved in the checkpoint to be loaded.""", show_default=True, required=True, default=1000, @@ -132,7 +131,7 @@ def reusable_options(f): change this to make validation more sparse, by increasing the validation period. Notice that this affects checkpoint saving. While checkpoints are created after every training step (the last training - step always triggers the overriding of latest checkpoint), and that + step always triggers the overriding of latest checkpoint), and this process is independent of validation runs, evaluation of the 'best' model obtained so far based on those will be influenced by this setting.""", @@ -204,8 +203,8 @@ def reusable_options(f): @click.option( "--balance-classes/--no-balance-classes", "-B/-N", - help="""If set, then balances weights of the random sampler during - training, so that samples from all sample classes are picked picked + help="""If set, balances weights of the random sampler during + training so that samples from all sample classes are picked equitably.""", required=True, show_default=True, @@ -224,7 +223,7 @@ def reusable_options(f): cls=ConfigCommand, epilog="""Examples: -1. Trains Pasa's model with Montgomery dataset, on a GPU (``cuda:0``): +1. Train a pasa model with the montgomery dataset, on a GPU (``cuda:0``): .. code:: sh @@ -249,14 +248,13 @@ def train( monitoring_interval, balance_classes, **_, -) -> None: - """Trains an CNN to perform image classification. +) -> None: # numpydoc ignore=PR01 + """Train an CNN to perform image classification. Training is performed for a configurable number of epochs, and - generates at least a final_model.ckpt. It may also generate a - number of intermediate checkpoints. Checkpoints are model files - (.ckpt files) that are stored during the training and useful to - resume the procedure in case it stops abruptly. + generates checkpoints. Checkpoints are model files with a .ckpt + extension that are used in subsequent tasks or from which training + can be resumed. """ import os @@ -296,9 +294,9 @@ def train( # If asked, rebalances the loss criterion based on the relative proportion # of class examples available in the training set. Also affects the - # validation loss if a validation set is available on the data module. + # validation loss if a validation set is available on the DataModule. if balance_classes: - logger.info("Applying datamodule train sampler balancing...") + logger.info("Applying DataModule train sampler balancing...") datamodule.balance_sampler_by_class = True # logger.info("Applying train/valid loss balancing...") # model.balance_losses_by_class(datamodule) diff --git a/src/mednet/scripts/train_analysis.py b/src/mednet/scripts/train_analysis.py index 7736d2cb9b362b6dbc5ba1d920bae9ff084a2a2e..9b7a514486e5faeaa50e78f7d6e0f47769b38e04 100644 --- a/src/mednet/scripts/train_analysis.py +++ b/src/mednet/scripts/train_analysis.py @@ -34,31 +34,27 @@ def create_figures( "percent-usage/gpu/*", ], ) -> list: - """Generates figures for each metric in the dataframe. + """Generate figures for each metric in the dataframe. - Each row of the dataframe correspond to an epoch and each column to a metric. + Each row of the dataframe corresponds to an epoch and each column to a metric. It is assumed that some metric names are of the form <metric>/<subset>. All subsets for a metric will be displayed on the same figure. - Parameters ---------- - - data: + data A dictionary where keys represent all scalar names, and values correspond to a tuple that contains an array with epoch numbers (when - values were taken), when the monitored values themselves. These lists + values were taken), and the monitored values themselves. These lists are pre-sorted by epoch number. - groups: - A list of scalar globs we are interested on the existing tensorboard - data, for plotting. Values with multiple matches are drawn on the same - plot. Values that do not exist are ignored. - + groups + A list of scalar globs present in the existing tensorboard data that + we are interested in for plotting. Values with multiple matches are + drawn on the same plot. Values that do not exist are ignored. Returns ------- - - figures: + list List of matplotlib figures, one per metric. """ import fnmatch @@ -111,21 +107,23 @@ def create_figures( epilog="""Examples: \b - 1. Analyzes a training log and produces various plots: + 1. Analyze a training log and produces various plots: .. code:: sh mednet train-analysis -vv results/logs """, ) -@click.argument( - "logdir", +@click.option( + "--logdir", + help="Path to the directory containing the Tensorboard training logs", + required=True, type=click.Path(dir_okay=True, exists=True, path_type=pathlib.Path), ) @click.option( "--output", "-o", - help="Name of the output file to dump (multi-page PDF)", + help="Name of the output file to create (multi-page .pdf)", required=True, show_default=True, default="trainlog.pdf", @@ -135,9 +133,8 @@ def create_figures( def train_analysis( logdir: pathlib.Path, output: pathlib.Path, -) -> None: - """Creates a plot for each metric in the training logs and saves them in a - pdf file.""" +) -> None: # numpydoc ignore=PR01 + """Create a plot for each metric in the training logs and saves them in a .pdf file.""" import matplotlib.pyplot as plt diff --git a/src/mednet/scripts/utils.py b/src/mednet/scripts/utils.py index 511fc8841c1b285b0db8b3cf2bf12b3257fb328a..7c62d84e715123b11442717bdce103479aef1cec 100644 --- a/src/mednet/scripts/utils.py +++ b/src/mednet/scripts/utils.py @@ -13,7 +13,7 @@ logger = logging.getLogger(__name__) def save_sh_command(path: pathlib.Path) -> None: - """Records command-line to reproduce this script. + """Record command-line to reproduce this script. This function can record the current command-line used to call the script being run. It creates an executable ``bash`` script setting up the current @@ -21,10 +21,8 @@ def save_sh_command(path: pathlib.Path) -> None: records further information on the date and time the script was run and the version of the package. - Parameters ---------- - path Path to a file where the commands to reproduce the current run will be recorded. Parent directories will be created if they do not exist. An diff --git a/src/mednet/utils/checkpointer.py b/src/mednet/utils/checkpointer.py index 543f7c7008e2aef6403babdc55cd4f391d36770b..51d9192fdd6e7379a75879c98af0d1f78244da6a 100644 --- a/src/mednet/utils/checkpointer.py +++ b/src/mednet/utils/checkpointer.py @@ -24,7 +24,7 @@ def _get_checkpoint_from_alias( path: pathlib.Path, alias: typing.Literal["best", "periodic"], ) -> pathlib.Path: - """Gets an existing checkpoint file path. + """Get an existing checkpoint file path. This function can search for names matching the checkpoint alias "stem" (ie. the prefix), and then assumes a dash "-" and a number follows that @@ -35,21 +35,18 @@ def _get_checkpoint_from_alias( If only one file is present matching the alias characteristics, then it is returned. - Parameters ---------- path - Folder in which may contain checkpoint + Folder in which may contain checkpoint. alias Can be one of "best" or "periodic". - Returns ------- Path to the requested checkpoint, or ``None``, if no checkpoint file matching specifications is found on the provided path. - Raises ------ FileNotFoundError @@ -91,8 +88,8 @@ def _get_checkpoint_from_alias( def get_checkpoint_to_resume_training( path: pathlib.Path, -): - """Returns the best checkpoint file path to resume training from. +) -> pathlib.Path: + """Return the best checkpoint file path to resume training from. Parameters ---------- @@ -100,11 +97,10 @@ def get_checkpoint_to_resume_training( The base directory containing either the "periodic" checkpoint to start the training session from. - Returns ------- - Path to a checkpoint file that exists on disk - + pathlib.Path + Path to a checkpoint file that exists on disk. Raises ------ @@ -117,8 +113,8 @@ def get_checkpoint_to_resume_training( def get_checkpoint_to_run_inference( path: pathlib.Path, -): - """Returns the best checkpoint file path to run inference with. +) -> pathlib.Path: + """Return the best checkpoint file path to run inference with. Parameters ---------- @@ -126,11 +122,10 @@ def get_checkpoint_to_run_inference( The base directory containing either the "best", "last" or "periodic" checkpoint to start the training session from. - Returns ------- - Path to a checkpoint file that exists on disk - + pathlib.Path + Path to a checkpoint file that exists on disk. Raises ------ diff --git a/src/mednet/utils/rc.py b/src/mednet/utils/rc.py index 1a99df49783454a28a7e6dd9692056a5660198b3..fcc1659d1f5304c44e0561a7e1e70c6bd905db7a 100644 --- a/src/mednet/utils/rc.py +++ b/src/mednet/utils/rc.py @@ -6,5 +6,10 @@ from clapper.rc import UserDefaults def load_rc() -> UserDefaults: - """Returns global configuration variables.""" + """Return global configuration variables. + + Returns + ------- + The user defaults read from the user .toml configuration file. + """ return UserDefaults("mednet.toml") diff --git a/src/mednet/utils/resources.py b/src/mednet/utils/resources.py index 01995cacb3b11a05444761a0f6f26d96ad10ae98..a17d3cbd2750e01650e3d7ddc89bb5f379df0157 100644 --- a/src/mednet/utils/resources.py +++ b/src/mednet/utils/resources.py @@ -32,23 +32,19 @@ GB = float(2**30) def run_nvidia_smi( query: typing.Sequence[str], ) -> dict[str, str | float] | None: - """Returns GPU information from query. + """Return GPU information from query. For a comprehensive list of options and help, execute ``nvidia-smi --help-query-gpu`` on a host with a GPU - Parameters ---------- - query - A list of query strings as defined by ``nvidia-smi --help-query-gpu`` - + A list of query strings as defined by ``nvidia-smi --help-query-gpu``. Returns ------- - - data + dict[str, str | float] | None A dictionary containing the queried parameters (``rename`` versions). If ``nvidia-smi`` is not available, returns ``None``. Percentage information is left alone, memory information is transformed to @@ -80,27 +76,21 @@ def run_nvidia_smi( def run_powermetrics( time_window_ms: int = 500, key: str | None = None ) -> dict[str, typing.Any] | None: - """Returns GPU information from the system. + """Return GPU information from the system. For a comprehensive list of options and help, execute ``man powermetrics`` on a Mac computer with Apple silicon. - Parameters ---------- - time_window_ms The amount of time, in milliseconds, to collect usage information on the GPU. - key If specified returns only a sub-key of the dictionary. - Returns ------- - - data A dictionary containing the GPU information. """ @@ -145,13 +135,13 @@ def run_powermetrics( def cuda_constants() -> dict[str, str | int | float] | None: - """Returns GPU (static) information using nvidia-smi. + """Return GPU (static) information using nvidia-smi. See :py:func:`run_nvidia_smi` for operational details. Returns ------- - + dict[str, str | int | float] | None If ``nvidia-smi`` is not available, returns ``None``, otherwise, we return a dictionary containing the following ``nvidia-smi`` query information, in this order: @@ -172,12 +162,11 @@ def cuda_constants() -> dict[str, str | int | float] | None: def mps_constants() -> dict[str, str | int | float] | None: - """Returns GPU (static) information using `/usr/bin/powermetrics`. + """Return GPU (static) information using `/usr/bin/powermetrics`. Returns ------- - - data : :py:class:`tuple`, None + dict[str, str | int | float] If ``nvidia-smi`` is not available, returns ``None``, otherwise, we return a dictionary containing the following ``nvidia-smi`` query information, in this order: @@ -203,14 +192,14 @@ def mps_constants() -> dict[str, str | int | float] | None: def cuda_log() -> dict[str, float] | None: - """Returns GPU information about current non-static status using nvidia- + """Return GPU information about current non-static status using nvidia- smi. See :py:func:`run_nvidia_smi` for operational details. Returns ------- - + dict[str, float] | None If ``nvidia-smi`` is not available, returns ``None``, otherwise, we return a dictionary containing the following ``nvidia-smi`` query information, in this order: @@ -243,19 +232,18 @@ def cuda_log() -> dict[str, float] | None: def mps_log() -> dict[str, float] | None: - """Returns GPU information about current non-static status using ``sudo + """Return GPU information about current non-static status using ``sudo powermetrics``. Returns ------- - If ``sudo powermetrics`` is not executable (or is not configured for passwordless execution), returns ``None``, otherwise, we return a dictionary containing the following query information, in this order: * ``freq_hz`` as ``frequency-MHz/gpu`` * 100 * (1 - ``idle_ratio``), as ``percent-usage/gpu``, - (:py:class:`float`, in percent) + (:py:class:`float`, in percent). """ result = run_powermetrics(500, key="gpu") @@ -270,16 +258,16 @@ def mps_log() -> dict[str, float] | None: def cpu_constants() -> dict[str, int | float]: - """Returns static CPU information about the current system. + """Return static CPU information about the current system. Returns ------- - - A dictionary containing these entries: + dict[str, int | float] + An ordered dictionary (organized as 2-tuples) containing these entries: 0. ``cpu_memory_total`` (:py:class:`float`): total memory available, in gigabytes - 1. ``cpu_count`` (:py:class:`int`): number of logical CPUs available + 1. ``cpu_count`` (:py:class:`int`): number of logical CPUs available. """ return { "memory-total-GB/cpu": psutil.virtual_memory().total / GB, @@ -288,13 +276,12 @@ def cpu_constants() -> dict[str, int | float]: class CPULogger: - """Logs CPU information using :py:mod:`psutil` + """Log CPU information using :py:mod:`psutil`. Parameters ---------- - pid - Process identifier of the main process (parent process) to observe + Process identifier of the main process (parent process) to observe. """ def __init__(self, pid: int | None = None): @@ -304,12 +291,11 @@ class CPULogger: [k.cpu_percent(interval=None) for k in self.cluster] def log(self) -> dict[str, int | float]: - """Returns current process cluster information. + """Return current process cluster information. Returns ------- - - data + dict[str, int | float] An ordered dictionary containing these entries: 0. ``cpu_memory_used`` (:py:class:`float`): total memory used from @@ -376,16 +362,13 @@ class _InformationGatherer: Parameters ---------- - device_type String representation of one of the supported pytorch device types triggering the correct readout of resource usage. - main_pid - The main process identifier to monitor - + The main process identifier to monitor. logger - A logger to be used for logging messages + A logger to be used for logging messages. """ def __init__( @@ -422,7 +405,7 @@ class _InformationGatherer: self.data: dict[str, list[int | float]] = {k: [] for k in keys} def acc(self) -> None: - """Accumulates another measurement.""" + """Accumulate another measurement.""" for k, v in self.cpu_logger.log().items(): self.data[k].append(v) @@ -443,12 +426,18 @@ class _InformationGatherer: pass def clear(self) -> None: - """Clears accumulated data.""" + """Clear accumulated data.""" for k in self.data.keys(): self.data[k] = [] def summary(self) -> dict[str, list[int | float]]: - """Returns the current data.""" + """Return the current data. + + Returns + ------- + dict[str, list[int | float]] + A dictionary with a list of resources and their corresponding values. + """ if len(next(iter(self.data.values()))) == 0: self.logger.error("CPU/GPU logger was not able to collect any data") return self.data @@ -467,29 +456,22 @@ def _monitor_worker( Parameters ========== - interval Number of seconds to wait between each measurement (maybe a floating - point number as accepted by :py:func:`time.sleep`) - + point number as accepted by :py:func:`time.sleep`). device_type String representation of one of the supported pytorch device types triggering the correct readout of resource usage. - main_pid - The main process identifier to monitor - + The main process identifier to monitor. stop - Event that indicates if we should continue running or stop - + Event that indicates if we should continue running or stop. summary_event - Event that indicates if we should produce a summary - + Event that indicates if we should produce a summary. queue - A queue, to send monitoring information back to the spawner - + A queue, to send monitoring information back to the spawner. logging_level - The logging level to use for logging from launched processes + The logging level to use for logging from launched processes. """ logger = multiprocessing.log_to_stderr(level=logging_level) ra = _InformationGatherer(device_type, main_pid, logger) @@ -517,20 +499,16 @@ class ResourceMonitor: Parameters ---------- - interval Number of seconds to wait between each measurement (maybe a floating - point number as accepted by :py:func:`time.sleep`) - + point number as accepted by :py:func:`time.sleep`). device_type String representation of one of the supported pytorch device types triggering the correct readout of resource usage. - main_pid - The main process identifier to monitor - + The main process identifier to monitor. logging_level - The logging level to use for logging from launched processes + The logging level to use for logging from launched processes. """ def __init__( @@ -567,11 +545,11 @@ class ResourceMonitor: self.data: dict[str, int | float] | None = None def __enter__(self) -> None: - """Starts the monitoring process.""" + """Start the monitoring process.""" self.monitor.start() def checkpoint(self, remove_last_n: int | None = None) -> None: - """Forces the monitoring process to yield data and clear the internal + """Force the monitoring process to yield data and clear the internal accumulator. Parameters @@ -609,8 +587,7 @@ class ResourceMonitor: self.data[k] = 0.0 def __exit__(self, *_) -> None: - """Stops the monitoring process and returns the summary of - observations.""" + """Stop the monitoring process and returns the summary of observations.""" self.stop_event.set() self.monitor.join() diff --git a/src/mednet/utils/summary.py b/src/mednet/utils/summary.py index 2f7d468c5a26439c3415326be383c9df9af66a09..bff705e30b557a0314dfef929535671e3dad7f81 100644 --- a/src/mednet/utils/summary.py +++ b/src/mednet/utils/summary.py @@ -6,11 +6,13 @@ from functools import reduce +import torch + from torch.nn.modules.module import _addindent # ignore this space! -def _repr(model): +def _repr(model: torch.nn.Module) -> tuple[str, int]: # We treat the extra repr like the sub-module, one item per line extra_lines = [] extra_repr = model.extra_repr() @@ -43,22 +45,17 @@ def _repr(model): return main_str, total_params -def summary(model): - """Counts the number of parameters in each model layer. +def summary(model: torch.nn.Module) -> tuple[str, int]: + """Count the number of parameters in each model layer. Parameters ---------- - - model : :py:class:`torch.nn.Module` - model to summarize + model + Model to summarize. Returns ------- - - repr : str - a multiline string representation of the network - - nparam : int - number of parameters + tuple[int, str] + A tuple containing a multiline string representation of the network and the number of parameters. """ return _repr(model) diff --git a/src/mednet/utils/tensorboard.py b/src/mednet/utils/tensorboard.py index e41b2c07a9f1a3ca9ab50d11b8c5a754d6b117b9..56a81c4497bd8552b15ce4468a4fb4fa8a0e1ce9 100644 --- a/src/mednet/utils/tensorboard.py +++ b/src/mednet/utils/tensorboard.py @@ -12,21 +12,20 @@ from tensorboard.backend.event_processing.event_accumulator import ( def scalars_to_dict( logdir: pathlib.Path, ) -> dict[str, tuple[list[int], list[float]]]: - """Returns scalars stored in tensorboard event files. + """Return scalars stored in tensorboard event files. This method will gather all tensorboard event files produced by a training run, and will return a dictionary with all collected scalars, ready for plotting. - Parameters ---------- logdir Directory containing the event files. - Returns ------- + dict[str, tuple[list[int], list[float]]] A dictionary where keys represent all scalar names, and values correspond to a tuple that contains an array with epoch numbers (when values were taken), when the monitored values themselves. The lists diff --git a/tests/conftest.py b/tests/conftest.py index 60d00e284ba5656189bbb05b03a3deb5afb549a9..b4f92331138424dc528900644bcef76e8f4c26ce 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,12 +17,29 @@ from mednet.data.typing import DatabaseSplit @pytest.fixture def datadir(request) -> pathlib.Path: - """Returns the directory in which the test is sitting.""" + """Return the directory in which the test is sitting. Check the pytest documentation for more information. + + Parameters + ---------- + request + Information of the requesting test function. + + Returns + ------- + pathlib.Path + The directory in which the test is sitting. + """ return pathlib.Path(request.module.__file__).parents[0] / "data" def pytest_configure(config): - """This function is run once for pytest setup.""" + """This function is run once for pytest setup. + + Parameters + ---------- + config + Configuration values. Check the pytest documentation for more information. + """ config.addinivalue_line( "markers", "skip_if_rc_var_not_set(name): this mark skips the test if a certain " @@ -37,6 +54,11 @@ def pytest_runtest_setup(item): The test is run if this function returns ``None``. To skip a test, call ``pytest.skip()``, specifying a reason. + + Parameters + ---------- + item + A test invocation item. Check the pytest documentation for more information. """ from mednet.utils.rc import load_rc @@ -76,8 +98,13 @@ def temporary_basedir(tmp_path_factory): def pytest_sessionstart(session: pytest.Session) -> None: - """Presets the session start to ensure the Montgomery dataset is always - available.""" + """Preset the session start to ensure the Montgomery dataset is always available. + + Parameters + ---------- + session + The session to use. + """ from mednet.utils.rc import load_rc @@ -129,27 +156,19 @@ class DatabaseCheckers: prefixes: typing.Sequence[str], possible_labels: typing.Sequence[int], ): - """Runs a simple consistence check on the data split. + """Run a simple consistence check on the data split. Parameters ---------- - - make_split - A database specific function that takes a split name and returns - the loaded database split. - - split_filename - This is the split we will check - - lenghts + split + An instance of DatabaseSplit. + lengths A dictionary that contains keys matching those of the split (this will be checked). The values of the dictionary should correspond to the sizes of each of the datasets in the split. - prefixes Each file named in a split should start with at least one of these prefixes. - possible_labels These are the list of possible labels contained in any split. """ @@ -179,21 +198,20 @@ class DatabaseCheckers: prefixes: typing.Sequence[str], possible_labels: typing.Sequence[int], ): - """Checks the consistence of an individual (loaded) batch. + """Check the consistence of an individual (loaded) batch. Parameters ---------- batch The loaded batch to be checked. - - size - The mini-batch size - + batch_size + The mini-batch size. + color_planes + The number of color planes in the images. prefixes Each file named in a split should start with at least one of these prefixes. - possible_labels These are the list of possible labels contained in any split. """ diff --git a/tests/test_cli.py b/tests/test_cli.py index 6648d1354cce704798a342a00b7554c0b90338d0..4828e1d09ca51d1b06c5cb3c29fb1f8b62ddccd0 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -87,7 +87,7 @@ def test_config_describe_montgomery(): runner = CliRunner() result = runner.invoke(describe, ["montgomery"]) _assert_exit_0(result) - assert "Montgomery datamodule for TB detection." in result.output + assert "Montgomery DataModule for TB detection." in result.output def test_database_help(): @@ -224,7 +224,7 @@ def test_train_pasa_montgomery(temporary_basedir): r"^Writing command-line for reproduction at .*$": 1, r"^Loading dataset:`train` without caching. Trade-off: CPU RAM usage: less | Disk I/O: more.$": 1, r"^Loading dataset:`validation` without caching. Trade-off: CPU RAM usage: less | Disk I/O: more.$": 1, - r"^Applying datamodule train sampler balancing...$": 1, + r"^Applying DataModule train sampler balancing...$": 1, r"^Balancing samples from dataset using metadata targets `label`$": 1, r"^Training for at most 1 epochs.$": 1, r"^Uninitialised pasa model - computing z-norm factors from train dataloader.$": 1, @@ -321,7 +321,7 @@ def test_train_pasa_montgomery_from_checkpoint(temporary_basedir): r"^Writing command-line for reproduction at .*$": 1, r"^Loading dataset:`train` without caching. Trade-off: CPU RAM usage: less | Disk I/O: more.$": 1, r"^Loading dataset:`validation` without caching. Trade-off: CPU RAM usage: less | Disk I/O: more.$": 1, - r"^Applying datamodule train sampler balancing...$": 1, + r"^Applying DataModule train sampler balancing...$": 1, r"^Balancing samples from dataset using metadata targets `label`$": 1, r"^Training for at most 2 epochs.$": 1, r"^Resuming from epoch 0 \(checkpoint file: .*$": 1, diff --git a/tests/test_tbx11k.py b/tests/test_tbx11k.py index 71c37e1615b48eb96955a82782188d5223c2a296..b1bbcc1f9f9a1d6473fd380a7f360fa298ae2183 100644 --- a/tests/test_tbx11k.py +++ b/tests/test_tbx11k.py @@ -153,16 +153,18 @@ def check_loaded_batch( batch_size: int, prefixes: typing.Sequence[str], ): - """Checks the consistence of an individual (loaded) batch. + """Check the consistence of an individual (loaded) batch. Parameters ---------- batch The loaded batch to be checked. - - size - The mini-batch size + batch_size + The mini-batch size. + prefixes + Each file named in a split should start with at least one of these + prefixes. """ assert len(batch) == 2 # data, metadata