diff --git a/src/ptbench/configs/datasets/__init__.py b/src/ptbench/configs/datasets/__init__.py index 1e4d913188fcbccad5aa0effa18837434f37e49b..88ac7a33e7051a69ff6d8ecfdb55aeb94b13b72f 100644 --- a/src/ptbench/configs/datasets/__init__.py +++ b/src/ptbench/configs/datasets/__init__.py @@ -2,8 +2,6 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -import random - import torch from torchvision.transforms import RandomRotation @@ -14,151 +12,6 @@ RANDOM_ROTATION = [RandomRotation(15)] """Shared data augmentation based on random rotation only.""" -def make_subset(samples, transforms=[], prefixes=[], suffixes=[]): - """Creates a new data set, applying transforms. - - .. note:: - - This is a convenience function for our own dataset definitions inside - this module, guaranteeting homogenity between dataset definitions - provided in this package. It assumes certain strategies for data - augmentation that may not be translatable to other applications. - - - Parameters - ---------- - - samples : list - List of delayed samples - - transforms : list - A list of transforms that needs to be applied to all samples in the set - - prefixes : list - A list of data augmentation operations that needs to be applied - **before** the transforms above - - suffixes : list - A list of data augmentation operations that needs to be applied - **after** the transforms above - - - Returns - ------- - - subset : :py:class:`ptbench.data.utils.SampleListDataset` - A pre-formatted dataset that can be fed to one of our engines - """ - from ...data.utils import SampleListDataset as wrapper - - return wrapper(samples, prefixes + transforms + suffixes) - - -def make_dataset( - subsets_groups, transforms=[], t_transforms=[], post_transforms=[] -): - """Creates a new configuration dataset from a list of dictionaries and - transforms. - - This function takes as input a list of dictionaries as those that can be - returned by :py:meth:`ptbench.data.dataset.JSONDataset.subsets` - mapping protocol names (such as ``train``, ``dev`` and ``test``) to - :py:class:`ptbench.data.sample.DelayedSample` lists, and a set of - transforms, and returns a dictionary applying - :py:class:`ptbench.data.utils.SampleListDataset` to these - lists, and our standard data augmentation if a ``train`` set exists. - - For example, if ``subsets`` is composed of two sets named ``train`` and - ``test``, this function will yield a dictionary with the following entries: - - * ``__train__``: Wraps the ``train`` subset, includes data augmentation - (note: datasets with names starting with ``_`` (underscore) are excluded - from prediction and evaluation by default, as they contain data - augmentation transformations.) - * ``train``: Wraps the ``train`` subset, **without** data augmentation - * ``test``: Wraps the ``test`` subset, **without** data augmentation - - .. note:: - - This is a convenience function for our own dataset definitions inside - this module, guaranteeting homogenity between dataset definitions - provided in this package. It assumes certain strategies for data - augmentation that may not be translatable to other applications. - - - Parameters - ---------- - - subsets : list - A list of dictionaries that contains the delayed sample lists - for a number of named lists. The subsets will be aggregated in one - final subset. If one of the keys is ``train``, our standard dataset - augmentation transforms are appended to the definition of that subset. - All other subsets remain un-augmented. - - transforms : list - A list of transforms that needs to be applied to all samples in the set - - t_transforms : list - A list of transforms that needs to be applied to the train samples - - post_transforms : list - A list of transforms that needs to be applied to all samples in the set - after all the other transforms - - - Returns - ------- - - dataset : dict - A pre-formatted dataset that can be fed to one of our engines. It maps - string names to :py:class:`ptbench.data.utils.SampleListDataset`'s. - """ - - retval = {} - - if len(subsets_groups) == 1: - subsets = subsets_groups[0] - else: - # If multiple subsets groups: aggregation - aggregated_subsets = {} - for subsets in subsets_groups: - for key in subsets.keys(): - if key in aggregated_subsets: - aggregated_subsets[key] += subsets[key] - # Shuffle if data comes from multiple datasets - random.shuffle(aggregated_subsets[key]) - else: - aggregated_subsets[key] = subsets[key] - subsets = aggregated_subsets - - # Add post_transforms after t_transforms for the train set - t_transforms += post_transforms - - for key in subsets.keys(): - retval[key] = make_subset( - subsets[key], transforms=transforms, suffixes=post_transforms - ) - if key == "train": - retval["__train__"] = make_subset( - subsets[key], transforms=transforms, suffixes=(t_transforms) - ) - if key == "validation": - # also use it for validation during training - retval["__valid__"] = retval[key] - - if ( - ("__train__" in retval) - and ("train" in retval) - and ("__valid__" not in retval) - ): - # if the dataset does not have a validation set, we use the unaugmented - # training set as validation set - retval["__valid__"] = retval["train"] - - return retval - - def get_samples_weights(dataset): """Compute the weights of all the samples of the dataset to balance it using the sampler of the dataloader.