Skip to content
Snippets Groups Projects
Commit 915f4e79 authored by Tim Laibacher's avatar Tim Laibacher
Browse files

Add valtrainer

parent 0ad73ba6
No related branches found
No related tags found
No related merge requests found
Pipeline #31351 failed
from bob.ip.binseg.configs.datasets.drive1024 import dataset as drive
from bob.ip.binseg.configs.datasets.stare1024 import dataset as stare
from bob.ip.binseg.configs.datasets.hrf1024 import dataset as hrf
from bob.ip.binseg.configs.datasets.chasedb11024 import dataset as chase
import torch
#### Config ####
# PyTorch dataset
dataset = torch.utils.data.ConcatDataset([drive,stare,hrf,chase])
\ No newline at end of file
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from bob.db.hrf import Database as HRF
from bob.ip.binseg.data.transforms import *
from bob.ip.binseg.data.binsegdataset import BinSegDataset
#### Config ####
transforms = Compose([
Pad((0,584,0,584))
,Resize((1024))
,RandomRotation()
,RandomHFlip()
,RandomVFlip()
,ColorJitter()
,ToTensor()
])
# bob.db.dataset init
bobdb = HRF(protocol = 'default')
# PyTorch dataset
dataset = BinSegDataset(bobdb, split='train', transform=transforms)
\ No newline at end of file
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import logging
import time
import datetime
import torch
import pandas as pd
from tqdm import tqdm
from bob.ip.binseg.utils.metric import SmoothedValue
from bob.ip.binseg.utils.plot import loss_curve
from bob.ip.binseg.engine.inferencer import do_inference
def do_valtrain(
model,
data_loader,
optimizer,
criterion,
scheduler,
checkpointer,
checkpoint_period,
device,
arguments,
output_folder,
val_loader = None
):
"""
Train model and save to disk.
Parameters
----------
model : :py:class:`torch.nn.Module`
Network (e.g. DRIU, HED, UNet)
data_loader : :py:class:`torch.utils.data.DataLoader`
optimizer : :py:mod:`torch.optim`
criterion : :py:class:`torch.nn.modules.loss._Loss`
loss function
scheduler : :py:mod:`torch.optim`
learning rate scheduler
checkpointer : :py:class:`bob.ip.binseg.utils.checkpointer.DetectronCheckpointer`
checkpointer
checkpoint_period : int
save a checkpoint every n epochs
device : str
device to use ``'cpu'`` or ``'cuda'``
arguments : dict
start end end epochs
output_folder : str
output path
"""
torch.manual_seed(42)
logger = logging.getLogger("bob.ip.binseg.engine.trainer")
logger.info("Start training")
start_epoch = arguments["epoch"]
max_epoch = arguments["max_epoch"]
# Logg to file
with open (os.path.join(output_folder,"{}_trainlog.csv".format(model.name)), "a+", 1) as outfile:
model.train().to(device)
# Total training timer
start_training_time = time.time()
for epoch in range(start_epoch, max_epoch):
scheduler.step()
losses = SmoothedValue(len(data_loader))
epoch = epoch + 1
arguments["epoch"] = epoch
# Epoch time
start_epoch_time = time.time()
for samples in tqdm(data_loader):
images = samples[1].to(device)
ground_truths = samples[2].to(device)
masks = None
if len(samples) == 4:
masks = samples[-1].to(device)
outputs = model(images)
loss = criterion(outputs, ground_truths, masks)
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.update(loss)
logger.debug("batch loss: {}".format(loss.item()))
if epoch % checkpoint_period == 0:
checkpointer.save("model_{:03d}".format(epoch), **arguments)
val_folder = os.path.join((output_folder,epoch))
do_inference(model,val_loader, device, val_folder)
if epoch == max_epoch:
checkpointer.save("model_final", **arguments)
epoch_time = time.time() - start_epoch_time
eta_seconds = epoch_time * (max_epoch - epoch)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
outfile.write(("{epoch}, "
"{avg_loss:.6f}, "
"{median_loss:.6f}, "
"{lr:.6f}, "
"{memory:.0f}"
"\n"
).format(
eta=eta_string,
epoch=epoch,
avg_loss=losses.avg,
median_loss=losses.median,
lr=optimizer.param_groups[0]["lr"],
memory = (torch.cuda.max_memory_allocated() / 1024.0 / 1024.0) if torch.cuda.is_available() else .0,
)
)
logger.info(("eta: {eta}, "
"epoch: {epoch}, "
"avg. loss: {avg_loss:.6f}, "
"median loss: {median_loss:.6f}, "
"lr: {lr:.6f}, "
"max mem: {memory:.0f}"
).format(
eta=eta_string,
epoch=epoch,
avg_loss=losses.avg,
median_loss=losses.median,
lr=optimizer.param_groups[0]["lr"],
memory = (torch.cuda.max_memory_allocated() / 1024.0 / 1024.0) if torch.cuda.is_available() else .0
)
)
total_training_time = time.time() - start_training_time
total_time_str = str(datetime.timedelta(seconds=total_training_time))
logger.info(
"Total training time: {} ({:.4f} s / epoch)".format(
total_time_str, total_training_time / (max_epoch)
))
log_plot_file = os.path.join(output_folder,"{}_trainlog.pdf".format(model.name))
logdf = pd.read_csv(os.path.join(output_folder,"{}_trainlog.csv".format(model.name)),header=None, names=["avg. loss", "median loss","lr","max memory"])
fig = loss_curve(logdf,output_folder)
logger.info("saving {}".format(log_plot_file))
fig.savefig(log_plot_file)
......@@ -24,12 +24,14 @@ from bob.extension.scripts.click_helper import (verbosity_option,
from bob.ip.binseg.utils.checkpointer import DetectronCheckpointer
from torch.utils.data import DataLoader
from bob.ip.binseg.engine.trainer import do_train
from bob.ip.binseg.engine.valtrainer import do_valtrain
from bob.ip.binseg.engine.inferencer import do_inference
from bob.ip.binseg.utils.plot import plot_overview
from bob.ip.binseg.utils.click import OptionEatAll
from bob.ip.binseg.utils.pdfcreator import create_pdf, get_paths
from bob.ip.binseg.utils.rsttable import create_overview_grid
from bob.ip.binseg.utils.plot import metricsviz, overlay
from torch.utils.data import SubsetRandomSampler
logger = logging.getLogger(__name__)
......@@ -308,12 +310,17 @@ def testcheckpoints(model
'-o',
required=True,
)
@click.option(
'--title',
'-t',
required=False,
)
@verbosity_option(cls=ResourceOption)
def compare(output_path_list, output_path, **kwargs):
def compare(output_path_list, output_path, title, **kwargs):
""" Compares multiple metrics files that are stored in the format mymodel/results/Metrics.csv """
logger.debug("Output paths: {}".format(output_path_list))
logger.info('Plotting precision vs recall curves for {}'.format(output_path_list))
fig = plot_overview(output_path_list)
fig = plot_overview(output_path_list,title)
if not os.path.exists(output_path): os.makedirs(output_path)
fig_filename = os.path.join(output_path, 'precision_recall_comparison.pdf')
logger.info('saving {}'.format(fig_filename))
......@@ -390,4 +397,142 @@ def visualize(dataset, output_path, **kwargs):
logger.info('Creating TP, FP, FN visualizations for {}'.format(output_path))
metricsviz(dataset=dataset, output_path=output_path)
logger.info('Creating overlay visualizations for {}'.format(output_path))
overlay(dataset=dataset, output_path=output_path)
\ No newline at end of file
overlay(dataset=dataset, output_path=output_path)
# Validation Train
@binseg.command(entry_point_group='bob.ip.binseg.config', cls=ConfigCommand)
@click.option(
'--output-path',
'-o',
required=True,
default="output",
cls=ResourceOption
)
@click.option(
'--model',
'-m',
required=True,
cls=ResourceOption
)
@click.option(
'--dataset',
'-d',
required=True,
cls=ResourceOption
)
@click.option(
'--optimizer',
required=True,
cls=ResourceOption
)
@click.option(
'--criterion',
required=True,
cls=ResourceOption
)
@click.option(
'--scheduler',
required=True,
cls=ResourceOption
)
@click.option(
'--pretrained-backbone',
'-t',
required=True,
cls=ResourceOption
)
@click.option(
'--batch-size',
'-b',
required=True,
default=2,
cls=ResourceOption)
@click.option(
'--epochs',
'-e',
help='Number of epochs used for training',
show_default=True,
required=True,
default=6,
cls=ResourceOption)
@click.option(
'--checkpoint-period',
'-p',
help='Number of epochs after which a checkpoint is saved',
show_default=True,
required=True,
default=2,
cls=ResourceOption)
@click.option(
'--device',
'-d',
help='A string indicating the device to use (e.g. "cpu" or "cuda:0"',
show_default=True,
required=True,
default='cpu',
cls=ResourceOption)
@click.option(
'--valsize',
'-a',
help='Size of validation set',
show_default=True,
required=True,
default=5,
cls=ResourceOption)
@verbosity_option(cls=ResourceOption)
def valtrain(model
,optimizer
,scheduler
,output_path
,epochs
,pretrained_backbone
,batch_size
,criterion
,dataset
,checkpoint_period
,device
,valsize
,**kwargs):
""" Train a model """
if not os.path.exists(output_path): os.makedirs(output_path)
# Validation and training set size
train_size = len(dataset) - valsize
# PyTorch dataloader
indices = torch.randperm(len(dataset))
train_indices = indices[:len(indices)-valsize][:train_size or None]
valid_indices = indices[len(indices)-valsize:] if valsize else None
train_loader = torch.utils.data.DataLoader(dataset, pin_memory=torch.cuda.is_available(), batch_size=batch_size,
sampler=SubsetRandomSampler(train_indices))
valid_loader = torch.utils.data.DataLoader(dataset, pin_memory=torch.cuda.is_available(), batch_size=batch_size,
sampler=SubsetRandomSampler(valid_indices))
# Checkpointer
checkpointer = DetectronCheckpointer(model, optimizer, scheduler,save_dir = output_path, save_to_disk=True)
arguments = {}
arguments["epoch"] = 0
extra_checkpoint_data = checkpointer.load(pretrained_backbone)
arguments.update(extra_checkpoint_data)
arguments["max_epoch"] = epochs
# Train
logger.info("Training for {} epochs".format(arguments["max_epoch"]))
logger.info("Continuing from epoch {}".format(arguments["epoch"]))
do_valtrain(model
, train_loader
, optimizer
, criterion
, scheduler
, checkpointer
, checkpoint_period
, device
, arguments
, output_path
, valid_loader
)
\ No newline at end of file
......@@ -158,7 +158,7 @@ def read_metricscsv(file):
return np.array(precision), np.array(recall)
def plot_overview(outputfolders):
def plot_overview(outputfolders,title):
"""
Plots comparison chart of all trained models
......@@ -166,6 +166,8 @@ def plot_overview(outputfolders):
----------
outputfolder : list
list containing output paths of all evaluated models (e.g. ``['DRIVE/model1', 'DRIVE/model2']``)
title : str
title of plot
Returns
-------
matplotlib.figure.Figure
......@@ -181,15 +183,16 @@ def plot_overview(outputfolders):
precisions.append(pr)
recalls.append(re)
modelname = folder.split('/')[-1]
datasetname = folder.split('/')[-2]
# parameters
summary_path = os.path.join(folder,'results/ModelSummary.txt')
with open (summary_path, "r") as outfile:
rows = outfile.readlines()
lastrow = rows[-1]
parameter = int(lastrow.split()[1].replace(',',''))
name = '[P={:.2f}M] {}'.format(parameter/100**3, modelname)
name = '[P={:.2f}M] {} {}'.format(parameter/100**3, modelname, datasetname)
names.append(name)
title = folder.split('/')[-2]
#title = folder.split('/')[-4]
fig = precision_recall_f1iso(precisions,recalls,names,title)
return fig
......
......@@ -78,10 +78,12 @@ setup(
'DRIVE1024 = bob.ip.binseg.configs.datasets.drive1024',
'DRIVE1168 = bob.ip.binseg.configs.datasets.drive1168',
'DRIVETEST = bob.ip.binseg.configs.datasets.drivetest',
'DRIVESTARECHASEDB1HRF1024 = bob.ip.binseg.configs.datasets.drivestarechasedb1hrf1024',
'DRIVESTARECHASEDB1IOSTAR1168 = bob.ip.binseg.configs.datasets.drivestarechasedb1iostar1168',
'DRIVESTAREIOSTARHRF960 = bob.ip.binseg.configs.datasets.drivestareiostarhrf960',
'HRF = bob.ip.binseg.configs.datasets.hrf',
'HRF960 = bob.ip.binseg.configs.datasets.hrf960',
'HRF1024 = bob.ip.binseg.configs.datasets.hrf1024',
'HRF1168 = bob.ip.binseg.configs.datasets.hrf1168',
'HRFTEST = bob.ip.binseg.configs.datasets.hrftest',
'IOSTAROD = bob.ip.binseg.configs.datasets.iostarod',
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment