Skip to content
Snippets Groups Projects
fold_7.py 1.28 KiB
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""HIV-TB dataset for TB detection (cross validation fold 7)

* Split reference: none (stratified kfolding)
* This configuration resolution: 512 x 512 (default)
* See :py:mod:`ptbench.data.hivtb` for dataset details
"""

from clapper.logging import setup

from .. import return_subsets
from ..base_datamodule import BaseDataModule
from . import _maker

logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")


class DefaultModule(BaseDataModule):
    def __init__(
        self,
        train_batch_size=1,
        predict_batch_size=1,
        drop_incomplete_batch=False,
        multiproc_kwargs=None,
    ):
        super().__init__(
            train_batch_size=train_batch_size,
            predict_batch_size=predict_batch_size,
            drop_incomplete_batch=drop_incomplete_batch,
            multiproc_kwargs=multiproc_kwargs,
        )

    def setup(self, stage: str):
        self.dataset = _maker("fold_7")
        (
            self.train_dataset,
            self.validation_dataset,
            self.extra_validation_datasets,
            self.predict_dataset,
        ) = return_subsets(self.dataset)


datamodule = DefaultModule