Skip to content
Snippets Groups Projects
Commit 0ad0f1c4 authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

Moved mc_ch_in_pc configs to data

parent 725e44ea
No related branches found
No related tags found
No related merge requests found
...@@ -465,8 +465,8 @@ hivtb_rs_f7 = "ptbench.configs.datasets.hivtb_RS.fold_7" ...@@ -465,8 +465,8 @@ hivtb_rs_f7 = "ptbench.configs.datasets.hivtb_RS.fold_7"
hivtb_rs_f8 = "ptbench.configs.datasets.hivtb_RS.fold_8" hivtb_rs_f8 = "ptbench.configs.datasets.hivtb_RS.fold_8"
hivtb_rs_f9 = "ptbench.configs.datasets.hivtb_RS.fold_9" hivtb_rs_f9 = "ptbench.configs.datasets.hivtb_RS.fold_9"
# montgomery-shenzhen-indian-padchest aggregated dataset # montgomery-shenzhen-indian-padchest aggregated dataset
mc_ch_in_pc = "ptbench.configs.datasets.mc_ch_in_pc.default" mc_ch_in_pc = "ptbench.data.mc_ch_in_pc.default"
mc_ch_in_pc_rgb = "ptbench.configs.datasets.mc_ch_in_pc.rgb" mc_ch_in_pc_rgb = "ptbench.data.mc_ch_in_pc.rgb"
# extended montgomery-shenzhen-indian-padchest aggregated dataset # extended montgomery-shenzhen-indian-padchest aggregated dataset
# (with radiological signs) # (with radiological signs)
mc_ch_in_pc_rs = "ptbench.configs.datasets.mc_ch_in_pc_RS.default" mc_ch_in_pc_rs = "ptbench.configs.datasets.mc_ch_in_pc_RS.default"
......
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
from torch.utils.data.dataset import ConcatDataset
def _maker(protocol):
if protocol == "default":
from ..indian import default as indian
from ..montgomery import default as mc
from ..padchest import tb_idiap as pc
from ..shenzhen import default as ch
elif protocol == "rgb":
from ..indian import rgb as indian
from ..montgomery import rgb as mc
from ..padchest import tb_idiap_rgb as pc
from ..shenzhen import rgb as ch
mc = mc.dataset
ch = ch.dataset
indian = indian.dataset
pc = pc.dataset
dataset = {}
dataset["__train__"] = ConcatDataset(
[mc["__train__"], ch["__train__"], indian["__train__"], pc["__train__"]]
)
dataset["train"] = ConcatDataset(
[mc["train"], ch["train"], indian["train"], pc["train"]]
)
dataset["__valid__"] = ConcatDataset(
[mc["__valid__"], ch["__valid__"], indian["__valid__"], pc["__valid__"]]
)
dataset["test"] = ConcatDataset(
[mc["test"], ch["test"], indian["test"], pc["test"]]
)
return dataset
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and Padchest
datasets."""
from . import _maker
dataset = _maker("default")
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and Padchest
(RGB) datasets."""
from . import _maker
dataset = _maker("rgb")
# Copyright © 2022 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
# Copyright © 2022 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and Padchest
datasets."""
from clapper.logging import setup
from torch.utils.data.dataset import ConcatDataset
from .. import return_subsets
from ..base_datamodule import BaseDataModule, get_dataset_from_module
from ..indian.default import datamodule as indian_datamodule
from ..montgomery.default import datamodule as mc_datamodule
from ..padchest.tb_idiap import datamodule as pc_datamodule
from ..shenzhen.default import datamodule as ch_datamodule
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
class DefaultModule(BaseDataModule):
def __init__(
self,
train_batch_size=1,
predict_batch_size=1,
drop_incomplete_batch=False,
multiproc_kwargs=None,
):
self.train_batch_size = train_batch_size
self.predict_batch_size = predict_batch_size
self.drop_incomplete_batch = drop_incomplete_batch
self.multiproc_kwargs = multiproc_kwargs
super().__init__(
train_batch_size=train_batch_size,
predict_batch_size=predict_batch_size,
drop_incomplete_batch=drop_incomplete_batch,
multiproc_kwargs=multiproc_kwargs,
)
def setup(self, stage: str):
# Instantiate other datamodules and get their datasets
module_args = {
"train_batch_size": self.train_batch_size,
"predict_batch_size": self.predict_batch_size,
"drop_incomplete_batch": self.drop_incomplete_batch,
"multiproc_kwargs": self.multiproc_kwargs,
}
mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
indian = get_dataset_from_module(
indian_datamodule, stage, **module_args
)
pc = get_dataset_from_module(pc_datamodule, stage, **module_args)
# Combine datasets
self.dataset = {}
self.dataset["__train__"] = ConcatDataset(
[
mc["__train__"],
ch["__train__"],
indian["__train__"],
pc["__train__"],
]
)
self.dataset["train"] = ConcatDataset(
[mc["train"], ch["train"], indian["train"], pc["train"]]
)
self.dataset["__valid__"] = ConcatDataset(
[
mc["__valid__"],
ch["__valid__"],
indian["__valid__"],
pc["__valid__"],
]
)
self.dataset["test"] = ConcatDataset(
[mc["test"], ch["test"], indian["test"], pc["test"]]
)
(
self.train_dataset,
self.validation_dataset,
self.extra_validation_datasets,
self.predict_dataset,
) = return_subsets(self.dataset)
datamodule = DefaultModule
# Copyright © 2022 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and Padchest
(RGB) datasets."""
from clapper.logging import setup
from torch.utils.data.dataset import ConcatDataset
from .. import return_subsets
from ..base_datamodule import BaseDataModule, get_dataset_from_module
from ..indian.rgb import datamodule as indian_datamodule
from ..montgomery.rgb import datamodule as mc_datamodule
from ..padchest.tb_idiap_rgb import datamodule as pc_datamodule
from ..shenzhen.rgb import datamodule as ch_datamodule
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
class DefaultModule(BaseDataModule):
def __init__(
self,
train_batch_size=1,
predict_batch_size=1,
drop_incomplete_batch=False,
multiproc_kwargs=None,
):
self.train_batch_size = train_batch_size
self.predict_batch_size = predict_batch_size
self.drop_incomplete_batch = drop_incomplete_batch
self.multiproc_kwargs = multiproc_kwargs
super().__init__(
train_batch_size=train_batch_size,
predict_batch_size=predict_batch_size,
drop_incomplete_batch=drop_incomplete_batch,
multiproc_kwargs=multiproc_kwargs,
)
def setup(self, stage: str):
# Instantiate other datamodules and get their datasets
module_args = {
"train_batch_size": self.train_batch_size,
"predict_batch_size": self.predict_batch_size,
"drop_incomplete_batch": self.drop_incomplete_batch,
"multiproc_kwargs": self.multiproc_kwargs,
}
mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
indian = get_dataset_from_module(
indian_datamodule, stage, **module_args
)
pc = get_dataset_from_module(pc_datamodule, stage, **module_args)
# Combine datasets
self.dataset = {}
self.dataset["__train__"] = ConcatDataset(
[
mc["__train__"],
ch["__train__"],
indian["__train__"],
pc["__train__"],
]
)
self.dataset["train"] = ConcatDataset(
[mc["train"], ch["train"], indian["train"], pc["train"]]
)
self.dataset["__valid__"] = ConcatDataset(
[
mc["__valid__"],
ch["__valid__"],
indian["__valid__"],
pc["__valid__"],
]
)
self.dataset["test"] = ConcatDataset(
[mc["test"], ch["test"], indian["test"], pc["test"]]
)
(
self.train_dataset,
self.validation_dataset,
self.extra_validation_datasets,
self.predict_dataset,
) = return_subsets(self.dataset)
datamodule = DefaultModule
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment