Skip to content
Snippets Groups Projects
Commit 327a53ce authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

Moved nih_cxr14_re_pc configs to data

parent ff35275f
No related branches found
No related tags found
No related merge requests found
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
from torch.utils.data.dataset import ConcatDataset
def _maker(protocol):
if protocol == "idiap":
from ..nih_cxr14_re import default as nih_cxr14_re
from ..padchest import no_tb_idiap as padchest_no_tb
else:
raise RuntimeError(f"Unsupported protocol: {protocol}")
nih_cxr14_re = nih_cxr14_re.dataset
padchest_no_tb = padchest_no_tb.dataset
dataset = {}
dataset["__train__"] = ConcatDataset(
[nih_cxr14_re["__train__"], padchest_no_tb["__train__"]]
)
dataset["train"] = ConcatDataset(
[nih_cxr14_re["train"], padchest_no_tb["train"]]
)
dataset["__valid__"] = ConcatDataset(
[nih_cxr14_re["__valid__"], padchest_no_tb["__valid__"]]
)
dataset["validation"] = ConcatDataset(
[nih_cxr14_re["validation"], padchest_no_tb["validation"]]
)
dataset["test"] = nih_cxr14_re["test"]
return dataset
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Aggregated dataset composed of NIH CXR14 relabeld and PadChest (normalized)
datasets."""
from . import _maker
dataset = _maker("idiap")
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Aggregated dataset composed of NIH CXR14 relabeld and PadChest (normalized)
datasets."""
from clapper.logging import setup
from torch.utils.data.dataset import ConcatDataset
from .. import return_subsets
from ..base_datamodule import BaseDataModule, get_dataset_from_module
from ..nih_cxr14_re.default import datamodule as nih_cxr14_re_datamodule
from ..padchest.no_tb_idiap import datamodule as padchest_no_tb_datamodule
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
class DefaultModule(BaseDataModule):
def __init__(
self,
train_batch_size=1,
predict_batch_size=1,
drop_incomplete_batch=False,
multiproc_kwargs=None,
):
self.train_batch_size = train_batch_size
self.predict_batch_size = predict_batch_size
self.drop_incomplete_batch = drop_incomplete_batch
self.multiproc_kwargs = multiproc_kwargs
super().__init__(
train_batch_size=train_batch_size,
predict_batch_size=predict_batch_size,
drop_incomplete_batch=drop_incomplete_batch,
multiproc_kwargs=multiproc_kwargs,
)
def setup(self, stage: str):
module_args = {
"train_batch_size": self.train_batch_size,
"predict_batch_size": self.predict_batch_size,
"drop_incomplete_batch": self.drop_incomplete_batch,
"multiproc_kwargs": self.multiproc_kwargs,
}
nih_cxr14_re = get_dataset_from_module(
nih_cxr14_re_datamodule, stage, **module_args
)
padchest_no_tb = get_dataset_from_module(
padchest_no_tb_datamodule, stage, **module_args
)
self.dataset = {}
self.dataset["__train__"] = ConcatDataset(
[nih_cxr14_re["__train__"], padchest_no_tb["__train__"]]
)
self.dataset["train"] = ConcatDataset(
[nih_cxr14_re["train"], padchest_no_tb["train"]]
)
self.dataset["__valid__"] = ConcatDataset(
[nih_cxr14_re["__valid__"], padchest_no_tb["__valid__"]]
)
self.dataset["validation"] = ConcatDataset(
[nih_cxr14_re["validation"], padchest_no_tb["validation"]]
)
self.dataset["test"] = nih_cxr14_re["test"]
(
self.train_dataset,
self.validation_dataset,
self.extra_validation_datasets,
self.predict_dataset,
) = return_subsets(self.dataset)
datamodule = DefaultModule
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment