Skip to content
Snippets Groups Projects
Commit dd849de4 authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

Cleaner way to apply train_transforms

parent 9855ed27
No related branches found
No related tags found
No related merge requests found
......@@ -77,10 +77,10 @@ class ElasticDeformation:
self.random_state = random_state
self.p = p
self.tensor_transform = transforms.Compose([transforms.ToTensor()])
def __call__(self, img):
if random.random() < self.p:
img = transforms.ToPILImage()(img)
img = numpy.asarray(img)
assert img.ndim == 2
......@@ -117,6 +117,6 @@ class ElasticDeformation:
result[:, :] = map_coordinates(
img[:, :], indices, order=self.spline_order, mode=self.mode
).reshape(shape)
return self.tensor_transform(PIL.Image.fromarray(result))
return transforms.ToTensor()(PIL.Image.fromarray(result))
else:
return img
......@@ -133,10 +133,9 @@ class PASA(pl.LightningModule):
def training_step(self, batch, batch_idx):
images = batch[1]
labels = batch[2]
for img in images:
img = torch.unsqueeze(
self.train_transforms(torch.squeeze(img, 0)), 0
)
img = self.train_transforms(img)
# Increase label dimension if too low
# Allows single and multiclass usage
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment