Skip to content
Snippets Groups Projects
default.py 2.96 KiB
# Copyright © 2022 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later

"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and Padchest
datasets."""

from clapper.logging import setup
from torch.utils.data.dataset import ConcatDataset

from .. import return_subsets
from ..base_datamodule import BaseDataModule, get_dataset_from_module
from ..indian.default import datamodule as indian_datamodule
from ..montgomery.default import datamodule as mc_datamodule
from ..padchest.tb_idiap import datamodule as pc_datamodule
from ..shenzhen.default import datamodule as ch_datamodule

logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")


class DefaultModule(BaseDataModule):
    def __init__(
        self,
        train_batch_size=1,
        predict_batch_size=1,
        drop_incomplete_batch=False,
        multiproc_kwargs=None,
    ):
        self.train_batch_size = train_batch_size
        self.predict_batch_size = predict_batch_size
        self.drop_incomplete_batch = drop_incomplete_batch
        self.multiproc_kwargs = multiproc_kwargs

        super().__init__(
            train_batch_size=train_batch_size,
            predict_batch_size=predict_batch_size,
            drop_incomplete_batch=drop_incomplete_batch,
            multiproc_kwargs=multiproc_kwargs,
        )

    def setup(self, stage: str):
        # Instantiate other datamodules and get their datasets

        module_args = {
            "train_batch_size": self.train_batch_size,
            "predict_batch_size": self.predict_batch_size,
            "drop_incomplete_batch": self.drop_incomplete_batch,
            "multiproc_kwargs": self.multiproc_kwargs,
        }

        mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
        ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
        indian = get_dataset_from_module(
            indian_datamodule, stage, **module_args
        )
        pc = get_dataset_from_module(pc_datamodule, stage, **module_args)

        # Combine datasets
        self.dataset = {}
        self.dataset["__train__"] = ConcatDataset(
            [
                mc["__train__"],
                ch["__train__"],
                indian["__train__"],
                pc["__train__"],
            ]
        )
        self.dataset["train"] = ConcatDataset(
            [mc["train"], ch["train"], indian["train"], pc["train"]]
        )
        self.dataset["__valid__"] = ConcatDataset(
            [
                mc["__valid__"],
                ch["__valid__"],
                indian["__valid__"],
                pc["__valid__"],
            ]
        )
        self.dataset["test"] = ConcatDataset(
            [mc["test"], ch["test"], indian["test"], pc["test"]]
        )

        (
            self.train_dataset,
            self.validation_dataset,
            self.extra_validation_datasets,
            self.predict_dataset,
        ) = return_subsets(self.dataset)


datamodule = DefaultModule