From 10cd1e243cc951edfd72c22fcc3fe54eeacc0976 Mon Sep 17 00:00:00 2001 From: mdelitroz <maxime.delitroz@idiap.ch> Date: Thu, 20 Jul 2023 19:03:03 +0200 Subject: [PATCH] We decided to handle the resizing of the input at the model level and not at the dataset level anymore. Loaders, docstrings and tests of Shenzhen and Montgomery were updated accordingly to only load data, remove black borders and crop to square shape. [image_utils.RemoveBlackBorders] had to be modified to allow functional utilization --- src/ptbench/data/image_utils.py | 38 ++++++++++++++++++++++++-- src/ptbench/data/montgomery/default.py | 5 ++-- src/ptbench/data/montgomery/fold_0.py | 5 ++-- src/ptbench/data/montgomery/fold_1.py | 5 ++-- src/ptbench/data/montgomery/fold_2.py | 5 ++-- src/ptbench/data/montgomery/fold_3.py | 5 ++-- src/ptbench/data/montgomery/fold_4.py | 5 ++-- src/ptbench/data/montgomery/fold_5.py | 5 ++-- src/ptbench/data/montgomery/fold_6.py | 5 ++-- src/ptbench/data/montgomery/fold_7.py | 5 ++-- src/ptbench/data/montgomery/fold_8.py | 5 ++-- src/ptbench/data/montgomery/fold_9.py | 5 ++-- src/ptbench/data/montgomery/loader.py | 27 ++++++------------ src/ptbench/data/shenzhen/default.py | 10 +++---- src/ptbench/data/shenzhen/fold_0.py | 10 +++---- src/ptbench/data/shenzhen/fold_1.py | 10 +++---- src/ptbench/data/shenzhen/fold_2.py | 10 +++---- src/ptbench/data/shenzhen/fold_3.py | 10 +++---- src/ptbench/data/shenzhen/fold_4.py | 10 +++---- src/ptbench/data/shenzhen/fold_5.py | 10 +++---- src/ptbench/data/shenzhen/fold_6.py | 10 +++---- src/ptbench/data/shenzhen/fold_7.py | 10 +++---- src/ptbench/data/shenzhen/fold_8.py | 10 +++---- src/ptbench/data/shenzhen/fold_9.py | 10 +++---- src/ptbench/data/shenzhen/loader.py | 21 ++++---------- tests/test_ch.py | 11 +++----- tests/test_mc.py | 8 ++---- 27 files changed, 133 insertions(+), 137 deletions(-) diff --git a/src/ptbench/data/image_utils.py b/src/ptbench/data/image_utils.py index ac31b9ce..ed284afc 100644 --- a/src/ptbench/data/image_utils.py +++ b/src/ptbench/data/image_utils.py @@ -31,6 +31,40 @@ class SingleAutoLevel16to8: ).convert("L") +def remove_black_borders(img: PIL.Image.Image, threshold: int=0) -> PIL.Image.Image: + """Remove black borders of CXR + + Parameters + ---------- + img + A PIL image + threshold + Threshold value from which borders are considered black. + Defaults to 0. + + Returns + ------- + A PIL image with black borders removed + """ + + img = numpy.asarray(img) + + if len(img.shape) == 2: # single channel + mask = numpy.asarray(img) > threshold + return PIL.Image.fromarray(img[numpy.ix_(mask.any(1), mask.any(0))]) + + elif len(img.shape) == 3 and img.shape[2] == 3: + r_mask = img[:, :, 0] > threshold + g_mask = img[:, :, 1] > threshold + b_mask = img[:, :, 2] > threshold + + mask = r_mask | g_mask | b_mask + return PIL.Image.fromarray(img[numpy.ix_(mask.any(1), mask.any(0))]) + + else: + raise NotImplementedError + + class RemoveBlackBorders: """Remove black borders of CXR.""" @@ -38,9 +72,7 @@ class RemoveBlackBorders: self.threshold = threshold def __call__(self, img): - img = numpy.asarray(img) - mask = numpy.asarray(img) > self.threshold - return PIL.Image.fromarray(img[numpy.ix_(mask.any(1), mask.any(0))]) + return remove_black_borders(img, self.threshold) def load_pil(path: str | pathlib.Path) -> PIL.Image.Image: diff --git a/src/ptbench/data/montgomery/default.py b/src/ptbench/data/montgomery/default.py index 5a0aad35..35878b1a 100644 --- a/src/ptbench/data/montgomery/default.py +++ b/src/ptbench/data/montgomery/default.py @@ -19,13 +19,12 @@ This configuration: * Load raw PNG with :py:mod:`PIL` * Remove black borders - * Torch resizing (512 x 512 px) - * Torch center cropping (512 x 512 px) + * Torch center cropping to get square image * Final specifications - * Fixed resolution: 512 x 512 px * Grayscale (single channel), 8 bits + * Varying resolutions Protocol ``default``: diff --git a/src/ptbench/data/montgomery/fold_0.py b/src/ptbench/data/montgomery/fold_0.py index 8567b7a5..4bfc4784 100644 --- a/src/ptbench/data/montgomery/fold_0.py +++ b/src/ptbench/data/montgomery/fold_0.py @@ -19,13 +19,12 @@ This configuration: * Load raw PNG with :py:mod:`PIL` * Remove black borders - * Torch resizing (512 x 512 px) - * Torch center cropping (512 x 512 px) + * Torch center cropping to get square image * Final specifications - * Fixed resolution: 512 x 512 px * Grayscale (single channel), 8 bits + * Varying resolutions """ import importlib.resources diff --git a/src/ptbench/data/montgomery/fold_1.py b/src/ptbench/data/montgomery/fold_1.py index 6b9679b2..0a74516f 100644 --- a/src/ptbench/data/montgomery/fold_1.py +++ b/src/ptbench/data/montgomery/fold_1.py @@ -19,13 +19,12 @@ This configuration: * Load raw PNG with :py:mod:`PIL` * Remove black borders - * Torch resizing (512 x 512 px) - * Torch center cropping (512 x 512 px) + * Torch center cropping to get square image * Final specifications - * Fixed resolution: 512 x 512 px * Grayscale (single channel), 8 bits + * Varying resolutions """ import importlib.resources diff --git a/src/ptbench/data/montgomery/fold_2.py b/src/ptbench/data/montgomery/fold_2.py index b2bcc3cd..386d3080 100644 --- a/src/ptbench/data/montgomery/fold_2.py +++ b/src/ptbench/data/montgomery/fold_2.py @@ -19,13 +19,12 @@ This configuration: * Load raw PNG with :py:mod:`PIL` * Remove black borders - * Torch resizing (512 x 512 px) - * Torch center cropping (512 x 512 px) + * Torch center cropping to get square image * Final specifications - * Fixed resolution: 512 x 512 px * Grayscale (single channel), 8 bits + * Varying resolutions """ import importlib.resources diff --git a/src/ptbench/data/montgomery/fold_3.py b/src/ptbench/data/montgomery/fold_3.py index 63bc27ec..1bfaa888 100644 --- a/src/ptbench/data/montgomery/fold_3.py +++ b/src/ptbench/data/montgomery/fold_3.py @@ -19,13 +19,12 @@ This configuration: * Load raw PNG with :py:mod:`PIL` * Remove black borders - * Torch resizing (512 x 512 px) - * Torch center cropping (512 x 512 px) + * Torch center cropping to get square image * Final specifications - * Fixed resolution: 512 x 512 px * Grayscale (single channel), 8 bits + * Varying resolutions """ import importlib.resources diff --git a/src/ptbench/data/montgomery/fold_4.py b/src/ptbench/data/montgomery/fold_4.py index cbf31799..b955cb15 100644 --- a/src/ptbench/data/montgomery/fold_4.py +++ b/src/ptbench/data/montgomery/fold_4.py @@ -19,13 +19,12 @@ This configuration: * Load raw PNG with :py:mod:`PIL` * Remove black borders - * Torch resizing (512 x 512 px) - * Torch center cropping (512 x 512 px) + * Torch center cropping to get square image * Final specifications - * Fixed resolution: 512 x 512 px * Grayscale (single channel), 8 bits + * Varying resolutions """ import importlib.resources diff --git a/src/ptbench/data/montgomery/fold_5.py b/src/ptbench/data/montgomery/fold_5.py index aca9c62b..5604cfdd 100644 --- a/src/ptbench/data/montgomery/fold_5.py +++ b/src/ptbench/data/montgomery/fold_5.py @@ -19,13 +19,12 @@ This configuration: * Load raw PNG with :py:mod:`PIL` * Remove black borders - * Torch resizing (512 x 512 px) - * Torch center cropping (512 x 512 px) + * Torch center cropping to get square image * Final specifications - * Fixed resolution: 512 x 512 px * Grayscale (single channel), 8 bits + * Varying resolutions """ import importlib.resources diff --git a/src/ptbench/data/montgomery/fold_6.py b/src/ptbench/data/montgomery/fold_6.py index ac5cdd57..72178209 100644 --- a/src/ptbench/data/montgomery/fold_6.py +++ b/src/ptbench/data/montgomery/fold_6.py @@ -19,13 +19,12 @@ This configuration: * Load raw PNG with :py:mod:`PIL` * Remove black borders - * Torch resizing (512 x 512 px) - * Torch center cropping (512 x 512 px) + * Torch center cropping to get square image * Final specifications - * Fixed resolution: 512 x 512 px * Grayscale (single channel), 8 bits + * Varying resolutions """ import importlib.resources diff --git a/src/ptbench/data/montgomery/fold_7.py b/src/ptbench/data/montgomery/fold_7.py index edf8957a..de895133 100644 --- a/src/ptbench/data/montgomery/fold_7.py +++ b/src/ptbench/data/montgomery/fold_7.py @@ -19,13 +19,12 @@ This configuration: * Load raw PNG with :py:mod:`PIL` * Remove black borders - * Torch resizing (512 x 512 px) - * Torch center cropping (512 x 512 px) + * Torch center cropping to get square image * Final specifications - * Fixed resolution: 512 x 512 px * Grayscale (single channel), 8 bits + * Varying resolutions """ import importlib.resources diff --git a/src/ptbench/data/montgomery/fold_8.py b/src/ptbench/data/montgomery/fold_8.py index ed0e5a91..bc901fc5 100644 --- a/src/ptbench/data/montgomery/fold_8.py +++ b/src/ptbench/data/montgomery/fold_8.py @@ -19,13 +19,12 @@ This configuration: * Load raw PNG with :py:mod:`PIL` * Remove black borders - * Torch resizing (512 x 512 px) - * Torch center cropping (512 x 512 px) + * Torch center cropping to get square image * Final specifications - * Fixed resolution: 512 x 512 px * Grayscale (single channel), 8 bits + * Varying resolutions """ import importlib.resources diff --git a/src/ptbench/data/montgomery/fold_9.py b/src/ptbench/data/montgomery/fold_9.py index 476463f0..758ae340 100644 --- a/src/ptbench/data/montgomery/fold_9.py +++ b/src/ptbench/data/montgomery/fold_9.py @@ -19,13 +19,12 @@ This configuration: * Load raw PNG with :py:mod:`PIL` * Remove black borders - * Torch resizing (512 x 512 px) - * Torch center cropping (512 x 512 px) + * Torch center cropping to get square image * Final specifications - * Fixed resolution: 512 x 512 px * Grayscale (single channel), 8 bits + * Varying resolutions """ import importlib.resources diff --git a/src/ptbench/data/montgomery/loader.py b/src/ptbench/data/montgomery/loader.py index 04451438..45733dad 100644 --- a/src/ptbench/data/montgomery/loader.py +++ b/src/ptbench/data/montgomery/loader.py @@ -19,21 +19,20 @@ This configuration: * Load raw PNG with :py:mod:`PIL` * Remove black borders - * Torch resizing (512 x 512 px) - * Torch center cropping (512 x 512 px) + * Torch center cropping to get square image * Final specifications - * Fixed resolution: 512 x 512 px * Grayscale (single channel), 8 bits + * Varying resolutions """ import os -import torchvision.transforms +from torchvision.transforms.functional import center_crop, to_tensor from ...utils.rc import load_rc -from ..image_utils import RemoveBlackBorders, load_pil +from ..image_utils import load_pil, remove_black_borders from ..typing import RawDataLoader as _BaseRawDataLoader from ..typing import Sample @@ -47,28 +46,15 @@ class RawDataLoader(_BaseRawDataLoader): datadir This variable contains the base directory where the database raw data is stored. - - transform - Transforms that are always applied to the loaded raw images. """ datadir: str - transform: torchvision.transforms.Compose def __init__(self): self.datadir = load_rc().get( "datadir.montgomery", os.path.realpath(os.curdir) ) - self.transform = torchvision.transforms.Compose( - [ - RemoveBlackBorders(), - torchvision.transforms.Resize(512), - torchvision.transforms.CenterCrop(512), - torchvision.transforms.ToTensor(), - ] - ) - def sample(self, sample: tuple[str, int]) -> Sample: """Loads a single image sample from the disk. @@ -87,7 +73,10 @@ class RawDataLoader(_BaseRawDataLoader): sample The sample representation """ - tensor = self.transform(load_pil(os.path.join(self.datadir, sample[0]))) + tensor = load_pil(os.path.join(self.datadir, sample[0])) + tensor = remove_black_borders(tensor) + tensor = center_crop(tensor, min(tensor.size[0], tensor.size[1])) + tensor = to_tensor(tensor) return tensor, dict(label=sample[1], name=sample[0]) # type: ignore[arg-type] diff --git a/src/ptbench/data/shenzhen/default.py b/src/ptbench/data/shenzhen/default.py index bfe93f44..ba8a2b57 100644 --- a/src/ptbench/data/shenzhen/default.py +++ b/src/ptbench/data/shenzhen/default.py @@ -10,7 +10,7 @@ This configuration: * Raw data input (on disk): - * PNG images (black and white, encoded as color images) + * PNG images (grayscale, encoded as RGB images) * Variable width and height: * widths: from 1130 to 3001 pixels @@ -22,13 +22,13 @@ This configuration: * Load raw PNG with :py:mod:`PIL` * Remove black borders - * Torch resizing(512px, 512px) - * Torch center cropping (512px, 512px) + * Torch center cropping to get square image * Final specifications: - * Fixed resolution: 512x512 pixels - * Color RGB encoding + * grayscale, encoded as RGB image + * varying resolutions + """ import importlib.resources diff --git a/src/ptbench/data/shenzhen/fold_0.py b/src/ptbench/data/shenzhen/fold_0.py index 888a0e60..d5c3e447 100644 --- a/src/ptbench/data/shenzhen/fold_0.py +++ b/src/ptbench/data/shenzhen/fold_0.py @@ -10,7 +10,7 @@ This configuration: * Raw data input (on disk): - * PNG images (black and white, encoded as color images) + * PNG images (grayscale, encoded as RGB images) * Variable width and height: * widths: from 1130 to 3001 pixels @@ -22,13 +22,13 @@ This configuration: * Load raw PNG with :py:mod:`PIL` * Remove black borders - * Torch resizing(512px, 512px) - * Torch center cropping (512px, 512px) + * Torch center cropping to get square image * Final specifications: - * Fixed resolution: 512x512 pixels - * Color RGB encoding + * grayscale, encoded as RGB image + * varying resolutions + """ import importlib.resources diff --git a/src/ptbench/data/shenzhen/fold_1.py b/src/ptbench/data/shenzhen/fold_1.py index 62d7fbd5..5c7a294a 100644 --- a/src/ptbench/data/shenzhen/fold_1.py +++ b/src/ptbench/data/shenzhen/fold_1.py @@ -10,7 +10,7 @@ This configuration: * Raw data input (on disk): - * PNG images (black and white, encoded as color images) + * PNG images (grayscale, encoded as RGB images) * Variable width and height: * widths: from 1130 to 3001 pixels @@ -22,13 +22,13 @@ This configuration: * Load raw PNG with :py:mod:`PIL` * Remove black borders - * Torch resizing(512px, 512px) - * Torch center cropping (512px, 512px) + * Torch center cropping to get square image * Final specifications: - * Fixed resolution: 512x512 pixels - * Color RGB encoding + * grayscale, encoded as RGB image + * varying resolutions + """ import importlib.resources diff --git a/src/ptbench/data/shenzhen/fold_2.py b/src/ptbench/data/shenzhen/fold_2.py index b41284cd..31480af5 100644 --- a/src/ptbench/data/shenzhen/fold_2.py +++ b/src/ptbench/data/shenzhen/fold_2.py @@ -10,7 +10,7 @@ This configuration: * Raw data input (on disk): - * PNG images (black and white, encoded as color images) + * PNG images (grayscale, encoded as RGB images) * Variable width and height: * widths: from 1130 to 3001 pixels @@ -22,13 +22,13 @@ This configuration: * Load raw PNG with :py:mod:`PIL` * Remove black borders - * Torch resizing(512px, 512px) - * Torch center cropping (512px, 512px) + * Torch center cropping to get square image * Final specifications: - * Fixed resolution: 512x512 pixels - * Color RGB encoding + * grayscale, encoded as RGB image + * varying resolutions + """ import importlib.resources diff --git a/src/ptbench/data/shenzhen/fold_3.py b/src/ptbench/data/shenzhen/fold_3.py index cca55506..a1881ad6 100644 --- a/src/ptbench/data/shenzhen/fold_3.py +++ b/src/ptbench/data/shenzhen/fold_3.py @@ -10,7 +10,7 @@ This configuration: * Raw data input (on disk): - * PNG images (black and white, encoded as color images) + * PNG images (grayscale, encoded as RGB images) * Variable width and height: * widths: from 1130 to 3001 pixels @@ -22,13 +22,13 @@ This configuration: * Load raw PNG with :py:mod:`PIL` * Remove black borders - * Torch resizing(512px, 512px) - * Torch center cropping (512px, 512px) + * Torch center cropping to get square image * Final specifications: - * Fixed resolution: 512x512 pixels - * Color RGB encoding + * grayscale, encoded as RGB image + * varying resolutions + """ import importlib.resources diff --git a/src/ptbench/data/shenzhen/fold_4.py b/src/ptbench/data/shenzhen/fold_4.py index 89742007..f86eb665 100644 --- a/src/ptbench/data/shenzhen/fold_4.py +++ b/src/ptbench/data/shenzhen/fold_4.py @@ -10,7 +10,7 @@ This configuration: * Raw data input (on disk): - * PNG images (black and white, encoded as color images) + * PNG images (grayscale, encoded as RGB images) * Variable width and height: * widths: from 1130 to 3001 pixels @@ -22,13 +22,13 @@ This configuration: * Load raw PNG with :py:mod:`PIL` * Remove black borders - * Torch resizing(512px, 512px) - * Torch center cropping (512px, 512px) + * Torch center cropping to get square image * Final specifications: - * Fixed resolution: 512x512 pixels - * Color RGB encoding + * grayscale, encoded as RGB image + * varying resolutions + """ import importlib.resources diff --git a/src/ptbench/data/shenzhen/fold_5.py b/src/ptbench/data/shenzhen/fold_5.py index c520399d..16ea44f7 100644 --- a/src/ptbench/data/shenzhen/fold_5.py +++ b/src/ptbench/data/shenzhen/fold_5.py @@ -10,7 +10,7 @@ This configuration: * Raw data input (on disk): - * PNG images (black and white, encoded as color images) + * PNG images (grayscale, encoded as RGB images) * Variable width and height: * widths: from 1130 to 3001 pixels @@ -22,13 +22,13 @@ This configuration: * Load raw PNG with :py:mod:`PIL` * Remove black borders - * Torch resizing(512px, 512px) - * Torch center cropping (512px, 512px) + * Torch center cropping to get square image * Final specifications: - * Fixed resolution: 512x512 pixels - * Color RGB encoding + * grayscale, encoded as RGB image + * varying resolutions + """ import importlib.resources diff --git a/src/ptbench/data/shenzhen/fold_6.py b/src/ptbench/data/shenzhen/fold_6.py index a28f8fc5..caecaa0d 100644 --- a/src/ptbench/data/shenzhen/fold_6.py +++ b/src/ptbench/data/shenzhen/fold_6.py @@ -10,7 +10,7 @@ This configuration: * Raw data input (on disk): - * PNG images (black and white, encoded as color images) + * PNG images (grayscale, encoded as RGB images) * Variable width and height: * widths: from 1130 to 3001 pixels @@ -22,13 +22,13 @@ This configuration: * Load raw PNG with :py:mod:`PIL` * Remove black borders - * Torch resizing(512px, 512px) - * Torch center cropping (512px, 512px) + * Torch center cropping to get square image * Final specifications: - * Fixed resolution: 512x512 pixels - * Color RGB encoding + * grayscale, encoded as RGB image + * varying resolutions + """ import importlib.resources diff --git a/src/ptbench/data/shenzhen/fold_7.py b/src/ptbench/data/shenzhen/fold_7.py index b0ea7b43..ee34486c 100644 --- a/src/ptbench/data/shenzhen/fold_7.py +++ b/src/ptbench/data/shenzhen/fold_7.py @@ -10,7 +10,7 @@ This configuration: * Raw data input (on disk): - * PNG images (black and white, encoded as color images) + * PNG images (grayscale, encoded as RGB images) * Variable width and height: * widths: from 1130 to 3001 pixels @@ -22,13 +22,13 @@ This configuration: * Load raw PNG with :py:mod:`PIL` * Remove black borders - * Torch resizing(512px, 512px) - * Torch center cropping (512px, 512px) + * Torch center cropping to get square image * Final specifications: - * Fixed resolution: 512x512 pixels - * Color RGB encoding + * grayscale, encoded as RGB image + * varying resolutions + """ import importlib.resources diff --git a/src/ptbench/data/shenzhen/fold_8.py b/src/ptbench/data/shenzhen/fold_8.py index 9bbfbe84..6c81ca70 100644 --- a/src/ptbench/data/shenzhen/fold_8.py +++ b/src/ptbench/data/shenzhen/fold_8.py @@ -10,7 +10,7 @@ This configuration: * Raw data input (on disk): - * PNG images (black and white, encoded as color images) + * PNG images (grayscale, encoded as RGB images) * Variable width and height: * widths: from 1130 to 3001 pixels @@ -22,13 +22,13 @@ This configuration: * Load raw PNG with :py:mod:`PIL` * Remove black borders - * Torch resizing(512px, 512px) - * Torch center cropping (512px, 512px) + * Torch center cropping to get square image * Final specifications: - * Fixed resolution: 512x512 pixels - * Color RGB encoding + * grayscale, encoded as RGB image + * varying resolutions + """ import importlib.resources diff --git a/src/ptbench/data/shenzhen/fold_9.py b/src/ptbench/data/shenzhen/fold_9.py index 87c2afb3..546e449e 100644 --- a/src/ptbench/data/shenzhen/fold_9.py +++ b/src/ptbench/data/shenzhen/fold_9.py @@ -10,7 +10,7 @@ This configuration: * Raw data input (on disk): - * PNG images (black and white, encoded as color images) + * PNG images (grayscale, encoded as RGB images) * Variable width and height: * widths: from 1130 to 3001 pixels @@ -22,13 +22,13 @@ This configuration: * Load raw PNG with :py:mod:`PIL` * Remove black borders - * Torch resizing(512px, 512px) - * Torch center cropping (512px, 512px) + * Torch center cropping to get square image * Final specifications: - * Fixed resolution: 512x512 pixels - * Color RGB encoding + * grayscale, encoded as RGB image + * varying resolutions + """ import importlib.resources diff --git a/src/ptbench/data/shenzhen/loader.py b/src/ptbench/data/shenzhen/loader.py index 49ccf8bf..df4d4afb 100644 --- a/src/ptbench/data/shenzhen/loader.py +++ b/src/ptbench/data/shenzhen/loader.py @@ -22,10 +22,10 @@ Philips DR Digital Diagnose systems. import os -import torchvision.transforms +from torchvision.transforms.functional import center_crop, to_tensor from ...utils.rc import load_rc -from ..image_utils import RemoveBlackBorders, load_pil_baw +from ..image_utils import load_pil_rgb, remove_black_borders from ..typing import RawDataLoader as _BaseRawDataLoader from ..typing import Sample @@ -45,22 +45,12 @@ class RawDataLoader(_BaseRawDataLoader): """ datadir: str - transform: torchvision.transforms.Compose def __init__(self): self.datadir = load_rc().get( "datadir.shenzhen", os.path.realpath(os.curdir) ) - self.transform = torchvision.transforms.Compose( - [ - RemoveBlackBorders(), - torchvision.transforms.Resize(512), - torchvision.transforms.CenterCrop(512), - torchvision.transforms.ToTensor(), - ] - ) - def sample(self, sample: tuple[str, int]) -> Sample: """Loads a single image sample from the disk. @@ -79,9 +69,10 @@ class RawDataLoader(_BaseRawDataLoader): sample The sample representation """ - tensor = self.transform( - load_pil_baw(os.path.join(self.datadir, sample[0])) - ) + tensor = load_pil_rgb(os.path.join(self.datadir, sample[0])) + tensor = remove_black_borders(tensor) + tensor = center_crop(tensor, min(tensor.size[0], tensor.size[1])) + tensor = to_tensor(tensor) return tensor, dict(label=sample[1], name=sample[0]) # type: ignore[arg-type] diff --git a/tests/test_ch.py b/tests/test_ch.py index 659e2c35..b28c81e9 100644 --- a/tests/test_ch.py +++ b/tests/test_ch.py @@ -120,11 +120,6 @@ def test_loading(): from ptbench.data.datamodule import _DelayedLoadingDataset - def _check_size(shape): - if shape[0] == 1 and shape[1] == 512 and shape[2] == 512: - return True - return False - def _check_sample(s): assert len(s) == 2 @@ -132,10 +127,12 @@ def test_loading(): metadata = s[1] assert isinstance(data, torch.Tensor) - assert _check_size(data.shape) # Check size + + assert data.size(0) == 3 # check 3 channels + assert data.size(1) == data.size(2) # check square image assert ( - torchvision.transforms.ToPILImage()(data).mode == "L" + torchvision.transforms.ToPILImage()(data).mode == "RGB" ) # Check colors assert "label" in metadata diff --git a/tests/test_mc.py b/tests/test_mc.py index ec4d85cd..2fcd14ac 100644 --- a/tests/test_mc.py +++ b/tests/test_mc.py @@ -126,11 +126,9 @@ def test_loading(): assert isinstance(data, torch.Tensor) - assert data.size() in ( - (1, 4020, 4892), # portrait - (1, 4892, 4020), # landscape - (1, 512, 512), # test database @ CI - ) + assert data.size(0) == 1 # check single channel + assert data.size(1) == data.size(2) # check square image + assert ( torchvision.transforms.ToPILImage()(data).mode == "L" ) # Check colors -- GitLab