Skip to content
Snippets Groups Projects
Commit bfc106ab authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[tests] Some reformatting to make black happy

parent aea0a2e2
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
...@@ -31,14 +31,16 @@ class SingleAutoLevel16to8: ...@@ -31,14 +31,16 @@ class SingleAutoLevel16to8:
).convert("L") ).convert("L")
def remove_black_borders(img: PIL.Image.Image, threshold: int=0) -> PIL.Image.Image: def remove_black_borders(
"""Remove black borders of CXR img: PIL.Image.Image, threshold: int = 0
) -> PIL.Image.Image:
"""Remove black borders of CXR.
Parameters Parameters
---------- ----------
img img
A PIL image A PIL image
threshold threshold
Threshold value from which borders are considered black. Threshold value from which borders are considered black.
Defaults to 0. Defaults to 0.
...@@ -49,10 +51,10 @@ def remove_black_borders(img: PIL.Image.Image, threshold: int=0) -> PIL.Image.Im ...@@ -49,10 +51,10 @@ def remove_black_borders(img: PIL.Image.Image, threshold: int=0) -> PIL.Image.Im
img = numpy.asarray(img) img = numpy.asarray(img)
if len(img.shape) == 2: # single channel if len(img.shape) == 2: # single channel
mask = numpy.asarray(img) > threshold mask = numpy.asarray(img) > threshold
return PIL.Image.fromarray(img[numpy.ix_(mask.any(1), mask.any(0))]) return PIL.Image.fromarray(img[numpy.ix_(mask.any(1), mask.any(0))])
elif len(img.shape) == 3 and img.shape[2] == 3: elif len(img.shape) == 3 and img.shape[2] == 3:
r_mask = img[:, :, 0] > threshold r_mask = img[:, :, 0] > threshold
g_mask = img[:, :, 1] > threshold g_mask = img[:, :, 1] > threshold
...@@ -60,7 +62,7 @@ def remove_black_borders(img: PIL.Image.Image, threshold: int=0) -> PIL.Image.Im ...@@ -60,7 +62,7 @@ def remove_black_borders(img: PIL.Image.Image, threshold: int=0) -> PIL.Image.Im
mask = r_mask | g_mask | b_mask mask = r_mask | g_mask | b_mask
return PIL.Image.fromarray(img[numpy.ix_(mask.any(1), mask.any(0))]) return PIL.Image.fromarray(img[numpy.ix_(mask.any(1), mask.any(0))])
else: else:
raise NotImplementedError raise NotImplementedError
......
import logging import logging
import os import os
import typing import typing
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_checkpoint(output_folder: str, resume_from: typing.Literal["last", "best"] | str | None) -> str | None : def get_checkpoint(
output_folder: str, resume_from: typing.Literal["last", "best"] | str | None
) -> str | None:
"""Gets a checkpoint file. """Gets a checkpoint file.
Can return the best or last checkpoint, or a checkpoint at a specific path. Can return the best or last checkpoint, or a checkpoint at a specific path.
...@@ -56,7 +58,9 @@ def get_checkpoint(output_folder: str, resume_from: typing.Literal["last", "best ...@@ -56,7 +58,9 @@ def get_checkpoint(output_folder: str, resume_from: typing.Literal["last", "best
elif resume_from is None: elif resume_from is None:
if os.path.isfile(last_checkpoint_path): if os.path.isfile(last_checkpoint_path):
checkpoint_file = last_checkpoint_path checkpoint_file = last_checkpoint_path
logger.info(f"Found existing checkpoint {last_checkpoint_path}. Loading.") logger.info(
f"Found existing checkpoint {last_checkpoint_path}. Loading."
)
else: else:
return None return None
......
...@@ -128,8 +128,8 @@ def test_loading(): ...@@ -128,8 +128,8 @@ def test_loading():
assert isinstance(data, torch.Tensor) assert isinstance(data, torch.Tensor)
assert data.size(0) == 3 # check 3 channels assert data.size(0) == 3 # check 3 channels
assert data.size(1) == data.size(2) # check square image assert data.size(1) == data.size(2) # check square image
assert ( assert (
torchvision.transforms.ToPILImage()(data).mode == "RGB" torchvision.transforms.ToPILImage()(data).mode == "RGB"
......
...@@ -10,7 +10,6 @@ import pytest ...@@ -10,7 +10,6 @@ import pytest
def test_protocol_consistency(): def test_protocol_consistency():
# Default protocol # Default protocol
datamodule = importlib.import_module( datamodule = importlib.import_module(
"ptbench.data.montgomery.default" "ptbench.data.montgomery.default"
...@@ -126,12 +125,12 @@ def test_loading(): ...@@ -126,12 +125,12 @@ def test_loading():
assert isinstance(data, torch.Tensor) assert isinstance(data, torch.Tensor)
assert data.size(0) == 1 # check single channel assert data.size(0) == 1 # check single channel
assert data.size(1) == data.size(2) # check square image assert data.size(1) == data.size(2) # check square image
assert ( assert (
torchvision.transforms.ToPILImage()(data).mode == "L" torchvision.transforms.ToPILImage()(data).mode == "L"
) # Check colors ) # Check colors
assert "label" in metadata assert "label" in metadata
assert metadata["label"] in [0, 1] # Check labels assert metadata["label"] in [0, 1] # Check labels
...@@ -145,10 +144,7 @@ def test_loading(): ...@@ -145,10 +144,7 @@ def test_loading():
raw_data_loader = datamodule.raw_data_loader raw_data_loader = datamodule.raw_data_loader
# Need to use private function so we can limit the number of samples to use # Need to use private function so we can limit the number of samples to use
dataset = _DelayedLoadingDataset( dataset = _DelayedLoadingDataset(subset["train"][:limit], raw_data_loader)
subset["train"][:limit],
raw_data_loader
)
for s in dataset: for s in dataset:
_check_sample(s) _check_sample(s)
...@@ -188,4 +184,3 @@ def test_check(): ...@@ -188,4 +184,3 @@ def test_check():
) )
== 0 == 0
) )
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment