diff --git a/src/mednet/models/transforms.py b/src/mednet/models/transforms.py index afb158ac7140de1554250a05751acf47e3a90b02..2a4f933ae4f541ea71015176e0a048e99bdfd4d8 100644 --- a/src/mednet/models/transforms.py +++ b/src/mednet/models/transforms.py @@ -10,7 +10,7 @@ import torchvision.transforms.functional def square_center_pad(img: torch.Tensor) -> torch.Tensor: - """Returns a squared version of the image, centered on a canvas padded with + """Return a squared version of the image, centered on a canvas padded with zeros. Parameters @@ -23,8 +23,7 @@ def square_center_pad(img: torch.Tensor) -> torch.Tensor: Returns ------- - img - transformed tensor, guaranteed to be square (ie. equal height and + Transformed tensor, guaranteed to be square (ie. equal height and width). """ @@ -129,8 +128,7 @@ def rgb_to_grayscale(img: torch.Tensor) -> torch.Tensor: class SquareCenterPad(torch.nn.Module): - """Transforms to a squared version of the image, centered on a canvas - padded with zeros.""" + """Transform to a squared version of the image, centered on a canvas padded with zeros.""" def __init__(self): super().__init__()