Skip to content
Snippets Groups Projects
Commit 229dd23c authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

[train-analysis] Fix input parameter issue

parent b53404c5
No related branches found
No related tags found
1 merge request!46Create common library
......@@ -101,7 +101,7 @@ def experiment(
ctx.invoke(
train_analysis,
logdir=train_output_folder / "logs",
output=output_folder / "trainlog.pdf",
output_folder=output_folder / "trainlog.pdf",
)
logger.info("Ended train analysis")
......
......@@ -9,8 +9,6 @@ import click
from clapper.click import ResourceOption, verbosity_option
from clapper.logging import setup
from .click import ConfigCommand
# avoids X11/graphical desktop requirement when creating plots
__import__("matplotlib").use("agg")
......@@ -186,7 +184,6 @@ def _create_figures(
@click.command(
cls=ConfigCommand,
epilog="""Examples:
\b
......@@ -194,7 +191,7 @@ def _create_figures(
.. code:: sh
mednet train-analysis -vv --log-dir=results/logs --output=trainlog.pdf
mednet train-analysis -vv results/logs
""",
)
@click.option(
......@@ -203,19 +200,15 @@ def _create_figures(
help="Path to the directory containing the Tensorboard training logs",
required=True,
type=click.Path(dir_okay=True, exists=True, path_type=pathlib.Path),
default="results/logs",
cls=ResourceOption,
)
@click.option(
"--output",
"--output-folder",
"-o",
help="""Path to a PDF file in which to store results. (leading
directories are created if they do not exist).""",
help="Directory in which to store results (created if does not exist)",
required=True,
cls=ResourceOption,
type=click.Path(
file_okay=True,
dir_okay=False,
file_okay=False,
dir_okay=True,
writable=True,
path_type=pathlib.Path,
),
......@@ -224,19 +217,21 @@ def _create_figures(
@verbosity_option(logger=logger, expose_value=False)
def train_analysis(
logdir: pathlib.Path,
output: pathlib.Path,
**_,
output_folder: pathlib.Path,
) -> None: # numpydoc ignore=PR01
"""Create a plot for each metric in the training logs and saves them in a .pdf file."""
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from mednet.libs.common.utils.tensorboard import scalars_to_dict
train_log_filename = "trainlog.pdf"
train_log_file = pathlib.Path(output_folder) / train_log_filename
data = scalars_to_dict(logdir)
output.parent.mkdir(parents=True, exist_ok=True)
train_log_file.parent.mkdir(parents=True, exist_ok=True)
with PdfPages(output) as pdf:
with PdfPages(train_log_file) as pdf:
for figure in _create_figures(data):
pdf.savefig(figure)
plt.close(figure)
......@@ -3,14 +3,11 @@
# SPDX-License-Identifier: GPL-3.0-or-later
"""Defines most common types used in code."""
import pathlib
import typing
import torch
Checkpoint: typing.TypeAlias = typing.MutableMapping[str, typing.Any]
"""Definition of a lightning checkpoint."""
SegmentationPrediction: typing.TypeAlias = tuple[
str, torch.Tensor, torch.Tensor, torch.Tensor
]
SegmentationPrediction: typing.TypeAlias = tuple[pathlib.Path, pathlib.Path]
"""The sample name, the target, and the predicted value."""
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment