diff --git a/src/ptbench/configs/models/pasa.py b/src/ptbench/configs/models/pasa.py
index d09161d0408118a3fd2dccefc1b082319f99051b..3ee0b92164b5531b65049b94e71b01b07e2ad27e 100644
--- a/src/ptbench/configs/models/pasa.py
+++ b/src/ptbench/configs/models/pasa.py
@@ -11,19 +11,20 @@ Screening and Visualization".
 Reference: [PASA-2019]_
 """
 
+from torch import empty
 from torch.nn import BCEWithLogitsLoss
-from torch.optim import Adam
 
-from ...models.pasa import build_pasa
+from ...models.pasa import PASA
 
 # config
-lr = 8e-5
-
-# model
-model = build_pasa()
+optimizer_configs = {"lr": 8e-5}
 
 # optimizer
-optimizer = Adam(model.parameters(), lr=lr)
+optimizer = "Adam"
 
 # criterion
-criterion = BCEWithLogitsLoss()
+criterion = BCEWithLogitsLoss(pos_weight=empty(1))
+criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))
+
+# model
+model = PASA(criterion, criterion_valid, optimizer, optimizer_configs)
diff --git a/src/ptbench/engine/callbacks.py b/src/ptbench/engine/callbacks.py
new file mode 100644
index 0000000000000000000000000000000000000000..d730cf68c95e69fe4191dd5e6bdf95ed8b9ee596
--- /dev/null
+++ b/src/ptbench/engine/callbacks.py
@@ -0,0 +1,104 @@
+import csv
+import time
+
+import numpy
+
+from pytorch_lightning import Callback
+from pytorch_lightning.callbacks import BasePredictionWriter
+
+
+# This ensures CSVLogger logs training and evaluation metrics on the same line
+# CSVLogger only accepts numerical values, not strings
+class LoggingCallback(Callback):
+    def __init__(self, resource_monitor):
+        super().__init__()
+        self.training_loss = []
+        self.validation_loss = []
+        self.start_training_time = 0
+        self.start_epoch_time = 0
+
+        self.resource_monitor = resource_monitor
+        self.max_queue_retries = 2
+
+    def on_train_start(self, trainer, pl_module):
+        self.start_training_time = time.time()
+
+    def on_train_epoch_start(self, trainer, pl_module):
+        self.start_epoch_time = time.time()
+
+    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
+        self.training_loss.append(outputs["loss"].item())
+
+    def on_validation_batch_end(
+        self, trainer, pl_module, outputs, batch, batch_idx
+    ):
+        self.validation_loss.append(outputs["validation_loss"].item())
+
+    def on_validation_epoch_end(self, trainer, pl_module):
+        self.resource_monitor.trigger_summary()
+
+        self.epoch_time = time.time() - self.start_epoch_time
+        eta_seconds = self.epoch_time * (
+            trainer.max_epochs - trainer.current_epoch
+        )
+        current_time = time.time() - self.start_training_time
+
+        self.log("total_time", current_time)
+        self.log("eta", eta_seconds)
+        self.log("loss", numpy.average(self.training_loss))
+        self.log("learning_rate", pl_module.lr)
+        self.log("validation_loss", numpy.average(self.validation_loss))
+
+        queue_retries = 0
+        # In case the resource monitor takes longer to fetch data from the queue, we wait
+        # Give up after self.resource_monitor.interval * self.max_queue_retries if cannot retrieve metrics from queue
+        while (
+            self.resource_monitor.data is None
+            and queue_retries < self.max_queue_retries
+        ):
+            queue_retries = queue_retries + 1
+            print(
+                f"Monitor queue is empty, retrying in {self.resource_monitor.interval}s"
+            )
+            time.sleep(self.resource_monitor.interval)
+
+        if queue_retries >= self.max_queue_retries:
+            print(
+                f"Unable to fetch monitoring information from queue after {queue_retries} retries"
+            )
+
+        assert self.resource_monitor.q.empty()
+
+        for metric_name, metric_value in self.resource_monitor.data:
+            self.log(metric_name, metric_value)
+
+        self.resource_monitor.data = None
+
+        self.training_loss = []
+        self.validation_loss = []
+
+
+class PredictionsWriter(BasePredictionWriter):
+    def __init__(self, logfile_name, logfile_fields, write_interval):
+        super().__init__(write_interval)
+        self.logfile_name = logfile_name
+        self.logfile_fields = logfile_fields
+
+    def write_on_epoch_end(
+        self, trainer, pl_module, predictions, batch_indices
+    ):
+        with open(self.logfile_name, "w") as logfile:
+            logwriter = csv.DictWriter(logfile, fieldnames=self.logfile_fields)
+            logwriter.writeheader()
+
+            # We should only get a single epoch here
+            for epoch in predictions:
+                for prediction in epoch:
+                    logwriter.writerow(
+                        {
+                            "filename": prediction[0],
+                            "likelihood": prediction[1].numpy(),
+                            "ground_truth": prediction[2].numpy(),
+                        }
+                    )
+            logfile.flush()
diff --git a/src/ptbench/engine/trainer.py b/src/ptbench/engine/trainer.py
index fd48d4f1fdd82eea05d28dfff7d035e5694a08fd..965a89cd7ee82f5c68ea079368b7a414dd264a95 100644
--- a/src/ptbench/engine/trainer.py
+++ b/src/ptbench/engine/trainer.py
@@ -9,16 +9,18 @@ import logging
 import os
 import shutil
 import sys
-import time
 
 import numpy
 import torch
 
+from pytorch_lightning import Trainer
+from pytorch_lightning.callbacks import ModelCheckpoint
+from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger
+from pytorch_lightning.utilities.model_summary import ModelSummary
 from tqdm import tqdm
 
-# from ..utils.resources import cpu_constants, gpu_constants, cpu_log, gpu_log
 from ..utils.resources import ResourceMonitor, cpu_constants, gpu_constants
-from ..utils.summary import summary
+from .callbacks import LoggingCallback
 
 logger = logging.getLogger(__name__)
 
@@ -127,10 +129,9 @@ def save_model_summary(output_folder, model):
     summary_path = os.path.join(output_folder, "model_summary.txt")
     logger.info(f"Saving model summary at {summary_path}...")
     with open(summary_path, "w") as f:
-        r, n = summary(model)
-        logger.info(f"Model has {n} parameters...")
-        f.write(r)
-    return r, n
+        summary = str(ModelSummary(model, max_depth=-1))
+        f.write(summary)
+    return summary
 
 
 def static_information_to_csv(static_logfile_name, device, n):
@@ -582,7 +583,6 @@ def run(
         specific loss function for the validation set
     """
 
-    start_epoch = arguments["epoch"]
     max_epoch = arguments["max_epoch"]
 
     check_gpu(device)
@@ -590,9 +590,40 @@ def run(
     os.makedirs(output_folder, exist_ok=True)
 
     # Save model summary
-    r, n = save_model_summary(output_folder, model)
+    _ = save_model_summary(output_folder, model)
 
-    # write static information to a CSV file
+    csv_logger = CSVLogger(output_folder, "logs_csv")
+    tensorboard_logger = TensorBoardLogger(output_folder, "logs_tensorboard")
+
+    resource_monitor = ResourceMonitor(
+        interval=5.0,
+        has_gpu=(device.type == "cuda"),
+        main_pid=os.getpid(),
+        logging_level=logging.ERROR,
+    )
+
+    with resource_monitor:
+        trainer = Trainer(
+            accelerator="auto",
+            devices="auto",
+            max_epochs=max_epoch,
+            logger=[csv_logger, tensorboard_logger],
+            check_val_every_n_epoch=1,
+            callbacks=[
+                LoggingCallback(resource_monitor),
+                ModelCheckpoint(
+                    output_folder,
+                    monitor="validation_loss",
+                    mode="min",
+                    save_on_train_epoch_end=False,
+                    every_n_epochs=checkpoint_period,
+                ),
+            ],
+        )
+
+        _ = trainer.fit(model, data_loader)
+
+    """# write static information to a CSV file
     static_logfile_name = os.path.join(output_folder, "constants.csv")
 
     static_information_to_csv(static_logfile_name, device, n)
@@ -710,4 +741,4 @@ def run(
         total_training_time = time.time() - start_training_time
         logger.info(
             f"Total training time: {datetime.timedelta(seconds=total_training_time)} ({(total_training_time/max_epoch):.4f}s in average per epoch)"
-        )
+        )"""
diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py
index 10e6cedeb672094f62f5c1d16dbaeb9d6983ce34..4fd816e17ab4a3dc85ef6769c9d2325657b91310 100644
--- a/src/ptbench/models/pasa.py
+++ b/src/ptbench/models/pasa.py
@@ -2,23 +2,49 @@
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
-from collections import OrderedDict
-
+import pytorch_lightning as pl
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
 
 from .normalizer import TorchVisionNormalizer
 
-
-class PASA(nn.Module):
+colors = [
+    [(47, 79, 79), "Cardiomegaly"],
+    [(255, 0, 0), "Emphysema"],
+    [(0, 128, 0), "Pleural effusion"],
+    [(0, 0, 128), "Hernia"],
+    [(255, 84, 0), "Infiltration"],
+    [(222, 184, 135), "Mass"],
+    [(0, 255, 0), "Nodule"],
+    [(0, 191, 255), "Atelectasis"],
+    [(0, 0, 255), "Pneumothorax"],
+    [(255, 0, 255), "Pleural thickening"],
+    [(255, 255, 0), "Pneumonia"],
+    [(126, 0, 255), "Fibrosis"],
+    [(255, 20, 147), "Edema"],
+    [(0, 255, 180), "Consolidation"],
+]
+
+
+class PASA(pl.LightningModule):
     """PASA module.
 
     Based on paper by [PASA-2019]_.
     """
 
-    def __init__(self):
+    def __init__(self, criterion, criterion_valid, optimizer, optimizer_params):
         super().__init__()
+
+        self.save_hyperparameters()
+
+        self.name = "pasa"
+
+        self.criterion = criterion
+        self.criterion_valid = criterion_valid
+
+        self.normalizer = TorchVisionNormalizer(nb_channels=1)
+
         # First convolution block
         self.fc1 = nn.Conv2d(1, 4, (3, 3), (2, 2), (1, 1))
         self.fc2 = nn.Conv2d(4, 16, (3, 3), (2, 2), (1, 1))
@@ -82,6 +108,8 @@ class PASA(nn.Module):
         tensor : :py:class:`torch.Tensor`
 
         """
+        x = self.normalizer(x)
+
         # First convolution block
         _x = x
         x = F.relu(self.batchNorm2d_4(self.fc1(x)))  # 1st convolution
@@ -127,21 +155,41 @@ class PASA(nn.Module):
 
         return x
 
+    def training_step(self, batch, batch_idx):
+        images = batch[1]
+        labels = batch[2]
 
-def build_pasa():
-    """Build pasa CNN.
+        # Increase label dimension if too low
+        # Allows single and multiclass usage
+        if labels.ndim == 1:
+            labels = torch.reshape(labels, (labels.shape[0], 1))
 
-    Returns
-    -------
+        # Forward pass on the network
+        outputs = self(images)
 
-    module : :py:class:`torch.nn.Module`
-    """
-    model = PASA()
-    model = [
-        ("normalizer", TorchVisionNormalizer(nb_channels=1)),
-        ("model", model),
-    ]
-    model = nn.Sequential(OrderedDict(model))
-
-    model.name = "pasa"
-    return model
+        training_loss = self.criterion(outputs, labels.double())
+
+        return {"loss": training_loss}
+
+    def validation_step(self, batch, batch_idx):
+        images = batch[1]
+        labels = batch[2]
+
+        # Increase label dimension if too low
+        # Allows single and multiclass usage
+        if labels.ndim == 1:
+            labels = torch.reshape(labels, (labels.shape[0], 1))
+
+        # data forwarding on the existing network
+        outputs = self(images)
+        validation_loss = self.criterion_valid(outputs, labels.double())
+
+        return {"validation_loss": validation_loss}
+
+    def configure_optimizers(self):
+        # Dynamically instantiates the optimizer given the configs
+        optimizer = getattr(torch.optim, self.hparams.optimizer)(
+            self.parameters(), **self.hparams.optimizer_params
+        )
+
+        return optimizer
diff --git a/src/ptbench/utils/resources.py b/src/ptbench/utils/resources.py
index be23ee452a1823555220c5d92d80a2f7c6a9223f..fa0ac3dd2b2332a8d938850d97739b457eeed13a 100644
--- a/src/ptbench/utils/resources.py
+++ b/src/ptbench/utils/resources.py
@@ -233,6 +233,7 @@ class CPULogger:
         cpu_percent = []
         open_files = []
         gone = set()
+
         for k in self.cluster:
             try:
                 memory_info.append(k.memory_info())
@@ -243,7 +244,6 @@ class CPULogger:
                 # it is too late to update any intermediate list
                 # at this point, but ensures to update counts later on
                 gone.add(k)
-
         return (
             ("cpu_memory_used", psutil.virtual_memory().used / GB),
             ("cpu_rss", sum([k.rss for k in memory_info]) / GB),
@@ -288,6 +288,10 @@ class _InformationGatherer:
             for i, k in enumerate(gpu_log()):
                 self.data[i + self.cpu_keys_len].append(k[1])
 
+    def clear(self):
+        """Clears accumulated data."""
+        self.data = [[] for _ in self.keys]
+
     def summary(self):
         """Returns the current data."""
         if len(self.data[0]) == 0:
@@ -298,7 +302,9 @@ class _InformationGatherer:
         return tuple(retval)
 
 
-def _monitor_worker(interval, has_gpu, main_pid, stop, queue, logging_level):
+def _monitor_worker(
+    interval, has_gpu, main_pid, stop, summary_event, queue, logging_level
+):
     """A monitoring worker that measures resources and returns lists.
 
     Parameters
@@ -329,6 +335,12 @@ def _monitor_worker(interval, has_gpu, main_pid, stop, queue, logging_level):
     while not stop.is_set():
         try:
             ra.acc()  # guarantees at least an entry will be available
+
+            if summary_event.is_set():
+                queue.put(ra.summary())
+                ra.clear()
+                summary_event.clear()
+
             time.sleep(interval)
         except Exception:
             logger.warning(
@@ -337,8 +349,6 @@ def _monitor_worker(interval, has_gpu, main_pid, stop, queue, logging_level):
             )
             time.sleep(0.5)  # wait half a second, and try again!
 
-    queue.put(ra.summary())
-
 
 class ResourceMonitor:
     """An external, non-blocking CPU/GPU resource monitor.
@@ -364,7 +374,8 @@ class ResourceMonitor:
         self.interval = interval
         self.has_gpu = has_gpu
         self.main_pid = main_pid
-        self.event = multiprocessing.Event()
+        self.stop_event = multiprocessing.Event()
+        self.summary_event = multiprocessing.Event()
         self.q = multiprocessing.Queue()
         self.logging_level = logging_level
 
@@ -375,7 +386,8 @@ class ResourceMonitor:
                 self.interval,
                 self.has_gpu,
                 self.main_pid,
-                self.event,
+                self.stop_event,
+                self.summary_event,
                 self.q,
                 self.logging_level,
             ),
@@ -390,19 +402,9 @@ class ResourceMonitor:
     def __enter__(self):
         """Starts the monitoring process."""
         self.monitor.start()
-        return self
 
-    def __exit__(self, *exc):
-        """Stops the monitoring process and returns the summary of
-        observations."""
-
-        self.event.set()
-        self.monitor.join()
-        if self.monitor.exitcode != 0:
-            logger.error(
-                f"CPU/GPU resource monitor process exited with code "
-                f"{self.monitor.exitcode}.  Check logs for errors!"
-            )
+    def trigger_summary(self):
+        self.summary_event.set()
 
         try:
             data = self.q.get(timeout=2 * self.interval)
@@ -426,3 +428,15 @@ class ResourceMonitor:
                 else:
                     summary.append((k, 0.0))
             self.data = tuple(summary)
+
+    def __exit__(self, *exc):
+        """Stops the monitoring process and returns the summary of
+        observations."""
+
+        self.stop_event.set()
+        self.monitor.join()
+        if self.monitor.exitcode != 0:
+            logger.error(
+                f"CPU/GPU resource monitor process exited with code "
+                f"{self.monitor.exitcode}.  Check logs for errors!"
+            )