From 53d2415ae1bf24c38c021b56320df15dcdf6f78d Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Thu, 6 Jun 2024 10:10:46 +0200 Subject: [PATCH] [mednet] Pre-commit fixes --- helpers/generate_histograms.py | 4 +- .../montgomery_shenzhen_indian/datamodule.py | 2 +- .../datamodule.py | 2 +- .../data/nih_cxr14_padchest/datamodule.py | 1 + src/mednet/libs/classification/scripts/cli.py | 8 +- .../libs/classification/tests/conftest.py | 16 +- .../tests/test_cli_classification.py | 4 +- src/mednet/libs/common/data/datamodule.py | 147 +----------------- src/mednet/libs/common/models/normalizer.py | 4 +- src/mednet/libs/common/scripts/config.py | 4 +- src/mednet/libs/common/scripts/database.py | 4 +- src/mednet/libs/common/tests/conftest.py | 1 - src/mednet/libs/common/utils/gitlab.py | 4 +- .../libs/segmentation/engine/evaluator.py | 7 +- src/mednet/libs/segmentation/models/lwnet.py | 24 +-- .../libs/segmentation/models/separate.py | 3 +- .../libs/segmentation/scripts/evaluate.py | 8 +- .../libs/segmentation/scripts/experiment.py | 8 +- .../libs/segmentation/tests/conftest.py | 24 +-- .../libs/segmentation/tests/test_measures.py | 48 ++---- src/mednet/libs/segmentation/utils/measure.py | 4 +- 21 files changed, 49 insertions(+), 278 deletions(-) diff --git a/helpers/generate_histograms.py b/helpers/generate_histograms.py index 8d4f83b5..57c9f342 100644 --- a/helpers/generate_histograms.py +++ b/helpers/generate_histograms.py @@ -299,9 +299,7 @@ def main( if not output_dir.exists(): output_dir.mkdir(parents=True, exist_ok=True) - output_file = ( - output_dir / f"histograms{model_name}_{k}_{fold_name}.json" - ) + output_file = output_dir / f"histograms{model_name}_{k}_{fold_name}.json" output_json_data = {} diff --git a/src/mednet/libs/classification/config/data/montgomery_shenzhen_indian/datamodule.py b/src/mednet/libs/classification/config/data/montgomery_shenzhen_indian/datamodule.py index b09fa338..2f6caeac 100644 --- a/src/mednet/libs/classification/config/data/montgomery_shenzhen_indian/datamodule.py +++ b/src/mednet/libs/classification/config/data/montgomery_shenzhen_indian/datamodule.py @@ -11,8 +11,8 @@ from mednet.libs.common.data.datamodule import ConcatDataModule from mednet.libs.common.data.split import make_split from ..indian.datamodule import CONFIGURATION_KEY_DATADIR as INDIAN_KEY_DATADIR -from ..indian.datamodule import DataModule as IndianDataModule from ..indian.datamodule import ClassificationRawDataLoader as IndianLoader +from ..indian.datamodule import DataModule as IndianDataModule from ..montgomery.datamodule import ( ClassificationRawDataLoader as MontgomeryLoader, ) diff --git a/src/mednet/libs/classification/config/data/montgomery_shenzhen_indian_tbx11k/datamodule.py b/src/mednet/libs/classification/config/data/montgomery_shenzhen_indian_tbx11k/datamodule.py index 53294527..c83680ec 100644 --- a/src/mednet/libs/classification/config/data/montgomery_shenzhen_indian_tbx11k/datamodule.py +++ b/src/mednet/libs/classification/config/data/montgomery_shenzhen_indian_tbx11k/datamodule.py @@ -11,8 +11,8 @@ from mednet.libs.common.data.datamodule import ConcatDataModule from mednet.libs.common.data.split import make_split from ..indian.datamodule import CONFIGURATION_KEY_DATADIR as INDIAN_KEY_DATADIR -from ..indian.datamodule import DataModule as IndianDataModule from ..indian.datamodule import ClassificationRawDataLoader as IndianLoader +from ..indian.datamodule import DataModule as IndianDataModule from ..montgomery.datamodule import ( ClassificationRawDataLoader as MontgomeryLoader, ) diff --git a/src/mednet/libs/classification/config/data/nih_cxr14_padchest/datamodule.py b/src/mednet/libs/classification/config/data/nih_cxr14_padchest/datamodule.py index d5a44ce6..e1907a35 100644 --- a/src/mednet/libs/classification/config/data/nih_cxr14_padchest/datamodule.py +++ b/src/mednet/libs/classification/config/data/nih_cxr14_padchest/datamodule.py @@ -11,6 +11,7 @@ from mednet.libs.common.data.split import make_split from ..nih_cxr14.datamodule import ClassificationRawDataLoader as CXR14Loader from ..padchest.datamodule import ClassificationRawDataLoader as PadchestLoader + class DataModule(ConcatDataModule): """Aggregated dataset composed of NIH CXR14 relabeld and PadChest (normalized) datasets. diff --git a/src/mednet/libs/classification/scripts/cli.py b/src/mednet/libs/classification/scripts/cli.py index b7839d0d..34ba1ff3 100644 --- a/src/mednet/libs/classification/scripts/cli.py +++ b/src/mednet/libs/classification/scripts/cli.py @@ -17,18 +17,14 @@ def classification(): pass -classification.add_command( - importlib.import_module("..config", package=__name__).config -) +classification.add_command(importlib.import_module("..config", package=__name__).config) classification.add_command( importlib.import_module("..database", package=__name__).database, ) classification.add_command( importlib.import_module("..predict", package=__name__).predict ) -classification.add_command( - importlib.import_module("..train", package=__name__).train -) +classification.add_command(importlib.import_module("..train", package=__name__).train) classification.add_command( importlib.import_module( "mednet.libs.common.scripts.train_analysis", diff --git a/src/mednet/libs/classification/tests/conftest.py b/src/mednet/libs/classification/tests/conftest.py index ecd74256..40e5f9d0 100644 --- a/src/mednet/libs/classification/tests/conftest.py +++ b/src/mednet/libs/classification/tests/conftest.py @@ -70,8 +70,7 @@ def pytest_runtest_setup(item): # iterates over all markers for the item being examined, get the first # argument and accumulate these names rc_names = [ - mark.args[0] - for mark in item.iter_markers(name="skip_if_rc_var_not_set") + mark.args[0] for mark in item.iter_markers(name="skip_if_rc_var_not_set") ] # checks all names mentioned are set in ~/.config/mednet.libs.classification.toml, otherwise, @@ -197,10 +196,7 @@ class DatabaseCheckers: assert "name" in batch[1] assert all( - [ - any([k.startswith(j) for j in prefixes]) - for k in batch[1]["name"] - ], + [any([k.startswith(j) for j in prefixes]) for k in batch[1]["name"]], ) # use the code below to view generated images @@ -227,9 +223,7 @@ class DatabaseCheckers: for idx, rs in enumerate(raw_samples): raw_samples_indices[rs[0]] = idx - for ref_hist_path, ref_hist_data in ref_histogram_splits[ - split_name - ]: + for ref_hist_path, ref_hist_data in ref_histogram_splits[split_name]: # Get index in the dataset that will return the data # corresponding to the specified sample name dataset_sample_index = raw_samples_indices[ref_hist_path] @@ -257,9 +251,7 @@ class DatabaseCheckers: # reference and check the similarity within a certain # threshold pearson_coeffs = numpy.corrcoef(histogram, ref_hist_data) - assert ( - 1 - pearson_coeff_threshold <= pearson_coeffs[0][1] <= 1 - ) + assert 1 - pearson_coeff_threshold <= pearson_coeffs[0][1] <= 1 else: assert histogram == ref_hist_data diff --git a/src/mednet/libs/classification/tests/test_cli_classification.py b/src/mednet/libs/classification/tests/test_cli_classification.py index 8c76cda2..8eb8a93c 100644 --- a/src/mednet/libs/classification/tests/test_cli_classification.py +++ b/src/mednet/libs/classification/tests/test_cli_classification.py @@ -441,9 +441,7 @@ def test_experiment(temporary_basedir): _assert_exit_0(result) assert (output_folder / "model" / "meta.json").exists() - assert ( - output_folder / "model" / f"model-at-epoch={num_epochs-1}.ckpt" - ).exists() + assert (output_folder / "model" / f"model-at-epoch={num_epochs-1}.ckpt").exists() assert (output_folder / "predictions" / "predictions.json").exists() assert (output_folder / "predictions" / "predictions.meta.json").exists() diff --git a/src/mednet/libs/common/data/datamodule.py b/src/mednet/libs/common/data/datamodule.py index 32e40ca0..dff5a24f 100644 --- a/src/mednet/libs/common/data/datamodule.py +++ b/src/mednet/libs/common/data/datamodule.py @@ -139,9 +139,7 @@ class _DelayedLoadingDataset(Dataset): def __getitem__(self, key: int) -> Sample: tensor, metadata = self.loader.sample(self.raw_dataset[key]) - return self.transform(tensor), transform_tvtensors( - metadata, self.transform - ) + return self.transform(tensor), transform_tvtensors(metadata, self.transform) def __len__(self): return len(self.raw_dataset) @@ -177,9 +175,7 @@ def _apply_loader_and_transforms( """ sample = load(info) - return model_transform(sample[0]), transform_tvtensors( - sample[1], model_transform - ) + return model_transform(sample[0]), transform_tvtensors(sample[1], model_transform) class _CachedDataset(Dataset): @@ -306,145 +302,6 @@ class _ConcatDataset(Dataset): yield from dataset -def _make_balanced_random_sampler( - dataset: Dataset, - target: str = "target", -) -> torch.utils.data.WeightedRandomSampler: - """Generate a pytorch sampler that samples according to class - probabilities. - - This function takes as input a torch Dataset, and computes the weights to - balance each class in the dataset, and the datasets themselves if one - passes a :py:class:`torch.utils.data.ConcatDataset`. - - In this implementation, we balance **both** class and dataset-origin - probabilities, what you expect for a truly *equitable* random sampler. - - Take this example for illustration: - - * Dataset 1: N = 10 samples, 9 samples with target=0, 1 sample with target=1 - * Dataset 2: N = 6 samples, 3 samples with target=0, 3 samples with target=1 - - So: - - | Dataset | Target | Samples | Weight | Normalised weight | - +---------+--------+---------+--------+-------------------+ - | 1 | 0 | 9 | 1/9 | 1/36 | - | 1 | 1 | 1 | 1/1 | 1/4 | - | 2 | 0 | 3 | 1/3 | 1/12 | - | 2 | 1 | 3 | 1/3 | 1/12 | - - Legend: - - * Weight: the weights computed by this method - * Normalised weight: the weight per sample used by the random sampler, - after normalising the weights by the sum of all weights in the - concatenated dataset, such that the sum of all normalized weights times - the number of samples is 1. - - The properties of this algorithm are as follows: - - 1. The probability of picking a sample from any target is the same (0.5 in - this case). To verify this, notice that the probability of picking a - sample with ``target=0`` is :math:`1/4 x 1 + 1/12 x 3 = 0.5`. - 2. The probability of picking a sample with ``target=0`` from Dataset 2 is - 3 times higher than those from Dataset 1. As there are 3 times fewer - samples in Dataset 2 with ``target=0``, this makes choosing samples from - Dataset 1 proportionally less likely. - 3. The probability of picking a sample with ``target=1`` from Dataset 2 is - 3 times lower than those from Dataset 1. As there are 3 times fewer - samples in Dataset 1 with ``target=1``, this makes choosing samples from - Dataset 2 proportionally less likely. - - This function assumes targets are stored on a dictionary entry named - ``target`` inside the metadata information for the - :py:data:`.typing.Sample`, and that its value is an integer. - - We then instantiate a pytorch sampler using the inverse probabilities (the - more samples in a class, the less likely it becomes to be sampled. - - Parameters - ---------- - dataset - An instance of torch Dataset. - :py:class:`torch.utils.data.ConcatDataset` are supported. - target - The name of a metadata key pointing to an integer property that allows - balancing the dataset. - - Returns - ------- - A sampler, to be used in a dataloader equipped with the same dataset - used to calculate the relative sample weights. - - Raises - ------ - RuntimeError - If requested to balance a dataset (single, not-concatenated) without an - existing target. - """ - - def _calculate_weights(targets: list[int]) -> list[float]: - counts = collections.Counter(targets) - weights = {k: 1.0 / v for k, v in counts.items()} - return [weights[k] for k in targets] - - if isinstance(dataset, torch.utils.data.ConcatDataset): - # There are two possible cases: targets/no-targets - metadata_example = dataset.datasets[0][0][1] - if target in metadata_example and isinstance( - metadata_example[target], - int, - ): - # there are integer targets, let's balance with those - logger.info( - f"Balancing sample selection probabilities **and** " - f"concatenated-datasets using metadata targets `{target}`", - ) - targets = [ - k - for ds in dataset.datasets - for k in typing.cast(Dataset, ds).targets() - ] - weights = _calculate_weights(targets) # type: ignore - else: - logger.warning( - f"Balancing samples **and** concatenated-datasets " - f"by using dataset totals as `{target}: int` is not true", - ) - weights = [ - k - for ds in dataset.datasets - for k in len(typing.cast(typing.Sized, ds)) - * [1.0 / len(typing.cast(typing.Sized, ds))] - ] - - pass - - else: - metadata_example = dataset[0][1] - if target in metadata_example and isinstance( - metadata_example[target], - int, - ): - logger.info( - f"Balancing samples from dataset using metadata " - f"targets `{target}`", - ) - weights = _calculate_weights(dataset.targets()) # type: ignore - else: - raise RuntimeError( - f"Cannot balance samples with multiple class targets " - f"({target}: list[int]) or without metadata targets `{target}`", - ) - - return torch.utils.data.WeightedRandomSampler( - weights, - len(weights), - replacement=True, - ) - - class ConcatDataModule(lightning.LightningDataModule): """A conveninent DataModule with dictionary split loading, mini- batching, parallelisation and caching, all in one. diff --git a/src/mednet/libs/common/models/normalizer.py b/src/mednet/libs/common/models/normalizer.py index 603de9c1..2eccf051 100644 --- a/src/mednet/libs/common/models/normalizer.py +++ b/src/mednet/libs/common/models/normalizer.py @@ -40,9 +40,7 @@ def make_z_normalizer( try: target = batch[1]["target"][0].item() if not isinstance(target, int): - logger.info( - "Targets are not Integer type, skipping z-normalization." - ) + logger.info("Targets are not Integer type, skipping z-normalization.") return None except RuntimeError: logger.info("Targets are not Integer type, skipping z-normalization.") diff --git a/src/mednet/libs/common/scripts/config.py b/src/mednet/libs/common/scripts/config.py index a813e90f..d4a8062b 100644 --- a/src/mednet/libs/common/scripts/config.py +++ b/src/mednet/libs/common/scripts/config.py @@ -93,9 +93,7 @@ def describe(name, entry_point_group, verbose) -> None: # numpydoc ignore=PR01 click.echo(inspect.getdoc(mod)) -def copy( - source, destination, entry_point_group -) -> None: # numpydoc ignore=PR01 +def copy(source, destination, entry_point_group) -> None: # numpydoc ignore=PR01 """Copy a specific configuration resource so it can be modified locally.""" import shutil diff --git a/src/mednet/libs/common/scripts/database.py b/src/mednet/libs/common/scripts/database.py index 63f1dde1..07c6f1f8 100644 --- a/src/mednet/libs/common/scripts/database.py +++ b/src/mednet/libs/common/scripts/database.py @@ -42,9 +42,7 @@ def check(entry_point_group, fold, limit): # numpydoc ignore=PR01 click.secho(f"Checking fold `{fold}`...", fg="yellow") try: - module = importlib.metadata.entry_points(group=entry_point_group)[ - fold - ].module + module = importlib.metadata.entry_points(group=entry_point_group)[fold].module except KeyError: raise Exception(f"Could not find database fold `{fold}`") diff --git a/src/mednet/libs/common/tests/conftest.py b/src/mednet/libs/common/tests/conftest.py index 07dbdce9..e9c23ab2 100644 --- a/src/mednet/libs/common/tests/conftest.py +++ b/src/mednet/libs/common/tests/conftest.py @@ -2,7 +2,6 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -import json import pathlib import pytest diff --git a/src/mednet/libs/common/utils/gitlab.py b/src/mednet/libs/common/utils/gitlab.py index d0b111dd..0baeba83 100644 --- a/src/mednet/libs/common/utils/gitlab.py +++ b/src/mednet/libs/common/utils/gitlab.py @@ -70,9 +70,7 @@ def sanitize_filename(tmpdir: pathlib.Path, path: pathlib.Path) -> pathlib.Path: return path absolute_sanitized_filename = tmpdir / sanitized_filename - logger.info( - f"Sanitazing filename `{path}` -> `{absolute_sanitized_filename}`" - ) + logger.info(f"Sanitazing filename `{path}` -> `{absolute_sanitized_filename}`") shutil.copy2(path, absolute_sanitized_filename) return absolute_sanitized_filename diff --git a/src/mednet/libs/segmentation/engine/evaluator.py b/src/mednet/libs/segmentation/engine/evaluator.py index 4cf5ecf0..01917104 100644 --- a/src/mednet/libs/segmentation/engine/evaluator.py +++ b/src/mednet/libs/segmentation/engine/evaluator.py @@ -144,8 +144,7 @@ def _sample_measures( step_size = 1.0 / steps data = [ - (index, threshold) - + sample_measures_for_threshold(pred, gt, mask, threshold) + (index, threshold) + sample_measures_for_threshold(pred, gt, mask, threshold) for index, threshold in enumerate(numpy.arange(0.0, 1.0, step_size)) ] @@ -499,9 +498,7 @@ def run( output_folder.mkdir(parents=True, exist_ok=True) measures_path = output_folder / f"{name}.csv" - logger.info( - f"Saving measures over all input images at {measures_path}..." - ) + logger.info(f"Saving measures over all input images at {measures_path}...") measures.to_csv(measures_path) return maxf1_threshold diff --git a/src/mednet/libs/segmentation/models/lwnet.py b/src/mednet/libs/segmentation/models/lwnet.py index 8bae2956..b9519e73 100644 --- a/src/mednet/libs/segmentation/models/lwnet.py +++ b/src/mednet/libs/segmentation/models/lwnet.py @@ -66,15 +66,11 @@ class ConvBlock(torch.nn.Module): else: self.pool = False - block.append( - torch.nn.Conv2d(in_c, out_c, kernel_size=k_sz, padding=pad) - ) + block.append(torch.nn.Conv2d(in_c, out_c, kernel_size=k_sz, padding=pad)) block.append(torch.nn.ReLU()) block.append(torch.nn.BatchNorm2d(out_c)) - block.append( - torch.nn.Conv2d(out_c, out_c, kernel_size=k_sz, padding=pad) - ) + block.append(torch.nn.Conv2d(out_c, out_c, kernel_size=k_sz, padding=pad)) block.append(torch.nn.ReLU()) block.append(torch.nn.BatchNorm2d(out_c)) @@ -106,14 +102,10 @@ class UpsampleBlock(torch.nn.Module): super().__init__() block = [] if up_mode == "transp_conv": - block.append( - torch.nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2) - ) + block.append(torch.nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2)) elif up_mode == "up_conv": block.append( - torch.nn.Upsample( - mode="bilinear", scale_factor=2, align_corners=False - ) + torch.nn.Upsample(mode="bilinear", scale_factor=2, align_corners=False) ) block.append(torch.nn.Conv2d(in_c, out_c, kernel_size=1)) else: @@ -141,9 +133,7 @@ class ConvBridgeBlock(torch.nn.Module): pad = (k_sz - 1) // 2 block = [] - block.append( - torch.nn.Conv2d(channels, channels, kernel_size=k_sz, padding=pad) - ) + block.append(torch.nn.Conv2d(channels, channels, kernel_size=k_sz, padding=pad)) block.append(torch.nn.ReLU()) block.append(torch.nn.BatchNorm2d(channels)) @@ -384,6 +374,4 @@ class LittleWNet(Model): return separate((probabilities, batch[1])) def configure_optimizers(self): - return self._optimizer_type( - self.parameters(), **self._optimizer_arguments - ) + return self._optimizer_type(self.parameters(), **self._optimizer_arguments) diff --git a/src/mednet/libs/segmentation/models/separate.py b/src/mednet/libs/segmentation/models/separate.py index 716addc6..4f2628f8 100644 --- a/src/mednet/libs/segmentation/models/separate.py +++ b/src/mednet/libs/segmentation/models/separate.py @@ -53,8 +53,7 @@ def separate(batch: Sample) -> list[SegmentationPrediction]: # as of now, this is really simple - to be made more complex upon need. metadata = [ - {key: value[i] for key, value in batch[1].items()} - for i in range(len(batch[0])) + {key: value[i] for key, value in batch[1].items()} for i in range(len(batch[0])) ] return _as_predictions(zip(batch[0], metadata)) diff --git a/src/mednet/libs/segmentation/scripts/evaluate.py b/src/mednet/libs/segmentation/scripts/evaluate.py index 84ba542c..0fdddc9a 100644 --- a/src/mednet/libs/segmentation/scripts/evaluate.py +++ b/src/mednet/libs/segmentation/scripts/evaluate.py @@ -176,9 +176,7 @@ def evaluate( # we try to convert it to float first threshold = float(threshold) if threshold < 0.0 or threshold > 1.0: - raise ValueError( - "Float thresholds must be within range [0.0, 1.0]" - ) + raise ValueError("Float thresholds must be within range [0.0, 1.0]") except ValueError: if threshold not in splits: raise ValueError( @@ -205,9 +203,7 @@ def evaluate( # Compute threshold on specified split if isinstance(threshold, str): logger.info(f"Evaluating threshold on '{threshold}' split") - threshold = run( - threshold, predict_data[threshold], output_folder, steps=steps - ) + threshold = run(threshold, predict_data[threshold], output_folder, steps=steps) logger.info(f"Set --threshold={threshold:.5f}") for split_name, predictions in predict_data.items(): diff --git a/src/mednet/libs/segmentation/scripts/experiment.py b/src/mednet/libs/segmentation/scripts/experiment.py index b909e26f..4c906c02 100644 --- a/src/mednet/libs/segmentation/scripts/experiment.py +++ b/src/mednet/libs/segmentation/scripts/experiment.py @@ -92,9 +92,7 @@ def experiment( train_stop_timestamp = datetime.now() logger.info(f"Ended training in {train_stop_timestamp}") - logger.info( - f"Training runtime: {train_stop_timestamp-train_start_timestamp}" - ) + logger.info(f"Training runtime: {train_stop_timestamp-train_start_timestamp}") logger.info("Started train analysis") from mednet.libs.common.scripts.train_analysis import train_analysis @@ -128,9 +126,7 @@ def experiment( predict_stop_timestamp = datetime.now() logger.info(f"Ended prediction in {predict_stop_timestamp}") - logger.info( - f"Prediction runtime: {predict_stop_timestamp-predict_start_timestamp}" - ) + logger.info(f"Prediction runtime: {predict_stop_timestamp-predict_start_timestamp}") evaluation_start_timestamp = datetime.now() logger.info(f"Started evaluation at {evaluation_start_timestamp}") diff --git a/src/mednet/libs/segmentation/tests/conftest.py b/src/mednet/libs/segmentation/tests/conftest.py index 648d9674..5f8efecc 100644 --- a/src/mednet/libs/segmentation/tests/conftest.py +++ b/src/mednet/libs/segmentation/tests/conftest.py @@ -70,8 +70,7 @@ def pytest_runtest_setup(item): # iterates over all markers for the item being examined, get the first # argument and accumulate these names rc_names = [ - mark.args[0] - for mark in item.iter_markers(name="skip_if_rc_var_not_set") + mark.args[0] for mark in item.iter_markers(name="skip_if_rc_var_not_set") ] # checks all names mentioned are set in ~/.config/mednet.libs.segmentation.toml, otherwise, @@ -182,25 +181,18 @@ class DatabaseCheckers: assert len(batch[1]) in [2, 3] # target, Optional(mask), name assert "target" in batch[1] - assert all( - [isinstance(target, torch.Tensor) for target in batch[1]["target"]] - ) + assert all([isinstance(target, torch.Tensor) for target in batch[1]["target"]]) if expected_num_targets: assert len(batch[1]["target"]) == expected_num_targets if "mask" in batch[1]: - assert all( - [isinstance(mask, torch.Tensor) for mask in batch[1]["mask"]] - ) + assert all([isinstance(mask, torch.Tensor) for mask in batch[1]["mask"]]) assert "name" in batch[1] if prefixes is not None: assert all( - [ - any([k.startswith(j) for j in prefixes]) - for k in batch[1]["name"] - ], + [any([k.startswith(j) for j in prefixes]) for k in batch[1]["name"]], ) @staticmethod @@ -222,9 +214,7 @@ class DatabaseCheckers: for idx, rs in enumerate(raw_samples): raw_samples_indices[rs[0]] = idx - for ref_hist_path, ref_hist_data in ref_histogram_splits[ - split_name - ]: + for ref_hist_path, ref_hist_data in ref_histogram_splits[split_name]: # Get index in the dataset that will return the data # corresponding to the specified sample name dataset_sample_index = raw_samples_indices[ref_hist_path] @@ -252,9 +242,7 @@ class DatabaseCheckers: # reference and check the similarity within a certain # threshold pearson_coeffs = numpy.corrcoef(histogram, ref_hist_data) - assert ( - 1 - pearson_coeff_threshold <= pearson_coeffs[0][1] <= 1 - ) + assert 1 - pearson_coeff_threshold <= pearson_coeffs[0][1] <= 1 else: assert histogram == ref_hist_data diff --git a/src/mednet/libs/segmentation/tests/test_measures.py b/src/mednet/libs/segmentation/tests/test_measures.py index 5d1b4069..a9fdceb0 100644 --- a/src/mednet/libs/segmentation/tests/test_measures.py +++ b/src/mednet/libs/segmentation/tests/test_measures.py @@ -54,9 +54,7 @@ class TestFrequentist(unittest.TestCase): def test_f1(self): p, r, s, a, j, f1 = base_measures(self.tp, self.fp, self.tn, self.fn) - self.assertEqual( - (2.0 * self.tp) / (2.0 * self.tp + self.fp + self.fn), f1 - ) + self.assertEqual((2.0 * self.tp) / (2.0 * self.tp + self.fp + self.fn), f1) self.assertAlmostEqual((2 * p * r) / (p + r), f1) # base definition @@ -133,27 +131,17 @@ class TestBayesian: fn = random.randint(100000, 1000000) _prec, _rec, _spec, _acc, _jac, _f1 = base_measures(tp, fp, tn, fn) - prec, rec, spec, acc, jac, f1 = bayesian_measures( - tp, fp, tn, fn, 0.5, 0.95 - ) + prec, rec, spec, acc, jac, f1 = bayesian_measures(tp, fp, tn, fn, 0.5, 0.95) # Notice that for very large k and l, the base frequentist measures # should be approximately the same as the bayesian mean and mode # extracted from the beta posterior. We test that here. - assert numpy.isclose( - _prec, prec[0] - ), f"freq: {_prec} <> bays: {prec[0]}" - assert numpy.isclose( - _prec, prec[1] - ), f"freq: {_prec} <> bays: {prec[1]}" + assert numpy.isclose(_prec, prec[0]), f"freq: {_prec} <> bays: {prec[0]}" + assert numpy.isclose(_prec, prec[1]), f"freq: {_prec} <> bays: {prec[1]}" assert numpy.isclose(_rec, rec[0]), f"freq: {_rec} <> bays: {rec[0]}" assert numpy.isclose(_rec, rec[1]), f"freq: {_rec} <> bays: {rec[1]}" - assert numpy.isclose( - _spec, spec[0] - ), f"freq: {_spec} <> bays: {spec[0]}" - assert numpy.isclose( - _spec, spec[1] - ), f"freq: {_spec} <> bays: {spec[1]}" + assert numpy.isclose(_spec, spec[0]), f"freq: {_spec} <> bays: {spec[0]}" + assert numpy.isclose(_spec, spec[1]), f"freq: {_spec} <> bays: {spec[1]}" assert numpy.isclose(_acc, acc[0]), f"freq: {_acc} <> bays: {acc[0]}" assert numpy.isclose(_acc, acc[1]), f"freq: {_acc} <> bays: {acc[1]}" assert numpy.isclose(_jac, jac[0]), f"freq: {_jac} <> bays: {jac[0]}" @@ -186,21 +174,11 @@ class TestBayesian: def test_auc(): # basic tests assert math.isclose(auc([0.0, 0.5, 1.0], [1.0, 1.0, 1.0]), 1.0) - assert math.isclose( - auc([0.0, 0.5, 1.0], [1.0, 0.5, 0.0]), 0.5, rel_tol=0.001 - ) - assert math.isclose( - auc([0.0, 0.5, 1.0], [0.0, 0.0, 0.0]), 0.0, rel_tol=0.001 - ) - assert math.isclose( - auc([0.0, 0.5, 1.0], [0.0, 1.0, 0.0]), 0.5, rel_tol=0.001 - ) - assert math.isclose( - auc([0.0, 0.5, 1.0], [0.0, 0.5, 0.0]), 0.25, rel_tol=0.001 - ) - assert math.isclose( - auc([0.0, 0.5, 1.0], [0.0, 0.5, 0.0]), 0.25, rel_tol=0.001 - ) + assert math.isclose(auc([0.0, 0.5, 1.0], [1.0, 0.5, 0.0]), 0.5, rel_tol=0.001) + assert math.isclose(auc([0.0, 0.5, 1.0], [0.0, 0.0, 0.0]), 0.0, rel_tol=0.001) + assert math.isclose(auc([0.0, 0.5, 1.0], [0.0, 1.0, 0.0]), 0.5, rel_tol=0.001) + assert math.isclose(auc([0.0, 0.5, 1.0], [0.0, 0.5, 0.0]), 0.25, rel_tol=0.001) + assert math.isclose(auc([0.0, 0.5, 1.0], [0.0, 0.5, 0.0]), 0.25, rel_tol=0.001) # reversing tht is also true assert math.isclose(auc([0.0, 0.5, 1.0][::-1], [1.0, 1.0, 1.0][::-1]), 1.0) @@ -222,9 +200,7 @@ def test_auc(): def test_auc_raises_value_error(): - with pytest.raises( - ValueError, match=r".*neither increasing nor decreasing.*" - ): + with pytest.raises(ValueError, match=r".*neither increasing nor decreasing.*"): # x is **not** monotonically increasing or decreasing assert math.isclose(auc([0.0, 0.5, 0.0], [1.0, 1.0, 1.0]), 1.0) diff --git a/src/mednet/libs/segmentation/utils/measure.py b/src/mednet/libs/segmentation/utils/measure.py index 1b56f315..8cbd22fb 100644 --- a/src/mednet/libs/segmentation/utils/measure.py +++ b/src/mednet/libs/segmentation/utils/measure.py @@ -370,9 +370,7 @@ def auc(x: numpy.ndarray, y: numpy.ndarray): x = x[::-1] y = y[::-1] else: - raise ValueError( - "x is neither increasing nor decreasing " f": {x}." - ) + raise ValueError("x is neither increasing nor decreasing " f": {x}.") y_interp = numpy.interp( numpy.arange(0, 1, 0.001), -- GitLab