diff --git a/tests/test_padchest.py b/tests/test_padchest.py index dfa491c879880ec01126dc11e4a112f775abaac0..97bf8a9cae7e8329b3afe4284dab176cde034015 100644 --- a/tests/test_padchest.py +++ b/tests/test_padchest.py @@ -40,18 +40,18 @@ def test_protocol_consistency( ) -# TODO: Improve this to include other protocols @pytest.mark.skip_if_rc_var_not_set("datadir.padchest") @pytest.mark.parametrize( "dataset", - [ - "train", - ], + ["train", "test"], ) @pytest.mark.parametrize( "name", [ "idiap", + "tb_idiap", + "no_tb_idiap", + "cardiomegaly_idiap", ], ) def test_loading(database_checkers, name: str, dataset: str): @@ -62,20 +62,32 @@ def test_loading(database_checkers, name: str, dataset: str): datamodule.model_transforms = [] # should be done before setup() datamodule.setup("predict") # sets up all datasets - loader = datamodule.predict_dataloader()[dataset] + if dataset in datamodule.predict_dataloader(): + # Not all datamodules have a test set + try: + loader = datamodule.predict_dataloader()[dataset] + except KeyError as e: + if str(e) == "test" and name in [ + "idiap", + "no_tb_idiap", + "cardiomegaly_idiap", + ]: + assert True + else: + assert False - limit = 3 # limit load checking - for batch in loader: - if limit == 0: - break - database_checkers.check_loaded_batch( - batch, - batch_size=1, - color_planes=1, - prefixes=("",), - possible_labels=(0, 1), - ) - limit -= 1 + limit = 3 # limit load checking + for batch in loader: + if limit == 0: + break + database_checkers.check_loaded_batch( + batch, + batch_size=1, + color_planes=1, + prefixes=("",), + possible_labels=(0, 1), + ) + limit -= 1 # TODO: check size 1024x1024