From 7be6a4eeb05f6da43b780e22ecd1b253b87ef694 Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.dos.anjos@gmail.com> Date: Tue, 1 Aug 2023 09:07:46 +0200 Subject: [PATCH] [data.nih_cxr14_re] Update datamodule; Prepare framework for multi-class classification --- src/ptbench/data/datamodule.py | 15 +- src/ptbench/data/nih_cxr14_re/__init__.py | 1 - .../data/nih_cxr14_re/cardiomegaly.json | 86 ++++++++++ .../data/nih_cxr14_re/cardiomegaly.json.bz2 | Bin 392 -> 0 bytes src/ptbench/data/nih_cxr14_re/cardiomegaly.py | 45 +---- src/ptbench/data/nih_cxr14_re/datamodule.py | 157 ++++++++++++++++++ src/ptbench/data/nih_cxr14_re/default.py | 44 +---- src/ptbench/data/typing.py | 4 +- 8 files changed, 257 insertions(+), 95 deletions(-) create mode 100644 src/ptbench/data/nih_cxr14_re/cardiomegaly.json delete mode 100644 src/ptbench/data/nih_cxr14_re/cardiomegaly.json.bz2 create mode 100644 src/ptbench/data/nih_cxr14_re/datamodule.py diff --git a/src/ptbench/data/datamodule.py b/src/ptbench/data/datamodule.py index 2cbc4b84..9b5eab61 100644 --- a/src/ptbench/data/datamodule.py +++ b/src/ptbench/data/datamodule.py @@ -106,7 +106,7 @@ class _DelayedLoadingDataset(Dataset): sample_size_mb = _sample_size_bytes(first_sample) / (1024.0 * 1024.0) logger.info(f"Estimated sample size: {sample_size_mb:.1f} Mb") - def labels(self) -> list[int]: + def labels(self) -> list[int | list[int]]: """Returns the integer labels for all samples in the dataset.""" return [self.loader.label(k) for k in self.raw_dataset] @@ -223,7 +223,7 @@ class _CachedDataset(Dataset): f"{sample_size_mb:.1f} / {(len(self.data)*sample_size_mb):.1f} Mb" ) - def labels(self) -> list[int]: + def labels(self) -> list[int | list[int]]: """Returns the integer labels for all samples in the dataset.""" return [k[1]["label"] for k in self.data] @@ -256,7 +256,7 @@ class _ConcatDataset(Dataset): for j in range(len(datasets[i])) ] - def labels(self) -> list[int]: + def labels(self) -> list[int | list[int]]: """Returns the integer labels for all samples in the dataset.""" return list(itertools.chain(*[k.labels() for k in self._datasets])) @@ -379,11 +379,11 @@ def _make_balanced_random_sampler( for ds in dataset.datasets for k in typing.cast(Dataset, ds).labels() ] - weights = _calculate_weights(targets) + weights = _calculate_weights(targets) # type: ignore else: logger.warning( f"Balancing samples **and** concatenated-datasets " - f"WITHOUT metadata targets (`{target}` not available)" + f"by using dataset totals as `{target}: int` is not true" ) weights = [ k @@ -403,10 +403,11 @@ def _make_balanced_random_sampler( f"Balancing samples from dataset using metadata " f"targets `{target}`" ) - weights = _calculate_weights(dataset.labels()) + weights = _calculate_weights(dataset.labels()) # type: ignore else: raise RuntimeError( - f"Cannot balance samples without metadata targets `{target}`" + f"Cannot balance samples with multiple class labels " + f"({target}: list[int]) or without metadata targets `{target}`" ) return torch.utils.data.WeightedRandomSampler( diff --git a/src/ptbench/data/nih_cxr14_re/__init__.py b/src/ptbench/data/nih_cxr14_re/__init__.py index 27d1903c..b9954cf1 100644 --- a/src/ptbench/data/nih_cxr14_re/__init__.py +++ b/src/ptbench/data/nih_cxr14_re/__init__.py @@ -1,7 +1,6 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later - """NIH CXR14 (relabeled) dataset for computer-aided diagnosis. This dataset was extracted from the clinical PACS database at the National diff --git a/src/ptbench/data/nih_cxr14_re/cardiomegaly.json b/src/ptbench/data/nih_cxr14_re/cardiomegaly.json new file mode 100644 index 00000000..b9af6ad7 --- /dev/null +++ b/src/ptbench/data/nih_cxr14_re/cardiomegaly.json @@ -0,0 +1,86 @@ +{ + "train": [ + ["images/00000001_000.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], + ["images/00000001_001.png", [1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], + ["images/00000001_002.png", [1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]], + ["images/00000007_000.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], + ["images/00000010_000.png", [1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0]], + ["images/00000011_000.png", [1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], + ["images/00000011_001.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]], + ["images/00000011_002.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]], + ["images/00000011_003.png", [1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]], + ["images/00000013_011.png", [1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0]], + ["images/00000013_014.png", [1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0]], + ["images/00000013_018.png", [1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0]], + ["images/00000013_022.png", [1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]], + ["images/00000013_024.png", [1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]], + ["images/00000013_025.png", [1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]], + ["images/00000013_026.png", [1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]], + ["images/00000013_027.png", [1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1]], + ["images/00000013_028.png", [1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1]], + ["images/00000013_029.png", [1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1]], + ["images/00000013_030.png", [1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0]], + ["images/00000013_031.png", [1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0]], + ["images/00000013_032.png", [1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0]], + ["images/00000013_034.png", [1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]], + ["images/00000013_037.png", [1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]], + ["images/00000013_038.png", [1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]], + ["images/00000013_040.png", [1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0]], + ["images/00000013_041.png", [1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0]], + ["images/00000013_043.png", [1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0]], + ["images/00000013_044.png", [1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0]], + ["images/00000013_045.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0]], + ["images/00000013_046.png", [1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]], + ["images/00000031_000.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], + ["images/00000033_000.png", [1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0]], + ["images/00000044_002.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], + ["images/00000045_000.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], + ["images/00000046_000.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0]], + ["images/00000054_003.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], + ["images/00000059_000.png", [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], + ["images/00000066_000.png", [1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0]], + ["images/00000069_000.png", [1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]] + ], + "validation": [ + ["images/00000001_000.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], + ["images/00000001_001.png", [1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], + ["images/00000001_002.png", [1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]], + ["images/00000007_000.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], + ["images/00000010_000.png", [1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0]], + ["images/00000011_000.png", [1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], + ["images/00000011_001.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]], + ["images/00000011_002.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]], + ["images/00000011_003.png", [1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]], + ["images/00000013_011.png", [1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0]], + ["images/00000013_014.png", [1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0]], + ["images/00000013_018.png", [1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0]], + ["images/00000013_022.png", [1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]], + ["images/00000013_024.png", [1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]], + ["images/00000013_025.png", [1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]], + ["images/00000013_026.png", [1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]], + ["images/00000013_027.png", [1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1]], + ["images/00000013_028.png", [1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1]], + ["images/00000013_029.png", [1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1]], + ["images/00000013_030.png", [1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0]], + ["images/00000013_031.png", [1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0]], + ["images/00000013_032.png", [1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0]], + ["images/00000013_034.png", [1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]], + ["images/00000013_037.png", [1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]], + ["images/00000013_038.png", [1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]], + ["images/00000013_040.png", [1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0]], + ["images/00000013_041.png", [1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0]], + ["images/00000013_043.png", [1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0]], + ["images/00000013_044.png", [1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0]], + ["images/00000013_045.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0]], + ["images/00000013_046.png", [1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]], + ["images/00000031_000.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], + ["images/00000033_000.png", [1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0]], + ["images/00000044_002.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], + ["images/00000045_000.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], + ["images/00000046_000.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0]], + ["images/00000054_003.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], + ["images/00000059_000.png", [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], + ["images/00000066_000.png", [1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0]], + ["images/00000069_000.png", [1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]] + ] +} diff --git a/src/ptbench/data/nih_cxr14_re/cardiomegaly.json.bz2 b/src/ptbench/data/nih_cxr14_re/cardiomegaly.json.bz2 deleted file mode 100644 index 13b6d810cd7bb430b332476e363970a5a728fa3e..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 392 zcmV;30eAjFT4*^jL0KkKS=7hkqW}wZTYwl4PzC?+01Bq3-3mYgyZ`_OFqoPQ5r|}A zFvb%TL4q+1j3yYuVrVc%A(4c@sj8Dg6g0+@X+1`n8L8kI*kjE%=Kdbn0fF;)hHx4E z27R6ZoV#tY?|Bnx6G+JfCXzui1dS$58f`HqX_HMEG-R1HB#^{u5+qGDiKd%L5h6^I zX|{=^Ns=^;5u$A-B1lAyB*~K*B#kDRjUy&a6KRtWnoXoc+G&y=;K}BXig~A3N6FE; zrw&~nj;^O?3G+I=9i82ix=viakHBXko{Y`1Z18zFY}xAB>KQiRnrXB{b!I%D2L=*h z88RUfMnn;&h#@wKlSwqwmq%v~@1ulfj;+i2b|jON+?=L6qq<4DI-5GWxO8$lq?4Du zJ31U3I5>OVoV*?mj)!+w!Q{g}FFCd^)Z4|T+3<O~JX6H|9i5!F?1?!xeZxN<-*5NS mGt@Kg8UK8PIQF{+es5sQ@)_6IGq-o6{}*yaI8cz($Ks>l&#)o@ diff --git a/src/ptbench/data/nih_cxr14_re/cardiomegaly.py b/src/ptbench/data/nih_cxr14_re/cardiomegaly.py index 1904ebfa..0715650d 100644 --- a/src/ptbench/data/nih_cxr14_re/cardiomegaly.py +++ b/src/ptbench/data/nih_cxr14_re/cardiomegaly.py @@ -2,47 +2,6 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -"""NIH CXR14 dataset for computer-aided diagnosis. +from .datamodule import DataModule -First 40 images with cardiomegaly. - -* See :py:mod:`ptbench.data.nih_cxr14_re` for split details -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.nih_cxr14_re` for dataset details -""" - -from clapper.logging import setup - -from .. import return_subsets -from ..base_datamodule import BaseDataModule -from . import _maker - -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") - - -class Fold0Module(BaseDataModule): - def __init__( - self, - train_batch_size=1, - predict_batch_size=1, - drop_incomplete_batch=False, - multiproc_kwargs=None, - ): - super().__init__( - train_batch_size=train_batch_size, - predict_batch_size=predict_batch_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, - ) - - def setup(self, stage: str): - self.dataset = _maker("cardiomegaly") - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.dataset) - - -datamodule = Fold0Module +datamodule = DataModule("cardiomegaly.json") diff --git a/src/ptbench/data/nih_cxr14_re/datamodule.py b/src/ptbench/data/nih_cxr14_re/datamodule.py new file mode 100644 index 00000000..66a4379c --- /dev/null +++ b/src/ptbench/data/nih_cxr14_re/datamodule.py @@ -0,0 +1,157 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import importlib.resources +import os + +import PIL.Image + +from torchvision.transforms.functional import to_tensor + +from ...utils.rc import load_rc +from ..datamodule import CachingDataModule +from ..split import JSONDatabaseSplit +from ..typing import DatabaseSplit +from ..typing import RawDataLoader as _BaseRawDataLoader +from ..typing import Sample + + +class RawDataLoader(_BaseRawDataLoader): + """A specialized raw-data-loader for the Montgomery dataset. + + Attributes + ---------- + + datadir + This variable contains the base directory where the database raw data + is stored. + + idiap_file_organisation + This variable will be ``True``, if the user has set the configuration + parameter ``nih_cxr14_re.idiap_file_organisation`` in the global + configuration file. It will cause internal loader to search for files + in a slightly different folder structure, that was adapted to Idiap's + requirements (number of files per folder to be less than 10k). + """ + + datadir: str + idiap_file_organisation: bool + + def __init__(self): + rc = load_rc() + self.datadir = rc.get( + "datadir.nih_cxr14_re", os.path.realpath(os.curdir) + ) + self.idiap_file_organisation = rc.get( + "nih_cxr14_re.idiap_folder_structure", False + ) + + def sample(self, sample: tuple[str, list[int]]) -> Sample: + """Loads a single image sample from the disk. + + Parameters + ---------- + + sample: + A tuple containing the path suffix, within the dataset root folder, + where to find the image to be loaded, and an integer, representing the + sample label. + + + Returns + ------- + + sample + The sample representation + """ + file_path = sample[0] # default + if self.idiap_file_organisation: + # for folder lookup efficiency, data is split into subfolders + # each original file is on the subfolder `f[:5]/f`, where f + # is the original file basename + basename = os.path.basename(sample[0]) + file_path = os.path.join( + os.path.dirname(sample[0]), + basename[:5], + basename, + ) + + # N.B.: NIH CXR-14 images are encoded as color PNGs + image = PIL.Image.open(os.path.join(self.datadir, file_path)) + tensor = to_tensor(image) + + # use the code below to view generated images + # from torchvision.transforms.functional import to_pil_image + # to_pil_image(tensor).show() + # __import__("pdb").set_trace() + + return tensor, dict(label=sample[1], name=sample[0]) # type: ignore[arg-type] + + def label(self, sample: tuple[str, list[int]]) -> list[int]: + """Loads a single image sample label from the disk. + + Parameters + ---------- + + sample: + A tuple containing the path suffix, within the dataset root folder, + where to find the image to be loaded, and an integer, representing the + sample label. + + + Returns + ------- + + labels + The integer labels associated with the sample + """ + return sample[1] + + +def make_split(basename: str) -> DatabaseSplit: + """Returns a database split for the Montgomery database.""" + + return JSONDatabaseSplit( + importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(basename) + ) + + +class DataModule(CachingDataModule): + """NIH CXR14 (relabeled) datamodule for computer-aided diagnosis. + + This dataset was extracted from the clinical PACS database at the National + Institutes of Health Clinical Center (USA) and represents 60% of all their + radiographs. It contains labels for 14 common radiological signs in this + order: cardiomegaly, emphysema, effusion, hernia, infiltration, mass, + nodule, atelectasis, pneumothorax, pleural thickening, pneumonia, fibrosis, + edema and consolidation. This is the relabeled version created in the + CheXNeXt study. + + * Reference: [NIH-CXR14-2017]_ + * Original resolution (height x width): 1024 x 1024 + * Labels: [CHEXNEXT-2018]_ + * Split reference: [CHEXNEXT-2018]_ + * Protocol ``default``: + + * Training samples: 98637 + * Validation samples: 6350 + * Test samples: 4355 + + * Output image: + + * Transforms: + + * Load raw PNG with :py:mod:`PIL` + + * Final specifications + + * RGB, encoded as a 3-plane image, 8 bits + * Square (1024x1024 px) + """ + + def __init__(self, split_filename: str): + super().__init__( + database_split=make_split(split_filename), + raw_data_loader=RawDataLoader(), + ) diff --git a/src/ptbench/data/nih_cxr14_re/default.py b/src/ptbench/data/nih_cxr14_re/default.py index 0ea6ef5a..7fe993a9 100644 --- a/src/ptbench/data/nih_cxr14_re/default.py +++ b/src/ptbench/data/nih_cxr14_re/default.py @@ -2,46 +2,6 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -"""NIH CXR14 (relabeled) dataset for computer-aided diagnosis (default -protocol) +from .datamodule import DataModule -* See :py:mod:`ptbench.data.nih_cxr14_re` for split details -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.nih_cxr14_re` for dataset details -""" - -from clapper.logging import setup - -from .. import return_subsets -from ..base_datamodule import BaseDataModule -from . import _maker - -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") - - -class DefaultModule(BaseDataModule): - def __init__( - self, - train_batch_size=1, - predict_batch_size=1, - drop_incomplete_batch=False, - multiproc_kwargs=None, - ): - super().__init__( - train_batch_size=train_batch_size, - predict_batch_size=predict_batch_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, - ) - - def setup(self, stage: str): - self.dataset = _maker("default") - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.dataset) - - -datamodule = DefaultModule +datamodule = DataModule("default.json.bz2") diff --git a/src/ptbench/data/typing.py b/src/ptbench/data/typing.py index bf821068..6f41b39e 100644 --- a/src/ptbench/data/typing.py +++ b/src/ptbench/data/typing.py @@ -28,7 +28,7 @@ class RawDataLoader: """Loads whole samples from media.""" raise NotImplementedError("You must implement the `sample()` method") - def label(self, k: typing.Any) -> int: + def label(self, k: typing.Any) -> int | list[int]: """Loads only sample label from media. If you do not override this implementation, then, by default, @@ -79,7 +79,7 @@ class Dataset(torch.utils.data.Dataset[Sample], typing.Iterable, typing.Sized): provide a dunder len method. """ - def labels(self) -> list[int]: + def labels(self) -> list[int | list[int]]: """Returns the integer labels for all samples in the dataset.""" raise NotImplementedError("You must implement the `labels()` method") -- GitLab