diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py index e2cb9b053a57c6c3a786c19c4cb47a99e8d5cddb..650aa7d4f60dfe5feda914deb8518415fda8876d 100644 --- a/src/ptbench/models/pasa.py +++ b/src/ptbench/models/pasa.py @@ -73,7 +73,8 @@ class Pasa(pl.LightningModule): self.name = "pasa" self.model_transforms = [ - torchvision.transforms.Resize(512), + torchvision.transforms.Grayscale(), + torchvision.transforms.Resize(512, antialias=True), ] self._train_loss = train_loss