diff --git a/pyproject.toml b/pyproject.toml index 0a83ba62bb8938c36e967bc0dcc415552f7a6175..910ff540249164d41d13a00824dfcdfc5df94552 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -427,6 +427,9 @@ lwnet = "mednet.libs.segmentation.config.models.lwnet" chasedb1 = "mednet.libs.segmentation.config.data.chasedb1.first_annotator" chasedb1-2nd = "mednet.libs.segmentation.config.data.chasedb1.second_annotator" +# cxr8 - cxr +cxr8 = "mednet.libs.segmentation.config.data.cxr8.default" + # drive dataset - retinography drive = "mednet.libs.segmentation.config.data.drive.default" diff --git a/src/mednet/libs/classification/config/data/hivtb/datamodule.py b/src/mednet/libs/classification/config/data/hivtb/datamodule.py index a93606bb6547b5cd196ce58100613a41ed130169..93627558cc62239c7341e412895c1d8931fb3910 100644 --- a/src/mednet/libs/classification/config/data/hivtb/datamodule.py +++ b/src/mednet/libs/classification/config/data/hivtb/datamodule.py @@ -15,10 +15,8 @@ from mednet.libs.classification.data.typing import ( ) from mednet.libs.common.data.datamodule import CachingDataModule from mednet.libs.common.data.image_utils import remove_black_borders - from mednet.libs.common.data.split import make_split from mednet.libs.common.data.typing import Sample - from torchvision.transforms.functional import to_tensor from ....utils.rc import load_rc @@ -127,7 +125,7 @@ class DataModule(CachingDataModule): assert __package__ is not None super().__init__( database_split=make_split(__package__, split_filename), - raw_data_loader=_ClassificationRawDataLoader(), - database_name=__package__.rsplit(".", 1)[1], + raw_data_loader=ClassificationRawDataLoader(), + database_name=__package__.rsplit(".", 1)[1], split_name=pathlib.Path(split_filename).stem, ) diff --git a/src/mednet/libs/classification/config/data/indian/datamodule.py b/src/mednet/libs/classification/config/data/indian/datamodule.py index 63ff00c284201f797dc4eb51f5b1e7ef8c723f4f..efcce8e7e7e8b3d50e94c34ebfe394ca4aff933b 100644 --- a/src/mednet/libs/classification/config/data/indian/datamodule.py +++ b/src/mednet/libs/classification/config/data/indian/datamodule.py @@ -11,9 +11,8 @@ import pathlib from mednet.libs.common.data.datamodule import CachingDataModule from mednet.libs.common.data.split import make_split -from mednet.libs.classification.data.typing import ( - ClassificationRawDataLoader as _ClassificationRawDataLoader, -) +from ....config.data.shenzhen.datamodule import ClassificationRawDataLoader + CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2]) """Key to search for in the configuration file for the root directory of this database.""" @@ -64,8 +63,9 @@ class DataModule(CachingDataModule): assert __package__ is not None super().__init__( database_split=make_split(__package__, split_filename), - raw_data_loader=_ClassificationRawDataLoader(), - database_name=__package__.rsplit(".", 1)[1], + raw_data_loader=ClassificationRawDataLoader( + config_variable=CONFIGURATION_KEY_DATADIR + ), + database_name=__package__.rsplit(".", 1)[1], split_name=pathlib.Path(split_filename).stem, ) - diff --git a/src/mednet/libs/classification/config/data/montgomery/datamodule.py b/src/mednet/libs/classification/config/data/montgomery/datamodule.py index 2afbb8d6cd478371bff6f2828a737a5d93959183..8ccd7536981665214855016d98c482ef6dd90acc 100644 --- a/src/mednet/libs/classification/config/data/montgomery/datamodule.py +++ b/src/mednet/libs/classification/config/data/montgomery/datamodule.py @@ -15,10 +15,8 @@ from mednet.libs.classification.data.typing import ( ) from mednet.libs.common.data.datamodule import CachingDataModule from mednet.libs.common.data.image_utils import remove_black_borders - from mednet.libs.common.data.split import make_split from mednet.libs.common.data.typing import Sample - from torchvision.transforms.functional import to_tensor from ....utils.rc import load_rc @@ -28,7 +26,7 @@ CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2]) database.""" -class RawDataLoader(_BaseRawDataLoader): +class ClassificationRawDataLoader(_ClassificationRawDataLoader): """A specialized raw-data-loader for the Montgomery dataset. Parameters @@ -140,7 +138,7 @@ class DataModule(CachingDataModule): assert __package__ is not None super().__init__( database_split=make_split(__package__, split_filename), - raw_data_loader=_ClassificationRawDataLoader(), - database_name=__package__.rsplit(".", 1)[1], + raw_data_loader=ClassificationRawDataLoader(), + database_name=__package__.rsplit(".", 1)[1], split_name=pathlib.Path(split_filename).stem, ) diff --git a/src/mednet/libs/classification/config/data/nih_cxr14/datamodule.py b/src/mednet/libs/classification/config/data/nih_cxr14/datamodule.py index 0edb1aef91c0f0a19acbd83bf6dfcc1103b89fae..b022d50a0e46edfdcf2572a18a736f558284d7ec 100644 --- a/src/mednet/libs/classification/config/data/nih_cxr14/datamodule.py +++ b/src/mednet/libs/classification/config/data/nih_cxr14/datamodule.py @@ -14,10 +14,8 @@ from mednet.libs.classification.data.typing import ( ClassificationRawDataLoader as _ClassificationRawDataLoader, ) from mednet.libs.common.data.datamodule import CachingDataModule - from mednet.libs.common.data.split import make_split from mednet.libs.common.data.typing import Sample - from torchvision.transforms.functional import to_tensor from ....utils.rc import load_rc @@ -86,7 +84,9 @@ class ClassificationRawDataLoader(_ClassificationRawDataLoader): # for folder lookup efficiency, data is split into subfolders # each original file is on the subfolder `f[:5]/f`, where f # is the original file basename - file_path = file_path.parent / file_path.name[:5] / file_path.name + file_path = pathlib.Path( + file_path.parent / file_path.name[:5] / file_path.name + ) # N.B.: some NIH CXR-14 images are encoded as color PNGs with an alpha # channel. Most, are grayscale PNGs @@ -176,7 +176,7 @@ class DataModule(CachingDataModule): assert __package__ is not None super().__init__( database_split=make_split(__package__, split_filename), - raw_data_loader=_ClassificationRawDataLoader(), - database_name=__package__.rsplit(".", 1)[1], + raw_data_loader=ClassificationRawDataLoader(), + database_name=__package__.rsplit(".", 1)[1], split_name=pathlib.Path(split_filename).stem, ) diff --git a/src/mednet/libs/classification/config/data/padchest/datamodule.py b/src/mednet/libs/classification/config/data/padchest/datamodule.py index 51a20bbdbc9392ce7917cb44b6db3a0233f9a857..0c6dfba30d92ee353010ea2220220c96c19fb8e4 100644 --- a/src/mednet/libs/classification/config/data/padchest/datamodule.py +++ b/src/mednet/libs/classification/config/data/padchest/datamodule.py @@ -16,10 +16,8 @@ from mednet.libs.classification.data.typing import ( ) from mednet.libs.common.data.datamodule import CachingDataModule from mednet.libs.common.data.image_utils import remove_black_borders - from mednet.libs.common.data.split import make_split from mednet.libs.common.data.typing import Sample - from torchvision.transforms.functional import to_tensor from ....utils.rc import load_rc @@ -329,7 +327,7 @@ class DataModule(CachingDataModule): assert __package__ is not None super().__init__( database_split=make_split(__package__, split_filename), - raw_data_loader=_ClassificationRawDataLoader(), - database_name=__package__.rsplit(".", 1)[1], + raw_data_loader=ClassificationRawDataLoader(), + database_name=__package__.rsplit(".", 1)[1], split_name=pathlib.Path(split_filename).stem, ) diff --git a/src/mednet/libs/classification/config/data/shenzhen/datamodule.py b/src/mednet/libs/classification/config/data/shenzhen/datamodule.py index 45aa327e06abbb22d9f7bbbc6f957a49291a4a06..b1b78c574b60f9b00fc48cd88d43746265cdc8b4 100644 --- a/src/mednet/libs/classification/config/data/shenzhen/datamodule.py +++ b/src/mednet/libs/classification/config/data/shenzhen/datamodule.py @@ -15,10 +15,8 @@ from mednet.libs.classification.data.typing import ( ) from mednet.libs.common.data.datamodule import CachingDataModule from mednet.libs.common.data.image_utils import remove_black_borders - from mednet.libs.common.data.split import make_split from mednet.libs.common.data.typing import Sample - from torchvision.transforms.functional import to_tensor from ....utils.rc import load_rc @@ -139,7 +137,7 @@ class DataModule(CachingDataModule): assert __package__ is not None super().__init__( database_split=make_split(__package__, split_filename), - raw_data_loader=_ClassificationRawDataLoader(), - database_name=__package__.rsplit(".", 1)[1], + raw_data_loader=ClassificationRawDataLoader(), + database_name=__package__.rsplit(".", 1)[1], split_name=pathlib.Path(split_filename).stem, ) diff --git a/src/mednet/libs/classification/config/data/tbpoc/datamodule.py b/src/mednet/libs/classification/config/data/tbpoc/datamodule.py index f92afc9d2e2ac141f7c4889db32e72c187f93f13..7c943b3e503c69bf1efb1f77be1468e2be910bf9 100644 --- a/src/mednet/libs/classification/config/data/tbpoc/datamodule.py +++ b/src/mednet/libs/classification/config/data/tbpoc/datamodule.py @@ -11,10 +11,8 @@ from mednet.libs.classification.data.typing import ( ) from mednet.libs.common.data.datamodule import CachingDataModule from mednet.libs.common.data.image_utils import remove_black_borders - from mednet.libs.common.data.split import make_split from mednet.libs.common.data.typing import Sample - from torchvision.transforms.functional import to_tensor from ....utils.rc import load_rc @@ -125,7 +123,7 @@ class DataModule(CachingDataModule): assert __package__ is not None super().__init__( database_split=make_split(__package__, split_filename), - raw_data_loader=_ClassificationRawDataLoader(), - database_name=__package__.rsplit(".", 1)[1], + raw_data_loader=ClassificationRawDataLoader(), + database_name=__package__.rsplit(".", 1)[1], split_name=pathlib.Path(split_filename).stem, ) diff --git a/src/mednet/libs/classification/tests/test_hivtb.py b/src/mednet/libs/classification/tests/test_hivtb.py index 2075f91a4f184d1ae1f191123144fc25c155c102..91258839bb00566b1330dd364a5f3555e4e4c47f 100644 --- a/src/mednet/libs/classification/tests/test_hivtb.py +++ b/src/mednet/libs/classification/tests/test_hivtb.py @@ -39,7 +39,7 @@ def test_protocol_consistency( from mednet.libs.common.data.split import make_split database_checkers.check_split( - make_split("mednet.config.data.hivtb", f"{split}.json"), + make_split("mednet.libs.classification.config.data.hivtb", f"{split}.json"), lengths=lengths, prefixes=("HIV-TB_Algorithm_study_X-rays",), possible_labels=(0, 1), diff --git a/src/mednet/libs/classification/tests/test_indian.py b/src/mednet/libs/classification/tests/test_indian.py index c55f82b388942d2cea721de5cea428797bc8aa06..952bc9b76727e72e25c41efe3bdfeb1de8e17cc8 100644 --- a/src/mednet/libs/classification/tests/test_indian.py +++ b/src/mednet/libs/classification/tests/test_indian.py @@ -43,7 +43,7 @@ def test_protocol_consistency( from mednet.libs.common.data.split import make_split database_checkers.check_split( - make_split("mednet.config.data.indian", f"{split}.json"), + make_split("mednet.libs.classification.config.data.indian", f"{split}.json"), lengths=lengths, prefixes=("DatasetA/Training", "DatasetA/Testing"), possible_labels=(0, 1), diff --git a/src/mednet/libs/classification/tests/test_montgomery.py b/src/mednet/libs/classification/tests/test_montgomery.py index 24ff0c9eae169475069867aac7051e94e8981b73..ef8715ea8b8caf2e3047831bcccd01863e53e25e 100644 --- a/src/mednet/libs/classification/tests/test_montgomery.py +++ b/src/mednet/libs/classification/tests/test_montgomery.py @@ -41,7 +41,9 @@ def test_protocol_consistency( from mednet.libs.common.data.split import make_split database_checkers.check_split( - make_split("mednet.config.data.montgomery", f"{split}.json"), + make_split( + "mednet.libs.classification.config.data.montgomery", f"{split}.json" + ), lengths=lengths, prefixes=("CXR_png/MCUCXR_0",), possible_labels=(0, 1), @@ -114,12 +116,8 @@ def test_loading(database_checkers, name: str, dataset: str): @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") def test_raw_transforms_image_quality(database_checkers, datadir: pathlib.Path): datamodule = importlib.import_module( -<<<<<<< HEAD - ".default", "mednet.config.data.montgomery" -======= ".default", "mednet.libs.classification.config.data.montgomery", ->>>>>>> b1ea1c0 ([mednet] Start reorganizing into monorepo) ).datamodule datamodule.model_transforms = [] diff --git a/src/mednet/libs/classification/tests/test_nih_cxr14.py b/src/mednet/libs/classification/tests/test_nih_cxr14.py index e3466a8f53d055743c9888c92800cb9375decca8..da5504ba37695368c7a77a3bc1e25b5a51a14b89 100644 --- a/src/mednet/libs/classification/tests/test_nih_cxr14.py +++ b/src/mednet/libs/classification/tests/test_nih_cxr14.py @@ -31,7 +31,7 @@ def test_protocol_consistency( from mednet.libs.common.data.split import make_split database_checkers.check_split( - make_split("mednet.config.data.nih_cxr14", f"{split}"), + make_split("mednet.libs.classification.config.data.nih_cxr14", f"{split}"), lengths=lengths, prefixes=("images/000",), possible_labels=(0, 1), diff --git a/src/mednet/libs/classification/tests/test_padchest.py b/src/mednet/libs/classification/tests/test_padchest.py index 8e420e17ea9cadd26f3886f3c4af09c4a533ea1e..af6007df4abb53f4d14fd964da4ad9e129eca804 100644 --- a/src/mednet/libs/classification/tests/test_padchest.py +++ b/src/mednet/libs/classification/tests/test_padchest.py @@ -36,7 +36,7 @@ def test_protocol_consistency( from mednet.libs.common.data.split import make_split database_checkers.check_split( - make_split("mednet.config.data.padchest", split), + make_split("mednet.libs.classification.config.data.padchest", split), lengths=lengths, prefixes=("",), possible_labels=(0, 1), diff --git a/src/mednet/libs/classification/tests/test_shenzhen.py b/src/mednet/libs/classification/tests/test_shenzhen.py index ec40545ad811a4281311179f553f92bbd1746ae3..afc55032dc27fa6e29d6515387f37888d2ed8f20 100644 --- a/src/mednet/libs/classification/tests/test_shenzhen.py +++ b/src/mednet/libs/classification/tests/test_shenzhen.py @@ -40,7 +40,7 @@ def test_protocol_consistency( from mednet.libs.common.data.split import make_split database_checkers.check_split( - make_split("mednet.config.data.shenzhen", f"{split}.json"), + make_split("mednet.libs.classification.config.data.shenzhen", f"{split}.json"), lengths=lengths, prefixes=("CXR_png/CHNCXR_0",), possible_labels=(0, 1), diff --git a/src/mednet/libs/classification/tests/test_tbpoc.py b/src/mednet/libs/classification/tests/test_tbpoc.py index 3fa562406403c0e9dd6fa6ba4974e64195cafa90..81288179e3cc7590726483faefa300344c50edeb 100644 --- a/src/mednet/libs/classification/tests/test_tbpoc.py +++ b/src/mednet/libs/classification/tests/test_tbpoc.py @@ -39,7 +39,7 @@ def test_protocol_consistency( from mednet.libs.common.data.split import make_split database_checkers.check_split( - make_split("mednet.config.data.tbpoc", f"{split}.json"), + make_split("mednet.libs.classification.config.data.tbpoc", f"{split}.json"), lengths=lengths, prefixes=( "TBPOC_CXR/TBPOC-", diff --git a/src/mednet/libs/classification/tests/test_tbx11k.py b/src/mednet/libs/classification/tests/test_tbx11k.py index be965f9a906eb2ff99a6a3ff8574c8de19d9ac31..932b87825a6af716e6a986bbbe0a2f18f2a2bdd5 100644 --- a/src/mednet/libs/classification/tests/test_tbx11k.py +++ b/src/mednet/libs/classification/tests/test_tbx11k.py @@ -142,7 +142,7 @@ def test_protocol_consistency( from mednet.libs.common.data.split import make_split database_checkers.check_split( - make_split("mednet.config.data.tbx11k", f"{split}.json"), + make_split("mednet.libs.classification.config.data.tbx11k", f"{split}.json"), lengths=lengths, prefixes=prefixes, possible_labels=(0, 1), diff --git a/src/mednet/libs/common/models/model.py b/src/mednet/libs/common/models/model.py index d01cc14a4c7e58d10cb8adb5fe875c9198761d9e..53a14658fbdf6edadd8190c9ae2f0e308e7cd54b 100644 --- a/src/mednet/libs/common/models/model.py +++ b/src/mednet/libs/common/models/model.py @@ -11,8 +11,8 @@ import torch.nn import torch.optim.optimizer import torch.utils.data import torchvision.transforms +from mednet.libs.common.data.typing import TransformSequence -from medbase.data.typing import TransformSequence from .loss_weights import get_positive_weights from .typing import Checkpoint @@ -224,7 +224,5 @@ class Model(pl.LightningModule): logger.warning( "Datamodule does not contain a validation dataloader. The training dataloader will be used instead." ) - validation_weights = get_positive_weights( - datamodule.train_dataloader() - ) + validation_weights = get_positive_weights(datamodule.train_dataloader()) self._validation_loss_arguments["pos_weight"] = validation_weights diff --git a/src/mednet/libs/common/tests/test_database_split.py b/src/mednet/libs/common/tests/test_database_split.py index e77c4cbf8c932f39f5a794b10a23778a1e48d837..114dd1bac72185685e41544bf3b7f393660c18f7 100644 --- a/src/mednet/libs/common/tests/test_database_split.py +++ b/src/mednet/libs/common/tests/test_database_split.py @@ -3,31 +3,9 @@ # SPDX-License-Identifier: GPL-3.0-or-later """Test code for datasets.""" -<<<<<<< HEAD -from medbase.data.split import JSONDatabaseSplit -======= -from mednet.libs.common.data.split import CSVDatabaseSplit, JSONDatabaseSplit +from mednet.libs.common.data.split import JSONDatabaseSplit -def test_csv_loading(datadir): - # tests if we can build a simple CSV loader for the Iris Flower dataset - database_split = CSVDatabaseSplit(datadir) - - assert len(database_split["iris-train"]) == 75 - for k in database_split["iris-train"]: - for f in range(4): - assert isinstance(k[f], str) # csv only loads stringd - assert isinstance(k[4], str) - - assert len(database_split["iris-test"]) == 75 - for k in database_split["iris-test"]: - for f in range(4): - assert isinstance(k[f], str) # csv only loads stringd - assert isinstance(k[4], str) - assert k[4] in ("Iris-setosa", "Iris-versicolor", "Iris-virginica") - ->>>>>>> b1ea1c0 ([mednet] Start reorganizing into monorepo) - def test_json_loading(datadir): # tests if we can build a simple JSON loader for the Iris Flower dataset diff --git a/src/mednet/libs/segmentation/config/data/chasedb1/datamodule.py b/src/mednet/libs/segmentation/config/data/chasedb1/datamodule.py index f8dfa0bc895d49e77f4191259b7ef4d67f8a2e8b..e635ea2903b5a05c174c2b4686935bd3b9f8178f 100644 --- a/src/mednet/libs/segmentation/config/data/chasedb1/datamodule.py +++ b/src/mednet/libs/segmentation/config/data/chasedb1/datamodule.py @@ -3,15 +3,14 @@ # SPDX-License-Identifier: GPL-3.0-or-later """CHASE-DB1 dataset for Vessel Segmentation.""" -import importlib.resources import os -from pathlib import Path +import pathlib import PIL.Image import pkg_resources from mednet.libs.common.data.datamodule import CachingDataModule -from mednet.libs.common.data.split import JSONDatabaseSplit -from mednet.libs.common.data.typing import DatabaseSplit, Sample +from mednet.libs.common.data.split import make_split +from mednet.libs.common.data.typing import Sample from mednet.libs.segmentation.data.typing import ( SegmentationRawDataLoader as _SegmentationRawDataLoader, ) @@ -28,16 +27,18 @@ database.""" class SegmentationRawDataLoader(_SegmentationRawDataLoader): """A specialized raw-data-loader for the Chase-db1 dataset.""" - datadir: str + datadir: pathlib.Path """This variable contains the base directory where the database raw data is stored.""" def __init__(self): - self.datadir = load_rc().get( - CONFIGURATION_KEY_DATADIR, os.path.realpath(os.curdir) + self.datadir = pathlib.Path( + load_rc().get(CONFIGURATION_KEY_DATADIR, os.path.realpath(os.curdir)) ) - self._pkg_path = pkg_resources.resource_filename(__name__, "masks") + self._pkg_path = pathlib.Path( + pkg_resources.resource_filename(__name__, "masks") + ) def sample(self, sample: tuple[str, str, str]) -> Sample: """Load a single image sample from the disk. @@ -53,20 +54,16 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader): The sample representation. """ - image = PIL.Image.open(Path(self.datadir) / str(sample[0])).convert( - mode="RGB" - ) + image = PIL.Image.open(self.datadir / sample[0]).convert(mode="RGB") tensor = tv_tensors.Image(to_tensor(image)) target = tv_tensors.Image( to_tensor( - PIL.Image.open(Path(self.datadir) / str(sample[1])).convert( - mode="1", dither=None - ) + PIL.Image.open(self.datadir / sample[1]).convert(mode="1", dither=None) ) ) mask = tv_tensors.Mask( to_tensor( - PIL.Image.open(Path(self._pkg_path) / str(sample[2])).convert( + PIL.Image.open(self._pkg_path / sample[2]).convert( mode="1", dither=None ) ) @@ -75,24 +72,6 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader): return tensor, dict(target=target, mask=mask, name=sample[0]) # type: ignore[arg-type] -def make_split(basename: str) -> DatabaseSplit: - """Return a database split for the Chase-db1 database. - - Parameters - ---------- - basename - Name of the .json file containing the split to load. - - Returns - ------- - An instance of DatabaseSplit. - """ - - return JSONDatabaseSplit( - importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(basename) - ) - - class DataModule(CachingDataModule): """CHASE-DB1 dataset for Vessel Segmentation. @@ -129,7 +108,10 @@ class DataModule(CachingDataModule): """ def __init__(self, split_filename: str): + assert __package__ is not None super().__init__( - database_split=make_split(split_filename), + database_split=make_split(__package__, split_filename), raw_data_loader=SegmentationRawDataLoader(), + database_name=__package__.rsplit(".", 1)[1], + split_name=pathlib.Path(split_filename).stem, ) diff --git a/src/mednet/libs/segmentation/config/data/drive/datamodule.py b/src/mednet/libs/segmentation/config/data/drive/datamodule.py index b3a2a76730d5be54881d57dccf08e941dceff260..14aa36aeab721b7213adc5a4a552877d2ae4a0a7 100644 --- a/src/mednet/libs/segmentation/config/data/drive/datamodule.py +++ b/src/mednet/libs/segmentation/config/data/drive/datamodule.py @@ -3,14 +3,13 @@ # SPDX-License-Identifier: GPL-3.0-or-later """COVD-DRIVE for Vessel Segmentation.""" -import importlib.resources import os -from pathlib import Path +import pathlib import PIL.Image from mednet.libs.common.data.datamodule import CachingDataModule -from mednet.libs.common.data.split import JSONDatabaseSplit -from mednet.libs.common.data.typing import DatabaseSplit, Sample +from mednet.libs.common.data.split import make_split +from mednet.libs.common.data.typing import Sample from mednet.libs.common.models.transforms import crop_image_to_mask from mednet.libs.segmentation.data.typing import ( SegmentationRawDataLoader as _SegmentationRawDataLoader, @@ -28,13 +27,13 @@ database.""" class SegmentationRawDataLoader(_SegmentationRawDataLoader): """A specialized raw-data-loader for the Drive dataset.""" - datadir: str + datadir: pathlib.Path """This variable contains the base directory where the database raw data is stored.""" def __init__(self): - self.datadir = load_rc().get( - CONFIGURATION_KEY_DATADIR, os.path.realpath(os.curdir) + self.datadir = pathlib.Path( + load_rc().get(CONFIGURATION_KEY_DATADIR, os.path.realpath(os.curdir)) ) def sample(self, sample: tuple[str, str, str]) -> Sample: @@ -51,20 +50,12 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader): The sample representation. """ - image = to_tensor( - PIL.Image.open(Path(self.datadir) / str(sample[0])).convert( - mode="RGB" - ) - ) + image = to_tensor(PIL.Image.open(self.datadir / sample[0]).convert(mode="RGB")) target = to_tensor( - PIL.Image.open(Path(self.datadir) / str(sample[1])).convert( - mode="1", dither=None - ) + PIL.Image.open(self.datadir / sample[1]).convert(mode="1", dither=None) ) mask = to_tensor( - PIL.Image.open(Path(self.datadir) / str(sample[2])).convert( - mode="1", dither=None - ) + PIL.Image.open(self.datadir / sample[2]).convert(mode="1", dither=None) ) tensor = tv_tensors.Image(crop_image_to_mask(image, mask)) @@ -74,24 +65,6 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader): return tensor, dict(target=target, mask=mask, name=sample[0]) # type: ignore[arg-type] -def make_split(basename: str) -> DatabaseSplit: - """Return a database split for the Drive database. - - Parameters - ---------- - basename - Name of the .json file containing the split to load. - - Returns - ------- - An instance of DatabaseSplit. - """ - - return JSONDatabaseSplit( - importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(basename) - ) - - class DataModule(CachingDataModule): """DRIVE dataset for Vessel Segmentation. @@ -117,7 +90,10 @@ class DataModule(CachingDataModule): """ def __init__(self, split_filename: str): + assert __package__ is not None super().__init__( - database_split=make_split(split_filename), + database_split=make_split(__package__, split_filename), raw_data_loader=SegmentationRawDataLoader(), + database_name=__package__.rsplit(".", 1)[1], + split_name=pathlib.Path(split_filename).stem, ) diff --git a/src/mednet/libs/segmentation/config/data/iostar/datamodule.py b/src/mednet/libs/segmentation/config/data/iostar/datamodule.py index 72500523a933a0bd60e50d9f72ea296281799ca4..50df56780bdf891c26e47195e4e073d0ccb45fca 100644 --- a/src/mednet/libs/segmentation/config/data/iostar/datamodule.py +++ b/src/mednet/libs/segmentation/config/data/iostar/datamodule.py @@ -3,14 +3,13 @@ # SPDX-License-Identifier: GPL-3.0-or-later """IOSTAR (training set) for Vessel and Optic-Disc Segmentation.""" -import importlib.resources import os -from pathlib import Path +import pathlib import PIL.Image from mednet.libs.common.data.datamodule import CachingDataModule -from mednet.libs.common.data.split import JSONDatabaseSplit -from mednet.libs.common.data.typing import DatabaseSplit, Sample +from mednet.libs.common.data.split import make_split +from mednet.libs.common.data.typing import Sample from mednet.libs.common.models.transforms import crop_image_to_mask from mednet.libs.segmentation.data.typing import ( SegmentationRawDataLoader as _SegmentationRawDataLoader, @@ -28,13 +27,13 @@ database.""" class SegmentationRawDataLoader(_SegmentationRawDataLoader): """A specialized raw-data-loader for the iostar dataset.""" - datadir: str + datadir: pathlib.Path """This variable contains the base directory where the database raw data is stored.""" def __init__(self): - self.datadir = load_rc().get( - CONFIGURATION_KEY_DATADIR, os.path.realpath(os.curdir) + self.datadir = pathlib.Path( + load_rc().get(CONFIGURATION_KEY_DATADIR, os.path.realpath(os.curdir)) ) def sample(self, sample: tuple[str, str, str]) -> Sample: @@ -52,22 +51,14 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader): The sample representation. """ - image = to_tensor( - PIL.Image.open(Path(self.datadir) / str(sample[0])).convert( - mode="RGB" - ) - ) + image = to_tensor(PIL.Image.open(self.datadir / sample[0]).convert(mode="RGB")) target = to_tensor( - PIL.Image.open(Path(self.datadir) / str(sample[1])).convert( - mode="1", dither=None - ) + PIL.Image.open(self.datadir / sample[1]).convert(mode="1", dither=None) ) mask = to_tensor( - PIL.Image.open(Path(self.datadir) / str(sample[2])).convert( - mode="1", dither=None - ) + PIL.Image.open(self.datadir / sample[2]).convert(mode="1", dither=None) ) tensor = tv_tensors.Image(crop_image_to_mask(image, mask)) @@ -77,24 +68,6 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader): return tensor, dict(target=target, mask=mask, name=sample[0]) # type: ignore[arg-type] -def make_split(basename: str) -> DatabaseSplit: - """Return a database split for the iostar database. - - Parameters - ---------- - basename - Name of the .json file containing the split to load. - - Returns - ------- - An instance of DatabaseSplit. - """ - - return JSONDatabaseSplit( - importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(basename) - ) - - class DataModule(CachingDataModule): """IOSTAR (training set) for Vessel and Optic-Disc Segmentation. @@ -123,7 +96,10 @@ class DataModule(CachingDataModule): """ def __init__(self, split_filename: str): + assert __package__ is not None super().__init__( - database_split=make_split(split_filename), + database_split=make_split(__package__, split_filename), raw_data_loader=SegmentationRawDataLoader(), + database_name=__package__.rsplit(".", 1)[1], + split_name=pathlib.Path(split_filename).stem, ) diff --git a/src/mednet/libs/segmentation/config/data/montgomery/datamodule.py b/src/mednet/libs/segmentation/config/data/montgomery/datamodule.py index 3828022c58caf09b36ef2eccf097818ecf6f5e50..f24b9925a52ba7f38ffc0ce0e28180f9f9c9683c 100644 --- a/src/mednet/libs/segmentation/config/data/montgomery/datamodule.py +++ b/src/mednet/libs/segmentation/config/data/montgomery/datamodule.py @@ -3,17 +3,16 @@ # SPDX-License-Identifier: GPL-3.0-or-later """Montgomery DataModule for TB detection.""" -import importlib.resources import os -from pathlib import Path +import pathlib import numpy as np import PIL.Image import pkg_resources import torch from mednet.libs.common.data.datamodule import CachingDataModule -from mednet.libs.common.data.split import JSONDatabaseSplit -from mednet.libs.common.data.typing import DatabaseSplit, Sample +from mednet.libs.common.data.split import make_split +from mednet.libs.common.data.typing import Sample from mednet.libs.segmentation.data.typing import ( SegmentationRawDataLoader as _SegmentationRawDataLoader, ) @@ -30,15 +29,17 @@ database.""" class SegmentationRawDataLoader(_SegmentationRawDataLoader): """A specialized raw-data-loader for the Montgomery dataset.""" - datadir: str + datadir: pathlib.Path """This variable contains the base directory where the database raw data is stored.""" def __init__(self): - self.datadir = load_rc().get( - CONFIGURATION_KEY_DATADIR, os.path.realpath(os.curdir) + self.datadir = pathlib.Path( + load_rc().get(CONFIGURATION_KEY_DATADIR, os.path.realpath(os.curdir)) + ) + self._pkg_path = pathlib.Path( + pkg_resources.resource_filename(__name__, "masks") ) - self._pkg_path = pkg_resources.resource_filename(__name__, "masks") def sample(self, sample: tuple[str, str, str]) -> Sample: """Load a single image sample from the disk. @@ -55,9 +56,7 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader): The sample representation. """ - image = PIL.Image.open(Path(self.datadir) / str(sample[0])).convert( - mode="RGB" - ) + image = PIL.Image.open(self.datadir / sample[0]).convert(mode="RGB") tensor = tv_tensors.Image(to_tensor(image)) # Combine left and right lung masks into a single tensor @@ -65,14 +64,14 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader): to_tensor( np.ma.mask_or( np.asarray( - PIL.Image.open( - Path(self.datadir) / str(sample[1]) - ).convert(mode="1", dither=None) + PIL.Image.open(self.datadir / sample[1]).convert( + mode="1", dither=None + ) ), np.asarray( - PIL.Image.open( - Path(self.datadir) / str(sample[2]) - ).convert(mode="1", dither=None) + PIL.Image.open(self.datadir / sample[2]).convert( + mode="1", dither=None + ) ), ) ).float() @@ -83,24 +82,6 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader): return tensor, dict(target=target, mask=mask, name=sample[0]) # type: ignore[arg-type] -def make_split(basename: str) -> DatabaseSplit: - """Return a database split for the Montgomery database. - - Parameters - ---------- - basename - Name of the .json file containing the split to load. - - Returns - ------- - An instance of DatabaseSplit. - """ - - return JSONDatabaseSplit( - importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(basename) - ) - - class DataModule(CachingDataModule): """Montgomery DataModule for TB detection. @@ -140,7 +121,10 @@ class DataModule(CachingDataModule): """ def __init__(self, split_filename: str): + assert __package__ is not None super().__init__( - database_split=make_split(split_filename), + database_split=make_split(__package__, split_filename), raw_data_loader=SegmentationRawDataLoader(), + database_name=__package__.rsplit(".", 1)[1], + split_name=pathlib.Path(split_filename).stem, ) diff --git a/src/mednet/libs/segmentation/config/data/stare/datamodule.py b/src/mednet/libs/segmentation/config/data/stare/datamodule.py index d037a6947169d7c3659f17f5064fe921f843f5ef..9bf18d18827707c1d8ad868c815028802326970a 100644 --- a/src/mednet/libs/segmentation/config/data/stare/datamodule.py +++ b/src/mednet/libs/segmentation/config/data/stare/datamodule.py @@ -3,15 +3,14 @@ # SPDX-License-Identifier: GPL-3.0-or-later """STARE dataset for Vessel Segmentation.""" -import importlib.resources import os -from pathlib import Path +import pathlib import PIL.Image import pkg_resources from mednet.libs.common.data.datamodule import CachingDataModule -from mednet.libs.common.data.split import JSONDatabaseSplit -from mednet.libs.common.data.typing import DatabaseSplit, Sample +from mednet.libs.common.data.split import make_split +from mednet.libs.common.data.typing import Sample from mednet.libs.common.models.transforms import crop_image_to_mask from mednet.libs.segmentation.data.typing import ( SegmentationRawDataLoader as _SegmentationRawDataLoader, @@ -29,15 +28,17 @@ database.""" class SegmentationRawDataLoader(_SegmentationRawDataLoader): """A specialized raw-data-loader for the Stare dataset.""" - datadir: str + datadir: pathlib.Path """This variable contains the base directory where the database raw data is stored.""" def __init__(self): - self.datadir = load_rc().get( - CONFIGURATION_KEY_DATADIR, os.path.realpath(os.curdir) + self.datadir = pathlib.Path( + load_rc().get(CONFIGURATION_KEY_DATADIR, os.path.realpath(os.curdir)) + ) + self._pkg_path = pathlib.Path( + pkg_resources.resource_filename(__name__, "masks") ) - self._pkg_path = pkg_resources.resource_filename(__name__, "masks") def sample(self, sample: tuple[str, str, str]) -> Sample: """Load a single image sample from the disk. @@ -54,22 +55,14 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader): The sample representation. """ - image = to_tensor( - PIL.Image.open(Path(self.datadir) / str(sample[0])).convert( - mode="RGB" - ) - ) + image = to_tensor(PIL.Image.open(self.datadir / sample[0]).convert(mode="RGB")) target = to_tensor( - PIL.Image.open(Path(self.datadir) / str(sample[1])).convert( - mode="1", dither=None - ) + PIL.Image.open(self.datadir / sample[1]).convert(mode="1", dither=None) ) mask = to_tensor( - PIL.Image.open(Path(self._pkg_path) / str(sample[2])).convert( - mode="1", dither=None - ) + PIL.Image.open(self._pkg_path / sample[2]).convert(mode="1", dither=None) ) tensor = tv_tensors.Image(crop_image_to_mask(image, mask)) @@ -79,24 +72,6 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader): return tensor, dict(target=target, mask=mask, name=sample[0]) # type: ignore[arg-type] -def make_split(basename: str) -> DatabaseSplit: - """Return a database split for the Stare database. - - Parameters - ---------- - basename - Name of the .json file containing the split to load. - - Returns - ------- - An instance of DatabaseSplit. - """ - - return JSONDatabaseSplit( - importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(basename) - ) - - class DataModule(CachingDataModule): """STARE dataset for Vessel Segmentation. @@ -126,7 +101,10 @@ class DataModule(CachingDataModule): """ def __init__(self, split_filename: str): + assert __package__ is not None super().__init__( - database_split=make_split(split_filename), + database_split=make_split(__package__, split_filename), raw_data_loader=SegmentationRawDataLoader(), + database_name=__package__.rsplit(".", 1)[1], + split_name=pathlib.Path(split_filename).stem, ) diff --git a/src/mednet/libs/segmentation/tests/test_chasedb1.py b/src/mednet/libs/segmentation/tests/test_chasedb1.py index 145bb4dbafeb5aea50e67211f6f2a3a7607c3279..77c8f2d60475939df42f243220deb6d728bb714a 100644 --- a/src/mednet/libs/segmentation/tests/test_chasedb1.py +++ b/src/mednet/libs/segmentation/tests/test_chasedb1.py @@ -28,12 +28,10 @@ def test_protocol_consistency( split: str, lengths: dict[str, int], ): - from mednet.libs.segmentation.config.data.chasedb1.datamodule import ( - make_split, - ) + from mednet.libs.common.data.split import make_split database_checkers.check_split( - make_split(f"{split}.json"), + make_split("mednet.libs.segmentation.config.data.chasedb1", f"{split}.json"), lengths=lengths, ) @@ -91,8 +89,7 @@ def test_loading(database_checkers, name: str, dataset: str): @pytest.mark.skip_if_rc_var_not_set("datadir.chasedb1") def test_raw_transforms_image_quality(database_checkers, datadir): reference_histogram_file = str( - datadir - / "histograms/raw_data/histograms_chasedb1_first_annotator.json", + datadir / "histograms/raw_data/histograms_chasedb1_first_annotator.json", ) datamodule = importlib.import_module( diff --git a/src/mednet/libs/segmentation/tests/test_cli_segmentation.py b/src/mednet/libs/segmentation/tests/test_cli_segmentation.py index 412509c8d98d5938f0f359732c69ddaea305a1d3..7a824c04f00049b6f2f4c4af176db947a94910fc 100644 --- a/src/mednet/libs/segmentation/tests/test_cli_segmentation.py +++ b/src/mednet/libs/segmentation/tests/test_cli_segmentation.py @@ -84,10 +84,7 @@ def test_config_describe_drive(): runner = CliRunner() result = runner.invoke(describe, ["drive"]) _assert_exit_0(result) - assert ( - "DRIVE dataset for Vessel Segmentation (default protocol)." - in result.output - ) + assert "DRIVE dataset for Vessel Segmentation (default protocol)." in result.output def test_database_help(): @@ -154,6 +151,7 @@ def test_evaluate_help(): _check_help(evaluate) +@pytest.mark.slow @pytest.mark.skip_if_rc_var_not_set("datadir.drive") def test_train_lwnet_drive(temporary_basedir): from mednet.libs.common.utils.checkpointer import ( @@ -185,10 +183,7 @@ def test_train_lwnet_drive(temporary_basedir): best = _get_checkpoint_from_alias(output_folder, "best") assert best.name.endswith("epoch=0" + CHECKPOINT_EXTENSION) - assert ( - len(list((output_folder / "logs").glob("events.out.tfevents.*"))) - == 1 - ) + assert len(list((output_folder / "logs").glob("events.out.tfevents.*"))) == 1 assert (output_folder / "meta.json").exists() keywords = { @@ -209,6 +204,7 @@ def test_train_lwnet_drive(temporary_basedir): ) +@pytest.mark.slow @pytest.mark.skip_if_rc_var_not_set("datadir.drive") def test_train_lwnet_drive_from_checkpoint(temporary_basedir): from mednet.libs.common.utils.checkpointer import ( @@ -240,9 +236,7 @@ def test_train_lwnet_drive_from_checkpoint(temporary_basedir): assert best.name.endswith("epoch=0" + CHECKPOINT_EXTENSION) assert (output_folder / "meta.json").exists() - assert ( - len(list((output_folder / "logs").glob("events.out.tfevents.*"))) == 1 - ) + assert len(list((output_folder / "logs").glob("events.out.tfevents.*"))) == 1 with stdout_logging() as buf: result = runner.invoke( @@ -264,10 +258,7 @@ def test_train_lwnet_drive_from_checkpoint(temporary_basedir): best = _get_checkpoint_from_alias(output_folder, "best") assert (output_folder / "meta.json").exists() - assert ( - len(list((output_folder / "logs").glob("events.out.tfevents.*"))) - == 2 - ) + assert len(list((output_folder / "logs").glob("events.out.tfevents.*"))) == 2 keywords = { r"^Loading dataset:`train` without caching. Trade-off: CPU RAM usage: less | Disk I/O: more.$": 1, @@ -288,6 +279,7 @@ def test_train_lwnet_drive_from_checkpoint(temporary_basedir): ) +@pytest.mark.slow @pytest.mark.skip_if_rc_var_not_set("datadir.drive") def test_predict_lwnet_drive(temporary_basedir, datadir): from mednet.libs.common.utils.checkpointer import ( @@ -340,6 +332,7 @@ def test_predict_lwnet_drive(temporary_basedir, datadir): ) +@pytest.mark.slow @pytest.mark.skip_if_rc_var_not_set("datadir.drive") def test_evaluate_lwnet_drive(temporary_basedir): from mednet.libs.segmentation.scripts.evaluate import evaluate @@ -388,6 +381,7 @@ def test_evaluate_lwnet_drive(temporary_basedir): ) +@pytest.mark.slow @pytest.mark.skip_if_rc_var_not_set("datadir.drive") def test_experiment(temporary_basedir): from mednet.libs.segmentation.scripts.experiment import experiment @@ -409,9 +403,7 @@ def test_experiment(temporary_basedir): _assert_exit_0(result) assert (output_folder / "model" / "meta.json").exists() - assert ( - output_folder / "model" / f"model-at-epoch={num_epochs-1}.ckpt" - ).exists() + assert (output_folder / "model" / f"model-at-epoch={num_epochs-1}.ckpt").exists() assert (output_folder / "predictions" / "predictions.json").exists() assert (output_folder / "predictions" / "predictions.meta.json").exists() diff --git a/src/mednet/libs/segmentation/tests/test_drive.py b/src/mednet/libs/segmentation/tests/test_drive.py index d1a1301d47522459397e5b7b1b22b733a42faeb2..060d113238af73a47c45e12e6dc6cc2d65a671fd 100644 --- a/src/mednet/libs/segmentation/tests/test_drive.py +++ b/src/mednet/libs/segmentation/tests/test_drive.py @@ -27,12 +27,10 @@ def test_protocol_consistency( split: str, lengths: dict[str, int], ): - from mednet.libs.segmentation.config.data.drive.datamodule import ( - make_split, - ) + from mednet.libs.common.data.split import make_split database_checkers.check_split( - make_split(f"{split}.json"), + make_split("mednet.libs.segmentation.config.data.drive", f"{split}.json"), lengths=lengths, ) @@ -110,8 +108,7 @@ def test_raw_transforms_image_quality(database_checkers, datadir): ) def test_model_transforms_image_quality(database_checkers, datadir, model_name): reference_histogram_file = str( - datadir - / f"histograms/models/histograms_{model_name}_drive_default.json", + datadir / f"histograms/models/histograms_{model_name}_drive_default.json", ) datamodule = importlib.import_module( diff --git a/src/mednet/libs/segmentation/tests/test_iostar.py b/src/mednet/libs/segmentation/tests/test_iostar.py index b00439a38d54e49c0dddf8de380d50f779c6bf4a..e50967e0a6a74622f463f448a23b5365e05e8564 100644 --- a/src/mednet/libs/segmentation/tests/test_iostar.py +++ b/src/mednet/libs/segmentation/tests/test_iostar.py @@ -28,12 +28,10 @@ def test_protocol_consistency( split: str, lengths: dict[str, int], ): - from mednet.libs.segmentation.config.data.iostar.datamodule import ( - make_split, - ) + from mednet.libs.common.data.split import make_split database_checkers.check_split( - make_split(f"{split}.json"), + make_split("mednet.libs.segmentation.config.data.iostar", f"{split}.json"), lengths=lengths, ) @@ -109,8 +107,7 @@ def test_raw_transforms_image_quality(database_checkers, datadir): ) def test_model_transforms_image_quality(database_checkers, datadir, model_name): reference_histogram_file = str( - datadir - / f"histograms/models/histograms_{model_name}_iostar_vessel.json", + datadir / f"histograms/models/histograms_{model_name}_iostar_vessel.json", ) datamodule = importlib.import_module( diff --git a/src/mednet/libs/segmentation/tests/test_montgomery.py b/src/mednet/libs/segmentation/tests/test_montgomery.py index 16ab9bb7ae73957eaaafb9ccedca603469b16fc1..2f90200f70722edc54e4eda27e0dcc6ba5d870c8 100644 --- a/src/mednet/libs/segmentation/tests/test_montgomery.py +++ b/src/mednet/libs/segmentation/tests/test_montgomery.py @@ -27,12 +27,10 @@ def test_protocol_consistency( split: str, lengths: dict[str, int], ): - from mednet.libs.segmentation.config.data.montgomery.datamodule import ( - make_split, - ) + from mednet.libs.common.data.split import make_split database_checkers.check_split( - make_split(f"{split}.json"), + make_split("mednet.libs.segmentation.config.data.montgomery", f"{split}.json"), lengths=lengths, ) @@ -112,8 +110,7 @@ def test_raw_transforms_image_quality(database_checkers, datadir): ) def test_model_transforms_image_quality(database_checkers, datadir, model_name): reference_histogram_file = str( - datadir - / f"histograms/models/histograms_{model_name}_montgomery_default.json", + datadir / f"histograms/models/histograms_{model_name}_montgomery_default.json", ) datamodule = importlib.import_module( diff --git a/src/mednet/libs/segmentation/tests/test_stare.py b/src/mednet/libs/segmentation/tests/test_stare.py index 9b0038945da8c6648664dcd124add57eed898cf5..6268ec9f32cf53c4ba8377cd1f9bba77f011d12f 100644 --- a/src/mednet/libs/segmentation/tests/test_stare.py +++ b/src/mednet/libs/segmentation/tests/test_stare.py @@ -28,12 +28,10 @@ def test_protocol_consistency( split: str, lengths: dict[str, int], ): - from mednet.libs.segmentation.config.data.stare.datamodule import ( - make_split, - ) + from mednet.libs.common.data.split import make_split database_checkers.check_split( - make_split(f"{split}.json"), + make_split("mednet.libs.segmentation.config.data.stare", f"{split}.json"), lengths=lengths, )