diff --git a/src/ptbench/data/base_datamodule.py b/src/ptbench/data/base_datamodule.py
index 3bcd441a59a8e08c2e4d9aa33808c8d651026100..8377c66328241f549dbb2f10946f9cc973ef7a6f 100644
--- a/src/ptbench/data/base_datamodule.py
+++ b/src/ptbench/data/base_datamodule.py
@@ -115,6 +115,10 @@ class BaseDataModule(pl.LightningDataModule):
 
         return loaders_dict
 
+    def update_module_properties(self, **kwargs):
+        for k, v in kwargs.items():
+            setattr(self, k, v)
+
     def _compute_chunk_size(self, batch_size, chunk_count):
         batch_chunk_size = batch_size
         if batch_size % chunk_count != 0:
diff --git a/src/ptbench/data/shenzhen/default.py b/src/ptbench/data/shenzhen/default.py
index 8afac8469920dde1777b0efc6ae3919c1e381f13..b5fb23f59277f1f6511c6707929856cc2357be2f 100644
--- a/src/ptbench/data/shenzhen/default.py
+++ b/src/ptbench/data/shenzhen/default.py
@@ -11,95 +11,18 @@
 """
 
 from clapper.logging import setup
-from torchvision import transforms
 
-from ..base_datamodule import BaseDataModule
-from ..dataset import CachedDataset, JSONProtocol, RuntimeDataset
-from ..shenzhen import _protocols, _raw_data_loader
-from ..transforms import ElasticDeformation, RemoveBlackBorders
+from ..transforms import ElasticDeformation
+from .utils import ShenzhenDataModule
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
+protocol_name = "default"
 
-class DefaultModule(BaseDataModule):
-    def __init__(
-        self,
-        batch_size=1,
-        batch_chunk_count=1,
-        drop_incomplete_batch=False,
-        cache_samples=False,
-        parallel=-1,
-    ):
-        super().__init__(
-            batch_size=batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            batch_chunk_count=batch_chunk_count,
-            parallel=parallel,
-        )
+augmentation_transforms = [ElasticDeformation(p=0.8)]
 
-        self._cache_samples = cache_samples
-        self._has_setup_fit = False
-        self._has_setup_predict = False
-        self._protocol = "default"
-
-        self.raw_data_transforms = [
-            RemoveBlackBorders(),
-            transforms.Resize(512),
-            transforms.CenterCrop(512),
-            transforms.ToTensor(),
-        ]
-
-        self.model_transforms = []
-
-        self.augmentation_transforms = [ElasticDeformation(p=0.8)]
-
-    def setup(self, stage: str):
-        json_protocol = JSONProtocol(
-            protocols=_protocols,
-            fieldnames=("data", "label"),
-        )
-
-        if self._cache_samples:
-            dataset = CachedDataset
-        else:
-            dataset = RuntimeDataset
-
-        if not self._has_setup_fit and stage == "fit":
-            self.train_dataset = dataset(
-                json_protocol,
-                self._protocol,
-                "train",
-                _raw_data_loader,
-                self._build_transforms(is_train=True),
-            )
-
-            self.validation_dataset = dataset(
-                json_protocol,
-                self._protocol,
-                "validation",
-                _raw_data_loader,
-                self._build_transforms(is_train=False),
-            )
-
-            self._has_setup_fit = True
-
-        if not self._has_setup_predict and stage == "predict":
-            self.train_dataset = dataset(
-                json_protocol,
-                self._protocol,
-                "train",
-                _raw_data_loader,
-                self._build_transforms(is_train=False),
-            )
-            self.validation_dataset = dataset(
-                json_protocol,
-                self._protocol,
-                "validation",
-                _raw_data_loader,
-                self._build_transforms(is_train=False),
-            )
-
-            self._has_setup_predict = True
-
-
-datamodule = DefaultModule
+datamodule = ShenzhenDataModule(
+    protocol="default",
+    model_transforms=[],
+    augmentation_transforms=augmentation_transforms,
+)
diff --git a/src/ptbench/data/shenzhen/utils.py b/src/ptbench/data/shenzhen/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1521b674212feec942e73df081e7b20c19d89e29
--- /dev/null
+++ b/src/ptbench/data/shenzhen/utils.py
@@ -0,0 +1,114 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+"""Shenzhen dataset for computer-aided diagnosis.
+
+The standard digital image database for Tuberculosis is created by the
+National Library of Medicine, Maryland, USA in collaboration with Shenzhen
+No.3 People’s Hospital, Guangdong Medical College, Shenzhen, China.
+The Chest X-rays are from out-patient clinics, and were captured as part of
+the daily routine using Philips DR Digital Diagnose systems.
+
+* Reference: [MONTGOMERY-SHENZHEN-2014]_
+* Original resolution (height x width or width x height): 3000 x 3000 or less
+* Split reference: none
+* Protocol ``default``:
+
+  * Training samples: 64% of TB and healthy CXR (including labels)
+  * Validation samples: 16% of TB and healthy CXR (including labels)
+  * Test samples: 20% of TB and healthy CXR (including labels)
+"""
+from clapper.logging import setup
+from torchvision import transforms
+
+from ..base_datamodule import BaseDataModule
+from ..dataset import CachedDataset, JSONProtocol, RuntimeDataset
+from ..shenzhen import _protocols, _raw_data_loader
+from ..transforms import RemoveBlackBorders
+
+logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
+
+
+class ShenzhenDataModule(BaseDataModule):
+    def __init__(
+        self,
+        protocol="default",
+        model_transforms=[],
+        augmentation_transforms=[],
+        batch_size=1,
+        batch_chunk_count=1,
+        drop_incomplete_batch=False,
+        cache_samples=False,
+        parallel=-1,
+    ):
+        super().__init__(
+            batch_size=batch_size,
+            drop_incomplete_batch=drop_incomplete_batch,
+            batch_chunk_count=batch_chunk_count,
+            parallel=parallel,
+        )
+
+        self._cache_samples = cache_samples
+        self._has_setup_fit = False
+        self._has_setup_predict = False
+        self._protocol = protocol
+
+        self.raw_data_transforms = [
+            RemoveBlackBorders(),
+            transforms.Resize(512),
+            transforms.CenterCrop(512),
+            transforms.ToTensor(),
+        ]
+
+        self.model_transforms = model_transforms
+
+        self.augmentation_transforms = augmentation_transforms
+
+    def setup(self, stage: str):
+        json_protocol = JSONProtocol(
+            protocols=_protocols,
+            fieldnames=("data", "label"),
+        )
+
+        if self._cache_samples:
+            dataset = CachedDataset
+        else:
+            dataset = RuntimeDataset
+
+        if not self._has_setup_fit and stage == "fit":
+            self.train_dataset = dataset(
+                json_protocol,
+                self._protocol,
+                "train",
+                _raw_data_loader,
+                self._build_transforms(is_train=True),
+            )
+
+            self.validation_dataset = dataset(
+                json_protocol,
+                self._protocol,
+                "validation",
+                _raw_data_loader,
+                self._build_transforms(is_train=False),
+            )
+
+            self._has_setup_fit = True
+
+        if not self._has_setup_predict and stage == "predict":
+            self.train_dataset = dataset(
+                json_protocol,
+                self._protocol,
+                "train",
+                _raw_data_loader,
+                self._build_transforms(is_train=False),
+            )
+            self.validation_dataset = dataset(
+                json_protocol,
+                self._protocol,
+                "validation",
+                _raw_data_loader,
+                self._build_transforms(is_train=False),
+            )
+
+            self._has_setup_predict = True
diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py
index 6b37e1a109623290e10d6c79009cc69bb403ab05..4d2a226b5b0b479b3f84c748d24aebd43d8d6dea 100644
--- a/src/ptbench/scripts/train.py
+++ b/src/ptbench/scripts/train.py
@@ -270,7 +270,7 @@ def train(
 
     checkpoint_file = get_checkpoint(output_folder, resume_from)
 
-    datamodule = datamodule(
+    datamodule.update_module_properties(
         batch_size=batch_size,
         batch_chunk_count=batch_chunk_count,
         drop_incomplete_batch=drop_incomplete_batch,