Skip to content
Snippets Groups Projects
Commit e193fbe6 authored by Tim Laibacher's avatar Tim Laibacher
Browse files

Update checkpointer, metrics and inference script. Add cli

parent 82cb7b06
No related branches found
No related tags found
No related merge requests found
Pipeline #29551 failed
Showing
with 521 additions and 56 deletions
...@@ -21,6 +21,7 @@ src/ ...@@ -21,6 +21,7 @@ src/
record.txt record.txt
core core
output_temp output_temp
output
### JupyterNotebook ### ### JupyterNotebook ###
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from bob.ip.binseg.data.transforms import ToTensor
from bob.ip.binseg.data.binsegdataset import BinSegDataset
from torch.utils.data import DataLoader
from bob.db.drive import Database as DRIVE
import torch
#### Config ####
# bob.db.dataset init
bobdb = DRIVE()
# transforms
transforms = ToTensor()
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from torch.optim.lr_scheduler import MultiStepLR
from bob.ip.binseg.modeling.driu import build_driu
import torch.optim as optim
from torch.nn import BCEWithLogitsLoss
from bob.ip.binseg.utils.model_zoo import modelurls
##### Config #####
pretrained_weight = 'vgg16'
lr = 0.001
betas = (0.9, 0.999)
eps = 1e-08
weight_decay = 0
amsgrad = False
scheduler_milestones = [150]
scheduler_gamma = 0.1
# model
model = build_driu()
# pretrained backbone
pretrained_backbone = modelurls['vgg16']
# optimizer
optimizer = optim.Adam(model.parameters(), lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad)
# criterion
criterion = BCEWithLogitsLoss()
# scheduler
scheduler = MultiStepLR(optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma)
...@@ -24,6 +24,7 @@ class BinSegDataset(Dataset): ...@@ -24,6 +24,7 @@ class BinSegDataset(Dataset):
def __init__(self, bobdb, split = None, transform = None): def __init__(self, bobdb, split = None, transform = None):
self.database = bobdb.samples(split) self.database = bobdb.samples(split)
self.transform = transform self.transform = transform
self.split = split
def __len__(self): def __len__(self):
""" """
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from bob.db.drive import Database as DRIVE from bob.db.drive import Database as DRIVE
from bob.ip.binseg.data.binsegdataset import BinSegDataset from bob.ip.binseg.data.binsegdataset import BinSegDataset
from bob.ip.binseg.data.transforms import ToTensor from bob.ip.binseg.data.transforms import ToTensor
......
...@@ -42,10 +42,10 @@ def batch_metrics(predictions, ground_truths, masks, names, output_folder, logge ...@@ -42,10 +42,10 @@ def batch_metrics(predictions, ground_truths, masks, names, output_folder, logge
# ground truth byte # ground truth byte
gts = ground_truths[j].byte() gts = ground_truths[j].byte()
single_metrics_file_path = os.path.join(output_folder, "{}.csv".format(names[j])) file_name = "{}.csv".format(names[j])
logger.info("saving {}".format(single_metrics_file_path)) logger.info("saving {}".format(file_name))
with open (single_metrics_file_path, "w+") as outfile: with open (os.path.join(output_folder,file_name), "w+") as outfile:
outfile.write("threshold, precision, recall, specificity, accuracy, jaccard, f1_score\n") outfile.write("threshold, precision, recall, specificity, accuracy, jaccard, f1_score\n")
...@@ -88,11 +88,13 @@ def batch_metrics(predictions, ground_truths, masks, names, output_folder, logge ...@@ -88,11 +88,13 @@ def batch_metrics(predictions, ground_truths, masks, names, output_folder, logge
def save_probability_images(predictions, names, output_folder, logger): def save_probability_images(predictions, names, output_folder, logger):
images_subfolder = os.path.join(output_folder,'images')
if not os.path.exists(images_subfolder): os.makedirs(images_subfolder)
for j in range(predictions.size()[0]): for j in range(predictions.size()[0]):
img = VF.to_pil_image(predictions.cpu().data[j]) img = VF.to_pil_image(predictions.cpu().data[j])
filename = '{}_prob.gif'.format(names[j]) filename = '{}_prob.gif'.format(names[j])
logger.info("saving {}".format(filename)) logger.info("saving {}".format(filename))
img.save(os.path.join(output_folder, filename)) img.save(os.path.join(images_subfolder, filename))
...@@ -104,8 +106,11 @@ def do_inference( ...@@ -104,8 +106,11 @@ def do_inference(
): ):
logger = logging.getLogger("bob.ip.binseg.engine.inference") logger = logging.getLogger("bob.ip.binseg.engine.inference")
logger.info("Start evaluation") logger.info("Start evaluation")
logger.info("Output folder: {}, Device: {}".format(output_folder, device)) logger.info("Split: {}, Output folder: {}, Device: {}".format(data_loader.dataset.split, output_folder, device))
model.eval() results_subfolder = os.path.join(output_folder,'results')
if not os.path.exists(results_subfolder): os.makedirs(results_subfolder)
model.eval().to(device)
# Sigmoid for probabilities # Sigmoid for probabilities
sigmoid = torch.nn.Sigmoid() sigmoid = torch.nn.Sigmoid()
...@@ -129,14 +134,12 @@ def do_inference( ...@@ -129,14 +134,12 @@ def do_inference(
times.append(batch_time) times.append(batch_time)
logger.info("Batch time: {:.5f} s".format(batch_time)) logger.info("Batch time: {:.5f} s".format(batch_time))
b_metrics = batch_metrics(probabilities, ground_truths, masks, names, output_folder, logger) b_metrics = batch_metrics(probabilities, ground_truths, masks, names, results_subfolder, logger)
metrics.extend(b_metrics) metrics.extend(b_metrics)
# Create probability images
save_probability_images(probabilities, names, output_folder, logger) save_probability_images(probabilities, names, output_folder, logger)
# NOTE: comment out for debugging
#with open (os.path.join(output_folder, "metrics.pkl"), "wb+") as outfile:
# logger.debug("Saving metrics to {}".format(output_folder))
# pickle.dump(metrics, outfile)
df_metrics = pd.DataFrame(metrics,columns= \ df_metrics = pd.DataFrame(metrics,columns= \
["name", ["name",
...@@ -148,12 +151,11 @@ def do_inference( ...@@ -148,12 +151,11 @@ def do_inference(
"jaccard", "jaccard",
"f1_score"]) "f1_score"])
# Report and Averages
metrics_file = "Metrics_{}.csv".format(model.name)
metrics_path = os.path.join(results_subfolder, metrics_file)
logger.info("Saving average over all input images: {}".format(metrics_file))
# Save to disk
metrics_path = os.path.join(output_folder, "Metrics.csv")
logging.info("Saving average over all inputs: {}".format(metrics_path))
# Report Averages
avg_metrics = df_metrics.groupby('threshold').mean() avg_metrics = df_metrics.groupby('threshold').mean()
avg_metrics.to_csv(metrics_path) avg_metrics.to_csv(metrics_path)
...@@ -163,12 +165,14 @@ def do_inference( ...@@ -163,12 +165,14 @@ def do_inference(
maxf1 = avg_metrics['f1_score'].max() maxf1 = avg_metrics['f1_score'].max()
optimal_f1_threshold = avg_metrics['f1_score'].idxmax() optimal_f1_threshold = avg_metrics['f1_score'].idxmax()
logging.info("Highest F1-score of {:.5f}, achieved at threshold {}".format(maxf1, optimal_f1_threshold)) logger.info("Highest F1-score of {:.5f}, achieved at threshold {}".format(maxf1, optimal_f1_threshold))
logging.info("Plotting Precision vs Recall") # Plotting
np_avg_metrics = avg_metrics.to_numpy().T np_avg_metrics = avg_metrics.to_numpy().T
fig = precision_recall_f1iso([np_avg_metrics[0]],[np_avg_metrics[1]],model.name) fig_name = "precision_recall_{}.pdf".format(model.name)
fig_filename = os.path.join(output_folder, 'simple-precision-recall.pdf') logger.info("saving {}".format(fig_name))
fig = precision_recall_f1iso([np_avg_metrics[0]],[np_avg_metrics[1]], model.name)
fig_filename = os.path.join(results_subfolder, fig_name)
fig.savefig(fig_filename) fig.savefig(fig_filename)
# Report times # Report times
...@@ -176,9 +180,16 @@ def do_inference( ...@@ -176,9 +180,16 @@ def do_inference(
average_batch_inference_time = np.mean(times) average_batch_inference_time = np.mean(times)
total_evalution_time = str(datetime.timedelta(seconds=int(time.time() - start_total_time ))) total_evalution_time = str(datetime.timedelta(seconds=int(time.time() - start_total_time )))
# Logging
logger.info("Total evaluation run-time: {}".format(total_evalution_time))
logger.info("Average batch inference time: {:.5f}s".format(average_batch_inference_time)) logger.info("Average batch inference time: {:.5f}s".format(average_batch_inference_time))
logger.info("Total inference time: {}".format(total_inference_time))
times_file = "Times_{}.txt".format(model.name)
logger.info("saving {}".format(times_file))
with open (os.path.join(results_subfolder,times_file), "w+") as outfile:
date = datetime.datetime.now()
outfile.write("Date: {} \n".format(date.strftime("%Y-%m-%d %H:%M:%S")))
outfile.write("Total evaluation run-time: {} \n".format(total_evalution_time))
outfile.write("Average batch inference time: {} \n".format(average_batch_inference_time))
outfile.write("Total inference time: {} \n".format(total_inference_time))
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from bob.db.drive import Database as DRIVE from bob.db.drive import Database as DRIVE
from bob.ip.binseg.data.binsegdataset import BinSegDataset from bob.ip.binseg.data.binsegdataset import BinSegDataset
from bob.ip.binseg.data.transforms import ToTensor from bob.ip.binseg.data.transforms import ToTensor
...@@ -6,7 +9,7 @@ from bob.ip.binseg.modeling.driu import build_driu ...@@ -6,7 +9,7 @@ from bob.ip.binseg.modeling.driu import build_driu
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import torch.optim as optim import torch.optim as optim
from torch.optim.lr_scheduler import MultiStepLR from torch.optim.lr_scheduler import MultiStepLR
from bob.ip.binseg.utils.checkpointer import Checkpointer from bob.ip.binseg.utils.checkpointer import Checkpointer, DetectronCheckpointer
from torch.nn import BCEWithLogitsLoss from torch.nn import BCEWithLogitsLoss
import logging import logging
...@@ -45,11 +48,14 @@ def train(): ...@@ -45,11 +48,14 @@ def train():
scheduler = MultiStepLR(optimizer, milestones=[150], gamma=0.1) scheduler = MultiStepLR(optimizer, milestones=[150], gamma=0.1)
# checkpointer # checkpointer
checkpointer = Checkpointer(model, optimizer, scheduler,save_dir = "./output_temp", save_to_disk=True) checkpointer = DetectronCheckpointer(model, optimizer, scheduler,save_dir = "./output_temp", save_to_disk=True)
# checkpoint period (iteration) # checkpoint period
checkpoint_period = 2 checkpoint_period = 2
# pretrained backbone
pretraind_backbone = model_urls['vgg16']
# device # device
device = "cpu" device = "cpu"
...@@ -57,6 +63,8 @@ def train(): ...@@ -57,6 +63,8 @@ def train():
arguments = {} arguments = {}
arguments["epoch"] = 0 arguments["epoch"] = 0
arguments["max_epoch"] = 6 arguments["max_epoch"] = 6
extra_checkpoint_data = checkpointer.load(pretraind_backbone)
arguments.update(extra_checkpoint_data)
logger.info("Training for {} epochs".format(arguments["max_epoch"])) logger.info("Training for {} epochs".format(arguments["max_epoch"]))
logger.info("Continuing from epoch {}".format(arguments["epoch"])) logger.info("Continuing from epoch {}".format(arguments["epoch"]))
do_train(model do_train(model
......
...@@ -20,12 +20,13 @@ def do_train( ...@@ -20,12 +20,13 @@ def do_train(
device, device,
arguments arguments
): ):
""" Trains the model """
logger = logging.getLogger("bob.ip.binseg.engine.trainer") logger = logging.getLogger("bob.ip.binseg.engine.trainer")
logger.info("Start training") logger.info("Start training")
start_epoch = arguments["epoch"] start_epoch = arguments["epoch"]
max_epoch = arguments["max_epoch"] max_epoch = arguments["max_epoch"]
model.train() model.train().to(device)
# Total training timer # Total training timer
start_training_time = time.time() start_training_time = time.time()
......
...@@ -17,6 +17,7 @@ model_urls = { ...@@ -17,6 +17,7 @@ model_urls = {
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
'resnet50__SIN_IN': 'https://bitbucket.org/robert_geirhos/texture-vs-shape-pretrained-models/raw/60b770e128fffcbd8562a3ab3546c1a735432d03/resnet50_finetune_60_epochs_lr_decay_after_30_start_resnet50_train_45_epochs_combined_IN_SF-ca06340c.pth.tar',
} }
...@@ -222,4 +223,4 @@ def resnet152(pretrained=False, **kwargs): ...@@ -222,4 +223,4 @@ def resnet152(pretrained=False, **kwargs):
model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
if pretrained: if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
return model return model
\ No newline at end of file \ No newline at end of file
...@@ -60,7 +60,7 @@ class VGG(nn.Module): ...@@ -60,7 +60,7 @@ class VGG(nn.Module):
def make_layers(cfg, batch_norm=False): def make_layers(cfg, batch_norm=False):
layers = nn.ModuleList() layers = []
in_channels = 3 in_channels = 3
for v in cfg: for v in cfg:
if v == 'M': if v == 'M':
...@@ -68,11 +68,11 @@ def make_layers(cfg, batch_norm=False): ...@@ -68,11 +68,11 @@ def make_layers(cfg, batch_norm=False):
else: else:
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
if batch_norm: if batch_norm:
layers.append(nn.Sequential(conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True))) layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
else: else:
layers.append(nn.Sequential(conv2d, nn.ReLU(inplace=True))) layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = v in_channels = v
return layers return nn.Sequential(*layers)
cfg = { cfg = {
...@@ -144,7 +144,7 @@ def vgg16(pretrained=False, **kwargs): ...@@ -144,7 +144,7 @@ def vgg16(pretrained=False, **kwargs):
kwargs['init_weights'] = False kwargs['init_weights'] = False
model = VGG(make_layers(cfg['D']), **kwargs) model = VGG(make_layers(cfg['D']), **kwargs)
if pretrained: if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['vgg16'])) model.load_state_dict(model_zoo.load_url(model_urls['vgg16']),strict=False)
return model return model
......
...@@ -53,7 +53,8 @@ class DRIU(nn.Module): ...@@ -53,7 +53,8 @@ class DRIU(nn.Module):
return out return out
def build_driu(): def build_driu():
backbone = vgg16(pretrained=False, return_features = [1, 4, 8, 12]) #backbone = vgg16(pretrained=False, return_features = [1, 4, 8, 12])
backbone = vgg16(pretrained=False, return_features = [3, 8, 14, 22])
driu_head = DRIU([64, 128, 256, 512]) driu_head = DRIU([64, 128, 256, 512])
model = nn.Sequential(OrderedDict([("backbone", backbone), ("head", driu_head)])) model = nn.Sequential(OrderedDict([("backbone", backbone), ("head", driu_head)]))
......
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
"""The main entry for bob ip binseg (click-based) scripts."""
import os
import time
import numpy
import collections
import pkg_resources
import click
from click_plugins import with_plugins
import logging
import torch
import bob.extension
from bob.extension.scripts.click_helper import (verbosity_option,
ConfigCommand, ResourceOption, AliasedGroup)
from bob.ip.binseg.utils.checkpointer import DetectronCheckpointer
from bob.ip.binseg.data.binsegdataset import BinSegDataset
from torch.utils.data import DataLoader
from bob.ip.binseg.engine.trainer import do_train
from bob.ip.binseg.engine.inferencer import do_inference
logger = logging.getLogger(__name__)
@with_plugins(pkg_resources.iter_entry_points('bob.ip.binseg.cli'))
@click.group(cls=AliasedGroup)
def binseg():
"""Binary 2D Fundus Image Segmentation Benchmark commands."""
pass
# Train
@binseg.command(entry_point_group='bob.ip.binseg.config', cls=ConfigCommand)
@click.option(
'--output-path',
'-o',
required=True,
default="output",
cls=ResourceOption
)
@click.option(
'--model',
'-m',
required=True,
cls=ResourceOption
)
@click.option(
'--optimizer',
required=True,
cls=ResourceOption
)
@click.option(
'--criterion',
required=True,
cls=ResourceOption
)
@click.option(
'--scheduler',
required=True,
cls=ResourceOption
)
@click.option(
'--pretrained-backbone',
required=True,
cls=ResourceOption
)
@click.option(
'--bobdb',
required=True,
cls=ResourceOption
)
@click.option(
'--split',
'-s',
required=True,
default='train',
cls=ResourceOption)
@click.option(
'--transforms',
required=True,
cls=ResourceOption)
@click.option(
'--batch-size',
'-b',
required=True,
default=2,
cls=ResourceOption)
@click.option(
'--epochs',
'-e',
help='Number of epochs used for training',
show_default=True,
required=True,
default=6,
cls=ResourceOption)
@click.option(
'--checkpoint-period',
'-p',
help='Number of epochs after which a checkpoint is saved',
show_default=True,
required=True,
default=2,
cls=ResourceOption)
@click.option(
'--device',
'-d',
help='A string indicating the device to use (e.g. "cpu" or "cuda:0"',
show_default=True,
required=True,
default='cpu',
cls=ResourceOption)
@verbosity_option(cls=ResourceOption)
def train(model
,optimizer
,scheduler
,output_path
,epochs
,pretrained_backbone
,split
,batch_size
,criterion
,bobdb
,transforms
,checkpoint_period
,device
,**kwargs):
if not os.path.exists(output_path): os.makedirs(output_path)
# PyTorch dataset
bsdataset = BinSegDataset(bobdb, split=split, transform=transforms)
# PyTorch dataloader
data_loader = DataLoader(
dataset = bsdataset
,batch_size = batch_size
,shuffle= True
,pin_memory = torch.cuda.is_available()
)
# Checkpointer
checkpointer = DetectronCheckpointer(model, optimizer, scheduler,save_dir = output_path, save_to_disk=True)
arguments = {}
arguments["epoch"] = 0
arguments["max_epoch"] = epochs
extra_checkpoint_data = checkpointer.load(pretrained_backbone)
arguments.update(extra_checkpoint_data)
# Train
logger.info("Training for {} epochs".format(arguments["max_epoch"]))
logger.info("Continuing from epoch {}".format(arguments["epoch"]))
do_train(model
, data_loader
, optimizer
, criterion
, scheduler
, checkpointer
, checkpoint_period
, device
, arguments
)
# Inference
@binseg.command(entry_point_group='bob.ip.binseg.config', cls=ConfigCommand)
@click.option(
'--output-path',
'-o',
required=True,
default="output",
cls=ResourceOption
)
@click.option(
'--model',
'-m',
required=True,
cls=ResourceOption
)
@click.option(
'--bobdb',
required=True,
cls=ResourceOption
)
@click.option(
'--transforms',
required=True,
cls=ResourceOption)
@click.option(
'--split',
'-s',
required=True,
default='test',
cls=ResourceOption)
@click.option(
'--batch-size',
'-b',
required=True,
default=2,
cls=ResourceOption)
@click.option(
'--device',
'-d',
help='A string indicating the device to use (e.g. "cpu" or "cuda:0"',
show_default=True,
required=True,
default='cpu',
cls=ResourceOption)
@verbosity_option(cls=ResourceOption)
def test(model
,output_path
,device
,split
,batch_size
,bobdb
,transforms
, **kwargs):
# PyTorch dataset
bsdataset = BinSegDataset(bobdb, split=split, transform=transforms)
# PyTorch dataloader
data_loader = DataLoader(
dataset = bsdataset
,batch_size = batch_size
,shuffle= False
,pin_memory = torch.cuda.is_available()
)
# checkpointer, load last model in dir
checkpointer = DetectronCheckpointer(model, save_dir = output_path, save_to_disk=False)
checkpointer.load()
do_inference(model, data_loader, device, output_path)
\ No newline at end of file
# see https://docs.python.org/3/library/pkgutil.html
from pkgutil import extend_path
__path__ = extend_path(__path__, __name__)
\ No newline at end of file
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Adapted from https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/engine/trainer.py # Adapted from https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/engine/trainer.py
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
...@@ -5,6 +8,7 @@ import logging ...@@ -5,6 +8,7 @@ import logging
import torch import torch
import os import os
from bob.ip.binseg.utils.model_serialization import load_state_dict from bob.ip.binseg.utils.model_serialization import load_state_dict
from bob.ip.binseg.utils.model_zoo import cache_url
class Checkpointer: class Checkpointer:
def __init__( def __init__(
...@@ -42,8 +46,8 @@ class Checkpointer: ...@@ -42,8 +46,8 @@ class Checkpointer:
save_file = os.path.join(self.save_dir, "{}.pth".format(name)) save_file = os.path.join(self.save_dir, "{}.pth".format(name))
self.logger.info("Saving checkpoint to {}".format(save_file)) self.logger.info("Saving checkpoint to {}".format(save_file))
torch.save(data, str(save_file)) torch.save(data, save_file)
self.tag_last_checkpoint(str(save_file)) self.tag_last_checkpoint(save_file)
def load(self, f=None): def load(self, f=None):
if self.has_checkpoint(): if self.has_checkpoint():
...@@ -91,4 +95,32 @@ class Checkpointer: ...@@ -91,4 +95,32 @@ class Checkpointer:
return torch.load(f, map_location=torch.device("cpu")) return torch.load(f, map_location=torch.device("cpu"))
def _load_model(self, checkpoint): def _load_model(self, checkpoint):
load_state_dict(self.model, checkpoint.pop("model")) load_state_dict(self.model, checkpoint.pop("model"))
\ No newline at end of file
class DetectronCheckpointer(Checkpointer):
def __init__(
self,
model,
optimizer=None,
scheduler=None,
save_dir="",
save_to_disk=None,
logger=None,
):
super(DetectronCheckpointer, self).__init__(
model, optimizer, scheduler, save_dir, save_to_disk, logger
)
def _load_file(self, f):
# download url files
if f.startswith("http"):
# if the file is a url path, download it and cache it
cached_f = cache_url(f)
self.logger.info("url {} cached in {}".format(f, cached_f))
f = cached_f
# load checkpoint
loaded = super(DetectronCheckpointer, self)._load_file(f)
if "model" not in loaded:
loaded = dict(model=loaded)
return loaded
\ No newline at end of file
...@@ -20,7 +20,9 @@ def align_and_update_state_dicts(model_state_dict, loaded_state_dict): ...@@ -20,7 +20,9 @@ def align_and_update_state_dicts(model_state_dict, loaded_state_dict):
backbone[0].body.res2.conv1.weight to res2.conv1.weight. backbone[0].body.res2.conv1.weight to res2.conv1.weight.
""" """
current_keys = sorted(list(model_state_dict.keys())) current_keys = sorted(list(model_state_dict.keys()))
print(current_keys)
loaded_keys = sorted(list(loaded_state_dict.keys())) loaded_keys = sorted(list(loaded_state_dict.keys()))
print(loaded_keys)
# get a matrix of string matches, where each (i, j) entry correspond to the size of the # get a matrix of string matches, where each (i, j) entry correspond to the size of the
# loaded_key string, if it matches # loaded_key string, if it matches
match_matrix = [ match_matrix = [
......
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# Adpated from:
# https://github.com/pytorch/pytorch/blob/master/torch/hub.py
# https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/utils/checkpoint.py
import errno
import hashlib
import os
import re
import shutil
import sys
import tempfile
import torch
import warnings
import zipfile
from urllib.request import urlopen
from urllib.parse import urlparse
modelurls = {
"vgg11": "https://download.pytorch.org/models/vgg11-bbd30ac9.pth",
"vgg13": "https://download.pytorch.org/models/vgg13-c768596a.pth",
"vgg16": "https://download.pytorch.org/models/vgg16-397923af.pth",
"vgg19": "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
"vgg11_bn": "https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
"vgg13_bn": "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
"vgg16_bn": "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
"vgg19_bn": "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth",
"resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth",
"resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth",
"resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth",
"resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth",
"resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth",
"resnet50_SIN_IN": "https://bitbucket.org/robert_geirhos/texture-vs-shape-pretrained-models/raw/60b770e128fffcbd8562a3ab3546c1a735432d03/resnet50_finetune_60_epochs_lr_decay_after_30_start_resnet50_train_45_epochs_combined_IN_SF-ca06340c.pth.tar",
}
def _download_url_to_file(url, dst, hash_prefix, progress):
file_size = None
u = urlopen(url)
meta = u.info()
if hasattr(meta, 'getheaders'):
content_length = meta.getheaders("Content-Length")
else:
content_length = meta.get_all("Content-Length")
if content_length is not None and len(content_length) > 0:
file_size = int(content_length[0])
f = tempfile.NamedTemporaryFile(delete=False)
try:
if hash_prefix is not None:
sha256 = hashlib.sha256()
with tqdm(total=file_size, disable=not progress) as pbar:
while True:
buffer = u.read(8192)
if len(buffer) == 0:
break
f.write(buffer)
if hash_prefix is not None:
sha256.update(buffer)
pbar.update(len(buffer))
f.close()
if hash_prefix is not None:
digest = sha256.hexdigest()
if digest[:len(hash_prefix)] != hash_prefix:
raise RuntimeError('invalid hash value (expected "{}", got "{}")'
.format(hash_prefix, digest))
shutil.move(f.name, dst)
finally:
f.close()
if os.path.exists(f.name):
os.remove(f.name)
HASH_REGEX = re.compile(r'-([a-f0-9]*)\.')
def cache_url(url, model_dir=None, progress=True):
r"""Loads the Torch serialized object at the given URL.
If the object is already present in `model_dir`, it's deserialized and
returned. The filename part of the URL should follow the naming convention
``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more
digits of the SHA256 hash of the contents of the file. The hash is used to
ensure unique names and to verify the contents of the file.
The default value of `model_dir` is ``$TORCH_HOME/models`` where
``$TORCH_HOME`` defaults to ``~/.torch``. The default directory can be
overridden with the ``$TORCH_MODEL_ZOO`` environment variable.
Args:
url (string): URL of the object to download
model_dir (string, optional): directory in which to save the object
progress (bool, optional): whether or not to display a progress bar to stderr
"""
if model_dir is None:
torch_home = os.path.expanduser(os.getenv("TORCH_HOME", "~/.torch"))
model_dir = os.getenv("TORCH_MODEL_ZOO", os.path.join(torch_home, "models"))
if not os.path.exists(model_dir):
os.makedirs(model_dir)
parts = urlparse(url)
filename = os.path.basename(parts.path)
cached_file = os.path.join(model_dir, filename)
if not os.path.exists(cached_file):
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
hash_prefix = HASH_REGEX.search(filename)
if hash_prefix is not None:
hash_prefix = hash_prefix.group(1)
_download_url_to_file(url, cached_file, hash_prefix, progress=progress)
return cached_file
\ No newline at end of file
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# author='Andre Anjos',
# author_email='andre.anjos@idiap.ch',
import numpy as np import numpy as np
def precision_recall_f1iso(precision, recall, names, title=None, human_perf_bsds500=False): def precision_recall_f1iso(precision, recall, names, title=None, human_perf_bsds500=False):
......
...@@ -20,22 +20,24 @@ build: ...@@ -20,22 +20,24 @@ build:
- install -d "${PREFIX}/share/doc/{{ name }}" - install -d "${PREFIX}/share/doc/{{ name }}"
- cp -R README.rst requirements.txt doc "${PREFIX}/share/doc/{{ name }}/" - cp -R README.rst requirements.txt doc "${PREFIX}/share/doc/{{ name }}/"
requirements:
build:
- numpy {{ numpy }}
run:
- {{ pin_compatible('numpy', max_pin='x.x') }}
requirements: requirements:
# place your build dependencies before the 'host' section # place your build dependencies before the 'host' section
host: host:
- python {{ python }} - python {{ python }}
- setuptools {{ setuptools }} - setuptools {{ setuptools }}
- torchvision {{ torchvision }} - torchvision {{ torchvision }}
- pytorch {{ pytorch } - pytorch {{ pytorch }}
- numpy {{ numpy }} - numpy {{ numpy }}
- bob.extension - bob.extension
# place your other host dependencies here
run:
- python
- setuptools
- {{ pin_compatible('pytorch') }}
- {{ pin_compatible('torchvision') }}
- {{ pin_compatible('numpy') }}
- pandas
- matplotlib
- bob.db.drive - bob.db.drive
- bob.db.stare - bob.db.stare
- bob.db.chasedb1 - bob.db.chasedb1
...@@ -45,15 +47,6 @@ requirements: ...@@ -45,15 +47,6 @@ requirements:
- bob.db.drishtigs1 - bob.db.drishtigs1
- bob.db.refuge - bob.db.refuge
- bob.db.iostar - bob.db.iostar
# place your other host dependencies here
run:
- python
- setuptools
- pytorch {{ pin_compatible('pytorch') }}
- torchvision {{ pin_compatible('torchvision') }}
- numpy {{ pin_compatible('numpy') }}
- pandas
- matplotlib
# place other runtime dependencies here (same as requirements.txt) # place other runtime dependencies here (same as requirements.txt)
test: test:
...@@ -61,6 +54,8 @@ test: ...@@ -61,6 +54,8 @@ test:
- {{ name }} - {{ name }}
commands: commands:
# test commands ("script" entry-points) from your package here # test commands ("script" entry-points) from your package here
- bob binseg --help
- bob binseg train --help
- nosetests --with-coverage --cover-package={{ name }} -sv {{ name }} - nosetests --with-coverage --cover-package={{ name }} -sv {{ name }}
- sphinx-build -aEW {{ project_dir }}/doc {{ project_dir }}/sphinx - sphinx-build -aEW {{ project_dir }}/doc {{ project_dir }}/sphinx
- sphinx-build -aEb doctest {{ project_dir }}/doc sphinx - sphinx-build -aEb doctest {{ project_dir }}/doc sphinx
...@@ -76,7 +71,7 @@ test: ...@@ -76,7 +71,7 @@ test:
about: about:
summary: Binary Segmentation Benchmark Package for Bob summary: Binary Segmentation Benchmark Package for Bob
home: https://www.idiap.ch/software/{{ group }}/ home: https://www.idiap.ch/software/bob/
license: GNU General Public License v3 (GPLv3) license: GNU General Public License v3 (GPLv3)
license_family: GPL license_family: GPL
license_file: ../COPYING license_file: ../COPYING
\ No newline at end of file
...@@ -233,6 +233,7 @@ if os.path.exists(sphinx_requirements): ...@@ -233,6 +233,7 @@ if os.path.exists(sphinx_requirements):
else: else:
intersphinx_mapping = link_documentation() intersphinx_mapping = link_documentation()
intersphinx_mapping['torch'] = ('https://pytorch.org/docs/stable/', None)
# We want to remove all private (i.e. _. or __.__) members # We want to remove all private (i.e. _. or __.__) members
# that are not in the list of accepted functions # that are not in the list of accepted functions
accepted_private_functions = ['__array__'] accepted_private_functions = ['__array__']
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment