diff --git a/src/ptbench/engine/trainer.py b/src/ptbench/engine/trainer.py index 965a89cd7ee82f5c68ea079368b7a414dd264a95..41ade3f310d43d12a5ed113d633e27b73b6cb29f 100644 --- a/src/ptbench/engine/trainer.py +++ b/src/ptbench/engine/trainer.py @@ -512,15 +512,12 @@ def run( valid_loader, extra_valid_loaders, optimizer, - criterion, - checkpointer, checkpoint_period, device, arguments, output_folder, monitoring_interval, batch_chunk_count, - criterion_valid, ): """Fits a CNN model using supervised learning and save it to disk. @@ -549,12 +546,6 @@ def run( optimizer : :py:mod:`torch.optim` - criterion : :py:class:`torch.nn.modules.loss._Loss` - loss function - - checkpointer : :py:class:`ptbench.utils.checkpointer.Checkpointer` - checkpointer implementation - checkpoint_period : int save a checkpoint every ``n`` epochs. If set to ``0`` (zero), then do not save intermediary checkpoints @@ -578,9 +569,6 @@ def run( mini-batch. This is particularly interesting when one has limited RAM on the GPU, but would like to keep training with larger batches. One exchanges for longer processing times in this case. - - criterion_valid : :py:class:`torch.nn.modules.loss._Loss` - specific loss function for the validation set """ max_epoch = arguments["max_epoch"] @@ -621,7 +609,7 @@ def run( ], ) - _ = trainer.fit(model, data_loader) + _ = trainer.fit(model, data_loader, valid_loader) """# write static information to a CSV file static_logfile_name = os.path.join(output_folder, "constants.csv") diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py index bafeb0303899086cd1b00cfe52d4d0896f7045b8..12e80ae131e4a82c53e0cd7a042006cfd3bb707e 100644 --- a/src/ptbench/scripts/train.py +++ b/src/ptbench/scripts/train.py @@ -8,6 +8,7 @@ import click from clapper.click import ConfigCommand, ResourceOption, verbosity_option from clapper.logging import setup +from pytorch_lightning import seed_everything logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @@ -53,42 +54,6 @@ def setup_pytorch_device(name): return torch.device(name) -def set_seeds(value, all_gpus): - """Sets up all relevant random seeds (numpy, python, cuda) - - If running with multiple GPUs **at the same time**, set ``all_gpus`` to - ``True`` to force all GPU seeds to be initialized. - - Reference: `PyTorch page for reproducibility - <https://pytorch.org/docs/stable/notes/randomness.html>`_. - - - Parameters - ---------- - - value : int - The random seed value to use - - all_gpus : :py:class:`bool`, Optional - If set, then reset the seed on all GPUs available at once. This is - normally **not** what you want if running on a single GPU - """ - import random - - import numpy.random - import torch - import torch.cuda - - random.seed(value) - numpy.random.seed(value) - torch.manual_seed(value) - torch.cuda.manual_seed(value) # noop if cuda not available - - # set seeds for all gpus - if all_gpus: - torch.cuda.manual_seed_all(value) # noop if cuda not available - - def set_reproducible_cuda(): """Turns-off all CUDA optimizations that would affect reproducibility. @@ -252,13 +217,14 @@ def set_reproducible_cuda(): "last saved checkpoint if training is restarted with the same " "configuration.", show_default=True, - required=True, - default=0, + required=False, + default=None, type=click.IntRange(min=0), cls=ResourceOption, ) @click.option( "--device", + "-d", help='A string indicating the device to use (e.g. "cpu" or "cuda:0")', show_default=True, required=True, @@ -288,6 +254,13 @@ def set_reproducible_cuda(): default=-1, cls=ResourceOption, ) +@click.option( + "--weight", + "-w", + help="Path or URL to pretrained model file (.pth extension)", + required=False, + cls=ResourceOption, +) @click.option( "--normalization", "-n", @@ -330,6 +303,7 @@ def train( device, seed, parallel, + weight, normalization, monitoring_interval, **_, @@ -354,11 +328,10 @@ def train( from ..configs.datasets import get_positive_weights, get_samples_weights from ..engine.trainer import run - from ..utils.checkpointer import Checkpointer device = setup_pytorch_device(device) - set_seeds(seed, all_gpus=False) + seed_everything(seed) use_dataset = dataset validation_dataset = None @@ -418,9 +391,6 @@ def train( # Create weighted random sampler train_samples_weights = get_samples_weights(use_dataset) - train_samples_weights = train_samples_weights.to( - device=device, non_blocking=torch.cuda.is_available() - ) train_sampler = WeightedRandomSampler( train_samples_weights, len(train_samples_weights), replacement=True ) @@ -428,10 +398,7 @@ def train( # Redefine a weighted criterion if possible if isinstance(criterion, torch.nn.BCEWithLogitsLoss): positive_weights = get_positive_weights(use_dataset) - positive_weights = positive_weights.to( - device=device, non_blocking=torch.cuda.is_available() - ) - criterion = BCEWithLogitsLoss(pos_weight=positive_weights) + model.criterion = BCEWithLogitsLoss(pos_weight=positive_weights) else: logger.warning("Weighted criterion not supported") @@ -454,10 +421,9 @@ def train( or criterion_valid is None ): positive_weights = get_positive_weights(validation_dataset) - positive_weights = positive_weights.to( - device=device, non_blocking=torch.cuda.is_available() + model.criterion_valid = BCEWithLogitsLoss( + pos_weight=positive_weights ) - criterion_valid = BCEWithLogitsLoss(pos_weight=positive_weights) else: logger.warning("Weighted valid criterion not supported") @@ -513,14 +479,8 @@ def train( ) logger.info(f"Z-normalization with mean {mean} and std {std}") - # Checkpointer - checkpointer = Checkpointer(model, optimizer, path=output_folder) - - # Initialize epoch information arguments = {} arguments["epoch"] = 0 - extra_checkpoint_data = checkpointer.load() - arguments.update(extra_checkpoint_data) arguments["max_epoch"] = epochs logger.info("Training for {} epochs".format(arguments["max_epoch"])) @@ -532,13 +492,10 @@ def train( valid_loader=valid_loader, extra_valid_loaders=extra_valid_loaders, optimizer=optimizer, - criterion=criterion, - checkpointer=checkpointer, checkpoint_period=checkpoint_period, device=device, arguments=arguments, output_folder=output_folder, monitoring_interval=monitoring_interval, batch_chunk_count=batch_chunk_count, - criterion_valid=criterion_valid, )