From 2fcec25b7b623f40f66024c5b572a9a9f25bcdff Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Tue, 27 Jun 2023 16:41:42 +0200
Subject: [PATCH] Removed TBDataset, using Runtime or Cached datasets instead

---
 src/ptbench/data/dataset.py          | 44 ----------------------------
 src/ptbench/data/shenzhen/default.py | 19 ++++++------
 2 files changed, 10 insertions(+), 53 deletions(-)

diff --git a/src/ptbench/data/dataset.py b/src/ptbench/data/dataset.py
index fad353a1..243425f1 100644
--- a/src/ptbench/data/dataset.py
+++ b/src/ptbench/data/dataset.py
@@ -367,50 +367,6 @@ class RuntimeDataset(torch.utils.data.Dataset):
         return len(self._samples)
 
 
-class TBDataset(torch.utils.data.Dataset):
-    def __init__(
-        self,
-        json_protocol,
-        protocol,
-        subset,
-        raw_data_loader,
-        transforms,
-        cache_samples=False,
-    ):
-        self.json_protocol = json_protocol
-        self.subset = subset
-        self.raw_data_loader = raw_data_loader
-        self.transforms = transforms
-
-        self.cache_samples = cache_samples
-
-        self._samples = json_protocol.subsets(protocol)[self.subset]
-
-        # Dict entry with relative path to files
-        for s in self._samples:
-            s["name"] = s["data"]
-
-        if self.cache_samples:
-            logger.info(f"Caching {self.subset} samples")
-            for sample in tqdm(self._samples):
-                sample["data"] = self.transforms(
-                    self.raw_data_loader(sample["data"])
-                )
-
-    def __getitem__(self, idx):
-        if self.cache_samples:
-            return self._samples[idx]
-        else:
-            sample = self._samples[idx].copy()
-            sample["data"] = self.transforms(
-                self.raw_data_loader(sample["data"])
-            )
-            return sample
-
-    def __len__(self):
-        return len(self._samples)
-
-
 def get_samples_weights(dataset):
     """Compute the weights of all the samples of the dataset to balance it
     using the sampler of the dataloader.
diff --git a/src/ptbench/data/shenzhen/default.py b/src/ptbench/data/shenzhen/default.py
index bf75eacb..8afac846 100644
--- a/src/ptbench/data/shenzhen/default.py
+++ b/src/ptbench/data/shenzhen/default.py
@@ -14,7 +14,7 @@ from clapper.logging import setup
 from torchvision import transforms
 
 from ..base_datamodule import BaseDataModule
-from ..dataset import JSONProtocol, TBDataset
+from ..dataset import CachedDataset, JSONProtocol, RuntimeDataset
 from ..shenzhen import _protocols, _raw_data_loader
 from ..transforms import ElasticDeformation, RemoveBlackBorders
 
@@ -59,43 +59,44 @@ class DefaultModule(BaseDataModule):
             fieldnames=("data", "label"),
         )
 
+        if self._cache_samples:
+            dataset = CachedDataset
+        else:
+            dataset = RuntimeDataset
+
         if not self._has_setup_fit and stage == "fit":
-            self.train_dataset = TBDataset(
+            self.train_dataset = dataset(
                 json_protocol,
                 self._protocol,
                 "train",
                 _raw_data_loader,
                 self._build_transforms(is_train=True),
-                cache_samples=self._cache_samples,
             )
 
-            self.validation_dataset = TBDataset(
+            self.validation_dataset = dataset(
                 json_protocol,
                 self._protocol,
                 "validation",
                 _raw_data_loader,
                 self._build_transforms(is_train=False),
-                cache_samples=self._cache_samples,
             )
 
             self._has_setup_fit = True
 
         if not self._has_setup_predict and stage == "predict":
-            self.train_dataset = TBDataset(
+            self.train_dataset = dataset(
                 json_protocol,
                 self._protocol,
                 "train",
                 _raw_data_loader,
                 self._build_transforms(is_train=False),
-                cache_samples=self._cache_samples,
             )
-            self.validation_dataset = TBDataset(
+            self.validation_dataset = dataset(
                 json_protocol,
                 self._protocol,
                 "validation",
                 _raw_data_loader,
                 self._build_transforms(is_train=False),
-                cache_samples=self._cache_samples,
             )
 
             self._has_setup_predict = True
-- 
GitLab