From 2fcec25b7b623f40f66024c5b572a9a9f25bcdff Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Tue, 27 Jun 2023 16:41:42 +0200 Subject: [PATCH] Removed TBDataset, using Runtime or Cached datasets instead --- src/ptbench/data/dataset.py | 44 ---------------------------- src/ptbench/data/shenzhen/default.py | 19 ++++++------ 2 files changed, 10 insertions(+), 53 deletions(-) diff --git a/src/ptbench/data/dataset.py b/src/ptbench/data/dataset.py index fad353a1..243425f1 100644 --- a/src/ptbench/data/dataset.py +++ b/src/ptbench/data/dataset.py @@ -367,50 +367,6 @@ class RuntimeDataset(torch.utils.data.Dataset): return len(self._samples) -class TBDataset(torch.utils.data.Dataset): - def __init__( - self, - json_protocol, - protocol, - subset, - raw_data_loader, - transforms, - cache_samples=False, - ): - self.json_protocol = json_protocol - self.subset = subset - self.raw_data_loader = raw_data_loader - self.transforms = transforms - - self.cache_samples = cache_samples - - self._samples = json_protocol.subsets(protocol)[self.subset] - - # Dict entry with relative path to files - for s in self._samples: - s["name"] = s["data"] - - if self.cache_samples: - logger.info(f"Caching {self.subset} samples") - for sample in tqdm(self._samples): - sample["data"] = self.transforms( - self.raw_data_loader(sample["data"]) - ) - - def __getitem__(self, idx): - if self.cache_samples: - return self._samples[idx] - else: - sample = self._samples[idx].copy() - sample["data"] = self.transforms( - self.raw_data_loader(sample["data"]) - ) - return sample - - def __len__(self): - return len(self._samples) - - def get_samples_weights(dataset): """Compute the weights of all the samples of the dataset to balance it using the sampler of the dataloader. diff --git a/src/ptbench/data/shenzhen/default.py b/src/ptbench/data/shenzhen/default.py index bf75eacb..8afac846 100644 --- a/src/ptbench/data/shenzhen/default.py +++ b/src/ptbench/data/shenzhen/default.py @@ -14,7 +14,7 @@ from clapper.logging import setup from torchvision import transforms from ..base_datamodule import BaseDataModule -from ..dataset import JSONProtocol, TBDataset +from ..dataset import CachedDataset, JSONProtocol, RuntimeDataset from ..shenzhen import _protocols, _raw_data_loader from ..transforms import ElasticDeformation, RemoveBlackBorders @@ -59,43 +59,44 @@ class DefaultModule(BaseDataModule): fieldnames=("data", "label"), ) + if self._cache_samples: + dataset = CachedDataset + else: + dataset = RuntimeDataset + if not self._has_setup_fit and stage == "fit": - self.train_dataset = TBDataset( + self.train_dataset = dataset( json_protocol, self._protocol, "train", _raw_data_loader, self._build_transforms(is_train=True), - cache_samples=self._cache_samples, ) - self.validation_dataset = TBDataset( + self.validation_dataset = dataset( json_protocol, self._protocol, "validation", _raw_data_loader, self._build_transforms(is_train=False), - cache_samples=self._cache_samples, ) self._has_setup_fit = True if not self._has_setup_predict and stage == "predict": - self.train_dataset = TBDataset( + self.train_dataset = dataset( json_protocol, self._protocol, "train", _raw_data_loader, self._build_transforms(is_train=False), - cache_samples=self._cache_samples, ) - self.validation_dataset = TBDataset( + self.validation_dataset = dataset( json_protocol, self._protocol, "validation", _raw_data_loader, self._build_transforms(is_train=False), - cache_samples=self._cache_samples, ) self._has_setup_predict = True -- GitLab