From e1215e9975f400befed7f956a16be20a3401e6aa Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Wed, 12 Apr 2023 10:18:03 +0200
Subject: [PATCH] Removed unused functions

---
 src/ptbench/engine/trainer.py | 422 ----------------------------------
 1 file changed, 422 deletions(-)

diff --git a/src/ptbench/engine/trainer.py b/src/ptbench/engine/trainer.py
index 05b37feb..cf8fc101 100644
--- a/src/ptbench/engine/trainer.py
+++ b/src/ptbench/engine/trainer.py
@@ -2,22 +2,15 @@
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
-import contextlib
 import csv
-import datetime
 import logging
 import os
 import shutil
-import sys
-
-import numpy
-import torch
 
 from pytorch_lightning import Trainer
 from pytorch_lightning.callbacks import ModelCheckpoint
 from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger
 from pytorch_lightning.utilities.model_summary import ModelSummary
-from tqdm import tqdm
 
 from ..utils.resources import ResourceMonitor, cpu_constants, gpu_constants
 from .callbacks import LoggingCallback
@@ -25,32 +18,6 @@ from .callbacks import LoggingCallback
 logger = logging.getLogger(__name__)
 
 
-@contextlib.contextmanager
-def torch_evaluation(model):
-    """Context manager to turn ON/OFF model evaluation.
-
-    This context manager will turn evaluation mode ON on entry and turn it OFF
-    when exiting the ``with`` statement block.
-
-
-    Parameters
-    ----------
-
-    model : :py:class:`torch.nn.Module`
-        Network
-
-
-    Yields
-    ------
-
-    model : :py:class:`torch.nn.Module`
-        Network
-    """
-    model.eval()
-    yield model
-    model.train()
-
-
 def check_gpu(device):
     """Check the device type and the availability of GPU.
 
@@ -67,45 +34,6 @@ def check_gpu(device):
         ), f"Device set to '{device}', but nvidia-smi is not installed"
 
 
-def initialize_lowest_validation_loss(logfile_name, arguments):
-    """Initialize the lowest validation loss from the logfile if it exists and
-    if the training does not start from epoch 0, which means that a previous
-    training session is resumed.
-
-    Parameters
-    ----------
-
-    logfile_name : str
-        The logfile_name which is a join between the output_folder and trainlog.csv
-
-    arguments : dict
-        start and end epochs
-    """
-
-    if arguments["epoch"] != 0 and os.path.exists(logfile_name):
-        # Open the CSV file
-        with open(logfile_name) as file:
-            reader = csv.DictReader(file)
-            column_name = "validation_loss"
-
-            if column_name not in reader.fieldnames:
-                return sys.float_info.max
-
-            # Get the values of the desired column as a list
-            values = [float(row[column_name]) for row in reader]
-
-            if not values:
-                return sys.float_info.max
-
-        lowest_value = min(values)
-        logger.info(
-            f"Found lowest validation loss from previous session: {lowest_value}"
-        )
-        return lowest_value
-
-    return sys.float_info.max
-
-
 def save_model_summary(output_folder, model):
     """Save a little summary of the model in a txt file.
 
@@ -220,236 +148,6 @@ def create_logfile_fields(valid_loader, extra_valid_loaders, device):
     return logfile_fields
 
 
