diff --git a/src/mednet/data/augmentations.py b/src/mednet/data/augmentations.py index 2afcdac5cd45605516e725a6d7c3fb3b57f2387b..91eaa595985cbd8371952007c44f9981284a522d 100644 --- a/src/mednet/data/augmentations.py +++ b/src/mednet/data/augmentations.py @@ -40,14 +40,14 @@ def _elastic_deformation_on_image( This implementation is based on 2 scipy functions (:py:func:`scipy.ndimage.gaussian_filter` and :py:func:`scipy.ndimage.map_coordinates`). It is very inefficient since it - requires data is moved off the current running device and then back. + requires data to be moved off the current running device and then back. Parameters ---------- img - The input image to apply elastic deformation at. This image should + The input image to apply elastic deformation to. This image should always have this shape: ``[C, H, W]``. It should always represent a tensor on the CPU. @@ -74,7 +74,7 @@ def _elastic_deformation_on_image( ------- tensor - A tensor on the CPU. + The image with elastic deformation applied, as a tensor on the CPU. """ if random.random() < p: @@ -148,14 +148,14 @@ def _elastic_deformation_on_batch( This implementation is based on 2 scipy functions (:py:func:`scipy.ndimage.gaussian_filter` and :py:func:`scipy.ndimage.map_coordinates`). It is very inefficient since it - requires data is moved off the current running device and then back. + requires data to be moved off the current running device and then back. Parameters ---------- img - The input image to apply elastic deformation at. This image should + The input image to apply elastic deformation to. This image should always have this shape: ``[C, H, W]``. It should always represent a tensor on the CPU. @@ -182,7 +182,7 @@ def _elastic_deformation_on_batch( ------- tensor - A tensor on the CPU. + A batch of images with elastic deformation applied, as a tensor on the CPU. """ # transforms our custom functions into simpler callables partial = functools.partial( @@ -210,12 +210,12 @@ class ElasticDeformation: This implementation is based on 2 scipy functions (:py:func:`scipy.ndimage.gaussian_filter` and :py:func:`scipy.ndimage.map_coordinates`). It is very inefficient since it - requires data is moved off the current running device and then back. + requires data to be moved off the current running device and then back. .. warning:: Furthermore, this transform is not scriptable and therefore cannot run - on a CUDA or MPS device. Applying it, effectively creates a bottleneck + on a CUDA or MPS device. Applying it effectively creates a bottleneck in model training. Source: https://gist.github.com/oeway/2e3b989e0343f0884388ed7ed82eb3b0 @@ -225,6 +225,7 @@ class ElasticDeformation: ---------- alpha + A multiplier for the gaussian filter outputs. sigma Standard deviation for Gaussian kernel. diff --git a/src/mednet/data/datamodule.py b/src/mednet/data/datamodule.py index b7d8660989ca8c5abec00c756a61d656bc98f5b1..5c8e30586c54e4a23b819e549dc794ddc8a35b52 100644 --- a/src/mednet/data/datamodule.py +++ b/src/mednet/data/datamodule.py @@ -65,7 +65,7 @@ def _sample_size_bytes(s: Sample) -> int: class _DelayedLoadingDataset(Dataset): """A list that loads its samples on demand. - This list mimics a pytorch Dataset, except raw data loading is done + This list mimics a pytorch Dataset, except that raw data loading is done on-the-fly, as the samples are requested through the bracket operator. @@ -152,7 +152,7 @@ class _CachedDataset(Dataset): """Basically, a list of preloaded samples. This dataset will load all samples from the raw dataset during construction - instead of delaying that to the indexing. Beyong raw-data-loading, + instead of delaying that to the indexing. Beyond raw-data-loading, ``transforms`` given upon construction contribute to the cached samples. @@ -305,20 +305,20 @@ def _make_balanced_random_sampler( this case). To verify this, notice that the probability of picking a sample with ``target=0`` is :math:`1/4 x 1 + 1/12 x 3 = 0.5`. 2. The probability of picking a sample with ``target=0`` from Dataset 2 is - 3 times higher than those from Dataset 1. As there are 3 times less + 3 times higher than those from Dataset 1. As there are 3 times fewer samples in Dataset 2 with ``target=0``, this makes choosing samples from Dataset 1 proportionally less likely. 3. The probability of picking a sample with ``target=1`` from Dataset 2 is - 3 times lower than those from Dataset 1. As there are 3 times less + 3 times lower than those from Dataset 1. As there are 3 times fewer samples in Dataset 1 with ``target=1``, this makes choosing samples from Dataset 2 proportionally less likely. This function assumes targets are stored on a dictionary entry named ``target`` inside the metadata information for the - :py:data:`.typing.Sample`, and that its value is integer. + :py:data:`.typing.Sample`, and that its value is an integer. We then instantiate a pytorch sampler using the inverse probabilities (the - more samples of a class, the less likely it becomes to be sampled. + more samples in a class, the less likely it becomes to be sampled. Parameters @@ -409,7 +409,7 @@ class ConcatDataModule(lightning.LightningDataModule): Instances of this class can load and concatenate an arbitrary number of data-split (a.k.a. protocol) definitions for (possibly disjoint) databases, and can manage raw data-loading from disk. An optional caching mechanism - stores the data at associated CPU memory, which can improve data serving + stores the data in associated CPU memory, which can improve data serving while training and evaluating models. This datamodule defines basic operations to handle data loading and @@ -435,7 +435,7 @@ class ConcatDataModule(lightning.LightningDataModule): .. tip:: - To check the split and the loader function works correctly, you may + To check the split and that the loader function works correctly, you may use :py:func:`.split.check_database_split_loading`. This class expects at least one entry called ``train`` to exist in the @@ -449,7 +449,7 @@ class ConcatDataModule(lightning.LightningDataModule): serves samples from CPU memory. Otherwise, loads samples from disk on demand. Running from CPU memory will offer increased speeds in exchange for CPU memory. Sufficient CPU memory must be available before you set - this attribute to ``True``. It is typicall useful for relatively small + this attribute to ``True``. It is typically useful for relatively small datasets. balance_sampler_by_class @@ -462,7 +462,7 @@ class ConcatDataModule(lightning.LightningDataModule): memory requirements for the network). If the number of samples in the batch is larger than the total number of samples available for training, this value is truncated. If this number is smaller, then - batches of the specified size are created and fed to the network until + batches of the specified size are created and fed to the network until there are no more new samples to feed (epoch is finished). If the total number of training samples is not a multiple of the batch-size, the last batch will be smaller than the first, unless @@ -474,18 +474,18 @@ class ConcatDataModule(lightning.LightningDataModule): requirements for the network). The number of samples loaded for every iteration will be ``batch_size/batch_chunk_count``. ``batch_size`` needs to be divisible by ``batch_chunk_count``, otherwise an error will - be raised. This parameter is used to reduce number of samples loaded in + be raised. This parameter is used to reduce the number of samples loaded in each iteration, in order to reduce the memory usage in exchange for - processing time (more iterations). This is specially interesting whe - one is running with GPUs with limited RAM. The default of 1 forces the + processing time (more iterations). This is especially interesting when + one is running on GPUs with limited RAM. The default of 1 forces the whole batch to be processed at once. Otherwise the batch is broken into batch-chunk-count pieces, and gradients are accumulated to complete each batch. drop_incomplete_batch - If set, then may drop the last batch in an epoch, in case it is + If set, then may drop the last batch in an epoch in case it is incomplete. If you set this option, you should also consider - increasing the total number of epochs of training, as the total number + increasing the total number of training epochs, as the total number of training steps may be reduced. parallel @@ -600,7 +600,7 @@ class ConcatDataModule(lightning.LightningDataModule): """Transforms required to fit data into the model. A list of transforms (torch modules) that will be applied after - raw- data-loading. and just before data is fed into the model or + raw-data-loading. and just before data is fed into the model or eventual data-augmentation transformations for all data loaders produced by this data module. This part of the pipeline receives data as output by the raw-data-loader, or model-related @@ -619,7 +619,7 @@ class ConcatDataModule(lightning.LightningDataModule): # datasets that have been setup() for the current stage are reset if value != old_value and len(self._datasets): logger.warning( - f"Reseting {len(self._datasets)} loaded datasets due " + f"Resetting {len(self._datasets)} loaded datasets due " "to changes in model-transform properties. If you were caching " "data loading, this will (eventually) trigger a reload." ) @@ -678,8 +678,8 @@ class ConcatDataModule(lightning.LightningDataModule): needs to be divisible by ``batch_chunk_count``, otherwise an error will be raised. This parameter is used to reduce number of samples loaded in each iteration, in order to reduce the memory usage in exchange for - processing time (more iterations). This is specially interesting whe - one is running with GPUs with limited RAM. The default of 1 forces the + processing time (more iterations). This is especially interesting when + one is running on GPUs with limited RAM. The default of 1 forces the whole batch to be processed at once. Otherwise the batch is broken into batch-chunk-count pieces, and gradients are accumulated to complete each batch. @@ -774,7 +774,7 @@ class ConcatDataModule(lightning.LightningDataModule): Parameters ---------- stage - Name of the stage to which the setup is applicable. Can be one of + Name of the stage in which the setup is applicable. Can be one of ``fit``, ``validate``, ``test`` or ``predict``. Each stage typically uses the following data loaders: @@ -800,19 +800,19 @@ class ConcatDataModule(lightning.LightningDataModule): self._setup_dataset(k) def teardown(self, stage: str) -> None: - """Unset-up datasets for different tasks on the pipeline. + """Unsets-up datasets for different tasks on the pipeline. This method unsets (unload, remove from memory, etc) all datasets required for a particular ``stage`` (fit, validate, test, predict). - If you have set ``cache_samples``, samples are loaded, this may + If you have set ``cache_samples``, samples are loaded and this may effectivley release all the associated memory. Parameters ---------- stage - Name of the stage to which the teardown is applicable. Can be one of + Name of the stage in which the teardown is applicable. Can be one of ``fit``, ``validate``, ``test`` or ``predict``. Each stage typically uses the following data loaders: @@ -928,7 +928,7 @@ class CachingDataModule(ConcatDataModule): An object instance that can load samples and labels from storage. **kwargs - List if named parameters matching those of + List of named parameters matching those of :py:class:`ConcatDataModule`, other than ``splits``. """ diff --git a/src/mednet/data/split.py b/src/mednet/data/split.py index 2d5370fe8e2e2f98d9438c2001ded5a555800ece..ddae2a1cf14f99fc318d0c43d1a9ea2e94324920 100644 --- a/src/mednet/data/split.py +++ b/src/mednet/data/split.py @@ -49,7 +49,7 @@ class JSONDatabaseSplit(DatabaseSplit): } Your database split many contain any number of (raw) datasets (dictionary - keys). For simplicity, we recommend all sample entries are formatted + keys). For simplicity, we recommend to format all sample entries similarly so that raw-data-loading is simplified. Use the function :py:func:`check_database_split_loading` to test raw data loading and fine tune the dataset split, or its loading. @@ -117,7 +117,7 @@ class CSVDatabaseSplit(DatabaseSplit): formatted files, each representing a dataset of this split, containing the sample data (one per row). Example: - Inside the directory ``my-split/``, one can file files ``train.csv``, + Inside the directory ``my-split/``, one can find the files ``train.csv``, ``validation.csv``, and ``test.csv``. Each file has a structure similar to the following: @@ -127,7 +127,7 @@ class CSVDatabaseSplit(DatabaseSplit): sample2-value1,sample2-value2,sample2-value3 ... - Each file in the provided directory defines the dataset name on the split. + Each file in the provided directory defines the dataset name of the split. So, the file ``train.csv`` will contain the data from the ``train`` dataset, and so on. @@ -139,7 +139,7 @@ class CSVDatabaseSplit(DatabaseSplit): ---------- directory - Absolute path to a directory containing the database split layed down + Absolute path to a directory containing the database split organized as a set of CSV files, one per dataset. """ @@ -196,7 +196,7 @@ class CSVDatabaseSplit(DatabaseSplit): return iter(self._datasets) def __len__(self) -> int: - """How many datasets we currently have.""" + """The number of datasets we currently have.""" return len(self._datasets) @@ -208,7 +208,7 @@ def check_database_split_loading( """For each dataset in the split, check if all data can be correctly loaded using the provided loader function. - This function will return the number of errors loading samples, and will + This function will return the number of errors when loading samples, and will log more detailed information to the logging stream. @@ -216,7 +216,7 @@ def check_database_split_loading( ---------- database_split - A mapping that, contains the database split. Each key represents the + A mapping that contains the database split. Each key represents the name of a dataset in the split. Each value is a (potentially complex) object that represents a single sample. @@ -236,7 +236,7 @@ def check_database_split_loading( Number of errors found """ logger.info( - "Checking if can load all samples in all datasets of this split..." + "Checking if all samples in all datasets of this split can be loaded..." ) errors = 0 for dataset, samples in database_split.items(): diff --git a/src/mednet/data/typing.py b/src/mednet/data/typing.py index c1df54c62eea0bf5005e2e3f5c0381e62b2e07a7..3102ecbc9ba608b1df866a134eda914b8f93e6ea 100644 --- a/src/mednet/data/typing.py +++ b/src/mednet/data/typing.py @@ -39,7 +39,7 @@ class RawDataLoader: Transform: typing.TypeAlias = typing.Callable[[torch.Tensor], torch.Tensor] -"""A callable, that transforms tensors into (other) tensors. +"""A callable that transforms tensors into (other) tensors. Typically used in data-processing pipelines inside pytorch. """ @@ -72,7 +72,7 @@ be assigned a different :py:class:`.RawDataLoader`. class Dataset(torch.utils.data.Dataset[Sample], typing.Iterable, typing.Sized): - """Our own definition of a pytorch Dataset, with interesting properties. + """Our own definition of a pytorch Dataset. We iterate over Sample objects in this case. Our datasets always provide a dunder len method. diff --git a/src/mednet/engine/predictor.py b/src/mednet/engine/predictor.py index 9b6c62a177fe3b187ddc947d8b6dd430ef066be5..3125ef873d2dee3e460b77571a913c67c4f0dec2 100644 --- a/src/mednet/engine/predictor.py +++ b/src/mednet/engine/predictor.py @@ -42,7 +42,7 @@ def run( device_manager An internal device representation, to be used for training and validation. This representation can be converted into a pytorch device - or a torch lightning accelerator setup. + or a lightning accelerator setup. Returns diff --git a/src/mednet/engine/saliency/completeness.py b/src/mednet/engine/saliency/completeness.py index 8d781cc18bf88e5725792d8c176fb43a26e2e07c..95a2c7fc8332f750541f5040445d0dfe869d92d9 100644 --- a/src/mednet/engine/saliency/completeness.py +++ b/src/mednet/engine/saliency/completeness.py @@ -47,7 +47,7 @@ def _calculate_road_scores( This function calculates ROAD scores by averaging the scores for different removal (hardcoded) percentiles, for a single input image, a - given visualization method, a target class. + given visualization method, and a target class. Parameters @@ -55,7 +55,7 @@ def _calculate_road_scores( model Neural network model (e.g. pasa). images - A batch of input images to use evaluating the ROAD scores. Currently, + A batch of input images to use for evaluating the ROAD scores. Currently, we only support batches with a single image. output_num Target output neuron to take into consideration when evaluating the @@ -89,7 +89,7 @@ def _calculate_road_scores( # current processing bottleneck. If you want to optimise anyting, look at # the evaluation of the perturbation using scipy.sparse at the # NoisyLinearImputer, part of the grad-cam package (submodule - # ``metrics.road``. + # ``metrics.road``). metric_target = [SigmoidClassifierOutputTarget(output_num)] MoRF_scores = cam_metric_ROADMoRF_avg( @@ -134,12 +134,13 @@ def _process_sample( saliency_map_callable A callable saliency-map generator from grad-cam target_class - Class to target for saliency estimation. Can be either set to - "all" or "highest". "highest". + Class to target for saliency estimation. Can be set to + "all" or "highest". "highest" is default, which means + only saliency maps for the class with the highest + activation will be generated. positive only If set, and the model chosen has a single output (binary), then saliency maps will only be generated for samples of the positive class - percentiles A sequence of percentiles (percent x100) integer values indicating the proportion of pixels to perturb in the original image to calculate both @@ -206,12 +207,12 @@ def run( ) -> dict[str, list[typing.Any]]: """Evaluates ROAD scores for all samples in a datamodule. - The ROAD algorithm was first described at [ROAD-2022]_. It estimates + The ROAD algorithm was first described in [ROAD-2022]_. It estimates explainability (in the completeness sense) of saliency maps by substituting - relevant pixels in the input image by a local average, and re-running + relevant pixels in the input image by a local average, re-running prediction on the altered image, and measuring changes in the output classification score when said perturbations are in place. By substituting - most or least relevant pixels with surrounding averages, the ROAD algorithm + the most or least relevant pixels with surrounding averages, the ROAD algorithm estimates the importance of such elements in the produced saliency map. As of 2023, this measurement technique is considered to be one of the state-of-the-art metrics of explainability. @@ -231,12 +232,12 @@ def run( device_manager An internal device representation, to be used for training and validation. This representation can be converted into a pytorch device - or a torch lightning accelerator setup. + or a lightning accelerator setup. saliency_map_algorithm The algorithm for saliency map estimation to use. target_class (Use only with multi-label models) Which class to target for CAM - calculation. Can be either set to "all" or "highest". "highest" is + calculation. Can be set to "all" or "highest". "highest" is default, which means only saliency maps for the class with the highest activation will be generated. positive_only @@ -298,8 +299,8 @@ def run( raise RuntimeError( f"The number of multiprocessing instances is set to {parallel} and " f"you asked to use a GPU (device = `{device_manager.device_type}`" - f"). The currently implementation can only handle a single GPU. " - f"Either disable GPU utilisation or set the number of " + f"). The current implementation can only handle a single GPU. " + f"Either disable GPU usage, set the number of " f"multiprocessing instances to one, or disable multiprocessing " "entirely (ie. set it to -1)." ) diff --git a/src/mednet/engine/saliency/evaluator.py b/src/mednet/engine/saliency/evaluator.py index aae1accaae37899ad84dc59a03498ed2ecb8547e..b2ae1d53ca9f2254d8f1aac18f8d9dc676e88a41 100644 --- a/src/mednet/engine/saliency/evaluator.py +++ b/src/mednet/engine/saliency/evaluator.py @@ -27,10 +27,10 @@ def _reconcile_metrics( Parameters ---------- completeness - A dictionary containing various tables with the sample name and + A list containing various tables with the sample name and completness (ROAD) scores. interpretability - A dictionary containing various tables with the sample name and + A list containing various tables with the sample name and interpretability (Pro. Energy) scores. diff --git a/src/mednet/engine/saliency/generator.py b/src/mednet/engine/saliency/generator.py index cb101e98ffeca29ad9c7d8a1e5cae7492f01c381..bae230507af682f716ee34544eed92e93735b959 100644 --- a/src/mednet/engine/saliency/generator.py +++ b/src/mednet/engine/saliency/generator.py @@ -129,7 +129,7 @@ def run( device_manager An internal device representation, to be used for training and validation. This representation can be converted into a pytorch device - or a torch lightning accelerator setup. + or a lightning accelerator setup. saliency_map_algorithm The algorithm to use for saliency map estimation. target_class diff --git a/src/mednet/engine/saliency/interpretability.py b/src/mednet/engine/saliency/interpretability.py index 95e184f4f9aee8a2cbc12f1a989adae565a490a0..a24c5f3fb8f2c3e74e471bcd49c478d92767e95c 100644 --- a/src/mednet/engine/saliency/interpretability.py +++ b/src/mednet/engine/saliency/interpretability.py @@ -441,7 +441,7 @@ def run( if not bboxes: logger.warning( - f"Sample `{name}` does not contdain bounding-box information. " + f"Sample `{name}` does not contain bounding-box information. " f"No localization metrics can be calculated in this case. " f"Skipping..." ) diff --git a/src/mednet/engine/saliency/viewer.py b/src/mednet/engine/saliency/viewer.py index 3c0a7efe1300e069c47189274fc556f177050759..fc03073a94044b78e707f328ab3b5c1a04990143 100644 --- a/src/mednet/engine/saliency/viewer.py +++ b/src/mednet/engine/saliency/viewer.py @@ -64,7 +64,7 @@ def _overlay_saliency_map( Parameters ---------- image - The input imge that will be overlayed with the saliency map + The input image that will be overlayed with the saliency map saliencies The saliency map that will be overlaid on the (raw) image colormap @@ -119,7 +119,7 @@ def _overlay_bounding_box( Parameters ---------- image - The input imge that will be overlayed with the saliency map + The input image that will be overlayed with the saliency map bbox The bounding box to draw on the input image color @@ -159,10 +159,10 @@ def _process_sample( The raw data representing the input sample that will be overlayed with saliency maps and annotations saliencies - The saliency map recovered from the model, that will be inprinted on + The saliency map recovered from the model, that will be imprinted on the raw_data ground_truth - Ground-truth annotations that may be inprinted on the final image + Ground-truth annotations that may be imprinted on the final image Returns @@ -209,9 +209,9 @@ def run( The label to target for evaluating interpretability metrics. Samples contining any other label are ignored. output_folder - Directory in which the resulting visualisations will be saved. + Directory in which the resulting visualizations will be saved. show_groundtruth - If set, inprint ground truth labels over the original image and + If set, imprint ground truth labels over the original image and saliency maps. threshold : float The pixel values above ``threshold``% of max value are kept in the diff --git a/src/mednet/engine/trainer.py b/src/mednet/engine/trainer.py index d4b39c4b224d2499cc74743dbaeb334f05653eec..24c5dc1cfcf8c98ad9a3e2d21a9962a450640312 100644 --- a/src/mednet/engine/trainer.py +++ b/src/mednet/engine/trainer.py @@ -152,7 +152,7 @@ def run( device_manager An internal device representation, to be used for training and validation. This representation can be converted into a pytorch device - or a torch lightning accelerator setup. + or a lightning accelerator setup. max_epochs The maximum number of epochs to train for.