diff --git a/src/ptbench/configs/datasets/__init__.py b/src/ptbench/configs/datasets/__init__.py deleted file mode 100644 index 88ac7a33e7051a69ff6d8ecfdb55aeb94b13b72f..0000000000000000000000000000000000000000 --- a/src/ptbench/configs/datasets/__init__.py +++ /dev/null @@ -1,158 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later - -import torch - -from torchvision.transforms import RandomRotation - -"""Standard configurations for dataset setup""" - -RANDOM_ROTATION = [RandomRotation(15)] -"""Shared data augmentation based on random rotation only.""" - - -def get_samples_weights(dataset): - """Compute the weights of all the samples of the dataset to balance it - using the sampler of the dataloader. - - This function takes as input a :py:class:`torch.utils.data.dataset.Dataset` - and computes the weights to balance each class in the dataset and the - datasets themselves if we have a ConcatDataset. - - - Parameters - ---------- - - dataset : torch.utils.data.dataset.Dataset - An instance of torch.utils.data.dataset.Dataset - ConcatDataset are supported - - - Returns - ------- - - samples_weights : :py:class:`torch.Tensor` - the weights for all the samples in the dataset given as input - """ - samples_weights = [] - - if isinstance(dataset, torch.utils.data.ConcatDataset): - for ds in dataset.datasets: - # Weighting only for binary labels - if isinstance(ds._samples[0].label, int): - # Groundtruth - targets = [] - for s in ds._samples: - targets.append(s.label) - targets = torch.tensor(targets) - - # Count number of samples per class - class_sample_count = torch.tensor( - [ - (targets == t).sum() - for t in torch.unique(targets, sorted=True) - ] - ) - - weight = 1.0 / class_sample_count.float() - - samples_weights.append( - torch.tensor([weight[t] for t in targets]) - ) - - else: - # We only weight to sample equally from each dataset - samples_weights.append(torch.full((len(ds),), 1.0 / len(ds))) - - # Concatenate sample weights from all the datasets - samples_weights = torch.cat(samples_weights) - - else: - # Weighting only for binary labels - if isinstance(dataset._samples[0].label, int): - # Groundtruth - targets = [] - for s in dataset._samples: - targets.append(s.label) - targets = torch.tensor(targets) - - # Count number of samples per class - class_sample_count = torch.tensor( - [ - (targets == t).sum() - for t in torch.unique(targets, sorted=True) - ] - ) - - weight = 1.0 / class_sample_count.float() - - samples_weights = torch.tensor([weight[t] for t in targets]) - - else: - # Equal weights for non-binary labels - samples_weights = torch.ones(len(dataset._samples)) - - return samples_weights - - -def get_positive_weights(dataset): - """Compute the positive weights of each class of the dataset to balance the - BCEWithLogitsLoss criterion. - - This function takes as input a :py:class:`torch.utils.data.dataset.Dataset` - and computes the positive weights of each class to use them to have - a balanced loss. - - - Parameters - ---------- - - dataset : torch.utils.data.dataset.Dataset - An instance of torch.utils.data.dataset.Dataset - ConcatDataset are supported - - - Returns - ------- - - positive_weights : :py:class:`torch.Tensor` - the positive weight of each class in the dataset given as input - """ - targets = [] - if isinstance(dataset, torch.utils.data.ConcatDataset): - for ds in dataset.datasets: - for s in ds._samples: - targets.append(s.label) - - else: - for s in dataset._samples: - targets.append(s.label) - - targets = torch.tensor(targets) - - # Binary labels - if len(list(targets.shape)) == 1: - class_sample_count = [ - float((targets == t).sum().item()) - for t in torch.unique(targets, sorted=True) - ] - - # Divide negatives by positives - positive_weights = torch.tensor( - [class_sample_count[0] / class_sample_count[1]] - ).reshape(-1) - - # Multiclass labels - else: - class_sample_count = torch.sum(targets, dim=0) - negative_class_sample_count = ( - torch.full((targets.size()[1],), float(targets.size()[0])) - - class_sample_count - ) - - positive_weights = negative_class_sample_count / ( - class_sample_count + negative_class_sample_count - ) - - return positive_weights diff --git a/src/ptbench/configs/datasets/shenzhen/__init__.py b/src/ptbench/configs/datasets/shenzhen/__init__.py deleted file mode 100644 index 84b9088ea60cbbf9ddee2fdf1bfc14203beda01f..0000000000000000000000000000000000000000 --- a/src/ptbench/configs/datasets/shenzhen/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later diff --git a/src/ptbench/configs/datasets/shenzhen/default.py b/src/ptbench/configs/datasets/shenzhen/default.py deleted file mode 100644 index 09aa54748d382122fd1f67a3b3e1871f3f0aa132..0000000000000000000000000000000000000000 --- a/src/ptbench/configs/datasets/shenzhen/default.py +++ /dev/null @@ -1,79 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later - -"""Shenzhen dataset for TB detection (default protocol) - -* Split reference: first 64% of TB and healthy CXR for "train" 16% for -* "validation", 20% for "test" -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.shenzhen` for dataset details -""" - -from clapper.logging import setup - -from ....data import return_subsets -from ....data.base_datamodule import BaseDataModule -from ....data.dataset import JSONProtocol -from ....data.shenzhen import _cached_loader, _delayed_loader, _protocols - -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") - - -class DefaultModule(BaseDataModule): - def __init__( - self, - train_batch_size=1, - predict_batch_size=1, - drop_incomplete_batch=False, - cache_samples=False, - multiproc_kwargs=None, - ): - super().__init__( - train_batch_size=train_batch_size, - predict_batch_size=predict_batch_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, - ) - - self.cache_samples = cache_samples - self.has_setup_fit = False - self.has_setup_predict = False - - def setup(self, stage: str): - if self.cache_samples: - logger.info( - "Argument cache_samples set to True. Samples will be loaded in memory." - ) - samples_loader = _cached_loader - else: - logger.info( - "Argument cache_samples set to False. Samples will be loaded at runtime." - ) - samples_loader = _delayed_loader - - json_protocol = JSONProtocol( - protocols=_protocols, - fieldnames=("data", "label"), - loader=samples_loader, - ) - - if not self.has_setup_fit and stage == "fit": - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - ) = return_subsets(json_protocol, "default", stage) - self.has_setup_fit = True - - if not self.has_setup_predict and stage == "predict": - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - ) = return_subsets(json_protocol, "default", stage) - - self.has_setup_predict = True - - -datamodule = DefaultModule() diff --git a/src/ptbench/configs/models/pasa.py b/src/ptbench/configs/models/pasa.py index cda0540fe0d7ea01c67df571cbeeecaa8199c53f..3ee0b92164b5531b65049b94e71b01b07e2ad27e 100644 --- a/src/ptbench/configs/models/pasa.py +++ b/src/ptbench/configs/models/pasa.py @@ -13,9 +13,7 @@ Reference: [PASA-2019]_ from torch import empty from torch.nn import BCEWithLogitsLoss -from torchvision import transforms -from ...data.transforms import ElasticDeformation from ...models.pasa import PASA # config @@ -28,9 +26,5 @@ optimizer = "Adam" criterion = BCEWithLogitsLoss(pos_weight=empty(1)) criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1)) -train_transforms = transforms.Compose([ElasticDeformation(p=0.8)]) - # model -model = PASA( - train_transforms, criterion, criterion_valid, optimizer, optimizer_configs -) +model = PASA(criterion, criterion_valid, optimizer, optimizer_configs) diff --git a/src/ptbench/data/__init__.py b/src/ptbench/data/__init__.py index 682d5d1ad8e88ae4f8dd72dccb888200bc3d161a..84b9088ea60cbbf9ddee2fdf1bfc14203beda01f 100644 --- a/src/ptbench/data/__init__.py +++ b/src/ptbench/data/__init__.py @@ -1,356 +1,3 @@ -"""Data manipulation and raw dataset definitions.""" - -import random - -import torch - -from clapper.logging import setup - -from .utils import SampleListDataset - -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") - - -def make_subset(samples, transforms=[], prefixes=[], suffixes=[]): - """Creates a new data set, applying transforms. - - .. note:: - - This is a convenience function for our own dataset definitions inside - this module, guaranteeting homogenity between dataset definitions - provided in this package. It assumes certain strategies for data - augmentation that may not be translatable to other applications. - - - Parameters - ---------- - - samples : list - List of delayed samples - - transforms : list - A list of transforms that needs to be applied to all samples in the set - - prefixes : list - A list of data augmentation operations that needs to be applied - **before** the transforms above - - suffixes : list - A list of data augmentation operations that needs to be applied - **after** the transforms above - - - Returns - ------- - - subset : :py:class:`ptbench.data.utils.SampleListDataset` - A pre-formatted dataset that can be fed to one of our engines - """ - from .utils import SampleListDataset as wrapper - - return wrapper(samples, prefixes + transforms + suffixes) - - -def make_dataset( - subsets_groups, transforms=[], t_transforms=[], post_transforms=[] -): - """Creates a new configuration dataset from a list of dictionaries and - transforms. - - This function takes as input a list of dictionaries as those that can be - returned by :py:meth:`ptbench.data.dataset.JSONDataset.subsets` - mapping protocol names (such as ``train``, ``dev`` and ``test``) to - :py:class:`ptbench.data.sample.DelayedSample` lists, and a set of - transforms, and returns a dictionary applying - :py:class:`ptbench.data.utils.SampleListDataset` to these - lists, and our standard data augmentation if a ``train`` set exists. - - For example, if ``subsets`` is composed of two sets named ``train`` and - ``test``, this function will yield a dictionary with the following entries: - - * ``__train__``: Wraps the ``train`` subset, includes data augmentation - (note: datasets with names starting with ``_`` (underscore) are excluded - from prediction and evaluation by default, as they contain data - augmentation transformations.) - * ``train``: Wraps the ``train`` subset, **without** data augmentation - * ``test``: Wraps the ``test`` subset, **without** data augmentation - - .. note:: - - This is a convenience function for our own dataset definitions inside - this module, guaranteeting homogenity between dataset definitions - provided in this package. It assumes certain strategies for data - augmentation that may not be translatable to other applications. - - - Parameters - ---------- - - subsets : list - A list of dictionaries that contains the delayed sample lists - for a number of named lists. The subsets will be aggregated in one - final subset. If one of the keys is ``train``, our standard dataset - augmentation transforms are appended to the definition of that subset. - All other subsets remain un-augmented. - - transforms : list - A list of transforms that needs to be applied to all samples in the set - - t_transforms : list - A list of transforms that needs to be applied to the train samples - - post_transforms : list - A list of transforms that needs to be applied to all samples in the set - after all the other transforms - - - Returns - ------- - - dataset : dict - A pre-formatted dataset that can be fed to one of our engines. It maps - string names to :py:class:`ptbench.data.utils.SampleListDataset`'s. - """ - - retval = {} - - if len(subsets_groups) == 1: - subsets = subsets_groups[0] - else: - # If multiple subsets groups: aggregation - aggregated_subsets = {} - for subsets in subsets_groups: - for key in subsets.keys(): - if key in aggregated_subsets: - aggregated_subsets[key] += subsets[key] - # Shuffle if data comes from multiple datasets - random.shuffle(aggregated_subsets[key]) - else: - aggregated_subsets[key] = subsets[key] - subsets = aggregated_subsets - - # Add post_transforms after t_transforms for the train set - t_transforms += post_transforms - - for key in subsets.keys(): - retval[key] = make_subset( - subsets[key], transforms=transforms, suffixes=post_transforms - ) - if key == "train": - retval["__train__"] = make_subset( - subsets[key], transforms=transforms, suffixes=(t_transforms) - ) - if key == "validation": - # also use it for validation during training - retval["__valid__"] = retval[key] - - if ( - ("__train__" in retval) - and ("train" in retval) - and ("__valid__" not in retval) - ): - # if the dataset does not have a validation set, we use the unaugmented - # training set as validation set - retval["__valid__"] = retval["train"] - - return retval - - -def get_samples_weights(dataset): - """Compute the weights of all the samples of the dataset to balance it - using the sampler of the dataloader. - - This function takes as input a :py:class:`torch.utils.data.dataset.Dataset` - and computes the weights to balance each class in the dataset and the - datasets themselves if we have a ConcatDataset. - - - Parameters - ---------- - - dataset : torch.utils.data.dataset.Dataset - An instance of torch.utils.data.dataset.Dataset - ConcatDataset are supported - - - Returns - ------- - - samples_weights : :py:class:`torch.Tensor` - the weights for all the samples in the dataset given as input - """ - samples_weights = [] - - if isinstance(dataset, torch.utils.data.ConcatDataset): - for ds in dataset.datasets: - # Weighting only for binary labels - if isinstance(ds._samples[0].label, int): - # Groundtruth - targets = [] - for s in ds._samples: - targets.append(s.label) - targets = torch.tensor(targets) - - # Count number of samples per class - class_sample_count = torch.tensor( - [ - (targets == t).sum() - for t in torch.unique(targets, sorted=True) - ] - ) - - weight = 1.0 / class_sample_count.float() - - samples_weights.append( - torch.tensor([weight[t] for t in targets]) - ) - - else: - # We only weight to sample equally from each dataset - samples_weights.append(torch.full((len(ds),), 1.0 / len(ds))) - - # Concatenate sample weights from all the datasets - samples_weights = torch.cat(samples_weights) - - else: - # Weighting only for binary labels - if isinstance(dataset._samples[0].label, int): - # Groundtruth - targets = [] - for s in dataset._samples: - targets.append(s.label) - targets = torch.tensor(targets) - - # Count number of samples per class - class_sample_count = torch.tensor( - [ - (targets == t).sum() - for t in torch.unique(targets, sorted=True) - ] - ) - - weight = 1.0 / class_sample_count.float() - - samples_weights = torch.tensor([weight[t] for t in targets]) - - else: - # Equal weights for non-binary labels - samples_weights = torch.ones(len(dataset._samples)) - - return samples_weights - - -def get_positive_weights(dataset): - """Compute the positive weights of each class of the dataset to balance the - BCEWithLogitsLoss criterion. - - This function takes as input a :py:class:`torch.utils.data.dataset.Dataset` - and computes the positive weights of each class to use them to have - a balanced loss. - - - Parameters - ---------- - - dataset : torch.utils.data.dataset.Dataset - An instance of torch.utils.data.dataset.Dataset - ConcatDataset are supported - - - Returns - ------- - - positive_weights : :py:class:`torch.Tensor` - the positive weight of each class in the dataset given as input - """ - targets = [] - - if isinstance(dataset, torch.utils.data.ConcatDataset): - for ds in dataset.datasets: - for s in ds._samples: - targets.append(s.label) - - else: - for s in dataset._samples: - targets.append(s.label) - - targets = torch.tensor(targets) - - # Binary labels - if len(list(targets.shape)) == 1: - class_sample_count = [ - float((targets == t).sum().item()) - for t in torch.unique(targets, sorted=True) - ] - - # Divide negatives by positives - positive_weights = torch.tensor( - [class_sample_count[0] / class_sample_count[1]] - ).reshape(-1) - - # Multiclass labels - else: - class_sample_count = torch.sum(targets, dim=0) - negative_class_sample_count = ( - torch.full((targets.size()[1],), float(targets.size()[0])) - - class_sample_count - ) - - positive_weights = negative_class_sample_count / ( - class_sample_count + negative_class_sample_count - ) - - return positive_weights - - -def return_subsets(dataset, protocol, stage): - train_set = None - valid_set = None - extra_valid_sets = None - - subsets = dataset.subsets(protocol) - - def get_train_subset(): - if "train" in subsets.keys(): - nonlocal train_set - train_set = SampleListDataset(subsets["train"], []) - - def get_valid_subset(): - if "validation" in subsets.keys(): - nonlocal valid_set - valid_set = SampleListDataset(subsets["validation"], []) - else: - logger.warning( - "No validation dataset found, using training set instead." - ) - if train_set is None: - get_train_subset() - - valid_set = train_set - - def get_extra_valid_subset(): - if "__extra_valid__" in subsets.keys(): - if not isinstance(subsets["__extra_valid__"], list): - raise RuntimeError( - f"If present, dataset['__extra_valid__'] must be a list, " - f"but you passed a {type(subsets['__extra_valid__'])}, " - f"which is invalid." - ) - logger.info( - f"Found {len(subsets['__extra_valid__'])} extra validation " - f"set(s) to be tracked during training" - ) - logger.info( - "Extra validation sets are NOT used for model checkpointing!" - ) - nonlocal extra_valid_sets - extra_valid_sets = SampleListDataset(subsets["__extra_valid__"], []) - - if stage == "fit": - get_train_subset() - get_valid_subset() - get_extra_valid_subset() - - return train_set, valid_set, extra_valid_sets - else: - raise ValueError(f"Stage {stage} is unknown.") +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later diff --git a/src/ptbench/data/base_datamodule.py b/src/ptbench/data/base_datamodule.py index 1c51a1055dc226efc72292a3e439041c67a6d37b..5cfa1620d14a186ec034821b95c3d1a2a0fdbb8b 100644 --- a/src/ptbench/data/base_datamodule.py +++ b/src/ptbench/data/base_datamodule.py @@ -2,13 +2,17 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +import multiprocessing +import sys + import lightning as pl import torch from clapper.logging import setup from torch.utils.data import DataLoader, WeightedRandomSampler +from torchvision import transforms -from ptbench.configs.datasets import get_samples_weights +from .dataset import get_samples_weights logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @@ -16,24 +20,30 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") class BaseDataModule(pl.LightningDataModule): def __init__( self, - train_batch_size=1, - predict_batch_size=1, + batch_size=1, batch_chunk_count=1, drop_incomplete_batch=False, - multiproc_kwargs={}, + parallel=-1, ): super().__init__() - self.train_batch_size = train_batch_size - self.predict_batch_size = predict_batch_size + self.batch_size = batch_size self.batch_chunk_count = batch_chunk_count + self.train_dataset = None + self.validation_dataset = None + self.extra_validation_datasets = None + + self._raw_data_transforms = [] + self._model_transforms = [] + self._augmentation_transforms = [] + self.drop_incomplete_batch = drop_incomplete_batch self.pin_memory = ( torch.cuda.is_available() ) # should only be true if GPU available and using it - self.multiproc_kwargs = multiproc_kwargs + self.parallel = parallel def setup(self, stage: str): # Implemented by user @@ -43,29 +53,32 @@ class BaseDataModule(pl.LightningDataModule): def train_dataloader(self): train_samples_weights = get_samples_weights(self.train_dataset) + multiproc_kwargs = self._setup_multiproc(self.parallel) train_sampler = WeightedRandomSampler( train_samples_weights, len(train_samples_weights), replacement=True ) return DataLoader( self.train_dataset, - batch_size=self.compute_chunk_size(self.train_batch_size), + batch_size=self._compute_chunk_size(self.batch_size), drop_last=self.drop_incomplete_batch, pin_memory=self.pin_memory, sampler=train_sampler, - **self.multiproc_kwargs, + **multiproc_kwargs, ) def val_dataloader(self): loaders_dict = {} + multiproc_kwargs = self._setup_multiproc(self.parallel) + val_loader = DataLoader( dataset=self.validation_dataset, - batch_size=self.compute_chunk_size(self.train_batch_size), + batch_size=self._compute_chunk_size(self.batch_size), shuffle=False, drop_last=False, pin_memory=self.pin_memory, - **self.multiproc_kwargs, + **multiproc_kwargs, ) loaders_dict["validation_loader"] = val_loader @@ -74,11 +87,11 @@ class BaseDataModule(pl.LightningDataModule): for set_idx, extra_set in enumerate(self.extra_validation_datasets): extra_val_loader = DataLoader( dataset=extra_set, - batch_size=self.train_batch_size, + batch_size=self._compute_chunk_size(self.batch_size), shuffle=False, drop_last=False, pin_memory=self.pin_memory, - **self.multiproc_kwargs, + **multiproc_kwargs, ) loaders_dict[ @@ -96,7 +109,7 @@ class BaseDataModule(pl.LightningDataModule): return loaders_dict - def compute_chunk_size(self, batch_size): + def _compute_chunk_size(self, batch_size): batch_chunk_size = batch_size if batch_size % self.batch_chunk_count != 0: # batch_size must be divisible by batch_chunk_count. @@ -109,6 +122,56 @@ class BaseDataModule(pl.LightningDataModule): return batch_chunk_size + def _setup_multiproc(self, parallel): + multiproc_kwargs = dict() + if parallel < 0: + multiproc_kwargs["num_workers"] = 0 + else: + multiproc_kwargs["num_workers"] = ( + parallel or multiprocessing.cpu_count() + ) + + if multiproc_kwargs["num_workers"] > 0 and sys.platform == "darwin": + multiproc_kwargs[ + "multiprocessing_context" + ] = multiprocessing.get_context("spawn") + + return multiproc_kwargs + + def _build_transforms(self, is_train): + all_transforms = self.raw_data_transforms + self.model_transforms + + if is_train: + all_transforms = all_transforms + self.augmentation_transforms + + # all_transforms.append(transforms.ToTensor()) + + return transforms.Compose(all_transforms) + + @property + def raw_data_transforms(self): + return self._raw_data_transforms + + @raw_data_transforms.setter + def raw_data_transforms(self, transforms): + self._raw_data_transforms = transforms + + @property + def model_transforms(self): + return self._model_transforms + + @model_transforms.setter + def model_transforms(self, transforms): + self._model_transforms = transforms + + @property + def augmentation_transforms(self): + return self._augmentation_transforms + + @augmentation_transforms.setter + def augmentation_transforms(self, transforms): + self._augmentation_transforms = transforms + def get_dataset_from_module(module, stage, **module_args): """Instantiates a DataModule and retrieves the corresponding dataset. diff --git a/src/ptbench/data/dataset.py b/src/ptbench/data/dataset.py index b568e78a225dc515c5e0b53d7d6b9a5b64f95cd1..fad353a13cfd2b38279e94193a565dc5abbb337f 100644 --- a/src/ptbench/data/dataset.py +++ b/src/ptbench/data/dataset.py @@ -7,16 +7,11 @@ import json import logging import os import pathlib -import random import torch -from torchvision.transforms import RandomRotation from tqdm import tqdm -RANDOM_ROTATION = [RandomRotation(15)] -"""Shared data augmentation based on random rotation only.""" - logger = logging.getLogger(__name__) @@ -75,7 +70,7 @@ class JSONProtocol: * ``data``: which contains the data associated witht this sample """ - def __init__(self, protocols, fieldnames, loader, post_transforms=[]): + def __init__(self, protocols, fieldnames): if isinstance(protocols, dict): self._protocols = protocols else: @@ -86,8 +81,6 @@ class JSONProtocol: for k in protocols } self.fieldnames = fieldnames - self._loader = loader - self.post_transforms = post_transforms def check(self, limit=0): """For each protocol, check if all data can be correctly accessed. @@ -174,11 +167,7 @@ class JSONProtocol: logger.info(f"Loading subset {subset} samples.") retval[subset] = [ - self._loader( - dict(protocol=protocol, subset=subset, order=n), - dict(zip(self.fieldnames, k)), - self.post_transforms, - ) + dict(zip(self.fieldnames, k)) for n, k in enumerate(tqdm(samples)) ] @@ -328,149 +317,98 @@ class CSVDataset: ] -def make_subset(samples, transforms=[], prefixes=[], suffixes=[]): - """Creates a new data set, applying transforms. - - .. note:: - - This is a convenience function for our own dataset definitions inside - this module, guaranteeting homogenity between dataset definitions - provided in this package. It assumes certain strategies for data - augmentation that may not be translatable to other applications. - - - Parameters - ---------- - - samples : list - List of delayed samples - - transforms : list - A list of transforms that needs to be applied to all samples in the set - - prefixes : list - A list of data augmentation operations that needs to be applied - **before** the transforms above - - suffixes : list - A list of data augmentation operations that needs to be applied - **after** the transforms above - - - Returns - ------- - - subset : :py:class:`ptbench.data.utils.SampleListDataset` - A pre-formatted dataset that can be fed to one of our engines - """ - from .utils import SampleListDataset as wrapper - - return wrapper(samples, prefixes + transforms + suffixes) - - -def make_dataset( - subsets_groups, transforms=[], t_transforms=[], post_transforms=[] -): - """Creates a new configuration dataset from a list of dictionaries and - transforms. - - This function takes as input a list of dictionaries as those that can be - returned by :py:meth:`ptbench.data.dataset.JSONDataset.subsets` - mapping protocol names (such as ``train``, ``dev`` and ``test``) to - :py:class:`ptbench.data.sample.DelayedSample` lists, and a set of - transforms, and returns a dictionary applying - :py:class:`ptbench.data.utils.SampleListDataset` to these - lists, and our standard data augmentation if a ``train`` set exists. - - For example, if ``subsets`` is composed of two sets named ``train`` and - ``test``, this function will yield a dictionary with the following entries: - - * ``__train__``: Wraps the ``train`` subset, includes data augmentation - (note: datasets with names starting with ``_`` (underscore) are excluded - from prediction and evaluation by default, as they contain data - augmentation transformations.) - * ``train``: Wraps the ``train`` subset, **without** data augmentation - * ``test``: Wraps the ``test`` subset, **without** data augmentation - - .. note:: - - This is a convenience function for our own dataset definitions inside - this module, guaranteeting homogenity between dataset definitions - provided in this package. It assumes certain strategies for data - augmentation that may not be translatable to other applications. - - - Parameters - ---------- +class CachedDataset(torch.utils.data.Dataset): + def __init__( + self, json_protocol, protocol, subset, raw_data_loader, transforms + ): + self.json_protocol = json_protocol + self.subset = subset + self.raw_data_loader = raw_data_loader + self.transforms = transforms + + self._samples = json_protocol.subsets(protocol)[self.subset] + # Dict entry with relative path to files, used during prediction + for s in self._samples: + s["name"] = s["data"] + + logger.info(f"Caching {self.subset} samples") + for sample in tqdm(self._samples): + sample["data"] = self.transforms( + self.raw_data_loader(sample["data"]) + ) - subsets : list - A list of dictionaries that contains the delayed sample lists - for a number of named lists. The subsets will be aggregated in one - final subset. If one of the keys is ``train``, our standard dataset - augmentation transforms are appended to the definition of that subset. - All other subsets remain un-augmented. + def __getitem__(self, idx): + return self._samples[idx] - transforms : list - A list of transforms that needs to be applied to all samples in the set + def __len__(self): + return len(self._samples) - t_transforms : list - A list of transforms that needs to be applied to the train samples - post_transforms : list - A list of transforms that needs to be applied to all samples in the set - after all the other transforms +class RuntimeDataset(torch.utils.data.Dataset): + def __init__( + self, json_protocol, protocol, subset, raw_data_loader, transforms + ): + self.json_protocol = json_protocol + self.subset = subset + self.raw_data_loader = raw_data_loader + self.transforms = transforms + + self._samples = json_protocol.subsets(protocol)[self.subset] + # Dict entry with relative path to files + for s in self._samples: + s["name"] = s["data"] + + def __getitem__(self, idx): + sample = self._samples[idx].copy() + sample["data"] = self.transforms(self.raw_data_loader(sample["data"])) + return sample + + def __len__(self): + return len(self._samples) + + +class TBDataset(torch.utils.data.Dataset): + def __init__( + self, + json_protocol, + protocol, + subset, + raw_data_loader, + transforms, + cache_samples=False, + ): + self.json_protocol = json_protocol + self.subset = subset + self.raw_data_loader = raw_data_loader + self.transforms = transforms + self.cache_samples = cache_samples - Returns - ------- + self._samples = json_protocol.subsets(protocol)[self.subset] - dataset : dict - A pre-formatted dataset that can be fed to one of our engines. It maps - string names to :py:class:`ptbench.data.utils.SampleListDataset`'s. - """ + # Dict entry with relative path to files + for s in self._samples: + s["name"] = s["data"] - retval = {} + if self.cache_samples: + logger.info(f"Caching {self.subset} samples") + for sample in tqdm(self._samples): + sample["data"] = self.transforms( + self.raw_data_loader(sample["data"]) + ) - if len(subsets_groups) == 1: - subsets = subsets_groups[0] - else: - # If multiple subsets groups: aggregation - aggregated_subsets = {} - for subsets in subsets_groups: - for key in subsets.keys(): - if key in aggregated_subsets: - aggregated_subsets[key] += subsets[key] - # Shuffle if data comes from multiple datasets - random.shuffle(aggregated_subsets[key]) - else: - aggregated_subsets[key] = subsets[key] - subsets = aggregated_subsets - - # Add post_transforms after t_transforms for the train set - t_transforms += post_transforms - - for key in subsets.keys(): - retval[key] = make_subset( - subsets[key], transforms=transforms, suffixes=post_transforms - ) - if key == "train": - retval["__train__"] = make_subset( - subsets[key], transforms=transforms, suffixes=(t_transforms) + def __getitem__(self, idx): + if self.cache_samples: + return self._samples[idx] + else: + sample = self._samples[idx].copy() + sample["data"] = self.transforms( + self.raw_data_loader(sample["data"]) ) - if key == "validation": - # also use it for validation during training - retval["__valid__"] = retval[key] - - if ( - ("__train__" in retval) - and ("train" in retval) - and ("__valid__" not in retval) - ): - # if the dataset does not have a validation set, we use the unaugmented - # training set as validation set - retval["__valid__"] = retval["train"] + return sample - return retval + def __len__(self): + return len(self._samples) def get_samples_weights(dataset): @@ -501,11 +439,11 @@ def get_samples_weights(dataset): if isinstance(dataset, torch.utils.data.ConcatDataset): for ds in dataset.datasets: # Weighting only for binary labels - if isinstance(ds._samples[0].label, int): + if isinstance(ds._samples[0]["label"], int): # Groundtruth targets = [] for s in ds._samples: - targets.append(s.label) + targets.append(s["label"]) targets = torch.tensor(targets) # Count number of samples per class @@ -531,11 +469,11 @@ def get_samples_weights(dataset): else: # Weighting only for binary labels - if isinstance(dataset._samples[0].label, int): + if isinstance(dataset._samples[0]["label"], int): # Groundtruth targets = [] for s in dataset._samples: - targets.append(s.label) + targets.append(s["label"]) targets = torch.tensor(targets) # Count number of samples per class @@ -585,11 +523,11 @@ def get_positive_weights(dataset): if isinstance(dataset, torch.utils.data.ConcatDataset): for ds in dataset.datasets: for s in ds._samples: - targets.append(s.label) + targets.append(s["label"]) else: for s in dataset._samples: - targets.append(s.label) + targets.append(s["label"]) targets = torch.tensor(targets) @@ -618,3 +556,75 @@ def get_positive_weights(dataset): ) return positive_weights + + +def reweight_BCEWithLogitsLoss(datamodule, model, criterion, criterion_valid): + from torch.nn import BCEWithLogitsLoss + + datamodule.prepare_data() + datamodule.setup(stage="fit") + + train_dataset = datamodule.train_dataset + validation_dataset = datamodule.validation_dataset + + # Redefine a weighted criterion if possible + if isinstance(criterion, torch.nn.BCEWithLogitsLoss): + positive_weights = get_positive_weights(train_dataset) + model.hparams.criterion = BCEWithLogitsLoss(pos_weight=positive_weights) + else: + logger.warning("Weighted criterion not supported") + + if validation_dataset is not None: + # Redefine a weighted valid criterion if possible + if ( + isinstance(criterion_valid, torch.nn.BCEWithLogitsLoss) + or criterion_valid is None + ): + positive_weights = get_positive_weights(validation_dataset) + model.hparams.criterion_valid = BCEWithLogitsLoss( + pos_weight=positive_weights + ) + else: + logger.warning("Weighted valid criterion not supported") + + +def normalize_data(normalization, model, datamodule): + from torch.utils.data import DataLoader + + datamodule.prepare_data() + datamodule.setup(stage="fit") + + train_dataset = datamodule.train_dataset + + # Create z-normalization model layer if needed + if normalization == "imagenet": + model.normalizer.set_mean_std( + [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] + ) + logger.info("Z-normalization with ImageNet mean and std") + elif normalization == "current": + # Compute mean/std of current train subset + temp_dl = DataLoader( + dataset=train_dataset, batch_size=len(train_dataset) + ) + + data = next(iter(temp_dl)) + mean = data[1].mean(dim=[0, 2, 3]) + std = data[1].std(dim=[0, 2, 3]) + + model.normalizer.set_mean_std(mean, std) + + # Format mean and std for logging + mean = str( + [ + round(x, 3) + for x in ((mean * 10**3).round() / (10**3)).tolist() + ] + ) + std = str( + [ + round(x, 3) + for x in ((std * 10**3).round() / (10**3)).tolist() + ] + ) + logger.info(f"Z-normalization with mean {mean} and std {std}") diff --git a/src/ptbench/data/loader.py b/src/ptbench/data/loader.py index a11aefee77becbbbb07e15a25d5404594c16999a..d6e86ed06dd04abb230830ebc8b13e2ea9720cfa 100644 --- a/src/ptbench/data/loader.py +++ b/src/ptbench/data/loader.py @@ -5,13 +5,8 @@ """Data loading code.""" - -import functools - import PIL.Image -from .sample import DelayedSample, Sample - def load_pil(path): """Loads a sample data. @@ -68,78 +63,3 @@ def load_pil_rgb(path): A PIL image in RGB mode """ return load_pil(path).convert("RGB") - - -def make_cached(sample, loader, additional_transforms=[], key=None): - return Sample( - loader(sample, additional_transforms), - key=key or sample["data"], - label=sample["label"], - ) - - -def make_delayed(sample, loader, additional_transforms=[], key=None): - """Returns a delayed-loading Sample object. - - Parameters - ---------- - - sample : dict - A dictionary that maps field names to sample data values (e.g. paths) - - loader : object - A function that inputs ``sample`` dictionaries and returns the loaded - data. - - key : str - A unique key identifier for this sample. If not provided, assumes - ``sample`` is a dictionary with a ``data`` entry and uses its path as - key. - - - Returns - ------- - - sample : ptbench.data.sample.DelayedSample - In which ``key`` is as provided and ``data`` can be accessed to trigger - sample loading. - """ - return DelayedSample( - functools.partial(loader, sample, additional_transforms), - key=key or sample["data"], - label=sample["label"], - ) - - -def make_delayed_bbox(sample, loader, key=None): - """Returns a delayed-loading Sample object. - - Parameters - ---------- - - sample : dict - A dictionary that maps field names to sample data values (e.g. paths) - - loader : object - A function that inputs ``sample`` dictionaries and returns the loaded - data. - - key : str - A unique key identifier for this sample. If not provided, assumes - ``sample`` is a dictionary with a ``data`` entry and uses its path as - key. - - - Returns - ------- - - sample : ptbench.data.sample.DelayedSample - In which ``key`` is as provided and ``data`` can be accessed to trigger - sample loading. - """ - return DelayedSample( - functools.partial(loader, sample), - key=key or sample["data"], - label=sample["label"], - bboxes=sample["bboxes"], - ) diff --git a/src/ptbench/data/sample.py b/src/ptbench/data/sample.py deleted file mode 100644 index 8d5102af091977fb4d746ca2f83de7fdd351cd91..0000000000000000000000000000000000000000 --- a/src/ptbench/data/sample.py +++ /dev/null @@ -1,74 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later - -"""Base definition of sample.""" - - -def _copy_attributes(s, d): - """Copies attributes from a dictionary to self.""" - s.__dict__.update( - {k: v for k, v in d.items() if k not in ("data", "load", "samples")} - ) - - -class DelayedSample: - """Representation of sample that can be loaded via a callable. - - The optional ``**kwargs`` argument allows you to attach more attributes to - this sample instance. - - - Parameters - ---------- - - load : object - A python function that can be called parameterlessly, to load the - sample in question from whatever medium - - parent : :py:class:`DelayedSample`, :py:class:`Sample`, None - If passed, consider this as a parent of this sample, to copy - information - - kwargs : dict - Further attributes of this sample, to be stored and eventually - transmitted to transformed versions of the sample - """ - - def __init__(self, load, parent=None, **kwargs): - self.load = load - if parent is not None: - _copy_attributes(self, parent.__dict__) - _copy_attributes(self, kwargs) - - @property - def data(self): - """Loads the data from the disk file.""" - return self.load() - - -class Sample: - """Representation of sample that is sufficient for the blocks in this - module. - - Each sample must have the following attributes: - - * attribute ``data``: Contains the data for this sample - - - Parameters - ---------- - - data : object - Object representing the data to initialize this sample with. - - parent : object - A parent object from which to inherit all other attributes (except - ``data``) - """ - - def __init__(self, data, parent=None, **kwargs): - self.data = data - if parent is not None: - _copy_attributes(self, parent.__dict__) - _copy_attributes(self, kwargs) diff --git a/src/ptbench/data/shenzhen/__init__.py b/src/ptbench/data/shenzhen/__init__.py index d284b1b28905594e0c9c7fd20cb31f7c566468e3..54eb632684a83e16ac28481f8499137f594b662a 100644 --- a/src/ptbench/data/shenzhen/__init__.py +++ b/src/ptbench/data/shenzhen/__init__.py @@ -23,11 +23,9 @@ import importlib.resources import os from clapper.logging import setup -from torchvision import transforms from ...utils.rc import load_rc -from ..loader import load_pil_baw, make_cached, make_delayed -from ..transforms import RemoveBlackBorders +from ..loader import load_pil_baw logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @@ -48,34 +46,8 @@ _protocols = [ _datadir = load_rc().get("datadir.shenzhen", os.path.realpath(os.curdir)) -_resize_size = 512 -_cc_size = 512 - -_data_transforms = [ - RemoveBlackBorders(), - transforms.Resize(_resize_size), - transforms.CenterCrop(_cc_size), - transforms.ToTensor(), -] - - -def _raw_data_loader(sample, additional_transforms=[]): - raw_data = load_pil_baw(os.path.join(_datadir, sample["data"])) - - base_transforms = transforms.Compose( - _data_transforms + additional_transforms - ) - return dict( - data=base_transforms(raw_data), - label=sample["label"], - ) - - -def _cached_loader(context, sample, additional_transforms=[]): - return make_cached(sample, _raw_data_loader, additional_transforms) +def _raw_data_loader(img_path): + raw_data = load_pil_baw(os.path.join(_datadir, img_path)) -def _delayed_loader(context, sample, additional_transforms=[]): - # "context" is ignored in this case - database is homogeneous - # we returned delayed samples to avoid loading all images at once - return make_delayed(sample, _raw_data_loader, additional_transforms) + return raw_data diff --git a/src/ptbench/data/shenzhen/default.py b/src/ptbench/data/shenzhen/default.py new file mode 100644 index 0000000000000000000000000000000000000000..f471d426e20e25c3517050b68f54650ce3074314 --- /dev/null +++ b/src/ptbench/data/shenzhen/default.py @@ -0,0 +1,103 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Shenzhen dataset for TB detection (default protocol) + +* Split reference: first 64% of TB and healthy CXR for "train" 16% for +* "validation", 20% for "test" +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.shenzhen` for dataset details +""" + +from clapper.logging import setup +from torchvision import transforms + +from ..base_datamodule import BaseDataModule +from ..dataset import JSONProtocol, TBDataset +from ..shenzhen import _protocols, _raw_data_loader +from ..transforms import ElasticDeformation, RemoveBlackBorders + +logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") + + +class DefaultModule(BaseDataModule): + def __init__( + self, + batch_size=1, + batch_chunk_count=1, + drop_incomplete_batch=False, + cache_samples=False, + parallel=-1, + ): + super().__init__( + batch_size=batch_size, + drop_incomplete_batch=drop_incomplete_batch, + batch_chunk_count=batch_chunk_count, + parallel=parallel, + ) + + self._cache_samples = cache_samples + self._has_setup_fit = False + self._has_setup_predict = False + self._protocol = "default" + + self.raw_data_transforms = [ + RemoveBlackBorders(), + transforms.Resize(512), + transforms.CenterCrop(512), + transforms.ToTensor(), + ] + + self.model_transforms = [] + + self.augmentation_transforms = [ElasticDeformation(p=0.8)] + + def setup(self, stage: str): + json_protocol = JSONProtocol( + protocols=_protocols, + fieldnames=("data", "label"), + ) + + if not self._has_setup_fit and stage == "fit": + self.train_dataset = TBDataset( + json_protocol, + self._protocol, + "train", + _raw_data_loader, + self._build_transforms(is_train=True), + cache_samples=self._cache_samples, + ) + self.validation_dataset = TBDataset( + json_protocol, + self._protocol, + "validation", + _raw_data_loader, + self._build_transforms(is_train=False), + cache_samples=self._cache_samples, + ) + + self._has_setup_fit = True + + if not self._has_setup_predict and stage == "predict": + self.train_dataset = TBDataset( + json_protocol, + self._protocol, + "train", + _raw_data_loader, + self._build_transforms(is_train=False), + cache_samples=self._cache_samples, + ) + self.validation_dataset = TBDataset( + json_protocol, + self._protocol, + "validation", + _raw_data_loader, + self._build_transforms(is_train=False), + cache_samples=self._cache_samples, + ) + + self._has_setup_predict = True + + +datamodule = DefaultModule diff --git a/src/ptbench/configs/datasets/shenzhen/rgb.py b/src/ptbench/data/shenzhen/rgb.py similarity index 90% rename from src/ptbench/configs/datasets/shenzhen/rgb.py rename to src/ptbench/data/shenzhen/rgb.py index 7cf77faa1406e6b34b0db06415002c2e23de49ab..2d93c07edb4c0824caa8149737ec42783966ad4c 100644 --- a/src/ptbench/configs/datasets/shenzhen/rgb.py +++ b/src/ptbench/data/shenzhen/rgb.py @@ -10,7 +10,6 @@ """ from clapper.logging import setup -from torchvision import transforms from ....data import return_subsets from ....data.base_datamodule import BaseDataModule @@ -28,6 +27,9 @@ class DefaultModule(BaseDataModule): drop_incomplete_batch=False, cache_samples=False, multiproc_kwargs=None, + data_transforms=[], + model_transforms=[], + train_transforms=[], ): super().__init__( train_batch_size=train_batch_size, @@ -39,11 +41,15 @@ class DefaultModule(BaseDataModule): self.cache_samples = cache_samples self.has_setup_fit = False - self.post_transforms = [ + self.data_transforms = data_transforms + self.model_transforms = model_transforms + self.train_transforms = train_transforms + + """[ transforms.ToPILImage(), transforms.Lambda(lambda x: x.convert("RGB")), transforms.ToTensor(), - ] + ]""" def setup(self, stage: str): if self.cache_samples: diff --git a/src/ptbench/data/utils.py b/src/ptbench/data/utils.py deleted file mode 100644 index bc5bdbd4b5fbda35361018e40479199435209fe0..0000000000000000000000000000000000000000 --- a/src/ptbench/data/utils.py +++ /dev/null @@ -1,136 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later - - -"""Common utilities.""" - -import numpy as np -import PIL -import torch -import torch.utils.data - -from torchvision.transforms import Compose, ToTensor - - -class SampleListDataset(torch.utils.data.Dataset): - """PyTorch dataset wrapper around Sample lists. - - A transform object can be passed that will be applied to the image, ground - truth and mask (if present). - - It supports indexing such that dataset[i] can be used to get ith sample. - - Parameters - ---------- - - samples : list - A list of :py:class:`ptbench.data.sample.Sample` objects - - transforms : :py:class:`list`, Optional - a list of transformations to be applied to **both** image and - ground-truth data. Notice a last transform - (:py:class:`torchvision.transforms.transforms.ToTensor`) is always - applied - you do not need to add that. - """ - - def __init__(self, samples, transforms=[]): - self._samples = samples - self.transforms = transforms - - @property - def transforms(self): - return self._transforms.transforms[:-1] - - @transforms.setter - def transforms(self, data): - if any([isinstance(t, ToTensor) for t in data]): - self._transforms = Compose(data) - else: - self._transforms = Compose(data + [ToTensor()]) - - def copy(self, transforms=None): - """Returns a deep copy of itself, optionally resetting transforms. - - Parameters - ---------- - - transforms : :py:class:`list`, Optional - An optional list of transforms to set in the copy. If not - specified, use ``self.transforms``. - """ - return SampleListDataset(self._samples, transforms or self.transforms) - - def random_permute(self, feature): - """Randomly permute feature values from all samples. - - Useful for permutation feature importance computation - - Parameters - ---------- - - feature : int - The position of the feature - """ - feature_values = np.zeros(len(self)) - - for k, s in enumerate(self._samples): - features = s.data["data"] - if isinstance(features, list): - feature_values[k] = features[feature] - - np.random.shuffle(feature_values) - - for k, s in enumerate(self._samples): - features = s.data["data"] - features[feature] = feature_values[k] - - def __len__(self): - """ - - Returns - ------- - - size : int - size of the dataset - - """ - return len(self._samples) - - def __getitem__(self, key): - """ - - Parameters - ---------- - - key : int, slice - - Returns - ------- - - sample : list - The sample data: ``[key, image, label]`` - - """ - if isinstance(key, slice): - return [self[k] for k in range(*key.indices(len(self)))] - else: # we try it as an int - item = data = self._samples[key] - if not isinstance(data, dict): - key = item.key - data = item.data # triggers data loading - - retval = data["data"] - - if self._transforms and isinstance(retval, PIL.Image.Image): - retval = self._transforms(retval) - elif isinstance(retval, list): - retval = torch.FloatTensor(retval) - - if "label" in data: - if isinstance(data["label"], list): - return [key, retval, torch.FloatTensor(data["label"])] - else: - return [key, retval, data["label"]] - - return [item.key, retval] diff --git a/src/ptbench/engine/predictor.py b/src/ptbench/engine/predictor.py index dc037af0679d3493a718a26dc393a61d00659a27..4535b368bca54daf97fa350b91516a126eff319a 100644 --- a/src/ptbench/engine/predictor.py +++ b/src/ptbench/engine/predictor.py @@ -13,7 +13,7 @@ from .callbacks import PredictionsWriter logger = logging.getLogger(__name__) -def run(model, data_loader, name, accelerator, output_folder, grad_cams=False): +def run(model, datamodule, name, accelerator, output_folder, grad_cams=False): """Runs inference on input data, outputs csv files with predictions. Parameters @@ -73,6 +73,6 @@ def run(model, data_loader, name, accelerator, output_folder, grad_cams=False): ], ) - all_predictions = trainer.predict(model, data_loader) + all_predictions = trainer.predict(model, datamodule) return all_predictions diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py index ed8c4b30a9412ebd2afb78c0b21f5a0ecb26fd91..d9e86b7055ed1ca3a06d807ab5d02b175b2f0cef 100644 --- a/src/ptbench/models/pasa.py +++ b/src/ptbench/models/pasa.py @@ -18,7 +18,6 @@ class PASA(pl.LightningModule): def __init__( self, - train_transforms, criterion, criterion_valid, optimizer, @@ -26,14 +25,12 @@ class PASA(pl.LightningModule): ): super().__init__() - self.save_hyperparameters(ignore=["train_transforms"]) + self.save_hyperparameters() self.name = "pasa" self.normalizer = TorchVisionNormalizer(nb_channels=1) - self.train_transforms = train_transforms - # First convolution block self.fc1 = nn.Conv2d(1, 4, (3, 3), (2, 2), (1, 1)) self.fc2 = nn.Conv2d(4, 16, (3, 3), (2, 2), (1, 1)) @@ -131,11 +128,8 @@ class PASA(pl.LightningModule): return x def training_step(self, batch, batch_idx): - images = batch[1] - labels = batch[2] - - for img in images: - img = self.train_transforms(img) + images = batch["data"] + labels = batch["label"] # Increase label dimension if too low # Allows single and multiclass usage @@ -152,8 +146,8 @@ class PASA(pl.LightningModule): return {"loss": training_loss} def validation_step(self, batch, batch_idx, dataloader_idx=0): - images = batch[1] - labels = batch[2] + images = batch["data"] + labels = batch["label"] # Increase label dimension if too low # Allows single and multiclass usage @@ -175,8 +169,8 @@ class PASA(pl.LightningModule): return {f"extra_validation_loss_{dataloader_idx}": validation_loss} def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False): - names = batch[0] - images = batch[1] + names = batch["names"] + images = batch["data"] outputs = self(images) probabilities = torch.sigmoid(outputs) @@ -186,7 +180,18 @@ class PASA(pl.LightningModule): if isinstance(outputs, list): outputs = outputs[-1] - return names[0], torch.flatten(probabilities), torch.flatten(batch[2]) + return { + f"dataloader_{dataloader_idx}_predictions": ( + names[0], + torch.flatten(probabilities), + torch.flatten(batch[2]), + ) + } + + def on_predict_epoch_end(self): + # Need to cache predictions in the predict step, then reorder by key + # Clear prediction dict + raise NotImplementedError def configure_optimizers(self): # Dynamically instantiates the optimizer given the configs diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py index 3606c9f9d2c0179f171acc4b3dcba7987242047e..6b37e1a109623290e10d6c79009cc69bb403ab05 100644 --- a/src/ptbench/scripts/train.py +++ b/src/ptbench/scripts/train.py @@ -260,104 +260,34 @@ def train( procedure in case it stops abruptly. """ - import multiprocessing - import sys - import torch.cuda import torch.nn - from torch.nn import BCEWithLogitsLoss - from torch.utils.data import DataLoader - - from ..configs.datasets import get_positive_weights + from ..data.dataset import normalize_data, reweight_BCEWithLogitsLoss from ..engine.trainer import run seed_everything(seed) - # PyTorch dataloader - multiproc_kwargs = dict() - if parallel < 0: - multiproc_kwargs["num_workers"] = 0 - else: - multiproc_kwargs["num_workers"] = ( - parallel or multiprocessing.cpu_count() - ) - - if multiproc_kwargs["num_workers"] > 0 and sys.platform == "darwin": - multiproc_kwargs[ - "multiprocessing_context" - ] = multiprocessing.get_context("spawn") + checkpoint_file = get_checkpoint(output_folder, resume_from) - datamodule.train_batch_size = batch_size - datamodule.batch_chunk_count = batch_chunk_count - datamodule.multiproc_kwargs = multiproc_kwargs + datamodule = datamodule( + batch_size=batch_size, + batch_chunk_count=batch_chunk_count, + drop_incomplete_batch=drop_incomplete_batch, + cache_samples=cache_samples, + parallel=parallel, + ) - # Manually calling these as we need to access some values to reweight the criterion datamodule.prepare_data() datamodule.setup(stage="fit") - train_dataset = datamodule.train_dataset - validation_dataset = datamodule.validation_dataset - - # Redefine a weighted criterion if possible - if isinstance(criterion, torch.nn.BCEWithLogitsLoss): - positive_weights = get_positive_weights(train_dataset) - model.hparams.criterion = BCEWithLogitsLoss(pos_weight=positive_weights) - else: - logger.warning("Weighted criterion not supported") - - if validation_dataset is not None: - # Redefine a weighted valid criterion if possible - if ( - isinstance(criterion_valid, torch.nn.BCEWithLogitsLoss) - or criterion_valid is None - ): - positive_weights = get_positive_weights(validation_dataset) - model.hparams.criterion_valid = BCEWithLogitsLoss( - pos_weight=positive_weights - ) - else: - logger.warning("Weighted valid criterion not supported") - - # Create z-normalization model layer if needed - if normalization == "imagenet": - model.normalizer.set_mean_std( - [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] - ) - logger.info("Z-normalization with ImageNet mean and std") - elif normalization == "current": - # Compute mean/std of current train subset - temp_dl = DataLoader( - dataset=train_dataset, batch_size=len(train_dataset) - ) - - data = next(iter(temp_dl)) - mean = data[1].mean(dim=[0, 2, 3]) - std = data[1].std(dim=[0, 2, 3]) - - model.normalizer.set_mean_std(mean, std) - - # Format mean and std for logging - mean = str( - [ - round(x, 3) - for x in ((mean * 10**3).round() / (10**3)).tolist() - ] - ) - std = str( - [ - round(x, 3) - for x in ((std * 10**3).round() / (10**3)).tolist() - ] - ) - logger.info(f"Z-normalization with mean {mean} and std {std}") + reweight_BCEWithLogitsLoss(datamodule, model, criterion, criterion_valid) + normalize_data(normalization, model, datamodule) arguments = {} arguments["max_epoch"] = epochs arguments["epoch"] = 0 - checkpoint_file = get_checkpoint(output_folder, resume_from) - # We only load the checkpoint to get some information about its state. The actual loading of the model is done in trainer.fit() if checkpoint_file is not None: checkpoint = torch.load(checkpoint_file)