diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py
index 6003eaffbdd14960e0b4298e366a991c4dd33324..f5bd3a7afa0b19e8ad8650cb325c7fc5ba79a166 100644
--- a/src/ptbench/scripts/train.py
+++ b/src/ptbench/scripts/train.py
@@ -2,14 +2,14 @@
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
-import os
-
 import click
 
 from clapper.click import ConfigCommand, ResourceOption, verbosity_option
 from clapper.logging import setup
 from pytorch_lightning import seed_everything
 
+from ..utils.checkpointer import get_checkpoint
+
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
 
@@ -432,41 +432,17 @@ def train(
 
     arguments = {}
     arguments["max_epoch"] = epochs
+    arguments["epoch"] = 0
 
-    last_checkpoint_path = os.path.join(output_folder, "model_final_epoch.ckpt")
-    best_checkpoint_path = os.path.join(
-        output_folder, "model_lowest_valid_loss.ckpt"
-    )
-
-    if resume_from == "last":
-        if os.path.isfile(last_checkpoint_path):
-            checkpoint = last_checkpoint_path
-            logger.info(f"Resuming training from {resume_from} checkpoint")
-        else:
-            raise FileNotFoundError(
-                f"Could not find checkpoint {last_checkpoint_path}"
-            )
-
-    elif resume_from == "best":
-        if os.path.isfile(best_checkpoint_path):
-            checkpoint = last_checkpoint_path
-            logger.info(f"Resuming training from {resume_from} checkpoint")
-        else:
-            raise FileNotFoundError(
-                f"Could not find checkpoint {best_checkpoint_path}"
-            )
-
-    elif resume_from is None:
-        checkpoint = None
+    checkpoint_file = get_checkpoint(output_folder, resume_from)
 
-    else:
-        if os.path.isfile(resume_from):
-            checkpoint = resume_from
-            logger.info(f"Resuming training from checkpoint {resume_from}")
-        else:
-            raise FileNotFoundError(f"Could not find checkpoint {resume_from}")
+    # We only load the checkpoint to get some information about its state. The actual loading of the model is done in trainer.fit()
+    if checkpoint_file is not None:
+        checkpoint = torch.load(checkpoint_file)
+        arguments["epoch"] = checkpoint["epoch"]
 
     logger.info("Training for {} epochs".format(arguments["max_epoch"]))
+    logger.info("Continuing from epoch {}".format(arguments["epoch"]))
 
     run(
         model=model,
@@ -479,5 +455,5 @@ def train(
         output_folder=output_folder,
         monitoring_interval=monitoring_interval,
         batch_chunk_count=batch_chunk_count,
-        checkpoint=checkpoint,
+        checkpoint=checkpoint_file,
     )
diff --git a/src/ptbench/utils/checkpointer.py b/src/ptbench/utils/checkpointer.py
new file mode 100644
index 0000000000000000000000000000000000000000..a30cba9444504a2abb350f4f7d9ca513bde9dde3
--- /dev/null
+++ b/src/ptbench/utils/checkpointer.py
@@ -0,0 +1,41 @@
+import logging
+import os
+
+logger = logging.getLogger(__name__)
+
+
+def get_checkpoint(output_folder, resume_from):
+    last_checkpoint_path = os.path.join(output_folder, "model_final_epoch.ckpt")
+    best_checkpoint_path = os.path.join(
+        output_folder, "model_lowest_valid_loss.ckpt"
+    )
+
+    if resume_from == "last":
+        if os.path.isfile(last_checkpoint_path):
+            checkpoint_file = last_checkpoint_path
+            logger.info(f"Resuming training from {resume_from} checkpoint")
+        else:
+            raise FileNotFoundError(
+                f"Could not find checkpoint {last_checkpoint_path}"
+            )
+
+    elif resume_from == "best":
+        if os.path.isfile(best_checkpoint_path):
+            checkpoint_file = last_checkpoint_path
+            logger.info(f"Resuming training from {resume_from} checkpoint")
+        else:
+            raise FileNotFoundError(
+                f"Could not find checkpoint {best_checkpoint_path}"
+            )
+
+    elif resume_from is None:
+        checkpoint_file = None
+
+    else:
+        if os.path.isfile(resume_from):
+            checkpoint_file = resume_from
+            logger.info(f"Resuming training from checkpoint {resume_from}")
+        else:
+            raise FileNotFoundError(f"Could not find checkpoint {resume_from}")
+
+    return checkpoint_file