From cd44c9fde10b65d28092eef62bf54451b35e6b92 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Wed, 28 Jun 2023 10:38:07 +0200
Subject: [PATCH] Apply transforms during __getitem__ in CachedDataset

---
 src/ptbench/data/dataset.py | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/src/ptbench/data/dataset.py b/src/ptbench/data/dataset.py
index 243425f1..15dc32a9 100644
--- a/src/ptbench/data/dataset.py
+++ b/src/ptbench/data/dataset.py
@@ -333,12 +333,12 @@ class CachedDataset(torch.utils.data.Dataset):
 
         logger.info(f"Caching {self.subset} samples")
         for sample in tqdm(self._samples):
-            sample["data"] = self.transforms(
-                self.raw_data_loader(sample["data"])
-            )
+            sample["data"] = self.raw_data_loader(sample["data"])
 
     def __getitem__(self, idx):
-        return self._samples[idx]
+        sample = self._samples[idx].copy()
+        sample["data"] = self.transforms(sample["data"])
+        return sample
 
     def __len__(self):
         return len(self._samples)
-- 
GitLab