diff --git a/src/mednet/models/alexnet.py b/src/mednet/models/alexnet.py
index 5e5b3e7eee3d810aaa4c8dfed31ad3d9fa5414b5..5177316da7d1b505b324429babfd226b5ae33d51 100644
--- a/src/mednet/models/alexnet.py
+++ b/src/mednet/models/alexnet.py
@@ -15,7 +15,7 @@ import torchvision.transforms
 
 from ..data.typing import TransformSequence
 from .separate import separate
-from .transforms import RGB
+from .transforms import RGB, SquareCenterPad
 from .typing import Checkpoint
 
 logger = logging.getLogger(__name__)
@@ -80,6 +80,7 @@ class Alexnet(pl.LightningModule):
         self.num_classes = num_classes
 
         self.model_transforms = [
+            SquareCenterPad(),
             torchvision.transforms.Resize(512, antialias=True),
             RGB(),
         ]
diff --git a/src/mednet/models/densenet.py b/src/mednet/models/densenet.py
index 333edb11e2112db9ef1f1bf37b2d333f5b418de6..f6c4d916f1791d83653ed4cd986a4be025de5a43 100644
--- a/src/mednet/models/densenet.py
+++ b/src/mednet/models/densenet.py
@@ -15,7 +15,7 @@ import torchvision.transforms
 
 from ..data.typing import TransformSequence
 from .separate import separate
-from .transforms import RGB
+from .transforms import RGB, SquareCenterPad
 from .typing import Checkpoint
 
 logger = logging.getLogger(__name__)
@@ -77,8 +77,8 @@ class Densenet(pl.LightningModule):
         self.name = "densenet-121"
         self.num_classes = num_classes
 
-        # image is probably large, resize first to get memory usage down
         self.model_transforms = [
+            SquareCenterPad(),
             torchvision.transforms.Resize(512, antialias=True),
             RGB(),
         ]
diff --git a/src/mednet/models/pasa.py b/src/mednet/models/pasa.py
index a7b8ae62c0c56bc2c5172058fdd3382c987514a1..21cfc2e1e6345b9318ae5e4a15e9b9a1981f6130 100644
--- a/src/mednet/models/pasa.py
+++ b/src/mednet/models/pasa.py
@@ -15,7 +15,7 @@ import torchvision.transforms
 
 from ..data.typing import TransformSequence
 from .separate import separate
-from .transforms import Grayscale
+from .transforms import Grayscale, SquareCenterPad
 from .typing import Checkpoint
 
 logger = logging.getLogger(__name__)
@@ -78,10 +78,10 @@ class Pasa(pl.LightningModule):
         self.name = "pasa"
         self.num_classes = num_classes
 
-        # image is probably large, resize first to get memory usage down
         self.model_transforms = [
-            torchvision.transforms.Resize(512, antialias=True),
             Grayscale(),
+            SquareCenterPad(),
+            torchvision.transforms.Resize(512, antialias=True),
         ]
 
         self._train_loss = train_loss
diff --git a/src/mednet/models/transforms.py b/src/mednet/models/transforms.py
index 8ff0af726aada9256603882dd8a6c0aeb6d34cad..92869c7b8210740c062bd5f67114c26481f96a58 100644
--- a/src/mednet/models/transforms.py
+++ b/src/mednet/models/transforms.py
@@ -3,11 +3,45 @@
 # SPDX-License-Identifier: GPL-3.0-or-later
 """A transform that turns grayscale images to RGB."""
 
+import numpy
 import torch
 import torch.nn
 import torchvision.transforms.functional
 
 
+def square_center_pad(img: torch.Tensor) -> torch.Tensor:
+    """Returns a squared version of the image, centered on a canvas padded with
+    zeros.
+
+    Parameters
+    ----------
+
+    img
+        The tensor to be transformed.  Expected to be in the form: ``[...,
+        [1,3], H, W]`` (i.e. arbitrary number of leading dimensions).
+
+    Returns
+    -------
+
+    img
+        transformed tensor, guaranteed to be square (ie. equal height and
+        width).
+    """
+
+    height, width = img.shape[-2:]
+    maxdim = numpy.max([height, width])
+
+    # padding
+    left = (maxdim - width) // 2
+    top = (maxdim - height) // 2
+    right = maxdim - width - left
+    bottom = maxdim - height - top
+
+    return torchvision.transforms.functional.pad(
+        img, [left, top, right, bottom], 0, "constant"
+    )
+
+
 def grayscale_to_rgb(img: torch.Tensor) -> torch.Tensor:
     """Converts an image in grayscale to RGB.
 
@@ -97,6 +131,17 @@ def rgb_to_grayscale(img: torch.Tensor) -> torch.Tensor:
     return torchvision.transforms.functional.rgb_to_grayscale(img)
 
 
+class SquareCenterPad(torch.nn.Module):
+    """Transforms to a squared version of the image, centered on a canvas
+    padded with zeros."""
+
+    def __init__(self):
+        super().__init__()
+
+    def forward(self, img: torch.Tensor) -> torch.Tensor:
+        return square_center_pad(img)
+
+
 class RGB(torch.nn.Module):
     """Converts an image in grayscale to RGB.