diff --git a/src/ptbench/configs/models/pasa.py b/src/ptbench/configs/models/pasa.py index d09161d0408118a3fd2dccefc1b082319f99051b..3ee0b92164b5531b65049b94e71b01b07e2ad27e 100644 --- a/src/ptbench/configs/models/pasa.py +++ b/src/ptbench/configs/models/pasa.py @@ -11,19 +11,20 @@ Screening and Visualization". Reference: [PASA-2019]_ """ +from torch import empty from torch.nn import BCEWithLogitsLoss -from torch.optim import Adam -from ...models.pasa import build_pasa +from ...models.pasa import PASA # config -lr = 8e-5 - -# model -model = build_pasa() +optimizer_configs = {"lr": 8e-5} # optimizer -optimizer = Adam(model.parameters(), lr=lr) +optimizer = "Adam" # criterion -criterion = BCEWithLogitsLoss() +criterion = BCEWithLogitsLoss(pos_weight=empty(1)) +criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1)) + +# model +model = PASA(criterion, criterion_valid, optimizer, optimizer_configs) diff --git a/src/ptbench/engine/callbacks.py b/src/ptbench/engine/callbacks.py new file mode 100644 index 0000000000000000000000000000000000000000..d730cf68c95e69fe4191dd5e6bdf95ed8b9ee596 --- /dev/null +++ b/src/ptbench/engine/callbacks.py @@ -0,0 +1,104 @@ +import csv +import time + +import numpy + +from pytorch_lightning import Callback +from pytorch_lightning.callbacks import BasePredictionWriter + + +# This ensures CSVLogger logs training and evaluation metrics on the same line +# CSVLogger only accepts numerical values, not strings +class LoggingCallback(Callback): + def __init__(self, resource_monitor): + super().__init__() + self.training_loss = [] + self.validation_loss = [] + self.start_training_time = 0 + self.start_epoch_time = 0 + + self.resource_monitor = resource_monitor + self.max_queue_retries = 2 + + def on_train_start(self, trainer, pl_module): + self.start_training_time = time.time() + + def on_train_epoch_start(self, trainer, pl_module): + self.start_epoch_time = time.time() + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + self.training_loss.append(outputs["loss"].item()) + + def on_validation_batch_end( + self, trainer, pl_module, outputs, batch, batch_idx + ): + self.validation_loss.append(outputs["validation_loss"].item()) + + def on_validation_epoch_end(self, trainer, pl_module): + self.resource_monitor.trigger_summary() + + self.epoch_time = time.time() - self.start_epoch_time + eta_seconds = self.epoch_time * ( + trainer.max_epochs - trainer.current_epoch + ) + current_time = time.time() - self.start_training_time + + self.log("total_time", current_time) + self.log("eta", eta_seconds) + self.log("loss", numpy.average(self.training_loss)) + self.log("learning_rate", pl_module.lr) + self.log("validation_loss", numpy.average(self.validation_loss)) + + queue_retries = 0 + # In case the resource monitor takes longer to fetch data from the queue, we wait + # Give up after self.resource_monitor.interval * self.max_queue_retries if cannot retrieve metrics from queue + while ( + self.resource_monitor.data is None + and queue_retries < self.max_queue_retries + ): + queue_retries = queue_retries + 1 + print( + f"Monitor queue is empty, retrying in {self.resource_monitor.interval}s" + ) + time.sleep(self.resource_monitor.interval) + + if queue_retries >= self.max_queue_retries: + print( + f"Unable to fetch monitoring information from queue after {queue_retries} retries" + ) + + assert self.resource_monitor.q.empty() + + for metric_name, metric_value in self.resource_monitor.data: + self.log(metric_name, metric_value) + + self.resource_monitor.data = None + + self.training_loss = [] + self.validation_loss = [] + + +class PredictionsWriter(BasePredictionWriter): + def __init__(self, logfile_name, logfile_fields, write_interval): + super().__init__(write_interval) + self.logfile_name = logfile_name + self.logfile_fields = logfile_fields + + def write_on_epoch_end( + self, trainer, pl_module, predictions, batch_indices + ): + with open(self.logfile_name, "w") as logfile: + logwriter = csv.DictWriter(logfile, fieldnames=self.logfile_fields) + logwriter.writeheader() + + # We should only get a single epoch here + for epoch in predictions: + for prediction in epoch: + logwriter.writerow( + { + "filename": prediction[0], + "likelihood": prediction[1].numpy(), + "ground_truth": prediction[2].numpy(), + } + ) + logfile.flush() diff --git a/src/ptbench/engine/trainer.py b/src/ptbench/engine/trainer.py index fd48d4f1fdd82eea05d28dfff7d035e5694a08fd..965a89cd7ee82f5c68ea079368b7a414dd264a95 100644 --- a/src/ptbench/engine/trainer.py +++ b/src/ptbench/engine/trainer.py @@ -9,16 +9,18 @@ import logging import os import shutil import sys -import time import numpy import torch +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger +from pytorch_lightning.utilities.model_summary import ModelSummary from tqdm import tqdm -# from ..utils.resources import cpu_constants, gpu_constants, cpu_log, gpu_log from ..utils.resources import ResourceMonitor, cpu_constants, gpu_constants -from ..utils.summary import summary +from .callbacks import LoggingCallback logger = logging.getLogger(__name__) @@ -127,10 +129,9 @@ def save_model_summary(output_folder, model): summary_path = os.path.join(output_folder, "model_summary.txt") logger.info(f"Saving model summary at {summary_path}...") with open(summary_path, "w") as f: - r, n = summary(model) - logger.info(f"Model has {n} parameters...") - f.write(r) - return r, n + summary = str(ModelSummary(model, max_depth=-1)) + f.write(summary) + return summary def static_information_to_csv(static_logfile_name, device, n): @@ -582,7 +583,6 @@ def run( specific loss function for the validation set """ - start_epoch = arguments["epoch"] max_epoch = arguments["max_epoch"] check_gpu(device) @@ -590,9 +590,40 @@ def run( os.makedirs(output_folder, exist_ok=True) # Save model summary - r, n = save_model_summary(output_folder, model) + _ = save_model_summary(output_folder, model) - # write static information to a CSV file + csv_logger = CSVLogger(output_folder, "logs_csv") + tensorboard_logger = TensorBoardLogger(output_folder, "logs_tensorboard") + + resource_monitor = ResourceMonitor( + interval=5.0, + has_gpu=(device.type == "cuda"), + main_pid=os.getpid(), + logging_level=logging.ERROR, + ) + + with resource_monitor: + trainer = Trainer( + accelerator="auto", + devices="auto", + max_epochs=max_epoch, + logger=[csv_logger, tensorboard_logger], + check_val_every_n_epoch=1, + callbacks=[ + LoggingCallback(resource_monitor), + ModelCheckpoint( + output_folder, + monitor="validation_loss", + mode="min", + save_on_train_epoch_end=False, + every_n_epochs=checkpoint_period, + ), + ], + ) + + _ = trainer.fit(model, data_loader) + + """# write static information to a CSV file static_logfile_name = os.path.join(output_folder, "constants.csv") static_information_to_csv(static_logfile_name, device, n) @@ -710,4 +741,4 @@ def run( total_training_time = time.time() - start_training_time logger.info( f"Total training time: {datetime.timedelta(seconds=total_training_time)} ({(total_training_time/max_epoch):.4f}s in average per epoch)" - ) + )""" diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py index 10e6cedeb672094f62f5c1d16dbaeb9d6983ce34..4fd816e17ab4a3dc85ef6769c9d2325657b91310 100644 --- a/src/ptbench/models/pasa.py +++ b/src/ptbench/models/pasa.py @@ -2,23 +2,49 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from collections import OrderedDict - +import pytorch_lightning as pl import torch import torch.nn as nn import torch.nn.functional as F from .normalizer import TorchVisionNormalizer - -class PASA(nn.Module): +colors = [ + [(47, 79, 79), "Cardiomegaly"], + [(255, 0, 0), "Emphysema"], + [(0, 128, 0), "Pleural effusion"], + [(0, 0, 128), "Hernia"], + [(255, 84, 0), "Infiltration"], + [(222, 184, 135), "Mass"], + [(0, 255, 0), "Nodule"], + [(0, 191, 255), "Atelectasis"], + [(0, 0, 255), "Pneumothorax"], + [(255, 0, 255), "Pleural thickening"], + [(255, 255, 0), "Pneumonia"], + [(126, 0, 255), "Fibrosis"], + [(255, 20, 147), "Edema"], + [(0, 255, 180), "Consolidation"], +] + + +class PASA(pl.LightningModule): """PASA module. Based on paper by [PASA-2019]_. """ - def __init__(self): + def __init__(self, criterion, criterion_valid, optimizer, optimizer_params): super().__init__() + + self.save_hyperparameters() + + self.name = "pasa" + + self.criterion = criterion + self.criterion_valid = criterion_valid + + self.normalizer = TorchVisionNormalizer(nb_channels=1) + # First convolution block self.fc1 = nn.Conv2d(1, 4, (3, 3), (2, 2), (1, 1)) self.fc2 = nn.Conv2d(4, 16, (3, 3), (2, 2), (1, 1)) @@ -82,6 +108,8 @@ class PASA(nn.Module): tensor : :py:class:`torch.Tensor` """ + x = self.normalizer(x) + # First convolution block _x = x x = F.relu(self.batchNorm2d_4(self.fc1(x))) # 1st convolution @@ -127,21 +155,41 @@ class PASA(nn.Module): return x + def training_step(self, batch, batch_idx): + images = batch[1] + labels = batch[2] -def build_pasa(): - """Build pasa CNN. + # Increase label dimension if too low + # Allows single and multiclass usage + if labels.ndim == 1: + labels = torch.reshape(labels, (labels.shape[0], 1)) - Returns - ------- + # Forward pass on the network + outputs = self(images) - module : :py:class:`torch.nn.Module` - """ - model = PASA() - model = [ - ("normalizer", TorchVisionNormalizer(nb_channels=1)), - ("model", model), - ] - model = nn.Sequential(OrderedDict(model)) - - model.name = "pasa" - return model + training_loss = self.criterion(outputs, labels.double()) + + return {"loss": training_loss} + + def validation_step(self, batch, batch_idx): + images = batch[1] + labels = batch[2] + + # Increase label dimension if too low + # Allows single and multiclass usage + if labels.ndim == 1: + labels = torch.reshape(labels, (labels.shape[0], 1)) + + # data forwarding on the existing network + outputs = self(images) + validation_loss = self.criterion_valid(outputs, labels.double()) + + return {"validation_loss": validation_loss} + + def configure_optimizers(self): + # Dynamically instantiates the optimizer given the configs + optimizer = getattr(torch.optim, self.hparams.optimizer)( + self.parameters(), **self.hparams.optimizer_params + ) + + return optimizer diff --git a/src/ptbench/utils/resources.py b/src/ptbench/utils/resources.py index be23ee452a1823555220c5d92d80a2f7c6a9223f..fa0ac3dd2b2332a8d938850d97739b457eeed13a 100644 --- a/src/ptbench/utils/resources.py +++ b/src/ptbench/utils/resources.py @@ -233,6 +233,7 @@ class CPULogger: cpu_percent = [] open_files = [] gone = set() + for k in self.cluster: try: memory_info.append(k.memory_info()) @@ -243,7 +244,6 @@ class CPULogger: # it is too late to update any intermediate list # at this point, but ensures to update counts later on gone.add(k) - return ( ("cpu_memory_used", psutil.virtual_memory().used / GB), ("cpu_rss", sum([k.rss for k in memory_info]) / GB), @@ -288,6 +288,10 @@ class _InformationGatherer: for i, k in enumerate(gpu_log()): self.data[i + self.cpu_keys_len].append(k[1]) + def clear(self): + """Clears accumulated data.""" + self.data = [[] for _ in self.keys] + def summary(self): """Returns the current data.""" if len(self.data[0]) == 0: @@ -298,7 +302,9 @@ class _InformationGatherer: return tuple(retval) -def _monitor_worker(interval, has_gpu, main_pid, stop, queue, logging_level): +def _monitor_worker( + interval, has_gpu, main_pid, stop, summary_event, queue, logging_level +): """A monitoring worker that measures resources and returns lists. Parameters @@ -329,6 +335,12 @@ def _monitor_worker(interval, has_gpu, main_pid, stop, queue, logging_level): while not stop.is_set(): try: ra.acc() # guarantees at least an entry will be available + + if summary_event.is_set(): + queue.put(ra.summary()) + ra.clear() + summary_event.clear() + time.sleep(interval) except Exception: logger.warning( @@ -337,8 +349,6 @@ def _monitor_worker(interval, has_gpu, main_pid, stop, queue, logging_level): ) time.sleep(0.5) # wait half a second, and try again! - queue.put(ra.summary()) - class ResourceMonitor: """An external, non-blocking CPU/GPU resource monitor. @@ -364,7 +374,8 @@ class ResourceMonitor: self.interval = interval self.has_gpu = has_gpu self.main_pid = main_pid - self.event = multiprocessing.Event() + self.stop_event = multiprocessing.Event() + self.summary_event = multiprocessing.Event() self.q = multiprocessing.Queue() self.logging_level = logging_level @@ -375,7 +386,8 @@ class ResourceMonitor: self.interval, self.has_gpu, self.main_pid, - self.event, + self.stop_event, + self.summary_event, self.q, self.logging_level, ), @@ -390,19 +402,9 @@ class ResourceMonitor: def __enter__(self): """Starts the monitoring process.""" self.monitor.start() - return self - def __exit__(self, *exc): - """Stops the monitoring process and returns the summary of - observations.""" - - self.event.set() - self.monitor.join() - if self.monitor.exitcode != 0: - logger.error( - f"CPU/GPU resource monitor process exited with code " - f"{self.monitor.exitcode}. Check logs for errors!" - ) + def trigger_summary(self): + self.summary_event.set() try: data = self.q.get(timeout=2 * self.interval) @@ -426,3 +428,15 @@ class ResourceMonitor: else: summary.append((k, 0.0)) self.data = tuple(summary) + + def __exit__(self, *exc): + """Stops the monitoring process and returns the summary of + observations.""" + + self.stop_event.set() + self.monitor.join() + if self.monitor.exitcode != 0: + logger.error( + f"CPU/GPU resource monitor process exited with code " + f"{self.monitor.exitcode}. Check logs for errors!" + )