From dd849de43e5fb91ef2533aa03a490a59e826149b Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Wed, 14 Jun 2023 14:57:00 +0200 Subject: [PATCH] Cleaner way to apply train_transforms --- src/ptbench/data/transforms.py | 6 +++--- src/ptbench/models/pasa.py | 5 ++--- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/ptbench/data/transforms.py b/src/ptbench/data/transforms.py index 34c3a605..6d3c17d2 100644 --- a/src/ptbench/data/transforms.py +++ b/src/ptbench/data/transforms.py @@ -77,10 +77,10 @@ class ElasticDeformation: self.random_state = random_state self.p = p - self.tensor_transform = transforms.Compose([transforms.ToTensor()]) - def __call__(self, img): if random.random() < self.p: + img = transforms.ToPILImage()(img) + img = numpy.asarray(img) assert img.ndim == 2 @@ -117,6 +117,6 @@ class ElasticDeformation: result[:, :] = map_coordinates( img[:, :], indices, order=self.spline_order, mode=self.mode ).reshape(shape) - return self.tensor_transform(PIL.Image.fromarray(result)) + return transforms.ToTensor()(PIL.Image.fromarray(result)) else: return img diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py index c127239d..ed8c4b30 100644 --- a/src/ptbench/models/pasa.py +++ b/src/ptbench/models/pasa.py @@ -133,10 +133,9 @@ class PASA(pl.LightningModule): def training_step(self, batch, batch_idx): images = batch[1] labels = batch[2] + for img in images: - img = torch.unsqueeze( - self.train_transforms(torch.squeeze(img, 0)), 0 - ) + img = self.train_transforms(img) # Increase label dimension if too low # Allows single and multiclass usage -- GitLab