diff --git a/src/ptbench/data/datamodule.py b/src/ptbench/data/datamodule.py index 11e76697b96a49836d0de435e12c7d2289e5252f..6670a3099bd4f855f89eba0001380421a08e176f 100644 --- a/src/ptbench/data/datamodule.py +++ b/src/ptbench/data/datamodule.py @@ -10,11 +10,7 @@ import typing import lightning import torch import torch.utils.data - -from clapper.logging import setup - -# TODO: No logging on this module... -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") +import torchvision.transforms def _setup_dataloader_multiproc_parameters( @@ -93,11 +89,13 @@ class _DelayedLoadingDataset(torch.utils.data.Dataset): raw_data_loader: typing.Callable[ [typing.Any], tuple[torch.Tensor, typing.Mapping] ], - transforms: typing.Optional[typing.Sequence[torch.nn.Module]] = None, + transforms: typing.Sequence[ + typing.Callable[[torch.Tensor], torch.Tensor] + ] = [], ): self.split = split self.raw_data_loader = raw_data_loader - self.transform = torch.nn.Sequential(*transforms) + self.transform = torchvision.transforms.Compose(*transforms) def __getitem__(self, key: int) -> tuple[torch.Tensor, typing.Mapping]: tensor, metadata = self.raw_data_loader(self.split[key]) @@ -137,10 +135,12 @@ class _CachedDataset(torch.utils.data.Dataset): raw_data_loader: typing.Callable[ [typing.Any], tuple[torch.Tensor, typing.Mapping] ], - transforms: typing.Optional[typing.Sequence[torch.nn.Module]] = None, + transforms: typing.Sequence[ + typing.Callable[[torch.Tensor], torch.Tensor] + ] = [], ): self.data = [raw_data_loader(k) for k in split] - self.transform = torch.nn.Sequential(*transforms) + self.transform = torchvision.transforms.Compose(*transforms) def __getitem__(self, key: int) -> tuple[torch.Tensor, typing.Mapping]: tensor, metadata = self.data[key] @@ -344,22 +344,21 @@ class CachingDataModule(lightning.LightningDataModule): ], cache_samples: bool = False, train_sampler: typing.Optional[torch.utils.data.Sampler] = None, - data_augmentations: list[torch.nn.Module] = [], - model_transforms: list[torch.nn.Module] = [], + data_augmentations: list[ + typing.Callable[[torch.Tensor], torch.Tensor] + ] = [], + model_transforms: list[ + typing.Callable[[torch.Tensor], torch.Tensor] + ] = [], batch_size: int = 1, batch_chunk_count: int = 1, drop_incomplete_batch: bool = False, parallel: int = -1, ): - # validation - if batch_size % batch_chunk_count != 0: - raise RuntimeError( - f"batch_size ({batch_size}) must be divisible by " - f"batch_chunk_size ({batch_chunk_count})." - ) - super().__init__() + self.set_chunk_size(batch_size, batch_chunk_count) + self.database_split = database_split self.raw_data_loader = raw_data_loader self.cache_samples = cache_samples @@ -367,16 +366,8 @@ class CachingDataModule(lightning.LightningDataModule): self.data_augmentations = data_augmentations self.model_transforms = model_transforms - self._batch_size = batch_size - self._batch_chunk_count = batch_chunk_count - self._chunk_size = self._batch_size // self._batch_chunk_count - self.drop_incomplete_batch = drop_incomplete_batch - self._parallel = parallel # immutable, otherwise would need to call - # the next function again - self._dataloader_multiproc = _setup_dataloader_multiproc_parameters( - parallel - ) + self.parallel = parallel # immutable, otherwise would need to call self.pin_memory = ( torch.cuda.is_available() @@ -385,6 +376,63 @@ class CachingDataModule(lightning.LightningDataModule): # datasets that have been setup() for the current stage self._datasets: dict[str, _DelayedLoadingDataset | _CachedDataset] = {} + @property + def parallel(self) -> int: + """The parallel property.""" + return self._parallel + + @parallel.setter + def parallel(self, value: int) -> None: + self._parallel = value + self._dataloader_multiproc = _setup_dataloader_multiproc_parameters( + value + ) + # datasets that have been setup() for the current stage + self._datasets: dict[str, _DelayedLoadingDataset | _CachedDataset] = {} + + def set_chunk_size(self, batch_size: int, batch_chunk_count: int) -> None: + """Coherently sets the batch-chunk-size after validation. + + Parameters + ---------- + + batch_size + Number of samples in every **training** batch (this parameter affects + memory requirements for the network). If the number of samples in the + batch is larger than the total number of samples available for + training, this value is truncated. If this number is smaller, then + batches of the specified size are created and fed to the network until + there are no more new samples to feed (epoch is finished). If the + total number of training samples is not a multiple of the batch-size, + the last batch will be smaller than the first, unless + ``drop_incomplete_batch`` is set to ``true``, in which case this batch + is not used. + + batch_chunk_count + Number of chunks in every batch (this parameter affects memory + requirements for the network). The number of samples loaded for every + iteration will be ``batch_size/batch_chunk_count``. ``batch_size`` + needs to be divisible by ``batch_chunk_count``, otherwise an error will + be raised. This parameter is used to reduce number of samples loaded in + each iteration, in order to reduce the memory usage in exchange for + processing time (more iterations). This is specially interesting whe + one is running with GPUs with limited RAM. The default of 1 forces the + whole batch to be processed at once. Otherwise the batch is broken into + batch-chunk-count pieces, and gradients are accumulated to complete + each batch. + """ + + # validation + if batch_size % batch_chunk_count != 0: + raise RuntimeError( + f"batch_size ({batch_size}) must be divisible by " + f"batch_chunk_size ({batch_chunk_count})." + ) + + self._batch_size = batch_size + self._batch_chunk_count = batch_chunk_count + self._chunk_size = self._batch_size // self._batch_chunk_count + def setup(self, stage: str) -> None: """Sets up datasets for different tasks on the pipeline. @@ -440,11 +488,40 @@ class CachingDataModule(lightning.LightningDataModule): elif stage == "predict": _setup("test", self.model_transforms) - def train_dataloader(self): + def unaugmented_train_dataloader(self) -> torch.utils.data.DataLoader: + """Returns a version of the train dataloader without augmentations. + + Use this method to obtain a version of the train dataloader without + augmentations, to compute input normalisation factors (e.g. mean and + standard deviation or min-max parameterisations). + + + Returns + ------- + + dataloader + The unaugmented train dataloader + """ + dataset = _DelayedLoadingDataset( + self.database_split["train"], + self.raw_data_loader, + self.model_transforms, + ) + return torch.utils.data.DataLoader( + dataset, + shuffle=False, + batch_size=self._chunk_size, + drop_last=self.drop_incomplete_batch, + pin_memory=self.pin_memory, + **self._dataloader_multiproc, + ) + + def train_dataloader(self) -> torch.utils.data.DataLoader: """Returns the train data loader.""" return torch.utils.data.DataLoader( self._datasets["train"], + shuffle=True, batch_size=self._chunk_size, drop_last=self.drop_incomplete_batch, pin_memory=self.pin_memory, @@ -452,14 +529,19 @@ class CachingDataModule(lightning.LightningDataModule): **self._dataloader_multiproc, ) - def val_dataloader(self): - """Returns the validation data loader(s)""" + def available_dataset_keys(self) -> typing.KeysView[str]: + """Returns all names for datasets that are setup.""" + return self._datasets.keys() - extra_valid = [ + def val_database_split_keys(self) -> list[str]: + """Returns list of validation dataset names.""" + return ["validation"] + [ k for k in self.database_split.keys() if k.startswith("monitor-") ] - # TODO: do we really need the train sampler here? + def val_dataloader(self) -> dict[str, torch.utils.data.DataLoader]: + """Returns the validation data loader(s)""" + validation_loader_opts = { "batch_size": self._chunk_size, "shuffle": False, @@ -468,22 +550,13 @@ class CachingDataModule(lightning.LightningDataModule): } validation_loader_opts.update(self._dataloader_multiproc) - # TODO: not sure this is the right way to handle multiple validation - # loaders, please check and fix - if not extra_valid: - return torch.utils.data.DataLoader( - self._datasets["validation"], - **validation_loader_opts, + # select all keys of interest + return { + k: torch.utils.data.DataLoader( + self._datasets[k], **validation_loader_opts ) - - else: - return [ - torch.utils.data.DataLoader( - self._datasets[k], - **validation_loader_opts, - ) - for k in ["validation"] + extra_valid - ] + for k in self.val_database_split_keys() + } def test_dataloader(self): """Returns the test data loader(s)""" diff --git a/src/ptbench/data/dataset.py b/src/ptbench/data/dataset.py index 6b373db4130a4422ac3ebce7b830392ac8bd2b79..c85fbcb3d221a2e73125d53785e0fff8cbdaa464 100644 --- a/src/ptbench/data/dataset.py +++ b/src/ptbench/data/dataset.py @@ -2,13 +2,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -import collections.abc -import csv -import importlib.abc -import json import logging -import pathlib -import typing import torch import torch.utils.data @@ -16,249 +10,6 @@ import torch.utils.data logger = logging.getLogger(__name__) -class JSONDatabaseSplit( - dict, - typing.Mapping[str, typing.Sequence[typing.Any]], -): - """Defines a loader that understands a database split (train, test, etc) in - JSON format. - - To create a new database split, you need to provide a JSON formatted - dictionary in a file, with contents similar to the following: - - .. code-block:: json - - { - "subset1": [ - [ - "sample1-data1", - "sample1-data2", - "sample1-data3", - ], - [ - "sample2-data1", - "sample2-data2", - "sample2-data3", - ] - ], - "subset2": [ - [ - "sample42-data1", - "sample42-data2", - "sample42-data3", - ], - ] - } - - Your database split many contain any number of subsets (dictionary keys). - For simplicity, we recommend all sample entries are formatted similarly so - that raw-data-loading is simplified. Use the function - :py:func:`check_database_split_loading` to test raw data loading and fine - tune the dataset split, or its loading. - - Objects of this class behave like a dictionary in which keys are subset - names in the split, and values represent samples data and meta-data. - - - Parameters - ---------- - - path - Absolute path to a JSON formatted file containing the database split to be - recognized by this object. - """ - - def __init__(self, path: pathlib.Path | str | importlib.abc.Traversable): - if isinstance(path, str): - path = pathlib.Path(path) - self.path = path - self.subsets = self._load_split_from_disk() - - def _load_split_from_disk(self) -> dict[str, typing.Sequence[typing.Any]]: - """Loads all subsets in a split from its file system representation. - - This method will load JSON information for the current split and return - all subsets of the given split after converting each entry through the - loader function. - - - Returns - ------- - - subsets : dict - A dictionary mapping subset names to lists of JSON objects - """ - - if str(self.path).endswith(".bz2"): - logger.debug(f"Loading database split from {str(self.path)}...") - with __import__("bz2").open(self.path) as f: - return json.load(f) - else: - with self.path.open() as f: - return json.load(f) - - def __getitem__(self, key: str) -> typing.Sequence[typing.Any]: - """Accesses subset ``key`` from this split.""" - return self.subsets[key] - - def __iter__(self): - """Iterates over the subsets.""" - return iter(self.subsets) - - def __len__(self) -> int: - """How many subsets we currently have.""" - return len(self.subsets) - - -class CSVDatabaseSplit(collections.abc.Mapping): - """Defines a loader that understands a database split (train, test, etc) in - CSV format. - - To create a new database split, you need to provide one or more CSV - formatted files, each representing a subset of this split, containing the - sample data (one per row). Example: - - Inside the directory ``my-split/``, one can file files ``train.csv``, - ``validation.csv``, and ``test.csv``. Each file has a structure similar to - the following: - - .. code-block:: text - - sample1-value1,sample1-value2,sample1-value3 - sample2-value1,sample2-value2,sample2-value3 - ... - - Each file in the provided directory defines the subset name on the split. - So, the file ``train.csv`` will contain the data from the ``train`` subset, - and so on. - - Objects of this class behave like a dictionary in which keys are subset - names in the split, and values represent samples data and meta-data. - - - Parameters - ---------- - - directory - Absolute path to a directory containing the database split layed down - as a set of CSV files, one per subset. - """ - - def __init__( - self, directory: pathlib.Path | str | importlib.abc.Traversable - ): - if isinstance(directory, str): - directory = pathlib.Path(directory) - assert ( - directory.is_dir() - ), f"`{str(directory)}` is not a valid directory" - self.directory = directory - self.subsets = self._load_split_from_disk() - - def _load_split_from_disk( - self, - ) -> dict[str, list[typing.Any]]: - """Loads all subsets in a split from its file system representation. - - This method will load CSV information for the current split and return all - subsets of the given split after converting each entry through the - loader function. - - - Returns - ------- - - subsets : dict - A dictionary mapping subset names to lists of JSON objects - """ - - retval = {} - for subset in self.directory.iterdir(): - if str(subset).endswith(".csv.bz2"): - logger.debug(f"Loading database split from {subset}...") - with __import__("bz2").open(subset) as f: - reader = csv.reader(f) - retval[subset.name[: -len(".csv.bz2")]] = [ - k for k in reader - ] - elif str(subset).endswith(".csv"): - with subset.open() as f: - reader = csv.reader(f) - retval[subset.name[: -len(".csv")]] = [k for k in reader] - else: - logger.debug( - f"Ignoring file {subset} in CSVDatabaseSplit readout" - ) - return retval - - def __getitem__(self, key: str) -> typing.Sequence[typing.Any]: - """Accesses subset ``key`` from this split.""" - return self.subsets[key] - - def __iter__(self): - """Iterates over the subsets.""" - return iter(self.subsets) - - def __len__(self) -> int: - """How many subsets we currently have.""" - return len(self.subsets) - - -def check_database_split_loading( - database_split: typing.Mapping[str, typing.Sequence[typing.Any]], - loader: typing.Callable[[typing.Any], torch.Tensor], - limit: int = 0, -) -> int: - """For each subset in the split, check if all data can be correctly loaded - using the provided loader function. - - This function will return the number of errors loading samples, and will - log more detailed information to the logging stream. - - - Parameters - ---------- - - database_split - A mapping that, contains the database split. Each key represents the - name of a subset in the split. Each value is a (potentially complex) - object that represents a single sample. - - loader - A callable that transforms sample entries in the database split - into :py:class:`torch.Tensor` objects that can be used for training - or inference. - - limit - Maximum number of samples to check (in each split/subset - combination) in this dataset. If set to zero, then check - everything. - - - Returns - ------- - - errors - Number of errors found - """ - logger.info( - "Checking if can load all samples in all subsets of this split..." - ) - errors = 0 - for subset in database_split.keys(): - samples = subset if not limit else subset[:limit] - for pos, sample in enumerate(samples): - try: - data = loader(sample) - assert isinstance(data, torch.Tensor) - except Exception as e: - logger.info( - f"Found error loading entry {pos} in subset `{subset}`: {e}" - ) - errors += 1 - return errors - - def get_positive_weights(dataset): """Compute the positive weights of each class of the dataset to balance the BCEWithLogitsLoss criterion. @@ -350,45 +101,3 @@ def reweight_BCEWithLogitsLoss(datamodule, model, criterion, criterion_valid): ) else: logger.warning("Weighted valid criterion not supported") - - -def normalize_data(normalization, model, datamodule): - from torch.utils.data import DataLoader - - datamodule.prepare_data() - datamodule.setup(stage="fit") - - train_dataset = datamodule.train_dataset - - # Create z-normalization model layer if needed - if normalization == "imagenet": - model.normalizer.set_mean_std( - [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] - ) - logger.info("Z-normalization with ImageNet mean and std") - elif normalization == "current": - # Compute mean/std of current train subset - temp_dl = DataLoader( - dataset=train_dataset, batch_size=len(train_dataset) - ) - - data = next(iter(temp_dl)) - mean = data[1].mean(dim=[0, 2, 3]) - std = data[1].std(dim=[0, 2, 3]) - - model.normalizer.set_mean_std(mean, std) - - # Format mean and std for logging - mean = str( - [ - round(x, 3) - for x in ((mean * 10**3).round() / (10**3)).tolist() - ] - ) - std = str( - [ - round(x, 3) - for x in ((std * 10**3).round() / (10**3)).tolist() - ] - ) - logger.info(f"Z-normalization with mean {mean} and std {std}") diff --git a/src/ptbench/data/padchest/__init__.py b/src/ptbench/data/padchest/__init__.py index af1dd3ec9f8cbaa1b3e32f5195363592602d94ac..f6f39a9ebe98eb68a88374b8d83a70447a4ad995 100644 --- a/src/ptbench/data/padchest/__init__.py +++ b/src/ptbench/data/padchest/__init__.py @@ -264,7 +264,7 @@ json_dataset = JSONDataset( def _maker(protocol, resize_size=512, cc_size=512, RGB=True): import torchvision.transforms as transforms - from ..transforms import SingleAutoLevel16to8 + from ..loader import SingleAutoLevel16to8 post_transforms = [] if not RGB: diff --git a/src/ptbench/data/loader.py b/src/ptbench/data/raw_data_loader.py similarity index 51% rename from src/ptbench/data/loader.py rename to src/ptbench/data/raw_data_loader.py index d6e86ed06dd04abb230830ebc8b13e2ea9720cfa..d852743f03cd23940868d8212662210bae70e7aa 100644 --- a/src/ptbench/data/loader.py +++ b/src/ptbench/data/raw_data_loader.py @@ -5,9 +5,42 @@ """Data loading code.""" +import numpy import PIL.Image +class SingleAutoLevel16to8: + """Converts a 16-bit image to 8-bit representation using "auto-level". + + This transform assumes that the input image is gray-scaled. + + To auto-level, we calculate the maximum and the minimum of the image, and + consider such a range should be mapped to the [0,255] range of the + destination image. + """ + + def __call__(self, img): + imin, imax = img.getextrema() + irange = imax - imin + return PIL.Image.fromarray( + numpy.round( + 255.0 * (numpy.array(img).astype(float) - imin) / irange + ).astype("uint8"), + ).convert("L") + + +class RemoveBlackBorders: + """Remove black borders of CXR.""" + + def __init__(self, threshold=0): + self.threshold = threshold + + def __call__(self, img): + img = numpy.asarray(img) + mask = numpy.asarray(img) > self.threshold + return PIL.Image.fromarray(img[numpy.ix_(mask.any(1), mask.any(0))]) + + def load_pil(path): """Loads a sample data. diff --git a/src/ptbench/data/shenzhen/default.py b/src/ptbench/data/shenzhen/default.py index c9068112992b3dc5382f802fa6f72f6c6f118c14..8d94329247f866a29aa9b07c05952e6d2fbfa296 100644 --- a/src/ptbench/data/shenzhen/default.py +++ b/src/ptbench/data/shenzhen/default.py @@ -15,13 +15,15 @@ This configuration: import importlib.resources from ..datamodule import CachingDataModule -from ..dataset import JSONDatabaseSplit +from ..split import JSONDatabaseSplit from ..transforms import ElasticDeformation -from .loader import raw_data_loader +from .raw_data_loader import raw_data_loader datamodule = CachingDataModule( database_split=JSONDatabaseSplit( - importlib.resources.files(__name__).joinpath("default.json.bz2") + importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath( + "default.json.bz2" + ) ), raw_data_loader=raw_data_loader, cache_samples=False, diff --git a/src/ptbench/data/shenzhen/loader.py b/src/ptbench/data/shenzhen/raw_data_loader.py similarity index 96% rename from src/ptbench/data/shenzhen/loader.py rename to src/ptbench/data/shenzhen/raw_data_loader.py index 6578fc8fbf2657d51a100d20a66b5f0b303591d1..0c2c39df4b0005bcf7f45f75e3cc143b234d3f92 100644 --- a/src/ptbench/data/shenzhen/loader.py +++ b/src/ptbench/data/shenzhen/raw_data_loader.py @@ -27,8 +27,7 @@ import torch.nn import torchvision.transforms from ...utils.rc import load_rc -from ..loader import load_pil_baw -from ..transforms import RemoveBlackBorders +from ..raw_data_loader import RemoveBlackBorders, load_pil_baw _datadir = load_rc().get("datadir.shenzhen", os.path.realpath(os.curdir)) """This variable contains the base directory where the database raw data is diff --git a/src/ptbench/data/split.py b/src/ptbench/data/split.py new file mode 100644 index 0000000000000000000000000000000000000000..06e813eef0e532095e465458568c9957c5b95904 --- /dev/null +++ b/src/ptbench/data/split.py @@ -0,0 +1,258 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import collections.abc +import csv +import importlib.abc +import json +import logging +import pathlib +import typing + +import torch + +logger = logging.getLogger(__name__) + + +class JSONDatabaseSplit( + dict, + typing.Mapping[str, typing.Sequence[typing.Any]], +): + """Defines a loader that understands a database split (train, test, etc) in + JSON format. + + To create a new database split, you need to provide a JSON formatted + dictionary in a file, with contents similar to the following: + + .. code-block:: json + + { + "subset1": [ + [ + "sample1-data1", + "sample1-data2", + "sample1-data3", + ], + [ + "sample2-data1", + "sample2-data2", + "sample2-data3", + ] + ], + "subset2": [ + [ + "sample42-data1", + "sample42-data2", + "sample42-data3", + ], + ] + } + + Your database split many contain any number of subsets (dictionary keys). + For simplicity, we recommend all sample entries are formatted similarly so + that raw-data-loading is simplified. Use the function + :py:func:`check_database_split_loading` to test raw data loading and fine + tune the dataset split, or its loading. + + Objects of this class behave like a dictionary in which keys are subset + names in the split, and values represent samples data and meta-data. + + + Parameters + ---------- + + path + Absolute path to a JSON formatted file containing the database split to be + recognized by this object. + """ + + def __init__(self, path: pathlib.Path | str | importlib.abc.Traversable): + if isinstance(path, str): + path = pathlib.Path(path) + self.path = path + self.subsets = self._load_split_from_disk() + + def _load_split_from_disk(self) -> dict[str, typing.Sequence[typing.Any]]: + """Loads all subsets in a split from its file system representation. + + This method will load JSON information for the current split and return + all subsets of the given split after converting each entry through the + loader function. + + + Returns + ------- + + subsets : dict + A dictionary mapping subset names to lists of JSON objects + """ + + if str(self.path).endswith(".bz2"): + logger.debug(f"Loading database split from {str(self.path)}...") + with __import__("bz2").open(self.path) as f: + return json.load(f) + else: + with self.path.open() as f: + return json.load(f) + + def __getitem__(self, key: str) -> typing.Sequence[typing.Any]: + """Accesses subset ``key`` from this split.""" + return self.subsets[key] + + def __iter__(self): + """Iterates over the subsets.""" + return iter(self.subsets) + + def __len__(self) -> int: + """How many subsets we currently have.""" + return len(self.subsets) + + +class CSVDatabaseSplit(collections.abc.Mapping): + """Defines a loader that understands a database split (train, test, etc) in + CSV format. + + To create a new database split, you need to provide one or more CSV + formatted files, each representing a subset of this split, containing the + sample data (one per row). Example: + + Inside the directory ``my-split/``, one can file files ``train.csv``, + ``validation.csv``, and ``test.csv``. Each file has a structure similar to + the following: + + .. code-block:: text + + sample1-value1,sample1-value2,sample1-value3 + sample2-value1,sample2-value2,sample2-value3 + ... + + Each file in the provided directory defines the subset name on the split. + So, the file ``train.csv`` will contain the data from the ``train`` subset, + and so on. + + Objects of this class behave like a dictionary in which keys are subset + names in the split, and values represent samples data and meta-data. + + + Parameters + ---------- + + directory + Absolute path to a directory containing the database split layed down + as a set of CSV files, one per subset. + """ + + def __init__( + self, directory: pathlib.Path | str | importlib.abc.Traversable + ): + if isinstance(directory, str): + directory = pathlib.Path(directory) + assert ( + directory.is_dir() + ), f"`{str(directory)}` is not a valid directory" + self.directory = directory + self.subsets = self._load_split_from_disk() + + def _load_split_from_disk( + self, + ) -> dict[str, list[typing.Any]]: + """Loads all subsets in a split from its file system representation. + + This method will load CSV information for the current split and return all + subsets of the given split after converting each entry through the + loader function. + + + Returns + ------- + + subsets : dict + A dictionary mapping subset names to lists of JSON objects + """ + + retval = {} + for subset in self.directory.iterdir(): + if str(subset).endswith(".csv.bz2"): + logger.debug(f"Loading database split from {subset}...") + with __import__("bz2").open(subset) as f: + reader = csv.reader(f) + retval[subset.name[: -len(".csv.bz2")]] = [ + k for k in reader + ] + elif str(subset).endswith(".csv"): + with subset.open() as f: + reader = csv.reader(f) + retval[subset.name[: -len(".csv")]] = [k for k in reader] + else: + logger.debug( + f"Ignoring file {subset} in CSVDatabaseSplit readout" + ) + return retval + + def __getitem__(self, key: str) -> typing.Sequence[typing.Any]: + """Accesses subset ``key`` from this split.""" + return self.subsets[key] + + def __iter__(self): + """Iterates over the subsets.""" + return iter(self.subsets) + + def __len__(self) -> int: + """How many subsets we currently have.""" + return len(self.subsets) + + +def check_database_split_loading( + database_split: typing.Mapping[str, typing.Sequence[typing.Any]], + loader: typing.Callable[[typing.Any], torch.Tensor], + limit: int = 0, +) -> int: + """For each subset in the split, check if all data can be correctly loaded + using the provided loader function. + + This function will return the number of errors loading samples, and will + log more detailed information to the logging stream. + + + Parameters + ---------- + + database_split + A mapping that, contains the database split. Each key represents the + name of a subset in the split. Each value is a (potentially complex) + object that represents a single sample. + + loader + A callable that transforms sample entries in the database split + into :py:class:`torch.Tensor` objects that can be used for training + or inference. + + limit + Maximum number of samples to check (in each split/subset + combination) in this dataset. If set to zero, then check + everything. + + + Returns + ------- + + errors + Number of errors found + """ + logger.info( + "Checking if can load all samples in all subsets of this split..." + ) + errors = 0 + for subset in database_split.keys(): + samples = subset if not limit else subset[:limit] + for pos, sample in enumerate(samples): + try: + data = loader(sample) + assert isinstance(data, torch.Tensor) + except Exception as e: + logger.info( + f"Found error loading entry {pos} in subset `{subset}`: {e}" + ) + errors += 1 + return errors diff --git a/src/ptbench/data/transforms.py b/src/ptbench/data/transforms.py index c85c3e9d699e3e27203fe4939486733664eb5664..cf1946e97b7701abf6a1b557ea3ef3dcaaa2a1eb 100644 --- a/src/ptbench/data/transforms.py +++ b/src/ptbench/data/transforms.py @@ -22,44 +22,9 @@ from scipy.ndimage import gaussian_filter, map_coordinates from torchvision import transforms -class SingleAutoLevel16to8: - """Converts a 16-bit image to 8-bit representation using "auto-level". - - This transform assumes that the input image is gray-scaled. - - To auto-level, we calculate the maximum and the minimum of the - image, and - consider such a range should be mapped to the [0,255] range of the - destination image. - """ - - def __call__(self, img): - imin, imax = img.getextrema() - irange = imax - imin - return PIL.Image.fromarray( - numpy.round( - 255.0 * (numpy.array(img).astype(float) - imin) / irange - ).astype("uint8"), - ).convert("L") - - -class RemoveBlackBorders: - """Remove black borders of CXR.""" - - def __init__(self, threshold=0): - self.threshold = threshold - - def __call__(self, img): - img = numpy.asarray(img) - mask = numpy.asarray(img) > self.threshold - return PIL.Image.fromarray(img[numpy.ix_(mask.any(1), mask.any(0))]) - - class ElasticDeformation: """Elastic deformation of 2D image slightly adapted from [SIMARD-2003]_. - TODO: needs to be converted into a torch.nn.Module to become scriptable! - Source: https://gist.github.com/oeway/2e3b989e0343f0884388ed7ed82eb3b0 """ @@ -70,7 +35,7 @@ class ElasticDeformation: spline_order=1, mode="nearest", random_state=numpy.random, - p=1, + p=1.0, ): self.alpha = alpha self.sigma = sigma diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py index a7cf9d567946899c874efce1664d0e7f65d5ace2..52bd27f07b16131deea9e3d0aa11c8c1cf294f88 100644 --- a/src/ptbench/models/densenet.py +++ b/src/ptbench/models/densenet.py @@ -48,6 +48,21 @@ class Densenet(pl.LightningModule): return x + def set_normalizer(self, dataloader): + """TODO: Write this function to set the Normalizer + + This function is NOOP if ``pretrained = True`` (normalizer set to + imagenet weights, during contruction). + """ + if self.pretrained: + from .normalizer import TorchVisionNormalizer + + self.normalizer = TorchVisionNormalizer(..., ...) + else: + from .normalizer import get_znorm_normalizer + + self.normalizer = get_znorm_normalizer(dataloader) + def training_step(self, batch, batch_idx): images = batch[1] labels = batch[2] diff --git a/src/ptbench/models/normalizer.py b/src/ptbench/models/normalizer.py index aa2142a6bb01de4dfc011cb8b49c1521b82a5e29..10320ab1a7ef2aa3bd2a7b1b43566a21da865150 100644 --- a/src/ptbench/models/normalizer.py +++ b/src/ptbench/models/normalizer.py @@ -6,6 +6,7 @@ import torch import torch.nn +import torch.utils.data class TorchVisionNormalizer(torch.nn.Module): @@ -20,19 +21,33 @@ class TorchVisionNormalizer(torch.nn.Module): Number of images channels fed to the model """ - def __init__(self, nb_channels=3): + def __init__(self, subtract: torch.Tensor, divide: torch.Tensor): super().__init__() - mean = torch.zeros(nb_channels)[None, :, None, None] - std = torch.ones(nb_channels)[None, :, None, None] - self.register_buffer("mean", mean) - self.register_buffer("std", std) + assert len(subtract) == len(divide), "TODO" + assert len(subtract) in (1, 3), "TODO" + self.subtract = subtract + self.divided = divide + subtract = torch.zeros(len(subtract.shape))[None, :, None, None] + divide = torch.ones(len(divide.shape))[None, :, None, None] + self.register_buffer("subtract", subtract) + self.register_buffer("divide", divide) self.name = "torchvision-normalizer" - def set_mean_std(self, mean, std): - mean = torch.as_tensor(mean)[None, :, None, None] - std = torch.as_tensor(std)[None, :, None, None] - self.register_buffer("mean", mean) - self.register_buffer("std", std) + def forward(self, inputs: torch.Tensor): + """inputs shape [batches, planes, height, width]""" + return inputs.sub(self.subtract).div(self.divide) - def forward(self, inputs): - return inputs.sub(self.mean).div(self.std) + +def get_znorm_normalizer( + dataloader: torch.utils.data.DataLoader, +) -> TorchVisionNormalizer: + # TODO: Fix this function to use unaugmented training set + # TODO: This function is only applicable IFF we are not fine-tuning (ie. + # model does not re-use weights from imagenet training!) + # TODO: Add type hints + # TODO: Add documentation + + # 1 extract mean/std from dataloader + + # 2 return TorchVisionNormalizer(mean, std) + pass diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py index 9da9702f31436df441cdb56d9415cc06e6829623..d14ea5d4f745a4f237c3cc96337d862013318188 100644 --- a/src/ptbench/models/pasa.py +++ b/src/ptbench/models/pasa.py @@ -6,6 +6,7 @@ import lightning.pytorch as pl import torch import torch.nn as nn import torch.nn.functional as F +import torch.utils.data from .normalizer import TorchVisionNormalizer @@ -127,6 +128,12 @@ class PASA(pl.LightningModule): return x + def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None: + """TODO: Write this function documentation""" + from .normalizer import get_znorm_normalizer + + self.normalizer = get_znorm_normalizer(dataloader) + def training_step(self, batch, batch_idx): images = batch["data"] labels = batch["label"]