Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • medai/software/mednet
1 result
Show changes
Showing
with 297 additions and 50 deletions
[datadir]
montgomery = "/idiap/resource/database/MontgomeryXraySet"
shenzhen = "/idiap/resource/database/ShenzhenXraySet"
indian = "/idiap/resource/database/TBXpredict"
tbx11k = "/idiap/resource/database/tbx11k"
...@@ -436,6 +436,102 @@ def test_evaluate_pasa_montgomery(temporary_basedir): ...@@ -436,6 +436,102 @@ def test_evaluate_pasa_montgomery(temporary_basedir):
) )
@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
def test_experiment(temporary_basedir):
from mednet.scripts.experiment import experiment
runner = CliRunner()
output_folder = str(temporary_basedir / "experiment")
num_epochs = 2
result = runner.invoke(
experiment,
[
"-vv",
"pasa",
"montgomery",
f"--epochs={num_epochs}",
f"--output-folder={output_folder}",
],
)
_assert_exit_0(result)
assert os.path.exists(os.path.join(output_folder, "command.sh"))
assert os.path.exists(os.path.join(output_folder, "predictions.json"))
assert os.path.exists(os.path.join(output_folder, "model", "command.sh"))
assert os.path.exists(os.path.join(output_folder, "model", "constants.csv"))
assert os.path.exists(
os.path.join(
output_folder, "model", f"model-at-epoch={num_epochs-1}.ckpt"
)
)
# Need to glob because we cannot be sure of the checkpoint with lowest validation loss
assert (
len(
glob.glob(
os.path.join(
output_folder,
"model",
"model-at-lowest-validation-loss-epoch=*.ckpt",
)
)
)
== 1
)
assert os.path.exists(
os.path.join(output_folder, "model", "model-summary.txt")
)
assert os.path.exists(os.path.join(output_folder, "model", "trainlog.pdf"))
assert (
len(
glob.glob(
os.path.join(
output_folder, "model", "logs", "events.out.tfevents.*"
)
)
)
== 1
)
assert os.path.exists(
os.path.join(output_folder, "evaluation", "plots.pdf")
)
assert os.path.exists(
os.path.join(output_folder, "evaluation", "summary.rst")
)
assert os.path.exists(os.path.join(output_folder, "gradcam", "saliencies"))
assert (
len(
glob.glob(
os.path.join(
output_folder,
"gradcam",
"saliencies",
"CXR_png",
"MCUCXR_*.npy",
)
)
)
== 138
)
assert os.path.exists(
os.path.join(output_folder, "gradcam", "visualizations")
)
assert (
len(
glob.glob(
os.path.join(
output_folder,
"gradcam",
"visualizations",
"CXR_png",
"MCUCXR_*.png",
)
)
)
== 58
)
# This script does not work anymore, either fix or remove the script + this test # This script does not work anymore, either fix or remove the script + this test
# def test_evaluatevis(temporary_basedir): # def test_evaluatevis(temporary_basedir):
# import pandas as pd # import pandas as pd
......
...@@ -87,5 +87,22 @@ def test_loading(database_checkers, name: str, dataset: str): ...@@ -87,5 +87,22 @@ def test_loading(database_checkers, name: str, dataset: str):
color_planes=1, color_planes=1,
prefixes=("HIV-TB_Algorithm_study_X-rays",), prefixes=("HIV-TB_Algorithm_study_X-rays",),
possible_labels=(0, 1), possible_labels=(0, 1),
expected_num_labels=1,
) )
limit -= 1 limit -= 1
@pytest.mark.skip_if_rc_var_not_set("datadir.hivtb")
def test_loaded_image_quality(database_checkers, datadir):
reference_histogram_file = str(
datadir / "histograms/raw_data/histograms_hivtb_fold_0.json"
)
datamodule = importlib.import_module(
".fold_0", "mednet.config.data.hivtb"
).datamodule
datamodule.model_transforms = []
datamodule.setup("predict")
database_checkers.check_image_quality(datamodule, reference_histogram_file)
...@@ -92,5 +92,22 @@ def test_loading(database_checkers, name: str, dataset: str): ...@@ -92,5 +92,22 @@ def test_loading(database_checkers, name: str, dataset: str):
color_planes=1, color_planes=1,
prefixes=("DatasetA/Training", "DatasetA/Testing"), prefixes=("DatasetA/Training", "DatasetA/Testing"),
possible_labels=(0, 1), possible_labels=(0, 1),
expected_num_labels=1,
) )
limit -= 1 limit -= 1
@pytest.mark.skip_if_rc_var_not_set("datadir.indian")
def test_loaded_image_quality(database_checkers, datadir):
reference_histogram_file = str(
datadir / "histograms/raw_data/histograms_indian_default.json"
)
datamodule = importlib.import_module(
".default", "mednet.config.data.indian"
).datamodule
datamodule.model_transforms = []
datamodule.setup("predict")
database_checkers.check_image_quality(datamodule, reference_histogram_file)
...@@ -89,5 +89,63 @@ def test_loading(database_checkers, name: str, dataset: str): ...@@ -89,5 +89,63 @@ def test_loading(database_checkers, name: str, dataset: str):
color_planes=1, color_planes=1,
prefixes=("CXR_png/MCUCXR_0",), prefixes=("CXR_png/MCUCXR_0",),
possible_labels=(0, 1), possible_labels=(0, 1),
expected_num_labels=1,
) )
limit -= 1 limit -= 1
@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
def test_raw_transforms_image_quality(database_checkers, datadir):
reference_histogram_file = str(
datadir / "histograms/raw_data/histograms_montgomery_default.json"
)
datamodule = importlib.import_module(
".default", "mednet.config.data.montgomery"
).datamodule
datamodule.model_transforms = []
datamodule.setup("predict")
database_checkers.check_image_quality(datamodule, reference_histogram_file)
@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
@pytest.mark.parametrize(
"model_name",
[
"alexnet",
"densenet",
"pasa",
],
)
def test_model_transforms_image_quality(database_checkers, datadir, model_name):
# Densenet's model.name is "densenet-212" and does not correspond to its module name.
if model_name == "densenet":
reference_histogram_file = str(
datadir
/ "histograms/models/histograms_densenet-121_montgomery_default.json"
)
else:
reference_histogram_file = str(
datadir
/ f"histograms/models/histograms_{model_name}_montgomery_default.json"
)
datamodule = importlib.import_module(
".default", "mednet.config.data.montgomery"
).datamodule
model = importlib.import_module(
f".{model_name}", "mednet.config.models"
).model
datamodule.model_transforms = model.model_transforms
datamodule.setup("predict")
database_checkers.check_image_quality(
datamodule,
reference_histogram_file,
compare_type="statistical",
pearson_coeff_threshold=0.005,
)
...@@ -35,22 +35,18 @@ def test_protocol_consistency( ...@@ -35,22 +35,18 @@ def test_protocol_consistency(
) )
testdata = [
("default", "train", 14),
("default", "validation", 14),
("default", "test", 14),
("cardiomegaly", "train", 14),
("cardiomegaly", "validation", 14),
]
@pytest.mark.skip_if_rc_var_not_set("datadir.nih_cxr14") @pytest.mark.skip_if_rc_var_not_set("datadir.nih_cxr14")
@pytest.mark.parametrize( @pytest.mark.parametrize("name,dataset,num_labels", testdata)
"dataset", def test_loading(database_checkers, name: str, dataset: str, num_labels: int):
[
"train",
"validation",
"test",
],
)
@pytest.mark.parametrize(
"name",
[
"default",
],
)
def test_loading(database_checkers, name: str, dataset: str):
datamodule = importlib.import_module( datamodule = importlib.import_module(
f".{name}", "mednet.config.data.nih_cxr14" f".{name}", "mednet.config.data.nih_cxr14"
).datamodule ).datamodule
...@@ -70,9 +66,23 @@ def test_loading(database_checkers, name: str, dataset: str): ...@@ -70,9 +66,23 @@ def test_loading(database_checkers, name: str, dataset: str):
color_planes=1, color_planes=1,
prefixes=("images/000",), prefixes=("images/000",),
possible_labels=(0, 1), possible_labels=(0, 1),
expected_num_labels=num_labels,
expected_image_shape=(1, 1024, 1024),
) )
limit -= 1 limit -= 1
# TODO: check size 1024x1024 @pytest.mark.skip_if_rc_var_not_set("datadir.nih_cxr14")
# TODO: check there are 14 binary labels (0, 1) def test_loaded_image_quality(database_checkers, datadir):
reference_histogram_file = str(
datadir / "histograms/raw_data/histograms_nih_cxr14_default.json"
)
datamodule = importlib.import_module(
".default", "mednet.config.data.nih_cxr14"
).datamodule
datamodule.model_transforms = []
datamodule.setup("predict")
database_checkers.check_image_quality(datamodule, reference_histogram_file)
...@@ -40,21 +40,18 @@ def test_protocol_consistency( ...@@ -40,21 +40,18 @@ def test_protocol_consistency(
) )
# TODO: Improve this to include other protocols testdata = [
("idiap", "train", 193),
("idiap", "test", 1),
("tb_idiap", "train", 1),
("no_tb_idiap", "train", 14),
("cardiomegaly_idiap", "train", 14),
]
@pytest.mark.skip_if_rc_var_not_set("datadir.padchest") @pytest.mark.skip_if_rc_var_not_set("datadir.padchest")
@pytest.mark.parametrize( @pytest.mark.parametrize("name,dataset,num_labels", testdata)
"dataset", def test_loading(database_checkers, name: str, dataset: str, num_labels: int):
[
"train",
],
)
@pytest.mark.parametrize(
"name",
[
"idiap",
],
)
def test_loading(database_checkers, name: str, dataset: str):
datamodule = importlib.import_module( datamodule = importlib.import_module(
f".{name}", "mednet.config.data.padchest" f".{name}", "mednet.config.data.padchest"
).datamodule ).datamodule
...@@ -62,22 +59,35 @@ def test_loading(database_checkers, name: str, dataset: str): ...@@ -62,22 +59,35 @@ def test_loading(database_checkers, name: str, dataset: str):
datamodule.model_transforms = [] # should be done before setup() datamodule.model_transforms = [] # should be done before setup()
datamodule.setup("predict") # sets up all datasets datamodule.setup("predict") # sets up all datasets
loader = datamodule.predict_dataloader()[dataset] if dataset in datamodule.predict_dataloader():
loader = datamodule.predict_dataloader()[dataset]
limit = 3 # limit load checking
for batch in loader: limit = 3 # limit load checking
if limit == 0: for batch in loader:
break if limit == 0:
database_checkers.check_loaded_batch( break
batch, database_checkers.check_loaded_batch(
batch_size=1, batch,
color_planes=1, batch_size=1,
prefixes=("",), color_planes=1,
possible_labels=(0, 1), prefixes=("",),
) possible_labels=(0, 1),
limit -= 1 expected_num_labels=num_labels,
)
limit -= 1
# TODO: check size 1024x1024
# TODO: check there are 14 binary labels (0, 1) (in some cases, in others much
# more)... @pytest.mark.skip_if_rc_var_not_set("datadir.padchest")
def test_loaded_image_quality(database_checkers, datadir):
reference_histogram_file = str(
datadir / "histograms/raw_data/histograms_padchest_idiap.json"
)
datamodule = importlib.import_module(
".idiap", "mednet.config.data.padchest"
).datamodule
datamodule.model_transforms = []
datamodule.setup("predict")
database_checkers.check_image_quality(datamodule, reference_histogram_file)
...@@ -89,5 +89,22 @@ def test_loading(database_checkers, name: str, dataset: str): ...@@ -89,5 +89,22 @@ def test_loading(database_checkers, name: str, dataset: str):
color_planes=1, color_planes=1,
prefixes=("CXR_png/CHNCXR_0",), prefixes=("CXR_png/CHNCXR_0",),
possible_labels=(0, 1), possible_labels=(0, 1),
expected_num_labels=1,
) )
limit -= 1 limit -= 1
@pytest.mark.skip_if_rc_var_not_set("datadir.shenzhen")
def test_loaded_image_quality(database_checkers, datadir):
reference_histogram_file = str(
datadir / "histograms/raw_data/histograms_shenzhen_default.json"
)
datamodule = importlib.import_module(
".default", "mednet.config.data.shenzhen"
).datamodule
datamodule.model_transforms = []
datamodule.setup("predict")
database_checkers.check_image_quality(datamodule, reference_histogram_file)
...@@ -93,5 +93,22 @@ def test_loading(database_checkers, name: str, dataset: str): ...@@ -93,5 +93,22 @@ def test_loading(database_checkers, name: str, dataset: str):
"TBPOC_CXR/tbpoc-", "TBPOC_CXR/tbpoc-",
), ),
possible_labels=(0, 1), possible_labels=(0, 1),
expected_num_labels=1,
) )
limit -= 1 limit -= 1
@pytest.mark.skip_if_rc_var_not_set("datadir.tbpoc")
def test_loaded_image_quality(database_checkers, datadir):
reference_histogram_file = str(
datadir / "histograms/raw_data/histograms_tbpoc_fold_0.json"
)
datamodule = importlib.import_module(
".fold_0", "mednet.config.data.tbpoc"
).datamodule
datamodule.model_transforms = []
datamodule.setup("predict")
database_checkers.check_image_quality(datamodule, reference_histogram_file)