diff --git a/src/mednet/models/transforms.py b/src/mednet/models/transforms.py
index afb158ac7140de1554250a05751acf47e3a90b02..2a4f933ae4f541ea71015176e0a048e99bdfd4d8 100644
--- a/src/mednet/models/transforms.py
+++ b/src/mednet/models/transforms.py
@@ -10,7 +10,7 @@ import torchvision.transforms.functional
 
 
 def square_center_pad(img: torch.Tensor) -> torch.Tensor:
-    """Returns a squared version of the image, centered on a canvas padded with
+    """Return a squared version of the image, centered on a canvas padded with
     zeros.
 
     Parameters
@@ -23,8 +23,7 @@ def square_center_pad(img: torch.Tensor) -> torch.Tensor:
     Returns
     -------
 
-    img
-        transformed tensor, guaranteed to be square (ie. equal height and
+        Transformed tensor, guaranteed to be square (ie. equal height and
         width).
     """
 
@@ -129,8 +128,7 @@ def rgb_to_grayscale(img: torch.Tensor) -> torch.Tensor:
 
 
 class SquareCenterPad(torch.nn.Module):
-    """Transforms to a squared version of the image, centered on a canvas
-    padded with zeros."""
+    """Transform to a squared version of the image, centered on a canvas padded with zeros."""
 
     def __init__(self):
         super().__init__()