diff --git a/src/mednet/libs/classification/scripts/evaluate.py b/src/mednet/libs/classification/scripts/evaluate.py index 50c2e56a7dcb37051a78265ef27eb33bec9fd442..d0553bad4a4284b35032da0b00dc89410c4b1913 100644 --- a/src/mednet/libs/classification/scripts/evaluate.py +++ b/src/mednet/libs/classification/scripts/evaluate.py @@ -137,8 +137,19 @@ def evaluate( # register metadata json_data: dict[str, typing.Any] = execution_metadata() + json_data.update( + dict( + predictions=str(predictions), + output_folder=str(output_folder), + threshold=threshold, + binning=binning, + plot=plot, + ), + ) json_data = {k.replace("_", "-"): v for k, v in json_data.items()} - save_json_with_backup(evaluation_file.with_suffix(".meta.json"), json_data) + evaluation_meta = evaluation_file.with_suffix(".meta.json") + logger.info(f"Saving evaluation metadata at `{str(evaluation_meta)}`...") + save_json_with_backup(evaluation_meta, json_data) if threshold in predict_data: # it is the name of a split @@ -161,7 +172,7 @@ def evaluate( results: dict[str, dict[str, typing.Any]] = dict() for k, v in predict_data.items(): - logger.info(f"Analyzing split `{k}`...") + logger.info(f"Computing performance on split `{k}`...") results[k] = run_binary( name=k, predictions=v, @@ -170,7 +181,7 @@ def evaluate( ) # records full result analysis to a JSON file - logger.info(f"Saving evaluation results at `{evaluation_file}`...") + logger.info(f"Saving evaluation results at `{str(evaluation_file)}`...") with evaluation_file.open("w") as f: json.dump(results, f, indent=2, cls=NumpyJSONEncoder) @@ -190,11 +201,10 @@ def evaluate( with table_path.open("w") as f: f.write(table) - # dump evaluation plots in file - figure_path = evaluation_file.with_suffix(".pdf") - logger.info(f"Saving evaluation figures at `{figure_path}`...") - if plot: + figure_path = evaluation_file.with_suffix(".pdf") + logger.info(f"Saving evaluation figures at `{str(figure_path)}`...") + with PdfPages(figure_path) as pdf: pr_curves = {k: v["curves"]["precision_recall"] for k, v in results.items()} pr_fig = aggregate_pr(pr_curves) diff --git a/src/mednet/libs/classification/scripts/experiment.py b/src/mednet/libs/classification/scripts/experiment.py index 5dcfe0c9d400117178d2f314c23fc083fe24b0c7..690bc74f5f4a561e35daaa2ba7ec2a8b9a2cd4cf 100644 --- a/src/mednet/libs/classification/scripts/experiment.py +++ b/src/mednet/libs/classification/scripts/experiment.py @@ -61,8 +61,8 @@ def experiment( └─ <output-folder>/ ├── model/ # the generated model will be here - ├── predictions.json # the prediction outputs for the sets - └── evaluation/ # the outputs of the evaluations for the sets + ├── predictions.json # the prediction outputs + ├── evaluation.json # the evaluation outputs """ experiment_start_timestamp = datetime.now() @@ -112,11 +112,9 @@ def experiment( from .predict import predict - predictions_output = output_folder / "predictions" - ctx.invoke( predict, - output_folder=predictions_output, + output_folder=output_folder, model=model, datamodule=datamodule, device=device, @@ -134,9 +132,9 @@ def experiment( from .evaluate import evaluate - predictions_file = predictions_output / "predictions.json" + predictions_file = output_folder / "predictions.json" - with (predictions_output / "predictions.json").open() as pf: + with (output_folder / "predictions.json").open() as pf: splits = json.load(pf).keys() if "validation" in splits: @@ -159,56 +157,6 @@ def experiment( f"Prediction runtime: {evaluation_stop_timestamp-evaluation_start_timestamp}" ) - saliency_map_generation_start_timestamp = datetime.now() - logger.info( - f"Started saliency map generation at {saliency_map_generation_start_timestamp}" - ) - - from .saliency.generate import generate - - saliencies_gen_folder = output_folder / "gradcam" / "saliencies" - - ctx.invoke( - generate, - model=model, - datamodule=datamodule, - weight=train_output_folder, - output_folder=saliencies_gen_folder, - ) - - saliency_map_generation_stop_timestamp = datetime.now() - logger.info( - f"Ended saliency map generation at {saliency_map_generation_stop_timestamp}" - ) - logger.info( - f"Saliency map generation runtime: {saliency_map_generation_stop_timestamp-saliency_map_generation_start_timestamp}" - ) - - saliency_images_generation_start_timestamp = datetime.now() - logger.info( - f"Started generating saliency images at {saliency_images_generation_start_timestamp}" - ) - - from .saliency.view import view - - saliencies_view_folder = output_folder / "gradcam" / "visualizations" - - ctx.invoke( - view, - model=model, - datamodule=datamodule, - input_folder=saliencies_gen_folder, - output_folder=saliencies_view_folder, - ) - - saliency_images_generation_stop_timestamp = datetime.now() - logger.info( - f"Ended saliency images generation at {saliency_images_generation_stop_timestamp}" - ) - logger.info( - f"Saliency images generation runtime: {saliency_images_generation_stop_timestamp-saliency_images_generation_start_timestamp}" - ) - experiment_stop_timestamp = datetime.now() logger.info( f"Total experiment runtime: {experiment_stop_timestamp-experiment_start_timestamp}" diff --git a/src/mednet/libs/segmentation/scripts/evaluate.py b/src/mednet/libs/segmentation/scripts/evaluate.py index 191cde3884a606c1d379e017aa6d49aefc2acd08..dc579d225c17a4ee846ddb34b66f75e3dd33aafd 100644 --- a/src/mednet/libs/segmentation/scripts/evaluate.py +++ b/src/mednet/libs/segmentation/scripts/evaluate.py @@ -14,6 +14,9 @@ from mednet.libs.segmentation.engine.evaluator import SUPPORTED_METRIC_TYPE logger = setup("mednet") +# avoids X11/graphical desktop requirement when creating plots +__import__("matplotlib").use("agg") + def validate_threshold(threshold: float | str, splits: list[str]): """Validate the user threshold selection and returns parsed threshold. @@ -89,17 +92,6 @@ def validate_threshold(threshold: float | str, splits: list[str]): default="results", cls=ResourceOption, ) -# @click.option( -# "--second-annotator", -# "-a", -# help="""A datamodule containing annotations from another annotator, that -# will be compared to the ground-truth (reference annotator) in each -# sample.""", -# required=False, -# default=None, -# cls=ResourceOption, -# show_default=True, -# ) @click.option( "--threshold", "-t", @@ -138,14 +130,43 @@ def validate_threshold(threshold: float | str, splits: list[str]): required=True, cls=ResourceOption, ) +@click.option( + "--compare-annotator", + "-a", + help="""Path to a JSON file as produced by the CLI ``dump-annotations``, + containing splits and sample lists with associated HDF5 files where we can + find pre-processed annotation masks. These annotations will be compared + with the target annotations on the main predictions. In this case, a row + is added for each available split in the evaluation table.""", + required=False, + default=None, + type=click.Path( + file_okay=True, + dir_okay=False, + writable=False, + path_type=pathlib.Path, + ), + cls=ResourceOption, +) +@click.option( + "--plot/--no-plot", + "-P", + help="""If set, then also produces figures containing the plots of + performance curves and score histograms.""", + required=True, + show_default=True, + default=True, + cls=ResourceOption, +) @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False) def evaluate( predictions: pathlib.Path, output_folder: pathlib.Path, threshold: str | float, metric: str, - # second_annotator, steps: int, + compare_annotator: pathlib.Path, + plot: bool, **_, # ignored ): # numpydoc ignore=PR01 """Evaluate predictions (from a model) on a segmentation task.""" @@ -185,10 +206,14 @@ def evaluate( threshold=threshold, metric=metric, steps=steps, + compare_annotator=str(compare_annotator), + plot=plot, ), ) json_data = {k.replace("_", "-"): v for k, v in json_data.items()} - save_json_with_backup(evaluation_file.with_suffix(".meta.json"), json_data) + evaluation_meta = evaluation_file.with_suffix(".meta.json") + logger.info(f"Saving evaluation metadata at `{str(evaluation_meta)}`...") + save_json_with_backup(evaluation_meta, json_data) threshold = validate_threshold(threshold, predict_data) threshold_list = numpy.arange( @@ -211,7 +236,7 @@ def evaluate( if isinstance(threshold, str): # Compute threshold on specified split, if required - logger.info(f"Evaluating threshold on `{threshold}` split using " f"`{metric}`") + logger.info(f"Evaluating threshold on split `{threshold}` using " f"`{metric}`") metric_list = compute_metric( eval_json_data[threshold]["counts"].values(), name2metric(typing.cast(SUPPORTED_METRIC_TYPE, metric)), @@ -230,14 +255,12 @@ def evaluate( threshold_index = (numpy.abs(threshold_list - threshold)).argmin() logger.info(f"Set --threshold={threshold_list[threshold_index]:.4f}") - logger.info("Tabulating performance summary...") - table_format = "rst" - output_table = output_folder / "evaluation.rst" metrics_available = list(typing.get_args(SUPPORTED_METRIC_TYPE)) table_headers = ["Dataset", "threshold"] + metrics_available + ["auroc", "avgprec"] table_data = [] for split_name in predict_data.keys(): + logger.info("Computing performance on split `{split_name}`...") counts = list(eval_json_data[split_name]["counts"].values()) base_metrics = all_metrics(*counts[threshold_index]) table_data.append([split_name, threshold_list[threshold_index]] + base_metrics) @@ -266,10 +289,11 @@ def evaluate( # records full result analysis to a JSON file evaluation_file = output_folder / "evaluation.json" - logger.info(f"Saving evaluation results at `{evaluation_file}`...") + logger.info(f"Saving evaluation results at `{str(evaluation_file)}`...") with evaluation_file.open("w") as f: json.dump(eval_json_data, f, indent=2, cls=NumpyJSONEncoder) + table_format = "rst" table = tabulate.tabulate( table_data, table_headers, @@ -278,41 +302,44 @@ def evaluate( stralign="right", ) click.echo(table) - logger.info(f"Saving table at {output_table}...") + + output_table = output_folder / "evaluation.rst" + logger.info(f"Saving tabulated performance summary at `{str(output_table)}`...") output_table.parent.mkdir(parents=True, exist_ok=True) with output_table.open("w") as f: f.write(table) - logger.info("Plotting performance curves...") - output_figure = output_folder / "evaluation.pdf" - logger.info(f"Saving figures at {output_figure}...") - with matplotlib.backends.backend_pdf.PdfPages(output_figure) as pdf: - with credible.plot.tight_layout( - ("False Positive Rate", "True Positive Rate"), "ROC" - ) as ( - fig, - ax, - ): - for split_name, data in eval_json_data.items(): - ax.plot( - data["curves"]["roc"]["fpr"], - data["curves"]["roc"]["tpr"], - label=f"{split_name} (AUC: {data['auc_score']:.2f})", - ) - ax.legend(loc="best", fancybox=True, framealpha=0.7) - pdf.savefig(fig) - - with credible.plot.tight_layout_f1iso( - ("Recall", "Precision"), "Precison-Recall" - ) as ( - fig, - ax, - ): - for split_name, data in eval_json_data.items(): - ax.plot( - data["curves"]["precision_recall"]["precision"], - data["curves"]["precision_recall"]["recall"], - label=f"{split_name} (AP: {data['average_precision_score']:.2f})", - ) - ax.legend(loc="best", fancybox=True, framealpha=0.7) - pdf.savefig(fig) + if plot: + figure_path = evaluation_file.with_suffix(".pdf") + logger.info(f"Saving evaluation figures at `{str(figure_path)}`...") + + with matplotlib.backends.backend_pdf.PdfPages(figure_path) as pdf: + with credible.plot.tight_layout( + ("False Positive Rate", "True Positive Rate"), "ROC" + ) as ( + fig, + ax, + ): + for split_name, data in eval_json_data.items(): + ax.plot( + data["curves"]["roc"]["fpr"], + data["curves"]["roc"]["tpr"], + label=f"{split_name} (AUC: {data['auc_score']:.2f})", + ) + ax.legend(loc="best", fancybox=True, framealpha=0.7) + pdf.savefig(fig) + + with credible.plot.tight_layout_f1iso( + ("Recall", "Precision"), "Precison-Recall" + ) as ( + fig, + ax, + ): + for split_name, data in eval_json_data.items(): + ax.plot( + data["curves"]["precision_recall"]["precision"], + data["curves"]["precision_recall"]["recall"], + label=f"{split_name} (AP: {data['average_precision_score']:.2f})", + ) + ax.legend(loc="best", fancybox=True, framealpha=0.7) + pdf.savefig(fig) diff --git a/src/mednet/libs/segmentation/scripts/experiment.py b/src/mednet/libs/segmentation/scripts/experiment.py index f3a6c006dc8927278ce0f4013cbf7222c539a308..523e09fb448da82d40cec6c60c60a7baeb244ab5 100644 --- a/src/mednet/libs/segmentation/scripts/experiment.py +++ b/src/mednet/libs/segmentation/scripts/experiment.py @@ -60,8 +60,8 @@ def experiment( \b └─ <output-folder>/ ├── model/ # the generated model will be here - ├── predictions # the prediction outputs for the sets - └── evaluation/ # the outputs of the evaluations for the sets + ├── predictions.json # the prediction outputs + └── evaluation.json # the evaluation outputs """ experiment_start_timestamp = datetime.now() @@ -110,11 +110,9 @@ def experiment( from .predict import predict - predictions_output = output_folder / "predictions" - ctx.invoke( predict, - output_folder=predictions_output, + output_folder=output_folder, model=model, datamodule=datamodule, device=device, @@ -132,11 +130,9 @@ def experiment( from .evaluate import evaluate - evaluation_output = output_folder / "evaluation" - - predictions_file = predictions_output / "predictions.json" + predictions_file = output_folder / "predictions.json" - with (predictions_output / "predictions.json").open() as pf: + with (predictions_file).open() as pf: splits = json.load(pf).keys() if "validation" in splits: @@ -149,7 +145,7 @@ def experiment( ctx.invoke( evaluate, predictions=predictions_file, - output_folder=evaluation_output, + output_folder=output_folder, threshold=evaluation_threshold, # metric="f1", # steps=100, diff --git a/tests/classification/test_cli.py b/tests/classification/test_cli.py index c9c7dfd39cf6ab8637e6a74a765dbe78f15cc854..8f408319841c74933a6598ae8a0b9f24f2fb7de7 100644 --- a/tests/classification/test_cli.py +++ b/tests/classification/test_cli.py @@ -185,7 +185,7 @@ def test_upload_help(): @pytest.mark.slow @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") -def test_train_pasa_montgomery(temporary_basedir): +def test_train_pasa_montgomery(session_tmp_path): from mednet.libs.classification.scripts.train import train from mednet.libs.common.utils.checkpointer import ( CHECKPOINT_EXTENSION, @@ -195,7 +195,7 @@ def test_train_pasa_montgomery(temporary_basedir): runner = CliRunner() with stdout_logging() as buf: - output_folder = temporary_basedir / "classification" / "results" + output_folder = session_tmp_path / "classification-standalone" result = runner.invoke( train, [ @@ -241,8 +241,8 @@ def test_train_pasa_montgomery(temporary_basedir): @pytest.mark.slow @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") -def test_train_pasa_montgomery_from_checkpoint(temporary_basedir): - from mednet.libs.classification.scripts.train import train +def test_predict_pasa_montgomery(session_tmp_path): + from mednet.libs.classification.scripts.predict import predict from mednet.libs.common.utils.checkpointer import ( CHECKPOINT_EXTENSION, _get_checkpoint_from_alias, @@ -250,62 +250,36 @@ def test_train_pasa_montgomery_from_checkpoint(temporary_basedir): runner = CliRunner() - output_folder = temporary_basedir / "classification" / "results" / "pasa_checkpoint" - result0 = runner.invoke( - train, - [ - "pasa", - "montgomery", - "-vv", - "--epochs=1", - "--batch-size=1", - f"--output-folder={str(output_folder)}", - ], - ) - _assert_exit_0(result0) - - # asserts checkpoints are there, or raises FileNotFoundError - last = _get_checkpoint_from_alias(output_folder, "periodic") - assert last.name.endswith("epoch=0" + CHECKPOINT_EXTENSION) - best = _get_checkpoint_from_alias(output_folder, "best") - assert best.name.endswith("epoch=0" + CHECKPOINT_EXTENSION) - - assert (output_folder / "meta.json").exists() - assert len(list((output_folder / "logs").glob("events.out.tfevents.*"))) == 1 - with stdout_logging() as buf: + output_folder = session_tmp_path / "classification-standalone" + last = _get_checkpoint_from_alias(output_folder, "periodic") + assert last.name.endswith("epoch=0" + CHECKPOINT_EXTENSION) result = runner.invoke( - train, + predict, [ "pasa", "montgomery", "-vv", - "--epochs=2", "--batch-size=1", - f"--output-folder={output_folder}", + f"--weight={str(last)}", + f"--output-folder={str(output_folder)}", ], ) _assert_exit_0(result) - # asserts checkpoints are there, or raises FileNotFoundError - last = _get_checkpoint_from_alias(output_folder, "periodic") - assert last.name.endswith("epoch=1" + CHECKPOINT_EXTENSION) - best = _get_checkpoint_from_alias(output_folder, "best") - - assert (output_folder / "meta.json").exists() - assert len(list((output_folder / "logs").glob("events.out.tfevents.*"))) == 2 + assert (output_folder / "predictions.meta.json").exists() + assert (output_folder / "predictions.json").exists() keywords = { - r"^Loading dataset:`train` without caching. Trade-off: CPU RAM usage: less | Disk I/O: more.$": 1, - r"^Loading dataset:`validation` without caching. Trade-off: CPU RAM usage: less | Disk I/O: more.$": 1, - r"^Applying train/valid loss balancing...$": 1, - r"^Training for at most 2 epochs.$": 1, - r"^Resuming from epoch 0 \(checkpoint file: .*$": 1, - r"^Writing run metadata at.*$": 1, - r"^Dataset `train` is already setup. Not re-instantiating it.$": 1, - r"^Dataset `validation` is already setup. Not re-instantiating it.$": 1, + r"^Loading dataset: * without caching. Trade-off: CPU RAM usage: less | Disk I/O: more$": 3, + r"^Loading checkpoint from .*$": 1, r"^Restoring normalizer from checkpoint.$": 1, + r"^Running prediction on `train` split...$": 1, + r"^Running prediction on `validation` split...$": 1, + r"^Running prediction on `test` split...$": 1, + r"^Predictions saved to .*$": 1, } + buf.seek(0) logging_output = buf.read() @@ -319,47 +293,37 @@ def test_train_pasa_montgomery_from_checkpoint(temporary_basedir): @pytest.mark.slow @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") -def test_predict_pasa_montgomery(temporary_basedir, datadir): - from mednet.libs.classification.scripts.predict import predict - from mednet.libs.common.utils.checkpointer import ( - CHECKPOINT_EXTENSION, - _get_checkpoint_from_alias, - ) +def test_evaluate_pasa_montgomery(session_tmp_path): + from mednet.libs.classification.scripts.evaluate import evaluate runner = CliRunner() with stdout_logging() as buf: - output = temporary_basedir / "classification" / "predictions" - last = _get_checkpoint_from_alias( - temporary_basedir / "classification" / "results", - "periodic", - ) - assert last.name.endswith("epoch=0" + CHECKPOINT_EXTENSION) + output_folder = session_tmp_path / "classification-standalone" result = runner.invoke( - predict, + evaluate, [ - "pasa", - "montgomery", "-vv", - "--batch-size=1", - f"--weight={str(last)}", - f"--output-folder={str(output)}", + f"--predictions={str(output_folder / 'predictions.json')}", + f"--output-folder={str(output_folder)}", + "--threshold=test", ], ) _assert_exit_0(result) - assert output.exists() + assert (output_folder / "evaluation.json").exists() + assert (output_folder / "evaluation.meta.json").exists() + assert (output_folder / "evaluation.rst").exists() + assert (output_folder / "evaluation.pdf").exists() keywords = { - r"^Loading dataset: * without caching. Trade-off: CPU RAM usage: less | Disk I/O: more$": 3, - r"^Loading checkpoint from .*$": 1, - r"^Restoring normalizer from checkpoint.$": 1, - r"^Running prediction on `train` split...$": 1, - r"^Running prediction on `validation` split...$": 1, - r"^Running prediction on `test` split...$": 1, - r"^Predictions saved to .*$": 1, + r"^Saving evaluation metadata at .*$": 1, + r"^Setting --threshold=.*$": 1, + r"^Computing performance on split .*...$": 3, + r"^Saving evaluation results at .*$": 1, + r"^Saving evaluation results in table format at .*$": 1, + r"^Saving evaluation figures at .*$": 1, } - buf.seek(0) logging_output = buf.read() @@ -373,39 +337,69 @@ def test_predict_pasa_montgomery(temporary_basedir, datadir): @pytest.mark.slow @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") -def test_evaluate_pasa_montgomery(temporary_basedir): - from mednet.libs.classification.scripts.evaluate import evaluate +def test_train_pasa_montgomery_from_checkpoint(tmp_path): + from mednet.libs.classification.scripts.train import train + from mednet.libs.common.utils.checkpointer import ( + CHECKPOINT_EXTENSION, + _get_checkpoint_from_alias, + ) runner = CliRunner() - with stdout_logging() as buf: - prediction_path = temporary_basedir / "classification" / "predictions" - predictions_file = prediction_path / "predictions.json" - evaluation_path = temporary_basedir / "classification" / "evaluations" + result0 = runner.invoke( + train, + [ + "pasa", + "montgomery", + "-vv", + "--epochs=1", + "--batch-size=1", + f"--output-folder={str(tmp_path)}", + ], + ) + _assert_exit_0(result0) + + # asserts checkpoints are there, or raises FileNotFoundError + last = _get_checkpoint_from_alias(tmp_path, "periodic") + assert last.name.endswith("epoch=0" + CHECKPOINT_EXTENSION) + best = _get_checkpoint_from_alias(tmp_path, "best") + assert best.name.endswith("epoch=0" + CHECKPOINT_EXTENSION) + assert (tmp_path / "meta.json").exists() + assert len(list((tmp_path / "logs").glob("events.out.tfevents.*"))) == 1 + + with stdout_logging() as buf: result = runner.invoke( - evaluate, + train, [ - "-vv", + "pasa", "montgomery", - f"--predictions={predictions_file}", - f"--output-folder={evaluation_path}", - "--threshold=test", + "-vv", + "--epochs=2", + "--batch-size=1", + f"--output-folder={tmp_path}", ], ) _assert_exit_0(result) - assert (evaluation_path / "evaluation.json").exists() - assert (evaluation_path / "evaluation.meta.json").exists() - assert (evaluation_path / "evaluation.pdf").exists() - assert (evaluation_path / "evaluation.rst").exists() + # asserts checkpoints are there, or raises FileNotFoundError + last = _get_checkpoint_from_alias(tmp_path, "periodic") + assert last.name.endswith("epoch=1" + CHECKPOINT_EXTENSION) + best = _get_checkpoint_from_alias(tmp_path, "best") + + assert (tmp_path / "meta.json").exists() + assert len(list((tmp_path / "logs").glob("events.out.tfevents.*"))) == 2 + keywords = { - r"^Setting --threshold=.*$": 1, - r"^Analyzing split `train`...$": 1, - r"^Analyzing split `validation`...$": 1, - r"^Analyzing split `test`...$": 1, - r"^Saving evaluation results .*$": 2, - r"^Saving evaluation figures at .*$": 1, + r"^Loading dataset:`train` without caching. Trade-off: CPU RAM usage: less | Disk I/O: more.$": 1, + r"^Loading dataset:`validation` without caching. Trade-off: CPU RAM usage: less | Disk I/O: more.$": 1, + r"^Applying train/valid loss balancing...$": 1, + r"^Training for at most 2 epochs.$": 1, + r"^Resuming from epoch 0 \(checkpoint file: .*$": 1, + r"^Writing run metadata at.*$": 1, + r"^Dataset `train` is already setup. Not re-instantiating it.$": 1, + r"^Dataset `validation` is already setup. Not re-instantiating it.$": 1, + r"^Restoring normalizer from checkpoint.$": 1, } buf.seek(0) logging_output = buf.read() @@ -420,12 +414,11 @@ def test_evaluate_pasa_montgomery(temporary_basedir): @pytest.mark.slow @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") -def test_experiment(temporary_basedir): +def test_experiment(tmp_path): from mednet.libs.classification.scripts.experiment import experiment runner = CliRunner() - output_folder = temporary_basedir / "classification" / "experiment" num_epochs = 2 result = runner.invoke( experiment, @@ -434,61 +427,39 @@ def test_experiment(temporary_basedir): "pasa", "montgomery", f"--epochs={num_epochs}", - f"--output-folder={str(output_folder)}", + f"--output-folder={str(tmp_path)}", ], ) _assert_exit_0(result) - assert (output_folder / "model" / "meta.json").exists() - assert (output_folder / "model" / f"model-at-epoch={num_epochs-1}.ckpt").exists() - assert (output_folder / "predictions" / "predictions.json").exists() - assert (output_folder / "predictions" / "predictions.meta.json").exists() + assert (tmp_path / "model" / "meta.json").exists() + assert (tmp_path / "model" / f"model-at-epoch={num_epochs-1}.ckpt").exists() + assert (tmp_path / "predictions.json").exists() + assert (tmp_path / "predictions.meta.json").exists() # Need to glob because we cannot be sure of the checkpoint with lowest validation loss assert ( len( list( - (output_folder / "model").glob( - "model-at-lowest-validation-loss-epoch=*.ckpt", - ), - ), + (tmp_path / "model").glob( + "model-at-lowest-validation-loss-epoch=*.ckpt" + ) + ) ) == 1 ) - assert (output_folder / "model" / "trainlog.pdf").exists() + assert (tmp_path / "model" / "trainlog.pdf").exists() assert ( len( list( - (output_folder / "model" / "logs").glob( + (tmp_path / "model" / "logs").glob( "events.out.tfevents.*", ), ), ) == 1 ) - assert (output_folder / "evaluation.json").exists() - assert (output_folder / "evaluation.meta.json").exists() - assert (output_folder / "evaluation.rst").exists() - assert (output_folder / "evaluation.pdf").exists() - assert (output_folder / "gradcam" / "saliencies").exists() - assert ( - len( - list( - (output_folder / "gradcam" / "saliencies" / "CXR_png").glob( - "MCUCXR_*.npy", - ), - ), - ) - == 138 - ) - assert (output_folder / "gradcam" / "visualizations").exists() - assert ( - len( - list( - (output_folder / "gradcam" / "visualizations" / "CXR_png").glob( - "MCUCXR_*.png", - ), - ), - ) - == 58 - ) + assert (tmp_path / "evaluation.json").exists() + assert (tmp_path / "evaluation.meta.json").exists() + assert (tmp_path / "evaluation.rst").exists() + assert (tmp_path / "evaluation.pdf").exists() diff --git a/tests/conftest.py b/tests/conftest.py index 36e5fa5b1fa427a98eb8d4c390a9eb371e5724a4..98d102e7a4721bb9c87531049956f0d195645597 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -95,7 +95,7 @@ def rc_variable_set(name): @pytest.fixture(scope="session") -def temporary_basedir(tmp_path_factory): +def session_tmp_path(tmp_path_factory): return tmp_path_factory.mktemp("test-cli") diff --git a/tests/segmentation/test_cli.py b/tests/segmentation/test_cli.py index 926556bc90356dcfe3a02b8f2f3533054a071278..5f96456b767a1c6935aae4f462519b4ce3c8f96f 100644 --- a/tests/segmentation/test_cli.py +++ b/tests/segmentation/test_cli.py @@ -153,7 +153,7 @@ def test_evaluate_help(): @pytest.mark.slow @pytest.mark.skip_if_rc_var_not_set("datadir.drive") -def test_train_lwnet_drive(temporary_basedir): +def test_train_lwnet_drive(session_tmp_path): from mednet.libs.common.utils.checkpointer import ( CHECKPOINT_EXTENSION, _get_checkpoint_from_alias, @@ -163,7 +163,7 @@ def test_train_lwnet_drive(temporary_basedir): runner = CliRunner() with stdout_logging() as buf: - output_folder = temporary_basedir / "segmentation" / "results" + output_folder = session_tmp_path / "segmentation-standalone" result = runner.invoke( train, [ @@ -206,68 +206,44 @@ def test_train_lwnet_drive(temporary_basedir): @pytest.mark.slow @pytest.mark.skip_if_rc_var_not_set("datadir.drive") -def test_train_lwnet_drive_from_checkpoint(temporary_basedir): +def test_predict_lwnet_drive(session_tmp_path): from mednet.libs.common.utils.checkpointer import ( CHECKPOINT_EXTENSION, _get_checkpoint_from_alias, ) - from mednet.libs.segmentation.scripts.train import train + from mednet.libs.segmentation.scripts.predict import predict runner = CliRunner() - output_folder = temporary_basedir / "segmentation" / "results" / "lwnet_checkpoint" - result0 = runner.invoke( - train, - [ - "lwnet", - "drive", - "-vv", - "--epochs=1", - "--batch-size=1", - f"--output-folder={str(output_folder)}", - ], - ) - _assert_exit_0(result0) - - # asserts checkpoints are there, or raises FileNotFoundError - last = _get_checkpoint_from_alias(output_folder, "periodic") - assert last.name.endswith("epoch=0" + CHECKPOINT_EXTENSION) - best = _get_checkpoint_from_alias(output_folder, "best") - assert best.name.endswith("epoch=0" + CHECKPOINT_EXTENSION) - - assert (output_folder / "meta.json").exists() - assert len(list((output_folder / "logs").glob("events.out.tfevents.*"))) == 1 - with stdout_logging() as buf: + output_folder = session_tmp_path / "segmentation-standalone" + last_ckpt = _get_checkpoint_from_alias(output_folder, "periodic") + assert last_ckpt.name.endswith("epoch=0" + CHECKPOINT_EXTENSION) result = runner.invoke( - train, + predict, [ "lwnet", "drive", "-vv", - "--epochs=2", "--batch-size=1", - f"--output-folder={output_folder}", + f"--weight={str(last_ckpt)}", + f"--output-folder={str(output_folder)}", ], ) _assert_exit_0(result) - # asserts checkpoints are there, or raises FileNotFoundError - last = _get_checkpoint_from_alias(output_folder, "periodic") - assert last.name.endswith("epoch=1" + CHECKPOINT_EXTENSION) - best = _get_checkpoint_from_alias(output_folder, "best") - - assert (output_folder / "meta.json").exists() - assert len(list((output_folder / "logs").glob("events.out.tfevents.*"))) == 2 + assert (output_folder / "predictions.meta.json").exists() + assert (output_folder / "predictions.json").exists() keywords = { - r"^Loading dataset:`train` without caching. Trade-off: CPU RAM usage: less | Disk I/O: more.$": 1, - r"^Training for at most 2 epochs.$": 1, - r"^Resuming from epoch 0 \(checkpoint file: .*$": 1, - r"^Writing run metadata at.*$": 1, - r"^Dataset `train` is already setup. Not re-instantiating it.$": 3, + r"^Loading dataset: * without caching. Trade-off: CPU RAM usage: less | Disk I/O: more$": 2, + r"^Loading checkpoint from .*$": 1, r"^Restoring normalizer from checkpoint.$": 1, + r"^Running prediction on `train` split...$": 1, + r"^Running prediction on `test` split...$": 1, + r"^Predictions saved to .*$": 1, } + buf.seek(0) logging_output = buf.read() @@ -281,46 +257,38 @@ def test_train_lwnet_drive_from_checkpoint(temporary_basedir): @pytest.mark.slow @pytest.mark.skip_if_rc_var_not_set("datadir.drive") -def test_predict_lwnet_drive(temporary_basedir, datadir): - from mednet.libs.common.utils.checkpointer import ( - CHECKPOINT_EXTENSION, - _get_checkpoint_from_alias, - ) - from mednet.libs.segmentation.scripts.predict import predict +def test_evaluate_lwnet_drive(session_tmp_path): + from mednet.libs.segmentation.scripts.evaluate import evaluate runner = CliRunner() with stdout_logging() as buf: - output = temporary_basedir / "segmentation" / "predictions" - last = _get_checkpoint_from_alias( - temporary_basedir / "segmentation" / "results", - "periodic", - ) - assert last.name.endswith("epoch=0" + CHECKPOINT_EXTENSION) + output_folder = session_tmp_path / "segmentation-standalone" result = runner.invoke( - predict, + evaluate, [ - "lwnet", - "drive", "-vv", - "--batch-size=1", - f"--weight={str(last)}", - f"--output-folder={str(output)}", + f"--predictions={str(output_folder / 'predictions.json')}", + f"--output-folder={str(output_folder)}", + "--threshold=test", ], ) _assert_exit_0(result) - assert output.exists() + assert (output_folder / "evaluation.json").exists() + assert (output_folder / "evaluation.meta.json").exists() + assert (output_folder / "evaluation.rst").exists() + assert (output_folder / "evaluation.pdf").exists() keywords = { - r"^Loading dataset: * without caching. Trade-off: CPU RAM usage: less | Disk I/O: more$": 2, - r"^Loading checkpoint from .*$": 1, - r"^Restoring normalizer from checkpoint.$": 1, - r"^Running prediction on `train` split...$": 1, - r"^Running prediction on `test` split...$": 1, - r"^Predictions saved to .*$": 1, + r"^Saving evaluation metadata at .*$": 1, + r"^Counting true/false positive/negatives at split.*$": 2, + r"^Evaluating threshold on split .*$": 1, + r"^Computing performance on split .*...$": 2, + r"^Saving evaluation results at .*$": 1, + r"^Saving tabulated performance summary at .*$": 1, + r"^Saving evaluation figures at .*$": 1, } - buf.seek(0) logging_output = buf.read() @@ -334,41 +302,66 @@ def test_predict_lwnet_drive(temporary_basedir, datadir): @pytest.mark.slow @pytest.mark.skip_if_rc_var_not_set("datadir.drive") -def test_evaluate_lwnet_drive(temporary_basedir): - from mednet.libs.segmentation.scripts.evaluate import evaluate +def test_train_lwnet_drive_from_checkpoint(tmp_path): + from mednet.libs.common.utils.checkpointer import ( + CHECKPOINT_EXTENSION, + _get_checkpoint_from_alias, + ) + from mednet.libs.segmentation.scripts.train import train runner = CliRunner() + result0 = runner.invoke( + train, + [ + "lwnet", + "drive", + "-vv", + "--epochs=1", + "--batch-size=1", + f"--output-folder={str(tmp_path)}", + ], + ) + _assert_exit_0(result0) + + # asserts checkpoints are there, or raises FileNotFoundError + last = _get_checkpoint_from_alias(tmp_path, "periodic") + assert last.name.endswith("epoch=0" + CHECKPOINT_EXTENSION) + best = _get_checkpoint_from_alias(tmp_path, "best") + assert best.name.endswith("epoch=0" + CHECKPOINT_EXTENSION) + + assert (tmp_path / "meta.json").exists() + assert len(list((tmp_path / "logs").glob("events.out.tfevents.*"))) == 1 + with stdout_logging() as buf: - prediction_path = temporary_basedir / "segmentation" / "predictions" - predictions_file = prediction_path / "predictions.json" - evaluation_path = temporary_basedir / "segmentation" / "evaluations" result = runner.invoke( - evaluate, + train, [ - "-vv", + "lwnet", "drive", - f"--predictions={predictions_file}", - f"--output-folder={evaluation_path}", - "--threshold=test", + "-vv", + "--epochs=2", + "--batch-size=1", + f"--output-folder={tmp_path}", ], ) _assert_exit_0(result) - assert (evaluation_path / "evaluation.json").exists() - assert (evaluation_path / "evaluation.meta.json").exists() - assert (evaluation_path / "evaluation.pdf").exists() - assert (evaluation_path / "evaluation.rst").exists() + # asserts checkpoints are there, or raises FileNotFoundError + last = _get_checkpoint_from_alias(tmp_path, "periodic") + assert last.name.endswith("epoch=1" + CHECKPOINT_EXTENSION) + best = _get_checkpoint_from_alias(tmp_path, "best") + + assert (tmp_path / "meta.json").exists() + assert len(list((tmp_path / "logs").glob("events.out.tfevents.*"))) == 2 keywords = { + r"^Loading dataset:`train` without caching. Trade-off: CPU RAM usage: less | Disk I/O: more.$": 1, + r"^Training for at most 2 epochs.$": 1, + r"^Resuming from epoch 0 \(checkpoint file: .*$": 1, r"^Writing run metadata at.*$": 1, - r"^Counting true/false positive/negatives at split.*$": 2, - r"^Evaluating threshold on.*$": 1, - r"^Tabulating performance summary...": 1, - r"^Saving evaluation results at.*$": 1, - r"^Saving table at .*$": 1, - r"^Plotting performance curves...": 1, - r"^Saving figures at .*$": 1, + r"^Dataset `train` is already setup. Not re-instantiating it.$": 3, + r"^Restoring normalizer from checkpoint.$": 1, } buf.seek(0) logging_output = buf.read() @@ -383,12 +376,11 @@ def test_evaluate_lwnet_drive(temporary_basedir): @pytest.mark.slow @pytest.mark.skip_if_rc_var_not_set("datadir.drive") -def test_experiment(temporary_basedir): +def test_experiment(tmp_path): from mednet.libs.segmentation.scripts.experiment import experiment runner = CliRunner() - output_folder = temporary_basedir / "segmentation" / "experiment" num_epochs = 2 result = runner.invoke( experiment, @@ -397,39 +389,40 @@ def test_experiment(temporary_basedir): "lwnet", "drive", f"--epochs={num_epochs}", - f"--output-folder={str(output_folder)}", + f"--output-folder={str(tmp_path)}", ], ) _assert_exit_0(result) - assert (output_folder / "model" / "meta.json").exists() - assert (output_folder / "model" / f"model-at-epoch={num_epochs-1}.ckpt").exists() - assert (output_folder / "predictions" / "predictions.json").exists() - assert (output_folder / "predictions" / "predictions.meta.json").exists() + assert (tmp_path / "model" / "meta.json").exists() + assert (tmp_path / "model" / f"model-at-epoch={num_epochs-1}.ckpt").exists() # Need to glob because we cannot be sure of the checkpoint with lowest validation loss assert ( len( list( - (output_folder / "model").glob( + (tmp_path / "model").glob( "model-at-lowest-validation-loss-epoch=*.ckpt", ), ), ) == 1 ) - assert (output_folder / "model" / "trainlog.pdf").exists() + assert (tmp_path / "model" / "trainlog.pdf").exists() assert ( len( list( - (output_folder / "model" / "logs").glob( + (tmp_path / "model" / "logs").glob( "events.out.tfevents.*", ), ), ) == 1 ) - assert (output_folder / "evaluation" / "evaluation.json").exists() - assert (output_folder / "evaluation" / "evaluation.meta.json").exists() - assert (output_folder / "evaluation" / "evaluation.pdf").exists() - assert (output_folder / "evaluation" / "evaluation.rst").exists() + + assert (tmp_path / "predictions.json").exists() + assert (tmp_path / "predictions.meta.json").exists() + assert (tmp_path / "evaluation.json").exists() + assert (tmp_path / "evaluation.meta.json").exists() + assert (tmp_path / "evaluation.pdf").exists() + assert (tmp_path / "evaluation.rst").exists()