diff --git a/src/mednet/models/transforms.py b/src/mednet/models/transforms.py
index 1e846e51eb83afeb8635fd18825d7f45caa2b999..8f5286fe1b804fd64f56f6b1e8e2c91fb9d395f9 100644
--- a/src/mednet/models/transforms.py
+++ b/src/mednet/models/transforms.py
@@ -13,8 +13,8 @@ def grayscale_to_rgb(img: torch.Tensor) -> torch.Tensor:
 
     If the image is already in RGB format, then this is a NOOP - the same
     tensor is returned (no cloning).  If the image is in grayscale format
-    (number of channels = 1), then triplicate that channel 3 times (a new copy is
-    returned in this case).
+    (number of color channels = 1), then replicate it to obtain 3 color channels
+    (a new copy is returned in this case).
 
 
     Parameters
@@ -28,7 +28,7 @@ def grayscale_to_rgb(img: torch.Tensor) -> torch.Tensor:
     -------
 
     img
-        Transformed tensor with the channel dimension replicated 3 times.
+        Transformed tensor with 3 identical color channels.
     """
     if img.ndim < 3:
         raise TypeError(
@@ -56,9 +56,9 @@ def rgb_to_grayscale(img: torch.Tensor) -> torch.Tensor:
     """Converts an image in RGB to grayscale.
 
     If the image is already in grayscale format, then this is a NOOP - the same
-    tensor is returned (no cloning).  If the image is in RGB format, then
-    compresses the color channels into a single grayscale channel
-    following this equation:
+    tensor is returned (no cloning).  If the image is in RGB format
+    (number of color channels = 3), then compresses the color channels into
+    a single grayscale channel following this equation:
 
     .. math::