diff --git a/pyproject.toml b/pyproject.toml index b2a52e5b215be2743b5267a6ed51c194c7546146..f22d566a950755597aac41eb69394dec6795c3cc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -117,7 +117,7 @@ montgomery_rs_f7 = "ptbench.configs.datasets.montgomery_RS.fold_7" montgomery_rs_f8 = "ptbench.configs.datasets.montgomery_RS.fold_8" montgomery_rs_f9 = "ptbench.configs.datasets.montgomery_RS.fold_9" # shenzhen dataset (and cross-validation folds) -shenzhen = "ptbench.configs.datasets.shenzhen.default" +shenzhen = "ptbench.data.shenzhen.default" shenzhen_rgb = "ptbench.configs.datasets.shenzhen.rgb" shenzhen_f0 = "ptbench.configs.datasets.shenzhen.fold_0" shenzhen_f1 = "ptbench.configs.datasets.shenzhen.fold_1" diff --git a/src/ptbench/configs/datasets/shenzhen/default.py b/src/ptbench/configs/datasets/shenzhen/default.py deleted file mode 100644 index 635c9ed9b7ed23a255ba9f5e5f79828de17b7f17..0000000000000000000000000000000000000000 --- a/src/ptbench/configs/datasets/shenzhen/default.py +++ /dev/null @@ -1,15 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later - -"""Shenzhen dataset for TB detection (default protocol) - -* Split reference: first 64% of TB and healthy CXR for "train" 16% for -* "validation", 20% for "test" -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.shenzhen` for dataset details -""" - -from . import _maker - -dataset = _maker("default") diff --git a/src/ptbench/data/datamodule.py b/src/ptbench/data/base_datamodule.py similarity index 63% rename from src/ptbench/data/datamodule.py rename to src/ptbench/data/base_datamodule.py index d3c03d76578b624f4d38a06e255501f91e2fab49..fb0970f0b2c1ebb1cfcae291abd626428f556bf6 100644 --- a/src/ptbench/data/datamodule.py +++ b/src/ptbench/data/base_datamodule.py @@ -13,10 +13,9 @@ from ptbench.configs.datasets import get_samples_weights logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") -class DataModule(pl.LightningDataModule): +class BaseDataModule(pl.LightningDataModule): def __init__( self, - dataset, train_batch_size=1, predict_batch_size=1, drop_incomplete_batch=False, @@ -24,8 +23,6 @@ class DataModule(pl.LightningDataModule): ): super().__init__() - self.dataset = dataset - self.train_batch_size = train_batch_size self.predict_batch_size = predict_batch_size @@ -37,37 +34,9 @@ class DataModule(pl.LightningDataModule): self.multiproc_kwargs = multiproc_kwargs def setup(self, stage: str): - if stage == "fit": - if "__train__" in self.dataset: - logger.info("Found (dedicated) '__train__' set for training") - self.train_dataset = self.dataset["__train__"] - else: - self.train_dataset = self.dataset["train"] - - if "__valid__" in self.dataset: - logger.info("Found (dedicated) '__valid__' set for validation") - self.validation_dataset = self.dataset["__valid__"] - - if "__extra_valid__" in self.dataset: - if not isinstance(self.dataset["__extra_valid__"], list): - raise RuntimeError( - f"If present, dataset['__extra_valid__'] must be a list, " - f"but you passed a {type(self.dataset['__extra_valid__'])}, " - f"which is invalid." - ) - logger.info( - f"Found {len(self.dataset['__extra_valid__'])} extra validation " - f"set(s) to be tracked during training" - ) - logger.info( - "Extra validation sets are NOT used for model checkpointing!" - ) - self.extra_validation_datasets = self.dataset["__extra_valid__"] - else: - self.extra_validation_datasets = None - - if stage == "predict": - self.predict_dataset = self.dataset + # Implemented by user + # Must define self.train_dataset, self.validation_dataset, self.extra_validation_datasets and self.predict_dataset + raise NotImplementedError def train_dataloader(self): train_samples_weights = get_samples_weights(self.train_dataset) diff --git a/src/ptbench/data/dataset.py b/src/ptbench/data/dataset.py index 995f227c2e44fea1bc77898d5e009a505108932e..1c425562633f346c7eac2a3af9f667520e785520 100644 --- a/src/ptbench/data/dataset.py +++ b/src/ptbench/data/dataset.py @@ -7,6 +7,14 @@ import json import logging import os import pathlib +import random + +import torch + +from torchvision.transforms import RandomRotation + +RANDOM_ROTATION = [RandomRotation(15)] +"""Shared data augmentation based on random rotation only.""" logger = logging.getLogger(__name__) @@ -313,3 +321,295 @@ class CSVDataset: ) for n, k in enumerate(samples) ] + + +def make_subset(samples, transforms=[], prefixes=[], suffixes=[]): + """Creates a new data set, applying transforms. + + .. note:: + + This is a convenience function for our own dataset definitions inside + this module, guaranteeting homogenity between dataset definitions + provided in this package. It assumes certain strategies for data + augmentation that may not be translatable to other applications. + + + Parameters + ---------- + + samples : list + List of delayed samples + + transforms : list + A list of transforms that needs to be applied to all samples in the set + + prefixes : list + A list of data augmentation operations that needs to be applied + **before** the transforms above + + suffixes : list + A list of data augmentation operations that needs to be applied + **after** the transforms above + + + Returns + ------- + + subset : :py:class:`ptbench.data.utils.SampleListDataset` + A pre-formatted dataset that can be fed to one of our engines + """ + from .utils import SampleListDataset as wrapper + + return wrapper(samples, prefixes + transforms + suffixes) + + +def make_dataset( + subsets_groups, transforms=[], t_transforms=[], post_transforms=[] +): + """Creates a new configuration dataset from a list of dictionaries and + transforms. + + This function takes as input a list of dictionaries as those that can be + returned by :py:meth:`ptbench.data.dataset.JSONDataset.subsets` + mapping protocol names (such as ``train``, ``dev`` and ``test``) to + :py:class:`ptbench.data.sample.DelayedSample` lists, and a set of + transforms, and returns a dictionary applying + :py:class:`ptbench.data.utils.SampleListDataset` to these + lists, and our standard data augmentation if a ``train`` set exists. + + For example, if ``subsets`` is composed of two sets named ``train`` and + ``test``, this function will yield a dictionary with the following entries: + + * ``__train__``: Wraps the ``train`` subset, includes data augmentation + (note: datasets with names starting with ``_`` (underscore) are excluded + from prediction and evaluation by default, as they contain data + augmentation transformations.) + * ``train``: Wraps the ``train`` subset, **without** data augmentation + * ``test``: Wraps the ``test`` subset, **without** data augmentation + + .. note:: + + This is a convenience function for our own dataset definitions inside + this module, guaranteeting homogenity between dataset definitions + provided in this package. It assumes certain strategies for data + augmentation that may not be translatable to other applications. + + + Parameters + ---------- + + subsets : list + A list of dictionaries that contains the delayed sample lists + for a number of named lists. The subsets will be aggregated in one + final subset. If one of the keys is ``train``, our standard dataset + augmentation transforms are appended to the definition of that subset. + All other subsets remain un-augmented. + + transforms : list + A list of transforms that needs to be applied to all samples in the set + + t_transforms : list + A list of transforms that needs to be applied to the train samples + + post_transforms : list + A list of transforms that needs to be applied to all samples in the set + after all the other transforms + + + Returns + ------- + + dataset : dict + A pre-formatted dataset that can be fed to one of our engines. It maps + string names to :py:class:`ptbench.data.utils.SampleListDataset`'s. + """ + + retval = {} + + if len(subsets_groups) == 1: + subsets = subsets_groups[0] + else: + # If multiple subsets groups: aggregation + aggregated_subsets = {} + for subsets in subsets_groups: + for key in subsets.keys(): + if key in aggregated_subsets: + aggregated_subsets[key] += subsets[key] + # Shuffle if data comes from multiple datasets + random.shuffle(aggregated_subsets[key]) + else: + aggregated_subsets[key] = subsets[key] + subsets = aggregated_subsets + + # Add post_transforms after t_transforms for the train set + t_transforms += post_transforms + + for key in subsets.keys(): + retval[key] = make_subset( + subsets[key], transforms=transforms, suffixes=post_transforms + ) + if key == "train": + retval["__train__"] = make_subset( + subsets[key], transforms=transforms, suffixes=(t_transforms) + ) + if key == "validation": + # also use it for validation during training + retval["__valid__"] = retval[key] + + if ( + ("__train__" in retval) + and ("train" in retval) + and ("__valid__" not in retval) + ): + # if the dataset does not have a validation set, we use the unaugmented + # training set as validation set + retval["__valid__"] = retval["train"] + + return retval + + +def get_samples_weights(dataset): + """Compute the weights of all the samples of the dataset to balance it + using the sampler of the dataloader. + + This function takes as input a :py:class:`torch.utils.data.dataset.Dataset` + and computes the weights to balance each class in the dataset and the + datasets themselves if we have a ConcatDataset. + + + Parameters + ---------- + + dataset : torch.utils.data.dataset.Dataset + An instance of torch.utils.data.dataset.Dataset + ConcatDataset are supported + + + Returns + ------- + + samples_weights : :py:class:`torch.Tensor` + the weights for all the samples in the dataset given as input + """ + samples_weights = [] + + if isinstance(dataset, torch.utils.data.ConcatDataset): + for ds in dataset.datasets: + # Weighting only for binary labels + if isinstance(ds._samples[0].label, int): + # Groundtruth + targets = [] + for s in ds._samples: + targets.append(s.label) + targets = torch.tensor(targets) + + # Count number of samples per class + class_sample_count = torch.tensor( + [ + (targets == t).sum() + for t in torch.unique(targets, sorted=True) + ] + ) + + weight = 1.0 / class_sample_count.float() + + samples_weights.append( + torch.tensor([weight[t] for t in targets]) + ) + + else: + # We only weight to sample equally from each dataset + samples_weights.append(torch.full((len(ds),), 1.0 / len(ds))) + + # Concatenate sample weights from all the datasets + samples_weights = torch.cat(samples_weights) + + else: + # Weighting only for binary labels + if isinstance(dataset._samples[0].label, int): + # Groundtruth + targets = [] + for s in dataset._samples: + targets.append(s.label) + targets = torch.tensor(targets) + + # Count number of samples per class + class_sample_count = torch.tensor( + [ + (targets == t).sum() + for t in torch.unique(targets, sorted=True) + ] + ) + + weight = 1.0 / class_sample_count.float() + + samples_weights = torch.tensor([weight[t] for t in targets]) + + else: + # Equal weights for non-binary labels + samples_weights = torch.ones(len(dataset._samples)) + + return samples_weights + + +def get_positive_weights(dataset): + """Compute the positive weights of each class of the dataset to balance the + BCEWithLogitsLoss criterion. + + This function takes as input a :py:class:`torch.utils.data.dataset.Dataset` + and computes the positive weights of each class to use them to have + a balanced loss. + + + Parameters + ---------- + + dataset : torch.utils.data.dataset.Dataset + An instance of torch.utils.data.dataset.Dataset + ConcatDataset are supported + + + Returns + ------- + + positive_weights : :py:class:`torch.Tensor` + the positive weight of each class in the dataset given as input + """ + targets = [] + + if isinstance(dataset, torch.utils.data.ConcatDataset): + for ds in dataset.datasets: + for s in ds._samples: + targets.append(s.label) + + else: + for s in dataset._samples: + targets.append(s.label) + + targets = torch.tensor(targets) + + # Binary labels + if len(list(targets.shape)) == 1: + class_sample_count = [ + float((targets == t).sum().item()) + for t in torch.unique(targets, sorted=True) + ] + + # Divide negatives by positives + positive_weights = torch.tensor( + [class_sample_count[0] / class_sample_count[1]] + ).reshape(-1) + + # Multiclass labels + else: + class_sample_count = torch.sum(targets, dim=0) + negative_class_sample_count = ( + torch.full((targets.size()[1],), float(targets.size()[0])) + - class_sample_count + ) + + positive_weights = negative_class_sample_count / ( + class_sample_count + negative_class_sample_count + ) + + return positive_weights diff --git a/src/ptbench/data/shenzhen/__init__.py b/src/ptbench/data/shenzhen/__init__.py index c1488160f553aed8e3e0fab4c882d46d8cffc799..60a004a3afc76239ffae144556feaaa888850ab7 100644 --- a/src/ptbench/data/shenzhen/__init__.py +++ b/src/ptbench/data/shenzhen/__init__.py @@ -22,11 +22,19 @@ the daily routine using Philips DR Digital Diagnose systems. import importlib.resources import os +import random + +import torch + +from clapper.logging import setup from ...utils.rc import load_rc from ..dataset import JSONDataset from ..loader import load_pil_baw, make_delayed +logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") + + _protocols = [ importlib.resources.files(__name__).joinpath("default.json.bz2"), importlib.resources.files(__name__).joinpath("fold_0.json.bz2"), @@ -57,9 +65,367 @@ def _loader(context, sample): return make_delayed(sample, _raw_data_loader) -dataset = JSONDataset( - protocols=_protocols, - fieldnames=("data", "label"), - loader=_loader, +json_dataset = JSONDataset( + protocols=_protocols, fieldnames=("data", "label"), loader=_loader ) """Shenzhen dataset object.""" + + +def _maker(protocol, resize_size=512, cc_size=512, RGB=False): + from torchvision import transforms + + from ..transforms import ElasticDeformation, RemoveBlackBorders + + post_transforms = [] + if RGB: + post_transforms = [ + transforms.Lambda(lambda x: x.convert("RGB")), + transforms.ToTensor(), + ] + + return make_dataset( + [json_dataset.subsets(protocol)], + [ + RemoveBlackBorders(), + transforms.Resize(resize_size), + transforms.CenterCrop(cc_size), + ], + [ElasticDeformation(p=0.8)], + post_transforms, + ) + + +def make_subset(samples, transforms=[], prefixes=[], suffixes=[]): + """Creates a new data set, applying transforms. + + .. note:: + + This is a convenience function for our own dataset definitions inside + this module, guaranteeting homogenity between dataset definitions + provided in this package. It assumes certain strategies for data + augmentation that may not be translatable to other applications. + + + Parameters + ---------- + + samples : list + List of delayed samples + + transforms : list + A list of transforms that needs to be applied to all samples in the set + + prefixes : list + A list of data augmentation operations that needs to be applied + **before** the transforms above + + suffixes : list + A list of data augmentation operations that needs to be applied + **after** the transforms above + + + Returns + ------- + + subset : :py:class:`ptbench.data.utils.SampleListDataset` + A pre-formatted dataset that can be fed to one of our engines + """ + from ...data.utils import SampleListDataset as wrapper + + return wrapper(samples, prefixes + transforms + suffixes) + + +def make_dataset( + subsets_groups, transforms=[], t_transforms=[], post_transforms=[] +): + """Creates a new configuration dataset from a list of dictionaries and + transforms. + + This function takes as input a list of dictionaries as those that can be + returned by :py:meth:`ptbench.data.dataset.JSONDataset.subsets` + mapping protocol names (such as ``train``, ``dev`` and ``test``) to + :py:class:`ptbench.data.sample.DelayedSample` lists, and a set of + transforms, and returns a dictionary applying + :py:class:`ptbench.data.utils.SampleListDataset` to these + lists, and our standard data augmentation if a ``train`` set exists. + + For example, if ``subsets`` is composed of two sets named ``train`` and + ``test``, this function will yield a dictionary with the following entries: + + * ``__train__``: Wraps the ``train`` subset, includes data augmentation + (note: datasets with names starting with ``_`` (underscore) are excluded + from prediction and evaluation by default, as they contain data + augmentation transformations.) + * ``train``: Wraps the ``train`` subset, **without** data augmentation + * ``test``: Wraps the ``test`` subset, **without** data augmentation + + .. note:: + + This is a convenience function for our own dataset definitions inside + this module, guaranteeting homogenity between dataset definitions + provided in this package. It assumes certain strategies for data + augmentation that may not be translatable to other applications. + + + Parameters + ---------- + + subsets : list + A list of dictionaries that contains the delayed sample lists + for a number of named lists. The subsets will be aggregated in one + final subset. If one of the keys is ``train``, our standard dataset + augmentation transforms are appended to the definition of that subset. + All other subsets remain un-augmented. + + transforms : list + A list of transforms that needs to be applied to all samples in the set + + t_transforms : list + A list of transforms that needs to be applied to the train samples + + post_transforms : list + A list of transforms that needs to be applied to all samples in the set + after all the other transforms + + + Returns + ------- + + dataset : dict + A pre-formatted dataset that can be fed to one of our engines. It maps + string names to :py:class:`ptbench.data.utils.SampleListDataset`'s. + """ + + retval = {} + + if len(subsets_groups) == 1: + subsets = subsets_groups[0] + else: + # If multiple subsets groups: aggregation + aggregated_subsets = {} + for subsets in subsets_groups: + for key in subsets.keys(): + if key in aggregated_subsets: + aggregated_subsets[key] += subsets[key] + # Shuffle if data comes from multiple datasets + random.shuffle(aggregated_subsets[key]) + else: + aggregated_subsets[key] = subsets[key] + subsets = aggregated_subsets + + # Add post_transforms after t_transforms for the train set + t_transforms += post_transforms + + for key in subsets.keys(): + retval[key] = make_subset( + subsets[key], transforms=transforms, suffixes=post_transforms + ) + if key == "train": + retval["__train__"] = make_subset( + subsets[key], transforms=transforms, suffixes=(t_transforms) + ) + if key == "validation": + # also use it for validation during training + retval["__valid__"] = retval[key] + + if ( + ("__train__" in retval) + and ("train" in retval) + and ("__valid__" not in retval) + ): + # if the dataset does not have a validation set, we use the unaugmented + # training set as validation set + retval["__valid__"] = retval["train"] + + return retval + + +def get_samples_weights(dataset): + """Compute the weights of all the samples of the dataset to balance it + using the sampler of the dataloader. + + This function takes as input a :py:class:`torch.utils.data.dataset.Dataset` + and computes the weights to balance each class in the dataset and the + datasets themselves if we have a ConcatDataset. + + + Parameters + ---------- + + dataset : torch.utils.data.dataset.Dataset + An instance of torch.utils.data.dataset.Dataset + ConcatDataset are supported + + + Returns + ------- + + samples_weights : :py:class:`torch.Tensor` + the weights for all the samples in the dataset given as input + """ + samples_weights = [] + + if isinstance(dataset, torch.utils.data.ConcatDataset): + for ds in dataset.datasets: + # Weighting only for binary labels + if isinstance(ds._samples[0].label, int): + # Groundtruth + targets = [] + for s in ds._samples: + targets.append(s.label) + targets = torch.tensor(targets) + + # Count number of samples per class + class_sample_count = torch.tensor( + [ + (targets == t).sum() + for t in torch.unique(targets, sorted=True) + ] + ) + + weight = 1.0 / class_sample_count.float() + + samples_weights.append( + torch.tensor([weight[t] for t in targets]) + ) + + else: + # We only weight to sample equally from each dataset + samples_weights.append(torch.full((len(ds),), 1.0 / len(ds))) + + # Concatenate sample weights from all the datasets + samples_weights = torch.cat(samples_weights) + + else: + # Weighting only for binary labels + if isinstance(dataset._samples[0].label, int): + # Groundtruth + targets = [] + for s in dataset._samples: + targets.append(s.label) + targets = torch.tensor(targets) + + # Count number of samples per class + class_sample_count = torch.tensor( + [ + (targets == t).sum() + for t in torch.unique(targets, sorted=True) + ] + ) + + weight = 1.0 / class_sample_count.float() + + samples_weights = torch.tensor([weight[t] for t in targets]) + + else: + # Equal weights for non-binary labels + samples_weights = torch.ones(len(dataset._samples)) + + return samples_weights + + +def get_positive_weights(dataset): + """Compute the positive weights of each class of the dataset to balance the + BCEWithLogitsLoss criterion. + + This function takes as input a :py:class:`torch.utils.data.dataset.Dataset` + and computes the positive weights of each class to use them to have + a balanced loss. + + + Parameters + ---------- + + dataset : torch.utils.data.dataset.Dataset + An instance of torch.utils.data.dataset.Dataset + ConcatDataset are supported + + + Returns + ------- + + positive_weights : :py:class:`torch.Tensor` + the positive weight of each class in the dataset given as input + """ + targets = [] + + if isinstance(dataset, torch.utils.data.ConcatDataset): + for ds in dataset.datasets: + for s in ds._samples: + targets.append(s.label) + + else: + for s in dataset._samples: + targets.append(s.label) + + targets = torch.tensor(targets) + + # Binary labels + if len(list(targets.shape)) == 1: + class_sample_count = [ + float((targets == t).sum().item()) + for t in torch.unique(targets, sorted=True) + ] + + # Divide negatives by positives + positive_weights = torch.tensor( + [class_sample_count[0] / class_sample_count[1]] + ).reshape(-1) + + # Multiclass labels + else: + class_sample_count = torch.sum(targets, dim=0) + negative_class_sample_count = ( + torch.full((targets.size()[1],), float(targets.size()[0])) + - class_sample_count + ) + + positive_weights = negative_class_sample_count / ( + class_sample_count + negative_class_sample_count + ) + + return positive_weights + + +def return_subsets(dataset): + train_dataset = None + validation_dataset = None + extra_validation_datasets = None + predict_dataset = None + + if "__train__" in dataset: + logger.info("Found (dedicated) '__train__' set for training") + train_dataset = dataset["__train__"] + else: + train_dataset = dataset["train"] + + if "__valid__" in dataset: + logger.info("Found (dedicated) '__valid__' set for validation") + validation_dataset = dataset["__valid__"] + + if "__extra_valid__" in dataset: + if not isinstance(dataset["__extra_valid__"], list): + raise RuntimeError( + f"If present, dataset['__extra_valid__'] must be a list, " + f"but you passed a {type(dataset['__extra_valid__'])}, " + f"which is invalid." + ) + logger.info( + f"Found {len(dataset['__extra_valid__'])} extra validation " + f"set(s) to be tracked during training" + ) + logger.info( + "Extra validation sets are NOT used for model checkpointing!" + ) + extra_validation_datasets = dataset["__extra_valid__"] + else: + extra_validation_datasets = None + + predict_dataset = dataset + + return ( + train_dataset, + validation_dataset, + extra_validation_datasets, + predict_dataset, + ) diff --git a/src/ptbench/data/shenzhen/default.py b/src/ptbench/data/shenzhen/default.py new file mode 100644 index 0000000000000000000000000000000000000000..8347cd56350e6de2901ca3793c5db1c00b9af55a --- /dev/null +++ b/src/ptbench/data/shenzhen/default.py @@ -0,0 +1,38 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from clapper.logging import setup + +from ..base_datamodule import BaseDataModule +from . import _maker, return_subsets + +logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") + + +class DefaultModule(BaseDataModule): + def __init__( + self, + train_batch_size=1, + predict_batch_size=1, + drop_incomplete_batch=False, + multiproc_kwargs=None, + ): + super().__init__( + train_batch_size=train_batch_size, + predict_batch_size=predict_batch_size, + drop_incomplete_batch=drop_incomplete_batch, + multiproc_kwargs=multiproc_kwargs, + ) + + def setup(self, stage: str): + self.dataset = _maker("default") + ( + self.train_dataset, + self.validation_dataset, + self.extra_validation_datasets, + self.predict_dataset, + ) = return_subsets(self.dataset) + + +datamodule = DefaultModule diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py index fe5a2b851806b108ad63e48edd9f4f0809536ebc..083d85d9566ca07c5dce3bf22a1d5265b5971dbf 100644 --- a/src/ptbench/scripts/train.py +++ b/src/ptbench/scripts/train.py @@ -8,7 +8,6 @@ from clapper.click import ConfigCommand, ResourceOption, verbosity_option from clapper.logging import setup from lightning.pytorch import seed_everything -from ..data.datamodule import DataModule from ..utils.checkpointer import get_checkpoint logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @@ -45,7 +44,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") cls=ResourceOption, ) @click.option( - "--dataset", + "--datamodule", "-d", help="A dictionary mapping string keys to " "torch.utils.data.dataset.Dataset instances implementing datasets " @@ -233,7 +232,7 @@ def train( drop_incomplete_batch, criterion, criterion_valid, - dataset, + datamodule, checkpoint_period, accelerator, seed, @@ -290,9 +289,9 @@ def train( else: batch_chunk_size = batch_size // batch_chunk_count - datamodule = DataModule( - dataset, + datamodule = datamodule( train_batch_size=batch_chunk_size, + drop_incomplete_batch=drop_incomplete_batch, multiproc_kwargs=multiproc_kwargs, ) # Manually calling these as we need to access some values to reweight the criterion