Skip to content
Snippets Groups Projects
Commit 2eeb0c14 authored by ogueler@idiap.ch's avatar ogueler@idiap.ch
Browse files

Merge branch 'add-datamodule' of gitlab.idiap.ch:biosignal/software/ptbench...

Merge branch 'add-datamodule' of gitlab.idiap.ch:biosignal/software/ptbench into add-datamodule-gradcam
parents 6621ede5 695e7029
No related branches found
No related tags found
No related merge requests found
......@@ -7,7 +7,7 @@
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/psf/black
rev: 23.7.0
rev: 23.9.1
hooks:
- id: black
- repo: https://github.com/pycqa/docformatter
......@@ -23,7 +23,7 @@ repos:
hooks:
- id: flake8
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.4.1
rev: v1.5.1
hooks:
- id: mypy
args: [
......@@ -33,7 +33,7 @@ repos:
--ignore-missing-imports,
]
- repo: https://github.com/asottile/pyupgrade
rev: v3.10.1
rev: v3.11.0
hooks:
- id: pyupgrade
args: [--py39-plus]
......
......@@ -87,7 +87,11 @@ class RawDataLoader(_BaseRawDataLoader):
# to_pil_image(tensor).show()
# __import__("pdb").set_trace()
return tensor, dict(label=sample[1], name=sample[0]) # type: ignore[arg-type]
return tensor, dict(
label=sample[1],
name=sample[0],
radsign_bboxes=self.bbox_annotations(sample),
)
def label(self, sample: DatabaseSample) -> int:
"""Loads a single image sample label from the disk.
......
......@@ -528,11 +528,11 @@ class ConcatDataModule(lightning.LightningDataModule):
self.parallel = parallel # immutable, otherwise would need to call
self.pin_memory = (
torch.cuda.is_available() or torch.backends.mps.is_available()
torch.cuda.is_available() or torch.backends.mps.is_available() # type: ignore
) # should only be true if GPU available and using it
# datasets that have been setup() for the current stage
self._datasets: CachingDataModule.DatasetDictionary = {}
self._datasets: ConcatDataModule.DatasetDictionary = {}
@property
def parallel(self) -> int:
......
......@@ -7,6 +7,7 @@ import importlib
import typing
import pytest
import torch
def id_function(val):
......@@ -147,6 +148,71 @@ def test_protocol_consistency(
)
def check_loaded_batch(
batch,
batch_size: int,
prefixes: typing.Sequence[str],
):
"""Checks the consistence of an individual (loaded) batch.
Parameters
----------
batch
The loaded batch to be checked.
size
The mini-batch size
"""
assert len(batch) == 2 # data, metadata
assert isinstance(batch[0], torch.Tensor)
assert batch[0].shape[0] == batch_size # mini-batch size
assert batch[0].shape[1] == 3 # grayscale images
assert batch[0].shape[2] == batch[0].shape[3] # image is square
assert batch[0].shape[2] == 512 # image is 512 pixels large
assert isinstance(batch[1], dict) # metadata
assert (
len(batch[1]) == 3
) # label, name and radiological sign bounding-boxes
assert "label" in batch[1]
assert all([k in (0, 1) for k in batch[1]["label"]])
assert "name" in batch[1]
assert all(
[any([k.startswith(j) for j in prefixes]) for k in batch[1]["name"]]
)
assert "radsign_bboxes" in batch[1]
for sample, label, bboxes in zip(
batch[0], batch[1]["label"], batch[1]["radsign_bboxes"]
):
# there must be a sign indicated on the image, if active TB is detected
if label == 1:
assert len(bboxes[0]) != 0
# eif label == 0: # not true, may have TBI!
# assert len(bboxes) == 0
# asserts all bounding boxes are within the raw image width and height
for bbox_label, xmin, ymin, width, height in zip(*bboxes):
if label == 1:
assert bbox_label == 1
else:
assert bbox_label == 0
assert (xmin + width) < sample.shape[2]
assert (ymin + height) < sample.shape[1]
# use the code below to view generated images
# from torchvision.transforms.functional import to_pil_image
# to_pil_image(batch[0][0]).show()
# __import__("pdb").set_trace()
@pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k")
@pytest.mark.parametrize(
"dataset",
......@@ -183,9 +249,7 @@ def test_protocol_consistency(
("v2_fold_9", ("imgs/health", "imgs/sick", "imgs/tb")),
],
)
def test_loading(
database_checkers, name: str, dataset: str, prefixes: typing.Sequence[str]
):
def test_loading(name: str, dataset: str, prefixes: typing.Sequence[str]):
datamodule = importlib.import_module(
f".{name}", "ptbench.config.data.tbx11k"
).datamodule
......@@ -195,21 +259,13 @@ def test_loading(
loader = datamodule.predict_dataloader()[dataset]
limit = 3 # limit load checking
limit = 50 # limit load checking
for batch in loader:
if limit == 0:
break
database_checkers.check_loaded_batch(
check_loaded_batch(
batch,
batch_size=1,
color_planes=3,
prefixes=prefixes,
possible_labels=(0, 1),
)
limit -= 1
# TODO: Tests for loading bounding boxes:
# if patient has active tb, then has to have 1 or more bounding boxes
# if patient does not have active tb, there should be no bounding boxes
# bounding boxes must be within image (512 x 512 pixels)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment