diff --git a/src/ptbench/data/image_utils.py b/src/ptbench/data/image_utils.py index ac31b9ce7fbce85fb688b394c99d591b83049f7f..ed284afc4ab63fd804a8110ff676c29917d968b4 100644 --- a/src/ptbench/data/image_utils.py +++ b/src/ptbench/data/image_utils.py @@ -31,6 +31,40 @@ class SingleAutoLevel16to8: ).convert("L") +def remove_black_borders(img: PIL.Image.Image, threshold: int=0) -> PIL.Image.Image: + """Remove black borders of CXR + + Parameters + ---------- + img + A PIL image + threshold + Threshold value from which borders are considered black. + Defaults to 0. + + Returns + ------- + A PIL image with black borders removed + """ + + img = numpy.asarray(img) + + if len(img.shape) == 2: # single channel + mask = numpy.asarray(img) > threshold + return PIL.Image.fromarray(img[numpy.ix_(mask.any(1), mask.any(0))]) + + elif len(img.shape) == 3 and img.shape[2] == 3: + r_mask = img[:, :, 0] > threshold + g_mask = img[:, :, 1] > threshold + b_mask = img[:, :, 2] > threshold + + mask = r_mask | g_mask | b_mask + return PIL.Image.fromarray(img[numpy.ix_(mask.any(1), mask.any(0))]) + + else: + raise NotImplementedError + + class RemoveBlackBorders: """Remove black borders of CXR.""" @@ -38,9 +72,7 @@ class RemoveBlackBorders: self.threshold = threshold def __call__(self, img): - img = numpy.asarray(img) - mask = numpy.asarray(img) > self.threshold - return PIL.Image.fromarray(img[numpy.ix_(mask.any(1), mask.any(0))]) + return remove_black_borders(img, self.threshold) def load_pil(path: str | pathlib.Path) -> PIL.Image.Image: diff --git a/src/ptbench/data/montgomery/default.py b/src/ptbench/data/montgomery/default.py index 5a0aad35973c733eafa6f68b407eabec1a5c2117..35878b1a761f1b93a273c585207541921d23a8f8 100644 --- a/src/ptbench/data/montgomery/default.py +++ b/src/ptbench/data/montgomery/default.py @@ -19,13 +19,12 @@ This configuration: * Load raw PNG with :py:mod:`PIL` * Remove black borders - * Torch resizing (512 x 512 px) - * Torch center cropping (512 x 512 px) + * Torch center cropping to get square image * Final specifications - * Fixed resolution: 512 x 512 px * Grayscale (single channel), 8 bits + * Varying resolutions Protocol ``default``: diff --git a/src/ptbench/data/montgomery/fold_0.py b/src/ptbench/data/montgomery/fold_0.py index 8567b7a5b2749e538c365fd69de8f43f26cd22d2..4bfc478404555fa16af02469c24ab3aa1797fc1c 100644 --- a/src/ptbench/data/montgomery/fold_0.py +++ b/src/ptbench/data/montgomery/fold_0.py @@ -19,13 +19,12 @@ This configuration: * Load raw PNG with :py:mod:`PIL` * Remove black borders - * Torch resizing (512 x 512 px) - * Torch center cropping (512 x 512 px) + * Torch center cropping to get square image * Final specifications - * Fixed resolution: 512 x 512 px * Grayscale (single channel), 8 bits + * Varying resolutions """ import importlib.resources diff --git a/src/ptbench/data/montgomery/fold_1.py b/src/ptbench/data/montgomery/fold_1.py index 6b9679b287060bf16200f3ff3bfdca291b7cc623..0a74516fd365a34614419da75a8beadbee5f6dda 100644 --- a/src/ptbench/data/montgomery/fold_1.py +++ b/src/ptbench/data/montgomery/fold_1.py @@ -19,13 +19,12 @@ This configuration: * Load raw PNG with :py:mod:`PIL` * Remove black borders - * Torch resizing (512 x 512 px) - * Torch center cropping (512 x 512 px) + * Torch center cropping to get square image * Final specifications - * Fixed resolution: 512 x 512 px * Grayscale (single channel), 8 bits + * Varying resolutions """ import importlib.resources diff --git a/src/ptbench/data/montgomery/fold_2.py b/src/ptbench/data/montgomery/fold_2.py index b2bcc3cd6a56840022f0de81e6733d78a71f43e5..386d30803c3cdc7a1c78622eaea412b6025ac10f 100644 --- a/src/ptbench/data/montgomery/fold_2.py +++ b/src/ptbench/data/montgomery/fold_2.py @@ -19,13 +19,12 @@ This configuration: * Load raw PNG with :py:mod:`PIL` * Remove black borders - * Torch resizing (512 x 512 px) - * Torch center cropping (512 x 512 px) + * Torch center cropping to get square image * Final specifications - * Fixed resolution: 512 x 512 px * Grayscale (single channel), 8 bits + * Varying resolutions """ import importlib.resources diff --git a/src/ptbench/data/montgomery/fold_3.py b/src/ptbench/data/montgomery/fold_3.py index 63bc27ec80c79a4e78fa4add406e6dae68f12628..1bfaa888987ff7053dbfc9634f60829e2748a078 100644 --- a/src/ptbench/data/montgomery/fold_3.py +++ b/src/ptbench/data/montgomery/fold_3.py @@ -19,13 +19,12 @@ This configuration: * Load raw PNG with :py:mod:`PIL` * Remove black borders - * Torch resizing (512 x 512 px) - * Torch center cropping (512 x 512 px) + * Torch center cropping to get square image * Final specifications - * Fixed resolution: 512 x 512 px * Grayscale (single channel), 8 bits + * Varying resolutions """ import importlib.resources diff --git a/src/ptbench/data/montgomery/fold_4.py b/src/ptbench/data/montgomery/fold_4.py index cbf3179959cf49c6cf48c826aa76d6bde91e1e85..b955cb1503f41f3bd0b83376a61fc49fba985619 100644 --- a/src/ptbench/data/montgomery/fold_4.py +++ b/src/ptbench/data/montgomery/fold_4.py @@ -19,13 +19,12 @@ This configuration: * Load raw PNG with :py:mod:`PIL` * Remove black borders - * Torch resizing (512 x 512 px) - * Torch center cropping (512 x 512 px) + * Torch center cropping to get square image * Final specifications - * Fixed resolution: 512 x 512 px * Grayscale (single channel), 8 bits + * Varying resolutions """ import importlib.resources diff --git a/src/ptbench/data/montgomery/fold_5.py b/src/ptbench/data/montgomery/fold_5.py index aca9c62b71f5000a1a26b879ed9a7493de791dd1..5604cfdd3496ac085200b5c06793a1ecdafb19f1 100644 --- a/src/ptbench/data/montgomery/fold_5.py +++ b/src/ptbench/data/montgomery/fold_5.py @@ -19,13 +19,12 @@ This configuration: * Load raw PNG with :py:mod:`PIL` * Remove black borders - * Torch resizing (512 x 512 px) - * Torch center cropping (512 x 512 px) + * Torch center cropping to get square image * Final specifications - * Fixed resolution: 512 x 512 px * Grayscale (single channel), 8 bits + * Varying resolutions """ import importlib.resources diff --git a/src/ptbench/data/montgomery/fold_6.py b/src/ptbench/data/montgomery/fold_6.py index ac5cdd57f52d53c4732903104e53420d0ac24f51..72178209acbe75c938c73d3f055d95aee71f74a8 100644 --- a/src/ptbench/data/montgomery/fold_6.py +++ b/src/ptbench/data/montgomery/fold_6.py @@ -19,13 +19,12 @@ This configuration: * Load raw PNG with :py:mod:`PIL` * Remove black borders - * Torch resizing (512 x 512 px) - * Torch center cropping (512 x 512 px) + * Torch center cropping to get square image * Final specifications - * Fixed resolution: 512 x 512 px * Grayscale (single channel), 8 bits + * Varying resolutions """ import importlib.resources diff --git a/src/ptbench/data/montgomery/fold_7.py b/src/ptbench/data/montgomery/fold_7.py index edf8957a5820ecf859289a345a961270cc098dc0..de895133b4b329ef6b8344c2228e1603f68d1315 100644 --- a/src/ptbench/data/montgomery/fold_7.py +++ b/src/ptbench/data/montgomery/fold_7.py @@ -19,13 +19,12 @@ This configuration: * Load raw PNG with :py:mod:`PIL` * Remove black borders - * Torch resizing (512 x 512 px) - * Torch center cropping (512 x 512 px) + * Torch center cropping to get square image * Final specifications - * Fixed resolution: 512 x 512 px * Grayscale (single channel), 8 bits + * Varying resolutions """ import importlib.resources diff --git a/src/ptbench/data/montgomery/fold_8.py b/src/ptbench/data/montgomery/fold_8.py index ed0e5a91eaac05f9563d8053c2b814b02d282133..bc901fc5038a19fb80b465cdb18e10647e3a03ee 100644 --- a/src/ptbench/data/montgomery/fold_8.py +++ b/src/ptbench/data/montgomery/fold_8.py @@ -19,13 +19,12 @@ This configuration: * Load raw PNG with :py:mod:`PIL` * Remove black borders - * Torch resizing (512 x 512 px) - * Torch center cropping (512 x 512 px) + * Torch center cropping to get square image * Final specifications - * Fixed resolution: 512 x 512 px * Grayscale (single channel), 8 bits + * Varying resolutions """ import importlib.resources diff --git a/src/ptbench/data/montgomery/fold_9.py b/src/ptbench/data/montgomery/fold_9.py index 476463f0fcd80139732952e3849c89f9cb5dff5b..758ae3403a2944518fc17219a31ff71175e21fae 100644 --- a/src/ptbench/data/montgomery/fold_9.py +++ b/src/ptbench/data/montgomery/fold_9.py @@ -19,13 +19,12 @@ This configuration: * Load raw PNG with :py:mod:`PIL` * Remove black borders - * Torch resizing (512 x 512 px) - * Torch center cropping (512 x 512 px) + * Torch center cropping to get square image * Final specifications - * Fixed resolution: 512 x 512 px * Grayscale (single channel), 8 bits + * Varying resolutions """ import importlib.resources diff --git a/src/ptbench/data/montgomery/loader.py b/src/ptbench/data/montgomery/loader.py index 04451438b4164e82f467bb0145cb7bd3be16e06d..45733dadde59eac18c43be953f54ed15bc89db47 100644 --- a/src/ptbench/data/montgomery/loader.py +++ b/src/ptbench/data/montgomery/loader.py @@ -19,21 +19,20 @@ This configuration: * Load raw PNG with :py:mod:`PIL` * Remove black borders - * Torch resizing (512 x 512 px) - * Torch center cropping (512 x 512 px) + * Torch center cropping to get square image * Final specifications - * Fixed resolution: 512 x 512 px * Grayscale (single channel), 8 bits + * Varying resolutions """ import os -import torchvision.transforms +from torchvision.transforms.functional import center_crop, to_tensor from ...utils.rc import load_rc -from ..image_utils import RemoveBlackBorders, load_pil +from ..image_utils import load_pil, remove_black_borders from ..typing import RawDataLoader as _BaseRawDataLoader from ..typing import Sample @@ -47,28 +46,15 @@ class RawDataLoader(_BaseRawDataLoader): datadir This variable contains the base directory where the database raw data is stored. - - transform - Transforms that are always applied to the loaded raw images. """ datadir: str - transform: torchvision.transforms.Compose def __init__(self): self.datadir = load_rc().get( "datadir.montgomery", os.path.realpath(os.curdir) ) - self.transform = torchvision.transforms.Compose( - [ - RemoveBlackBorders(), - torchvision.transforms.Resize(512), - torchvision.transforms.CenterCrop(512), - torchvision.transforms.ToTensor(), - ] - ) - def sample(self, sample: tuple[str, int]) -> Sample: """Loads a single image sample from the disk. @@ -87,7 +73,10 @@ class RawDataLoader(_BaseRawDataLoader): sample The sample representation """ - tensor = self.transform(load_pil(os.path.join(self.datadir, sample[0]))) + tensor = load_pil(os.path.join(self.datadir, sample[0])) + tensor = remove_black_borders(tensor) + tensor = center_crop(tensor, min(tensor.size[0], tensor.size[1])) + tensor = to_tensor(tensor) return tensor, dict(label=sample[1], name=sample[0]) # type: ignore[arg-type] diff --git a/src/ptbench/data/shenzhen/default.py b/src/ptbench/data/shenzhen/default.py index bfe93f44faaa9df235f357c1cc3a927412f4a011..ba8a2b5714acbbe62e7e0f75c1fe8ec040721459 100644 --- a/src/ptbench/data/shenzhen/default.py +++ b/src/ptbench/data/shenzhen/default.py @@ -10,7 +10,7 @@ This configuration: * Raw data input (on disk): - * PNG images (black and white, encoded as color images) + * PNG images (grayscale, encoded as RGB images) * Variable width and height: * widths: from 1130 to 3001 pixels @@ -22,13 +22,13 @@ This configuration: * Load raw PNG with :py:mod:`PIL` * Remove black borders - * Torch resizing(512px, 512px) - * Torch center cropping (512px, 512px) + * Torch center cropping to get square image * Final specifications: - * Fixed resolution: 512x512 pixels - * Color RGB encoding + * grayscale, encoded as RGB image + * varying resolutions + """ import importlib.resources diff --git a/src/ptbench/data/shenzhen/fold_0.py b/src/ptbench/data/shenzhen/fold_0.py index 888a0e60024480a3aaff65f6e3d819370fd22669..d5c3e44737fe7a090b3b7ae33c7cc49c0843edf3 100644 --- a/src/ptbench/data/shenzhen/fold_0.py +++ b/src/ptbench/data/shenzhen/fold_0.py @@ -10,7 +10,7 @@ This configuration: * Raw data input (on disk): - * PNG images (black and white, encoded as color images) + * PNG images (grayscale, encoded as RGB images) * Variable width and height: * widths: from 1130 to 3001 pixels @@ -22,13 +22,13 @@ This configuration: * Load raw PNG with :py:mod:`PIL` * Remove black borders - * Torch resizing(512px, 512px) - * Torch center cropping (512px, 512px) + * Torch center cropping to get square image * Final specifications: - * Fixed resolution: 512x512 pixels - * Color RGB encoding + * grayscale, encoded as RGB image + * varying resolutions + """ import importlib.resources diff --git a/src/ptbench/data/shenzhen/fold_1.py b/src/ptbench/data/shenzhen/fold_1.py index 62d7fbd55c83ed746754cbc99dcc65fe48efbc6a..5c7a294aa86109e3700b99ff953237fc8afab18e 100644 --- a/src/ptbench/data/shenzhen/fold_1.py +++ b/src/ptbench/data/shenzhen/fold_1.py @@ -10,7 +10,7 @@ This configuration: * Raw data input (on disk): - * PNG images (black and white, encoded as color images) + * PNG images (grayscale, encoded as RGB images) * Variable width and height: * widths: from 1130 to 3001 pixels @@ -22,13 +22,13 @@ This configuration: * Load raw PNG with :py:mod:`PIL` * Remove black borders - * Torch resizing(512px, 512px) - * Torch center cropping (512px, 512px) + * Torch center cropping to get square image * Final specifications: - * Fixed resolution: 512x512 pixels - * Color RGB encoding + * grayscale, encoded as RGB image + * varying resolutions + """ import importlib.resources diff --git a/src/ptbench/data/shenzhen/fold_2.py b/src/ptbench/data/shenzhen/fold_2.py index b41284cd9d1c4a56c70eff715078f82213dabb3c..31480af5e70a72db203795d49631dd13c7df290a 100644 --- a/src/ptbench/data/shenzhen/fold_2.py +++ b/src/ptbench/data/shenzhen/fold_2.py @@ -10,7 +10,7 @@ This configuration: * Raw data input (on disk): - * PNG images (black and white, encoded as color images) + * PNG images (grayscale, encoded as RGB images) * Variable width and height: * widths: from 1130 to 3001 pixels @@ -22,13 +22,13 @@ This configuration: * Load raw PNG with :py:mod:`PIL` * Remove black borders - * Torch resizing(512px, 512px) - * Torch center cropping (512px, 512px) + * Torch center cropping to get square image * Final specifications: - * Fixed resolution: 512x512 pixels - * Color RGB encoding + * grayscale, encoded as RGB image + * varying resolutions + """ import importlib.resources diff --git a/src/ptbench/data/shenzhen/fold_3.py b/src/ptbench/data/shenzhen/fold_3.py index cca555064e9923433ef39f591b3e342365cf7afc..a1881ad69e928bc6dfaccec4db9c30b8da5027da 100644 --- a/src/ptbench/data/shenzhen/fold_3.py +++ b/src/ptbench/data/shenzhen/fold_3.py @@ -10,7 +10,7 @@ This configuration: * Raw data input (on disk): - * PNG images (black and white, encoded as color images) + * PNG images (grayscale, encoded as RGB images) * Variable width and height: * widths: from 1130 to 3001 pixels @@ -22,13 +22,13 @@ This configuration: * Load raw PNG with :py:mod:`PIL` * Remove black borders - * Torch resizing(512px, 512px) - * Torch center cropping (512px, 512px) + * Torch center cropping to get square image * Final specifications: - * Fixed resolution: 512x512 pixels - * Color RGB encoding + * grayscale, encoded as RGB image + * varying resolutions + """ import importlib.resources diff --git a/src/ptbench/data/shenzhen/fold_4.py b/src/ptbench/data/shenzhen/fold_4.py index 897420076303e47406cc9efb3b6bf0d294ab3611..f86eb6659b64ac24204f7a5a57edccb533e3379a 100644 --- a/src/ptbench/data/shenzhen/fold_4.py +++ b/src/ptbench/data/shenzhen/fold_4.py @@ -10,7 +10,7 @@ This configuration: * Raw data input (on disk): - * PNG images (black and white, encoded as color images) + * PNG images (grayscale, encoded as RGB images) * Variable width and height: * widths: from 1130 to 3001 pixels @@ -22,13 +22,13 @@ This configuration: * Load raw PNG with :py:mod:`PIL` * Remove black borders - * Torch resizing(512px, 512px) - * Torch center cropping (512px, 512px) + * Torch center cropping to get square image * Final specifications: - * Fixed resolution: 512x512 pixels - * Color RGB encoding + * grayscale, encoded as RGB image + * varying resolutions + """ import importlib.resources diff --git a/src/ptbench/data/shenzhen/fold_5.py b/src/ptbench/data/shenzhen/fold_5.py index c520399d98ead9eeb1e3bdcfbe4dc48393adcebc..16ea44f7809b0e595b9fb62d2496eaaf04b585ba 100644 --- a/src/ptbench/data/shenzhen/fold_5.py +++ b/src/ptbench/data/shenzhen/fold_5.py @@ -10,7 +10,7 @@ This configuration: * Raw data input (on disk): - * PNG images (black and white, encoded as color images) + * PNG images (grayscale, encoded as RGB images) * Variable width and height: * widths: from 1130 to 3001 pixels @@ -22,13 +22,13 @@ This configuration: * Load raw PNG with :py:mod:`PIL` * Remove black borders - * Torch resizing(512px, 512px) - * Torch center cropping (512px, 512px) + * Torch center cropping to get square image * Final specifications: - * Fixed resolution: 512x512 pixels - * Color RGB encoding + * grayscale, encoded as RGB image + * varying resolutions + """ import importlib.resources diff --git a/src/ptbench/data/shenzhen/fold_6.py b/src/ptbench/data/shenzhen/fold_6.py index a28f8fc5ca3e0ebd4b49fceaec99d3a2e94dd34c..caecaa0dd41e1c20a2618640c56f781db7937e39 100644 --- a/src/ptbench/data/shenzhen/fold_6.py +++ b/src/ptbench/data/shenzhen/fold_6.py @@ -10,7 +10,7 @@ This configuration: * Raw data input (on disk): - * PNG images (black and white, encoded as color images) + * PNG images (grayscale, encoded as RGB images) * Variable width and height: * widths: from 1130 to 3001 pixels @@ -22,13 +22,13 @@ This configuration: * Load raw PNG with :py:mod:`PIL` * Remove black borders - * Torch resizing(512px, 512px) - * Torch center cropping (512px, 512px) + * Torch center cropping to get square image * Final specifications: - * Fixed resolution: 512x512 pixels - * Color RGB encoding + * grayscale, encoded as RGB image + * varying resolutions + """ import importlib.resources diff --git a/src/ptbench/data/shenzhen/fold_7.py b/src/ptbench/data/shenzhen/fold_7.py index b0ea7b4324334980a2e55e4496ac4ab6af705d17..ee34486cf3cad7856dc0055d02d87f74b68a8b13 100644 --- a/src/ptbench/data/shenzhen/fold_7.py +++ b/src/ptbench/data/shenzhen/fold_7.py @@ -10,7 +10,7 @@ This configuration: * Raw data input (on disk): - * PNG images (black and white, encoded as color images) + * PNG images (grayscale, encoded as RGB images) * Variable width and height: * widths: from 1130 to 3001 pixels @@ -22,13 +22,13 @@ This configuration: * Load raw PNG with :py:mod:`PIL` * Remove black borders - * Torch resizing(512px, 512px) - * Torch center cropping (512px, 512px) + * Torch center cropping to get square image * Final specifications: - * Fixed resolution: 512x512 pixels - * Color RGB encoding + * grayscale, encoded as RGB image + * varying resolutions + """ import importlib.resources diff --git a/src/ptbench/data/shenzhen/fold_8.py b/src/ptbench/data/shenzhen/fold_8.py index 9bbfbe84ab942cf5da5a8c5fc8318724908998f9..6c81ca7062e3b55441f3052b04e7ae7d317fc9d2 100644 --- a/src/ptbench/data/shenzhen/fold_8.py +++ b/src/ptbench/data/shenzhen/fold_8.py @@ -10,7 +10,7 @@ This configuration: * Raw data input (on disk): - * PNG images (black and white, encoded as color images) + * PNG images (grayscale, encoded as RGB images) * Variable width and height: * widths: from 1130 to 3001 pixels @@ -22,13 +22,13 @@ This configuration: * Load raw PNG with :py:mod:`PIL` * Remove black borders - * Torch resizing(512px, 512px) - * Torch center cropping (512px, 512px) + * Torch center cropping to get square image * Final specifications: - * Fixed resolution: 512x512 pixels - * Color RGB encoding + * grayscale, encoded as RGB image + * varying resolutions + """ import importlib.resources diff --git a/src/ptbench/data/shenzhen/fold_9.py b/src/ptbench/data/shenzhen/fold_9.py index 87c2afb328f9b09f420a1ddce5f5d0ea54346c43..546e449e4288c262886c3570f4e3a09735f923ad 100644 --- a/src/ptbench/data/shenzhen/fold_9.py +++ b/src/ptbench/data/shenzhen/fold_9.py @@ -10,7 +10,7 @@ This configuration: * Raw data input (on disk): - * PNG images (black and white, encoded as color images) + * PNG images (grayscale, encoded as RGB images) * Variable width and height: * widths: from 1130 to 3001 pixels @@ -22,13 +22,13 @@ This configuration: * Load raw PNG with :py:mod:`PIL` * Remove black borders - * Torch resizing(512px, 512px) - * Torch center cropping (512px, 512px) + * Torch center cropping to get square image * Final specifications: - * Fixed resolution: 512x512 pixels - * Color RGB encoding + * grayscale, encoded as RGB image + * varying resolutions + """ import importlib.resources diff --git a/src/ptbench/data/shenzhen/loader.py b/src/ptbench/data/shenzhen/loader.py index 49ccf8bfb217e411004228b4acf7c924e3ffec66..df4d4afbb0e4afa22522b7ebc403dc91372fefcd 100644 --- a/src/ptbench/data/shenzhen/loader.py +++ b/src/ptbench/data/shenzhen/loader.py @@ -22,10 +22,10 @@ Philips DR Digital Diagnose systems. import os -import torchvision.transforms +from torchvision.transforms.functional import center_crop, to_tensor from ...utils.rc import load_rc -from ..image_utils import RemoveBlackBorders, load_pil_baw +from ..image_utils import load_pil_rgb, remove_black_borders from ..typing import RawDataLoader as _BaseRawDataLoader from ..typing import Sample @@ -45,22 +45,12 @@ class RawDataLoader(_BaseRawDataLoader): """ datadir: str - transform: torchvision.transforms.Compose def __init__(self): self.datadir = load_rc().get( "datadir.shenzhen", os.path.realpath(os.curdir) ) - self.transform = torchvision.transforms.Compose( - [ - RemoveBlackBorders(), - torchvision.transforms.Resize(512), - torchvision.transforms.CenterCrop(512), - torchvision.transforms.ToTensor(), - ] - ) - def sample(self, sample: tuple[str, int]) -> Sample: """Loads a single image sample from the disk. @@ -79,9 +69,10 @@ class RawDataLoader(_BaseRawDataLoader): sample The sample representation """ - tensor = self.transform( - load_pil_baw(os.path.join(self.datadir, sample[0])) - ) + tensor = load_pil_rgb(os.path.join(self.datadir, sample[0])) + tensor = remove_black_borders(tensor) + tensor = center_crop(tensor, min(tensor.size[0], tensor.size[1])) + tensor = to_tensor(tensor) return tensor, dict(label=sample[1], name=sample[0]) # type: ignore[arg-type] diff --git a/tests/test_ch.py b/tests/test_ch.py index 659e2c35ae092f3a90f7d072ba033786bb80bdf9..b28c81e93ef0765c59dd7252b93ecdecaeff6946 100644 --- a/tests/test_ch.py +++ b/tests/test_ch.py @@ -120,11 +120,6 @@ def test_loading(): from ptbench.data.datamodule import _DelayedLoadingDataset - def _check_size(shape): - if shape[0] == 1 and shape[1] == 512 and shape[2] == 512: - return True - return False - def _check_sample(s): assert len(s) == 2 @@ -132,10 +127,12 @@ def test_loading(): metadata = s[1] assert isinstance(data, torch.Tensor) - assert _check_size(data.shape) # Check size + + assert data.size(0) == 3 # check 3 channels + assert data.size(1) == data.size(2) # check square image assert ( - torchvision.transforms.ToPILImage()(data).mode == "L" + torchvision.transforms.ToPILImage()(data).mode == "RGB" ) # Check colors assert "label" in metadata diff --git a/tests/test_mc.py b/tests/test_mc.py index ec4d85cd01a42d018104e053fb3a20d7c7c62e4f..2fcd14ac131f5d919e954bd4d9562effd4a19296 100644 --- a/tests/test_mc.py +++ b/tests/test_mc.py @@ -126,11 +126,9 @@ def test_loading(): assert isinstance(data, torch.Tensor) - assert data.size() in ( - (1, 4020, 4892), # portrait - (1, 4892, 4020), # landscape - (1, 512, 512), # test database @ CI - ) + assert data.size(0) == 1 # check single channel + assert data.size(1) == data.size(2) # check square image + assert ( torchvision.transforms.ToPILImage()(data).mode == "L" ) # Check colors