Skip to content
Snippets Groups Projects
Commit 2fbbd899 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[ptbench.engine.trainer] Implement type hints

parent b596fe59
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
......@@ -128,7 +128,7 @@ class DeviceManager:
f"Unexpected device type {self.device_type} lacks support"
)
def lightning_accelerator(self) -> tuple[str, int | list[int] | str | None]:
def lightning_accelerator(self) -> tuple[str, int | list[int] | str]:
"""Returns the lightning accelerator setup.
Returns
......
......@@ -14,14 +14,16 @@ import torch.nn
from ..utils.resources import ResourceMonitor, cpu_constants, gpu_constants
from .callbacks import LoggingCallback
from .device import DeviceManager
logger = logging.getLogger(__name__)
def save_model_summary(
output_folder: str, model: torch.nn.Module
output_folder: str,
model: torch.nn.Module,
) -> tuple[lightning.pytorch.callbacks.ModelSummary, int]:
"""Save a little summary of the model in a txt file.
"""Saves a little summary of the model in a txt file.
Parameters
----------
......@@ -32,13 +34,14 @@ def save_model_summary(
model
Network (e.g. driu, hed, unet)
Returns
-------
summary:
The model summary in a text format.
summary
The model summary in a text format
total_parameters:
The number of parameters of the model.
total_parameters
The number of parameters of the model
"""
summary_path = os.path.join(output_folder, "model_summary.txt")
logger.info(f"Saving model summary at {summary_path}...")
......@@ -94,15 +97,15 @@ def static_information_to_csv(
def run(
model,
datamodule,
checkpoint_period,
device_manager,
arguments,
output_folder,
monitoring_interval,
batch_chunk_count,
checkpoint,
model: lightning.pytorch.LightningModule,
datamodule: lightning.pytorch.LightningDataModule,
checkpoint_period: int,
device_manager: DeviceManager,
max_epochs: int,
output_folder: str,
monitoring_interval: int | float,
batch_chunk_count: int,
checkpoint: str,
):
"""Fits a CNN model using supervised learning and save it to disk.
......@@ -113,48 +116,40 @@ def run(
Parameters
----------
model : :py:class:`torch.nn.Module`
model
Neural network model (e.g. pasa).
data_loader : :py:class:`torch.utils.data.DataLoader`
The pytorch Dataloader used to iterate over batches.
valid_loaders : :py:class:`list` of :py:class:`torch.utils.data.DataLoader`
To be used to validate the model and enable automatic checkpointing.
If ``None``, then do not validate it.
extra_valid_loaders : :py:class:`list` of :py:class:`torch.utils.data.DataLoader`
To be used to validate the model, however **does not affect** automatic
checkpointing. If empty, then does not log anything else. Otherwise,
an extra column with the loss of every dataset in this list is kept on
the final training log.
datamodule
The lightning datamodule to use for training **and** validation
checkpoint_period : int
checkpoint_period
Save a checkpoint every ``n`` epochs. If set to ``0`` (zero), then do
not save intermediary checkpoints.
device_manager : DeviceManager
A device, to be used for training.
device_manager
An internal device representation, to be used for training and
validation. This representation can be converted into a pytorch device
or a torch lightning accelerator setup.
arguments : dict
Start and end epochs:
max_epochs
The maximum number of epochs to train for.
output_folder : str
output_folder
Directory in which the results will be saved.
monitoring_interval : int, float
monitoring_interval
Interval, in seconds (or fractions), through which we should monitor
resources during training.
batch_chunk_count: int
batch_chunk_count
If this number is different than 1, then each batch will be divided in
this number of chunks. Gradients will be accumulated to perform each
mini-batch. This is particularly interesting when one has limited RAM
on the GPU, but would like to keep training with larger batches. One
exchanges for longer processing times in this case.
"""
max_epoch = arguments["max_epoch"]
checkpoint
"""
os.makedirs(output_folder, exist_ok=True)
......@@ -198,7 +193,7 @@ def run(
trainer = lightning.pytorch.Trainer(
accelerator=accelerator,
devices=devices,
max_epochs=max_epoch,
max_epochs=max_epochs,
accumulate_grad_batches=batch_chunk_count,
logger=[csv_logger, tensorboard_logger],
check_val_every_n_epoch=1,
......
......@@ -229,8 +229,7 @@ def train(
procedure in case it stops abruptly.
"""
import torch.cuda
import torch.nn
import torch
from lightning.pytorch import seed_everything
......@@ -276,25 +275,20 @@ def train(
"Skipping sample class/dataset ownership balancing on user request"
)
arguments = {}
arguments["max_epoch"] = epochs
arguments["epoch"] = 0
logger.info(f"Training for at most {epochs} epochs.")
# We only load the checkpoint to get some information about its state. The
# actual loading of the model is done in trainer.fit()
if checkpoint_file is not None:
checkpoint = torch.load(checkpoint_file)
arguments["epoch"] = checkpoint["epoch"]
logger.info("Training for {} epochs".format(arguments["max_epoch"]))
logger.info("Continuing from epoch {}".format(arguments["epoch"]))
start_epoch = checkpoint["epoch"]
logger.info(f"Resuming from epoch {start_epoch}...")
run(
model=model,
datamodule=datamodule,
checkpoint_period=checkpoint_period,
device_manager=DeviceManager(device),
arguments=arguments,
max_epochs=epochs,
output_folder=output_folder,
monitoring_interval=monitoring_interval,
batch_chunk_count=batch_chunk_count,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment