diff --git a/src/ptbench/data/shenzhen/loader.py b/src/ptbench/data/shenzhen/loader.py
index e983fe7028d858118efb9ba41e560aee8e95845a..49ccf8bfb217e411004228b4acf7c924e3ffec66 100644
--- a/src/ptbench/data/shenzhen/loader.py
+++ b/src/ptbench/data/shenzhen/loader.py
@@ -82,7 +82,8 @@ class RawDataLoader(_BaseRawDataLoader):
         tensor = self.transform(
             load_pil_baw(os.path.join(self.datadir, sample[0]))
         )
-        return tensor, dict(label=sample[1])  # type: ignore[arg-type]
+
+        return tensor, dict(label=sample[1], name=sample[0])  # type: ignore[arg-type]
 
     def label(self, sample: tuple[str, int]) -> int:
         """Loads a single image sample label from the disk.
diff --git a/src/ptbench/data/shenzhen/rgb.py b/src/ptbench/data/shenzhen/rgb.py
index f45f601d6e84a04e4610319d1f27631ff32ee69c..211b49236cae22af68ad61d38849429dedb606d7 100644
--- a/src/ptbench/data/shenzhen/rgb.py
+++ b/src/ptbench/data/shenzhen/rgb.py
@@ -18,7 +18,7 @@ from torchvision import transforms
 
 from ..datamodule import CachingDataModule
 from ..split import JSONDatabaseSplit
-from .raw_data_loader import raw_data_loader
+from .loader import RawDataLoader
 
 datamodule = CachingDataModule(
     database_split=JSONDatabaseSplit(
@@ -26,16 +26,10 @@ datamodule = CachingDataModule(
             "default.json.bz2"
         )
     ),
-    raw_data_loader=raw_data_loader,
-    cache_samples=False,
-    # train_sampler: typing.Optional[torch.utils.data.Sampler] = None,
+    raw_data_loader=RawDataLoader(),
     model_transforms=[
         transforms.ToPILImage(),
         transforms.Lambda(lambda x: x.convert("RGB")),
         transforms.ToTensor(),
     ],
-    # batch_size = 1,
-    # batch_chunk_count = 1,
-    # drop_incomplete_batch = False,
-    # parallel = -1,
 )
diff --git a/src/ptbench/data/typing.py b/src/ptbench/data/typing.py
index 344c1294df6777f88acdde23dec51d40fc51e31e..f0e54f1afa701f0b7f9c691bbc7dfd51aad86e8f 100644
--- a/src/ptbench/data/typing.py
+++ b/src/ptbench/data/typing.py
@@ -73,3 +73,6 @@ DataLoader = torch.utils.data.DataLoader[Sample]
 
 We iterate over Sample objects in this case.
 """
+
+Checkpoint = dict[str, typing.Any]
+"""Definition of a lightning checkpoint."""
diff --git a/src/ptbench/engine/callbacks.py b/src/ptbench/engine/callbacks.py
index 350140a8516ddef43e89323e65b746ca7c479182..8774f9c45c248414618a24592cab7ee687e74dbf 100644
--- a/src/ptbench/engine/callbacks.py
+++ b/src/ptbench/engine/callbacks.py
@@ -398,26 +398,27 @@ class PredictionsWriter(lightning.pytorch.callbacks.BasePredictionWriter):
         predictions: typing.Sequence[typing.Any],
         batch_indices: typing.Sequence[typing.Any] | None,
     ) -> None:
-        for dataloader_idx, dataloader_results in enumerate(predictions):
-            dataloader_name = list(
-                trainer.datamodule.predict_dataloader().keys()
-            )[dataloader_idx].replace("_loader", "")
+        dataloader_name = list(trainer.datamodule.predict_dataloader().keys())[
+            0
+        ]
 
-            logfile = os.path.join(
-                self.output_dir, dataloader_name, "predictions.csv"
-            )
-            os.makedirs(os.path.dirname(logfile), exist_ok=True)
-
-            with open(logfile, "w") as l_f:
-                logwriter = csv.DictWriter(l_f, fieldnames=self.logfile_fields)
-                logwriter.writeheader()
-
-                for prediction in dataloader_results:
-                    logwriter.writerow(
-                        {
-                            "filename": prediction[0],
-                            "likelihood": prediction[1].numpy(),
-                            "ground_truth": prediction[2].numpy(),
-                        }
-                    )
-                l_f.flush()
+        logfile = os.path.join(
+            self.output_dir, f"predictions_{dataloader_name}_set.csv"
+        )
+        os.makedirs(os.path.dirname(logfile), exist_ok=True)
+
+        logger.info(f"Saving predictions in {logfile}.")
+
+        with open(logfile, "w") as l_f:
+            logwriter = csv.DictWriter(l_f, fieldnames=self.logfile_fields)
+            logwriter.writeheader()
+
+            for prediction in predictions:
+                logwriter.writerow(
+                    {
+                        "filename": prediction[0],
+                        "likelihood": prediction[1].numpy(),
+                        "ground_truth": prediction[2].numpy(),
+                    }
+                )
+            l_f.flush()
diff --git a/src/ptbench/engine/predictor.py b/src/ptbench/engine/predictor.py
index 5dcbb79c9fd0a8c32f9d269f8302b888da56be84..6bb6e275d304406182449c77a68a8a8e62406719 100644
--- a/src/ptbench/engine/predictor.py
+++ b/src/ptbench/engine/predictor.py
@@ -5,15 +5,22 @@
 import logging
 import os
 
+import lightning.pytorch
+
 from lightning.pytorch import Trainer
 
-from ..utils.accelerator import AcceleratorProcessor
 from .callbacks import PredictionsWriter
+from .device import DeviceManager
 
 logger = logging.getLogger(__name__)
 
 
-def run(model, datamodule, accelerator, output_folder, grad_cams=False):
+def run(
+    model: lightning.pytorch.LightningModule,
+    datamodule: lightning.pytorch.LightningDataModule,
+    device_manager: DeviceManager,
+    output_folder: str,
+):
     """Runs inference on input data, outputs csv files with predictions.
 
     Parameters
@@ -21,11 +28,13 @@ def run(model, datamodule, accelerator, output_folder, grad_cams=False):
     model : :py:class:`torch.nn.Module`
         Neural network model (e.g. pasa).
 
-    data_loader : py:class:`torch.torch.utils.data.DataLoader`
-        The pytorch Dataloader used to iterate over batches.
+    datamodule
+        The lightning datamodule to use for training **and** validation
 
-    accelerator : str
-        A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The device can also be specified (gpu:0)
+    device_manager
+        An internal device representation, to be used for training and
+        validation.  This representation can be converted into a pytorch device
+        or a torch lightning accelerator setup.
 
     output_folder : str
         Directory in which the results will be saved.
@@ -44,19 +53,11 @@ def run(model, datamodule, accelerator, output_folder, grad_cams=False):
     logger.info(f"Output folder: {output_folder}")
     os.makedirs(output_folder, exist_ok=True)
 
-    accelerator_processor = AcceleratorProcessor(accelerator)
-
-    if accelerator_processor.device is None:
-        devices = "auto"
-    else:
-        devices = accelerator_processor.device
-
-    logger.info(f"Device: {devices}")
-
     logfile_fields = ("filename", "likelihood", "ground_truth")
 
+    accelerator, devices = device_manager.lightning_accelerator()
     trainer = Trainer(
-        accelerator=accelerator_processor.accelerator,
+        accelerator=accelerator,
         devices=devices,
         callbacks=[
             PredictionsWriter(
diff --git a/src/ptbench/models/alexnet.py b/src/ptbench/models/alexnet.py
index a878a076037925879c072bffb87d23a3e1ce7b0d..c3aadc6d0466777fb0adcb62eb7d20887cf36b20 100644
--- a/src/ptbench/models/alexnet.py
+++ b/src/ptbench/models/alexnet.py
@@ -8,13 +8,12 @@ import typing
 import lightning.pytorch as pl
 import torch
 import torch.nn
-import torch.nn.functional as F
 import torch.optim.optimizer
 import torch.utils.data
 import torchvision.models as models
 import torchvision.transforms
 
-from ..data.typing import DataLoader, TransformSequence
+from ..data.typing import Checkpoint, DataLoader, TransformSequence
 
 logger = logging.getLogger(__name__)
 
@@ -61,10 +60,10 @@ class Alexnet(pl.LightningModule):
 
     def __init__(
         self,
-        train_loss: torch.nn.Module,
-        validation_loss: torch.nn.Module | None,
-        optimizer_type: type[torch.optim.Optimizer],
-        optimizer_arguments: dict[str, typing.Any],
+        train_loss: torch.nn.Module = torch.nn.BCEWithLogitsLoss(),
+        validation_loss: torch.nn.Module | None = None,
+        optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
+        optimizer_arguments: dict[str, typing.Any] = {},
         augmentation_transforms: TransformSequence = [],
         pretrained: bool = False,
     ):
@@ -105,6 +104,32 @@ class Alexnet(pl.LightningModule):
 
         return x
 
+    def on_save_checkpoint(self, checkpoint: Checkpoint) -> None:
+        """Called by Lightning to restore your model.
+
+        If you saved something with on_save_checkpoint() this is your chance to restore this.
+
+        Parameters
+        ----------
+
+        checkpoint:
+            Loaded checkpoint
+        """
+        checkpoint["normalizer"] = self.normalizer
+
+    def on_load_checkpoint(self, checkpoint: Checkpoint) -> None:
+        """Called by Lightning when saving a checkpoint to give you a chance to
+        store anything else you might want to save.
+
+        Parameters
+        ----------
+
+        checkpoint:
+            Loaded checkpoint
+        """
+        logger.info("Restoring normalizer from checkpoint.")
+        self.normalizer = checkpoint["normalizer"]
+
     def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
         """Initializes the normalizer for the current model.
 
