Skip to content
Snippets Groups Projects
Commit 1b773088 authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

Cleaner way to apply train_transforms

parent c3f14d0d
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
...@@ -77,10 +77,10 @@ class ElasticDeformation: ...@@ -77,10 +77,10 @@ class ElasticDeformation:
self.random_state = random_state self.random_state = random_state
self.p = p self.p = p
self.tensor_transform = transforms.Compose([transforms.ToTensor()])
def __call__(self, img): def __call__(self, img):
if random.random() < self.p: if random.random() < self.p:
img = transforms.ToPILImage()(img)
img = numpy.asarray(img) img = numpy.asarray(img)
assert img.ndim == 2 assert img.ndim == 2
...@@ -117,6 +117,6 @@ class ElasticDeformation: ...@@ -117,6 +117,6 @@ class ElasticDeformation:
result[:, :] = map_coordinates( result[:, :] = map_coordinates(
img[:, :], indices, order=self.spline_order, mode=self.mode img[:, :], indices, order=self.spline_order, mode=self.mode
).reshape(shape) ).reshape(shape)
return self.tensor_transform(PIL.Image.fromarray(result)) return transforms.ToTensor()(PIL.Image.fromarray(result))
else: else:
return img return img
...@@ -133,10 +133,9 @@ class PASA(pl.LightningModule): ...@@ -133,10 +133,9 @@ class PASA(pl.LightningModule):
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
images = batch[1] images = batch[1]
labels = batch[2] labels = batch[2]
for img in images: for img in images:
img = torch.unsqueeze( img = self.train_transforms(img)
self.train_transforms(torch.squeeze(img, 0)), 0
)
# Increase label dimension if too low # Increase label dimension if too low
# Allows single and multiclass usage # Allows single and multiclass usage
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment