Skip to content
Snippets Groups Projects
Commit a61efdb4 authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

Added pytorch-lightning to pasa model and trainer

parent ee7b9636
No related branches found
No related tags found
1 merge request!4Moved code to lightning
......@@ -11,19 +11,20 @@ Screening and Visualization".
Reference: [PASA-2019]_
"""
from torch import empty
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
from ...models.pasa import build_pasa
from ...models.pasa import PASA
# config
lr = 8e-5
# model
model = build_pasa()
optimizer_configs = {"lr": 8e-5}
# optimizer
optimizer = Adam(model.parameters(), lr=lr)
optimizer = "Adam"
# criterion
criterion = BCEWithLogitsLoss()
criterion = BCEWithLogitsLoss(pos_weight=empty(1))
criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))
# model
model = PASA(criterion, criterion_valid, optimizer, optimizer_configs)
import csv
import time
import numpy
from pytorch_lightning import Callback
from pytorch_lightning.callbacks import BasePredictionWriter
# This ensures CSVLogger logs training and evaluation metrics on the same line
# CSVLogger only accepts numerical values, not strings
class LoggingCallback(Callback):
def __init__(self, resource_monitor):
super().__init__()
self.training_loss = []
self.validation_loss = []
self.start_training_time = 0
self.start_epoch_time = 0
self.resource_monitor = resource_monitor
self.max_queue_retries = 2
def on_train_start(self, trainer, pl_module):
self.start_training_time = time.time()
def on_train_epoch_start(self, trainer, pl_module):
self.start_epoch_time = time.time()
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
self.training_loss.append(outputs["loss"].item())
def on_validation_batch_end(
self, trainer, pl_module, outputs, batch, batch_idx
):
self.validation_loss.append(outputs["validation_loss"].item())
def on_validation_epoch_end(self, trainer, pl_module):
self.resource_monitor.trigger_summary()
self.epoch_time = time.time() - self.start_epoch_time
eta_seconds = self.epoch_time * (
trainer.max_epochs - trainer.current_epoch
)
current_time = time.time() - self.start_training_time
self.log("total_time", current_time)
self.log("eta", eta_seconds)
self.log("loss", numpy.average(self.training_loss))
self.log("learning_rate", pl_module.lr)
self.log("validation_loss", numpy.average(self.validation_loss))
queue_retries = 0
# In case the resource monitor takes longer to fetch data from the queue, we wait
# Give up after self.resource_monitor.interval * self.max_queue_retries if cannot retrieve metrics from queue
while (
self.resource_monitor.data is None
and queue_retries < self.max_queue_retries
):
queue_retries = queue_retries + 1
print(
f"Monitor queue is empty, retrying in {self.resource_monitor.interval}s"
)
time.sleep(self.resource_monitor.interval)
if queue_retries >= self.max_queue_retries:
print(
f"Unable to fetch monitoring information from queue after {queue_retries} retries"
)
assert self.resource_monitor.q.empty()
for metric_name, metric_value in self.resource_monitor.data:
self.log(metric_name, metric_value)
self.resource_monitor.data = None
self.training_loss = []
self.validation_loss = []
class PredictionsWriter(BasePredictionWriter):
def __init__(self, logfile_name, logfile_fields, write_interval):
super().__init__(write_interval)
self.logfile_name = logfile_name
self.logfile_fields = logfile_fields
def write_on_epoch_end(
self, trainer, pl_module, predictions, batch_indices
):
with open(self.logfile_name, "w") as logfile:
logwriter = csv.DictWriter(logfile, fieldnames=self.logfile_fields)
logwriter.writeheader()
# We should only get a single epoch here
for epoch in predictions:
for prediction in epoch:
logwriter.writerow(
{
"filename": prediction[0],
"likelihood": prediction[1].numpy(),
"ground_truth": prediction[2].numpy(),
}
)
logfile.flush()
......@@ -9,16 +9,18 @@ import logging
import os
import shutil
import sys
import time
import numpy
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger
from pytorch_lightning.utilities.model_summary import ModelSummary
from tqdm import tqdm
# from ..utils.resources import cpu_constants, gpu_constants, cpu_log, gpu_log
from ..utils.resources import ResourceMonitor, cpu_constants, gpu_constants
from ..utils.summary import summary
from .callbacks import LoggingCallback
logger = logging.getLogger(__name__)
......@@ -127,10 +129,9 @@ def save_model_summary(output_folder, model):
summary_path = os.path.join(output_folder, "model_summary.txt")
logger.info(f"Saving model summary at {summary_path}...")
with open(summary_path, "w") as f:
r, n = summary(model)
logger.info(f"Model has {n} parameters...")
f.write(r)
return r, n
summary = str(ModelSummary(model, max_depth=-1))
f.write(summary)
return summary
def static_information_to_csv(static_logfile_name, device, n):
......@@ -582,7 +583,6 @@ def run(
specific loss function for the validation set
"""
start_epoch = arguments["epoch"]
max_epoch = arguments["max_epoch"]
check_gpu(device)
......@@ -590,9 +590,40 @@ def run(
os.makedirs(output_folder, exist_ok=True)
# Save model summary
r, n = save_model_summary(output_folder, model)
_ = save_model_summary(output_folder, model)
# write static information to a CSV file
csv_logger = CSVLogger(output_folder, "logs_csv")
tensorboard_logger = TensorBoardLogger(output_folder, "logs_tensorboard")
resource_monitor = ResourceMonitor(
interval=5.0,
has_gpu=(device.type == "cuda"),
main_pid=os.getpid(),
logging_level=logging.ERROR,
)
with resource_monitor:
trainer = Trainer(
accelerator="auto",
devices="auto",
max_epochs=max_epoch,
logger=[csv_logger, tensorboard_logger],
check_val_every_n_epoch=1,
callbacks=[
LoggingCallback(resource_monitor),
ModelCheckpoint(
output_folder,
monitor="validation_loss",
mode="min",
save_on_train_epoch_end=False,
every_n_epochs=checkpoint_period,
),
],
)
_ = trainer.fit(model, data_loader)
"""# write static information to a CSV file
static_logfile_name = os.path.join(output_folder, "constants.csv")
static_information_to_csv(static_logfile_name, device, n)
......@@ -710,4 +741,4 @@ def run(
total_training_time = time.time() - start_training_time
logger.info(
f"Total training time: {datetime.timedelta(seconds=total_training_time)} ({(total_training_time/max_epoch):.4f}s in average per epoch)"
)
)"""
......@@ -2,23 +2,49 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later
from collections import OrderedDict
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from .normalizer import TorchVisionNormalizer
class PASA(nn.Module):
colors = [
[(47, 79, 79), "Cardiomegaly"],
[(255, 0, 0), "Emphysema"],
[(0, 128, 0), "Pleural effusion"],
[(0, 0, 128), "Hernia"],
[(255, 84, 0), "Infiltration"],
[(222, 184, 135), "Mass"],
[(0, 255, 0), "Nodule"],
[(0, 191, 255), "Atelectasis"],
[(0, 0, 255), "Pneumothorax"],
[(255, 0, 255), "Pleural thickening"],
[(255, 255, 0), "Pneumonia"],
[(126, 0, 255), "Fibrosis"],
[(255, 20, 147), "Edema"],
[(0, 255, 180), "Consolidation"],
]
class PASA(pl.LightningModule):
"""PASA module.
Based on paper by [PASA-2019]_.
"""
def __init__(self):
def __init__(self, criterion, criterion_valid, optimizer, optimizer_params):
super().__init__()
self.save_hyperparameters()
self.name = "pasa"
self.criterion = criterion
self.criterion_valid = criterion_valid
self.normalizer = TorchVisionNormalizer(nb_channels=1)
# First convolution block
self.fc1 = nn.Conv2d(1, 4, (3, 3), (2, 2), (1, 1))
self.fc2 = nn.Conv2d(4, 16, (3, 3), (2, 2), (1, 1))
......@@ -82,6 +108,8 @@ class PASA(nn.Module):
tensor : :py:class:`torch.Tensor`
"""
x = self.normalizer(x)
# First convolution block
_x = x
x = F.relu(self.batchNorm2d_4(self.fc1(x))) # 1st convolution
......@@ -127,21 +155,41 @@ class PASA(nn.Module):
return x
def training_step(self, batch, batch_idx):
images = batch[1]
labels = batch[2]
def build_pasa():
"""Build pasa CNN.
# Increase label dimension if too low
# Allows single and multiclass usage
if labels.ndim == 1:
labels = torch.reshape(labels, (labels.shape[0], 1))
Returns
-------
# Forward pass on the network
outputs = self(images)
module : :py:class:`torch.nn.Module`
"""
model = PASA()
model = [
("normalizer", TorchVisionNormalizer(nb_channels=1)),
("model", model),
]
model = nn.Sequential(OrderedDict(model))
model.name = "pasa"
return model
training_loss = self.criterion(outputs, labels.double())
return {"loss": training_loss}
def validation_step(self, batch, batch_idx):
images = batch[1]
labels = batch[2]
# Increase label dimension if too low
# Allows single and multiclass usage
if labels.ndim == 1:
labels = torch.reshape(labels, (labels.shape[0], 1))
# data forwarding on the existing network
outputs = self(images)
validation_loss = self.criterion_valid(outputs, labels.double())
return {"validation_loss": validation_loss}
def configure_optimizers(self):
# Dynamically instantiates the optimizer given the configs
optimizer = getattr(torch.optim, self.hparams.optimizer)(
self.parameters(), **self.hparams.optimizer_params
)
return optimizer
......@@ -233,6 +233,7 @@ class CPULogger:
cpu_percent = []
open_files = []
gone = set()
for k in self.cluster:
try:
memory_info.append(k.memory_info())
......@@ -243,7 +244,6 @@ class CPULogger:
# it is too late to update any intermediate list
# at this point, but ensures to update counts later on
gone.add(k)
return (
("cpu_memory_used", psutil.virtual_memory().used / GB),
("cpu_rss", sum([k.rss for k in memory_info]) / GB),
......@@ -288,6 +288,10 @@ class _InformationGatherer:
for i, k in enumerate(gpu_log()):
self.data[i + self.cpu_keys_len].append(k[1])
def clear(self):
"""Clears accumulated data."""
self.data = [[] for _ in self.keys]
def summary(self):
"""Returns the current data."""
if len(self.data[0]) == 0:
......@@ -298,7 +302,9 @@ class _InformationGatherer:
return tuple(retval)
def _monitor_worker(interval, has_gpu, main_pid, stop, queue, logging_level):
def _monitor_worker(
interval, has_gpu, main_pid, stop, summary_event, queue, logging_level
):
"""A monitoring worker that measures resources and returns lists.
Parameters
......@@ -329,6 +335,12 @@ def _monitor_worker(interval, has_gpu, main_pid, stop, queue, logging_level):
while not stop.is_set():
try:
ra.acc() # guarantees at least an entry will be available
if summary_event.is_set():
queue.put(ra.summary())
ra.clear()
summary_event.clear()
time.sleep(interval)
except Exception:
logger.warning(
......@@ -337,8 +349,6 @@ def _monitor_worker(interval, has_gpu, main_pid, stop, queue, logging_level):
)
time.sleep(0.5) # wait half a second, and try again!
queue.put(ra.summary())
class ResourceMonitor:
"""An external, non-blocking CPU/GPU resource monitor.
......@@ -364,7 +374,8 @@ class ResourceMonitor:
self.interval = interval
self.has_gpu = has_gpu
self.main_pid = main_pid
self.event = multiprocessing.Event()
self.stop_event = multiprocessing.Event()
self.summary_event = multiprocessing.Event()
self.q = multiprocessing.Queue()
self.logging_level = logging_level
......@@ -375,7 +386,8 @@ class ResourceMonitor:
self.interval,
self.has_gpu,
self.main_pid,
self.event,
self.stop_event,
self.summary_event,
self.q,
self.logging_level,
),
......@@ -390,19 +402,9 @@ class ResourceMonitor:
def __enter__(self):
"""Starts the monitoring process."""
self.monitor.start()
return self
def __exit__(self, *exc):
"""Stops the monitoring process and returns the summary of
observations."""
self.event.set()
self.monitor.join()
if self.monitor.exitcode != 0:
logger.error(
f"CPU/GPU resource monitor process exited with code "
f"{self.monitor.exitcode}. Check logs for errors!"
)
def trigger_summary(self):
self.summary_event.set()
try:
data = self.q.get(timeout=2 * self.interval)
......@@ -426,3 +428,15 @@ class ResourceMonitor:
else:
summary.append((k, 0.0))
self.data = tuple(summary)
def __exit__(self, *exc):
"""Stops the monitoring process and returns the summary of
observations."""
self.stop_event.set()
self.monitor.join()
if self.monitor.exitcode != 0:
logger.error(
f"CPU/GPU resource monitor process exited with code "
f"{self.monitor.exitcode}. Check logs for errors!"
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment