Skip to content
Snippets Groups Projects
Commit bfc106ab authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[tests] Some reformatting to make black happy

parent aea0a2e2
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
......@@ -31,14 +31,16 @@ class SingleAutoLevel16to8:
).convert("L")
def remove_black_borders(img: PIL.Image.Image, threshold: int=0) -> PIL.Image.Image:
"""Remove black borders of CXR
def remove_black_borders(
img: PIL.Image.Image, threshold: int = 0
) -> PIL.Image.Image:
"""Remove black borders of CXR.
Parameters
----------
img
img
A PIL image
threshold
threshold
Threshold value from which borders are considered black.
Defaults to 0.
......@@ -49,10 +51,10 @@ def remove_black_borders(img: PIL.Image.Image, threshold: int=0) -> PIL.Image.Im
img = numpy.asarray(img)
if len(img.shape) == 2: # single channel
if len(img.shape) == 2: # single channel
mask = numpy.asarray(img) > threshold
return PIL.Image.fromarray(img[numpy.ix_(mask.any(1), mask.any(0))])
elif len(img.shape) == 3 and img.shape[2] == 3:
r_mask = img[:, :, 0] > threshold
g_mask = img[:, :, 1] > threshold
......@@ -60,7 +62,7 @@ def remove_black_borders(img: PIL.Image.Image, threshold: int=0) -> PIL.Image.Im
mask = r_mask | g_mask | b_mask
return PIL.Image.fromarray(img[numpy.ix_(mask.any(1), mask.any(0))])
else:
raise NotImplementedError
......
import logging
import os
import typing
logger = logging.getLogger(__name__)
def get_checkpoint(output_folder: str, resume_from: typing.Literal["last", "best"] | str | None) -> str | None :
def get_checkpoint(
output_folder: str, resume_from: typing.Literal["last", "best"] | str | None
) -> str | None:
"""Gets a checkpoint file.
Can return the best or last checkpoint, or a checkpoint at a specific path.
......@@ -56,7 +58,9 @@ def get_checkpoint(output_folder: str, resume_from: typing.Literal["last", "best
elif resume_from is None:
if os.path.isfile(last_checkpoint_path):
checkpoint_file = last_checkpoint_path
logger.info(f"Found existing checkpoint {last_checkpoint_path}. Loading.")
logger.info(
f"Found existing checkpoint {last_checkpoint_path}. Loading."
)
else:
return None
......
......@@ -128,8 +128,8 @@ def test_loading():
assert isinstance(data, torch.Tensor)
assert data.size(0) == 3 # check 3 channels
assert data.size(1) == data.size(2) # check square image
assert data.size(0) == 3 # check 3 channels
assert data.size(1) == data.size(2) # check square image
assert (
torchvision.transforms.ToPILImage()(data).mode == "RGB"
......
......@@ -10,7 +10,6 @@ import pytest
def test_protocol_consistency():
# Default protocol
datamodule = importlib.import_module(
"ptbench.data.montgomery.default"
......@@ -126,12 +125,12 @@ def test_loading():
assert isinstance(data, torch.Tensor)
assert data.size(0) == 1 # check single channel
assert data.size(1) == data.size(2) # check square image
assert data.size(0) == 1 # check single channel
assert data.size(1) == data.size(2) # check square image
assert (
torchvision.transforms.ToPILImage()(data).mode == "L"
) # Check colors
torchvision.transforms.ToPILImage()(data).mode == "L"
) # Check colors
assert "label" in metadata
assert metadata["label"] in [0, 1] # Check labels
......@@ -145,10 +144,7 @@ def test_loading():
raw_data_loader = datamodule.raw_data_loader
# Need to use private function so we can limit the number of samples to use
dataset = _DelayedLoadingDataset(
subset["train"][:limit],
raw_data_loader
)
dataset = _DelayedLoadingDataset(subset["train"][:limit], raw_data_loader)
for s in dataset:
_check_sample(s)
......@@ -188,4 +184,3 @@ def test_check():
)
== 0
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment