diff --git a/src/mednet/libs/classification/scripts/train.py b/src/mednet/libs/classification/scripts/train.py index 94fbcb00bacbb0cadce116e5a944c4abf57cdbb3..dbc14ebef400b23dbacd90c458d6108b4f76c69e 100644 --- a/src/mednet/libs/classification/scripts/train.py +++ b/src/mednet/libs/classification/scripts/train.py @@ -2,8 +2,13 @@ import click from clapper.click import ResourceOption, verbosity_option from clapper.logging import setup from mednet.libs.common.scripts.click import ConfigCommand -from mednet.libs.common.scripts.train import reusable_options -from mednet.libs.common.scripts.train import train as train_script +from mednet.libs.common.scripts.train import ( + get_checkpoint_file, + load_checkpoint, + reusable_options, + save_json_data, + setup_datamodule, +) logger = setup("mednet", format="%(levelname)s: %(message)s") @@ -48,20 +53,68 @@ def train( extension that are used in subsequent tasks or from which training can be resumed. """ - train_script( + from lightning.pytorch import seed_everything + from mednet.libs.common.engine.device import DeviceManager + from mednet.libs.common.engine.trainer import run + + seed_everything(seed) + device_manager = DeviceManager(device) + + # reset datamodule with user configurable options + setup_datamodule( + datamodule, + model, + batch_size, + batch_chunk_count, + drop_incomplete_batch, + cache_samples, + parallel, + ) + + # If asked, rebalances the loss criterion based on the relative proportion + # of class examples available in the training set. Also affects the + # validation loss if a validation set is available on the DataModule. + if balance_classes: + logger.info("Applying DataModule train sampler balancing...") + datamodule.balance_sampler_by_class = True + # logger.info("Applying train/valid loss balancing...") + # model.balance_losses_by_class(datamodule) + else: + logger.info( + "Skipping sample class/dataset ownership balancing on user request", + ) + + checkpoint_file = get_checkpoint_file(output_folder) + load_checkpoint(checkpoint_file, datamodule, model) + + logger.info(f"Training for at most {epochs} epochs.") + + # stores all information we can think of, to reproduce this later + save_json_data( + datamodule, model, output_folder, + device_manager, epochs, batch_size, accumulate_grad_batches, drop_incomplete_batch, - datamodule, validation_period, - device, cache_samples, seed, parallel, monitoring_interval, balance_classes, - **_, + ) + + run( + model=model, + datamodule=datamodule, + validation_period=validation_period, + device_manager=device_manager, + max_epochs=epochs, + output_folder=output_folder, + monitoring_interval=monitoring_interval, + batch_chunk_count=batch_chunk_count, + checkpoint=checkpoint_file, ) diff --git a/src/mednet/libs/common/scripts/train.py b/src/mednet/libs/common/scripts/train.py index 2228725b673068f8417fcb1a8b6fc829357dab39..32d7bd9481b0d973d0f0eb6279836d584578d84b 100644 --- a/src/mednet/libs/common/scripts/train.py +++ b/src/mednet/libs/common/scripts/train.py @@ -215,84 +215,50 @@ def reusable_options(f): return wrapper_reusable_options -def train( - model, - output_folder, - epochs, - batch_size, - accumulate_grad_batches, - drop_incomplete_batch, - datamodule, - validation_period, - device, - cache_samples, - seed, - parallel, - monitoring_interval, - balance_classes, - **_, -) -> None: # numpydoc ignore=PR01 - """Train an CNN to perform image classification. +def get_checkpoint_file(results_dir) -> pathlib.Path: + """Return the path of the latest checkpoint if it exists. - Training is performed for a configurable number of epochs, and - generates checkpoints. Checkpoints are model files with a .ckpt - extension that are used in subsequent tasks or from which training - can be resumed. - """ + Parameters + ---------- + results_dir + Directory in which results are saved. - import torch - from lightning.pytorch import seed_everything - from mednet.libs.common.engine.device import DeviceManager - from mednet.libs.common.engine.trainer import run + Returns + ------- + Path to the latest checkpoint + """ from mednet.libs.common.utils.checkpointer import ( get_checkpoint_to_resume_training, ) - from .utils import ( - device_properties, - execution_metadata, - model_summary, - save_json_with_backup, - ) - checkpoint_file = None - if output_folder.is_dir(): + if results_dir.is_dir(): try: - checkpoint_file = get_checkpoint_to_resume_training(output_folder) + checkpoint_file = get_checkpoint_to_resume_training(results_dir) except FileNotFoundError: logger.info( - f"Folder {output_folder} already exists, but I did not" + f"Folder {results_dir} already exists, but I did not" f" find any usable checkpoint file to resume training" f" from. Starting from scratch...", ) - seed_everything(seed) - - # reset datamodule with user configurable options - datamodule.drop_incomplete_batch = drop_incomplete_batch - datamodule.cache_samples = cache_samples - datamodule.parallel = parallel - datamodule.model_transforms = model.model_transforms + return checkpoint_file - datamodule.prepare_data() - datamodule.setup(stage="fit") - # If asked, rebalances the loss criterion based on the relative proportion - # of class examples available in the training set. Also affects the - # validation loss if a validation set is available on the DataModule. - if balance_classes: - logger.info("Applying train/valid loss balancing...") - model.balance_losses(datamodule) - else: - logger.info( - "Skipping sample class/dataset ownership balancing on user request", - ) +def load_checkpoint(checkpoint_file, datamodule, model): + """Load the checkpoint. - logger.info(f"Training for at most {epochs} epochs.") + Parameters + ---------- + checkpoint_file + Path to the checkpoint. + datamodule + Instance of a Datamodule, used to set the model's normalizer. + model + The model corresponding to the checkpoint. + """ - arguments = {} - arguments["max_epoch"] = epochs - arguments["epoch"] = 0 + import torch if checkpoint_file is None or not hasattr(model, "on_load_checkpoint"): # Sets the model normalizer with the unaugmented-train-subset if we are @@ -316,9 +282,51 @@ def train( f"(checkpoint file: `{str(checkpoint_file)}`)...", ) - device_manager = DeviceManager(device) - # stores all information we can think of, to reproduce this later +def setup_datamodule( + datamodule, + model, + batch_size, + batch_chunk_count, + drop_incomplete_batch, + cache_samples, + parallel, +) -> None: # numpydoc ignore=PR01 + """Configure and set up the datamodule.""" + datamodule.set_chunk_size(batch_size, batch_chunk_count) + datamodule.drop_incomplete_batch = drop_incomplete_batch + datamodule.cache_samples = cache_samples + datamodule.parallel = parallel + datamodule.model_transforms = model.model_transforms + + datamodule.prepare_data() + datamodule.setup(stage="fit") + + +def save_json_data( + datamodule, + model, + output_folder, + device_manager, + epochs, + batch_size, + accumulate_grad_batches, + drop_incomplete_batch, + validation_period, + cache_samples, + seed, + parallel, + monitoring_interval, + balance_classes, +) -> None: # numpydoc ignore=PR01 + """Save training hyperparameters into a .json file.""" + from .utils import ( + device_properties, + execution_metadata, + model_summary, + save_json_with_backup, + ) + json_data: dict[str, typing.Any] = execution_metadata() json_data.update(device_properties(device_manager.device_type)) json_data.update( @@ -341,15 +349,3 @@ def train( json_data.update(model_summary(model)) json_data = {k.replace("_", "-"): v for k, v in json_data.items()} save_json_with_backup(output_folder / "meta.json", json_data) - - run( - model=model, - datamodule=datamodule, - validation_period=validation_period, - device_manager=device_manager, - max_epochs=epochs, - output_folder=output_folder, - monitoring_interval=monitoring_interval, - accumulate_grad_batches=accumulate_grad_batches, - checkpoint=checkpoint_file, - ) diff --git a/src/mednet/libs/segmentation/scripts/train.py b/src/mednet/libs/segmentation/scripts/train.py index 933231c6d6aa3a6a96088d61f42407b041ac1da8..2632369c874fda1aad7b792c6f6a0a655b0cfe3b 100644 --- a/src/mednet/libs/segmentation/scripts/train.py +++ b/src/mednet/libs/segmentation/scripts/train.py @@ -2,14 +2,19 @@ import click from clapper.click import ResourceOption, verbosity_option from clapper.logging import setup from mednet.libs.common.scripts.click import ConfigCommand -from mednet.libs.common.scripts.train import reusable_options -from mednet.libs.common.scripts.train import train as train_script +from mednet.libs.common.scripts.train import ( + get_checkpoint_file, + load_checkpoint, + reusable_options, + save_json_data, + setup_datamodule, +) logger = setup("mednet", format="%(levelname)s: %(message)s") @click.command( - entry_point_group="mednet.libs.segmentation.config", + entry_point_group="mednet.libs.classification.config", cls=ConfigCommand, epilog="""Examples: @@ -39,27 +44,62 @@ def train( balance_classes, **_, ) -> None: # numpydoc ignore=PR01 - """Train a model to perform image segmentation. + """Train an CNN to perform image classification. Training is performed for a configurable number of epochs, and generates checkpoints. Checkpoints are model files with a .ckpt extension that are used in subsequent tasks or from which training can be resumed. """ - train_script( + from lightning.pytorch import seed_everything + from mednet.libs.common.engine.device import DeviceManager + from mednet.libs.common.engine.trainer import run + + seed_everything(seed) + device_manager = DeviceManager(device) + + # reset datamodule with user configurable options + setup_datamodule( + datamodule, + model, + batch_size, + batch_chunk_count, + drop_incomplete_batch, + cache_samples, + parallel, + ) + + checkpoint_file = get_checkpoint_file(output_folder) + load_checkpoint(checkpoint_file, datamodule, model) + + logger.info(f"Training for at most {epochs} epochs.") + + # stores all information we can think of, to reproduce this later + save_json_data( + datamodule, model, output_folder, + device_manager, epochs, batch_size, batch_chunk_count, drop_incomplete_batch, - datamodule, validation_period, - device, cache_samples, seed, parallel, monitoring_interval, balance_classes, - **_, + ) + + run( + model=model, + datamodule=datamodule, + validation_period=validation_period, + device_manager=device_manager, + max_epochs=epochs, + output_folder=output_folder, + monitoring_interval=monitoring_interval, + batch_chunk_count=batch_chunk_count, + checkpoint=checkpoint_file, )