@@ -214,7 +239,7 @@ class Alexnet(pl.LightningModule):
     def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
         images = batch[0]
         labels = batch[1]["label"]
-        names = batch[1]["names"]
+        names = batch[1]["name"]
 
         outputs = self(images)
         probabilities = torch.sigmoid(outputs)
diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py
index 8eba3b53410da4874d8b59c48c73e13bec1cb703..021f6ce2c6f5cfb3ad3819144a442744577d5eaa 100644
--- a/src/ptbench/models/densenet.py
+++ b/src/ptbench/models/densenet.py
@@ -8,13 +8,12 @@ import typing
 import lightning.pytorch as pl
 import torch
 import torch.nn
-import torch.nn.functional as F
 import torch.optim.optimizer
 import torch.utils.data
 import torchvision.models as models
 import torchvision.transforms
 
-from ..data.typing import DataLoader, TransformSequence
+from ..data.typing import Checkpoint, DataLoader, TransformSequence
 
 logger = logging.getLogger(__name__)
 
@@ -59,12 +58,12 @@ class Densenet(pl.LightningModule):
 
     def __init__(
         self,
-        train_loss: torch.nn.Module,
-        validation_loss: torch.nn.Module | None,
-        optimizer_type: type[torch.optim.Optimizer],
-        optimizer_arguments: dict[str, typing.Any],
+        train_loss: torch.nn.Module = torch.nn.BCEWithLogitsLoss(),
+        validation_loss: torch.nn.Module | None = None,
+        optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
+        optimizer_arguments: dict[str, typing.Any] = {},
         augmentation_transforms: TransformSequence = [],
-        pretrained: bool= False,
+        pretrained: bool = False,
     ):
         super().__init__()
 
@@ -98,13 +97,38 @@ class Densenet(pl.LightningModule):
         )
 
     def forward(self, x):
-        
         x = self.normalizer(x)  # type: ignore
 
         x = self.model_ft(x)
 
         return x
 
+    def on_save_checkpoint(self, checkpoint: Checkpoint) -> None:
+        """Called by Lightning to restore your model.
+
+        If you saved something with on_save_checkpoint() this is your chance to restore this.
+
+        Parameters
+        ----------
+
+        checkpoint:
+            Loaded checkpoint
+        """
+        checkpoint["normalizer"] = self.normalizer
+
+    def on_load_checkpoint(self, checkpoint: Checkpoint) -> None:
+        """Called by Lightning when saving a checkpoint to give you a chance to
+        store anything else you might want to save.
+
+        Parameters
+        ----------
+
+        checkpoint:
+            Loaded checkpoint
+        """
+        logger.info("Restoring normalizer from checkpoint.")
+        self.normalizer = checkpoint["normalizer"]
+
     def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
         """Initializes the normalizer for the current model.
 
@@ -216,7 +240,7 @@ class Densenet(pl.LightningModule):
     def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
         images = batch[0]
         labels = batch[1]["label"]
-        names = batch[1]["names"]
+        names = batch[1]["name"]
 
         outputs = self(images)
         probabilities = torch.sigmoid(outputs)
diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py
index d6dd23ee02923490247d03b75fbf2c167aef57dd..5dd1c33c19c7a5b22e2b37bbfeb9943322f7f16f 100644
--- a/src/ptbench/models/pasa.py
+++ b/src/ptbench/models/pasa.py
@@ -13,7 +13,7 @@ import torch.optim.optimizer
 import torch.utils.data
 import torchvision.transforms
 
-from ..data.typing import DataLoader, TransformSequence
+from ..data.typing import Checkpoint, DataLoader, TransformSequence
 
 logger = logging.getLogger(__name__)
 
@@ -58,10 +58,10 @@ class Pasa(pl.LightningModule):
 
     def __init__(
         self,
-        train_loss: torch.nn.Module,
-        validation_loss: torch.nn.Module | None,
-        optimizer_type: type[torch.optim.Optimizer],
-        optimizer_arguments: dict[str, typing.Any],
+        train_loss: torch.nn.Module = torch.nn.BCEWithLogitsLoss(),
+        validation_loss: torch.nn.Module | None = None,
+        optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
+        optimizer_arguments: dict[str, typing.Any] = {},
         augmentation_transforms: TransformSequence = [],
     ):
         super().__init__()
@@ -185,10 +185,29 @@ class Pasa(pl.LightningModule):
 
         return x
 
-    def on_save_checkpoint(self, checkpoint):
+    def on_save_checkpoint(self, checkpoint: Checkpoint) -> None:
+        """Called by Lightning to restore your model.
+
+        If you saved something with on_save_checkpoint() this is your chance to restore this.
+
+        Parameters
+        ----------
+
+        checkpoint:
+            Loaded checkpoint
+        """
         checkpoint["normalizer"] = self.normalizer
 
-    def on_load_checkpoint(self, checkpoint):
+    def on_load_checkpoint(self, checkpoint: Checkpoint) -> None:
+        """Called by Lightning when saving a checkpoint to give you a chance to
+        store anything else you might want to save.
+
+        Parameters
+        ----------
+
+        checkpoint:
+            Loaded checkpoint
+        """
         logger.info("Restoring normalizer from checkpoint.")
         self.normalizer = checkpoint["normalizer"]
 
@@ -289,7 +308,7 @@ class Pasa(pl.LightningModule):
     def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
         images = batch[0]
         labels = batch[1]["label"]
-        names = batch[1]["names"]
+        names = batch[1]["name"]
 
         outputs = self(images)
         probabilities = torch.sigmoid(outputs)
diff --git a/src/ptbench/scripts/predict.py b/src/ptbench/scripts/predict.py
index a78d74b41d8f75b3f4466a890f206e4b2503a84c..f73baabab5d2f09f5d2fd6d802c429ec63a2cc6f 100644
--- a/src/ptbench/scripts/predict.py
+++ b/src/ptbench/scripts/predict.py
@@ -62,9 +62,9 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
     cls=ResourceOption,
 )
 @click.option(
-    "--accelerator",
-    "-a",
-    help='A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The device can also be specified (gpu:0)',
+    "--device",
+    "-d",
+    help='A string indicating the device to use (e.g. "cpu" or "cuda:0")',
     show_default=True,
     required=True,
     default="cpu",
@@ -77,22 +77,14 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
     required=True,
     cls=ResourceOption,
 )
-@click.option(
-    "--grad-cams",
-    "-g",
-    help="If set, generate grad cams for each prediction (must use batch of 1)",
-    is_flag=True,
-    cls=ResourceOption,
-)
 @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
 def predict(
     output_folder,
     model,
     datamodule,
     batch_size,
-    accelerator,
+    device,
     weight,
-    grad_cams,
     **_,
 ) -> None:
     """Predicts Tuberculosis presence (probabilities) on input images."""
@@ -103,12 +95,11 @@ def predict(
 
     from matplotlib.backends.backend_pdf import PdfPages
 
+    from ..engine.device import DeviceManager
     from ..engine.predictor import run
     from ..utils.plot import relevance_analysis_plot
 
-    datamodule = datamodule(
-        batch_size=batch_size,
-    )
+    datamodule.set_chunk_size(batch_size, 1)
 
     logger.info(f"Loading checkpoint from {weight}")
     model = model.load_from_checkpoint(weight, strict=False)
@@ -128,4 +119,4 @@ def predict(
         )
         pdf.close()
 
-    _ = run(model, datamodule, accelerator, output_folder, grad_cams)
+    _ = run(model, datamodule, DeviceManager(device), output_folder)