diff --git a/helpers/generate_histograms.py b/helpers/generate_histograms.py index 8d4f83b515d3ad2f0ff51fff382fa43234529556..57c9f342258674b91caedae6bb82d2cda5931da9 100644 --- a/helpers/generate_histograms.py +++ b/helpers/generate_histograms.py @@ -299,9 +299,7 @@ def main( if not output_dir.exists(): output_dir.mkdir(parents=True, exist_ok=True) - output_file = ( - output_dir / f"histograms{model_name}_{k}_{fold_name}.json" - ) + output_file = output_dir / f"histograms{model_name}_{k}_{fold_name}.json" output_json_data = {} diff --git a/src/mednet/libs/classification/config/data/montgomery_shenzhen_indian/datamodule.py b/src/mednet/libs/classification/config/data/montgomery_shenzhen_indian/datamodule.py index b09fa338731907c6a5fcd0196eccff6d57013c2b..2f6caeac988d8b27639ed573dd7a63e96708226e 100644 --- a/src/mednet/libs/classification/config/data/montgomery_shenzhen_indian/datamodule.py +++ b/src/mednet/libs/classification/config/data/montgomery_shenzhen_indian/datamodule.py @@ -11,8 +11,8 @@ from mednet.libs.common.data.datamodule import ConcatDataModule from mednet.libs.common.data.split import make_split from ..indian.datamodule import CONFIGURATION_KEY_DATADIR as INDIAN_KEY_DATADIR -from ..indian.datamodule import DataModule as IndianDataModule from ..indian.datamodule import ClassificationRawDataLoader as IndianLoader +from ..indian.datamodule import DataModule as IndianDataModule from ..montgomery.datamodule import ( ClassificationRawDataLoader as MontgomeryLoader, ) diff --git a/src/mednet/libs/classification/config/data/montgomery_shenzhen_indian_tbx11k/datamodule.py b/src/mednet/libs/classification/config/data/montgomery_shenzhen_indian_tbx11k/datamodule.py index 532945277debb3558fae1a97c0a26d434690cea4..c83680ecb89efe2ab7b51accaa21ab9e476ceb5f 100644 --- a/src/mednet/libs/classification/config/data/montgomery_shenzhen_indian_tbx11k/datamodule.py +++ b/src/mednet/libs/classification/config/data/montgomery_shenzhen_indian_tbx11k/datamodule.py @@ -11,8 +11,8 @@ from mednet.libs.common.data.datamodule import ConcatDataModule from mednet.libs.common.data.split import make_split from ..indian.datamodule import CONFIGURATION_KEY_DATADIR as INDIAN_KEY_DATADIR -from ..indian.datamodule import DataModule as IndianDataModule from ..indian.datamodule import ClassificationRawDataLoader as IndianLoader +from ..indian.datamodule import DataModule as IndianDataModule from ..montgomery.datamodule import ( ClassificationRawDataLoader as MontgomeryLoader, ) diff --git a/src/mednet/libs/classification/config/data/nih_cxr14_padchest/datamodule.py b/src/mednet/libs/classification/config/data/nih_cxr14_padchest/datamodule.py index d5a44ce64e5060f69083f95824a5995f63246e05..e1907a35947bebc9f278266ee8bb6a3f16cc1876 100644 --- a/src/mednet/libs/classification/config/data/nih_cxr14_padchest/datamodule.py +++ b/src/mednet/libs/classification/config/data/nih_cxr14_padchest/datamodule.py @@ -11,6 +11,7 @@ from mednet.libs.common.data.split import make_split from ..nih_cxr14.datamodule import ClassificationRawDataLoader as CXR14Loader from ..padchest.datamodule import ClassificationRawDataLoader as PadchestLoader + class DataModule(ConcatDataModule): """Aggregated dataset composed of NIH CXR14 relabeld and PadChest (normalized) datasets. diff --git a/src/mednet/libs/classification/scripts/cli.py b/src/mednet/libs/classification/scripts/cli.py index b7839d0dd1bf78eae8eb81e4fdbd8078cc9bec0d..34ba1ff38d08568dd96f3df61e85b2f70669d079 100644 --- a/src/mednet/libs/classification/scripts/cli.py +++ b/src/mednet/libs/classification/scripts/cli.py @@ -17,18 +17,14 @@ def classification(): pass -classification.add_command( - importlib.import_module("..config", package=__name__).config -) +classification.add_command(importlib.import_module("..config", package=__name__).config) classification.add_command( importlib.import_module("..database", package=__name__).database, ) classification.add_command( importlib.import_module("..predict", package=__name__).predict ) -classification.add_command( - importlib.import_module("..train", package=__name__).train -) +classification.add_command(importlib.import_module("..train", package=__name__).train) classification.add_command( importlib.import_module( "mednet.libs.common.scripts.train_analysis", diff --git a/src/mednet/libs/classification/tests/conftest.py b/src/mednet/libs/classification/tests/conftest.py index ecd74256dbc8824d70e3962140498a88a313e827..40e5f9d0bde84d4cdc24707a0563cc03e97d2722 100644 --- a/src/mednet/libs/classification/tests/conftest.py +++ b/src/mednet/libs/classification/tests/conftest.py @@ -70,8 +70,7 @@ def pytest_runtest_setup(item): # iterates over all markers for the item being examined, get the first # argument and accumulate these names rc_names = [ - mark.args[0] - for mark in item.iter_markers(name="skip_if_rc_var_not_set") + mark.args[0] for mark in item.iter_markers(name="skip_if_rc_var_not_set") ] # checks all names mentioned are set in ~/.config/mednet.libs.classification.toml, otherwise, @@ -197,10 +196,7 @@ class DatabaseCheckers: assert "name" in batch[1] assert all( - [ - any([k.startswith(j) for j in prefixes]) - for k in batch[1]["name"] - ], + [any([k.startswith(j) for j in prefixes]) for k in batch[1]["name"]], ) # use the code below to view generated images @@ -227,9 +223,7 @@ class DatabaseCheckers: for idx, rs in enumerate(raw_samples): raw_samples_indices[rs[0]] = idx - for ref_hist_path, ref_hist_data in ref_histogram_splits[ - split_name - ]: + for ref_hist_path, ref_hist_data in ref_histogram_splits[split_name]: # Get index in the dataset that will return the data # corresponding to the specified sample name dataset_sample_index = raw_samples_indices[ref_hist_path] @@ -257,9 +251,7 @@ class DatabaseCheckers: # reference and check the similarity within a certain # threshold pearson_coeffs = numpy.corrcoef(histogram, ref_hist_data) - assert ( - 1 - pearson_coeff_threshold <= pearson_coeffs[0][1] <= 1 - ) + assert 1 - pearson_coeff_threshold <= pearson_coeffs[0][1] <= 1 else: assert histogram == ref_hist_data diff --git a/src/mednet/libs/classification/tests/test_cli_classification.py b/src/mednet/libs/classification/tests/test_cli_classification.py index 8c76cda2b0f8146d5a1e5c9a2024e25dc2a8857a..8eb8a93c69488d1d4e37bf0650e21e1c84e12ecf 100644 --- a/src/mednet/libs/classification/tests/test_cli_classification.py +++ b/src/mednet/libs/classification/tests/test_cli_classification.py @@ -441,9 +441,7 @@ def test_experiment(temporary_basedir): _assert_exit_0(result) assert (output_folder / "model" / "meta.json").exists() - assert ( - output_folder / "model" / f"model-at-epoch={num_epochs-1}.ckpt" - ).exists() + assert (output_folder / "model" / f"model-at-epoch={num_epochs-1}.ckpt").exists() assert (output_folder / "predictions" / "predictions.json").exists() assert (output_folder / "predictions" / "predictions.meta.json").exists() diff --git a/src/mednet/libs/common/data/datamodule.py b/src/mednet/libs/common/data/datamodule.py index 32e40ca0eaf919e501089eddbfda1768a5a450eb..dff5a24f0b0a95b2c9c1b36e1f962da14da32e4c 100644 --- a/src/mednet/libs/common/data/datamodule.py +++ b/src/mednet/libs/common/data/datamodule.py @@ -139,9 +139,7 @@ class _DelayedLoadingDataset(Dataset): def __getitem__(self, key: int) -> Sample: tensor, metadata = self.loader.sample(self.raw_dataset[key]) - return self.transform(tensor), transform_tvtensors( - metadata, self.transform - ) + return self.transform(tensor), transform_tvtensors(metadata, self.transform) def __len__(self): return len(self.raw_dataset) @@ -177,9 +175,7 @@ def _apply_loader_and_transforms( """ sample = load(info) - return model_transform(sample[0]), transform_tvtensors( - sample[1], model_transform - ) + return model_transform(sample[0]), transform_tvtensors(sample[1], model_transform) class _CachedDataset(Dataset): @@ -306,145 +302,6 @@ class _ConcatDataset(Dataset): yield from dataset -def _make_balanced_random_sampler( - dataset: Dataset, - target: str = "target", -) -> torch.utils.data.WeightedRandomSampler: - """Generate a pytorch sampler that samples according to class - probabilities. - - This function takes as input a torch Dataset, and computes the weights to - balance each class in the dataset, and the datasets themselves if one - passes a :py:class:`torch.utils.data.ConcatDataset`. - - In this implementation, we balance **both** class and dataset-origin - probabilities, what you expect for a truly *equitable* random sampler. - - Take this example for illustration: - - * Dataset 1: N = 10 samples, 9 samples with target=0, 1 sample with target=1 - * Dataset 2: N = 6 samples, 3 samples with target=0, 3 samples with target=1 - - So: - - | Dataset | Target | Samples | Weight | Normalised weight | - +---------+--------+---------+--------+-------------------+ - | 1 | 0 | 9 | 1/9 | 1/36 | - | 1 | 1 | 1 | 1/1 | 1/4 | - | 2 | 0 | 3 | 1/3 | 1/12 | - | 2 | 1 | 3 | 1/3 | 1/12 | - - Legend: - - * Weight: the weights computed by this method - * Normalised weight: the weight per sample used by the random sampler, - after normalising the weights by the sum of all weights in the - concatenated dataset, such that the sum of all normalized weights times - the number of samples is 1. - - The properties of this algorithm are as follows: - - 1. The probability of picking a sample from any target is the same (0.5 in - this case). To verify this, notice that the probability of picking a - sample with ``target=0`` is :math:`1/4 x 1 + 1/12 x 3 = 0.5`. - 2. The probability of picking a sample with ``target=0`` from Dataset 2 is - 3 times higher than those from Dataset 1. As there are 3 times fewer - samples in Dataset 2 with ``target=0``, this makes choosing samples from - Dataset 1 proportionally less likely. - 3. The probability of picking a sample with ``target=1`` from Dataset 2 is - 3 times lower than those from Dataset 1. As there are 3 times fewer - samples in Dataset 1 with ``target=1``, this makes choosing samples from - Dataset 2 proportionally less likely. - - This function assumes targets are stored on a dictionary entry named - ``target`` inside the metadata information for the - :py:data:`.typing.Sample`, and that its value is an integer. - - We then instantiate a pytorch sampler using the inverse probabilities (the - more samples in a class, the less likely it becomes to be sampled. - - Parameters - ---------- - dataset - An instance of torch Dataset. - :py:class:`torch.utils.data.ConcatDataset` are supported. - target - The name of a metadata key pointing to an integer property that allows - balancing the dataset. - - Returns - ------- - A sampler, to be used in a dataloader equipped with the same dataset - used to calculate the relative sample weights. - - Raises - ------ - RuntimeError - If requested to balance a dataset (single, not-concatenated) without an - existing target. - """ - - def _calculate_weights(targets: list[int]) -> list[float]: - counts = collections.Counter(targets) - weights = {k: 1.0 / v for k, v in counts.items()} - return [weights[k] for k in targets] - - if isinstance(dataset, torch.utils.data.ConcatDataset): - # There are two possible cases: targets/no-targets - metadata_example = dataset.datasets[0][0][1] - if target in metadata_example and isinstance( - metadata_example[target], - int, - ): - # there are integer targets, let's balance with those - logger.info( - f"Balancing sample selection probabilities **and** " - f"concatenated-datasets using metadata targets `{target}`", - ) - targets = [ - k - for ds in dataset.datasets - for k in typing.cast(Dataset, ds).targets() - ] - weights = _calculate_weights(targets) # type: ignore - else: - logger.warning( - f"Balancing samples **and** concatenated-datasets " - f"by using dataset totals as `{target}: int` is not true", - ) - weights = [ - k - for ds in dataset.datasets - for k in len(typing.cast(typing.Sized, ds)) - * [1.0 / len(typing.cast(typing.Sized, ds))] - ] - - pass - - else: - metadata_example = dataset[0][1] - if target in metadata_example and isinstance( - metadata_example[target], - int, - ): - logger.info( - f"Balancing samples from dataset using metadata " - f"targets `{target}`", - ) - weights = _calculate_weights(dataset.targets()) # type: ignore - else: - raise RuntimeError( - f"Cannot balance samples with multiple class targets " - f"({target}: list[int]) or without metadata targets `{target}`", - ) - - return torch.utils.data.WeightedRandomSampler( - weights, - len(weights), - replacement=True, - ) - - class ConcatDataModule(lightning.LightningDataModule): """A conveninent DataModule with dictionary split loading, mini- batching, parallelisation and caching, all in one. diff --git a/src/mednet/libs/common/models/normalizer.py b/src/mednet/libs/common/models/normalizer.py index 603de9c132d0a51a1c2cd56958574fdc874bb56f..2eccf0511239d1557557ce6ad16b888291040f2a 100644 --- a/src/mednet/libs/common/models/normalizer.py +++ b/src/mednet/libs/common/models/normalizer.py @@ -40,9 +40,7 @@ def make_z_normalizer( try: target = batch[1]["target"][0].item() if not isinstance(target, int): - logger.info( - "Targets are not Integer type, skipping z-normalization." - ) + logger.info("Targets are not Integer type, skipping z-normalization.") return None except RuntimeError: logger.info("Targets are not Integer type, skipping z-normalization.") diff --git a/src/mednet/libs/common/scripts/config.py b/src/mednet/libs/common/scripts/config.py index a813e90f717e48eb3258a7557750627ab0c1c0e4..d4a8062b60c9962c5284817b24444e4dc2232220 100644 --- a/src/mednet/libs/common/scripts/config.py +++ b/src/mednet/libs/common/scripts/config.py @@ -93,9 +93,7 @@ def describe(name, entry_point_group, verbose) -> None: # numpydoc ignore=PR01 click.echo(inspect.getdoc(mod)) -def copy( - source, destination, entry_point_group -) -> None: # numpydoc ignore=PR01 +def copy(source, destination, entry_point_group) -> None: # numpydoc ignore=PR01 """Copy a specific configuration resource so it can be modified locally.""" import shutil diff --git a/src/mednet/libs/common/scripts/database.py b/src/mednet/libs/common/scripts/database.py index 63f1dde1c17c28778288151d0db06fa76ab7a378..07c6f1f8d63fbfafa082f7a9cba2e834b2158f99 100644 --- a/src/mednet/libs/common/scripts/database.py +++ b/src/mednet/libs/common/scripts/database.py @@ -42,9 +42,7 @@ def check(entry_point_group, fold, limit): # numpydoc ignore=PR01 click.secho(f"Checking fold `{fold}`...", fg="yellow") try: - module = importlib.metadata.entry_points(group=entry_point_group)[ - fold - ].module + module = importlib.metadata.entry_points(group=entry_point_group)[fold].module except KeyError: raise Exception(f"Could not find database fold `{fold}`") diff --git a/src/mednet/libs/common/tests/conftest.py b/src/mednet/libs/common/tests/conftest.py index 07dbdce9d5e028a9335bfaa6080e07991f91612a..e9c23ab2fa4c458078c42f5d0c5cdd3045cc0d9a 100644 --- a/src/mednet/libs/common/tests/conftest.py +++ b/src/mednet/libs/common/tests/conftest.py @@ -2,7 +2,6 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -import json import pathlib import pytest diff --git a/src/mednet/libs/common/utils/gitlab.py b/src/mednet/libs/common/utils/gitlab.py index d0b111ddcf1f54ec995ffb0544c0199aa3d4f7b4..0baeba83012905eff8b43e504404c52dd9584749 100644 --- a/src/mednet/libs/common/utils/gitlab.py +++ b/src/mednet/libs/common/utils/gitlab.py @@ -70,9 +70,7 @@ def sanitize_filename(tmpdir: pathlib.Path, path: pathlib.Path) -> pathlib.Path: return path absolute_sanitized_filename = tmpdir / sanitized_filename - logger.info( - f"Sanitazing filename `{path}` -> `{absolute_sanitized_filename}`" - ) + logger.info(f"Sanitazing filename `{path}` -> `{absolute_sanitized_filename}`") shutil.copy2(path, absolute_sanitized_filename) return absolute_sanitized_filename diff --git a/src/mednet/libs/segmentation/engine/evaluator.py b/src/mednet/libs/segmentation/engine/evaluator.py index 4cf5ecf02a29a486cb834f5d76c935ac36431f66..019171043df6577cb75463bc98126838ebccd761 100644 --- a/src/mednet/libs/segmentation/engine/evaluator.py +++ b/src/mednet/libs/segmentation/engine/evaluator.py @@ -144,8 +144,7 @@ def _sample_measures( step_size = 1.0 / steps data = [ - (index, threshold) - + sample_measures_for_threshold(pred, gt, mask, threshold) + (index, threshold) + sample_measures_for_threshold(pred, gt, mask, threshold) for index, threshold in enumerate(numpy.arange(0.0, 1.0, step_size)) ] @@ -499,9 +498,7 @@ def run( output_folder.mkdir(parents=True, exist_ok=True) measures_path = output_folder / f"{name}.csv" - logger.info( - f"Saving measures over all input images at {measures_path}..." - ) + logger.info(f"Saving measures over all input images at {measures_path}...") measures.to_csv(measures_path) return maxf1_threshold diff --git a/src/mednet/libs/segmentation/models/lwnet.py b/src/mednet/libs/segmentation/models/lwnet.py index 8bae2956f333216f75a21e4109fdf63fc1a93851..b9519e734331788443d0205dde13844713eb9d2c 100644 --- a/src/mednet/libs/segmentation/models/lwnet.py +++ b/src/mednet/libs/segmentation/models/lwnet.py @@ -66,15 +66,11 @@ class ConvBlock(torch.nn.Module): else: self.pool = False - block.append( - torch.nn.Conv2d(in_c, out_c, kernel_size=k_sz, padding=pad) - ) + block.append(torch.nn.Conv2d(in_c, out_c, kernel_size=k_sz, padding=pad)) block.append(torch.nn.ReLU()) block.append(torch.nn.BatchNorm2d(out_c)) - block.append( - torch.nn.Conv2d(out_c, out_c, kernel_size=k_sz, padding=pad) - ) + block.append(torch.nn.Conv2d(out_c, out_c, kernel_size=k_sz, padding=pad)) block.append(torch.nn.ReLU()) block.append(torch.nn.BatchNorm2d(out_c)) @@ -106,14 +102,10 @@ class UpsampleBlock(torch.nn.Module): super().__init__() block = [] if up_mode == "transp_conv": - block.append( - torch.nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2) - ) + block.append(torch.nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2)) elif up_mode == "up_conv": block.append( - torch.nn.Upsample( - mode="bilinear", scale_factor=2, align_corners=False - ) + torch.nn.Upsample(mode="bilinear", scale_factor=2, align_corners=False) ) block.append(torch.nn.Conv2d(in_c, out_c, kernel_size=1)) else: @@ -141,9 +133,7 @@ class ConvBridgeBlock(torch.nn.Module): pad = (k_sz - 1) // 2 block = [] - block.append( - torch.nn.Conv2d(channels, channels, kernel_size=k_sz, padding=pad) - ) + block.append(torch.nn.Conv2d(channels, channels, kernel_size=k_sz, padding=pad)) block.append(torch.nn.ReLU()) block.append(torch.nn.BatchNorm2d(channels)) @@ -384,6 +374,4 @@ class LittleWNet(Model): return separate((probabilities, batch[1])) def configure_optimizers(self): - return self._optimizer_type( - self.parameters(), **self._optimizer_arguments - ) + return self._optimizer_type(self.parameters(), **self._optimizer_arguments) diff --git a/src/mednet/libs/segmentation/models/separate.py b/src/mednet/libs/segmentation/models/separate.py index 716addc67de95b60978526cc7c9a6bf92b408a81..4f2628f8b41d4b1fea238b8cbda441200cfe109e 100644 --- a/src/mednet/libs/segmentation/models/separate.py +++ b/src/mednet/libs/segmentation/models/separate.py @@ -53,8 +53,7 @@ def separate(batch: Sample) -> list[SegmentationPrediction]: # as of now, this is really simple - to be made more complex upon need. metadata = [ - {key: value[i] for key, value in batch[1].items()} - for i in range(len(batch[0])) + {key: value[i] for key, value in batch[1].items()} for i in range(len(batch[0])) ] return _as_predictions(zip(batch[0], metadata)) diff --git a/src/mednet/libs/segmentation/scripts/evaluate.py b/src/mednet/libs/segmentation/scripts/evaluate.py index 84ba542c6dff4a4547846c80513be9735201b386..0fdddc9adea3a9cabe4a0d70e08fb27414214372 100644 --- a/src/mednet/libs/segmentation/scripts/evaluate.py +++ b/src/mednet/libs/segmentation/scripts/evaluate.py @@ -176,9 +176,7 @@ def evaluate( # we try to convert it to float first threshold = float(threshold) if threshold < 0.0 or threshold > 1.0: - raise ValueError( - "Float thresholds must be within range [0.0, 1.0]" - ) + raise ValueError("Float thresholds must be within range [0.0, 1.0]") except ValueError: if threshold not in splits: raise ValueError( @@ -205,9 +203,7 @@ def evaluate( # Compute threshold on specified split if isinstance(threshold, str): logger.info(f"Evaluating threshold on '{threshold}' split") - threshold = run( - threshold, predict_data[threshold], output_folder, steps=steps - ) + threshold = run(threshold, predict_data[threshold], output_folder, steps=steps) logger.info(f"Set --threshold={threshold:.5f}") for split_name, predictions in predict_data.items(): diff --git a/src/mednet/libs/segmentation/scripts/experiment.py b/src/mednet/libs/segmentation/scripts/experiment.py index b909e26f3eb9ed28d01d2fbfc89080dc4e06f2f0..4c906c023f6e2d6e4e63f3973bb7c676b73040f6 100644 --- a/src/mednet/libs/segmentation/scripts/experiment.py +++ b/src/mednet/libs/segmentation/scripts/experiment.py @@ -92,9 +92,7 @@ def experiment( train_stop_timestamp = datetime.now() logger.info(f"Ended training in {train_stop_timestamp}") - logger.info( - f"Training runtime: {train_stop_timestamp-train_start_timestamp}" - ) + logger.info(f"Training runtime: {train_stop_timestamp-train_start_timestamp}") logger.info("Started train analysis") from mednet.libs.common.scripts.train_analysis import train_analysis @@ -128,9 +126,7 @@ def experiment( predict_stop_timestamp = datetime.now() logger.info(f"Ended prediction in {predict_stop_timestamp}") - logger.info( - f"Prediction runtime: {predict_stop_timestamp-predict_start_timestamp}" - ) + logger.info(f"Prediction runtime: {predict_stop_timestamp-predict_start_timestamp}") evaluation_start_timestamp = datetime.now() logger.info(f"Started evaluation at {evaluation_start_timestamp}") diff --git a/src/mednet/libs/segmentation/tests/conftest.py b/src/mednet/libs/segmentation/tests/conftest.py index 648d96740cb362bf0999230cb2cfd6c694223979..5f8efecc9ba6cd913b9de9d92453d7ced866e0de 100644 --- a/src/mednet/libs/segmentation/tests/conftest.py +++ b/src/mednet/libs/segmentation/tests/conftest.py @@ -70,8 +70,7 @@ def pytest_runtest_setup(item): # iterates over all markers for the item being examined, get the first # argument and accumulate these names rc_names = [ - mark.args[0] - for mark in item.iter_markers(name="skip_if_rc_var_not_set") + mark.args[0] for mark in item.iter_markers(name="skip_if_rc_var_not_set") ] # checks all names mentioned are set in ~/.config/mednet.libs.segmentation.toml, otherwise, @@ -182,25 +181,18 @@ class DatabaseCheckers: assert len(batch[1]) in [2, 3] # target, Optional(mask), name assert "target" in batch[1] - assert all( - [isinstance(target, torch.Tensor) for target in batch[1]["target"]] - ) + assert all([isinstance(target, torch.Tensor) for target in batch[1]["target"]]) if expected_num_targets: assert len(batch[1]["target"]) == expected_num_targets if "mask" in batch[1]: - assert all( - [isinstance(mask, torch.Tensor) for mask in batch[1]["mask"]] - ) + assert all([isinstance(mask, torch.Tensor) for mask in batch[1]["mask"]]) assert "name" in batch[1] if prefixes is not None: assert all( - [ - any([k.startswith(j) for j in prefixes]) - for k in batch[1]["name"] - ], + [any([k.startswith(j) for j in prefixes]) for k in batch[1]["name"]], ) @staticmethod @@ -222,9 +214,7 @@ class DatabaseCheckers: for idx, rs in enumerate(raw_samples): raw_samples_indices[rs[0]] = idx - for ref_hist_path, ref_hist_data in ref_histogram_splits[ - split_name - ]: + for ref_hist_path, ref_hist_data in ref_histogram_splits[split_name]: # Get index in the dataset that will return the data # corresponding to the specified sample name dataset_sample_index = raw_samples_indices[ref_hist_path] @@ -252,9 +242,7 @@ class DatabaseCheckers: # reference and check the similarity within a certain # threshold pearson_coeffs = numpy.corrcoef(histogram, ref_hist_data) - assert ( - 1 - pearson_coeff_threshold <= pearson_coeffs[0][1] <= 1 - ) + assert 1 - pearson_coeff_threshold <= pearson_coeffs[0][1] <= 1 else: assert histogram == ref_hist_data diff --git a/src/mednet/libs/segmentation/tests/test_measures.py b/src/mednet/libs/segmentation/tests/test_measures.py index 5d1b40699250368664a2dc1ec1c2f0fe25db05fb..a9fdceb02a077a7dcd3b59bba91f271f6a45dcb0 100644 --- a/src/mednet/libs/segmentation/tests/test_measures.py +++ b/src/mednet/libs/segmentation/tests/test_measures.py @@ -54,9 +54,7 @@ class TestFrequentist(unittest.TestCase): def test_f1(self): p, r, s, a, j, f1 = base_measures(self.tp, self.fp, self.tn, self.fn) - self.assertEqual( - (2.0 * self.tp) / (2.0 * self.tp + self.fp + self.fn), f1 - ) + self.assertEqual((2.0 * self.tp) / (2.0 * self.tp + self.fp + self.fn), f1) self.assertAlmostEqual((2 * p * r) / (p + r), f1) # base definition @@ -133,27 +131,17 @@ class TestBayesian: fn = random.randint(100000, 1000000) _prec, _rec, _spec, _acc, _jac, _f1 = base_measures(tp, fp, tn, fn) - prec, rec, spec, acc, jac, f1 = bayesian_measures( - tp, fp, tn, fn, 0.5, 0.95 - ) + prec, rec, spec, acc, jac, f1 = bayesian_measures(tp, fp, tn, fn, 0.5, 0.95) # Notice that for very large k and l, the base frequentist measures # should be approximately the same as the bayesian mean and mode # extracted from the beta posterior. We test that here. - assert numpy.isclose( - _prec, prec[0] - ), f"freq: {_prec} <> bays: {prec[0]}" - assert numpy.isclose( - _prec, prec[1] - ), f"freq: {_prec} <> bays: {prec[1]}" + assert numpy.isclose(_prec, prec[0]), f"freq: {_prec} <> bays: {prec[0]}" + assert numpy.isclose(_prec, prec[1]), f"freq: {_prec} <> bays: {prec[1]}" assert numpy.isclose(_rec, rec[0]), f"freq: {_rec} <> bays: {rec[0]}" assert numpy.isclose(_rec, rec[1]), f"freq: {_rec} <> bays: {rec[1]}" - assert numpy.isclose( - _spec, spec[0] - ), f"freq: {_spec} <> bays: {spec[0]}" - assert numpy.isclose( - _spec, spec[1] - ), f"freq: {_spec} <> bays: {spec[1]}" + assert numpy.isclose(_spec, spec[0]), f"freq: {_spec} <> bays: {spec[0]}" + assert numpy.isclose(_spec, spec[1]), f"freq: {_spec} <> bays: {spec[1]}" assert numpy.isclose(_acc, acc[0]), f"freq: {_acc} <> bays: {acc[0]}" assert numpy.isclose(_acc, acc[1]), f"freq: {_acc} <> bays: {acc[1]}" assert numpy.isclose(_jac, jac[0]), f"freq: {_jac} <> bays: {jac[0]}" @@ -186,21 +174,11 @@ class TestBayesian: def test_auc(): # basic tests assert math.isclose(auc([0.0, 0.5, 1.0], [1.0, 1.0, 1.0]), 1.0) - assert math.isclose( - auc([0.0, 0.5, 1.0], [1.0, 0.5, 0.0]), 0.5, rel_tol=0.001 - ) - assert math.isclose( - auc([0.0, 0.5, 1.0], [0.0, 0.0, 0.0]), 0.0, rel_tol=0.001 - ) - assert math.isclose( - auc([0.0, 0.5, 1.0], [0.0, 1.0, 0.0]), 0.5, rel_tol=0.001 - ) - assert math.isclose( - auc([0.0, 0.5, 1.0], [0.0, 0.5, 0.0]), 0.25, rel_tol=0.001 - ) - assert math.isclose( - auc([0.0, 0.5, 1.0], [0.0, 0.5, 0.0]), 0.25, rel_tol=0.001 - ) + assert math.isclose(auc([0.0, 0.5, 1.0], [1.0, 0.5, 0.0]), 0.5, rel_tol=0.001) + assert math.isclose(auc([0.0, 0.5, 1.0], [0.0, 0.0, 0.0]), 0.0, rel_tol=0.001) + assert math.isclose(auc([0.0, 0.5, 1.0], [0.0, 1.0, 0.0]), 0.5, rel_tol=0.001) + assert math.isclose(auc([0.0, 0.5, 1.0], [0.0, 0.5, 0.0]), 0.25, rel_tol=0.001) + assert math.isclose(auc([0.0, 0.5, 1.0], [0.0, 0.5, 0.0]), 0.25, rel_tol=0.001) # reversing tht is also true assert math.isclose(auc([0.0, 0.5, 1.0][::-1], [1.0, 1.0, 1.0][::-1]), 1.0) @@ -222,9 +200,7 @@ def test_auc(): def test_auc_raises_value_error(): - with pytest.raises( - ValueError, match=r".*neither increasing nor decreasing.*" - ): + with pytest.raises(ValueError, match=r".*neither increasing nor decreasing.*"): # x is **not** monotonically increasing or decreasing assert math.isclose(auc([0.0, 0.5, 0.0], [1.0, 1.0, 1.0]), 1.0) diff --git a/src/mednet/libs/segmentation/utils/measure.py b/src/mednet/libs/segmentation/utils/measure.py index 1b56f31520c098d9c68096ec93d35a87dac55591..8cbd22fbd7135c638b8f1c614b14104e62528f34 100644 --- a/src/mednet/libs/segmentation/utils/measure.py +++ b/src/mednet/libs/segmentation/utils/measure.py @@ -370,9 +370,7 @@ def auc(x: numpy.ndarray, y: numpy.ndarray): x = x[::-1] y = y[::-1] else: - raise ValueError( - "x is neither increasing nor decreasing " f": {x}." - ) + raise ValueError("x is neither increasing nor decreasing " f": {x}.") y_interp = numpy.interp( numpy.arange(0, 1, 0.001),