Skip to content
Snippets Groups Projects
Commit 9eacc051 authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

[mednet] Fixes after rebase

parent d8242837
No related branches found
No related tags found
1 merge request!46Create common library
Showing
with 67 additions and 142 deletions
......@@ -427,6 +427,9 @@ lwnet = "mednet.libs.segmentation.config.models.lwnet"
chasedb1 = "mednet.libs.segmentation.config.data.chasedb1.first_annotator"
chasedb1-2nd = "mednet.libs.segmentation.config.data.chasedb1.second_annotator"
# cxr8 - cxr
cxr8 = "mednet.libs.segmentation.config.data.cxr8.default"
# drive dataset - retinography
drive = "mednet.libs.segmentation.config.data.drive.default"
......
......@@ -15,10 +15,8 @@ from mednet.libs.classification.data.typing import (
)
from mednet.libs.common.data.datamodule import CachingDataModule
from mednet.libs.common.data.image_utils import remove_black_borders
from mednet.libs.common.data.split import make_split
from mednet.libs.common.data.typing import Sample
from torchvision.transforms.functional import to_tensor
from ....utils.rc import load_rc
......@@ -127,7 +125,7 @@ class DataModule(CachingDataModule):
assert __package__ is not None
super().__init__(
database_split=make_split(__package__, split_filename),
raw_data_loader=_ClassificationRawDataLoader(),
database_name=__package__.rsplit(".", 1)[1],
raw_data_loader=ClassificationRawDataLoader(),
database_name=__package__.rsplit(".", 1)[1],
split_name=pathlib.Path(split_filename).stem,
)
......@@ -11,9 +11,8 @@ import pathlib
from mednet.libs.common.data.datamodule import CachingDataModule
from mednet.libs.common.data.split import make_split
from mednet.libs.classification.data.typing import (
ClassificationRawDataLoader as _ClassificationRawDataLoader,
)
from ....config.data.shenzhen.datamodule import ClassificationRawDataLoader
CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2])
"""Key to search for in the configuration file for the root directory of this
database."""
......@@ -64,8 +63,9 @@ class DataModule(CachingDataModule):
assert __package__ is not None
super().__init__(
database_split=make_split(__package__, split_filename),
raw_data_loader=_ClassificationRawDataLoader(),
database_name=__package__.rsplit(".", 1)[1],
raw_data_loader=ClassificationRawDataLoader(
config_variable=CONFIGURATION_KEY_DATADIR
),
database_name=__package__.rsplit(".", 1)[1],
split_name=pathlib.Path(split_filename).stem,
)
......@@ -15,10 +15,8 @@ from mednet.libs.classification.data.typing import (
)
from mednet.libs.common.data.datamodule import CachingDataModule
from mednet.libs.common.data.image_utils import remove_black_borders
from mednet.libs.common.data.split import make_split
from mednet.libs.common.data.typing import Sample
from torchvision.transforms.functional import to_tensor
from ....utils.rc import load_rc
......@@ -28,7 +26,7 @@ CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2])
database."""
class RawDataLoader(_BaseRawDataLoader):
class ClassificationRawDataLoader(_ClassificationRawDataLoader):
"""A specialized raw-data-loader for the Montgomery dataset.
Parameters
......@@ -140,7 +138,7 @@ class DataModule(CachingDataModule):
assert __package__ is not None
super().__init__(
database_split=make_split(__package__, split_filename),
raw_data_loader=_ClassificationRawDataLoader(),
database_name=__package__.rsplit(".", 1)[1],
raw_data_loader=ClassificationRawDataLoader(),
database_name=__package__.rsplit(".", 1)[1],
split_name=pathlib.Path(split_filename).stem,
)
......@@ -14,10 +14,8 @@ from mednet.libs.classification.data.typing import (
ClassificationRawDataLoader as _ClassificationRawDataLoader,
)
from mednet.libs.common.data.datamodule import CachingDataModule
from mednet.libs.common.data.split import make_split
from mednet.libs.common.data.typing import Sample
from torchvision.transforms.functional import to_tensor
from ....utils.rc import load_rc
......@@ -86,7 +84,9 @@ class ClassificationRawDataLoader(_ClassificationRawDataLoader):
# for folder lookup efficiency, data is split into subfolders
# each original file is on the subfolder `f[:5]/f`, where f
# is the original file basename
file_path = file_path.parent / file_path.name[:5] / file_path.name
file_path = pathlib.Path(
file_path.parent / file_path.name[:5] / file_path.name
)
# N.B.: some NIH CXR-14 images are encoded as color PNGs with an alpha
# channel. Most, are grayscale PNGs
......@@ -176,7 +176,7 @@ class DataModule(CachingDataModule):
assert __package__ is not None
super().__init__(
database_split=make_split(__package__, split_filename),
raw_data_loader=_ClassificationRawDataLoader(),
database_name=__package__.rsplit(".", 1)[1],
raw_data_loader=ClassificationRawDataLoader(),
database_name=__package__.rsplit(".", 1)[1],
split_name=pathlib.Path(split_filename).stem,
)
......@@ -16,10 +16,8 @@ from mednet.libs.classification.data.typing import (
)
from mednet.libs.common.data.datamodule import CachingDataModule
from mednet.libs.common.data.image_utils import remove_black_borders
from mednet.libs.common.data.split import make_split
from mednet.libs.common.data.typing import Sample
from torchvision.transforms.functional import to_tensor
from ....utils.rc import load_rc
......@@ -329,7 +327,7 @@ class DataModule(CachingDataModule):
assert __package__ is not None
super().__init__(
database_split=make_split(__package__, split_filename),
raw_data_loader=_ClassificationRawDataLoader(),
database_name=__package__.rsplit(".", 1)[1],
raw_data_loader=ClassificationRawDataLoader(),
database_name=__package__.rsplit(".", 1)[1],
split_name=pathlib.Path(split_filename).stem,
)
......@@ -15,10 +15,8 @@ from mednet.libs.classification.data.typing import (
)
from mednet.libs.common.data.datamodule import CachingDataModule
from mednet.libs.common.data.image_utils import remove_black_borders
from mednet.libs.common.data.split import make_split
from mednet.libs.common.data.typing import Sample
from torchvision.transforms.functional import to_tensor
from ....utils.rc import load_rc
......@@ -139,7 +137,7 @@ class DataModule(CachingDataModule):
assert __package__ is not None
super().__init__(
database_split=make_split(__package__, split_filename),
raw_data_loader=_ClassificationRawDataLoader(),
database_name=__package__.rsplit(".", 1)[1],
raw_data_loader=ClassificationRawDataLoader(),
database_name=__package__.rsplit(".", 1)[1],
split_name=pathlib.Path(split_filename).stem,
)
......@@ -11,10 +11,8 @@ from mednet.libs.classification.data.typing import (
)
from mednet.libs.common.data.datamodule import CachingDataModule
from mednet.libs.common.data.image_utils import remove_black_borders
from mednet.libs.common.data.split import make_split
from mednet.libs.common.data.typing import Sample
from torchvision.transforms.functional import to_tensor
from ....utils.rc import load_rc
......@@ -125,7 +123,7 @@ class DataModule(CachingDataModule):
assert __package__ is not None
super().__init__(
database_split=make_split(__package__, split_filename),
raw_data_loader=_ClassificationRawDataLoader(),
database_name=__package__.rsplit(".", 1)[1],
raw_data_loader=ClassificationRawDataLoader(),
database_name=__package__.rsplit(".", 1)[1],
split_name=pathlib.Path(split_filename).stem,
)
......@@ -39,7 +39,7 @@ def test_protocol_consistency(
from mednet.libs.common.data.split import make_split
database_checkers.check_split(
make_split("mednet.config.data.hivtb", f"{split}.json"),
make_split("mednet.libs.classification.config.data.hivtb", f"{split}.json"),
lengths=lengths,
prefixes=("HIV-TB_Algorithm_study_X-rays",),
possible_labels=(0, 1),
......
......@@ -43,7 +43,7 @@ def test_protocol_consistency(
from mednet.libs.common.data.split import make_split
database_checkers.check_split(
make_split("mednet.config.data.indian", f"{split}.json"),
make_split("mednet.libs.classification.config.data.indian", f"{split}.json"),
lengths=lengths,
prefixes=("DatasetA/Training", "DatasetA/Testing"),
possible_labels=(0, 1),
......
......@@ -41,7 +41,9 @@ def test_protocol_consistency(
from mednet.libs.common.data.split import make_split
database_checkers.check_split(
make_split("mednet.config.data.montgomery", f"{split}.json"),
make_split(
"mednet.libs.classification.config.data.montgomery", f"{split}.json"
),
lengths=lengths,
prefixes=("CXR_png/MCUCXR_0",),
possible_labels=(0, 1),
......@@ -114,12 +116,8 @@ def test_loading(database_checkers, name: str, dataset: str):
@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
def test_raw_transforms_image_quality(database_checkers, datadir: pathlib.Path):
datamodule = importlib.import_module(
<<<<<<< HEAD
".default", "mednet.config.data.montgomery"
=======
".default",
"mednet.libs.classification.config.data.montgomery",
>>>>>>> b1ea1c0 ([mednet] Start reorganizing into monorepo)
).datamodule
datamodule.model_transforms = []
......
......@@ -31,7 +31,7 @@ def test_protocol_consistency(
from mednet.libs.common.data.split import make_split
database_checkers.check_split(
make_split("mednet.config.data.nih_cxr14", f"{split}"),
make_split("mednet.libs.classification.config.data.nih_cxr14", f"{split}"),
lengths=lengths,
prefixes=("images/000",),
possible_labels=(0, 1),
......
......@@ -36,7 +36,7 @@ def test_protocol_consistency(
from mednet.libs.common.data.split import make_split
database_checkers.check_split(
make_split("mednet.config.data.padchest", split),
make_split("mednet.libs.classification.config.data.padchest", split),
lengths=lengths,
prefixes=("",),
possible_labels=(0, 1),
......
......@@ -40,7 +40,7 @@ def test_protocol_consistency(
from mednet.libs.common.data.split import make_split
database_checkers.check_split(
make_split("mednet.config.data.shenzhen", f"{split}.json"),
make_split("mednet.libs.classification.config.data.shenzhen", f"{split}.json"),
lengths=lengths,
prefixes=("CXR_png/CHNCXR_0",),
possible_labels=(0, 1),
......
......@@ -39,7 +39,7 @@ def test_protocol_consistency(
from mednet.libs.common.data.split import make_split
database_checkers.check_split(
make_split("mednet.config.data.tbpoc", f"{split}.json"),
make_split("mednet.libs.classification.config.data.tbpoc", f"{split}.json"),
lengths=lengths,
prefixes=(
"TBPOC_CXR/TBPOC-",
......
......@@ -142,7 +142,7 @@ def test_protocol_consistency(
from mednet.libs.common.data.split import make_split
database_checkers.check_split(
make_split("mednet.config.data.tbx11k", f"{split}.json"),
make_split("mednet.libs.classification.config.data.tbx11k", f"{split}.json"),
lengths=lengths,
prefixes=prefixes,
possible_labels=(0, 1),
......
......@@ -11,8 +11,8 @@ import torch.nn
import torch.optim.optimizer
import torch.utils.data
import torchvision.transforms
from mednet.libs.common.data.typing import TransformSequence
from medbase.data.typing import TransformSequence
from .loss_weights import get_positive_weights
from .typing import Checkpoint
......@@ -224,7 +224,5 @@ class Model(pl.LightningModule):
logger.warning(
"Datamodule does not contain a validation dataloader. The training dataloader will be used instead."
)
validation_weights = get_positive_weights(
datamodule.train_dataloader()
)
validation_weights = get_positive_weights(datamodule.train_dataloader())
self._validation_loss_arguments["pos_weight"] = validation_weights
......@@ -3,31 +3,9 @@
# SPDX-License-Identifier: GPL-3.0-or-later
"""Test code for datasets."""
<<<<<<< HEAD
from medbase.data.split import JSONDatabaseSplit
=======
from mednet.libs.common.data.split import CSVDatabaseSplit, JSONDatabaseSplit
from mednet.libs.common.data.split import JSONDatabaseSplit
def test_csv_loading(datadir):
# tests if we can build a simple CSV loader for the Iris Flower dataset
database_split = CSVDatabaseSplit(datadir)
assert len(database_split["iris-train"]) == 75
for k in database_split["iris-train"]:
for f in range(4):
assert isinstance(k[f], str) # csv only loads stringd
assert isinstance(k[4], str)
assert len(database_split["iris-test"]) == 75
for k in database_split["iris-test"]:
for f in range(4):
assert isinstance(k[f], str) # csv only loads stringd
assert isinstance(k[4], str)
assert k[4] in ("Iris-setosa", "Iris-versicolor", "Iris-virginica")
>>>>>>> b1ea1c0 ([mednet] Start reorganizing into monorepo)
def test_json_loading(datadir):
# tests if we can build a simple JSON loader for the Iris Flower dataset
......
......@@ -3,15 +3,14 @@
# SPDX-License-Identifier: GPL-3.0-or-later
"""CHASE-DB1 dataset for Vessel Segmentation."""
import importlib.resources
import os
from pathlib import Path
import pathlib
import PIL.Image
import pkg_resources
from mednet.libs.common.data.datamodule import CachingDataModule
from mednet.libs.common.data.split import JSONDatabaseSplit
from mednet.libs.common.data.typing import DatabaseSplit, Sample
from mednet.libs.common.data.split import make_split
from mednet.libs.common.data.typing import Sample
from mednet.libs.segmentation.data.typing import (
SegmentationRawDataLoader as _SegmentationRawDataLoader,
)
......@@ -28,16 +27,18 @@ database."""
class SegmentationRawDataLoader(_SegmentationRawDataLoader):
"""A specialized raw-data-loader for the Chase-db1 dataset."""
datadir: str
datadir: pathlib.Path
"""This variable contains the base directory where the database raw data is
stored."""
def __init__(self):
self.datadir = load_rc().get(
CONFIGURATION_KEY_DATADIR, os.path.realpath(os.curdir)
self.datadir = pathlib.Path(
load_rc().get(CONFIGURATION_KEY_DATADIR, os.path.realpath(os.curdir))
)
self._pkg_path = pkg_resources.resource_filename(__name__, "masks")
self._pkg_path = pathlib.Path(
pkg_resources.resource_filename(__name__, "masks")
)
def sample(self, sample: tuple[str, str, str]) -> Sample:
"""Load a single image sample from the disk.
......@@ -53,20 +54,16 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader):
The sample representation.
"""
image = PIL.Image.open(Path(self.datadir) / str(sample[0])).convert(
mode="RGB"
)
image = PIL.Image.open(self.datadir / sample[0]).convert(mode="RGB")
tensor = tv_tensors.Image(to_tensor(image))
target = tv_tensors.Image(
to_tensor(
PIL.Image.open(Path(self.datadir) / str(sample[1])).convert(
mode="1", dither=None
)
PIL.Image.open(self.datadir / sample[1]).convert(mode="1", dither=None)
)
)
mask = tv_tensors.Mask(
to_tensor(
PIL.Image.open(Path(self._pkg_path) / str(sample[2])).convert(
PIL.Image.open(self._pkg_path / sample[2]).convert(
mode="1", dither=None
)
)
......@@ -75,24 +72,6 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader):
return tensor, dict(target=target, mask=mask, name=sample[0]) # type: ignore[arg-type]
def make_split(basename: str) -> DatabaseSplit:
"""Return a database split for the Chase-db1 database.
Parameters
----------
basename
Name of the .json file containing the split to load.
Returns
-------
An instance of DatabaseSplit.
"""
return JSONDatabaseSplit(
importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(basename)
)
class DataModule(CachingDataModule):
"""CHASE-DB1 dataset for Vessel Segmentation.
......@@ -129,7 +108,10 @@ class DataModule(CachingDataModule):
"""
def __init__(self, split_filename: str):
assert __package__ is not None
super().__init__(
database_split=make_split(split_filename),
database_split=make_split(__package__, split_filename),
raw_data_loader=SegmentationRawDataLoader(),
database_name=__package__.rsplit(".", 1)[1],
split_name=pathlib.Path(split_filename).stem,
)
......@@ -3,14 +3,13 @@
# SPDX-License-Identifier: GPL-3.0-or-later
"""COVD-DRIVE for Vessel Segmentation."""
import importlib.resources
import os
from pathlib import Path
import pathlib
import PIL.Image
from mednet.libs.common.data.datamodule import CachingDataModule
from mednet.libs.common.data.split import JSONDatabaseSplit
from mednet.libs.common.data.typing import DatabaseSplit, Sample
from mednet.libs.common.data.split import make_split
from mednet.libs.common.data.typing import Sample
from mednet.libs.common.models.transforms import crop_image_to_mask
from mednet.libs.segmentation.data.typing import (
SegmentationRawDataLoader as _SegmentationRawDataLoader,
......@@ -28,13 +27,13 @@ database."""
class SegmentationRawDataLoader(_SegmentationRawDataLoader):
"""A specialized raw-data-loader for the Drive dataset."""
datadir: str
datadir: pathlib.Path
"""This variable contains the base directory where the database raw data is
stored."""
def __init__(self):
self.datadir = load_rc().get(
CONFIGURATION_KEY_DATADIR, os.path.realpath(os.curdir)
self.datadir = pathlib.Path(
load_rc().get(CONFIGURATION_KEY_DATADIR, os.path.realpath(os.curdir))
)
def sample(self, sample: tuple[str, str, str]) -> Sample:
......@@ -51,20 +50,12 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader):
The sample representation.
"""
image = to_tensor(
PIL.Image.open(Path(self.datadir) / str(sample[0])).convert(
mode="RGB"
)
)
image = to_tensor(PIL.Image.open(self.datadir / sample[0]).convert(mode="RGB"))
target = to_tensor(
PIL.Image.open(Path(self.datadir) / str(sample[1])).convert(
mode="1", dither=None
)
PIL.Image.open(self.datadir / sample[1]).convert(mode="1", dither=None)
)
mask = to_tensor(
PIL.Image.open(Path(self.datadir) / str(sample[2])).convert(
mode="1", dither=None
)
PIL.Image.open(self.datadir / sample[2]).convert(mode="1", dither=None)
)
tensor = tv_tensors.Image(crop_image_to_mask(image, mask))
......@@ -74,24 +65,6 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader):
return tensor, dict(target=target, mask=mask, name=sample[0]) # type: ignore[arg-type]
def make_split(basename: str) -> DatabaseSplit:
"""Return a database split for the Drive database.
Parameters
----------
basename
Name of the .json file containing the split to load.
Returns
-------
An instance of DatabaseSplit.
"""
return JSONDatabaseSplit(
importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(basename)
)
class DataModule(CachingDataModule):
"""DRIVE dataset for Vessel Segmentation.
......@@ -117,7 +90,10 @@ class DataModule(CachingDataModule):
"""
def __init__(self, split_filename: str):
assert __package__ is not None
super().__init__(
database_split=make_split(split_filename),
database_split=make_split(__package__, split_filename),
raw_data_loader=SegmentationRawDataLoader(),
database_name=__package__.rsplit(".", 1)[1],
split_name=pathlib.Path(split_filename).stem,
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment