diff --git a/tests/test_montgomery_shenzhen_indian_tbx11k.py b/tests/test_montgomery_shenzhen_indian_tbx11k.py index 75464104c1f789358144e2ddcff3553b3c9a14fe..8e45c6d206f8f6c3dfdbf6c7127a5da35edad16b 100644 --- a/tests/test_montgomery_shenzhen_indian_tbx11k.py +++ b/tests/test_montgomery_shenzhen_indian_tbx11k.py @@ -100,6 +100,37 @@ def test_split_consistency(name: str, tbx11k_name: str): assert isinstance(combined.splits[split][3][1], tbx11k_loader) +@pytest.mark.parametrize( + "dataset", + [ + "train", + ], +) +@pytest.mark.parametrize( + "tbx11k_name", + [ + ("v1_healthy_vs_atb"), + ], +) +def test_batch_uniformity(tbx11k_name: str, dataset: str): + + combined = importlib.import_module( + f".{tbx11k_name}", + "mednet.config.data.montgomery_shenzhen_indian_tbx11k", + ).datamodule + + combined.model_transforms = [] # should be done before setup() + combined.setup("predict") # sets up all datasets + + loader = combined.predict_dataloader()[dataset] + + limit = 5 # limit load checking + for batch in loader: + if limit == 0: + break + assert len(batch[1]) == 2 # label, name. No radiological sign bounding-boxes + + @pytest.mark.slow @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") @pytest.mark.skip_if_rc_var_not_set("datadir.shenzhen")