diff --git a/src/ptbench/engine/trainer.py b/src/ptbench/engine/trainer.py
index ddacf63cdd86cbe0e045835556f48aa5c8c99bb2..05b37febe6410e50f09b86e99226e2c86fe5ab7c 100644
--- a/src/ptbench/engine/trainer.py
+++ b/src/ptbench/engine/trainer.py
@@ -462,6 +462,7 @@ def run(
     output_folder,
     monitoring_interval,
     batch_chunk_count,
+    checkpoint,
 ):
     """Fits a CNN model using supervised learning and save it to disk.
 
@@ -560,7 +561,7 @@ def run(
             callbacks=[LoggingCallback(resource_monitor), checkpoint_callback],
         )
 
-        _ = trainer.fit(model, data_loader, valid_loader)
+        _ = trainer.fit(model, data_loader, valid_loader, ckpt_path=checkpoint)
 
     """# write static information to a CSV file
     static_logfile_name = os.path.join(output_folder, "constants.csv")
diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py
index 12e80ae131e4a82c53e0cd7a042006cfd3bb707e..ea82f5a7b91f272893812474d6b2a19551e71600 100644
--- a/src/ptbench/scripts/train.py
+++ b/src/ptbench/scripts/train.py
@@ -254,13 +254,6 @@ def set_reproducible_cuda():
     default=-1,
     cls=ResourceOption,
 )
-@click.option(
-    "--weight",
-    "-w",
-    help="Path or URL to pretrained model file (.pth extension)",
-    required=False,
-    cls=ResourceOption,
-)
 @click.option(
     "--normalization",
     "-n",
@@ -287,6 +280,14 @@ def set_reproducible_cuda():
     default=5.0,
     cls=ResourceOption,
 )
+@click.option(
+    "--resume-from",
+    help="Which checkpoint to resume training from. Can be one of 'None', 'best', 'last', or a path to a ckpt file.",
+    type=str,
+    required=False,
+    default=None,
+    cls=ResourceOption,
+)
 @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
 def train(
     model,
@@ -303,9 +304,9 @@ def train(
     device,
     seed,
     parallel,
-    weight,
     normalization,
     monitoring_interval,
+    resume_from,
     **_,
 ):
     """Trains an CNN to perform tuberculosis detection.
@@ -483,6 +484,39 @@ def train(
     arguments["epoch"] = 0
     arguments["max_epoch"] = epochs
 
+    last_checkpoint_path = os.path.join(output_folder, "model_final_epoch.ckpt")
+    best_checkpoint_path = os.path.join(
+        output_folder, "model_lowest_valid_loss.ckpt"
+    )
+
+    if resume_from == "last":
+        if os.path.isfile(last_checkpoint_path):
+            checkpoint = last_checkpoint_path
+            logger.info(f"Resuming training from {resume_from} checkpoint")
+        else:
+            raise FileNotFoundError(
+                f"Could not find checkpoint {last_checkpoint_path}"
+            )
+
+    elif resume_from == "best":
+        if os.path.isfile(best_checkpoint_path):
+            checkpoint = last_checkpoint_path
+            logger.info(f"Resuming training from {resume_from} checkpoint")
+        else:
+            raise FileNotFoundError(
+                f"Could not find checkpoint {best_checkpoint_path}"
+            )
+
+    elif resume_from is None:
+        checkpoint = None
+
+    else:
+        if os.path.isfile(resume_from):
+            checkpoint = resume_from
+            logger.info(f"Resuming training from checkpoint {resume_from}")
+        else:
+            raise FileNotFoundError(f"Could not find checkpoint {resume_from}")
+
     logger.info("Training for {} epochs".format(arguments["max_epoch"]))
     logger.info("Continuing from epoch {}".format(arguments["epoch"]))
 
@@ -498,4 +532,5 @@ def train(
         output_folder=output_folder,
         monitoring_interval=monitoring_interval,
         batch_chunk_count=batch_chunk_count,
+        checkpoint=checkpoint,
     )