Skip to content
Snippets Groups Projects
Commit afda408c authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

[tests] Fix tests

parent d354514c
No related branches found
No related tags found
1 merge request!46Create common library
Showing
with 25 additions and 25 deletions
......@@ -38,4 +38,4 @@ class ClassificationRawDataLoader(RawDataLoader):
The label corresponding to the specified sample.
"""
return self.sample(k)[1]["label"]
return self.sample(k)[1]["target"]
......@@ -168,7 +168,7 @@ def _process_sample(
"""
name: str = sample[1]["name"][0]
label: int = int(sample[1]["label"].item())
label: int = int(sample[1]["target"].item())
image = sample[0]
# in binary classification systems, negative labels may be skipped
......
......@@ -426,7 +426,7 @@ def run(
disable=None,
):
name = str(sample[1]["name"][0])
label = int(sample[1]["label"].item())
label = int(sample[1]["target"].item())
if label != target_label:
# we add the entry for dataset completeness, but do not treat
......
......@@ -44,12 +44,12 @@ def test_split_consistency(name: str):
montgomery_loader = importlib.import_module(
".datamodule",
"mednet.libs.classification.config.data.montgomery",
).RawDataLoader
).ClassificationRawDataLoader
shenzhen_loader = importlib.import_module(
".datamodule",
"mednet.libs.classification.config.data.shenzhen",
).RawDataLoader
).ClassificationRawDataLoader
for split in ("train", "validation", "test"):
assert montgomery.splits[split][0][0] == combined.splits[split][0][0]
......
......@@ -49,17 +49,17 @@ def test_split_consistency(name: str):
montgomery_loader = importlib.import_module(
".datamodule",
"mednet.libs.classification.config.data.montgomery",
).RawDataLoader
).ClassificationRawDataLoader
shenzhen_loader = importlib.import_module(
".datamodule",
"mednet.libs.classification.config.data.shenzhen",
).RawDataLoader
).ClassificationRawDataLoader
indian_loader = importlib.import_module(
".datamodule",
"mednet.libs.classification.config.data.indian",
).RawDataLoader
).ClassificationRawDataLoader
for split in ("train", "validation", "test"):
assert montgomery.splits[split][0][0] == combined.splits[split][0][0]
......
......@@ -44,22 +44,22 @@ def test_split_consistency(name: str, padchest_name: str):
montgomery_loader = importlib.import_module(
".datamodule",
"mednet.libs.classification.config.data.montgomery",
).RawDataLoader
).ClassificationRawDataLoader
shenzhen_loader = importlib.import_module(
".datamodule",
"mednet.libs.classification.config.data.shenzhen",
).RawDataLoader
).ClassificationRawDataLoader
indian_loader = importlib.import_module(
".datamodule",
"mednet.libs.classification.config.data.indian",
).RawDataLoader
).ClassificationRawDataLoader
padchest_loader = importlib.import_module(
".datamodule",
"mednet.libs.classification.config.data.padchest",
).RawDataLoader
).ClassificationRawDataLoader
for split in ("train", "validation", "test"):
assert montgomery.splits[split][0][0] == combined.splits[split][0][0]
......
......@@ -65,22 +65,22 @@ def test_split_consistency(name: str, tbx11k_name: str):
montgomery_loader = importlib.import_module(
".datamodule",
"mednet.libs.classification.config.data.montgomery",
).RawDataLoader
).ClassificationRawDataLoader
shenzhen_loader = importlib.import_module(
".datamodule",
"mednet.libs.classification.config.data.shenzhen",
).RawDataLoader
).ClassificationRawDataLoader
indian_loader = importlib.import_module(
".datamodule",
"mednet.libs.classification.config.data.indian",
).RawDataLoader
).ClassificationRawDataLoader
tbx11k_loader = importlib.import_module(
".datamodule",
"mednet.libs.classification.config.data.tbx11k",
).RawDataLoader
).ClassificationRawDataLoader
for split in ("train", "validation", "test"):
assert montgomery.splits[split][0][0] == combined.splits[split][0][0]
......
......@@ -34,12 +34,12 @@ def test_split_consistency(name: str, padchest_name: str, combined_name: str):
cxr14_loader = importlib.import_module(
".datamodule",
"mednet.libs.classification.config.data.nih_cxr14",
).RawDataLoader
).ClassificationRawDataLoader
padchest_loader = importlib.import_module(
".datamodule",
"mednet.libs.classification.config.data.padchest",
).RawDataLoader
).ClassificationRawDataLoader
for split in ("train", "validation", "test"):
assert nih_cxr14.splits[split][0][0] == combined.splits[split][0][0]
......
......@@ -192,11 +192,11 @@ def check_loaded_batch(
assert isinstance(batch[1], dict) # metadata
assert len(batch[1]) == 3 # label, name and radiological sign bounding-boxes
assert "label" in batch[1]
assert all([k in possible_labels for k in batch[1]["label"]])
assert "target" in batch[1]
assert all([k in possible_labels for k in batch[1]["target"]])
if expected_num_labels:
assert len(batch[1]["label"]) == expected_num_labels
assert len(batch[1]["target"]) == expected_num_labels
assert "name" in batch[1]
assert all(
......@@ -207,7 +207,7 @@ def check_loaded_batch(
for sample, label, bboxes in zip(
batch[0],
batch[1]["label"],
batch[1]["target"],
batch[1]["bounding_boxes"],
):
# there must be a sign indicated on the image, if active TB is detected
......
......@@ -190,11 +190,11 @@ class DatabaseCheckers:
assert isinstance(batch[1], dict) # metadata
assert len(batch[1]) == 2 # label and name
assert "label" in batch[1]
assert all([k in possible_labels for k in batch[1]["label"]])
assert "target" in batch[1]
assert all([k in possible_labels for k in batch[1]["target"]])
if expected_num_labels:
assert len(batch[1]["label"]) == expected_num_labels
assert len(batch[1]["target"]) == expected_num_labels
assert "name" in batch[1]
assert all(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment