From cd44c9fde10b65d28092eef62bf54451b35e6b92 Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Wed, 28 Jun 2023 10:38:07 +0200 Subject: [PATCH] Apply transforms during __getitem__ in CachedDataset --- src/ptbench/data/dataset.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/ptbench/data/dataset.py b/src/ptbench/data/dataset.py index 243425f1..15dc32a9 100644 --- a/src/ptbench/data/dataset.py +++ b/src/ptbench/data/dataset.py @@ -333,12 +333,12 @@ class CachedDataset(torch.utils.data.Dataset): logger.info(f"Caching {self.subset} samples") for sample in tqdm(self._samples): - sample["data"] = self.transforms( - self.raw_data_loader(sample["data"]) - ) + sample["data"] = self.raw_data_loader(sample["data"]) def __getitem__(self, idx): - return self._samples[idx] + sample = self._samples[idx].copy() + sample["data"] = self.transforms(sample["data"]) + return sample def __len__(self): return len(self._samples) -- GitLab