-def train_epoch(loader, model, optimizer, device, criterion, batch_chunk_count):
-    """Trains the model for a single epoch (through all batches)
-
-    Parameters
-    ----------
-
-    loader : :py:class:`torch.utils.data.DataLoader`
-        To be used to train the model
-
-    model : :py:class:`torch.nn.Module`
-        Network (e.g. driu, hed, unet)
-
-    optimizer : :py:mod:`torch.optim`
-
-    device : :py:class:`torch.device`
-        device to use
-
-    criterion : :py:class:`torch.nn.modules.loss._Loss`
-
-    batch_chunk_count: int
-        If this number is different than 1, then each batch will be divided in
-        this number of chunks.  Gradients will be accumulated to perform each
-        mini-batch.   This is particularly interesting when one has limited RAM
-        on the GPU, but would like to keep training with larger batches.  One
-        exchanges for longer processing times in this case.  To better understand
-        gradient accumulation, read
-        https://stackoverflow.com/questions/62067400/understanding-accumulated-gradients-in-pytorch.
-
-
-    Returns
-    -------
-
-    loss : float
-        A floating-point value corresponding the weighted average of this
-        epoch's loss
-    """
-    losses_in_epoch = []
-    samples_in_epoch = []
-    losses_in_batch = []
-    samples_in_batch = []
-
-    # progress bar only on interactive jobs
-    for idx, samples in enumerate(
-        tqdm(loader, desc="train", leave=False, disable=None)
-    ):
-        images = samples[1].to(
-            device=device, non_blocking=torch.cuda.is_available()
-        )
-        labels = samples[2].to(
-            device=device, non_blocking=torch.cuda.is_available()
-        )
-
-        # Increase label dimension if too low
-        # Allows single and multiclass usage
-        if labels.ndim == 1:
-            labels = torch.reshape(labels, (labels.shape[0], 1))
-
-        # Forward pass on the network
-        outputs = model(images)
-
-        loss = criterion(outputs, labels.double())
-
-        losses_in_batch.append(loss.item())
-        samples_in_batch.append(len(samples))
-
-        # Normalize loss to account for batch accumulation
-        loss = loss / batch_chunk_count
-
-        # Accumulate gradients - does not update weights just yet...
-        loss.backward()
-
-        # Weight update on the network
-        if ((idx + 1) % batch_chunk_count == 0) or (idx + 1 == len(loader)):
-            # Advances optimizer to the "next" state and applies weight update
-            # over the whole model
-            optimizer.step()
-
-            # Zeroes gradients for the next batch
-            optimizer.zero_grad()
-
-            # Normalize loss for current batch
-            batch_loss = numpy.average(
-                losses_in_batch, weights=samples_in_batch
-            )
-            losses_in_epoch.append(batch_loss.item())
-            samples_in_epoch.append(len(samples))
-
-            losses_in_batch.clear()
-            samples_in_batch.clear()
-            logger.debug(f"batch loss: {batch_loss.item()}")
-
-    return numpy.average(losses_in_epoch, weights=samples_in_epoch)
-
-
-def validate_epoch(loader, model, device, criterion, pbar_desc):
-    """Processes input samples and returns loss (scalar)
-
-    Parameters
-    ----------
-
-    loader : :py:class:`torch.utils.data.DataLoader`
-        To be used to validate the model
-
-    model : :py:class:`torch.nn.Module`
-        Network (e.g. driu, hed, unet)
-
-    optimizer : :py:mod:`torch.optim`
-
-    device : :py:class:`torch.device`
-        device to use
-
-    criterion : :py:class:`torch.nn.modules.loss._Loss`
-        loss function
-
-    pbar_desc : str
-        A string for the progress bar descriptor
-
-
-    Returns
-    -------
-
-    loss : float
-        A floating-point value corresponding the weighted average of this
-        epoch's loss
-    """
-    batch_losses = []
-    samples_in_batch = []
-
-    with torch.no_grad(), torch_evaluation(model):
-        for samples in tqdm(loader, desc=pbar_desc, leave=False, disable=None):
-            images = samples[1].to(
-                device=device,
-                non_blocking=torch.cuda.is_available(),
-            )
-            labels = samples[2].to(
-                device=device,
-                non_blocking=torch.cuda.is_available(),
-            )
-
-            # Increase label dimension if too low
-            # Allows single and multiclass usage
-            if labels.ndim == 1:
-                labels = torch.reshape(labels, (labels.shape[0], 1))
-
-            # data forwarding on the existing network
-            outputs = model(images)
-            loss = criterion(outputs, labels.double())
-
-            batch_losses.append(loss.item())
-            samples_in_batch.append(len(samples))
-
-    return numpy.average(batch_losses, weights=samples_in_batch)
-
-
-def write_log_info(
-    epoch,
-    current_time,
-    eta_seconds,
-    loss,
-    valid_loss,
-    extra_valid_losses,
-    optimizer,
-    logwriter,
-    logfile,
-    resource_data,
-):
-    """Write log info in trainlog.csv.
-
-    Parameters
-    ----------
-
-    epoch : int
-        Current epoch
-
-    current_time : float
-        Current training time
-
-    eta_seconds : float
-        estimated time-of-arrival taking into consideration previous epoch performance
-
-    loss : float
-        Current epoch's training loss
-
-    valid_loss : :py:class:`float`, None
-        Current epoch's validation loss
-
-    extra_valid_losses : :py:class:`list` of :py:class:`float`
-        Validation losses from other validation datasets being currently
-        tracked
-
-    optimizer : :py:mod:`torch.optim`
-
-    logwriter : csv.DictWriter
-        Dictionary writer that give the ability to write on the trainlog.csv
-
-    logfile : io.TextIOWrapper
-
-    resource_data : tuple
-        Monitored resources at the machine (CPU and GPU)
-    """
-
-    logdata = (
-        ("epoch", f"{epoch}"),
-        (
-            "total_time",
-            f"{datetime.timedelta(seconds=int(current_time))}",
-        ),
-        ("eta", f"{datetime.timedelta(seconds=int(eta_seconds))}"),
-        ("loss", f"{loss:.6f}"),
-        ("learning_rate", f"{optimizer.param_groups[0]['lr']:.6f}"),
-    )
-
-    if valid_loss is not None:
-        logdata += (("validation_loss", f"{valid_loss:.6f}"),)
-
-    if extra_valid_losses:
-        entry = numpy.array_str(
-            numpy.array(extra_valid_losses),
-            max_line_width=sys.maxsize,
-            precision=6,
-        )
-        logdata += (("extra_validation_losses", entry),)
-
-    logdata += resource_data
-
-    logwriter.writerow(dict(k for k in logdata))
-    logfile.flush()
-    tqdm.write("|".join([f"{k}: {v}" for (k, v) in logdata[:4]]))
-
-
 def run(
     model,
     data_loader,
@@ -562,123 +260,3 @@ def run(
         )
 
         _ = trainer.fit(model, data_loader, valid_loader, ckpt_path=checkpoint)
-
-    """# write static information to a CSV file
-    static_logfile_name = os.path.join(output_folder, "constants.csv")
-
-    static_information_to_csv(static_logfile_name, device, n)
-
-    # Log continous information to (another) file
-    logfile_name = os.path.join(output_folder, "trainlog.csv")
-
-    check_exist_logfile(logfile_name, arguments)
-
-    logfile_fields = create_logfile_fields(
-        valid_loader, extra_valid_loaders, device
-    )
-
-    # the lowest validation loss obtained so far - this value is updated only
-    # if a validation set is available
-    lowest_validation_loss = initialize_lowest_validation_loss(
-        logfile_name, arguments
-    )
-
-    # set a specific validation criterion if the user has set one
-    criterion_valid = criterion_valid or criterion
-
-    with open(logfile_name, "a+", newline="") as logfile:
-        logwriter = csv.DictWriter(logfile, fieldnames=logfile_fields)
-
-        if arguments["epoch"] == 0:
-            logwriter.writeheader()
-
-        model.train()  # set training mode
-
-        model.to(device)  # set/cast parameters to device
-        for state in optimizer.state.values():
-            for k, v in state.items():
-                if isinstance(v, torch.Tensor):
-                    state[k] = v.to(device)
-
-        # Total training timer
-        start_training_time = time.time()
-
-        for epoch in tqdm(
-            range(start_epoch, max_epoch),
-            desc="epoch",
-            leave=False,
-            disable=None,
-        ):
-            with ResourceMonitor(
-                interval=monitoring_interval,
-                has_gpu=(device.type == "cuda"),
-                main_pid=os.getpid(),
-                logging_level=logging.ERROR,
-            ) as resource_monitor:
-                epoch = epoch + 1
-                arguments["epoch"] = epoch
-
-                # Epoch time
-                start_epoch_time = time.time()
-
-                train_loss = train_epoch(
-                    data_loader,
-                    model,
-                    optimizer,
-                    device,
-                    criterion,
-                    batch_chunk_count,
-                )
-
-                valid_loss = (
-                    validate_epoch(
-                        valid_loader, model, device, criterion_valid, "valid"
-                    )
-                    if valid_loader is not None
-                    else None
-                )
-
-                extra_valid_losses = []
-                for pos, extra_valid_loader in enumerate(extra_valid_loaders):
-                    loss = validate_epoch(
-                        extra_valid_loader,
-                        model,
-                        device,
-                        criterion_valid,
-                        f"xval@{pos+1}",
-                    )
-                    extra_valid_losses.append(loss)
-
-            lowest_validation_loss = checkpointer_process(
-                checkpointer,
-                checkpoint_period,
-                valid_loss,
-                lowest_validation_loss,
-                arguments,
-                epoch,
-                max_epoch,
-            )
-
-            # computes ETA (estimated time-of-arrival; end of training) taking
-            # into consideration previous epoch performance
-            epoch_time = time.time() - start_epoch_time
-            eta_seconds = epoch_time * (max_epoch - epoch)
-            current_time = time.time() - start_training_time
-
-            write_log_info(
-                epoch,
-                current_time,
-                eta_seconds,
-                train_loss,
-                valid_loss,
-                extra_valid_losses,
-                optimizer,
-                logwriter,
-                logfile,
-                resource_monitor.data,
-            )
-
-        total_training_time = time.time() - start_training_time
-        logger.info(
-            f"Total training time: {datetime.timedelta(seconds=total_training_time)} ({(total_training_time/max_epoch):.4f}s in average per epoch)"
-        )"""
-- 
GitLab