diff --git a/src/ptbench/engine/device.py b/src/ptbench/engine/device.py index 253bba0d9da3bedca6010b2ad87937b9da4f08e0..2eeef34a96156083df564a20746e447f2e577afe 100644 --- a/src/ptbench/engine/device.py +++ b/src/ptbench/engine/device.py @@ -128,7 +128,7 @@ class DeviceManager: f"Unexpected device type {self.device_type} lacks support" ) - def lightning_accelerator(self) -> tuple[str, int | list[int] | str | None]: + def lightning_accelerator(self) -> tuple[str, int | list[int] | str]: """Returns the lightning accelerator setup. Returns diff --git a/src/ptbench/engine/trainer.py b/src/ptbench/engine/trainer.py index 10121af1091a0e3a186efd4c338f4af0ad4b9cf5..052bb885549087bf64c02e5be05ee1debaabaaca 100644 --- a/src/ptbench/engine/trainer.py +++ b/src/ptbench/engine/trainer.py @@ -14,14 +14,16 @@ import torch.nn from ..utils.resources import ResourceMonitor, cpu_constants, gpu_constants from .callbacks import LoggingCallback +from .device import DeviceManager logger = logging.getLogger(__name__) def save_model_summary( - output_folder: str, model: torch.nn.Module + output_folder: str, + model: torch.nn.Module, ) -> tuple[lightning.pytorch.callbacks.ModelSummary, int]: - """Save a little summary of the model in a txt file. + """Saves a little summary of the model in a txt file. Parameters ---------- @@ -32,13 +34,14 @@ def save_model_summary( model Network (e.g. driu, hed, unet) + Returns ------- - summary: - The model summary in a text format. + summary + The model summary in a text format - total_parameters: - The number of parameters of the model. + total_parameters + The number of parameters of the model """ summary_path = os.path.join(output_folder, "model_summary.txt") logger.info(f"Saving model summary at {summary_path}...") @@ -94,15 +97,15 @@ def static_information_to_csv( def run( - model, - datamodule, - checkpoint_period, - device_manager, - arguments, - output_folder, - monitoring_interval, - batch_chunk_count, - checkpoint, + model: lightning.pytorch.LightningModule, + datamodule: lightning.pytorch.LightningDataModule, + checkpoint_period: int, + device_manager: DeviceManager, + max_epochs: int, + output_folder: str, + monitoring_interval: int | float, + batch_chunk_count: int, + checkpoint: str, ): """Fits a CNN model using supervised learning and save it to disk. @@ -113,48 +116,40 @@ def run( Parameters ---------- - model : :py:class:`torch.nn.Module` + model Neural network model (e.g. pasa). - data_loader : :py:class:`torch.utils.data.DataLoader` - The pytorch Dataloader used to iterate over batches. - - valid_loaders : :py:class:`list` of :py:class:`torch.utils.data.DataLoader` - To be used to validate the model and enable automatic checkpointing. - If ``None``, then do not validate it. - - extra_valid_loaders : :py:class:`list` of :py:class:`torch.utils.data.DataLoader` - To be used to validate the model, however **does not affect** automatic - checkpointing. If empty, then does not log anything else. Otherwise, - an extra column with the loss of every dataset in this list is kept on - the final training log. + datamodule + The lightning datamodule to use for training **and** validation - checkpoint_period : int + checkpoint_period Save a checkpoint every ``n`` epochs. If set to ``0`` (zero), then do not save intermediary checkpoints. - device_manager : DeviceManager - A device, to be used for training. + device_manager + An internal device representation, to be used for training and + validation. This representation can be converted into a pytorch device + or a torch lightning accelerator setup. - arguments : dict - Start and end epochs: + max_epochs + The maximum number of epochs to train for. - output_folder : str + output_folder Directory in which the results will be saved. - monitoring_interval : int, float + monitoring_interval Interval, in seconds (or fractions), through which we should monitor resources during training. - batch_chunk_count: int + batch_chunk_count If this number is different than 1, then each batch will be divided in this number of chunks. Gradients will be accumulated to perform each mini-batch. This is particularly interesting when one has limited RAM on the GPU, but would like to keep training with larger batches. One exchanges for longer processing times in this case. - """ - max_epoch = arguments["max_epoch"] + checkpoint + """ os.makedirs(output_folder, exist_ok=True) @@ -198,7 +193,7 @@ def run( trainer = lightning.pytorch.Trainer( accelerator=accelerator, devices=devices, - max_epochs=max_epoch, + max_epochs=max_epochs, accumulate_grad_batches=batch_chunk_count, logger=[csv_logger, tensorboard_logger], check_val_every_n_epoch=1, diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py index 664b8b1ad1ae38625a22af4a1092b15da20b2727..bffeebdb5d2cae835b7cd17706ebac4b93ef35fe 100644 --- a/src/ptbench/scripts/train.py +++ b/src/ptbench/scripts/train.py @@ -229,8 +229,7 @@ def train( procedure in case it stops abruptly. """ - import torch.cuda - import torch.nn + import torch from lightning.pytorch import seed_everything @@ -276,25 +275,20 @@ def train( "Skipping sample class/dataset ownership balancing on user request" ) - arguments = {} - arguments["max_epoch"] = epochs - arguments["epoch"] = 0 - + logger.info(f"Training for at most {epochs} epochs.") # We only load the checkpoint to get some information about its state. The # actual loading of the model is done in trainer.fit() if checkpoint_file is not None: checkpoint = torch.load(checkpoint_file) - arguments["epoch"] = checkpoint["epoch"] - - logger.info("Training for {} epochs".format(arguments["max_epoch"])) - logger.info("Continuing from epoch {}".format(arguments["epoch"])) + start_epoch = checkpoint["epoch"] + logger.info(f"Resuming from epoch {start_epoch}...") run( model=model, datamodule=datamodule, checkpoint_period=checkpoint_period, device_manager=DeviceManager(device), - arguments=arguments, + max_epochs=epochs, output_folder=output_folder, monitoring_interval=monitoring_interval, batch_chunk_count=batch_chunk_count,