Skip to content
Snippets Groups Projects
Commit 0632bb72 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[data.nih_cxr14_padchest] Reimplements aggregated database

parent 1d7cbaff
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
Pipeline #76704 canceled
...@@ -193,7 +193,7 @@ montgomery-shenzhen-indian-tbx11k-v2-f7 = "ptbench.data.montgomery_shenzhen_indi ...@@ -193,7 +193,7 @@ montgomery-shenzhen-indian-tbx11k-v2-f7 = "ptbench.data.montgomery_shenzhen_indi
montgomery-shenzhen-indian-tbx11k-v2-f8 = "ptbench.data.montgomery_shenzhen_indian_tbx11k.v2_fold_8" montgomery-shenzhen-indian-tbx11k-v2-f8 = "ptbench.data.montgomery_shenzhen_indian_tbx11k.v2_fold_8"
montgomery-shenzhen-indian-tbx11k-v2-f9 = "ptbench.data.montgomery_shenzhen_indian_tbx11k.v2_fold_9" montgomery-shenzhen-indian-tbx11k-v2-f9 = "ptbench.data.montgomery_shenzhen_indian_tbx11k.v2_fold_9"
# tbpoc dataset (and cross-validation folds) # tbpoc dataset (only cross-validation folds)
tbpoc_f0 = "ptbench.data.tbpoc.fold_0" tbpoc_f0 = "ptbench.data.tbpoc.fold_0"
tbpoc_f1 = "ptbench.data.tbpoc.fold_1" tbpoc_f1 = "ptbench.data.tbpoc.fold_1"
tbpoc_f2 = "ptbench.data.tbpoc.fold_2" tbpoc_f2 = "ptbench.data.tbpoc.fold_2"
...@@ -205,7 +205,7 @@ tbpoc_f7 = "ptbench.data.tbpoc.fold_7" ...@@ -205,7 +205,7 @@ tbpoc_f7 = "ptbench.data.tbpoc.fold_7"
tbpoc_f8 = "ptbench.data.tbpoc.fold_8" tbpoc_f8 = "ptbench.data.tbpoc.fold_8"
tbpoc_f9 = "ptbench.data.tbpoc.fold_9" tbpoc_f9 = "ptbench.data.tbpoc.fold_9"
# hivtb dataset (and cross-validation folds) # hivtb dataset (only cross-validation folds)
hivtb_f0 = "ptbench.data.hivtb.fold_0" hivtb_f0 = "ptbench.data.hivtb.fold_0"
hivtb_f1 = "ptbench.data.hivtb.fold_1" hivtb_f1 = "ptbench.data.hivtb.fold_1"
hivtb_f2 = "ptbench.data.hivtb.fold_2" hivtb_f2 = "ptbench.data.hivtb.fold_2"
...@@ -217,9 +217,6 @@ hivtb_f7 = "ptbench.data.hivtb.fold_7" ...@@ -217,9 +217,6 @@ hivtb_f7 = "ptbench.data.hivtb.fold_7"
hivtb_f8 = "ptbench.data.hivtb.fold_8" hivtb_f8 = "ptbench.data.hivtb.fold_8"
hivtb_f9 = "ptbench.data.hivtb.fold_9" hivtb_f9 = "ptbench.data.hivtb.fold_9"
# montgomery-shenzhen-indian-padchest aggregated dataset
mc_ch_in_pc = "ptbench.data.mc_ch_in_pc.default"
# NIH CXR14 (relabeled), multi-class (14 labels) # NIH CXR14 (relabeled), multi-class (14 labels)
nih-cxr14 = "ptbench.data.nih_cxr14.default" nih-cxr14 = "ptbench.data.nih_cxr14.default"
nih-cxr14-cardiomegaly = "ptbench.data.nih_cxr14.cardiomegaly" nih-cxr14-cardiomegaly = "ptbench.data.nih_cxr14.cardiomegaly"
...@@ -231,7 +228,10 @@ padchest-no-tb-idiap = "ptbench.data.padchest.no_tb_idiap" ...@@ -231,7 +228,10 @@ padchest-no-tb-idiap = "ptbench.data.padchest.no_tb_idiap"
padchest-cardiomegaly-idiap = "ptbench.data.padchest.cardiomegaly_idiap" padchest-cardiomegaly-idiap = "ptbench.data.padchest.cardiomegaly_idiap"
# NIH CXR14 / PadChest aggregated dataset # NIH CXR14 / PadChest aggregated dataset
nih_cxr14_pc_idiap = "ptbench.data.nih_cxr14_re_pc.idiap" nih-cxr14-padchest = "ptbench.data.nih_cxr14_padchest.idiap"
# montgomery-shenzhen-indian-padchest aggregated dataset
mc_ch_in_pc = "ptbench.data.mc_ch_in_pc.default"
[tool.setuptools] [tool.setuptools]
zip-safe = true zip-safe = true
......
# Copyright © 2022 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
from ..datamodule import ConcatDataModule
from ..nih_cxr14.datamodule import RawDataLoader as CXR14Loader
from ..nih_cxr14.datamodule import make_split as make_cxr14_split
from ..padchest.datamodule import RawDataLoader as PadchestLoader
from ..padchest.datamodule import make_split as make_padchest_split
class DataModule(ConcatDataModule):
"""Aggregated dataset composed of NIH CXR14 relabeld and PadChest
(normalized) datasets."""
def __init__(self, cxr14_split_filename: str, padchest_split_filename):
cxr14_loader = CXR14Loader()
cxr14_split = make_cxr14_split(cxr14_split_filename)
padchest_loader = PadchestLoader()
padchest_split = make_padchest_split(padchest_split_filename)
super().__init__(
splits={
"train": [
(cxr14_split["train"], cxr14_loader),
(padchest_split["train"], padchest_loader),
],
"validation": [
(cxr14_split["validation"], cxr14_loader),
(padchest_split["validation"], padchest_loader),
],
"test": [
(cxr14_split["test"], cxr14_loader),
(padchest_split["test"], padchest_loader),
],
}
)
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
# #
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
from .datamodule import DataModule
datamodule = DataModule("default.json.bz2", "no-tb-idiap.json.bz2")
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Aggregated dataset composed of NIH CXR14 relabeld and PadChest (normalized)
datasets."""
from clapper.logging import setup
from torch.utils.data.dataset import ConcatDataset
from .. import return_subsets
from ..base_datamodule import BaseDataModule, get_dataset_from_module
from ..nih_cxr14_re.default import datamodule as nih_cxr14_re_datamodule
from ..padchest.no_tb_idiap import datamodule as padchest_no_tb_datamodule
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
class DefaultModule(BaseDataModule):
def __init__(
self,
train_batch_size=1,
predict_batch_size=1,
drop_incomplete_batch=False,
multiproc_kwargs=None,
):
self.train_batch_size = train_batch_size
self.predict_batch_size = predict_batch_size
self.drop_incomplete_batch = drop_incomplete_batch
self.multiproc_kwargs = multiproc_kwargs
super().__init__(
train_batch_size=train_batch_size,
predict_batch_size=predict_batch_size,
drop_incomplete_batch=drop_incomplete_batch,
multiproc_kwargs=multiproc_kwargs,
)
def setup(self, stage: str):
module_args = {
"train_batch_size": self.train_batch_size,
"predict_batch_size": self.predict_batch_size,
"drop_incomplete_batch": self.drop_incomplete_batch,
"multiproc_kwargs": self.multiproc_kwargs,
}
nih_cxr14_re = get_dataset_from_module(
nih_cxr14_re_datamodule, stage, **module_args
)
padchest_no_tb = get_dataset_from_module(
padchest_no_tb_datamodule, stage, **module_args
)
self.dataset = {}
self.dataset["__train__"] = ConcatDataset(
[nih_cxr14_re["__train__"], padchest_no_tb["__train__"]]
)
self.dataset["train"] = ConcatDataset(
[nih_cxr14_re["train"], padchest_no_tb["train"]]
)
self.dataset["__valid__"] = ConcatDataset(
[nih_cxr14_re["__valid__"], padchest_no_tb["__valid__"]]
)
self.dataset["validation"] = ConcatDataset(
[nih_cxr14_re["validation"], padchest_no_tb["validation"]]
)
self.dataset["test"] = nih_cxr14_re["test"]
(
self.train_dataset,
self.validation_dataset,
self.extra_validation_datasets,
self.predict_dataset,
) = return_subsets(self.dataset)
datamodule = DefaultModule
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment