diff --git a/src/ptbench/engine/device.py b/src/ptbench/engine/device.py
index 253bba0d9da3bedca6010b2ad87937b9da4f08e0..2eeef34a96156083df564a20746e447f2e577afe 100644
--- a/src/ptbench/engine/device.py
+++ b/src/ptbench/engine/device.py
@@ -128,7 +128,7 @@ class DeviceManager:
             f"Unexpected device type {self.device_type} lacks support"
         )
 
-    def lightning_accelerator(self) -> tuple[str, int | list[int] | str | None]:
+    def lightning_accelerator(self) -> tuple[str, int | list[int] | str]:
         """Returns the lightning accelerator setup.
 
         Returns
diff --git a/src/ptbench/engine/trainer.py b/src/ptbench/engine/trainer.py
index 10121af1091a0e3a186efd4c338f4af0ad4b9cf5..052bb885549087bf64c02e5be05ee1debaabaaca 100644
--- a/src/ptbench/engine/trainer.py
+++ b/src/ptbench/engine/trainer.py
@@ -14,14 +14,16 @@ import torch.nn
 
 from ..utils.resources import ResourceMonitor, cpu_constants, gpu_constants
 from .callbacks import LoggingCallback
+from .device import DeviceManager
 
 logger = logging.getLogger(__name__)
 
 
 def save_model_summary(
-    output_folder: str, model: torch.nn.Module
+    output_folder: str,
+    model: torch.nn.Module,
 ) -> tuple[lightning.pytorch.callbacks.ModelSummary, int]:
-    """Save a little summary of the model in a txt file.
+    """Saves a little summary of the model in a txt file.
 
     Parameters
     ----------
@@ -32,13 +34,14 @@ def save_model_summary(
     model
         Network (e.g. driu, hed, unet)
 
+
     Returns
     -------
-    summary:
-        The model summary in a text format.
+    summary
+        The model summary in a text format
 
-    total_parameters:
-        The number of parameters of the model.
+    total_parameters
+        The number of parameters of the model
     """
     summary_path = os.path.join(output_folder, "model_summary.txt")
     logger.info(f"Saving model summary at {summary_path}...")
@@ -94,15 +97,15 @@ def static_information_to_csv(
 
 
 def run(
-    model,
-    datamodule,
-    checkpoint_period,
-    device_manager,
-    arguments,
-    output_folder,
-    monitoring_interval,
-    batch_chunk_count,
-    checkpoint,
+    model: lightning.pytorch.LightningModule,
+    datamodule: lightning.pytorch.LightningDataModule,
+    checkpoint_period: int,
+    device_manager: DeviceManager,
+    max_epochs: int,
+    output_folder: str,
+    monitoring_interval: int | float,
+    batch_chunk_count: int,
+    checkpoint: str,
 ):
     """Fits a CNN model using supervised learning and save it to disk.
 
@@ -113,48 +116,40 @@ def run(
     Parameters
     ----------
 
-    model : :py:class:`torch.nn.Module`
+    model
         Neural network model (e.g. pasa).
 
-    data_loader : :py:class:`torch.utils.data.DataLoader`
-        The pytorch Dataloader used to iterate over batches.
-
-    valid_loaders : :py:class:`list` of :py:class:`torch.utils.data.DataLoader`
-        To be used to validate the model and enable automatic checkpointing.
-        If ``None``, then do not validate it.
-
-    extra_valid_loaders : :py:class:`list` of :py:class:`torch.utils.data.DataLoader`
-        To be used to validate the model, however **does not affect** automatic
-        checkpointing. If empty, then does not log anything else.  Otherwise,
-        an extra column with the loss of every dataset in this list is kept on
-        the final training log.
+    datamodule
+        The lightning datamodule to use for training **and** validation
 
-    checkpoint_period : int
+    checkpoint_period
         Save a checkpoint every ``n`` epochs.  If set to ``0`` (zero), then do
         not save intermediary checkpoints.
 
-    device_manager : DeviceManager
-        A device, to be used for training.
+    device_manager
+        An internal device representation, to be used for training and
+        validation.  This representation can be converted into a pytorch device
+        or a torch lightning accelerator setup.
 
-    arguments : dict
-        Start and end epochs:
+    max_epochs
+        The maximum number of epochs to train for.
 
-    output_folder : str
+    output_folder
         Directory in which the results will be saved.
 
-    monitoring_interval : int, float
+    monitoring_interval
         Interval, in seconds (or fractions), through which we should monitor
         resources during training.
 
-    batch_chunk_count: int
+    batch_chunk_count
         If this number is different than 1, then each batch will be divided in
         this number of chunks.  Gradients will be accumulated to perform each
         mini-batch.   This is particularly interesting when one has limited RAM
         on the GPU, but would like to keep training with larger batches.  One
         exchanges for longer processing times in this case.
-    """
 
-    max_epoch = arguments["max_epoch"]
+    checkpoint
+    """
 
     os.makedirs(output_folder, exist_ok=True)
 
@@ -198,7 +193,7 @@ def run(
         trainer = lightning.pytorch.Trainer(
             accelerator=accelerator,
             devices=devices,
-            max_epochs=max_epoch,
+            max_epochs=max_epochs,
             accumulate_grad_batches=batch_chunk_count,
             logger=[csv_logger, tensorboard_logger],
             check_val_every_n_epoch=1,
diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py
index 664b8b1ad1ae38625a22af4a1092b15da20b2727..bffeebdb5d2cae835b7cd17706ebac4b93ef35fe 100644
--- a/src/ptbench/scripts/train.py
+++ b/src/ptbench/scripts/train.py
@@ -229,8 +229,7 @@ def train(
     procedure in case it stops abruptly.
     """
 
-    import torch.cuda
-    import torch.nn
+    import torch
 
     from lightning.pytorch import seed_everything
 
@@ -276,25 +275,20 @@ def train(
             "Skipping sample class/dataset ownership balancing on user request"
         )
 
-    arguments = {}
-    arguments["max_epoch"] = epochs
-    arguments["epoch"] = 0
-
+    logger.info(f"Training for at most {epochs} epochs.")
     # We only load the checkpoint to get some information about its state. The
     # actual loading of the model is done in trainer.fit()
     if checkpoint_file is not None:
         checkpoint = torch.load(checkpoint_file)
-        arguments["epoch"] = checkpoint["epoch"]
-
-    logger.info("Training for {} epochs".format(arguments["max_epoch"]))
-    logger.info("Continuing from epoch {}".format(arguments["epoch"]))
+        start_epoch = checkpoint["epoch"]
+        logger.info(f"Resuming from epoch {start_epoch}...")
 
     run(
         model=model,
         datamodule=datamodule,
         checkpoint_period=checkpoint_period,
         device_manager=DeviceManager(device),
-        arguments=arguments,
+        max_epochs=epochs,
         output_folder=output_folder,
         monitoring_interval=monitoring_interval,
         batch_chunk_count=batch_chunk_count,