From 0152c9c37ec9720d3d43f2229b67652dd30cd202 Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.dos.anjos@gmail.com> Date: Thu, 14 Dec 2023 14:52:31 +0100 Subject: [PATCH] [utils.checkpointer] Refactor checkpoint saving and loading --- src/ptbench/data/datamodule.py | 5 + src/ptbench/engine/trainer.py | 52 +++++---- src/ptbench/scripts/experiment.py | 15 ++- src/ptbench/scripts/train.py | 81 ++++++++------ src/ptbench/utils/checkpointer.py | 172 +++++++++++++++++++++--------- tests/test_cli.py | 61 +++++++---- 6 files changed, 253 insertions(+), 133 deletions(-) diff --git a/src/ptbench/data/datamodule.py b/src/ptbench/data/datamodule.py index 3d7e1a64..b954a000 100644 --- a/src/ptbench/data/datamodule.py +++ b/src/ptbench/data/datamodule.py @@ -580,8 +580,10 @@ class ConcatDataModule(lightning.LightningDataModule): if value < 0: num_workers = 0 + else: num_workers = value or multiprocessing.cpu_count() + self._dataloader_multiproc["num_workers"] = num_workers if num_workers > 0 and sys.platform == "darwin": @@ -589,6 +591,9 @@ class ConcatDataModule(lightning.LightningDataModule): "multiprocessing_context" ] = multiprocessing.get_context("spawn") + # keep workers hanging around if we have multiple + self._dataloader_multiproc["persistent_workers"] = True + @property def model_transforms(self) -> list[Transform] | None: """Transforms required to fit data into the model. diff --git a/src/ptbench/engine/trainer.py b/src/ptbench/engine/trainer.py index fccf47b8..b4f878fe 100644 --- a/src/ptbench/engine/trainer.py +++ b/src/ptbench/engine/trainer.py @@ -13,6 +13,7 @@ import lightning.pytorch.callbacks import lightning.pytorch.loggers import torch.nn +from ..utils.checkpointer import CHECKPOINT_ALIASES from ..utils.resources import ResourceMonitor, cpu_constants, gpu_constants from .callbacks import LoggingCallback from .device import DeviceManager @@ -47,13 +48,13 @@ def save_model_summary( summary_path = output_folder / "model-summary.txt" logger.info(f"Saving model summary at {summary_path}...") with summary_path.open("w") as f: - summary = lightning.pytorch.utilities.model_summary.ModelSummary( + summary = lightning.pytorch.utilities.model_summary.ModelSummary( # type: ignore model, max_depth=-1 ) f.write(str(summary)) return ( summary, - lightning.pytorch.utilities.model_summary.ModelSummary( + lightning.pytorch.utilities.model_summary.ModelSummary( # type: ignore model ).total_parameters, ) @@ -99,13 +100,13 @@ def static_information_to_csv( def run( model: lightning.pytorch.LightningModule, datamodule: lightning.pytorch.LightningDataModule, - checkpoint_period: int, + validation_period: int, device_manager: DeviceManager, max_epochs: int, output_folder: pathlib.Path, monitoring_interval: int | float, batch_chunk_count: int, - checkpoint: str | None, + checkpoint: pathlib.Path | None, ): """Fits a CNN model using supervised learning and save it to disk. @@ -122,9 +123,15 @@ def run( datamodule The lightning datamodule to use for training **and** validation - checkpoint_period - Save a checkpoint every ``n`` epochs. If set to ``0`` (zero), then do - not save intermediary checkpoints. + validation_period + Number of epochs after which validation happens. By default, we run + validation after every training epoch (period=1). You can change this + to make validation more sparse, by increasing the validation period. + Notice that this affects checkpoint saving. While checkpoints are + created after every training step (the last training step always + triggers the overriding of latest checkpoint), and that this process is + independent of validation runs, evaluation of the 'best' model obtained + so far based on those will be influenced by this setting. device_manager An internal device representation, to be used for training and @@ -177,17 +184,22 @@ def run( logging_level=logging.ERROR, ) - checkpoint_callback = lightning.pytorch.callbacks.ModelCheckpoint( - output_folder, - "model_lowest_valid_loss", - save_last=True, + # This checkpointer will operate at the end of every validation epoch + # (which happens at each checkpoint period), it will then save the lowest + # validation loss model observed. It will also save the last trained model + checkpoint_minvalloss_callback = lightning.pytorch.callbacks.ModelCheckpoint( + dirpath=output_folder, + filename=CHECKPOINT_ALIASES["best"], + save_last=True, # will (re)create the last trained model, at every iteration monitor="loss/validation", mode="min", - save_on_train_epoch_end=True, - every_n_epochs=checkpoint_period, + save_on_train_epoch_end=True, # run checks at the end of validation + every_n_epochs=validation_period, # frequency at which it would check the "monitor" + enable_version_counter=False, # no versioning of aliased checkpoints ) - - checkpoint_callback.CHECKPOINT_NAME_LAST = "model_final_epoch" + checkpoint_minvalloss_callback.CHECKPOINT_NAME_LAST = CHECKPOINT_ALIASES[ # type: ignore + "periodic" + ] # write static information to a CSV file static_information_to_csv( @@ -204,9 +216,13 @@ def run( max_epochs=max_epochs, accumulate_grad_batches=batch_chunk_count, logger=tensorboard_logger, - check_val_every_n_epoch=1, + check_val_every_n_epoch=validation_period, log_every_n_steps=len(datamodule.train_dataloader()), - callbacks=[LoggingCallback(resource_monitor), checkpoint_callback], + callbacks=[ + LoggingCallback(resource_monitor), + checkpoint_minvalloss_callback, + ], ) - _ = trainer.fit(model, datamodule, ckpt_path=checkpoint) + checkpoint_str = checkpoint if checkpoint is None else str(checkpoint) + _ = trainer.fit(model, datamodule, ckpt_path=checkpoint_str) diff --git a/src/ptbench/scripts/experiment.py b/src/ptbench/scripts/experiment.py index 55c24260..1d520fe3 100644 --- a/src/ptbench/scripts/experiment.py +++ b/src/ptbench/scripts/experiment.py @@ -42,7 +42,7 @@ def experiment( batch_chunk_count, drop_incomplete_batch, datamodule, - checkpoint_period, + validation_period, device, cache_samples, seed, @@ -84,7 +84,7 @@ def experiment( batch_chunk_count=batch_chunk_count, drop_incomplete_batch=drop_incomplete_batch, datamodule=datamodule, - checkpoint_period=checkpoint_period, + validation_period=validation_period, device=device, cache_samples=cache_samples, seed=seed, @@ -111,13 +111,12 @@ def experiment( logger.info("Started predicting") - from .predict import predict + from ..utils.checkpointer import get_checkpoint_to_run_inference + + model_file = get_checkpoint_to_run_inference(train_output_folder) + logger.info(f"Found `{str(model_file)}`. Continuing...") - # preferably, we use the best model on the validation set - # otherwise, we get the last saved model - model_file = train_output_folder / "model_lowest_valid_loss.ckpt" - if not model_file.exists(): - model_file = train_output_folder / "model_final_epoch.ckpt" + from .predict import predict predictions_output = output_folder / "predictions.json" diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py index c5c2a3ac..e8149a41 100644 --- a/src/ptbench/scripts/train.py +++ b/src/ptbench/scripts/train.py @@ -125,16 +125,21 @@ def reusable_options(f): cls=ResourceOption, ) @click.option( - "--checkpoint-period", + "--validation-period", "-p", - help="""Number of epochs after which a checkpoint is saved. A value of - zero will disable check-pointing. If checkpointing is enabled and - training stops, it is automatically resumed from the last saved - checkpoint if training is restarted with the same configuration.""", + help="""Number of epochs after which validation happens. By default, + we run validation after every training epoch (period=1). You can + change this to make validation more sparse, by increasing the + validation period. Notice that this affects checkpoint saving. While + checkpoints are created after every training step (the last training + step always triggers the overriding of latest checkpoint), and that + this process is independent of validation runs, evaluation of the + 'best' model obtained so far based on those will be influenced by this + setting.""", show_default=True, - required=False, - default=None, - type=click.IntRange(min=0), + required=True, + default=1, + type=click.IntRange(min=1), cls=ResourceOption, ) @click.option( @@ -183,27 +188,19 @@ def reusable_options(f): "--monitoring-interval", "-I", help="""Time between checks for the use of resources during each training - epoch. An interval of 5 seconds, for example, will lead to CPU and GPU - resources being probed every 5 seconds during each training epoch. - Values registered in the training logs correspond to averages (or maxima) - observed through possibly many probes in each epoch. Notice that setting a - very small value may cause the probing process to become extremely busy, - potentially biasing the overall perception of resource usage.""", + epoch, in seconds. An interval of 5 seconds, for example, will lead to + CPU and GPU resources being probed every 5 seconds during each training + epoch. Values registered in the training logs correspond to averages + (or maxima) observed through possibly many probes in each epoch. + Notice that setting a very small value may cause the probing process to + become extremely busy, potentially biasing the overall perception of + resource usage.""", type=click.FloatRange(min=0.1), show_default=True, required=True, default=5.0, cls=ResourceOption, ) - @click.option( - "--resume-from", - help="""Which checkpoint to resume training from. If set, can be one of - `best`, `last`, or a path to a model checkpoint.""", - type=click.STRING, - required=False, - default=None, - cls=ResourceOption, - ) @click.option( "--balance-classes/--no-balance-classes", "-B/-N", @@ -244,13 +241,12 @@ def train( batch_chunk_count, drop_incomplete_batch, datamodule, - checkpoint_period, + validation_period, device, cache_samples, seed, parallel, monitoring_interval, - resume_from, balance_classes, **_, ) -> None: @@ -263,20 +259,31 @@ def train( resume the procedure in case it stops abruptly. """ + import os + import torch from lightning.pytorch import seed_everything from ..engine.device import DeviceManager from ..engine.trainer import run - from ..utils.checkpointer import get_checkpoint + from ..utils.checkpointer import get_checkpoint_to_resume_training from .utils import save_sh_command + checkpoint_file = None + if os.path.isdir(output_folder): + try: + checkpoint_file = get_checkpoint_to_resume_training(output_folder) + except FileNotFoundError: + logger.info( + f"Folder {output_folder} already exists, but I did not" + f" find any usable checkpoint file to resume training" + f" from. Starting from scratch..." + ) + save_sh_command(output_folder / "command.sh") seed_everything(seed) - checkpoint_file = get_checkpoint(output_folder, resume_from) - # reset datamodule with user configurable options datamodule.set_chunk_size(batch_size, batch_chunk_count) datamodule.drop_incomplete_batch = drop_incomplete_batch @@ -307,25 +314,31 @@ def train( arguments["epoch"] = 0 if checkpoint_file is None or not hasattr(model, "on_load_checkpoint"): - # Sets the model normalizer with the unaugmented-train-subset. - # this call may be a NOOP, if the model was pre-trained and expects - # different weights for the normalisation layer. + # Sets the model normalizer with the unaugmented-train-subset if we are + # starting from scratch and/or the model does not contain its own + # checkpoint loading strategy (e.g. a pytorch stock checkpoint). This + # call may be a NOOP, if the model comes from outside this framework, + # and expects different weights for the normalisation layer. if hasattr(model, "set_normalizer"): model.set_normalizer(datamodule.unshuffled_train_dataloader()) else: logger.warning( - f"Model {model.name} has no 'set_normalizer' method. Skipping." + f"Model {model.name} has no `set_normalizer` method. " + "Skipping normalization setup (unsupported external model)." ) else: # Normalizer will be loaded during model.on_load_checkpoint checkpoint = torch.load(checkpoint_file) start_epoch = checkpoint["epoch"] - logger.info(f"Resuming from epoch {start_epoch}...") + logger.info( + f"Resuming from epoch {start_epoch} " + f"(checkpoint file: `{str(checkpoint_file)}`)..." + ) run( model=model, datamodule=datamodule, - checkpoint_period=checkpoint_period, + validation_period=validation_period, device_manager=DeviceManager(device), max_epochs=epochs, output_folder=output_folder, diff --git a/src/ptbench/utils/checkpointer.py b/src/ptbench/utils/checkpointer.py index 8f6685b8..988e611e 100644 --- a/src/ptbench/utils/checkpointer.py +++ b/src/ptbench/utils/checkpointer.py @@ -4,76 +4,146 @@ import logging import pathlib +import re import typing logger = logging.getLogger(__name__) -def get_checkpoint( - output_folder: pathlib.Path, - resume_from: typing.Literal["last", "best"] | str | None, -) -> str | None: - """Gets a checkpoint file. +CHECKPOINT_ALIASES = { + "best": "model-at-lowest-validation-loss-{epoch}", + "periodic": "model-at-{epoch}", +} +"""Standard paths where checkpoints may be (if produced with this +framework).""" - Can return the best or last checkpoint, or a checkpoint at a specific path. - Ensures the checkpoint exists, raising an error if it is not the case. +CHECKPOINT_EXTENSION = ".ckpt" - If ``resume_from`` is ``None``, checks the output directory if a "last" - checkpoint file already exists and returns it. If no checkpoint is found, - returns ``None``. - ``resume_from`` can also be a path to an existing checkpoint file. In this - case, we check it and return if it exists. +def _get_checkpoint_from_alias( + path: pathlib.Path, + alias: typing.Literal["best", "periodic"], +) -> pathlib.Path: + """Gets an existing checkpoint file path. + + This function can search for names matching the checkpoint alias "stem" + (ie. the prefix), and then assumes a dash "-" and a number follows that + prefix before the expected file extension. The number is parsed and + considred to be an epoch number. The latest file (the file containing the + highest epoch number) is returned. + + If only one file is present matching the alias characteristics, then it is + returned. Parameters ---------- - output_folder - Folder in which checkpoints are stored. - resume_from - Which model to get. Can be one of "best", "last", or a path to a checkpoint. - If ``None``, gets the last checkpoint if it exists, otherwise returns - ``None`` (signal to start from scratch). + path + Folder in which may contain checkpoint + alias + Can be one of "best" or "periodic". Returns ------- - Path to the requested checkpoint (as a plain string) or ``None`` (start - from scratch). + Path to the requested checkpoint, or ``None``, if no checkpoint file + matching specifications is found on the provided path. Raises ------ FileNotFoundError - In case a required file cannot be found. + In case it cannot find any file on the provided path matching the given + specifications. """ - # standard paths where checkpoints may be (if produced with this framework) - last_path = output_folder / "model_final_epoch.ckpt" - best_path = output_folder / "model_lowest_valid_loss.ckpt" - - if resume_from in ("last", "best"): - use_file = last_path if resume_from == "last" else best_path - if use_file.is_file(): - logger.info(f"Found checkpoint at `{str(use_file)}`") - return str(use_file) - else: - raise FileNotFoundError( - f"Could not find a checkpoint file at `{str(use_file)}`" - ) - - elif resume_from is None: - # use-case: user is re-starting a crashed/cancelled job - if last_path.is_file(): - logger.info(f"Found checkpoint at `{str(last_path)}`") - return str(last_path) - else: - return None - - elif isinstance(resume_from, str): - if pathlib.Path(resume_from).is_file(): - logger.info(f"Found checkpoint at `{resume_from}`") - return resume_from - else: - raise FileNotFoundError( - f"Could not find a checkpoint file at `{resume_from}`" - ) + + template = path / (CHECKPOINT_ALIASES[alias] + CHECKPOINT_EXTENSION) + + if template.exists(): + return template + + # otherwise, we see if we are looking for a template instead, in which case + # we must pick the latest. + assert "{epoch}" in str( + template + ), f"Template `{str(template)}` does not contain the keyword `{{epoch}}`" + + pattern = re.compile( + template.name.replace("{epoch}", r"epoch=(?P<epoch>\d+)") + ) + highest = -1 + for f in template.parent.iterdir(): + match = pattern.match(f.name) + if match is not None: + value = int(match.group("epoch")) + if value > highest: + highest = value + + if highest != -1: + return template.with_name( + template.name.replace("{epoch}", f"epoch={highest}") + ) + + raise FileNotFoundError( + f"A file matching `{str(template)}` specifications was not found" + ) + + +def get_checkpoint_to_resume_training( + path: pathlib.Path, +): + """Returns the best checkpoint file path to resume training from. + + Parameters + ---------- + path + The base directory containing either the "periodic" checkpoint to start + the training session from. + + + Returns + ------- + Path to a checkpoint file that exists on disk + + + Raises + ------ + FileNotFoundError + If none of the checkpoints can be found on the provided directory. + """ + + return _get_checkpoint_from_alias(path, "periodic") + + +def get_checkpoint_to_run_inference( + path: pathlib.Path, +): + """Returns the best checkpoint file path to run inference with. + + Parameters + ---------- + path + The base directory containing either the "best", "last" or "periodic" + checkpoint to start the training session from. + + + Returns + ------- + Path to a checkpoint file that exists on disk + + + Raises + ------ + FileNotFoundError + If none of the checkpoints can be found on the provided directory. + """ + + try: + _get_checkpoint_from_alias(path, "best") + except FileNotFoundError: + logger.error( + "Did not find lowest-validation-loss model to run inference " + "from. Trying to search for the last periodically saved model..." + ) + + return _get_checkpoint_from_alias(path, "periodic") diff --git a/tests/test_cli.py b/tests/test_cli.py index 4d408bac..f90462c5 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -193,11 +193,15 @@ def test_compare_vis_help(): @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") def test_train_pasa_montgomery(temporary_basedir): from ptbench.scripts.train import train + from ptbench.utils.checkpointer import ( + CHECKPOINT_EXTENSION, + _get_checkpoint_from_alias, + ) runner = CliRunner() with stdout_logging() as buf: - output_folder = str(temporary_basedir / "results") + output_folder = temporary_basedir / "results" result = runner.invoke( train, [ @@ -206,17 +210,17 @@ def test_train_pasa_montgomery(temporary_basedir): "-vv", "--epochs=1", "--batch-size=1", - f"--output-folder={output_folder}", + f"--output-folder={str(output_folder)}", ], ) _assert_exit_0(result) - assert os.path.exists( - os.path.join(output_folder, "model_final_epoch.ckpt") - ) - assert os.path.exists( - os.path.join(output_folder, "model_lowest_valid_loss.ckpt") - ) + # asserts checkpoints are there, or raises FileNotFoundError + last = _get_checkpoint_from_alias(output_folder, "periodic") + assert last.name.endswith("epoch=0" + CHECKPOINT_EXTENSION) + best = _get_checkpoint_from_alias(output_folder, "best") + assert best.name.endswith("epoch=0" + CHECKPOINT_EXTENSION) + assert os.path.exists(os.path.join(output_folder, "constants.csv")) assert ( len( @@ -254,10 +258,14 @@ def test_train_pasa_montgomery(temporary_basedir): @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") def test_train_pasa_montgomery_from_checkpoint(temporary_basedir): from ptbench.scripts.train import train + from ptbench.utils.checkpointer import ( + CHECKPOINT_EXTENSION, + _get_checkpoint_from_alias, + ) runner = CliRunner() - output_folder = str(temporary_basedir / "results/pasa_checkpoint") + output_folder = temporary_basedir / "results" / "pasa_checkpoint" result0 = runner.invoke( train, [ @@ -266,15 +274,17 @@ def test_train_pasa_montgomery_from_checkpoint(temporary_basedir): "-vv", "--epochs=1", "--batch-size=1", - f"--output-folder={output_folder}", + f"--output-folder={str(output_folder)}", ], ) _assert_exit_0(result0) - assert os.path.exists(os.path.join(output_folder, "model_final_epoch.ckpt")) - assert os.path.exists( - os.path.join(output_folder, "model_lowest_valid_loss.ckpt") - ) + # asserts checkpoints are there, or raises FileNotFoundError + last = _get_checkpoint_from_alias(output_folder, "periodic") + assert last.name.endswith("epoch=0" + CHECKPOINT_EXTENSION) + best = _get_checkpoint_from_alias(output_folder, "best") + assert best.name.endswith("epoch=0" + CHECKPOINT_EXTENSION) + assert os.path.exists(os.path.join(output_folder, "constants.csv")) assert ( len( @@ -301,12 +311,11 @@ def test_train_pasa_montgomery_from_checkpoint(temporary_basedir): ) _assert_exit_0(result) - assert os.path.exists( - os.path.join(output_folder, "model_final_epoch.ckpt") - ) - assert os.path.exists( - os.path.join(output_folder, "model_lowest_valid_loss.ckpt") - ) + # asserts checkpoints are there, or raises FileNotFoundError + last = _get_checkpoint_from_alias(output_folder, "periodic") + assert last.name.endswith("epoch=1" + CHECKPOINT_EXTENSION) + best = _get_checkpoint_from_alias(output_folder, "best") + assert os.path.exists(os.path.join(output_folder, "constants.csv")) assert ( @@ -348,11 +357,19 @@ def test_train_pasa_montgomery_from_checkpoint(temporary_basedir): @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") def test_predict_pasa_montgomery(temporary_basedir, datadir): from ptbench.scripts.predict import predict + from ptbench.utils.checkpointer import ( + CHECKPOINT_EXTENSION, + _get_checkpoint_from_alias, + ) runner = CliRunner() with stdout_logging() as buf: - output = str(temporary_basedir / "predictions") + output = temporary_basedir / "predictions" + last = _get_checkpoint_from_alias( + temporary_basedir / "results", "periodic" + ) + assert last.name.endswith("epoch=0" + CHECKPOINT_EXTENSION) result = runner.invoke( predict, [ @@ -360,7 +377,7 @@ def test_predict_pasa_montgomery(temporary_basedir, datadir): "montgomery", "-vv", "--batch-size=1", - f"--weight={str(temporary_basedir / 'results' / 'model_final_epoch.ckpt')}", + f"--weight={str(last)}", f"--output={output}", ], ) -- GitLab