diff --git a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/datamodule.py b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/datamodule.py index 4e6052498acb8255f46020714e0e36269e80d5f8..db2f79010073dbbbefbcbbfb1a75dae3b283bea8 100644 --- a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/datamodule.py +++ b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/datamodule.py @@ -40,7 +40,7 @@ class DataModule(ConcatDataModule): indian_loader = IndianLoader(INDIAN_KEY_DATADIR) indian_package = IndianDataModule.__module__.rsplit(".", 1)[0] indian_split = make_split(indian_package, split_filename) - tbx11k_loader = TBX11kLoader() + tbx11k_loader = TBX11kLoader(ignore_bboxes=True) tbx11k_package = TBX11kLoader.__module__.rsplit(".", 1)[0] tbx11k_split = make_split(tbx11k_package, tbx11k_split_filename) diff --git a/src/mednet/config/data/tbx11k/datamodule.py b/src/mednet/config/data/tbx11k/datamodule.py index c98b28775499a7a623c2a0cdfa058d2580f9afab..993f9e50d1cbb8bdfe0d52f9d526f8e142e356e4 100644 --- a/src/mednet/config/data/tbx11k/datamodule.py +++ b/src/mednet/config/data/tbx11k/datamodule.py @@ -168,19 +168,26 @@ finding locations, as described above. class RawDataLoader(_BaseRawDataLoader): - """A specialized raw-data-loader for the TBX11k dataset.""" + """A specialized raw-data-loader for the TBX11k dataset. + + Parameters + ---------- + ignore_bboxes + If True, sample() does not return bounding boxes. + """ datadir: pathlib.Path """This variable contains the base directory where the database raw data is stored.""" - def __init__(self): + def __init__(self, ignore_bboxes: bool = False): self.datadir = pathlib.Path( load_rc().get( CONFIGURATION_KEY_DATADIR, os.path.realpath(os.curdir), ), ) + self.ignore_bboxes = ignore_bboxes def sample(self, sample: DatabaseSample) -> Sample: """Load a single image sample from the disk. @@ -206,6 +213,12 @@ class RawDataLoader(_BaseRawDataLoader): # to_pil_image(tensor).show() # __import__("pdb").set_trace() + if self.ignore_bboxes: + return tensor, dict( + label=sample[1], + name=sample[0], + ) + return tensor, dict( label=sample[1], name=sample[0], @@ -356,13 +369,15 @@ class DataModule(CachingDataModule): ---------- split_filename Name of the .json file containing the split to load. + ignore_bboxes + If True, sample() does not return bounding boxes. """ - def __init__(self, split_filename: str): + def __init__(self, split_filename: str, ignore_bboxes: bool = False): assert __package__ is not None super().__init__( database_split=make_split(__package__, split_filename), - raw_data_loader=RawDataLoader(), + raw_data_loader=RawDataLoader(ignore_bboxes=ignore_bboxes), database_name=__package__.rsplit(".", 1)[1], split_name=pathlib.Path(split_filename).stem, ) diff --git a/tests/test_montgomery_shenzhen_indian_tbx11k.py b/tests/test_montgomery_shenzhen_indian_tbx11k.py index 75464104c1f789358144e2ddcff3553b3c9a14fe..491095a6389fd1b592a051b1582dc53a5e46b133 100644 --- a/tests/test_montgomery_shenzhen_indian_tbx11k.py +++ b/tests/test_montgomery_shenzhen_indian_tbx11k.py @@ -100,6 +100,36 @@ def test_split_consistency(name: str, tbx11k_name: str): assert isinstance(combined.splits[split][3][1], tbx11k_loader) +@pytest.mark.parametrize( + "dataset", + [ + "train", + ], +) +@pytest.mark.parametrize( + "tbx11k_name", + [ + ("v1_healthy_vs_atb"), + ], +) +def test_batch_uniformity(tbx11k_name: str, dataset: str): + combined = importlib.import_module( + f".{tbx11k_name}", + "mednet.config.data.montgomery_shenzhen_indian_tbx11k", + ).datamodule + + combined.model_transforms = [] # should be done before setup() + combined.setup("predict") # sets up all datasets + + loader = combined.predict_dataloader()[dataset] + + limit = 5 # limit load checking + for batch in loader: + if limit == 0: + break + assert len(batch[1]) == 2 # label, name. No radiological sign bounding-boxes + + @pytest.mark.slow @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") @pytest.mark.skip_if_rc_var_not_set("datadir.shenzhen")