From bfc106ab413c9b0245556a9280db81f0d16ceb4b Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Fri, 21 Jul 2023 20:35:02 +0200
Subject: [PATCH] [tests] Some reformatting to make black happy

---
 src/ptbench/data/image_utils.py   | 16 +++++++++-------
 src/ptbench/utils/checkpointer.py | 10 +++++++---
 tests/test_ch.py                  |  4 ++--
 tests/test_mc.py                  | 15 +++++----------
 4 files changed, 23 insertions(+), 22 deletions(-)

diff --git a/src/ptbench/data/image_utils.py b/src/ptbench/data/image_utils.py
index ed284afc..b1c9d82e 100644
--- a/src/ptbench/data/image_utils.py
+++ b/src/ptbench/data/image_utils.py
@@ -31,14 +31,16 @@ class SingleAutoLevel16to8:
         ).convert("L")
 
 
-def remove_black_borders(img: PIL.Image.Image, threshold: int=0) -> PIL.Image.Image:
-    """Remove black borders of CXR
+def remove_black_borders(
+    img: PIL.Image.Image, threshold: int = 0
+) -> PIL.Image.Image:
+    """Remove black borders of CXR.
 
     Parameters
     ----------
-        img 
+        img
             A PIL image
-        threshold 
+        threshold
             Threshold value from which borders are considered black.
             Defaults to 0.
 
@@ -49,10 +51,10 @@ def remove_black_borders(img: PIL.Image.Image, threshold: int=0) -> PIL.Image.Im
 
     img = numpy.asarray(img)
 
-    if len(img.shape) == 2: # single channel
+    if len(img.shape) == 2:  # single channel
         mask = numpy.asarray(img) > threshold
         return PIL.Image.fromarray(img[numpy.ix_(mask.any(1), mask.any(0))])
-    
+
     elif len(img.shape) == 3 and img.shape[2] == 3:
         r_mask = img[:, :, 0] > threshold
         g_mask = img[:, :, 1] > threshold
@@ -60,7 +62,7 @@ def remove_black_borders(img: PIL.Image.Image, threshold: int=0) -> PIL.Image.Im
 
         mask = r_mask | g_mask | b_mask
         return PIL.Image.fromarray(img[numpy.ix_(mask.any(1), mask.any(0))])
-    
+
     else:
         raise NotImplementedError
 
diff --git a/src/ptbench/utils/checkpointer.py b/src/ptbench/utils/checkpointer.py
index 5c2f272c..318811b0 100644
--- a/src/ptbench/utils/checkpointer.py
+++ b/src/ptbench/utils/checkpointer.py
@@ -1,11 +1,13 @@
 import logging
 import os
-
 import typing
+
 logger = logging.getLogger(__name__)
 
 
-def get_checkpoint(output_folder: str, resume_from: typing.Literal["last", "best"] | str | None) -> str | None : 
+def get_checkpoint(
+    output_folder: str, resume_from: typing.Literal["last", "best"] | str | None
+) -> str | None:
     """Gets a checkpoint file.
 
     Can return the best or last checkpoint, or a checkpoint at a specific path.
@@ -56,7 +58,9 @@ def get_checkpoint(output_folder: str, resume_from: typing.Literal["last", "best
     elif resume_from is None:
         if os.path.isfile(last_checkpoint_path):
             checkpoint_file = last_checkpoint_path
-            logger.info(f"Found existing checkpoint {last_checkpoint_path}. Loading.")
+            logger.info(
+                f"Found existing checkpoint {last_checkpoint_path}. Loading."
+            )
         else:
             return None
 
diff --git a/tests/test_ch.py b/tests/test_ch.py
index b28c81e9..c678e087 100644
--- a/tests/test_ch.py
+++ b/tests/test_ch.py
@@ -128,8 +128,8 @@ def test_loading():
 
         assert isinstance(data, torch.Tensor)
 
-        assert data.size(0) == 3 # check 3 channels
-        assert data.size(1) == data.size(2) # check square image
+        assert data.size(0) == 3  # check 3 channels
+        assert data.size(1) == data.size(2)  # check square image
 
         assert (
             torchvision.transforms.ToPILImage()(data).mode == "RGB"
diff --git a/tests/test_mc.py b/tests/test_mc.py
index 2fcd14ac..25bd4709 100644
--- a/tests/test_mc.py
+++ b/tests/test_mc.py
@@ -10,7 +10,6 @@ import pytest
 
 
 def test_protocol_consistency():
-
     # Default protocol
     datamodule = importlib.import_module(
         "ptbench.data.montgomery.default"
@@ -126,12 +125,12 @@ def test_loading():
 
         assert isinstance(data, torch.Tensor)
 
-        assert data.size(0) == 1 # check single channel
-        assert data.size(1) == data.size(2) # check square image
+        assert data.size(0) == 1  # check single channel
+        assert data.size(1) == data.size(2)  # check square image
 
         assert (
-            torchvision.transforms.ToPILImage()(data).mode == "L" 
-        ) # Check colors
+            torchvision.transforms.ToPILImage()(data).mode == "L"
+        )  # Check colors
 
         assert "label" in metadata
         assert metadata["label"] in [0, 1]  # Check labels
@@ -145,10 +144,7 @@ def test_loading():
     raw_data_loader = datamodule.raw_data_loader
 
     # Need to use private function so we can limit the number of samples to use
-    dataset = _DelayedLoadingDataset(
-        subset["train"][:limit],
-        raw_data_loader
-    )
+    dataset = _DelayedLoadingDataset(subset["train"][:limit], raw_data_loader)
 
     for s in dataset:
         _check_sample(s)
@@ -188,4 +184,3 @@ def test_check():
             )
             == 0
         )
-
-- 
GitLab