diff --git a/src/ptbench/data/dataset.py b/src/ptbench/data/dataset.py
index fad353a13cfd2b38279e94193a565dc5abbb337f..243425f1ea9199a152163ad24b711e46cf88b3b6 100644
--- a/src/ptbench/data/dataset.py
+++ b/src/ptbench/data/dataset.py
@@ -367,50 +367,6 @@ class RuntimeDataset(torch.utils.data.Dataset):
         return len(self._samples)
 
 
-class TBDataset(torch.utils.data.Dataset):
-    def __init__(
-        self,
-        json_protocol,
-        protocol,
-        subset,
-        raw_data_loader,
-        transforms,
-        cache_samples=False,
-    ):
-        self.json_protocol = json_protocol
-        self.subset = subset
-        self.raw_data_loader = raw_data_loader
-        self.transforms = transforms
-
-        self.cache_samples = cache_samples
-
-        self._samples = json_protocol.subsets(protocol)[self.subset]
-
-        # Dict entry with relative path to files
-        for s in self._samples:
-            s["name"] = s["data"]
-
-        if self.cache_samples:
-            logger.info(f"Caching {self.subset} samples")
-            for sample in tqdm(self._samples):
-                sample["data"] = self.transforms(
-                    self.raw_data_loader(sample["data"])
-                )
-
-    def __getitem__(self, idx):
-        if self.cache_samples:
-            return self._samples[idx]
-        else:
-            sample = self._samples[idx].copy()
-            sample["data"] = self.transforms(
-                self.raw_data_loader(sample["data"])
-            )
-            return sample
-
-    def __len__(self):
-        return len(self._samples)
-
-
 def get_samples_weights(dataset):
     """Compute the weights of all the samples of the dataset to balance it
     using the sampler of the dataloader.
diff --git a/src/ptbench/data/shenzhen/default.py b/src/ptbench/data/shenzhen/default.py
index bf75eacb36b83f121e53ecb19ab97edfdd77d9ab..8afac8469920dde1777b0efc6ae3919c1e381f13 100644
--- a/src/ptbench/data/shenzhen/default.py
+++ b/src/ptbench/data/shenzhen/default.py
@@ -14,7 +14,7 @@ from clapper.logging import setup
 from torchvision import transforms
 
 from ..base_datamodule import BaseDataModule
-from ..dataset import JSONProtocol, TBDataset
+from ..dataset import CachedDataset, JSONProtocol, RuntimeDataset
 from ..shenzhen import _protocols, _raw_data_loader
 from ..transforms import ElasticDeformation, RemoveBlackBorders
 
@@ -59,43 +59,44 @@ class DefaultModule(BaseDataModule):
             fieldnames=("data", "label"),
         )
 
+        if self._cache_samples:
+            dataset = CachedDataset
+        else:
+            dataset = RuntimeDataset
+
         if not self._has_setup_fit and stage == "fit":
-            self.train_dataset = TBDataset(
+            self.train_dataset = dataset(
                 json_protocol,
                 self._protocol,
                 "train",
                 _raw_data_loader,
                 self._build_transforms(is_train=True),
-                cache_samples=self._cache_samples,
             )
 
-            self.validation_dataset = TBDataset(
+            self.validation_dataset = dataset(
                 json_protocol,
                 self._protocol,
                 "validation",
                 _raw_data_loader,
                 self._build_transforms(is_train=False),
-                cache_samples=self._cache_samples,
             )
 
             self._has_setup_fit = True
 
         if not self._has_setup_predict and stage == "predict":
-            self.train_dataset = TBDataset(
+            self.train_dataset = dataset(
                 json_protocol,
                 self._protocol,
                 "train",
                 _raw_data_loader,
                 self._build_transforms(is_train=False),
-                cache_samples=self._cache_samples,
             )
-            self.validation_dataset = TBDataset(
+            self.validation_dataset = dataset(
                 json_protocol,
                 self._protocol,
                 "validation",
                 _raw_data_loader,
                 self._build_transforms(is_train=False),
-                cache_samples=self._cache_samples,
             )
 
             self._has_setup_predict = True