diff --git a/src/ptbench/engine/trainer.py b/src/ptbench/engine/trainer.py
index 965a89cd7ee82f5c68ea079368b7a414dd264a95..41ade3f310d43d12a5ed113d633e27b73b6cb29f 100644
--- a/src/ptbench/engine/trainer.py
+++ b/src/ptbench/engine/trainer.py
@@ -512,15 +512,12 @@ def run(
     valid_loader,
     extra_valid_loaders,
     optimizer,
-    criterion,
-    checkpointer,
     checkpoint_period,
     device,
     arguments,
     output_folder,
     monitoring_interval,
     batch_chunk_count,
-    criterion_valid,
 ):
     """Fits a CNN model using supervised learning and save it to disk.
 
@@ -549,12 +546,6 @@ def run(
 
     optimizer : :py:mod:`torch.optim`
 
-    criterion : :py:class:`torch.nn.modules.loss._Loss`
-        loss function
-
-    checkpointer : :py:class:`ptbench.utils.checkpointer.Checkpointer`
-        checkpointer implementation
-
     checkpoint_period : int
         save a checkpoint every ``n`` epochs.  If set to ``0`` (zero), then do
         not save intermediary checkpoints
@@ -578,9 +569,6 @@ def run(
         mini-batch.   This is particularly interesting when one has limited RAM
         on the GPU, but would like to keep training with larger batches.  One
         exchanges for longer processing times in this case.
-
-    criterion_valid : :py:class:`torch.nn.modules.loss._Loss`
-        specific loss function for the validation set
     """
 
     max_epoch = arguments["max_epoch"]
@@ -621,7 +609,7 @@ def run(
             ],
         )
 
-        _ = trainer.fit(model, data_loader)
+        _ = trainer.fit(model, data_loader, valid_loader)
 
     """# write static information to a CSV file
     static_logfile_name = os.path.join(output_folder, "constants.csv")
diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py
index bafeb0303899086cd1b00cfe52d4d0896f7045b8..12e80ae131e4a82c53e0cd7a042006cfd3bb707e 100644
--- a/src/ptbench/scripts/train.py
+++ b/src/ptbench/scripts/train.py
@@ -8,6 +8,7 @@ import click
 
 from clapper.click import ConfigCommand, ResourceOption, verbosity_option
 from clapper.logging import setup
+from pytorch_lightning import seed_everything
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
@@ -53,42 +54,6 @@ def setup_pytorch_device(name):
     return torch.device(name)
 
 
-def set_seeds(value, all_gpus):
-    """Sets up all relevant random seeds (numpy, python, cuda)
-
-    If running with multiple GPUs **at the same time**, set ``all_gpus`` to
-    ``True`` to force all GPU seeds to be initialized.
-
-    Reference: `PyTorch page for reproducibility
-    <https://pytorch.org/docs/stable/notes/randomness.html>`_.
-
-
-    Parameters
-    ----------
-
-    value : int
-        The random seed value to use
-
-    all_gpus : :py:class:`bool`, Optional
-        If set, then reset the seed on all GPUs available at once.  This is
-        normally **not** what you want if running on a single GPU
-    """
-    import random
-
-    import numpy.random
-    import torch
-    import torch.cuda
-
-    random.seed(value)
-    numpy.random.seed(value)
-    torch.manual_seed(value)
-    torch.cuda.manual_seed(value)  # noop if cuda not available
-
-    # set seeds for all gpus
-    if all_gpus:
-        torch.cuda.manual_seed_all(value)  # noop if cuda not available
-
-
 def set_reproducible_cuda():
     """Turns-off all CUDA optimizations that would affect reproducibility.
 
@@ -252,13 +217,14 @@ def set_reproducible_cuda():
     "last saved checkpoint if training is restarted with the same "
     "configuration.",
     show_default=True,
-    required=True,
-    default=0,
+    required=False,
+    default=None,
     type=click.IntRange(min=0),
     cls=ResourceOption,
 )
 @click.option(
     "--device",
+    "-d",
     help='A string indicating the device to use (e.g. "cpu" or "cuda:0")',
     show_default=True,
     required=True,
@@ -288,6 +254,13 @@ def set_reproducible_cuda():
     default=-1,
     cls=ResourceOption,
 )
+@click.option(
+    "--weight",
+    "-w",
+    help="Path or URL to pretrained model file (.pth extension)",
+    required=False,
+    cls=ResourceOption,
+)
 @click.option(
     "--normalization",
     "-n",
@@ -330,6 +303,7 @@ def train(
     device,
     seed,
     parallel,
+    weight,
     normalization,
     monitoring_interval,
     **_,
@@ -354,11 +328,10 @@ def train(
 
     from ..configs.datasets import get_positive_weights, get_samples_weights
     from ..engine.trainer import run
-    from ..utils.checkpointer import Checkpointer
 
     device = setup_pytorch_device(device)
 
-    set_seeds(seed, all_gpus=False)
+    seed_everything(seed)
 
     use_dataset = dataset
     validation_dataset = None
@@ -418,9 +391,6 @@ def train(
 
     # Create weighted random sampler
     train_samples_weights = get_samples_weights(use_dataset)
-    train_samples_weights = train_samples_weights.to(
-        device=device, non_blocking=torch.cuda.is_available()
-    )
     train_sampler = WeightedRandomSampler(
         train_samples_weights, len(train_samples_weights), replacement=True
     )
@@ -428,10 +398,7 @@ def train(
     # Redefine a weighted criterion if possible
     if isinstance(criterion, torch.nn.BCEWithLogitsLoss):
         positive_weights = get_positive_weights(use_dataset)
-        positive_weights = positive_weights.to(
-            device=device, non_blocking=torch.cuda.is_available()
-        )
-        criterion = BCEWithLogitsLoss(pos_weight=positive_weights)
+        model.criterion = BCEWithLogitsLoss(pos_weight=positive_weights)
     else:
         logger.warning("Weighted criterion not supported")
 
@@ -454,10 +421,9 @@ def train(
             or criterion_valid is None
         ):
             positive_weights = get_positive_weights(validation_dataset)
-            positive_weights = positive_weights.to(
-                device=device, non_blocking=torch.cuda.is_available()
+            model.criterion_valid = BCEWithLogitsLoss(
+                pos_weight=positive_weights
             )
-            criterion_valid = BCEWithLogitsLoss(pos_weight=positive_weights)
         else:
             logger.warning("Weighted valid criterion not supported")
 
@@ -513,14 +479,8 @@ def train(
         )
         logger.info(f"Z-normalization with mean {mean} and std {std}")
 
-    # Checkpointer
-    checkpointer = Checkpointer(model, optimizer, path=output_folder)
-
-    # Initialize epoch information
     arguments = {}
     arguments["epoch"] = 0
-    extra_checkpoint_data = checkpointer.load()
-    arguments.update(extra_checkpoint_data)
     arguments["max_epoch"] = epochs
 
     logger.info("Training for {} epochs".format(arguments["max_epoch"]))
@@ -532,13 +492,10 @@ def train(
         valid_loader=valid_loader,
         extra_valid_loaders=extra_valid_loaders,
         optimizer=optimizer,
-        criterion=criterion,
-        checkpointer=checkpointer,
         checkpoint_period=checkpoint_period,
         device=device,
         arguments=arguments,
         output_folder=output_folder,
         monitoring_interval=monitoring_interval,
         batch_chunk_count=batch_chunk_count,
-        criterion_valid=criterion_valid,
     )