From e1215e9975f400befed7f956a16be20a3401e6aa Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Wed, 12 Apr 2023 10:18:03 +0200 Subject: [PATCH] Removed unused functions --- src/ptbench/engine/trainer.py | 422 ---------------------------------- 1 file changed, 422 deletions(-) diff --git a/src/ptbench/engine/trainer.py b/src/ptbench/engine/trainer.py index 05b37feb..cf8fc101 100644 --- a/src/ptbench/engine/trainer.py +++ b/src/ptbench/engine/trainer.py @@ -2,22 +2,15 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -import contextlib import csv -import datetime import logging import os import shutil -import sys - -import numpy -import torch from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger from pytorch_lightning.utilities.model_summary import ModelSummary -from tqdm import tqdm from ..utils.resources import ResourceMonitor, cpu_constants, gpu_constants from .callbacks import LoggingCallback @@ -25,32 +18,6 @@ from .callbacks import LoggingCallback logger = logging.getLogger(__name__) -@contextlib.contextmanager -def torch_evaluation(model): - """Context manager to turn ON/OFF model evaluation. - - This context manager will turn evaluation mode ON on entry and turn it OFF - when exiting the ``with`` statement block. - - - Parameters - ---------- - - model : :py:class:`torch.nn.Module` - Network - - - Yields - ------ - - model : :py:class:`torch.nn.Module` - Network - """ - model.eval() - yield model - model.train() - - def check_gpu(device): """Check the device type and the availability of GPU. @@ -67,45 +34,6 @@ def check_gpu(device): ), f"Device set to '{device}', but nvidia-smi is not installed" -def initialize_lowest_validation_loss(logfile_name, arguments): - """Initialize the lowest validation loss from the logfile if it exists and - if the training does not start from epoch 0, which means that a previous - training session is resumed. - - Parameters - ---------- - - logfile_name : str - The logfile_name which is a join between the output_folder and trainlog.csv - - arguments : dict - start and end epochs - """ - - if arguments["epoch"] != 0 and os.path.exists(logfile_name): - # Open the CSV file - with open(logfile_name) as file: - reader = csv.DictReader(file) - column_name = "validation_loss" - - if column_name not in reader.fieldnames: - return sys.float_info.max - - # Get the values of the desired column as a list - values = [float(row[column_name]) for row in reader] - - if not values: - return sys.float_info.max - - lowest_value = min(values) - logger.info( - f"Found lowest validation loss from previous session: {lowest_value}" - ) - return lowest_value - - return sys.float_info.max - - def save_model_summary(output_folder, model): """Save a little summary of the model in a txt file. @@ -220,236 +148,6 @@ def create_logfile_fields(valid_loader, extra_valid_loaders, device): return logfile_fields -def train_epoch(loader, model, optimizer, device, criterion, batch_chunk_count): - """Trains the model for a single epoch (through all batches) - - Parameters - ---------- - - loader : :py:class:`torch.utils.data.DataLoader` - To be used to train the model - - model : :py:class:`torch.nn.Module` - Network (e.g. driu, hed, unet) - - optimizer : :py:mod:`torch.optim` - - device : :py:class:`torch.device` - device to use - - criterion : :py:class:`torch.nn.modules.loss._Loss` - - batch_chunk_count: int - If this number is different than 1, then each batch will be divided in - this number of chunks. Gradients will be accumulated to perform each - mini-batch. This is particularly interesting when one has limited RAM - on the GPU, but would like to keep training with larger batches. One - exchanges for longer processing times in this case. To better understand - gradient accumulation, read - https://stackoverflow.com/questions/62067400/understanding-accumulated-gradients-in-pytorch. - - - Returns - ------- - - loss : float - A floating-point value corresponding the weighted average of this - epoch's loss - """ - losses_in_epoch = [] - samples_in_epoch = [] - losses_in_batch = [] - samples_in_batch = [] - - # progress bar only on interactive jobs - for idx, samples in enumerate( - tqdm(loader, desc="train", leave=False, disable=None) - ): - images = samples[1].to( - device=device, non_blocking=torch.cuda.is_available() - ) - labels = samples[2].to( - device=device, non_blocking=torch.cuda.is_available() - ) - - # Increase label dimension if too low - # Allows single and multiclass usage - if labels.ndim == 1: - labels = torch.reshape(labels, (labels.shape[0], 1)) - - # Forward pass on the network - outputs = model(images) - - loss = criterion(outputs, labels.double()) - - losses_in_batch.append(loss.item()) - samples_in_batch.append(len(samples)) - - # Normalize loss to account for batch accumulation - loss = loss / batch_chunk_count - - # Accumulate gradients - does not update weights just yet... - loss.backward() - - # Weight update on the network - if ((idx + 1) % batch_chunk_count == 0) or (idx + 1 == len(loader)): - # Advances optimizer to the "next" state and applies weight update - # over the whole model - optimizer.step() - - # Zeroes gradients for the next batch - optimizer.zero_grad() - - # Normalize loss for current batch - batch_loss = numpy.average( - losses_in_batch, weights=samples_in_batch - ) - losses_in_epoch.append(batch_loss.item()) - samples_in_epoch.append(len(samples)) - - losses_in_batch.clear() - samples_in_batch.clear() - logger.debug(f"batch loss: {batch_loss.item()}") - - return numpy.average(losses_in_epoch, weights=samples_in_epoch) - - -def validate_epoch(loader, model, device, criterion, pbar_desc): - """Processes input samples and returns loss (scalar) - - Parameters - ---------- - - loader : :py:class:`torch.utils.data.DataLoader` - To be used to validate the model - - model : :py:class:`torch.nn.Module` - Network (e.g. driu, hed, unet) - - optimizer : :py:mod:`torch.optim` - - device : :py:class:`torch.device` - device to use - - criterion : :py:class:`torch.nn.modules.loss._Loss` - loss function - - pbar_desc : str - A string for the progress bar descriptor - - - Returns - ------- - - loss : float - A floating-point value corresponding the weighted average of this - epoch's loss - """ - batch_losses = [] - samples_in_batch = [] - - with torch.no_grad(), torch_evaluation(model): - for samples in tqdm(loader, desc=pbar_desc, leave=False, disable=None): - images = samples[1].to( - device=device, - non_blocking=torch.cuda.is_available(), - ) - labels = samples[2].to( - device=device, - non_blocking=torch.cuda.is_available(), - ) - - # Increase label dimension if too low - # Allows single and multiclass usage - if labels.ndim == 1: - labels = torch.reshape(labels, (labels.shape[0], 1)) - - # data forwarding on the existing network - outputs = model(images) - loss = criterion(outputs, labels.double()) - - batch_losses.append(loss.item()) - samples_in_batch.append(len(samples)) - - return numpy.average(batch_losses, weights=samples_in_batch) - - -def write_log_info( - epoch, - current_time, - eta_seconds, - loss, - valid_loss, - extra_valid_losses, - optimizer, - logwriter, - logfile, - resource_data, -): - """Write log info in trainlog.csv. - - Parameters - ---------- - - epoch : int - Current epoch - - current_time : float - Current training time - - eta_seconds : float - estimated time-of-arrival taking into consideration previous epoch performance - - loss : float - Current epoch's training loss - - valid_loss : :py:class:`float`, None - Current epoch's validation loss - - extra_valid_losses : :py:class:`list` of :py:class:`float` - Validation losses from other validation datasets being currently - tracked - - optimizer : :py:mod:`torch.optim` - - logwriter : csv.DictWriter - Dictionary writer that give the ability to write on the trainlog.csv - - logfile : io.TextIOWrapper - - resource_data : tuple - Monitored resources at the machine (CPU and GPU) - """ - - logdata = ( - ("epoch", f"{epoch}"), - ( - "total_time", - f"{datetime.timedelta(seconds=int(current_time))}", - ), - ("eta", f"{datetime.timedelta(seconds=int(eta_seconds))}"), - ("loss", f"{loss:.6f}"), - ("learning_rate", f"{optimizer.param_groups[0]['lr']:.6f}"), - ) - - if valid_loss is not None: - logdata += (("validation_loss", f"{valid_loss:.6f}"),) - - if extra_valid_losses: - entry = numpy.array_str( - numpy.array(extra_valid_losses), - max_line_width=sys.maxsize, - precision=6, - ) - logdata += (("extra_validation_losses", entry),) - - logdata += resource_data - - logwriter.writerow(dict(k for k in logdata)) - logfile.flush() - tqdm.write("|".join([f"{k}: {v}" for (k, v) in logdata[:4]])) - - def run( model, data_loader, @@ -562,123 +260,3 @@ def run( ) _ = trainer.fit(model, data_loader, valid_loader, ckpt_path=checkpoint) - - """# write static information to a CSV file - static_logfile_name = os.path.join(output_folder, "constants.csv") - - static_information_to_csv(static_logfile_name, device, n) - - # Log continous information to (another) file - logfile_name = os.path.join(output_folder, "trainlog.csv") - - check_exist_logfile(logfile_name, arguments) - - logfile_fields = create_logfile_fields( - valid_loader, extra_valid_loaders, device - ) - - # the lowest validation loss obtained so far - this value is updated only - # if a validation set is available - lowest_validation_loss = initialize_lowest_validation_loss( - logfile_name, arguments - ) - - # set a specific validation criterion if the user has set one - criterion_valid = criterion_valid or criterion - - with open(logfile_name, "a+", newline="") as logfile: - logwriter = csv.DictWriter(logfile, fieldnames=logfile_fields) - - if arguments["epoch"] == 0: - logwriter.writeheader() - - model.train() # set training mode - - model.to(device) # set/cast parameters to device - for state in optimizer.state.values(): - for k, v in state.items(): - if isinstance(v, torch.Tensor): - state[k] = v.to(device) - - # Total training timer - start_training_time = time.time() - - for epoch in tqdm( - range(start_epoch, max_epoch), - desc="epoch", - leave=False, - disable=None, - ): - with ResourceMonitor( - interval=monitoring_interval, - has_gpu=(device.type == "cuda"), - main_pid=os.getpid(), - logging_level=logging.ERROR, - ) as resource_monitor: - epoch = epoch + 1 - arguments["epoch"] = epoch - - # Epoch time - start_epoch_time = time.time() - - train_loss = train_epoch( - data_loader, - model, - optimizer, - device, - criterion, - batch_chunk_count, - ) - - valid_loss = ( - validate_epoch( - valid_loader, model, device, criterion_valid, "valid" - ) - if valid_loader is not None - else None - ) - - extra_valid_losses = [] - for pos, extra_valid_loader in enumerate(extra_valid_loaders): - loss = validate_epoch( - extra_valid_loader, - model, - device, - criterion_valid, - f"xval@{pos+1}", - ) - extra_valid_losses.append(loss) - - lowest_validation_loss = checkpointer_process( - checkpointer, - checkpoint_period, - valid_loss, - lowest_validation_loss, - arguments, - epoch, - max_epoch, - ) - - # computes ETA (estimated time-of-arrival; end of training) taking - # into consideration previous epoch performance - epoch_time = time.time() - start_epoch_time - eta_seconds = epoch_time * (max_epoch - epoch) - current_time = time.time() - start_training_time - - write_log_info( - epoch, - current_time, - eta_seconds, - train_loss, - valid_loss, - extra_valid_losses, - optimizer, - logwriter, - logfile, - resource_monitor.data, - ) - - total_training_time = time.time() - start_training_time - logger.info( - f"Total training time: {datetime.timedelta(seconds=total_training_time)} ({(total_training_time/max_epoch):.4f}s in average per epoch)" - )""" -- GitLab