Skip to content
Snippets Groups Projects
Commit afda408c authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

[tests] Fix tests

parent d354514c
No related branches found
No related tags found
1 merge request!46Create common library
Showing
with 25 additions and 25 deletions
...@@ -38,4 +38,4 @@ class ClassificationRawDataLoader(RawDataLoader): ...@@ -38,4 +38,4 @@ class ClassificationRawDataLoader(RawDataLoader):
The label corresponding to the specified sample. The label corresponding to the specified sample.
""" """
return self.sample(k)[1]["label"] return self.sample(k)[1]["target"]
...@@ -168,7 +168,7 @@ def _process_sample( ...@@ -168,7 +168,7 @@ def _process_sample(
""" """
name: str = sample[1]["name"][0] name: str = sample[1]["name"][0]
label: int = int(sample[1]["label"].item()) label: int = int(sample[1]["target"].item())
image = sample[0] image = sample[0]
# in binary classification systems, negative labels may be skipped # in binary classification systems, negative labels may be skipped
......
...@@ -426,7 +426,7 @@ def run( ...@@ -426,7 +426,7 @@ def run(
disable=None, disable=None,
): ):
name = str(sample[1]["name"][0]) name = str(sample[1]["name"][0])
label = int(sample[1]["label"].item()) label = int(sample[1]["target"].item())
if label != target_label: if label != target_label:
# we add the entry for dataset completeness, but do not treat # we add the entry for dataset completeness, but do not treat
......
...@@ -44,12 +44,12 @@ def test_split_consistency(name: str): ...@@ -44,12 +44,12 @@ def test_split_consistency(name: str):
montgomery_loader = importlib.import_module( montgomery_loader = importlib.import_module(
".datamodule", ".datamodule",
"mednet.libs.classification.config.data.montgomery", "mednet.libs.classification.config.data.montgomery",
).RawDataLoader ).ClassificationRawDataLoader
shenzhen_loader = importlib.import_module( shenzhen_loader = importlib.import_module(
".datamodule", ".datamodule",
"mednet.libs.classification.config.data.shenzhen", "mednet.libs.classification.config.data.shenzhen",
).RawDataLoader ).ClassificationRawDataLoader
for split in ("train", "validation", "test"): for split in ("train", "validation", "test"):
assert montgomery.splits[split][0][0] == combined.splits[split][0][0] assert montgomery.splits[split][0][0] == combined.splits[split][0][0]
......
...@@ -49,17 +49,17 @@ def test_split_consistency(name: str): ...@@ -49,17 +49,17 @@ def test_split_consistency(name: str):
montgomery_loader = importlib.import_module( montgomery_loader = importlib.import_module(
".datamodule", ".datamodule",
"mednet.libs.classification.config.data.montgomery", "mednet.libs.classification.config.data.montgomery",
).RawDataLoader ).ClassificationRawDataLoader
shenzhen_loader = importlib.import_module( shenzhen_loader = importlib.import_module(
".datamodule", ".datamodule",
"mednet.libs.classification.config.data.shenzhen", "mednet.libs.classification.config.data.shenzhen",
).RawDataLoader ).ClassificationRawDataLoader
indian_loader = importlib.import_module( indian_loader = importlib.import_module(
".datamodule", ".datamodule",
"mednet.libs.classification.config.data.indian", "mednet.libs.classification.config.data.indian",
).RawDataLoader ).ClassificationRawDataLoader
for split in ("train", "validation", "test"): for split in ("train", "validation", "test"):
assert montgomery.splits[split][0][0] == combined.splits[split][0][0] assert montgomery.splits[split][0][0] == combined.splits[split][0][0]
......
...@@ -44,22 +44,22 @@ def test_split_consistency(name: str, padchest_name: str): ...@@ -44,22 +44,22 @@ def test_split_consistency(name: str, padchest_name: str):
montgomery_loader = importlib.import_module( montgomery_loader = importlib.import_module(
".datamodule", ".datamodule",
"mednet.libs.classification.config.data.montgomery", "mednet.libs.classification.config.data.montgomery",
).RawDataLoader ).ClassificationRawDataLoader
shenzhen_loader = importlib.import_module( shenzhen_loader = importlib.import_module(
".datamodule", ".datamodule",
"mednet.libs.classification.config.data.shenzhen", "mednet.libs.classification.config.data.shenzhen",
).RawDataLoader ).ClassificationRawDataLoader
indian_loader = importlib.import_module( indian_loader = importlib.import_module(
".datamodule", ".datamodule",
"mednet.libs.classification.config.data.indian", "mednet.libs.classification.config.data.indian",
).RawDataLoader ).ClassificationRawDataLoader
padchest_loader = importlib.import_module( padchest_loader = importlib.import_module(
".datamodule", ".datamodule",
"mednet.libs.classification.config.data.padchest", "mednet.libs.classification.config.data.padchest",
).RawDataLoader ).ClassificationRawDataLoader
for split in ("train", "validation", "test"): for split in ("train", "validation", "test"):
assert montgomery.splits[split][0][0] == combined.splits[split][0][0] assert montgomery.splits[split][0][0] == combined.splits[split][0][0]
......
...@@ -65,22 +65,22 @@ def test_split_consistency(name: str, tbx11k_name: str): ...@@ -65,22 +65,22 @@ def test_split_consistency(name: str, tbx11k_name: str):
montgomery_loader = importlib.import_module( montgomery_loader = importlib.import_module(
".datamodule", ".datamodule",
"mednet.libs.classification.config.data.montgomery", "mednet.libs.classification.config.data.montgomery",
).RawDataLoader ).ClassificationRawDataLoader
shenzhen_loader = importlib.import_module( shenzhen_loader = importlib.import_module(
".datamodule", ".datamodule",
"mednet.libs.classification.config.data.shenzhen", "mednet.libs.classification.config.data.shenzhen",
).RawDataLoader ).ClassificationRawDataLoader
indian_loader = importlib.import_module( indian_loader = importlib.import_module(
".datamodule", ".datamodule",
"mednet.libs.classification.config.data.indian", "mednet.libs.classification.config.data.indian",
).RawDataLoader ).ClassificationRawDataLoader
tbx11k_loader = importlib.import_module( tbx11k_loader = importlib.import_module(
".datamodule", ".datamodule",
"mednet.libs.classification.config.data.tbx11k", "mednet.libs.classification.config.data.tbx11k",
).RawDataLoader ).ClassificationRawDataLoader
for split in ("train", "validation", "test"): for split in ("train", "validation", "test"):
assert montgomery.splits[split][0][0] == combined.splits[split][0][0] assert montgomery.splits[split][0][0] == combined.splits[split][0][0]
......
...@@ -34,12 +34,12 @@ def test_split_consistency(name: str, padchest_name: str, combined_name: str): ...@@ -34,12 +34,12 @@ def test_split_consistency(name: str, padchest_name: str, combined_name: str):
cxr14_loader = importlib.import_module( cxr14_loader = importlib.import_module(
".datamodule", ".datamodule",
"mednet.libs.classification.config.data.nih_cxr14", "mednet.libs.classification.config.data.nih_cxr14",
).RawDataLoader ).ClassificationRawDataLoader
padchest_loader = importlib.import_module( padchest_loader = importlib.import_module(
".datamodule", ".datamodule",
"mednet.libs.classification.config.data.padchest", "mednet.libs.classification.config.data.padchest",
).RawDataLoader ).ClassificationRawDataLoader
for split in ("train", "validation", "test"): for split in ("train", "validation", "test"):
assert nih_cxr14.splits[split][0][0] == combined.splits[split][0][0] assert nih_cxr14.splits[split][0][0] == combined.splits[split][0][0]
......
...@@ -192,11 +192,11 @@ def check_loaded_batch( ...@@ -192,11 +192,11 @@ def check_loaded_batch(
assert isinstance(batch[1], dict) # metadata assert isinstance(batch[1], dict) # metadata
assert len(batch[1]) == 3 # label, name and radiological sign bounding-boxes assert len(batch[1]) == 3 # label, name and radiological sign bounding-boxes
assert "label" in batch[1] assert "target" in batch[1]
assert all([k in possible_labels for k in batch[1]["label"]]) assert all([k in possible_labels for k in batch[1]["target"]])
if expected_num_labels: if expected_num_labels:
assert len(batch[1]["label"]) == expected_num_labels assert len(batch[1]["target"]) == expected_num_labels
assert "name" in batch[1] assert "name" in batch[1]
assert all( assert all(
...@@ -207,7 +207,7 @@ def check_loaded_batch( ...@@ -207,7 +207,7 @@ def check_loaded_batch(
for sample, label, bboxes in zip( for sample, label, bboxes in zip(
batch[0], batch[0],
batch[1]["label"], batch[1]["target"],
batch[1]["bounding_boxes"], batch[1]["bounding_boxes"],
): ):
# there must be a sign indicated on the image, if active TB is detected # there must be a sign indicated on the image, if active TB is detected
......
...@@ -190,11 +190,11 @@ class DatabaseCheckers: ...@@ -190,11 +190,11 @@ class DatabaseCheckers:
assert isinstance(batch[1], dict) # metadata assert isinstance(batch[1], dict) # metadata
assert len(batch[1]) == 2 # label and name assert len(batch[1]) == 2 # label and name
assert "label" in batch[1] assert "target" in batch[1]
assert all([k in possible_labels for k in batch[1]["label"]]) assert all([k in possible_labels for k in batch[1]["target"]])
if expected_num_labels: if expected_num_labels:
assert len(batch[1]["label"]) == expected_num_labels assert len(batch[1]["target"]) == expected_num_labels
assert "name" in batch[1] assert "name" in batch[1]
assert all( assert all(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment