diff --git a/conda/meta.yaml b/conda/meta.yaml index a1bed5c468fbf89b641800779d7aab4f79888415..33434fb6550f0a393de6e6487f79e637238be3cc 100644 --- a/conda/meta.yaml +++ b/conda/meta.yaml @@ -25,6 +25,7 @@ requirements: - pip - clapper {{ clapper }} - click {{ click }} + - credible {{ credible }} - grad-cam {{ grad_cam }} - matplotlib {{ matplotlib }} - numpy {{ numpy }} @@ -44,6 +45,7 @@ requirements: - python >=3.10 - {{ pin_compatible('clapper') }} - {{ pin_compatible('click') }} + - {{ pin_compatible('credible') }} - {{ pin_compatible('grad-cam', max_pin='x.x') }} - {{ pin_compatible('matplotlib') }} - {{ pin_compatible('numpy') }} diff --git a/pyproject.toml b/pyproject.toml index 55db0f21b1b3c5415e41416f8a2ff074924c880b..0457cdff94604b93b689e4cf0176ac2080238010 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ classifiers = [ dependencies = [ "clapper", "click", + "credible", "numpy", "scipy", "scikit-image", diff --git a/src/mednet/config/data/hivtb/datamodule.py b/src/mednet/config/data/hivtb/datamodule.py index 68a7b7a3bdfea302ecd41bc750355b354b85c5f2..cae64f2c53b2fda0992aea1f662307ddab27d616 100644 --- a/src/mednet/config/data/hivtb/datamodule.py +++ b/src/mednet/config/data/hivtb/datamodule.py @@ -141,4 +141,6 @@ class DataModule(CachingDataModule): super().__init__( database_split=make_split(split_filename), raw_data_loader=RawDataLoader(), + database_name=__package__.split(".")[-1], + split_name=os.path.splitext(split_filename)[0], ) diff --git a/src/mednet/config/data/indian/datamodule.py b/src/mednet/config/data/indian/datamodule.py index 08a507223436be949b4ae03de1c2671060697c1c..2fc0567bbbefd65ec4c96fc19592c1ec1db7215c 100644 --- a/src/mednet/config/data/indian/datamodule.py +++ b/src/mednet/config/data/indian/datamodule.py @@ -7,6 +7,7 @@ Database reference: [INDIAN-2013]_ """ import importlib.resources +import os from ....config.data.shenzhen.datamodule import RawDataLoader from ....data.datamodule import CachingDataModule @@ -82,4 +83,6 @@ class DataModule(CachingDataModule): raw_data_loader=RawDataLoader( config_variable=CONFIGURATION_KEY_DATADIR ), + database_name=__package__.split(".")[-1], + split_name=os.path.splitext(split_filename)[0], ) diff --git a/src/mednet/config/data/montgomery/datamodule.py b/src/mednet/config/data/montgomery/datamodule.py index 86e9fdb7058c3e25351212b93be7a8167c9ecc03..5ed7fa50e325810f3b8523a883aa5f9f8c1c301b 100644 --- a/src/mednet/config/data/montgomery/datamodule.py +++ b/src/mednet/config/data/montgomery/datamodule.py @@ -143,4 +143,6 @@ class DataModule(CachingDataModule): super().__init__( database_split=make_split(split_filename), raw_data_loader=RawDataLoader(), + database_name=__package__.split(".")[-1], + split_name=os.path.splitext(split_filename)[0], ) diff --git a/src/mednet/config/data/montgomery_shenzhen/datamodule.py b/src/mednet/config/data/montgomery_shenzhen/datamodule.py index fa83fdde5165e24c0d08b0b9c123b87bd98e3805..6df353ad70d701d8585be9b46dc4ea540adccdb4 100644 --- a/src/mednet/config/data/montgomery_shenzhen/datamodule.py +++ b/src/mednet/config/data/montgomery_shenzhen/datamodule.py @@ -1,6 +1,9 @@ # Copyright © 2022 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later +"""Aggregated DataModule composed of Montgomery and Shenzhen databases.""" + +import os from ....data.datamodule import ConcatDataModule from ..montgomery.datamodule import RawDataLoader as MontgomeryLoader @@ -38,5 +41,7 @@ class DataModule(ConcatDataModule): (montgomery_split["test"], montgomery_loader), (shenzhen_split["test"], shenzhen_loader), ], - } + }, + database_name=__package__.split(".")[-1], + split_name=os.path.splitext(split_filename)[0], ) diff --git a/src/mednet/config/data/montgomery_shenzhen_indian/datamodule.py b/src/mednet/config/data/montgomery_shenzhen_indian/datamodule.py index 92617fddc7f93608d22297d173abf47df0a46a2b..fc9af897b9501adcf40e46a0c841d06aefec214b 100644 --- a/src/mednet/config/data/montgomery_shenzhen_indian/datamodule.py +++ b/src/mednet/config/data/montgomery_shenzhen_indian/datamodule.py @@ -1,7 +1,9 @@ # Copyright © 2022 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Aggregated DataModule composed of Montgomery, Shenzhen and Indian datasets.""" +"""Aggregated DataModule composed of Montgomery, Shenzhen and Indian databases.""" + +import os from ....data.datamodule import ConcatDataModule from ..indian.datamodule import CONFIGURATION_KEY_DATADIR as INDIAN_KEY_DATADIR @@ -14,7 +16,8 @@ from ..shenzhen.datamodule import make_split as make_shenzhen_split class DataModule(ConcatDataModule): - """Aggregated DataModule composed of Montgomery, Shenzhen and Indian datasets. + """Aggregated DataModule composed of Montgomery, Shenzhen and Indian + datasets. Parameters ---------- @@ -47,5 +50,7 @@ class DataModule(ConcatDataModule): (shenzhen_split["test"], shenzhen_loader), (indian_split["test"], indian_loader), ], - } + }, + database_name=__package__.split(".")[-1], + split_name=os.path.splitext(split_filename)[0], ) diff --git a/src/mednet/config/data/montgomery_shenzhen_indian_padchest/datamodule.py b/src/mednet/config/data/montgomery_shenzhen_indian_padchest/datamodule.py index 7cff19e0efb6b9e4417992a866b2da315f81845e..bbfda89bd3d54d695f82004d79dc59c8566af21a 100644 --- a/src/mednet/config/data/montgomery_shenzhen_indian_padchest/datamodule.py +++ b/src/mednet/config/data/montgomery_shenzhen_indian_padchest/datamodule.py @@ -3,6 +3,8 @@ # SPDX-License-Identifier: GPL-3.0-or-later """Aggregated DataModule composed of Montgomery, Shenzhen, Indian, and PadChest datasets.""" +import os + from ....data.datamodule import ConcatDataModule from ..indian.datamodule import CONFIGURATION_KEY_DATADIR as INDIAN_KEY_DATADIR from ..indian.datamodule import RawDataLoader as IndianLoader @@ -58,5 +60,11 @@ class DataModule(ConcatDataModule): (indian_split["test"], indian_loader), (padchest_split["test"], padchest_loader), ], - } + }, + database_name=__package__.split(".")[-1], + split_name=( + os.path.splitext(split_filename)[0] + + "+" + + os.path.splitext(padchest_split_filename)[0] + ), ) diff --git a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/datamodule.py b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/datamodule.py index 358648d83f0f5cf824604950d7221b779803e3c0..f8f82c967289219ad0adac7646874d01dcffc1c3 100644 --- a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/datamodule.py +++ b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/datamodule.py @@ -3,6 +3,8 @@ # SPDX-License-Identifier: GPL-3.0-or-later """Aggregated DataModule composed of Montgomery, Shenzhen, Indian, and TBX11k datasets.""" +import os + from ....data.datamodule import ConcatDataModule from ..indian.datamodule import CONFIGURATION_KEY_DATADIR as INDIAN_KEY_DATADIR from ..indian.datamodule import RawDataLoader as IndianLoader @@ -57,5 +59,11 @@ class DataModule(ConcatDataModule): (indian_split["test"], indian_loader), (tbx11k_split["test"], tbx11k_loader), ], - } + }, + database_name=__package__.split(".")[-1], + split_name=( + os.path.splitext(split_filename)[0] + + "+" + + os.path.splitext(tbx11k_split_filename)[0] + ), ) diff --git a/src/mednet/config/data/nih_cxr14/datamodule.py b/src/mednet/config/data/nih_cxr14/datamodule.py index 5967ee63152c6e6dac0aa9b1fb6bd13d3f24a438..26596b74c427e712d7ef3b2b141d4508d1eae395 100644 --- a/src/mednet/config/data/nih_cxr14/datamodule.py +++ b/src/mednet/config/data/nih_cxr14/datamodule.py @@ -192,4 +192,6 @@ class DataModule(CachingDataModule): super().__init__( database_split=make_split(split_filename), raw_data_loader=RawDataLoader(), + database_name=__package__.split(".")[-1], + split_name=os.path.splitext(split_filename)[0], ) diff --git a/src/mednet/config/data/nih_cxr14_padchest/datamodule.py b/src/mednet/config/data/nih_cxr14_padchest/datamodule.py index 2c793c7980156f14f0ab4946652bd1f627cd22c0..6cc383405093db4adca141ad6001bab9572fefa9 100644 --- a/src/mednet/config/data/nih_cxr14_padchest/datamodule.py +++ b/src/mednet/config/data/nih_cxr14_padchest/datamodule.py @@ -1,6 +1,9 @@ # Copyright © 2022 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later +"""Aggregated DataModule composed of NIH-CXR-14 and PadChest databases.""" + +import os from ....data.datamodule import ConcatDataModule from ..nih_cxr14.datamodule import RawDataLoader as CXR14Loader @@ -42,5 +45,11 @@ class DataModule(ConcatDataModule): # there is no test set on padchest # (padchest_split["test"], padchest_loader), ], - } + }, + database_name=__package__.split(".")[-1], + split_name=( + os.path.splitext(cxr14_split_filename)[0] + + "+" + + os.path.splitext(padchest_split_filename)[0] + ), ) diff --git a/src/mednet/config/data/padchest/datamodule.py b/src/mednet/config/data/padchest/datamodule.py index 778d505f96813f28b348b9a251d91f4213871d77..d146fc0357659134151084d765945b96a8f8a305 100644 --- a/src/mednet/config/data/padchest/datamodule.py +++ b/src/mednet/config/data/padchest/datamodule.py @@ -341,4 +341,6 @@ class DataModule(CachingDataModule): super().__init__( database_split=make_split(split_filename), raw_data_loader=RawDataLoader(), + database_name=__package__.split(".")[-1], + split_name=os.path.splitext(split_filename)[0], ) diff --git a/src/mednet/config/data/shenzhen/datamodule.py b/src/mednet/config/data/shenzhen/datamodule.py index 6853ebe56718c38db9e1e1d72a2c08ded11627dd..81e48f9b70dc887a97a312a0d8062afc809682e2 100644 --- a/src/mednet/config/data/shenzhen/datamodule.py +++ b/src/mednet/config/data/shenzhen/datamodule.py @@ -155,4 +155,6 @@ class DataModule(CachingDataModule): super().__init__( database_split=make_split(split_filename), raw_data_loader=RawDataLoader(), + database_name=__package__.split(".")[-1], + split_name=os.path.splitext(split_filename)[0], ) diff --git a/src/mednet/config/data/tbpoc/datamodule.py b/src/mednet/config/data/tbpoc/datamodule.py index 67846f6c7c0842c18f24d83a4018beb1ccc908b9..14b09e7f23e7c72779f43f7f9bb8503ca9954414 100644 --- a/src/mednet/config/data/tbpoc/datamodule.py +++ b/src/mednet/config/data/tbpoc/datamodule.py @@ -136,4 +136,6 @@ class DataModule(CachingDataModule): super().__init__( database_split=make_split(split_filename), raw_data_loader=RawDataLoader(), + database_name=__package__.split(".")[-1], + split_name=os.path.splitext(split_filename)[0], ) diff --git a/src/mednet/config/data/tbx11k/datamodule.py b/src/mednet/config/data/tbx11k/datamodule.py index 9c76bb5af431793e2668ef4a181e4e98b28e5f25..1735607ee72cb7f1298bca1e1adca57bb5d41d6a 100644 --- a/src/mednet/config/data/tbx11k/datamodule.py +++ b/src/mednet/config/data/tbx11k/datamodule.py @@ -355,4 +355,6 @@ class DataModule(CachingDataModule): super().__init__( database_split=make_split(split_filename), raw_data_loader=RawDataLoader(), + database_name=__package__.split(".")[-1], + split_name=os.path.splitext(split_filename)[0], ) diff --git a/src/mednet/data/datamodule.py b/src/mednet/data/datamodule.py index 4f79857b88eeb7a7fada2bf40f29ebb9dcfe7209..71fead5a2d275b2aa33f28dcfc9c894ac687a2b2 100644 --- a/src/mednet/data/datamodule.py +++ b/src/mednet/data/datamodule.py @@ -458,6 +458,12 @@ class ConcatDataModule(lightning.LightningDataModule): Entries named ``monitor-...`` will be considered extra datasets that do not influence any early stop criteria during training, and are just monitored beyond the ``validation`` dataset. + database_name + The name of the database, or aggregated database containing the + raw-samples served by this data module. + split_name + The name of the split used to group the samples into the various + datasets for training, validation and testing. cache_samples If set, then issue raw data loading during ``prepare_data()``, and serves samples from CPU memory. Otherwise, loads samples from disk on @@ -510,6 +516,8 @@ class ConcatDataModule(lightning.LightningDataModule): def __init__( self, splits: ConcatDatabaseSplit, + database_name: str = "", + split_name: str = "", cache_samples: bool = False, balance_sampler_by_class: bool = False, batch_size: int = 1, @@ -522,10 +530,15 @@ class ConcatDataModule(lightning.LightningDataModule): self.set_chunk_size(batch_size, batch_chunk_count) self.splits = splits + self.database_name = database_name + self.split_name = split_name for dataset_name, split_loaders in splits.items(): count = sum([len(k) for k, _ in split_loaders]) - logger.info(f"Dataset `{dataset_name}` contains {count} samples") + logger.info( + f"Dataset `{dataset_name}` (`{database_name}`/`{split_name}`) " + f"contains {count} samples" + ) self.cache_samples = cache_samples self._train_sampler = None diff --git a/src/mednet/engine/evaluator.py b/src/mednet/engine/evaluator.py index cd24260014c5374e692c9a10cce9427e859692dc..64bbd76e11fabe603ab2a05ffd8adbe271b78371 100644 --- a/src/mednet/engine/evaluator.py +++ b/src/mednet/engine/evaluator.py @@ -10,6 +10,7 @@ import typing from collections.abc import Iterable, Iterator +import credible.curves import matplotlib.figure import numpy import numpy.typing @@ -239,7 +240,7 @@ def run_binary( # point measures on threshold summary = dict( - split=name, + num_samples=len(y_labels), threshold=use_threshold, threshold_a_posteriori=(threshold_a_priori is None), precision=sklearn.metrics.precision_score( @@ -248,13 +249,20 @@ def run_binary( recall=sklearn.metrics.recall_score( y_labels, y_predictions, pos_label=pos_label ), + f1_score=sklearn.metrics.f1_score( + y_labels, y_predictions, pos_label=pos_label + ), + average_precision_score=sklearn.metrics.average_precision_score( + y_labels, y_predictions, pos_label=pos_label + ), specificity=sklearn.metrics.recall_score( y_labels, y_predictions, pos_label=neg_label ), - accuracy=sklearn.metrics.accuracy_score(y_labels, y_predictions), - f1_score=sklearn.metrics.f1_score( - y_labels, y_predictions, pos_label=pos_label + auc_score=sklearn.metrics.roc_auc_score( + y_labels, + y_predictions, ), + accuracy=sklearn.metrics.accuracy_score(y_labels, y_predictions), ) # figures: score distributions @@ -279,7 +287,7 @@ def run_binary( def aggregate_summaries( - data: typing.Sequence[typing.Mapping[str, typing.Any]], fmt: str + data: typing.Mapping[str, typing.Mapping[str, typing.Any]], fmt: str ) -> str: """Tabulate summaries from multiple splits. @@ -300,8 +308,14 @@ def aggregate_summaries( A string containing the tabulated information. """ - headers = list(data[0].keys()) - table = [[k[h] for h in headers] for k in data] + example = next(iter(data.values())) + headers = list(example.keys()) + table = [[k[h] for h in headers] for k in data.values()] + + # add subset names + headers = ["subset"] + headers + table = [[name] + k for name, k in zip(data.keys(), table)] + return tabulate.tabulate(table, headers, tablefmt=fmt, floatfmt=".3f") @@ -471,7 +485,7 @@ def aggregate_pr( Parameters ---------- data - A dictionary mapping split names to ROC curve data produced by + A dictionary mapping split names to Precision-Recall curve data produced by :py:func:sklearn.metrics.precision_recall_curve`. title The title of the plot. @@ -503,8 +517,8 @@ def aggregate_pr( legend = [] for name, (prec, recall, _) in data.items(): - _auc = sklearn.metrics.auc(recall, prec) - label = f"{name} (AUC={_auc:.2f})" + _ap = credible.curves.average_metric([prec, recall]) + label = f"{name} (AP={_ap:.2f})" color = next(colorcycler) style = next(linecycler) diff --git a/src/mednet/engine/trainer.py b/src/mednet/engine/trainer.py index 3da73a1124d7c94704d0a6b2228eedd75664777d..7fd8f2a7a4d30fcf1beae5ed55870d3b2b056f33 100644 --- a/src/mednet/engine/trainer.py +++ b/src/mednet/engine/trainer.py @@ -2,108 +2,22 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -import csv import logging import os import pathlib -import shutil import lightning.pytorch import lightning.pytorch.callbacks import lightning.pytorch.loggers -import torch.nn from ..utils.checkpointer import CHECKPOINT_ALIASES -from ..utils.resources import ( - ResourceMonitor, - cpu_constants, - cuda_constants, - mps_constants, -) +from ..utils.resources import ResourceMonitor from .callbacks import LoggingCallback -from .device import DeviceManager, SupportedPytorchDevice +from .device import DeviceManager logger = logging.getLogger(__name__) -def save_model_summary( - output_folder: pathlib.Path, - model: torch.nn.Module, -) -> tuple[lightning.pytorch.callbacks.ModelSummary, int]: - """Save a little summary of the model in a txt file. - - Parameters - ---------- - output_folder - Directory in which to save the summary. - model - Instance of the model for which to save the summary. - - Returns - ------- - tuple[lightning.pytorch.callbacks.ModelSummary, int] - A tuple with the model summary in a text format and number of parameters of the model. - """ - summary_path = output_folder / "model-summary.txt" - logger.info(f"Saving model summary at {summary_path}...") - with summary_path.open("w") as f: - summary = lightning.pytorch.utilities.model_summary.ModelSummary( # type: ignore - model, max_depth=-1 - ) - f.write(str(summary)) - return ( - summary, - lightning.pytorch.utilities.model_summary.ModelSummary( # type: ignore - model - ).total_parameters, - ) - - -def static_information_to_csv( - static_logfile_name: pathlib.Path, - device_type: SupportedPytorchDevice, - model_size: int, -) -> None: - """Save the static information in a CSV file. - - Parameters - ---------- - static_logfile_name - The static file name which is a join between the output folder and - "constants.csv". - device_type - The type of device we are using. - model_size - The size of the model we will be training. - """ - if static_logfile_name.exists(): - backup = static_logfile_name.parent / (static_logfile_name.name + "~") - shutil.copy(static_logfile_name, backup) - - with static_logfile_name.open("w", newline="") as f: - logdata: dict[str, int | float | str] = {} - logdata.update(cpu_constants()) - - match device_type: - case "cpu": - pass - case "cuda": - results = cuda_constants() - if results is not None: - logdata.update(results) - case "mps": - results = mps_constants() - if results is not None: - logdata.update(results) - case _: - pass - - logdata["number-of-model-parameters"] = model_size - logwriter = csv.DictWriter(f, fieldnames=logdata.keys()) - logwriter.writeheader() - logwriter.writerow(logdata) - - def run( model: lightning.pytorch.LightningModule, datamodule: lightning.pytorch.LightningDataModule, @@ -158,9 +72,6 @@ def run( os.makedirs(output_folder, exist_ok=True) - # Save model summary - _, no_of_parameters = save_model_summary(output_folder, model) - from .loggers import CustomTensorboardLogger log_dir = "logs" @@ -205,13 +116,6 @@ def run( "periodic" ] - # write static information to a CSV file - static_information_to_csv( - output_folder / "constants.csv", - device_manager.device_type, - no_of_parameters, - ) - with train_resource_monitor, validation_resource_monitor: accelerator, devices = device_manager.lightning_accelerator() trainer = lightning.pytorch.Trainer( diff --git a/src/mednet/scripts/evaluate.py b/src/mednet/scripts/evaluate.py index 3c21ef632a25bcac0e39ba7d7678dc5737b2637f..49f8ad4e89e3f6cd41c69550e31a58c663dc5afb 100644 --- a/src/mednet/scripts/evaluate.py +++ b/src/mednet/scripts/evaluate.py @@ -26,13 +26,13 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") .. code:: sh - mednet evaluate -vv --predictions=path/to/predictions.json --output-folder=path/to/results + mednet evaluate -vv --predictions=path/to/predictions.json --output=evaluation.json 2. Run evaluation on an existing prediction output, tune threshold a priori on the `validation` set: .. code:: sh - mednet evaluate -vv --predictions=path/to/predictions.json --output-folder=path/to/results --threshold=validation + mednet evaluate -vv --predictions=path/to/predictions.json --output=evaluation.json --threshold=validation """, ) @click.option( @@ -46,13 +46,16 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") cls=ResourceOption, ) @click.option( - "--output-folder", + "--output", "-o", - help="Directory in which to store the analysis result (created if does not exist)", - required=False, - default="results", - type=click.Path(file_okay=False, dir_okay=True, path_type=pathlib.Path), + help="""Path to a JSON file in which to save evaluation results + (leading directories are created if they do not exist).""", + required=True, + default="evaluation.json", cls=ResourceOption, + type=click.Path( + file_okay=True, dir_okay=False, writable=True, path_type=pathlib.Path + ), ) @click.option( "--threshold", @@ -75,7 +78,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False) def evaluate( predictions: pathlib.Path, - output_folder: pathlib.Path, + output: pathlib.Path, threshold: str | float, **_, # ignored ) -> None: # numpydoc ignore=PR01 @@ -94,10 +97,16 @@ def evaluate( aggregate_summaries, run_binary, ) + from .utils import execution_metadata, save_json_with_backup with predictions.open("r") as f: predict_data = json.load(f) + # register metadata + json_data: dict[str, typing.Any] = execution_metadata() + json_data = {k.replace("_", "-"): v for k, v in json_data.items()} + save_json_with_backup(output.with_suffix(".meta.json"), json_data) + if threshold in predict_data: # it is the name of a split # first run evaluation for reference dataset @@ -133,40 +142,40 @@ def evaluate( threshold_a_priori=use_threshold, ) - rows = [v[0] for v in results.values()] - table = aggregate_summaries(rows, fmt="rst") - click.echo(table) - - if output_folder is not None: - output_folder.mkdir(parents=True, exist_ok=True) + data = {k: v[0] for k, v in results.items()} + logger.info(f"Saving evaluation results at `{output}`...") + with output.open("w") as f: + json.dump(data, f, indent=2) - table_path = output_folder / "summary.rst" - - logger.info(f"Saving measures at `{table_path}`...") - with table_path.open("w") as f: - f.write(table) - - figure_path = output_folder / "plots.pdf" - logger.info(f"Saving figures at `{figure_path}`...") - - with PdfPages(figure_path) as pdf: - pr_curves = { - k: v[2]["precision_recall"] for k, v in results.items() - } - pr_fig = aggregate_pr(pr_curves) - pdf.savefig(pr_fig) - - roc_curves = {k: v[2]["roc"] for k, v in results.items()} - roc_fig = aggregate_roc(roc_curves) - pdf.savefig(roc_fig) - - # order ready-to-save figures by type instead of split - figures = {k: v[1] for k, v in results.items()} - keys = next(iter(figures.values())).keys() - figures_by_type = { - k: [v[k] for v in figures.values()] for k in keys - } + # dump evaluation results in RST format to screen and file + table = aggregate_summaries(data, fmt="rst") + click.echo(table) - for group_figures in figures_by_type.values(): - for f in group_figures: - pdf.savefig(f) + table_path = output.with_suffix(".rst") + logger.info( + f"Saving evaluation results in table format at `{table_path}`..." + ) + with table_path.open("w") as f: + f.write(table) + + # dump evaluation plots in file + figure_path = output.with_suffix(".pdf") + logger.info(f"Saving evaluation figures at `{figure_path}`...") + + with PdfPages(figure_path) as pdf: + pr_curves = {k: v[2]["precision_recall"] for k, v in results.items()} + pr_fig = aggregate_pr(pr_curves) + pdf.savefig(pr_fig) + + roc_curves = {k: v[2]["roc"] for k, v in results.items()} + roc_fig = aggregate_roc(roc_curves) + pdf.savefig(roc_fig) + + # order ready-to-save figures by type instead of split + figures = {k: v[1] for k, v in results.items()} + keys = next(iter(figures.values())).keys() + figures_by_type = {k: [v[k] for v in figures.values()] for k in keys} + + for group_figures in figures_by_type.values(): + for f in group_figures: + pdf.savefig(f) diff --git a/src/mednet/scripts/experiment.py b/src/mednet/scripts/experiment.py index 275bc8b8bb91417a6467bcc62b4b7af708253533..e19f59efd4ec7120fbde8957b8b9d112c39ebe02 100644 --- a/src/mednet/scripts/experiment.py +++ b/src/mednet/scripts/experiment.py @@ -63,10 +63,6 @@ def experiment( ├── predictions.json # the prediction outputs for the sets └── evaluation/ # the outputs of the evaluations for the sets """ - from .utils import save_sh_command - - save_sh_command(output_folder / "command.sh") - logger.info("Started training") from .train import train @@ -128,18 +124,18 @@ def experiment( from .evaluate import evaluate - evaluations_folder = output_folder / "evaluation" + evaluation_output = output_folder / "evaluation.json" ctx.invoke( evaluate, - output_folder=evaluations_folder, predictions=predictions_output, + output=evaluation_output, threshold="validation", ) logger.info("Ended evaluating") - logger.info("Started generating saliencies") + logger.info("Started generating saliency maps") from .saliency.generate import generate @@ -153,9 +149,9 @@ def experiment( output_folder=saliencies_gen_folder, ) - logger.info("Ended generating saliencies") + logger.info("Ended generating saliency maps") - logger.info("Started viewing saliencies") + logger.info("Started generating saliency images") from .saliency.view import view @@ -169,4 +165,4 @@ def experiment( output_folder=saliencies_view_folder, ) - logger.info("Ended viewing saliencies") + logger.info("Ended generating saliency images") diff --git a/src/mednet/scripts/predict.py b/src/mednet/scripts/predict.py index 68dd8da7e5f1e28b03d21746b9550039f4caa440..e1f38d48fccbb3f07b57439f034cd04d7ad6423a 100644 --- a/src/mednet/scripts/predict.py +++ b/src/mednet/scripts/predict.py @@ -36,11 +36,11 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @click.option( "--output", "-o", - help="""Path to a .json file in which to save predictions for all samples in the + help="""Path to a JSON file in which to save predictions for all samples in the input DataModule (leading directories are created if they do not exist).""", required=True, - default="results", + default="predictions.json", cls=ResourceOption, type=click.Path( file_okay=True, dir_okay=False, writable=True, path_type=pathlib.Path @@ -129,10 +129,17 @@ def predict( import json import shutil + import typing from ..engine.device import DeviceManager from ..engine.predictor import run from ..utils.checkpointer import get_checkpoint_to_run_inference + from .utils import ( + device_properties, + execution_metadata, + model_summary, + save_json_with_backup, + ) datamodule.set_chunk_size(batch_size, 1) datamodule.parallel = parallel @@ -147,7 +154,23 @@ def predict( logger.info(f"Loading checkpoint from `{weight}`...") model = type(model).load_from_checkpoint(weight, strict=False) - predictions = run(model, datamodule, DeviceManager(device)) + device_manager = DeviceManager(device) + + # register metadata + json_data: dict[str, typing.Any] = execution_metadata() + json_data.update(device_properties(device_manager.device_type)) + json_data.update( + dict( + database_name=datamodule.database_name, + database_split=datamodule.split_name, + model_name=model.name, + ) + ) + json_data.update(model_summary(model)) + json_data = {k.replace("_", "-"): v for k, v in json_data.items()} + save_json_with_backup(output.with_suffix(".meta.json"), json_data) + + predictions = run(model, datamodule, device_manager) output.parent.mkdir(parents=True, exist_ok=True) if output.exists(): diff --git a/src/mednet/scripts/saliency/view.py b/src/mednet/scripts/saliency/view.py index f44be79f8f4999fcc69fc38da368b53869fe696e..542d177bb757a7984c0e3b89308bcb998b304aa3 100644 --- a/src/mednet/scripts/saliency/view.py +++ b/src/mednet/scripts/saliency/view.py @@ -101,7 +101,6 @@ def view( """Generate heatmaps for input CXRs based on existing saliency maps.""" from ...engine.saliency.viewer import run - from ..utils import save_sh_command assert ( input_folder != output_folder @@ -114,8 +113,6 @@ def view( logger.info(f"Output folder: {output_folder}") os.makedirs(output_folder, exist_ok=True) - save_sh_command(output_folder / "command.sh") - datamodule.set_chunk_size(1, 1) datamodule.drop_incomplete_batch = False # datamodule.cache_samples = cache_samples diff --git a/src/mednet/scripts/train.py b/src/mednet/scripts/train.py index 1dc78e4a68e690e7a311b31966270b2de4eedadf..9280f1dd02c454feb8212e06f29b47117c1b7660 100644 --- a/src/mednet/scripts/train.py +++ b/src/mednet/scripts/train.py @@ -4,6 +4,7 @@ import functools import pathlib +import typing import click @@ -266,7 +267,12 @@ def train( from ..engine.device import DeviceManager from ..engine.trainer import run from ..utils.checkpointer import get_checkpoint_to_resume_training - from .utils import save_sh_command + from .utils import ( + device_properties, + execution_metadata, + model_summary, + save_json_with_backup, + ) checkpoint_file = None if os.path.isdir(output_folder): @@ -279,7 +285,6 @@ def train( f" from. Starting from scratch..." ) - save_sh_command(output_folder / "command.sh") seed_everything(seed) # reset datamodule with user configurable options @@ -333,11 +338,37 @@ def train( f"(checkpoint file: `{str(checkpoint_file)}`)..." ) + device_manager = DeviceManager(device) + + # stores all information we can think of, to reproduce this later + json_data: dict[str, typing.Any] = execution_metadata() + json_data.update(device_properties(device_manager.device_type)) + json_data.update( + dict( + database_name=datamodule.database_name, + split_name=datamodule.split_name, + epochs=epochs, + batch_size=batch_size, + batch_chunk_count=batch_chunk_count, + drop_incomplete_batch=drop_incomplete_batch, + validation_period=validation_period, + cache_samples=cache_samples, + seed=seed, + parallel=parallel, + monitoring_interval=monitoring_interval, + balance_classes=balance_classes, + model_name=model.name, + ) + ) + json_data.update(model_summary(model)) + json_data = {k.replace("_", "-"): v for k, v in json_data.items()} + save_json_with_backup(output_folder / "meta.json", json_data) + run( model=model, datamodule=datamodule, validation_period=validation_period, - device_manager=DeviceManager(device), + device_manager=device_manager, max_epochs=epochs, output_folder=output_folder, monitoring_interval=monitoring_interval, diff --git a/src/mednet/scripts/train_analysis.py b/src/mednet/scripts/train_analysis.py index 9b7a514486e5faeaa50e78f7d6e0f47769b38e04..0dc23f324c1a71b5705137f001c30a0fb8c7403d 100644 --- a/src/mednet/scripts/train_analysis.py +++ b/src/mednet/scripts/train_analysis.py @@ -18,7 +18,6 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") def create_figures( data: dict[str, tuple[list[int], list[float]]], groups: list[str] = [ - "total-execution-time-seconds", "loss/*", "learning-rate", "memory-used-GB/cpu/*" "rss-GB/cpu/*", @@ -30,7 +29,6 @@ def create_figures( "memory-percent/gpu/*", "memory-used-GB/gpu/*", "memory-free-GB/gpu/*", - "memory-free-GB/gpu/*", "percent-usage/gpu/*", ], ) -> list: @@ -116,6 +114,7 @@ def create_figures( ) @click.option( "--logdir", + "-l", help="Path to the directory containing the Tensorboard training logs", required=True, type=click.Path(dir_okay=True, exists=True, path_type=pathlib.Path), diff --git a/src/mednet/scripts/utils.py b/src/mednet/scripts/utils.py index 7c62d84e715123b11442717bdce103479aef1cec..6524ed9d3f7a413790a0f2ae115fca5260131348 100644 --- a/src/mednet/scripts/utils.py +++ b/src/mednet/scripts/utils.py @@ -1,64 +1,170 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -import importlib.metadata +"""Utilities for command-line scripts.""" + +import json import logging -import os import pathlib import shutil -import sys -import time + +import lightning.pytorch +import lightning.pytorch.callbacks +import torch.nn + +from ..engine.device import SupportedPytorchDevice logger = logging.getLogger(__name__) -def save_sh_command(path: pathlib.Path) -> None: - """Record command-line to reproduce this script. +def model_summary( + model: torch.nn.Module, +) -> dict[str, int | list[tuple[str, str, int]]]: + """Save a little summary of the model in a txt file. + + Parameters + ---------- + model + Instance of the model for which to save the summary. + + Returns + ------- + tuple[lightning.pytorch.callbacks.ModelSummary, int] + A tuple with the model summary in a text format and number of parameters of the model. + """ + + s = lightning.pytorch.utilities.model_summary.ModelSummary( # type: ignore + model + ) + + return dict( + model_summary=list(zip(s.layer_names, s.layer_types, s.param_nums)), + model_size=s.total_parameters, + ) - This function can record the current command-line used to call the script - being run. It creates an executable ``bash`` script setting up the current - working directory and activating a conda environment, if needed. It - records further information on the date and time the script was run and the - version of the package. + +def device_properties( + device_type: SupportedPytorchDevice, +) -> dict[str, int | float | str]: + """Generate information concerning hardware properties. + + Parameters + ---------- + device_type + The type of compute device we are using. + + Returns + ------- + Static properties of the current machine. + """ + + from ..utils.resources import cpu_constants, cuda_constants, mps_constants + + retval: dict[str, int | float | str] = {} + retval.update(cpu_constants()) + + match device_type: + case "cpu": + pass + case "cuda": + results = cuda_constants() + if results is not None: + retval.update(results) + case "mps": + results = mps_constants() + if results is not None: + retval.update(results) + case _: + pass + + return retval + + +def execution_metadata() -> dict[str, int | float | str]: + """Produce metadata concerning the running script, in the form of a + dictionary. + + This function returns potentially useful metadata concerning program + execution. It contains a certain number of preset variables. + + Returns + ------- + A dictionary that contains the following fields: + + * ``package-name``: current package name (e.g. ``mednet``) + * ``package-version``: current package version (e.g. ``1.0.0b0``) + * ``datetime``: date and time in ISO8601 format (e.g. ``2024-02-23T18:38:09+01:00``) + * ``user``: username (e.g. ``johndoe``) + * ``conda-env``: if set, the name of the current conda environment + * ``path``: current path when executing the command + * ``command-line``: the command-line that is being run + * ``hostname``: machine hostname (e.g. ``localhost``) + * ``platform``: machine platform (e.g. ``darwin``) + """ + + import importlib.metadata + import importlib.util + import os + import sys + + args = [] + for k in sys.argv: + if " " in k: + args.append(f"'{k}'") + else: + args.append(k) + + # current date time, in ISO8610 format + datetime = __import__("datetime").datetime.now().astimezone().isoformat() + + # collects dependence information + package_name = __package__.split(".")[0] + requires = importlib.metadata.requires(package_name) or [] + dependence_names = [k.split()[0] for k in requires] + dependencies = { + k: importlib.metadata.version(k) # version number as str + for k in dependence_names + if importlib.util.find_spec(k) is not None # if is installed + } + + data = { + "datetime": datetime, + "package-name": __package__.split(".")[0], + "package-version": importlib.metadata.version(package_name), + "dependencies": dependencies, + "user": __import__("getpass").getuser(), + "conda-env": os.environ.get("CONDA_DEFAULT_ENV", ""), + "path": os.path.realpath(os.curdir), + "command-line": " ".join(args), + "hostname": __import__("platform").node(), + "platform": sys.platform, + } + + return data + + +def save_json_with_backup(path: pathlib.Path, data: dict | list) -> None: + """Save a dictionary into a JSON file with path checking and backup. + + This function will save a dictionary into a JSON file. It will check to + the existence of the directory leading to the file and create it if + necessary. If the file already exists on the destination folder, it is + backed-up before a new file is created with the new contents. Parameters ---------- path - Path to a file where the commands to reproduce the current run will be - recorded. Parent directories will be created if they do not exist. An - existing copy will be backed-up if it exists. + The full path where to save the JSON data. + data + The data to save on the JSON file. """ - logger.info(f"Writing command-line for reproduction at `{path}`...") + logger.info(f"Writing run metadata at `{path}`...") - # create parent directories path.parent.mkdir(parents=True, exist_ok=True) - - # backup if exists if path.exists(): backup = path.parent / (path.name + "~") shutil.copy(path, backup) - # write the file - package = __name__.split(".", 1)[0] - version = importlib.metadata.version(package) - with path.open("w") as f: - f.write("#!/usr/bin/env sh\n") - f.write(f"# date: {time.asctime()}\n") - f.write(f"# version: {version} ({package})\n") - f.write(f"# platform: {sys.platform}\n") - f.write("\n") - args = [] - for k in sys.argv: - if " " in k: - args.append(f'"{k}"') - else: - args.append(k) - if os.environ.get("CONDA_DEFAULT_ENV") is not None: - f.write(f"# conda activate {os.environ['CONDA_DEFAULT_ENV']}\n") - f.write(f"# cd {os.path.realpath(os.curdir)}\n") - f.write(" ".join(args) + "\n") - - # make it executable - path.chmod(0o755) + json.dump(data, f, indent=2) diff --git a/src/mednet/utils/resources.py b/src/mednet/utils/resources.py index a17d3cbd2750e01650e3d7ddc89bb5f379df0157..5f71f57473c7d5768c00c9d08770c568cdcf4090 100644 --- a/src/mednet/utils/resources.py +++ b/src/mednet/utils/resources.py @@ -186,8 +186,7 @@ def mps_constants() -> dict[str, str | int | float] | None: return { "apple-processor-model": name, - "number-cpu-cores": multiprocessing.cpu_count(), - "number-gpu-cores": no_gpu_cores, + "number-of-cores/gpu": no_gpu_cores, } @@ -271,7 +270,7 @@ def cpu_constants() -> dict[str, int | float]: """ return { "memory-total-GB/cpu": psutil.virtual_memory().total / GB, - "count/cpu": psutil.cpu_count(logical=True), + "number-of-cores/cpu": psutil.cpu_count(logical=True), } diff --git a/tests/test_cli.py b/tests/test_cli.py index 3382c88f4022e15378e8a830e40ef829667b975f..381ef539bc7a4dec82e468456124fa42086be7e5 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -4,8 +4,6 @@ """Tests for our CLI applications.""" import contextlib -import glob -import os import re import pytest @@ -209,26 +207,20 @@ def test_train_pasa_montgomery(temporary_basedir): best = _get_checkpoint_from_alias(output_folder, "best") assert best.name.endswith("epoch=0" + CHECKPOINT_EXTENSION) - assert os.path.exists(os.path.join(output_folder, "constants.csv")) assert ( - len( - glob.glob( - os.path.join(output_folder, "logs", "events.out.tfevents.*") - ) - ) + len(list((output_folder / "logs").glob("events.out.tfevents.*"))) == 1 ) - assert os.path.exists(os.path.join(output_folder, "model-summary.txt")) + assert (output_folder / "meta.json").exists() keywords = { - r"^Writing command-line for reproduction at .*$": 1, r"^Loading dataset:`train` without caching. Trade-off: CPU RAM usage: less | Disk I/O: more.$": 1, r"^Loading dataset:`validation` without caching. Trade-off: CPU RAM usage: less | Disk I/O: more.$": 1, r"^Applying DataModule train sampler balancing...$": 1, r"^Balancing samples from dataset using metadata targets `label`$": 1, r"^Training for at most 1 epochs.$": 1, r"^Uninitialised pasa model - computing z-norm factors from train dataloader.$": 1, - r"^Saving model summary at.*$": 1, + r"^Writing run metadata at.*$": 1, r"^Dataset `train` is already setup. Not re-instantiating it.$": 1, r"^Dataset `validation` is already setup. Not re-instantiating it.$": 1, } @@ -273,18 +265,11 @@ def test_train_pasa_montgomery_from_checkpoint(temporary_basedir): best = _get_checkpoint_from_alias(output_folder, "best") assert best.name.endswith("epoch=0" + CHECKPOINT_EXTENSION) - assert os.path.exists(os.path.join(output_folder, "constants.csv")) + assert (output_folder / "meta.json").exists() assert ( - len( - glob.glob( - os.path.join(output_folder, "logs", "events.out.tfevents.*") - ) - ) - == 1 + len(list((output_folder / "logs").glob("events.out.tfevents.*"))) == 1 ) - assert os.path.exists(os.path.join(output_folder, "model-summary.txt")) - with stdout_logging() as buf: result = runner.invoke( train, @@ -304,28 +289,20 @@ def test_train_pasa_montgomery_from_checkpoint(temporary_basedir): assert last.name.endswith("epoch=1" + CHECKPOINT_EXTENSION) best = _get_checkpoint_from_alias(output_folder, "best") - assert os.path.exists(os.path.join(output_folder, "constants.csv")) - + assert (output_folder / "meta.json").exists() assert ( - len( - glob.glob( - os.path.join(output_folder, "logs", "events.out.tfevents.*") - ) - ) + len(list((output_folder / "logs").glob("events.out.tfevents.*"))) == 2 ) - assert os.path.exists(os.path.join(output_folder, "model-summary.txt")) - keywords = { - r"^Writing command-line for reproduction at .*$": 1, r"^Loading dataset:`train` without caching. Trade-off: CPU RAM usage: less | Disk I/O: more.$": 1, r"^Loading dataset:`validation` without caching. Trade-off: CPU RAM usage: less | Disk I/O: more.$": 1, r"^Applying DataModule train sampler balancing...$": 1, r"^Balancing samples from dataset using metadata targets `label`$": 1, r"^Training for at most 2 epochs.$": 1, r"^Resuming from epoch 0 \(checkpoint file: .*$": 1, - r"^Saving model summary at.*$": 1, + r"^Writing run metadata at.*$": 1, r"^Dataset `train` is already setup. Not re-instantiating it.$": 1, r"^Dataset `validation` is already setup. Not re-instantiating it.$": 1, r"^Restoring normalizer from checkpoint.$": 1, @@ -342,7 +319,7 @@ def test_train_pasa_montgomery_from_checkpoint(temporary_basedir): @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") -def test_predict_pasa_montgomery(temporary_basedir, datadir): +def test_predict_pasa_montgomery(temporary_basedir): from mednet.scripts.predict import predict from mednet.utils.checkpointer import ( CHECKPOINT_EXTENSION, @@ -352,7 +329,7 @@ def test_predict_pasa_montgomery(temporary_basedir, datadir): runner = CliRunner() with stdout_logging() as buf: - output = temporary_basedir / "predictions" + output = temporary_basedir / "predictions.json" last = _get_checkpoint_from_alias( temporary_basedir / "results", "periodic" ) @@ -365,12 +342,12 @@ def test_predict_pasa_montgomery(temporary_basedir, datadir): "-vv", "--batch-size=1", f"--weight={str(last)}", - f"--output={output}", + f"--output={str(output)}", ], ) _assert_exit_0(result) - assert os.path.exists(output) + assert output.exists() keywords = { r"^Loading dataset: * without caching. Trade-off: CPU RAM usage: less | Disk I/O: more$": 3, @@ -400,30 +377,32 @@ def test_evaluate_pasa_montgomery(temporary_basedir): runner = CliRunner() with stdout_logging() as buf: - prediction_folder = str(temporary_basedir / "predictions") - output_folder = str(temporary_basedir / "evaluations") + prediction_path = temporary_basedir / "predictions.json" + output_path = temporary_basedir / "evaluation.json" result = runner.invoke( evaluate, [ "-vv", "montgomery", - f"--predictions={prediction_folder}", - f"--output-folder={output_folder}", + f"--predictions={str(prediction_path)}", + f"--output={str(output_path)}", "--threshold=test", ], ) _assert_exit_0(result) - assert os.path.exists(os.path.join(output_folder, "plots.pdf")) - assert os.path.exists(os.path.join(output_folder, "summary.rst")) + assert output_path.exists() + assert output_path.with_suffix(".meta.json").exists() + assert output_path.with_suffix(".rst").exists() + assert output_path.with_suffix(".pdf").exists() keywords = { r"^Setting --threshold=.*$": 1, r"^Analyzing split `train`...$": 1, r"^Analyzing split `validation`...$": 1, r"^Analyzing split `test`...$": 1, - r"^Saving measures at .*$": 1, - r"^Saving figures at .*$": 1, + r"^Saving evaluation results .*$": 2, + r"^Saving evaluation figures at .*$": 1, } buf.seek(0) logging_output = buf.read() @@ -442,7 +421,7 @@ def test_experiment(temporary_basedir): runner = CliRunner() - output_folder = str(temporary_basedir / "experiment") + output_folder = temporary_basedir / "experiment" num_epochs = 2 result = runner.invoke( experiment, @@ -451,80 +430,59 @@ def test_experiment(temporary_basedir): "pasa", "montgomery", f"--epochs={num_epochs}", - f"--output-folder={output_folder}", + f"--output-folder={str(output_folder)}", ], ) _assert_exit_0(result) - assert os.path.exists(os.path.join(output_folder, "command.sh")) - assert os.path.exists(os.path.join(output_folder, "predictions.json")) - assert os.path.exists(os.path.join(output_folder, "model", "command.sh")) - assert os.path.exists(os.path.join(output_folder, "model", "constants.csv")) - assert os.path.exists( - os.path.join( - output_folder, "model", f"model-at-epoch={num_epochs-1}.ckpt" - ) - ) + assert (output_folder / "model" / "meta.json").exists() + assert ( + output_folder / "model" / f"model-at-epoch={num_epochs-1}.ckpt" + ).exists() + assert (output_folder / "predictions.json").exists() + assert (output_folder / "predictions.meta.json").exists() + # Need to glob because we cannot be sure of the checkpoint with lowest validation loss assert ( len( - glob.glob( - os.path.join( - output_folder, - "model", - "model-at-lowest-validation-loss-epoch=*.ckpt", + list( + (output_folder / "model").glob( + "model-at-lowest-validation-loss-epoch=*.ckpt" ) ) ) == 1 ) - assert os.path.exists( - os.path.join(output_folder, "model", "model-summary.txt") - ) - assert os.path.exists(os.path.join(output_folder, "model", "trainlog.pdf")) + assert (output_folder / "model" / "trainlog.pdf").exists() assert ( len( - glob.glob( - os.path.join( - output_folder, "model", "logs", "events.out.tfevents.*" - ) + list( + (output_folder / "model" / "logs").glob("events.out.tfevents.*") ) ) == 1 ) - assert os.path.exists( - os.path.join(output_folder, "evaluation", "plots.pdf") - ) - assert os.path.exists( - os.path.join(output_folder, "evaluation", "summary.rst") - ) - assert os.path.exists(os.path.join(output_folder, "gradcam", "saliencies")) + assert (output_folder / "evaluation.json").exists() + assert (output_folder / "evaluation.meta.json").exists() + assert (output_folder / "evaluation.rst").exists() + assert (output_folder / "evaluation.pdf").exists() + assert (output_folder / "gradcam" / "saliencies").exists() assert ( len( - glob.glob( - os.path.join( - output_folder, - "gradcam", - "saliencies", - "CXR_png", - "MCUCXR_*.npy", + list( + (output_folder / "gradcam" / "saliencies" / "CXR_png").glob( + "MCUCXR_*.npy" ) ) ) == 138 ) - assert os.path.exists( - os.path.join(output_folder, "gradcam", "visualizations") - ) + assert (output_folder / "gradcam" / "visualizations").exists() assert ( len( - glob.glob( - os.path.join( - output_folder, - "gradcam", - "visualizations", - "CXR_png", - "MCUCXR_*.png", + list( + (output_folder / "gradcam" / "visualizations" / "CXR_png").glob( + "MCUCXR_*.png" ) ) )