From bfc106ab413c9b0245556a9280db81f0d16ceb4b Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.dos.anjos@gmail.com> Date: Fri, 21 Jul 2023 20:35:02 +0200 Subject: [PATCH] [tests] Some reformatting to make black happy --- src/ptbench/data/image_utils.py | 16 +++++++++------- src/ptbench/utils/checkpointer.py | 10 +++++++--- tests/test_ch.py | 4 ++-- tests/test_mc.py | 15 +++++---------- 4 files changed, 23 insertions(+), 22 deletions(-) diff --git a/src/ptbench/data/image_utils.py b/src/ptbench/data/image_utils.py index ed284afc..b1c9d82e 100644 --- a/src/ptbench/data/image_utils.py +++ b/src/ptbench/data/image_utils.py @@ -31,14 +31,16 @@ class SingleAutoLevel16to8: ).convert("L") -def remove_black_borders(img: PIL.Image.Image, threshold: int=0) -> PIL.Image.Image: - """Remove black borders of CXR +def remove_black_borders( + img: PIL.Image.Image, threshold: int = 0 +) -> PIL.Image.Image: + """Remove black borders of CXR. Parameters ---------- - img + img A PIL image - threshold + threshold Threshold value from which borders are considered black. Defaults to 0. @@ -49,10 +51,10 @@ def remove_black_borders(img: PIL.Image.Image, threshold: int=0) -> PIL.Image.Im img = numpy.asarray(img) - if len(img.shape) == 2: # single channel + if len(img.shape) == 2: # single channel mask = numpy.asarray(img) > threshold return PIL.Image.fromarray(img[numpy.ix_(mask.any(1), mask.any(0))]) - + elif len(img.shape) == 3 and img.shape[2] == 3: r_mask = img[:, :, 0] > threshold g_mask = img[:, :, 1] > threshold @@ -60,7 +62,7 @@ def remove_black_borders(img: PIL.Image.Image, threshold: int=0) -> PIL.Image.Im mask = r_mask | g_mask | b_mask return PIL.Image.fromarray(img[numpy.ix_(mask.any(1), mask.any(0))]) - + else: raise NotImplementedError diff --git a/src/ptbench/utils/checkpointer.py b/src/ptbench/utils/checkpointer.py index 5c2f272c..318811b0 100644 --- a/src/ptbench/utils/checkpointer.py +++ b/src/ptbench/utils/checkpointer.py @@ -1,11 +1,13 @@ import logging import os - import typing + logger = logging.getLogger(__name__) -def get_checkpoint(output_folder: str, resume_from: typing.Literal["last", "best"] | str | None) -> str | None : +def get_checkpoint( + output_folder: str, resume_from: typing.Literal["last", "best"] | str | None +) -> str | None: """Gets a checkpoint file. Can return the best or last checkpoint, or a checkpoint at a specific path. @@ -56,7 +58,9 @@ def get_checkpoint(output_folder: str, resume_from: typing.Literal["last", "best elif resume_from is None: if os.path.isfile(last_checkpoint_path): checkpoint_file = last_checkpoint_path - logger.info(f"Found existing checkpoint {last_checkpoint_path}. Loading.") + logger.info( + f"Found existing checkpoint {last_checkpoint_path}. Loading." + ) else: return None diff --git a/tests/test_ch.py b/tests/test_ch.py index b28c81e9..c678e087 100644 --- a/tests/test_ch.py +++ b/tests/test_ch.py @@ -128,8 +128,8 @@ def test_loading(): assert isinstance(data, torch.Tensor) - assert data.size(0) == 3 # check 3 channels - assert data.size(1) == data.size(2) # check square image + assert data.size(0) == 3 # check 3 channels + assert data.size(1) == data.size(2) # check square image assert ( torchvision.transforms.ToPILImage()(data).mode == "RGB" diff --git a/tests/test_mc.py b/tests/test_mc.py index 2fcd14ac..25bd4709 100644 --- a/tests/test_mc.py +++ b/tests/test_mc.py @@ -10,7 +10,6 @@ import pytest def test_protocol_consistency(): - # Default protocol datamodule = importlib.import_module( "ptbench.data.montgomery.default" @@ -126,12 +125,12 @@ def test_loading(): assert isinstance(data, torch.Tensor) - assert data.size(0) == 1 # check single channel - assert data.size(1) == data.size(2) # check square image + assert data.size(0) == 1 # check single channel + assert data.size(1) == data.size(2) # check square image assert ( - torchvision.transforms.ToPILImage()(data).mode == "L" - ) # Check colors + torchvision.transforms.ToPILImage()(data).mode == "L" + ) # Check colors assert "label" in metadata assert metadata["label"] in [0, 1] # Check labels @@ -145,10 +144,7 @@ def test_loading(): raw_data_loader = datamodule.raw_data_loader # Need to use private function so we can limit the number of samples to use - dataset = _DelayedLoadingDataset( - subset["train"][:limit], - raw_data_loader - ) + dataset = _DelayedLoadingDataset(subset["train"][:limit], raw_data_loader) for s in dataset: _check_sample(s) @@ -188,4 +184,3 @@ def test_check(): ) == 0 ) - -- GitLab