diff --git a/src/ptbench/data/__init__.py b/src/ptbench/data/__init__.py index 2617ec0f2a5dfc7e294035a8debf153cfe83062b..516af66bcae176129d985555b6b7facd885ce5b4 100644 --- a/src/ptbench/data/__init__.py +++ b/src/ptbench/data/__init__.py @@ -1 +1,345 @@ """Data manipulation and raw dataset definitions.""" + +import random + +import torch + +from clapper.logging import setup + +logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") + + +def make_subset(samples, transforms=[], prefixes=[], suffixes=[]): + """Creates a new data set, applying transforms. + + .. note:: + + This is a convenience function for our own dataset definitions inside + this module, guaranteeting homogenity between dataset definitions + provided in this package. It assumes certain strategies for data + augmentation that may not be translatable to other applications. + + + Parameters + ---------- + + samples : list + List of delayed samples + + transforms : list + A list of transforms that needs to be applied to all samples in the set + + prefixes : list + A list of data augmentation operations that needs to be applied + **before** the transforms above + + suffixes : list + A list of data augmentation operations that needs to be applied + **after** the transforms above + + + Returns + ------- + + subset : :py:class:`ptbench.data.utils.SampleListDataset` + A pre-formatted dataset that can be fed to one of our engines + """ + from .utils import SampleListDataset as wrapper + + return wrapper(samples, prefixes + transforms + suffixes) + + +def make_dataset( + subsets_groups, transforms=[], t_transforms=[], post_transforms=[] +): + """Creates a new configuration dataset from a list of dictionaries and + transforms. + + This function takes as input a list of dictionaries as those that can be + returned by :py:meth:`ptbench.data.dataset.JSONDataset.subsets` + mapping protocol names (such as ``train``, ``dev`` and ``test``) to + :py:class:`ptbench.data.sample.DelayedSample` lists, and a set of + transforms, and returns a dictionary applying + :py:class:`ptbench.data.utils.SampleListDataset` to these + lists, and our standard data augmentation if a ``train`` set exists. + + For example, if ``subsets`` is composed of two sets named ``train`` and + ``test``, this function will yield a dictionary with the following entries: + + * ``__train__``: Wraps the ``train`` subset, includes data augmentation + (note: datasets with names starting with ``_`` (underscore) are excluded + from prediction and evaluation by default, as they contain data + augmentation transformations.) + * ``train``: Wraps the ``train`` subset, **without** data augmentation + * ``test``: Wraps the ``test`` subset, **without** data augmentation + + .. note:: + + This is a convenience function for our own dataset definitions inside + this module, guaranteeting homogenity between dataset definitions + provided in this package. It assumes certain strategies for data + augmentation that may not be translatable to other applications. + + + Parameters + ---------- + + subsets : list + A list of dictionaries that contains the delayed sample lists + for a number of named lists. The subsets will be aggregated in one + final subset. If one of the keys is ``train``, our standard dataset + augmentation transforms are appended to the definition of that subset. + All other subsets remain un-augmented. + + transforms : list + A list of transforms that needs to be applied to all samples in the set + + t_transforms : list + A list of transforms that needs to be applied to the train samples + + post_transforms : list + A list of transforms that needs to be applied to all samples in the set + after all the other transforms + + + Returns + ------- + + dataset : dict + A pre-formatted dataset that can be fed to one of our engines. It maps + string names to :py:class:`ptbench.data.utils.SampleListDataset`'s. + """ + + retval = {} + + if len(subsets_groups) == 1: + subsets = subsets_groups[0] + else: + # If multiple subsets groups: aggregation + aggregated_subsets = {} + for subsets in subsets_groups: + for key in subsets.keys(): + if key in aggregated_subsets: + aggregated_subsets[key] += subsets[key] + # Shuffle if data comes from multiple datasets + random.shuffle(aggregated_subsets[key]) + else: + aggregated_subsets[key] = subsets[key] + subsets = aggregated_subsets + + # Add post_transforms after t_transforms for the train set + t_transforms += post_transforms + + for key in subsets.keys(): + retval[key] = make_subset( + subsets[key], transforms=transforms, suffixes=post_transforms + ) + if key == "train": + retval["__train__"] = make_subset( + subsets[key], transforms=transforms, suffixes=(t_transforms) + ) + if key == "validation": + # also use it for validation during training + retval["__valid__"] = retval[key] + + if ( + ("__train__" in retval) + and ("train" in retval) + and ("__valid__" not in retval) + ): + # if the dataset does not have a validation set, we use the unaugmented + # training set as validation set + retval["__valid__"] = retval["train"] + + return retval + + +def get_samples_weights(dataset): + """Compute the weights of all the samples of the dataset to balance it + using the sampler of the dataloader. + + This function takes as input a :py:class:`torch.utils.data.dataset.Dataset` + and computes the weights to balance each class in the dataset and the + datasets themselves if we have a ConcatDataset. + + + Parameters + ---------- + + dataset : torch.utils.data.dataset.Dataset + An instance of torch.utils.data.dataset.Dataset + ConcatDataset are supported + + + Returns + ------- + + samples_weights : :py:class:`torch.Tensor` + the weights for all the samples in the dataset given as input + """ + samples_weights = [] + + if isinstance(dataset, torch.utils.data.ConcatDataset): + for ds in dataset.datasets: + # Weighting only for binary labels + if isinstance(ds._samples[0].label, int): + # Groundtruth + targets = [] + for s in ds._samples: + targets.append(s.label) + targets = torch.tensor(targets) + + # Count number of samples per class + class_sample_count = torch.tensor( + [ + (targets == t).sum() + for t in torch.unique(targets, sorted=True) + ] + ) + + weight = 1.0 / class_sample_count.float() + + samples_weights.append( + torch.tensor([weight[t] for t in targets]) + ) + + else: + # We only weight to sample equally from each dataset + samples_weights.append(torch.full((len(ds),), 1.0 / len(ds))) + + # Concatenate sample weights from all the datasets + samples_weights = torch.cat(samples_weights) + + else: + # Weighting only for binary labels + if isinstance(dataset._samples[0].label, int): + # Groundtruth + targets = [] + for s in dataset._samples: + targets.append(s.label) + targets = torch.tensor(targets) + + # Count number of samples per class + class_sample_count = torch.tensor( + [ + (targets == t).sum() + for t in torch.unique(targets, sorted=True) + ] + ) + + weight = 1.0 / class_sample_count.float() + + samples_weights = torch.tensor([weight[t] for t in targets]) + + else: + # Equal weights for non-binary labels + samples_weights = torch.ones(len(dataset._samples)) + + return samples_weights + + +def get_positive_weights(dataset): + """Compute the positive weights of each class of the dataset to balance the + BCEWithLogitsLoss criterion. + + This function takes as input a :py:class:`torch.utils.data.dataset.Dataset` + and computes the positive weights of each class to use them to have + a balanced loss. + + + Parameters + ---------- + + dataset : torch.utils.data.dataset.Dataset + An instance of torch.utils.data.dataset.Dataset + ConcatDataset are supported + + + Returns + ------- + + positive_weights : :py:class:`torch.Tensor` + the positive weight of each class in the dataset given as input + """ + targets = [] + + if isinstance(dataset, torch.utils.data.ConcatDataset): + for ds in dataset.datasets: + for s in ds._samples: + targets.append(s.label) + + else: + for s in dataset._samples: + targets.append(s.label) + + targets = torch.tensor(targets) + + # Binary labels + if len(list(targets.shape)) == 1: + class_sample_count = [ + float((targets == t).sum().item()) + for t in torch.unique(targets, sorted=True) + ] + + # Divide negatives by positives + positive_weights = torch.tensor( + [class_sample_count[0] / class_sample_count[1]] + ).reshape(-1) + + # Multiclass labels + else: + class_sample_count = torch.sum(targets, dim=0) + negative_class_sample_count = ( + torch.full((targets.size()[1],), float(targets.size()[0])) + - class_sample_count + ) + + positive_weights = negative_class_sample_count / ( + class_sample_count + negative_class_sample_count + ) + + return positive_weights + + +def return_subsets(dataset): + train_dataset = None + validation_dataset = None + extra_validation_datasets = None + predict_dataset = None + + if "__train__" in dataset: + logger.info("Found (dedicated) '__train__' set for training") + train_dataset = dataset["__train__"] + else: + train_dataset = dataset["train"] + + if "__valid__" in dataset: + logger.info("Found (dedicated) '__valid__' set for validation") + validation_dataset = dataset["__valid__"] + + if "__extra_valid__" in dataset: + if not isinstance(dataset["__extra_valid__"], list): + raise RuntimeError( + f"If present, dataset['__extra_valid__'] must be a list, " + f"but you passed a {type(dataset['__extra_valid__'])}, " + f"which is invalid." + ) + logger.info( + f"Found {len(dataset['__extra_valid__'])} extra validation " + f"set(s) to be tracked during training" + ) + logger.info( + "Extra validation sets are NOT used for model checkpointing!" + ) + extra_validation_datasets = dataset["__extra_valid__"] + else: + extra_validation_datasets = None + + predict_dataset = dataset + + return ( + train_dataset, + validation_dataset, + extra_validation_datasets, + predict_dataset, + ) diff --git a/src/ptbench/data/shenzhen/__init__.py b/src/ptbench/data/shenzhen/__init__.py index 60a004a3afc76239ffae144556feaaa888850ab7..a854e559bc033cd12eff0b10fc246b47edd73b64 100644 --- a/src/ptbench/data/shenzhen/__init__.py +++ b/src/ptbench/data/shenzhen/__init__.py @@ -22,13 +22,11 @@ the daily routine using Philips DR Digital Diagnose systems. import importlib.resources import os -import random - -import torch from clapper.logging import setup from ...utils.rc import load_rc +from .. import make_dataset from ..dataset import JSONDataset from ..loader import load_pil_baw, make_delayed @@ -93,339 +91,3 @@ def _maker(protocol, resize_size=512, cc_size=512, RGB=False): [ElasticDeformation(p=0.8)], post_transforms, ) - - -def make_subset(samples, transforms=[], prefixes=[], suffixes=[]): - """Creates a new data set, applying transforms. - - .. note:: - - This is a convenience function for our own dataset definitions inside - this module, guaranteeting homogenity between dataset definitions - provided in this package. It assumes certain strategies for data - augmentation that may not be translatable to other applications. - - - Parameters - ---------- - - samples : list - List of delayed samples - - transforms : list - A list of transforms that needs to be applied to all samples in the set - - prefixes : list - A list of data augmentation operations that needs to be applied - **before** the transforms above - - suffixes : list - A list of data augmentation operations that needs to be applied - **after** the transforms above - - - Returns - ------- - - subset : :py:class:`ptbench.data.utils.SampleListDataset` - A pre-formatted dataset that can be fed to one of our engines - """ - from ...data.utils import SampleListDataset as wrapper - - return wrapper(samples, prefixes + transforms + suffixes) - - -def make_dataset( - subsets_groups, transforms=[], t_transforms=[], post_transforms=[] -): - """Creates a new configuration dataset from a list of dictionaries and - transforms. - - This function takes as input a list of dictionaries as those that can be - returned by :py:meth:`ptbench.data.dataset.JSONDataset.subsets` - mapping protocol names (such as ``train``, ``dev`` and ``test``) to - :py:class:`ptbench.data.sample.DelayedSample` lists, and a set of - transforms, and returns a dictionary applying - :py:class:`ptbench.data.utils.SampleListDataset` to these - lists, and our standard data augmentation if a ``train`` set exists. - - For example, if ``subsets`` is composed of two sets named ``train`` and - ``test``, this function will yield a dictionary with the following entries: - - * ``__train__``: Wraps the ``train`` subset, includes data augmentation - (note: datasets with names starting with ``_`` (underscore) are excluded - from prediction and evaluation by default, as they contain data - augmentation transformations.) - * ``train``: Wraps the ``train`` subset, **without** data augmentation - * ``test``: Wraps the ``test`` subset, **without** data augmentation - - .. note:: - - This is a convenience function for our own dataset definitions inside - this module, guaranteeting homogenity between dataset definitions - provided in this package. It assumes certain strategies for data - augmentation that may not be translatable to other applications. - - - Parameters - ---------- - - subsets : list - A list of dictionaries that contains the delayed sample lists - for a number of named lists. The subsets will be aggregated in one - final subset. If one of the keys is ``train``, our standard dataset - augmentation transforms are appended to the definition of that subset. - All other subsets remain un-augmented. - - transforms : list - A list of transforms that needs to be applied to all samples in the set - - t_transforms : list - A list of transforms that needs to be applied to the train samples - - post_transforms : list - A list of transforms that needs to be applied to all samples in the set - after all the other transforms - - - Returns - ------- - - dataset : dict - A pre-formatted dataset that can be fed to one of our engines. It maps - string names to :py:class:`ptbench.data.utils.SampleListDataset`'s. - """ - - retval = {} - - if len(subsets_groups) == 1: - subsets = subsets_groups[0] - else: - # If multiple subsets groups: aggregation - aggregated_subsets = {} - for subsets in subsets_groups: - for key in subsets.keys(): - if key in aggregated_subsets: - aggregated_subsets[key] += subsets[key] - # Shuffle if data comes from multiple datasets - random.shuffle(aggregated_subsets[key]) - else: - aggregated_subsets[key] = subsets[key] - subsets = aggregated_subsets - - # Add post_transforms after t_transforms for the train set - t_transforms += post_transforms - - for key in subsets.keys(): - retval[key] = make_subset( - subsets[key], transforms=transforms, suffixes=post_transforms - ) - if key == "train": - retval["__train__"] = make_subset( - subsets[key], transforms=transforms, suffixes=(t_transforms) - ) - if key == "validation": - # also use it for validation during training - retval["__valid__"] = retval[key] - - if ( - ("__train__" in retval) - and ("train" in retval) - and ("__valid__" not in retval) - ): - # if the dataset does not have a validation set, we use the unaugmented - # training set as validation set - retval["__valid__"] = retval["train"] - - return retval - - -def get_samples_weights(dataset): - """Compute the weights of all the samples of the dataset to balance it - using the sampler of the dataloader. - - This function takes as input a :py:class:`torch.utils.data.dataset.Dataset` - and computes the weights to balance each class in the dataset and the - datasets themselves if we have a ConcatDataset. - - - Parameters - ---------- - - dataset : torch.utils.data.dataset.Dataset - An instance of torch.utils.data.dataset.Dataset - ConcatDataset are supported - - - Returns - ------- - - samples_weights : :py:class:`torch.Tensor` - the weights for all the samples in the dataset given as input - """ - samples_weights = [] - - if isinstance(dataset, torch.utils.data.ConcatDataset): - for ds in dataset.datasets: - # Weighting only for binary labels - if isinstance(ds._samples[0].label, int): - # Groundtruth - targets = [] - for s in ds._samples: - targets.append(s.label) - targets = torch.tensor(targets) - - # Count number of samples per class - class_sample_count = torch.tensor( - [ - (targets == t).sum() - for t in torch.unique(targets, sorted=True) - ] - ) - - weight = 1.0 / class_sample_count.float() - - samples_weights.append( - torch.tensor([weight[t] for t in targets]) - ) - - else: - # We only weight to sample equally from each dataset - samples_weights.append(torch.full((len(ds),), 1.0 / len(ds))) - - # Concatenate sample weights from all the datasets - samples_weights = torch.cat(samples_weights) - - else: - # Weighting only for binary labels - if isinstance(dataset._samples[0].label, int): - # Groundtruth - targets = [] - for s in dataset._samples: - targets.append(s.label) - targets = torch.tensor(targets) - - # Count number of samples per class - class_sample_count = torch.tensor( - [ - (targets == t).sum() - for t in torch.unique(targets, sorted=True) - ] - ) - - weight = 1.0 / class_sample_count.float() - - samples_weights = torch.tensor([weight[t] for t in targets]) - - else: - # Equal weights for non-binary labels - samples_weights = torch.ones(len(dataset._samples)) - - return samples_weights - - -def get_positive_weights(dataset): - """Compute the positive weights of each class of the dataset to balance the - BCEWithLogitsLoss criterion. - - This function takes as input a :py:class:`torch.utils.data.dataset.Dataset` - and computes the positive weights of each class to use them to have - a balanced loss. - - - Parameters - ---------- - - dataset : torch.utils.data.dataset.Dataset - An instance of torch.utils.data.dataset.Dataset - ConcatDataset are supported - - - Returns - ------- - - positive_weights : :py:class:`torch.Tensor` - the positive weight of each class in the dataset given as input - """ - targets = [] - - if isinstance(dataset, torch.utils.data.ConcatDataset): - for ds in dataset.datasets: - for s in ds._samples: - targets.append(s.label) - - else: - for s in dataset._samples: - targets.append(s.label) - - targets = torch.tensor(targets) - - # Binary labels - if len(list(targets.shape)) == 1: - class_sample_count = [ - float((targets == t).sum().item()) - for t in torch.unique(targets, sorted=True) - ] - - # Divide negatives by positives - positive_weights = torch.tensor( - [class_sample_count[0] / class_sample_count[1]] - ).reshape(-1) - - # Multiclass labels - else: - class_sample_count = torch.sum(targets, dim=0) - negative_class_sample_count = ( - torch.full((targets.size()[1],), float(targets.size()[0])) - - class_sample_count - ) - - positive_weights = negative_class_sample_count / ( - class_sample_count + negative_class_sample_count - ) - - return positive_weights - - -def return_subsets(dataset): - train_dataset = None - validation_dataset = None - extra_validation_datasets = None - predict_dataset = None - - if "__train__" in dataset: - logger.info("Found (dedicated) '__train__' set for training") - train_dataset = dataset["__train__"] - else: - train_dataset = dataset["train"] - - if "__valid__" in dataset: - logger.info("Found (dedicated) '__valid__' set for validation") - validation_dataset = dataset["__valid__"] - - if "__extra_valid__" in dataset: - if not isinstance(dataset["__extra_valid__"], list): - raise RuntimeError( - f"If present, dataset['__extra_valid__'] must be a list, " - f"but you passed a {type(dataset['__extra_valid__'])}, " - f"which is invalid." - ) - logger.info( - f"Found {len(dataset['__extra_valid__'])} extra validation " - f"set(s) to be tracked during training" - ) - logger.info( - "Extra validation sets are NOT used for model checkpointing!" - ) - extra_validation_datasets = dataset["__extra_valid__"] - else: - extra_validation_datasets = None - - predict_dataset = dataset - - return ( - train_dataset, - validation_dataset, - extra_validation_datasets, - predict_dataset, - ) diff --git a/src/ptbench/data/shenzhen/default.py b/src/ptbench/data/shenzhen/default.py index 3801efbc2a5f66d467a7681be37819e7fda74ede..bbeabcafa98554f6f8bcae71df287e6c3eda939e 100644 --- a/src/ptbench/data/shenzhen/default.py +++ b/src/ptbench/data/shenzhen/default.py @@ -12,8 +12,9 @@ from clapper.logging import setup +from .. import return_subsets from ..base_datamodule import BaseDataModule -from . import _maker, return_subsets +from . import _maker logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") diff --git a/src/ptbench/data/shenzhen/fold_0.py b/src/ptbench/data/shenzhen/fold_0.py index d65d513bcb4cefcfc553e3c6bde864392db01add..5b4d45602d13a0e6bd0e2724a6c3202c1532eef6 100644 --- a/src/ptbench/data/shenzhen/fold_0.py +++ b/src/ptbench/data/shenzhen/fold_0.py @@ -11,8 +11,9 @@ from clapper.logging import setup +from .. import return_subsets from ..base_datamodule import BaseDataModule -from . import _maker, return_subsets +from . import _maker logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") diff --git a/src/ptbench/data/shenzhen/fold_0_rgb.py b/src/ptbench/data/shenzhen/fold_0_rgb.py index bcc853dd7c52543d4aa27a4b10083d4407582786..143ef731fae0f7d3c746e384e08e400f36c92511 100644 --- a/src/ptbench/data/shenzhen/fold_0_rgb.py +++ b/src/ptbench/data/shenzhen/fold_0_rgb.py @@ -11,8 +11,9 @@ from clapper.logging import setup +from .. import return_subsets from ..base_datamodule import BaseDataModule -from . import _maker, return_subsets +from . import _maker logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") diff --git a/src/ptbench/data/shenzhen/fold_1.py b/src/ptbench/data/shenzhen/fold_1.py index b9494f156e69ef9ed2af68d607f8d8db2540381f..f01adef0af4da152ad756036aacb4bf83c5c20cc 100644 --- a/src/ptbench/data/shenzhen/fold_1.py +++ b/src/ptbench/data/shenzhen/fold_1.py @@ -11,8 +11,9 @@ from clapper.logging import setup +from .. import return_subsets from ..base_datamodule import BaseDataModule -from . import _maker, return_subsets +from . import _maker logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") diff --git a/src/ptbench/data/shenzhen/fold_1_rgb.py b/src/ptbench/data/shenzhen/fold_1_rgb.py index 01e23967e51878a348b89d14e263ff2c5e1d0503..9d457adfa8835e45c7d5dc993ab23ffc8baafeb9 100644 --- a/src/ptbench/data/shenzhen/fold_1_rgb.py +++ b/src/ptbench/data/shenzhen/fold_1_rgb.py @@ -12,8 +12,9 @@ from clapper.logging import setup +from .. import return_subsets from ..base_datamodule import BaseDataModule -from . import _maker, return_subsets +from . import _maker logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") diff --git a/src/ptbench/data/shenzhen/fold_2.py b/src/ptbench/data/shenzhen/fold_2.py index 8d5cf816ce4ae950f19a71ea0c57662f427e2ee4..04dd656263cc6e71d2c61e3f5e221f7129fa7035 100644 --- a/src/ptbench/data/shenzhen/fold_2.py +++ b/src/ptbench/data/shenzhen/fold_2.py @@ -11,8 +11,9 @@ from clapper.logging import setup +from .. import return_subsets from ..base_datamodule import BaseDataModule -from . import _maker, return_subsets +from . import _maker logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") diff --git a/src/ptbench/data/shenzhen/fold_2_rgb.py b/src/ptbench/data/shenzhen/fold_2_rgb.py index baf6752d66093a059f295f0e540a9d3ffd6b681f..37cbe10ebd72057535bacd241ecb4dbfb2bded3d 100644 --- a/src/ptbench/data/shenzhen/fold_2_rgb.py +++ b/src/ptbench/data/shenzhen/fold_2_rgb.py @@ -11,8 +11,9 @@ from clapper.logging import setup +from .. import return_subsets from ..base_datamodule import BaseDataModule -from . import _maker, return_subsets +from . import _maker logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") diff --git a/src/ptbench/data/shenzhen/fold_3.py b/src/ptbench/data/shenzhen/fold_3.py index b42882ccf663b2635dbc106b817f1956a113863e..b43fcb29c9392d6e9bc5c20d9e93969b33bd2708 100644 --- a/src/ptbench/data/shenzhen/fold_3.py +++ b/src/ptbench/data/shenzhen/fold_3.py @@ -11,8 +11,9 @@ from clapper.logging import setup +from .. import return_subsets from ..base_datamodule import BaseDataModule -from . import _maker, return_subsets +from . import _maker logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") diff --git a/src/ptbench/data/shenzhen/fold_3_rgb.py b/src/ptbench/data/shenzhen/fold_3_rgb.py index a02c2b1d4e7b33321aa45d537f9af2155f4ae321..162a3f82d640633ce6d13d25b2a500dc0fb63a54 100644 --- a/src/ptbench/data/shenzhen/fold_3_rgb.py +++ b/src/ptbench/data/shenzhen/fold_3_rgb.py @@ -11,8 +11,9 @@ from clapper.logging import setup +from .. import return_subsets from ..base_datamodule import BaseDataModule -from . import _maker, return_subsets +from . import _maker logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") diff --git a/src/ptbench/data/shenzhen/fold_4.py b/src/ptbench/data/shenzhen/fold_4.py index a9ad14710ad37c02c34e6575fbac064db64a4054..58e0a2f221e0a6a48b3466c015c9c5790fc6ae3a 100644 --- a/src/ptbench/data/shenzhen/fold_4.py +++ b/src/ptbench/data/shenzhen/fold_4.py @@ -11,8 +11,9 @@ from clapper.logging import setup +from .. import return_subsets from ..base_datamodule import BaseDataModule -from . import _maker, return_subsets +from . import _maker logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") diff --git a/src/ptbench/data/shenzhen/fold_4_rgb.py b/src/ptbench/data/shenzhen/fold_4_rgb.py index 3620ba428533a9e97392bde40f77b55b454809bb..0dd4ccf89c553a19030774772447bb66bf0cf9b7 100644 --- a/src/ptbench/data/shenzhen/fold_4_rgb.py +++ b/src/ptbench/data/shenzhen/fold_4_rgb.py @@ -11,8 +11,9 @@ from clapper.logging import setup +from .. import return_subsets from ..base_datamodule import BaseDataModule -from . import _maker, return_subsets +from . import _maker logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") diff --git a/src/ptbench/data/shenzhen/fold_5.py b/src/ptbench/data/shenzhen/fold_5.py index 426a9d6609e212168b3c2775b791de416a6b7dfd..ff115340f723d33ebdeb49524cc7e3b269738563 100644 --- a/src/ptbench/data/shenzhen/fold_5.py +++ b/src/ptbench/data/shenzhen/fold_5.py @@ -11,8 +11,9 @@ from clapper.logging import setup +from .. import return_subsets from ..base_datamodule import BaseDataModule -from . import _maker, return_subsets +from . import _maker logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") diff --git a/src/ptbench/data/shenzhen/fold_5_rgb.py b/src/ptbench/data/shenzhen/fold_5_rgb.py index 29e7013885ac1ba917ce40c7a027c827af54d705..46e255e7c37c1f547d25626970324f88d58ba03e 100644 --- a/src/ptbench/data/shenzhen/fold_5_rgb.py +++ b/src/ptbench/data/shenzhen/fold_5_rgb.py @@ -11,8 +11,9 @@ from clapper.logging import setup +from .. import return_subsets from ..base_datamodule import BaseDataModule -from . import _maker, return_subsets +from . import _maker logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") diff --git a/src/ptbench/data/shenzhen/fold_6.py b/src/ptbench/data/shenzhen/fold_6.py index fb0a91b896723240f26a283cb76c5a2245470fc0..eb81ae882369a1a99bca60e93b1922c7a9d7e9fc 100644 --- a/src/ptbench/data/shenzhen/fold_6.py +++ b/src/ptbench/data/shenzhen/fold_6.py @@ -11,8 +11,9 @@ from clapper.logging import setup +from .. import return_subsets from ..base_datamodule import BaseDataModule -from . import _maker, return_subsets +from . import _maker logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") diff --git a/src/ptbench/data/shenzhen/fold_6_rgb.py b/src/ptbench/data/shenzhen/fold_6_rgb.py index 35e7e6d7e7af0758aa4aaa01f24887e3eea6711f..b9654d08008bd10f0a5d953f2a79870214218651 100644 --- a/src/ptbench/data/shenzhen/fold_6_rgb.py +++ b/src/ptbench/data/shenzhen/fold_6_rgb.py @@ -11,8 +11,9 @@ from clapper.logging import setup +from .. import return_subsets from ..base_datamodule import BaseDataModule -from . import _maker, return_subsets +from . import _maker logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") diff --git a/src/ptbench/data/shenzhen/fold_7.py b/src/ptbench/data/shenzhen/fold_7.py index 743b344d77b9324611e1490dc47966bb52dc41a4..79b0d1fff4483ef98eed42b6dce6c6b64076bd1e 100644 --- a/src/ptbench/data/shenzhen/fold_7.py +++ b/src/ptbench/data/shenzhen/fold_7.py @@ -11,8 +11,9 @@ from clapper.logging import setup +from .. import return_subsets from ..base_datamodule import BaseDataModule -from . import _maker, return_subsets +from . import _maker logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") diff --git a/src/ptbench/data/shenzhen/fold_7_rgb.py b/src/ptbench/data/shenzhen/fold_7_rgb.py index 0a9f83d7a0072aa722a7d1cb881249e4722be36c..8a36acb2c79a4467f5dccad532c17cf528613ab5 100644 --- a/src/ptbench/data/shenzhen/fold_7_rgb.py +++ b/src/ptbench/data/shenzhen/fold_7_rgb.py @@ -11,8 +11,9 @@ from clapper.logging import setup +from .. import return_subsets from ..base_datamodule import BaseDataModule -from . import _maker, return_subsets +from . import _maker logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") diff --git a/src/ptbench/data/shenzhen/fold_8.py b/src/ptbench/data/shenzhen/fold_8.py index ee7f716714adc35beeab1b4eb450253d408320ad..cf1cd36a8d2d2a40113a775131f9d5b8ba0a092a 100644 --- a/src/ptbench/data/shenzhen/fold_8.py +++ b/src/ptbench/data/shenzhen/fold_8.py @@ -11,8 +11,9 @@ from clapper.logging import setup +from .. import return_subsets from ..base_datamodule import BaseDataModule -from . import _maker, return_subsets +from . import _maker logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") diff --git a/src/ptbench/data/shenzhen/fold_8_rgb.py b/src/ptbench/data/shenzhen/fold_8_rgb.py index 1351790faf6234c2d7a3b51c500b2d16c94fd2c0..1aa0bcec76d875d8e9b966b4461228230bdf79f2 100644 --- a/src/ptbench/data/shenzhen/fold_8_rgb.py +++ b/src/ptbench/data/shenzhen/fold_8_rgb.py @@ -11,8 +11,9 @@ from clapper.logging import setup +from .. import return_subsets from ..base_datamodule import BaseDataModule -from . import _maker, return_subsets +from . import _maker logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") diff --git a/src/ptbench/data/shenzhen/fold_9.py b/src/ptbench/data/shenzhen/fold_9.py index dbd8ab3146dca6bda7048c1d47fe1ee0852129c7..e1bb569d0389a8eb67b0ad186e48c6f98aa7cf57 100644 --- a/src/ptbench/data/shenzhen/fold_9.py +++ b/src/ptbench/data/shenzhen/fold_9.py @@ -11,8 +11,9 @@ from clapper.logging import setup +from .. import return_subsets from ..base_datamodule import BaseDataModule -from . import _maker, return_subsets +from . import _maker logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") diff --git a/src/ptbench/data/shenzhen/fold_9_rgb.py b/src/ptbench/data/shenzhen/fold_9_rgb.py index 729141dc3a134e4d5e815441262a58004b98ed80..c0a577df0a0d3667420c32042816f65ba9ad20ce 100644 --- a/src/ptbench/data/shenzhen/fold_9_rgb.py +++ b/src/ptbench/data/shenzhen/fold_9_rgb.py @@ -11,8 +11,9 @@ from clapper.logging import setup +from .. import return_subsets from ..base_datamodule import BaseDataModule -from . import _maker, return_subsets +from . import _maker logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") diff --git a/src/ptbench/data/shenzhen/rgb.py b/src/ptbench/data/shenzhen/rgb.py index 6186990ec8023df5030ba84a9be627e80a95ba44..7bdb8fe3ce6826fb98d0d6356f2e1b429670a3d1 100644 --- a/src/ptbench/data/shenzhen/rgb.py +++ b/src/ptbench/data/shenzhen/rgb.py @@ -12,8 +12,9 @@ from clapper.logging import setup +from .. import return_subsets from ..base_datamodule import BaseDataModule -from . import _maker, return_subsets +from . import _maker logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")