diff --git a/src/mednet/models/transforms.py b/src/mednet/models/transforms.py index 1e846e51eb83afeb8635fd18825d7f45caa2b999..8f5286fe1b804fd64f56f6b1e8e2c91fb9d395f9 100644 --- a/src/mednet/models/transforms.py +++ b/src/mednet/models/transforms.py @@ -13,8 +13,8 @@ def grayscale_to_rgb(img: torch.Tensor) -> torch.Tensor: If the image is already in RGB format, then this is a NOOP - the same tensor is returned (no cloning). If the image is in grayscale format - (number of channels = 1), then triplicate that channel 3 times (a new copy is - returned in this case). + (number of color channels = 1), then replicate it to obtain 3 color channels + (a new copy is returned in this case). Parameters @@ -28,7 +28,7 @@ def grayscale_to_rgb(img: torch.Tensor) -> torch.Tensor: ------- img - Transformed tensor with the channel dimension replicated 3 times. + Transformed tensor with 3 identical color channels. """ if img.ndim < 3: raise TypeError( @@ -56,9 +56,9 @@ def rgb_to_grayscale(img: torch.Tensor) -> torch.Tensor: """Converts an image in RGB to grayscale. If the image is already in grayscale format, then this is a NOOP - the same - tensor is returned (no cloning). If the image is in RGB format, then - compresses the color channels into a single grayscale channel - following this equation: + tensor is returned (no cloning). If the image is in RGB format + (number of color channels = 3), then compresses the color channels into + a single grayscale channel following this equation: .. math::