# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later """HIV-TB dataset for TB detection (cross validation fold 7) * Split reference: none (stratified kfolding) * This configuration resolution: 512 x 512 (default) * See :py:mod:`ptbench.data.hivtb` for dataset details """ from clapper.logging import setup from .. import return_subsets from ..base_datamodule import BaseDataModule from . import _maker logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") class DefaultModule(BaseDataModule): def __init__( self, train_batch_size=1, predict_batch_size=1, drop_incomplete_batch=False, multiproc_kwargs=None, ): super().__init__( train_batch_size=train_batch_size, predict_batch_size=predict_batch_size, drop_incomplete_batch=drop_incomplete_batch, multiproc_kwargs=multiproc_kwargs, ) def setup(self, stage: str): self.dataset = _maker("fold_7") ( self.train_dataset, self.validation_dataset, self.extra_validation_datasets, self.predict_dataset, ) = return_subsets(self.dataset) datamodule = DefaultModule