diff --git a/src/mednet/engine/trainer.py b/src/mednet/engine/trainer.py index 3da73a1124d7c94704d0a6b2228eedd75664777d..7fd8f2a7a4d30fcf1beae5ed55870d3b2b056f33 100644 --- a/src/mednet/engine/trainer.py +++ b/src/mednet/engine/trainer.py @@ -2,108 +2,22 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -import csv import logging import os import pathlib -import shutil import lightning.pytorch import lightning.pytorch.callbacks import lightning.pytorch.loggers -import torch.nn from ..utils.checkpointer import CHECKPOINT_ALIASES -from ..utils.resources import ( - ResourceMonitor, - cpu_constants, - cuda_constants, - mps_constants, -) +from ..utils.resources import ResourceMonitor from .callbacks import LoggingCallback -from .device import DeviceManager, SupportedPytorchDevice +from .device import DeviceManager logger = logging.getLogger(__name__) -def save_model_summary( - output_folder: pathlib.Path, - model: torch.nn.Module, -) -> tuple[lightning.pytorch.callbacks.ModelSummary, int]: - """Save a little summary of the model in a txt file. - - Parameters - ---------- - output_folder - Directory in which to save the summary. - model - Instance of the model for which to save the summary. - - Returns - ------- - tuple[lightning.pytorch.callbacks.ModelSummary, int] - A tuple with the model summary in a text format and number of parameters of the model. - """ - summary_path = output_folder / "model-summary.txt" - logger.info(f"Saving model summary at {summary_path}...") - with summary_path.open("w") as f: - summary = lightning.pytorch.utilities.model_summary.ModelSummary( # type: ignore - model, max_depth=-1 - ) - f.write(str(summary)) - return ( - summary, - lightning.pytorch.utilities.model_summary.ModelSummary( # type: ignore - model - ).total_parameters, - ) - - -def static_information_to_csv( - static_logfile_name: pathlib.Path, - device_type: SupportedPytorchDevice, - model_size: int, -) -> None: - """Save the static information in a CSV file. - - Parameters - ---------- - static_logfile_name - The static file name which is a join between the output folder and - "constants.csv". - device_type - The type of device we are using. - model_size - The size of the model we will be training. - """ - if static_logfile_name.exists(): - backup = static_logfile_name.parent / (static_logfile_name.name + "~") - shutil.copy(static_logfile_name, backup) - - with static_logfile_name.open("w", newline="") as f: - logdata: dict[str, int | float | str] = {} - logdata.update(cpu_constants()) - - match device_type: - case "cpu": - pass - case "cuda": - results = cuda_constants() - if results is not None: - logdata.update(results) - case "mps": - results = mps_constants() - if results is not None: - logdata.update(results) - case _: - pass - - logdata["number-of-model-parameters"] = model_size - logwriter = csv.DictWriter(f, fieldnames=logdata.keys()) - logwriter.writeheader() - logwriter.writerow(logdata) - - def run( model: lightning.pytorch.LightningModule, datamodule: lightning.pytorch.LightningDataModule, @@ -158,9 +72,6 @@ def run( os.makedirs(output_folder, exist_ok=True) - # Save model summary - _, no_of_parameters = save_model_summary(output_folder, model) - from .loggers import CustomTensorboardLogger log_dir = "logs" @@ -205,13 +116,6 @@ def run( "periodic" ] - # write static information to a CSV file - static_information_to_csv( - output_folder / "constants.csv", - device_manager.device_type, - no_of_parameters, - ) - with train_resource_monitor, validation_resource_monitor: accelerator, devices = device_manager.lightning_accelerator() trainer = lightning.pytorch.Trainer( diff --git a/src/mednet/scripts/predict.py b/src/mednet/scripts/predict.py index 68dd8da7e5f1e28b03d21746b9550039f4caa440..216e166e0e9410186351375397eba16963b465e8 100644 --- a/src/mednet/scripts/predict.py +++ b/src/mednet/scripts/predict.py @@ -129,10 +129,17 @@ def predict( import json import shutil + import typing from ..engine.device import DeviceManager from ..engine.predictor import run from ..utils.checkpointer import get_checkpoint_to_run_inference + from .utils import ( + device_properties, + execution_metadata, + model_summary, + save_json_with_backup, + ) datamodule.set_chunk_size(batch_size, 1) datamodule.parallel = parallel @@ -147,7 +154,23 @@ def predict( logger.info(f"Loading checkpoint from `{weight}`...") model = type(model).load_from_checkpoint(weight, strict=False) - predictions = run(model, datamodule, DeviceManager(device)) + device_manager = DeviceManager(device) + + # register metadata + json_data: dict[str, typing.Any] = execution_metadata() + json_data.update(device_properties(device_manager.device_type)) + json_data.update( + dict( + database_name=datamodule.database_name, + database_split=datamodule.split_name, + model_name=model.name, + ) + ) + json_data.update(model_summary(model)) + json_data = {k.replace("_", "-"): v for k, v in json_data.items()} + save_json_with_backup(output.with_suffix(".meta.json"), json_data) + + predictions = run(model, datamodule, device_manager) output.parent.mkdir(parents=True, exist_ok=True) if output.exists(): diff --git a/src/mednet/scripts/train.py b/src/mednet/scripts/train.py index 1dc78e4a68e690e7a311b31966270b2de4eedadf..9280f1dd02c454feb8212e06f29b47117c1b7660 100644 --- a/src/mednet/scripts/train.py +++ b/src/mednet/scripts/train.py @@ -4,6 +4,7 @@ import functools import pathlib +import typing import click @@ -266,7 +267,12 @@ def train( from ..engine.device import DeviceManager from ..engine.trainer import run from ..utils.checkpointer import get_checkpoint_to_resume_training - from .utils import save_sh_command + from .utils import ( + device_properties, + execution_metadata, + model_summary, + save_json_with_backup, + ) checkpoint_file = None if os.path.isdir(output_folder): @@ -279,7 +285,6 @@ def train( f" from. Starting from scratch..." ) - save_sh_command(output_folder / "command.sh") seed_everything(seed) # reset datamodule with user configurable options @@ -333,11 +338,37 @@ def train( f"(checkpoint file: `{str(checkpoint_file)}`)..." ) + device_manager = DeviceManager(device) + + # stores all information we can think of, to reproduce this later + json_data: dict[str, typing.Any] = execution_metadata() + json_data.update(device_properties(device_manager.device_type)) + json_data.update( + dict( + database_name=datamodule.database_name, + split_name=datamodule.split_name, + epochs=epochs, + batch_size=batch_size, + batch_chunk_count=batch_chunk_count, + drop_incomplete_batch=drop_incomplete_batch, + validation_period=validation_period, + cache_samples=cache_samples, + seed=seed, + parallel=parallel, + monitoring_interval=monitoring_interval, + balance_classes=balance_classes, + model_name=model.name, + ) + ) + json_data.update(model_summary(model)) + json_data = {k.replace("_", "-"): v for k, v in json_data.items()} + save_json_with_backup(output_folder / "meta.json", json_data) + run( model=model, datamodule=datamodule, validation_period=validation_period, - device_manager=DeviceManager(device), + device_manager=device_manager, max_epochs=epochs, output_folder=output_folder, monitoring_interval=monitoring_interval, diff --git a/src/mednet/scripts/utils.py b/src/mednet/scripts/utils.py index 7c62d84e715123b11442717bdce103479aef1cec..a4e5432e229c69e5d8dbd5224a63ca90b96aa4b2 100644 --- a/src/mednet/scripts/utils.py +++ b/src/mednet/scripts/utils.py @@ -1,64 +1,160 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -import importlib.metadata +"""Utilities for command-line scripts.""" + +import json import logging -import os import pathlib import shutil -import sys -import time + +import lightning.pytorch +import lightning.pytorch.callbacks +import torch.nn + +from ..engine.device import SupportedPytorchDevice logger = logging.getLogger(__name__) -def save_sh_command(path: pathlib.Path) -> None: - """Record command-line to reproduce this script. +def model_summary( + model: torch.nn.Module, +) -> dict[str, int | list[tuple[str, str, int]]]: + """Save a little summary of the model in a txt file. + + Parameters + ---------- + model + Instance of the model for which to save the summary. + + Returns + ------- + tuple[lightning.pytorch.callbacks.ModelSummary, int] + A tuple with the model summary in a text format and number of parameters of the model. + """ + + s = lightning.pytorch.utilities.model_summary.ModelSummary( # type: ignore + model + ) + + return dict( + model_summary=list(zip(s.layer_names, s.layer_types, s.param_nums)), + model_size=s.total_parameters, + ) + - This function can record the current command-line used to call the script - being run. It creates an executable ``bash`` script setting up the current - working directory and activating a conda environment, if needed. It - records further information on the date and time the script was run and the - version of the package. +def device_properties( + device_type: SupportedPytorchDevice, +) -> dict[str, int | float | str]: + """Generate information concerning hardware properties. + + Parameters + ---------- + device_type + The type of compute device we are using. + + Returns + ------- + Static properties of the current machine. + """ + + from ..utils.resources import cpu_constants, cuda_constants, mps_constants + + retval: dict[str, int | float | str] = {} + retval.update(cpu_constants()) + + match device_type: + case "cpu": + pass + case "cuda": + results = cuda_constants() + if results is not None: + retval.update(results) + case "mps": + results = mps_constants() + if results is not None: + retval.update(results) + case _: + pass + + return retval + + +def execution_metadata() -> dict[str, int | float | str]: + """Produce metadata concerning the running script, in the form of a dictionary. + + This function returns potentially useful metadata concerning program + execution. It contains a certain number of preset variables. + + Returns + ------- + A dictionary that contains the following fields: + + * ``package-name``: current package name (e.g. ``mednet``) + * ``package-version``: current package version (e.g. ``1.0.0b0``) + * ``datetime``: date and time in ISO8601 format (e.g. ``2024-02-23T18:38:09+01:00``) + * ``user``: username (e.g. ``johndoe``) + * ``conda-env``: if set, the name of the current conda environment + * ``path``: current path when executing the command + * ``command-line``: the command-line that is being run + * ``hostname``: machine hostname (e.g. ``localhost``) + * ``platform``: machine platform (e.g. ``darwin``) + """ + + import importlib.metadata + import os + import sys + + args = [] + for k in sys.argv: + if " " in k: + args.append(f"'{k}'") + else: + args.append(k) + + data = { + "package-name": __package__.split(".")[0], + "package-version": importlib.metadata.version( + __package__.split(".")[0] + ), + "date": __import__("datetime") + .datetime.now() + .astimezone() + .replace(microsecond=0) + .isoformat(), + "user": __import__("getpass").getuser(), + "conda-env": os.environ.get("CONDA_DEFAULT_ENV", ""), + "path": os.path.realpath(os.curdir), + "command-line": " ".join(args), + "hostname": __import__("platform").node(), + "platform": sys.platform, + } + + return data + + +def save_json_with_backup(path: pathlib.Path, data: dict | list) -> None: + """Save a dictionary into a JSON file with path checking and backup. + + This function will save a dictionary into a JSON file. It will check to + the existence of the directory leading to the file and create it if + necessary. If the file already exists on the destination folder, it is + backed-up before a new file is created with the new contents. Parameters ---------- path - Path to a file where the commands to reproduce the current run will be - recorded. Parent directories will be created if they do not exist. An - existing copy will be backed-up if it exists. + The full path where to save the JSON data. + data + The data to save on the JSON file. """ - logger.info(f"Writing command-line for reproduction at `{path}`...") + logger.info(f"Writing run metadata at `{path}`...") - # create parent directories path.parent.mkdir(parents=True, exist_ok=True) - - # backup if exists if path.exists(): backup = path.parent / (path.name + "~") shutil.copy(path, backup) - # write the file - package = __name__.split(".", 1)[0] - version = importlib.metadata.version(package) - with path.open("w") as f: - f.write("#!/usr/bin/env sh\n") - f.write(f"# date: {time.asctime()}\n") - f.write(f"# version: {version} ({package})\n") - f.write(f"# platform: {sys.platform}\n") - f.write("\n") - args = [] - for k in sys.argv: - if " " in k: - args.append(f'"{k}"') - else: - args.append(k) - if os.environ.get("CONDA_DEFAULT_ENV") is not None: - f.write(f"# conda activate {os.environ['CONDA_DEFAULT_ENV']}\n") - f.write(f"# cd {os.path.realpath(os.curdir)}\n") - f.write(" ".join(args) + "\n") - - # make it executable - path.chmod(0o755) + json.dump(data, f, indent=2)