Skip to content
Snippets Groups Projects
Commit ce135a5d authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

Merge branch 'update-tests' into 'issue-23-and-39'

Update tests

See merge request biosignal/software/mednet!18
parents 2b40fbd6 1c9ce4ba
Branches
Tags
2 merge requests!18Update tests,!16Make square centre-padding a model transform
Pipeline #84309 passed
Showing
with 279 additions and 156 deletions
Subproject commit 05344d20182ad4169fc5b9c38052d629aded30ed
Subproject commit fe158e393b7904c6b381b0d9862895bf05c6cf7a
[datadir]
montgomery = "/idiap/resource/database/MontgomeryXraySet"
shenzhen = "/idiap/resource/database/ShenzhenXraySet"
indian = "/idiap/resource/database/TBXpredict"
tbx11k = "/idiap/resource/database/tbx11k"
......@@ -436,108 +436,97 @@ def test_evaluate_pasa_montgomery(temporary_basedir):
)
# This script does not work anymore, either fix or remove the script + this test
# def test_evaluatevis(temporary_basedir):
# import pandas as pd
# from mednet.scripts.evaluatevis import evaluatevis
# runner = CliRunner()
# # Create a sample directory structure and CSV files
# input_folder = temporary_basedir / "camutils_cli" / "gradcam"
# input_folder.mkdir(parents=True, exist_ok=True)
# class1_dir = input_folder / "class1"
# class1_dir.mkdir(parents=True, exist_ok=True)
# class2_dir = input_folder / "class2"
# class2_dir.mkdir(parents=True, exist_ok=True)
# data = {
# "MoRF": [1, 2, 3],
# "LeRF": [2, 4, 6],
# "Combined Score ((LeRF-MoRF) / 2)": [1.5, 3, 4.5],
# "IoU": [1, 2, 3],
# "IoDA": [2, 4, 6],
# "propEnergy": [1.5, 3, 4.5],
# "ASF": [1, 2, 3],
# }
# df = pd.DataFrame(data)
# df.to_csv(class1_dir / "file1.csv", index=False)
# df.to_csv(class2_dir / "file1.csv", index=False)
# df.to_csv(class1_dir / "file2.csv", index=False)
# df.to_csv(class2_dir / "file2.csv", index=False)
# result = runner.invoke(evaluatevis, ["-vv", "-i", str(input_folder)])
# assert result.exit_code == 0
# assert (input_folder / "file1_summary.csv").exists()
# assert (input_folder / "file2_summary.csv").exists()
# Not enough RAM available to do this test
# @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
# def test_predict_densenetrs_montgomery(temporary_basedir, datadir):
# from mednet.scripts.predict import predict
# runner = CliRunner()
# with stdout_logging() as buf:
# output_folder = str(temporary_basedir / "predictions")
# result = runner.invoke(
# predict,
# [
# "densenet_rs",
# "montgomery_f0_rgb",
# "-vv",
# "--batch-size=1",
# f"--weight={str(datadir / 'lfs' / 'models' / 'densenetrs.pth')}",
# f"--output-folder={output_folder}",
# "--grad-cams"
# ],
# )
# _assert_exit_0(result)
# # check predictions are there
# predictions_file1 = os.path.join(output_folder, "train/predictions.csv")
# predictions_file2 = os.path.join(output_folder, "validation/predictions.csv")
# predictions_file3 = os.path.join(output_folder, "test/predictions.csv")
# assert os.path.exists(predictions_file1)
# assert os.path.exists(predictions_file2)
# assert os.path.exists(predictions_file3)
# # check some grad cams are there
# cam1 = os.path.join(output_folder, "train/cams/MCUCXR_0002_0_cam.png")
# cam2 = os.path.join(output_folder, "train/cams/MCUCXR_0126_1_cam.png")
# cam3 = os.path.join(output_folder, "train/cams/MCUCXR_0275_1_cam.png")
# cam4 = os.path.join(output_folder, "validation/cams/MCUCXR_0399_1_cam.png")
# cam5 = os.path.join(output_folder, "validation/cams/MCUCXR_0113_1_cam.png")
# cam6 = os.path.join(output_folder, "validation/cams/MCUCXR_0013_0_cam.png")
# cam7 = os.path.join(output_folder, "test/cams/MCUCXR_0027_0_cam.png")
# cam8 = os.path.join(output_folder, "test/cams/MCUCXR_0094_0_cam.png")
# cam9 = os.path.join(output_folder, "test/cams/MCUCXR_0375_1_cam.png")
# assert os.path.exists(cam1)
# assert os.path.exists(cam2)
# assert os.path.exists(cam3)
# assert os.path.exists(cam4)
# assert os.path.exists(cam5)
# assert os.path.exists(cam6)
# assert os.path.exists(cam7)
# assert os.path.exists(cam8)
# assert os.path.exists(cam9)
# keywords = {
# r"^Loading checkpoint from.*$": 1,
# r"^Total time:.*$": 3,
# r"^Grad cams folder:.*$": 3,
# }
# buf.seek(0)
# logging_output = buf.read()
# for k, v in keywords.items():
# assert _str_counter(k, logging_output) == v, (
# f"Count for string '{k}' appeared "
# f"({_str_counter(k, logging_output)}) "
# f"instead of the expected {v}:\nOutput:\n{logging_output}"
# )
@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
def test_experiment(temporary_basedir):
from mednet.scripts.experiment import experiment
runner = CliRunner()
output_folder = str(temporary_basedir / "experiment")
num_epochs = 2
result = runner.invoke(
experiment,
[
"-vv",
"pasa",
"montgomery",
f"--epochs={num_epochs}",
f"--output-folder={output_folder}",
],
)
_assert_exit_0(result)
assert os.path.exists(os.path.join(output_folder, "command.sh"))
assert os.path.exists(os.path.join(output_folder, "predictions.json"))
assert os.path.exists(os.path.join(output_folder, "model", "command.sh"))
assert os.path.exists(os.path.join(output_folder, "model", "constants.csv"))
assert os.path.exists(
os.path.join(
output_folder, "model", f"model-at-epoch={num_epochs-1}.ckpt"
)
)
# Need to glob because we cannot be sure of the checkpoint with lowest validation loss
assert (
len(
glob.glob(
os.path.join(
output_folder,
"model",
"model-at-lowest-validation-loss-epoch=*.ckpt",
)
)
)
== 1
)
assert os.path.exists(
os.path.join(output_folder, "model", "model-summary.txt")
)
assert os.path.exists(os.path.join(output_folder, "model", "trainlog.pdf"))
assert (
len(
glob.glob(
os.path.join(
output_folder, "model", "logs", "events.out.tfevents.*"
)
)
)
== 1
)
assert os.path.exists(
os.path.join(output_folder, "evaluation", "plots.pdf")
)
assert os.path.exists(
os.path.join(output_folder, "evaluation", "summary.rst")
)
assert os.path.exists(os.path.join(output_folder, "gradcam", "saliencies"))
assert (
len(
glob.glob(
os.path.join(
output_folder,
"gradcam",
"saliencies",
"CXR_png",
"MCUCXR_*.npy",
)
)
)
== 138
)
assert os.path.exists(
os.path.join(output_folder, "gradcam", "visualizations")
)
assert (
len(
glob.glob(
os.path.join(
output_folder,
"gradcam",
"visualizations",
"CXR_png",
"MCUCXR_*.png",
)
)
)
== 58
)
......@@ -87,5 +87,22 @@ def test_loading(database_checkers, name: str, dataset: str):
color_planes=1,
prefixes=("HIV-TB_Algorithm_study_X-rays",),
possible_labels=(0, 1),
expected_num_labels=1,
)
limit -= 1
@pytest.mark.skip_if_rc_var_not_set("datadir.hivtb")
def test_loaded_image_quality(database_checkers, datadir):
reference_histogram_file = str(
datadir / "histograms/raw_data/histograms_hivtb_fold_0.json"
)
datamodule = importlib.import_module(
".fold_0", "mednet.config.data.hivtb"
).datamodule
datamodule.model_transforms = []
datamodule.setup("predict")
database_checkers.check_image_quality(datamodule, reference_histogram_file)
......@@ -92,5 +92,22 @@ def test_loading(database_checkers, name: str, dataset: str):
color_planes=1,
prefixes=("DatasetA/Training", "DatasetA/Testing"),
possible_labels=(0, 1),
expected_num_labels=1,
)
limit -= 1
@pytest.mark.skip_if_rc_var_not_set("datadir.indian")
def test_loaded_image_quality(database_checkers, datadir):
reference_histogram_file = str(
datadir / "histograms/raw_data/histograms_indian_default.json"
)
datamodule = importlib.import_module(
".default", "mednet.config.data.indian"
).datamodule
datamodule.model_transforms = []
datamodule.setup("predict")
database_checkers.check_image_quality(datamodule, reference_histogram_file)
......@@ -89,5 +89,63 @@ def test_loading(database_checkers, name: str, dataset: str):
color_planes=1,
prefixes=("CXR_png/MCUCXR_0",),
possible_labels=(0, 1),
expected_num_labels=1,
)
limit -= 1
@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
def test_raw_transforms_image_quality(database_checkers, datadir):
reference_histogram_file = str(
datadir / "histograms/raw_data/histograms_montgomery_default.json"
)
datamodule = importlib.import_module(
".default", "mednet.config.data.montgomery"
).datamodule
datamodule.model_transforms = []
datamodule.setup("predict")
database_checkers.check_image_quality(datamodule, reference_histogram_file)
@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
@pytest.mark.parametrize(
"model_name",
[
"alexnet",
"densenet",
"pasa",
],
)
def test_model_transforms_image_quality(database_checkers, datadir, model_name):
# Densenet's model.name is "densenet-212" and does not correspond to its module name.
if model_name == "densenet":
reference_histogram_file = str(
datadir
/ "histograms/models/histograms_densenet-121_montgomery_default.json"
)
else:
reference_histogram_file = str(
datadir
/ f"histograms/models/histograms_{model_name}_montgomery_default.json"
)
datamodule = importlib.import_module(
".default", "mednet.config.data.montgomery"
).datamodule
model = importlib.import_module(
f".{model_name}", "mednet.config.models"
).model
datamodule.model_transforms = model.model_transforms
datamodule.setup("predict")
database_checkers.check_image_quality(
datamodule,
reference_histogram_file,
compare_type="statistical",
pearson_coeff_threshold=0.005,
)
......@@ -35,22 +35,18 @@ def test_protocol_consistency(
)
testdata = [
("default", "train", 14),
("default", "validation", 14),
("default", "test", 14),
("cardiomegaly", "train", 14),
("cardiomegaly", "validation", 14),
]
@pytest.mark.skip_if_rc_var_not_set("datadir.nih_cxr14")
@pytest.mark.parametrize(
"dataset",
[
"train",
"validation",
"test",
],
)
@pytest.mark.parametrize(
"name",
[
"default",
],
)
def test_loading(database_checkers, name: str, dataset: str):
@pytest.mark.parametrize("name,dataset,num_labels", testdata)
def test_loading(database_checkers, name: str, dataset: str, num_labels: int):
datamodule = importlib.import_module(
f".{name}", "mednet.config.data.nih_cxr14"
).datamodule
......@@ -70,9 +66,23 @@ def test_loading(database_checkers, name: str, dataset: str):
color_planes=1,
prefixes=("images/000",),
possible_labels=(0, 1),
expected_num_labels=num_labels,
expected_image_shape=(1, 1024, 1024),
)
limit -= 1
# TODO: check size 1024x1024
# TODO: check there are 14 binary labels (0, 1)
@pytest.mark.skip_if_rc_var_not_set("datadir.nih_cxr14")
def test_loaded_image_quality(database_checkers, datadir):
reference_histogram_file = str(
datadir / "histograms/raw_data/histograms_nih_cxr14_default.json"
)
datamodule = importlib.import_module(
".default", "mednet.config.data.nih_cxr14"
).datamodule
datamodule.model_transforms = []
datamodule.setup("predict")
database_checkers.check_image_quality(datamodule, reference_histogram_file)
......@@ -40,21 +40,18 @@ def test_protocol_consistency(
)
# TODO: Improve this to include other protocols
testdata = [
("idiap", "train", 193),
("idiap", "test", 1),
("tb_idiap", "train", 1),
("no_tb_idiap", "train", 14),
("cardiomegaly_idiap", "train", 14),
]
@pytest.mark.skip_if_rc_var_not_set("datadir.padchest")
@pytest.mark.parametrize(
"dataset",
[
"train",
],
)
@pytest.mark.parametrize(
"name",
[
"idiap",
],
)
def test_loading(database_checkers, name: str, dataset: str):
@pytest.mark.parametrize("name,dataset,num_labels", testdata)
def test_loading(database_checkers, name: str, dataset: str, num_labels: int):
datamodule = importlib.import_module(
f".{name}", "mednet.config.data.padchest"
).datamodule
......@@ -62,22 +59,35 @@ def test_loading(database_checkers, name: str, dataset: str):
datamodule.model_transforms = [] # should be done before setup()
datamodule.setup("predict") # sets up all datasets
loader = datamodule.predict_dataloader()[dataset]
limit = 3 # limit load checking
for batch in loader:
if limit == 0:
break
database_checkers.check_loaded_batch(
batch,
batch_size=1,
color_planes=1,
prefixes=("",),
possible_labels=(0, 1),
)
limit -= 1
# TODO: check size 1024x1024
# TODO: check there are 14 binary labels (0, 1) (in some cases, in others much
# more)...
if dataset in datamodule.predict_dataloader():
loader = datamodule.predict_dataloader()[dataset]
limit = 3 # limit load checking
for batch in loader:
if limit == 0:
break
database_checkers.check_loaded_batch(
batch,
batch_size=1,
color_planes=1,
prefixes=("",),
possible_labels=(0, 1),
expected_num_labels=num_labels,
)
limit -= 1
@pytest.mark.skip_if_rc_var_not_set("datadir.padchest")
def test_loaded_image_quality(database_checkers, datadir):
reference_histogram_file = str(
datadir / "histograms/raw_data/histograms_padchest_idiap.json"
)
datamodule = importlib.import_module(
".idiap", "mednet.config.data.padchest"
).datamodule
datamodule.model_transforms = []
datamodule.setup("predict")
database_checkers.check_image_quality(datamodule, reference_histogram_file)
......@@ -89,5 +89,22 @@ def test_loading(database_checkers, name: str, dataset: str):
color_planes=1,
prefixes=("CXR_png/CHNCXR_0",),
possible_labels=(0, 1),
expected_num_labels=1,
)
limit -= 1
@pytest.mark.skip_if_rc_var_not_set("datadir.shenzhen")
def test_loaded_image_quality(database_checkers, datadir):
reference_histogram_file = str(
datadir / "histograms/raw_data/histograms_shenzhen_default.json"
)
datamodule = importlib.import_module(
".default", "mednet.config.data.shenzhen"
).datamodule
datamodule.model_transforms = []
datamodule.setup("predict")
database_checkers.check_image_quality(datamodule, reference_histogram_file)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment