diff --git a/src/ptbench/configs/models/densenet.py b/src/ptbench/configs/models/densenet.py
index 6759490854fd05ac8d8d3a9eec5b1494e0cfb0f2..5d612b2a18146ba306b419a808936e9c3c7042f7 100644
--- a/src/ptbench/configs/models/densenet.py
+++ b/src/ptbench/configs/models/densenet.py
@@ -6,20 +6,30 @@
 
 from torch import empty
 from torch.nn import BCEWithLogitsLoss
+from torch.optim import Adam
 
 from ...models.densenet import Densenet
 
-# config
-optimizer_configs = {"lr": 0.0001}
-
 # optimizer
-optimizer = "Adam"
+optimizer = Adam
+optimizer_configs = {"lr": 0.0001}
 
 # criterion
 criterion = BCEWithLogitsLoss(pos_weight=empty(1))
 criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))
 
+from ...data.transforms import ElasticDeformation
+
+augmentation_transforms = [
+    ElasticDeformation(p=0.8),
+]
+
 # model
 model = Densenet(
-    criterion, criterion_valid, optimizer, optimizer_configs, pretrained=False
+    criterion,
+    criterion_valid,
+    optimizer,
+    optimizer_configs,
+    pretrained=False,
+    augmentation_transforms=augmentation_transforms,
 )
diff --git a/src/ptbench/data/shenzhen/rgb.py b/src/ptbench/data/shenzhen/rgb.py
index 2d93c07edb4c0824caa8149737ec42783966ad4c..f45f601d6e84a04e4610319d1f27631ff32ee69c 100644
--- a/src/ptbench/data/shenzhen/rgb.py
+++ b/src/ptbench/data/shenzhen/rgb.py
@@ -2,81 +2,40 @@
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
-"""Shenzhen dataset for TB detection (cross validation fold 0, RGB)
+"""Shenzhen datamodule for computer-aided diagnosis (default protocol)
 
-* Split reference: first 80% of TB and healthy CXR for "train", rest for "test"
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.shenzhen` for dataset details
-"""
-
-from clapper.logging import setup
-
-from ....data import return_subsets
-from ....data.base_datamodule import BaseDataModule
-from ....data.dataset import JSONProtocol
-from ....data.shenzhen import _cached_loader, _delayed_loader, _protocols
+See :py:mod:`ptbench.data.shenzhen` for dataset details.
 
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
-
-class DefaultModule(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        cache_samples=False,
-        multiproc_kwargs=None,
-        data_transforms=[],
-        model_transforms=[],
-        train_transforms=[],
-    ):
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
-
-        self.cache_samples = cache_samples
-        self.has_setup_fit = False
+This configuration:
+* raw data (default): :py:obj:`ptbench.data.shenzhen._tranforms`
+* augmentations: elastic deformation (probability = 80%)
+* output image resolution: 512x512 pixels
+"""
 
-        self.data_transforms = data_transforms
-        self.model_transforms = model_transforms
-        self.train_transforms = train_transforms
+import importlib.resources
 
-        """[
-            transforms.ToPILImage(),
-            transforms.Lambda(lambda x: x.convert("RGB")),
-            transforms.ToTensor(),
-        ]"""
+from torchvision import transforms
 
-    def setup(self, stage: str):
-        if self.cache_samples:
-            logger.info(
-                "Argument cache_samples set to True. Samples will be loaded in memory."
-            )
-            samples_loader = _cached_loader
-        else:
-            logger.info(
-                "Argument cache_samples set to False. Samples will be loaded at runtime."
-            )
-            samples_loader = _delayed_loader
+from ..datamodule import CachingDataModule
+from ..split import JSONDatabaseSplit
+from .raw_data_loader import raw_data_loader
 
-        self.json_protocol = JSONProtocol(
-            protocols=_protocols,
-            fieldnames=("data", "label"),
-            loader=samples_loader,
-            post_transforms=self.post_transforms,
+datamodule = CachingDataModule(
+    database_split=JSONDatabaseSplit(
+        importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(
+            "default.json.bz2"
         )
-
-        if not self.has_setup_fit and stage == "fit":
-            (
-                self.train_dataset,
-                self.validation_dataset,
-                self.extra_validation_datasets,
-            ) = return_subsets(self.json_protocol, "default", stage)
-            self.has_setup_fit = True
-
-
-datamodule = DefaultModule
+    ),
+    raw_data_loader=raw_data_loader,
+    cache_samples=False,
+    # train_sampler: typing.Optional[torch.utils.data.Sampler] = None,
+    model_transforms=[
+        transforms.ToPILImage(),
+        transforms.Lambda(lambda x: x.convert("RGB")),
+        transforms.ToTensor(),
+    ],
+    # batch_size = 1,
+    # batch_chunk_count = 1,
+    # drop_incomplete_batch = False,
+    # parallel = -1,
+)
diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py
index e61d7dece53bce91df406048145382b86cfa5f81..ae866e720ba4df0be292bd2ee3f23878a714818a 100644
--- a/src/ptbench/models/densenet.py
+++ b/src/ptbench/models/densenet.py
@@ -8,6 +8,7 @@ import lightning.pytorch as pl
 import torch
 import torch.nn as nn
 import torchvision.models as models
+import torchvision.transforms
 
 logger = logging.getLogger(__name__)
 
@@ -20,25 +21,37 @@ class Densenet(pl.LightningModule):
 
     def __init__(
         self,
-        criterion,
-        criterion_valid,
-        optimizer,
-        optimizer_configs,
+        criterion=None,
+        criterion_valid=None,
+        optimizer=None,
+        optimizer_configs=None,
         pretrained=False,
-        nb_channels=3,
+        augmentation_transforms=[],
     ):
         super().__init__()
 
-        # Saves all hyper parameters declared on __init__ into ``self.hparams`.
-        # You can access those by their name, like `self.hparams.optimizer`
-        self.save_hyperparameters(ignore=["criterion", "criterion_valid"])
-
         self.name = "Densenet"
 
+        self.augmentation_transforms = torchvision.transforms.Compose(
+            augmentation_transforms
+        )
+
+        self.criterion = criterion
+        self.criterion_valid = criterion_valid
+
+        self.optimizer = optimizer
+        self.optimizer_configs = optimizer_configs
+
         self.normalizer = None
+        self.pretrained = pretrained
 
         # Load pretrained model
-        weights = None if not pretrained else models.DenseNet121_Weights.DEFAULT
+        if not pretrained:
+            weights = None
+        else:
+            logger.info("Loading pretrained model weights")
+            weights = models.DenseNet121_Weights.DEFAULT
+
         self.model_ft = models.densenet121(weights=weights)
 
         # Adapt output features
@@ -52,17 +65,24 @@ class Densenet(pl.LightningModule):
 
         return x
 
-    def set_normalizer(self, dataloader):
-        """TODO: Write this function to set the Normalizer
+    def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
+        """Initializes the normalizer for the current model.
 
         This function is NOOP if ``pretrained = True`` (normalizer set to
         imagenet weights, during contruction).
+
+        Parameters
+        ----------
+
+        dataloader: :py:class:`torch.utils.data.DataLoader`
+            A torch Dataloader from which to compute the mean and std.
+            Will not be used if the model is pretrained.
         """
         if self.pretrained:
             from .normalizer import make_imagenet_normalizer
 
             logger.warning(
-                "ImageNet pre-trained densenet model - NOT"
+                "ImageNet pre-trained densenet model - NOT "
                 "computing z-norm factors from training data. "
                 "Using preset factors from torchvision."
             )
@@ -76,9 +96,38 @@ class Densenet(pl.LightningModule):
             )
             self.normalizer = make_z_normalizer(dataloader)
 
+    def set_bce_loss_weights(self, datamodule):
+        """Reweights loss weights if BCEWithLogitsLoss is used.
+
+        Parameters
+        ----------
+
+        datamodule:
+            A datamodule implementing train_dataloader() and val_dataloader()
+        """
+        from ..data.dataset import _get_positive_weights
+
+        if isinstance(self.criterion, torch.nn.BCEWithLogitsLoss):
+            logger.info("Reweighting BCEWithLogitsLoss training criterion.")
+            train_positive_weights = _get_positive_weights(
+                datamodule.train_dataloader()
+            )
+            self.criterion = torch.nn.BCEWithLogitsLoss(
+                pos_weight=train_positive_weights
+            )
+
+        if isinstance(self.criterion_valid, torch.nn.BCEWithLogitsLoss):
+            logger.info("Reweighting BCEWithLogitsLoss validation criterion.")
+            validation_positive_weights = _get_positive_weights(
+                datamodule.val_dataloader()["validation"]
+            )
+            self.criterion_valid = torch.nn.BCEWithLogitsLoss(
+                pos_weight=validation_positive_weights
+            )
+
     def training_step(self, batch, batch_idx):
-        images = batch[1]
-        labels = batch[2]
+        images = batch[0]
+        labels = batch[1]["label"]
 
         # Increase label dimension if too low
         # Allows single and multiclass usage
@@ -86,17 +135,20 @@ class Densenet(pl.LightningModule):
             labels = torch.reshape(labels, (labels.shape[0], 1))
 
         # Forward pass on the network
-        outputs = self(images)
+        augmented_images = [
+            self.augmentation_transforms(img).to(self.device) for img in images
+        ]
+        # Combine list of augmented images back into a tensor
+        augmented_images = torch.cat(augmented_images, 0).view(images.shape)
+        outputs = self(augmented_images)
 
-        # Manually move criterion to selected device, since not part of the model.
-        self.hparams.criterion = self.hparams.criterion.to(self.device)
-        training_loss = self.hparams.criterion(outputs, labels.float())
+        training_loss = self.criterion(outputs, labels.float())
 
         return {"loss": training_loss}
 
     def validation_step(self, batch, batch_idx, dataloader_idx=0):
-        images = batch[1]
-        labels = batch[2]
+        images = batch[0]
+        labels = batch[1]["label"]
 
         # Increase label dimension if too low
         # Allows single and multiclass usage
@@ -106,11 +158,7 @@ class Densenet(pl.LightningModule):
         # data forwarding on the existing network
         outputs = self(images)
 
-        # Manually move criterion to selected device, since not part of the model.
-        self.hparams.criterion_valid = self.hparams.criterion_valid.to(
-            self.device
-        )
-        validation_loss = self.hparams.criterion_valid(outputs, labels.float())
+        validation_loss = self.criterion_valid(outputs, labels.float())
 
         if dataloader_idx == 0:
             return {"validation_loss": validation_loss}
@@ -118,8 +166,9 @@ class Densenet(pl.LightningModule):
             return {f"extra_validation_loss_{dataloader_idx}": validation_loss}
 
     def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
-        names = batch[0]
-        images = batch[1]
+        images = batch[0]
+        labels = batch[1]["label"]
+        names = batch[1]["name"]
 
         outputs = self(images)
         probabilities = torch.sigmoid(outputs)
@@ -129,12 +178,8 @@ class Densenet(pl.LightningModule):
         if isinstance(outputs, list):
             outputs = outputs[-1]
 
-        return names[0], torch.flatten(probabilities), torch.flatten(batch[2])
+        return names[0], torch.flatten(probabilities), torch.flatten(labels)
 
     def configure_optimizers(self):
-        # Dynamically instantiates the optimizer given the configs
-        optimizer = getattr(torch.optim, self.hparams.optimizer)(
-            self.parameters(), **self.hparams.optimizer_configs
-        )
-
+        optimizer = self.optimizer(self.parameters(), **self.optimizer_configs)
         return optimizer