Skip to content
Snippets Groups Projects
Commit 833335da authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[all] Passed black on all python files

parent c3dc9915
No related branches found
No related tags found
1 merge request!12Streamlining
Pipeline #38202 passed
Showing
with 760 additions and 476 deletions
...@@ -8,6 +8,7 @@ import torch ...@@ -8,6 +8,7 @@ import torch
import torchvision.transforms.functional as VF import torchvision.transforms.functional as VF
import bob.io.base import bob.io.base
def get_file_lists(data_path, glob): def get_file_lists(data_path, glob):
""" """
Recursively retrieves file lists from a given path, matching a given glob Recursively retrieves file lists from a given path, matching a given glob
...@@ -20,6 +21,7 @@ def get_file_lists(data_path, glob): ...@@ -20,6 +21,7 @@ def get_file_lists(data_path, glob):
image_file_names = np.array(sorted(list(data_path.rglob(glob)))) image_file_names = np.array(sorted(list(data_path.rglob(glob))))
return image_file_names return image_file_names
class ImageFolderInference(Dataset): class ImageFolderInference(Dataset):
""" """
Generic ImageFolder containing images for inference Generic ImageFolder containing images for inference
...@@ -43,7 +45,8 @@ class ImageFolderInference(Dataset): ...@@ -43,7 +45,8 @@ class ImageFolderInference(Dataset):
List of transformations to apply to every input sample List of transformations to apply to every input sample
""" """
def __init__(self, path, glob='*', transform = None):
def __init__(self, path, glob="*", transform=None):
self.transform = transform self.transform = transform
self.path = path self.path = path
self.img_file_list = get_file_lists(path, glob) self.img_file_list = get_file_lists(path, glob)
...@@ -57,7 +60,7 @@ class ImageFolderInference(Dataset): ...@@ -57,7 +60,7 @@ class ImageFolderInference(Dataset):
""" """
return len(self.img_file_list) return len(self.img_file_list)
def __getitem__(self,index): def __getitem__(self, index):
""" """
Parameters Parameters
---------- ----------
...@@ -74,9 +77,9 @@ class ImageFolderInference(Dataset): ...@@ -74,9 +77,9 @@ class ImageFolderInference(Dataset):
sample = [img] sample = [img]
if self.transform : if self.transform:
sample = self.transform(*sample) sample = self.transform(*sample)
sample.insert(0,img_name) sample.insert(0, img_name)
return sample return sample
...@@ -22,17 +22,18 @@ import collections ...@@ -22,17 +22,18 @@ import collections
import bob.core import bob.core
_pil_interpolation_to_str = { _pil_interpolation_to_str = {
Image.NEAREST: 'PIL.Image.NEAREST', Image.NEAREST: "PIL.Image.NEAREST",
Image.BILINEAR: 'PIL.Image.BILINEAR', Image.BILINEAR: "PIL.Image.BILINEAR",
Image.BICUBIC: 'PIL.Image.BICUBIC', Image.BICUBIC: "PIL.Image.BICUBIC",
Image.LANCZOS: 'PIL.Image.LANCZOS', Image.LANCZOS: "PIL.Image.LANCZOS",
Image.HAMMING: 'PIL.Image.HAMMING', Image.HAMMING: "PIL.Image.HAMMING",
Image.BOX: 'PIL.Image.BOX', Image.BOX: "PIL.Image.BOX",
} }
Iterable = collections.abc.Iterable Iterable = collections.abc.Iterable
# Compose # Compose
class Compose: class Compose:
"""Composes several transforms. """Composes several transforms.
...@@ -51,15 +52,17 @@ class Compose: ...@@ -51,15 +52,17 @@ class Compose:
return args return args
def __repr__(self): def __repr__(self):
format_string = self.__class__.__name__ + '(' format_string = self.__class__.__name__ + "("
for t in self.transforms: for t in self.transforms:
format_string += '\n' format_string += "\n"
format_string += ' {0}'.format(t) format_string += " {0}".format(t)
format_string += '\n)' format_string += "\n)"
return format_string return format_string
# Preprocessing # Preprocessing
class CenterCrop: class CenterCrop:
""" """
Crop at the center. Crop at the center.
...@@ -69,6 +72,7 @@ class CenterCrop: ...@@ -69,6 +72,7 @@ class CenterCrop:
size : int size : int
target size target size
""" """
def __init__(self, size): def __init__(self, size):
self.size = size self.size = size
...@@ -91,6 +95,7 @@ class Crop: ...@@ -91,6 +95,7 @@ class Crop:
w : int w : int
width of the cropped image. width of the cropped image.
""" """
def __init__(self, i, j, h, w): def __init__(self, i, j, h, w):
self.i = i self.i = i
self.j = j self.j = j
...@@ -98,7 +103,10 @@ class Crop: ...@@ -98,7 +103,10 @@ class Crop:
self.w = w self.w = w
def __call__(self, *args): def __call__(self, *args):
return [img.crop((self.j, self.i, self.j + self.w, self.i + self.h)) for img in args] return [
img.crop((self.j, self.i, self.j + self.w, self.i + self.h)) for img in args
]
class Pad: class Pad:
""" """
...@@ -115,12 +123,17 @@ class Pad: ...@@ -115,12 +123,17 @@ class Pad:
pixel fill value for constant fill. Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively. pixel fill value for constant fill. Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively.
This value is only used when the padding_mode is constant This value is only used when the padding_mode is constant
""" """
def __init__(self, padding, fill=0): def __init__(self, padding, fill=0):
self.padding = padding self.padding = padding
self.fill = fill self.fill = fill
def __call__(self, *args): def __call__(self, *args):
return [VF.pad(img, self.padding, self.fill, padding_mode='constant') for img in args] return [
VF.pad(img, self.padding, self.fill, padding_mode="constant")
for img in args
]
class AutoLevel16to8: class AutoLevel16to8:
"""Converts a 16-bit image to 8-bit representation using "auto-level" """Converts a 16-bit image to 8-bit representation using "auto-level"
...@@ -131,13 +144,16 @@ class AutoLevel16to8: ...@@ -131,13 +144,16 @@ class AutoLevel16to8:
consider such a range should be mapped to the [0,255] range of the consider such a range should be mapped to the [0,255] range of the
destination image. destination image.
""" """
def _process_one(self, img): def _process_one(self, img):
return Image.fromarray(bob.core.convert(img, 'uint8', (0,255), return Image.fromarray(
img.getextrema())) bob.core.convert(img, "uint8", (0, 255), img.getextrema())
)
def __call__(self, *args): def __call__(self, *args):
return [self._process_one(img) for img in args] return [self._process_one(img) for img in args]
class ToRGB: class ToRGB:
"""Converts from any input format to RGB, using an ADAPTIVE conversion. """Converts from any input format to RGB, using an ADAPTIVE conversion.
...@@ -146,17 +162,21 @@ class ToRGB: ...@@ -146,17 +162,21 @@ class ToRGB:
defaults. This may be aggressive if applied to 16-bit images without defaults. This may be aggressive if applied to 16-bit images without
further considerations. further considerations.
""" """
def __call__(self, *args): def __call__(self, *args):
return [img.convert(mode="RGB") for img in args] return [img.convert(mode="RGB") for img in args]
class ToTensor: class ToTensor:
"""Converts :py:class:`PIL.Image.Image` to :py:class:`torch.Tensor` """ """Converts :py:class:`PIL.Image.Image` to :py:class:`torch.Tensor` """
def __call__(self, *args): def __call__(self, *args):
return [VF.to_tensor(img) for img in args] return [VF.to_tensor(img) for img in args]
# Augmentations # Augmentations
class RandomHFlip: class RandomHFlip:
""" """
Flips horizontally Flips horizontally
...@@ -166,7 +186,8 @@ class RandomHFlip: ...@@ -166,7 +186,8 @@ class RandomHFlip:
prob : float prob : float
probability at which imgage is flipped. Defaults to ``0.5`` probability at which imgage is flipped. Defaults to ``0.5``
""" """
def __init__(self, prob = 0.5):
def __init__(self, prob=0.5):
self.prob = prob self.prob = prob
def __call__(self, *args): def __call__(self, *args):
...@@ -186,7 +207,8 @@ class RandomVFlip: ...@@ -186,7 +207,8 @@ class RandomVFlip:
prob : float prob : float
probability at which imgage is flipped. Defaults to ``0.5`` probability at which imgage is flipped. Defaults to ``0.5``
""" """
def __init__(self, prob = 0.5):
def __init__(self, prob=0.5):
self.prob = prob self.prob = prob
def __call__(self, *args): def __call__(self, *args):
...@@ -208,17 +230,19 @@ class RandomRotation: ...@@ -208,17 +230,19 @@ class RandomRotation:
prob : float prob : float
probability at which imgage is rotated. Defaults to ``0.5`` probability at which imgage is rotated. Defaults to ``0.5``
""" """
def __init__(self, degree_range = (-15, +15), prob = 0.5):
def __init__(self, degree_range=(-15, +15), prob=0.5):
self.prob = prob self.prob = prob
self.degree_range = degree_range self.degree_range = degree_range
def __call__(self, *args): def __call__(self, *args):
if random.random() < self.prob: if random.random() < self.prob:
degree = random.randint(*self.degree_range) degree = random.randint(*self.degree_range)
return [VF.rotate(img, degree, resample = Image.BILINEAR) for img in args] return [VF.rotate(img, degree, resample=Image.BILINEAR) for img in args]
else: else:
return args return args
class ColorJitter(object): class ColorJitter(object):
""" """
Randomly change the brightness, contrast, saturation and hue Randomly change the brightness, contrast, saturation and hue
...@@ -240,7 +264,10 @@ class ColorJitter(object): ...@@ -240,7 +264,10 @@ class ColorJitter(object):
prob : float prob : float
probability at which the operation is applied probability at which the operation is applied
""" """
def __init__(self, brightness=0.3, contrast=0.3, saturation=0.02, hue=0.02, prob=0.5):
def __init__(
self, brightness=0.3, contrast=0.3, saturation=0.02, hue=0.02, prob=0.5
):
self.brightness = brightness self.brightness = brightness
self.contrast = contrast self.contrast = contrast
self.saturation = saturation self.saturation = saturation
...@@ -252,15 +279,21 @@ class ColorJitter(object): ...@@ -252,15 +279,21 @@ class ColorJitter(object):
transforms = [] transforms = []
if brightness > 0: if brightness > 0:
brightness_factor = random.uniform(max(0, 1 - brightness), 1 + brightness) brightness_factor = random.uniform(max(0, 1 - brightness), 1 + brightness)
transforms.append(Lambda(lambda img: VF.adjust_brightness(img, brightness_factor))) transforms.append(
Lambda(lambda img: VF.adjust_brightness(img, brightness_factor))
)
if contrast > 0: if contrast > 0:
contrast_factor = random.uniform(max(0, 1 - contrast), 1 + contrast) contrast_factor = random.uniform(max(0, 1 - contrast), 1 + contrast)
transforms.append(Lambda(lambda img: VF.adjust_contrast(img, contrast_factor))) transforms.append(
Lambda(lambda img: VF.adjust_contrast(img, contrast_factor))
)
if saturation > 0: if saturation > 0:
saturation_factor = random.uniform(max(0, 1 - saturation), 1 + saturation) saturation_factor = random.uniform(max(0, 1 - saturation), 1 + saturation)
transforms.append(Lambda(lambda img: VF.adjust_saturation(img, saturation_factor))) transforms.append(
Lambda(lambda img: VF.adjust_saturation(img, saturation_factor))
)
if hue > 0: if hue > 0:
hue_factor = random.uniform(-hue, hue) hue_factor = random.uniform(-hue, hue)
...@@ -273,8 +306,9 @@ class ColorJitter(object): ...@@ -273,8 +306,9 @@ class ColorJitter(object):
def __call__(self, *args): def __call__(self, *args):
if random.random() < self.prob: if random.random() < self.prob:
transform = self.get_params(self.brightness, self.contrast, transform = self.get_params(
self.saturation, self.hue) self.brightness, self.contrast, self.saturation, self.hue
)
trans_img = transform(args[0]) trans_img = transform(args[0])
return [trans_img, *args[1:]] return [trans_img, *args[1:]]
else: else:
...@@ -301,7 +335,14 @@ class RandomResizedCrop: ...@@ -301,7 +335,14 @@ class RandomResizedCrop:
probability at which the operation is applied. Defaults to ``0.5`` probability at which the operation is applied. Defaults to ``0.5``
""" """
def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR, prob = 0.5): def __init__(
self,
size,
scale=(0.08, 1.0),
ratio=(3.0 / 4.0, 4.0 / 3.0),
interpolation=Image.BILINEAR,
prob=0.5,
):
if isinstance(size, tuple): if isinstance(size, tuple):
self.size = size self.size = size
else: else:
...@@ -333,10 +374,10 @@ class RandomResizedCrop: ...@@ -333,10 +374,10 @@ class RandomResizedCrop:
# Fallback to central crop # Fallback to central crop
in_ratio = img.size[0] / img.size[1] in_ratio = img.size[0] / img.size[1]
if (in_ratio < min(ratio)): if in_ratio < min(ratio):
w = img.size[0] w = img.size[0]
h = w / min(ratio) h = w / min(ratio)
elif (in_ratio > max(ratio)): elif in_ratio > max(ratio):
h = img.size[1] h = img.size[1]
w = h * max(ratio) w = h * max(ratio)
else: # whole image else: # whole image
...@@ -359,10 +400,10 @@ class RandomResizedCrop: ...@@ -359,10 +400,10 @@ class RandomResizedCrop:
def __repr__(self): def __repr__(self):
interpolate_str = _pil_interpolation_to_str[self.interpolation] interpolate_str = _pil_interpolation_to_str[self.interpolation]
format_string = self.__class__.__name__ + '(size={0}'.format(self.size) format_string = self.__class__.__name__ + "(size={0}".format(self.size)
format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale)) format_string += ", scale={0}".format(tuple(round(s, 4) for s in self.scale))
format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio)) format_string += ", ratio={0}".format(tuple(round(r, 4) for r in self.ratio))
format_string += ', interpolation={0})'.format(interpolate_str) format_string += ", interpolation={0})".format(interpolate_str)
return format_string return format_string
...@@ -391,4 +432,6 @@ class Resize: ...@@ -391,4 +432,6 @@ class Resize:
def __repr__(self): def __repr__(self):
interpolate_str = _pil_interpolation_to_str[self.interpolation] interpolate_str = _pil_interpolation_to_str[self.interpolation]
return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str) return self.__class__.__name__ + "(size={0}, interpolation={1})".format(
self.size, interpolate_str
)
...@@ -54,8 +54,17 @@ class AdaBound(torch.optim.Optimizer): ...@@ -54,8 +54,17 @@ class AdaBound(torch.optim.Optimizer):
""" """
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), final_lr=0.1, def __init__(
gamma=1e-3, eps=1e-8, weight_decay=0, amsbound=False): self,
params,
lr=1e-3,
betas=(0.9, 0.999),
final_lr=0.1,
gamma=1e-3,
eps=1e-8,
weight_decay=0,
amsbound=False,
):
if not 0.0 <= lr: if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr)) raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps: if not 0.0 <= eps:
...@@ -68,16 +77,23 @@ class AdaBound(torch.optim.Optimizer): ...@@ -68,16 +77,23 @@ class AdaBound(torch.optim.Optimizer):
raise ValueError("Invalid final learning rate: {}".format(final_lr)) raise ValueError("Invalid final learning rate: {}".format(final_lr))
if not 0.0 <= gamma < 1.0: if not 0.0 <= gamma < 1.0:
raise ValueError("Invalid gamma parameter: {}".format(gamma)) raise ValueError("Invalid gamma parameter: {}".format(gamma))
defaults = dict(lr=lr, betas=betas, final_lr=final_lr, gamma=gamma, eps=eps, defaults = dict(
weight_decay=weight_decay, amsbound=amsbound) lr=lr,
betas=betas,
final_lr=final_lr,
gamma=gamma,
eps=eps,
weight_decay=weight_decay,
amsbound=amsbound,
)
super(AdaBound, self).__init__(params, defaults) super(AdaBound, self).__init__(params, defaults)
self.base_lrs = list(map(lambda group: group['lr'], self.param_groups)) self.base_lrs = list(map(lambda group: group["lr"], self.param_groups))
def __setstate__(self, state): def __setstate__(self, state):
super(AdaBound, self).__setstate__(state) super(AdaBound, self).__setstate__(state)
for group in self.param_groups: for group in self.param_groups:
group.setdefault('amsbound', False) group.setdefault("amsbound", False)
def step(self, closure=None): def step(self, closure=None):
"""Performs a single optimization step. """Performs a single optimization step.
...@@ -94,37 +110,38 @@ class AdaBound(torch.optim.Optimizer): ...@@ -94,37 +110,38 @@ class AdaBound(torch.optim.Optimizer):
loss = closure() loss = closure()
for group, base_lr in zip(self.param_groups, self.base_lrs): for group, base_lr in zip(self.param_groups, self.base_lrs):
for p in group['params']: for p in group["params"]:
if p.grad is None: if p.grad is None:
continue continue
grad = p.grad.data grad = p.grad.data
if grad.is_sparse: if grad.is_sparse:
raise RuntimeError( raise RuntimeError(
'Adam does not support sparse gradients, please consider SparseAdam instead') "Adam does not support sparse gradients, please consider SparseAdam instead"
amsbound = group['amsbound'] )
amsbound = group["amsbound"]
state = self.state[p] state = self.state[p]
# State initialization # State initialization
if len(state) == 0: if len(state) == 0:
state['step'] = 0 state["step"] = 0
# Exponential moving average of gradient values # Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data) state["exp_avg"] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values # Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data) state["exp_avg_sq"] = torch.zeros_like(p.data)
if amsbound: if amsbound:
# Maintains max of all exp. moving avg. of sq. grad. values # Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros_like(p.data) state["max_exp_avg_sq"] = torch.zeros_like(p.data)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
if amsbound: if amsbound:
max_exp_avg_sq = state['max_exp_avg_sq'] max_exp_avg_sq = state["max_exp_avg_sq"]
beta1, beta2 = group['betas'] beta1, beta2 = group["betas"]
state['step'] += 1 state["step"] += 1
if group['weight_decay'] != 0: if group["weight_decay"] != 0:
grad = grad.add(group['weight_decay'], p.data) grad = grad.add(group["weight_decay"], p.data)
# Decay the first and second moment running average coefficient # Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(1 - beta1, grad) exp_avg.mul_(beta1).add_(1 - beta1, grad)
...@@ -133,19 +150,19 @@ class AdaBound(torch.optim.Optimizer): ...@@ -133,19 +150,19 @@ class AdaBound(torch.optim.Optimizer):
# Maintains the maximum of all 2nd moment running avg. till now # Maintains the maximum of all 2nd moment running avg. till now
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
# Use the max. for normalizing running avg. of gradient # Use the max. for normalizing running avg. of gradient
denom = max_exp_avg_sq.sqrt().add_(group['eps']) denom = max_exp_avg_sq.sqrt().add_(group["eps"])
else: else:
denom = exp_avg_sq.sqrt().add_(group['eps']) denom = exp_avg_sq.sqrt().add_(group["eps"])
bias_correction1 = 1 - beta1 ** state['step'] bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state['step'] bias_correction2 = 1 - beta2 ** state["step"]
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1
# Applies bounds on actual learning rate # Applies bounds on actual learning rate
# lr_scheduler cannot affect final_lr, this is a workaround to apply lr decay # lr_scheduler cannot affect final_lr, this is a workaround to apply lr decay
final_lr = group['final_lr'] * group['lr'] / base_lr final_lr = group["final_lr"] * group["lr"] / base_lr
lower_bound = final_lr * (1 - 1 / (group['gamma'] * state['step'] + 1)) lower_bound = final_lr * (1 - 1 / (group["gamma"] * state["step"] + 1))
upper_bound = final_lr * (1 + 1 / (group['gamma'] * state['step'])) upper_bound = final_lr * (1 + 1 / (group["gamma"] * state["step"]))
step_size = torch.full_like(denom, step_size) step_size = torch.full_like(denom, step_size)
step_size.div_(denom).clamp_(lower_bound, upper_bound).mul_(exp_avg) step_size.div_(denom).clamp_(lower_bound, upper_bound).mul_(exp_avg)
...@@ -153,6 +170,7 @@ class AdaBound(torch.optim.Optimizer): ...@@ -153,6 +170,7 @@ class AdaBound(torch.optim.Optimizer):
return loss return loss
class AdaBoundW(torch.optim.Optimizer): class AdaBoundW(torch.optim.Optimizer):
"""Implements AdaBound algorithm with Decoupled Weight Decay """Implements AdaBound algorithm with Decoupled Weight Decay
(See https://arxiv.org/abs/1711.05101) (See https://arxiv.org/abs/1711.05101)
...@@ -187,8 +205,17 @@ class AdaBoundW(torch.optim.Optimizer): ...@@ -187,8 +205,17 @@ class AdaBoundW(torch.optim.Optimizer):
""" """
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), final_lr=0.1, def __init__(
gamma=1e-3, eps=1e-8, weight_decay=0, amsbound=False): self,
params,
lr=1e-3,
betas=(0.9, 0.999),
final_lr=0.1,
gamma=1e-3,
eps=1e-8,
weight_decay=0,
amsbound=False,
):
if not 0.0 <= lr: if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr)) raise ValueError("Invalid learning rate: {}".format(lr))
...@@ -202,16 +229,23 @@ class AdaBoundW(torch.optim.Optimizer): ...@@ -202,16 +229,23 @@ class AdaBoundW(torch.optim.Optimizer):
raise ValueError("Invalid final learning rate: {}".format(final_lr)) raise ValueError("Invalid final learning rate: {}".format(final_lr))
if not 0.0 <= gamma < 1.0: if not 0.0 <= gamma < 1.0:
raise ValueError("Invalid gamma parameter: {}".format(gamma)) raise ValueError("Invalid gamma parameter: {}".format(gamma))
defaults = dict(lr=lr, betas=betas, final_lr=final_lr, gamma=gamma, defaults = dict(
eps=eps, weight_decay=weight_decay, amsbound=amsbound) lr=lr,
betas=betas,
final_lr=final_lr,
gamma=gamma,
eps=eps,
weight_decay=weight_decay,
amsbound=amsbound,
)
super(AdaBoundW, self).__init__(params, defaults) super(AdaBoundW, self).__init__(params, defaults)
self.base_lrs = list(map(lambda group: group['lr'], self.param_groups)) self.base_lrs = list(map(lambda group: group["lr"], self.param_groups))
def __setstate__(self, state): def __setstate__(self, state):
super(AdaBoundW, self).__setstate__(state) super(AdaBoundW, self).__setstate__(state)
for group in self.param_groups: for group in self.param_groups:
group.setdefault('amsbound', False) group.setdefault("amsbound", False)
def step(self, closure=None): def step(self, closure=None):
"""Performs a single optimization step. """Performs a single optimization step.
...@@ -229,34 +263,35 @@ class AdaBoundW(torch.optim.Optimizer): ...@@ -229,34 +263,35 @@ class AdaBoundW(torch.optim.Optimizer):
loss = closure() loss = closure()
for group, base_lr in zip(self.param_groups, self.base_lrs): for group, base_lr in zip(self.param_groups, self.base_lrs):
for p in group['params']: for p in group["params"]:
if p.grad is None: if p.grad is None:
continue continue
grad = p.grad.data grad = p.grad.data
if grad.is_sparse: if grad.is_sparse:
raise RuntimeError( raise RuntimeError(
'Adam does not support sparse gradients, please consider SparseAdam instead') "Adam does not support sparse gradients, please consider SparseAdam instead"
amsbound = group['amsbound'] )
amsbound = group["amsbound"]
state = self.state[p] state = self.state[p]
# State initialization # State initialization
if len(state) == 0: if len(state) == 0:
state['step'] = 0 state["step"] = 0
# Exponential moving average of gradient values # Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data) state["exp_avg"] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values # Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data) state["exp_avg_sq"] = torch.zeros_like(p.data)
if amsbound: if amsbound:
# Maintains max of all exp. moving avg. of sq. grad. values # Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros_like(p.data) state["max_exp_avg_sq"] = torch.zeros_like(p.data)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
if amsbound: if amsbound:
max_exp_avg_sq = state['max_exp_avg_sq'] max_exp_avg_sq = state["max_exp_avg_sq"]
beta1, beta2 = group['betas'] beta1, beta2 = group["betas"]
state['step'] += 1 state["step"] += 1
# Decay the first and second moment running average coefficient # Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(1 - beta1, grad) exp_avg.mul_(beta1).add_(1 - beta1, grad)
...@@ -265,25 +300,25 @@ class AdaBoundW(torch.optim.Optimizer): ...@@ -265,25 +300,25 @@ class AdaBoundW(torch.optim.Optimizer):
# Maintains the maximum of all 2nd moment running avg. till now # Maintains the maximum of all 2nd moment running avg. till now
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
# Use the max. for normalizing running avg. of gradient # Use the max. for normalizing running avg. of gradient
denom = max_exp_avg_sq.sqrt().add_(group['eps']) denom = max_exp_avg_sq.sqrt().add_(group["eps"])
else: else:
denom = exp_avg_sq.sqrt().add_(group['eps']) denom = exp_avg_sq.sqrt().add_(group["eps"])
bias_correction1 = 1 - beta1 ** state['step'] bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state['step'] bias_correction2 = 1 - beta2 ** state["step"]
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1
# Applies bounds on actual learning rate # Applies bounds on actual learning rate
# lr_scheduler cannot affect final_lr, this is a workaround to # lr_scheduler cannot affect final_lr, this is a workaround to
# apply lr decay # apply lr decay
final_lr = group['final_lr'] * group['lr'] / base_lr final_lr = group["final_lr"] * group["lr"] / base_lr
lower_bound = final_lr * (1 - 1 / (group['gamma'] * state['step'] + 1)) lower_bound = final_lr * (1 - 1 / (group["gamma"] * state["step"] + 1))
upper_bound = final_lr * (1 + 1 / (group['gamma'] * state['step'])) upper_bound = final_lr * (1 + 1 / (group["gamma"] * state["step"]))
step_size = torch.full_like(denom, step_size) step_size = torch.full_like(denom, step_size)
step_size.div_(denom).clamp_(lower_bound, upper_bound).mul_(exp_avg) step_size.div_(denom).clamp_(lower_bound, upper_bound).mul_(exp_avg)
if group['weight_decay'] != 0: if group["weight_decay"] != 0:
decayed_weights = torch.mul(p.data, group['weight_decay']) decayed_weights = torch.mul(p.data, group["weight_decay"])
p.data.add_(-step_size) p.data.add_(-step_size)
p.data.sub_(decayed_weights) p.data.sub_(decayed_weights)
else: else:
......
...@@ -18,7 +18,6 @@ from bob.ip.binseg.utils.plot import precision_recall_f1iso_confintval ...@@ -18,7 +18,6 @@ from bob.ip.binseg.utils.plot import precision_recall_f1iso_confintval
from bob.ip.binseg.utils.summary import summary from bob.ip.binseg.utils.summary import summary
def batch_metrics(predictions, ground_truths, names, output_folder, logger): def batch_metrics(predictions, ground_truths, names, output_folder, logger):
""" """
Calculates metrics on the batch and saves it to disc Calculates metrics on the batch and saves it to disc
...@@ -51,21 +50,23 @@ def batch_metrics(predictions, ground_truths, names, output_folder, logger): ...@@ -51,21 +50,23 @@ def batch_metrics(predictions, ground_truths, names, output_folder, logger):
file_name = "{}.csv".format(names[j]) file_name = "{}.csv".format(names[j])
logger.info("saving {}".format(file_name)) logger.info("saving {}".format(file_name))
with open (os.path.join(output_folder,file_name), "w+") as outfile: with open(os.path.join(output_folder, file_name), "w+") as outfile:
outfile.write("threshold, precision, recall, specificity, accuracy, jaccard, f1_score\n") outfile.write(
"threshold, precision, recall, specificity, accuracy, jaccard, f1_score\n"
)
for threshold in np.arange(0.0,1.0,step_size): for threshold in np.arange(0.0, 1.0, step_size):
# threshold # threshold
binary_pred = torch.gt(predictions[j], threshold).byte() binary_pred = torch.gt(predictions[j], threshold).byte()
# equals and not-equals # equals and not-equals
equals = torch.eq(binary_pred, gts).type(torch.uint8) # tensor equals = torch.eq(binary_pred, gts).type(torch.uint8) # tensor
notequals = torch.ne(binary_pred, gts).type(torch.uint8) # tensor notequals = torch.ne(binary_pred, gts).type(torch.uint8) # tensor
# true positives # true positives
tp_tensor = (gts * binary_pred ) # tensor tp_tensor = gts * binary_pred # tensor
tp_count = torch.sum(tp_tensor).item() # scalar tp_count = torch.sum(tp_tensor).item() # scalar
# false positives # false positives
fp_tensor = torch.eq((binary_pred + tp_tensor), 1) fp_tensor = torch.eq((binary_pred + tp_tensor), 1)
...@@ -83,10 +84,13 @@ def batch_metrics(predictions, ground_truths, names, output_folder, logger): ...@@ -83,10 +84,13 @@ def batch_metrics(predictions, ground_truths, names, output_folder, logger):
metrics = base_metrics(tp_count, fp_count, tn_count, fn_count) metrics = base_metrics(tp_count, fp_count, tn_count, fn_count)
# write to disk # write to disk
outfile.write("{:.2f},{:.5f},{:.5f},{:.5f},{:.5f},{:.5f},{:.5f} \n".format(threshold, *metrics)) outfile.write(
"{:.2f},{:.5f},{:.5f},{:.5f},{:.5f},{:.5f},{:.5f} \n".format(
batch_metrics.append([names[j],threshold, *metrics ]) threshold, *metrics
)
)
batch_metrics.append([names[j], threshold, *metrics])
return batch_metrics return batch_metrics
...@@ -106,16 +110,18 @@ def save_probability_images(predictions, names, output_folder, logger): ...@@ -106,16 +110,18 @@ def save_probability_images(predictions, names, output_folder, logger):
logger : :py:class:`logging.Logger` logger : :py:class:`logging.Logger`
python logger python logger
""" """
images_subfolder = os.path.join(output_folder,'images') images_subfolder = os.path.join(output_folder, "images")
for j in range(predictions.size()[0]): for j in range(predictions.size()[0]):
img = VF.to_pil_image(predictions.cpu().data[j]) img = VF.to_pil_image(predictions.cpu().data[j])
filename = '{}.png'.format(names[j].split(".")[0]) filename = "{}.png".format(names[j].split(".")[0])
fullpath = os.path.join(images_subfolder, filename) fullpath = os.path.join(images_subfolder, filename)
logger.info("saving {}".format(fullpath)) logger.info("saving {}".format(fullpath))
fulldir = os.path.dirname(fullpath) fulldir = os.path.dirname(fullpath)
if not os.path.exists(fulldir): os.makedirs(fulldir) if not os.path.exists(fulldir):
os.makedirs(fulldir)
img.save(fullpath) img.save(fullpath)
def save_hdf(predictions, names, output_folder, logger): def save_hdf(predictions, names, output_folder, logger):
""" """
Saves probability maps as image in the same format as the test image Saves probability maps as image in the same format as the test image
...@@ -131,23 +137,21 @@ def save_hdf(predictions, names, output_folder, logger): ...@@ -131,23 +137,21 @@ def save_hdf(predictions, names, output_folder, logger):
logger : :py:class:`logging.Logger` logger : :py:class:`logging.Logger`
python logger python logger
""" """
hdf5_subfolder = os.path.join(output_folder,'hdf5') hdf5_subfolder = os.path.join(output_folder, "hdf5")
if not os.path.exists(hdf5_subfolder): os.makedirs(hdf5_subfolder) if not os.path.exists(hdf5_subfolder):
os.makedirs(hdf5_subfolder)
for j in range(predictions.size()[0]): for j in range(predictions.size()[0]):
img = predictions.cpu().data[j].squeeze(0).numpy() img = predictions.cpu().data[j].squeeze(0).numpy()
filename = '{}.hdf5'.format(names[j].split(".")[0]) filename = "{}.hdf5".format(names[j].split(".")[0])
fullpath = os.path.join(hdf5_subfolder, filename) fullpath = os.path.join(hdf5_subfolder, filename)
logger.info("saving {}".format(filename)) logger.info("saving {}".format(filename))
fulldir = os.path.dirname(fullpath) fulldir = os.path.dirname(fullpath)
if not os.path.exists(fulldir): os.makedirs(fulldir) if not os.path.exists(fulldir):
os.makedirs(fulldir)
bob.io.base.save(img, fullpath) bob.io.base.save(img, fullpath)
def do_inference(
model, def do_inference(model, data_loader, device, output_folder=None):
data_loader,
device,
output_folder = None
):
""" """
Run inference and calculate metrics Run inference and calculate metrics
...@@ -164,8 +168,8 @@ def do_inference( ...@@ -164,8 +168,8 @@ def do_inference(
logger = logging.getLogger("bob.ip.binseg.engine.inference") logger = logging.getLogger("bob.ip.binseg.engine.inference")
logger.info("Start evaluation") logger.info("Start evaluation")
logger.info("Output folder: {}, Device: {}".format(output_folder, device)) logger.info("Output folder: {}, Device: {}".format(output_folder, device))
results_subfolder = os.path.join(output_folder,'results') results_subfolder = os.path.join(output_folder, "results")
os.makedirs(results_subfolder,exist_ok=True) os.makedirs(results_subfolder, exist_ok=True)
model.eval().to(device) model.eval().to(device)
# Sigmoid for probabilities # Sigmoid for probabilities
...@@ -189,7 +193,7 @@ def do_inference( ...@@ -189,7 +193,7 @@ def do_inference(
# necessary check for hed architecture that uses several outputs # necessary check for hed architecture that uses several outputs
# for loss calculation instead of just the last concatfuse block # for loss calculation instead of just the last concatfuse block
if isinstance(outputs,list): if isinstance(outputs, list):
outputs = outputs[-1] outputs = outputs[-1]
probabilities = sigmoid(outputs) probabilities = sigmoid(outputs)
...@@ -198,7 +202,9 @@ def do_inference( ...@@ -198,7 +202,9 @@ def do_inference(
times.append(batch_time) times.append(batch_time)
logger.info("Batch time: {:.5f} s".format(batch_time)) logger.info("Batch time: {:.5f} s".format(batch_time))
b_metrics = batch_metrics(probabilities, ground_truths, names,results_subfolder, logger) b_metrics = batch_metrics(
probabilities, ground_truths, names, results_subfolder, logger
)
metrics.extend(b_metrics) metrics.extend(b_metrics)
# Create probability images # Create probability images
...@@ -207,74 +213,94 @@ def do_inference( ...@@ -207,74 +213,94 @@ def do_inference(
save_hdf(probabilities, names, output_folder, logger) save_hdf(probabilities, names, output_folder, logger)
# DataFrame # DataFrame
df_metrics = pd.DataFrame(metrics,columns= \ df_metrics = pd.DataFrame(
["name", metrics,
"threshold", columns=[
"precision", "name",
"recall", "threshold",
"specificity", "precision",
"accuracy", "recall",
"jaccard", "specificity",
"f1_score"]) "accuracy",
"jaccard",
"f1_score",
],
)
# Report and Averages # Report and Averages
metrics_file = "Metrics.csv".format(model.name) metrics_file = "Metrics.csv".format(model.name)
metrics_path = os.path.join(results_subfolder, metrics_file) metrics_path = os.path.join(results_subfolder, metrics_file)
logger.info("Saving average over all input images: {}".format(metrics_file)) logger.info("Saving average over all input images: {}".format(metrics_file))
avg_metrics = df_metrics.groupby('threshold').mean() avg_metrics = df_metrics.groupby("threshold").mean()
std_metrics = df_metrics.groupby('threshold').std() std_metrics = df_metrics.groupby("threshold").std()
# Uncomment below for F1-score calculation based on average precision and metrics instead of # Uncomment below for F1-score calculation based on average precision and metrics instead of
# F1-scores of individual images. This method is in line with Maninis et. al. (2016) # F1-scores of individual images. This method is in line with Maninis et. al. (2016)
#avg_metrics["f1_score"] = (2* avg_metrics["precision"]*avg_metrics["recall"])/ \ # avg_metrics["f1_score"] = (2* avg_metrics["precision"]*avg_metrics["recall"])/ \
# (avg_metrics["precision"]+avg_metrics["recall"]) # (avg_metrics["precision"]+avg_metrics["recall"])
avg_metrics["std_pr"] = std_metrics["precision"] avg_metrics["std_pr"] = std_metrics["precision"]
avg_metrics["pr_upper"] = avg_metrics['precision'] + avg_metrics["std_pr"] avg_metrics["pr_upper"] = avg_metrics["precision"] + avg_metrics["std_pr"]
avg_metrics["pr_lower"] = avg_metrics['precision'] - avg_metrics["std_pr"] avg_metrics["pr_lower"] = avg_metrics["precision"] - avg_metrics["std_pr"]
avg_metrics["std_re"] = std_metrics["recall"] avg_metrics["std_re"] = std_metrics["recall"]
avg_metrics["re_upper"] = avg_metrics['recall'] + avg_metrics["std_re"] avg_metrics["re_upper"] = avg_metrics["recall"] + avg_metrics["std_re"]
avg_metrics["re_lower"] = avg_metrics['recall'] - avg_metrics["std_re"] avg_metrics["re_lower"] = avg_metrics["recall"] - avg_metrics["std_re"]
avg_metrics["std_f1"] = std_metrics["f1_score"] avg_metrics["std_f1"] = std_metrics["f1_score"]
avg_metrics.to_csv(metrics_path) avg_metrics.to_csv(metrics_path)
maxf1 = avg_metrics['f1_score'].max() maxf1 = avg_metrics["f1_score"].max()
optimal_f1_threshold = avg_metrics['f1_score'].idxmax() optimal_f1_threshold = avg_metrics["f1_score"].idxmax()
logger.info("Highest F1-score of {:.5f}, achieved at threshold {}".format(maxf1, optimal_f1_threshold)) logger.info(
"Highest F1-score of {:.5f}, achieved at threshold {}".format(
maxf1, optimal_f1_threshold
)
)
# Plotting # Plotting
np_avg_metrics = avg_metrics.to_numpy().T np_avg_metrics = avg_metrics.to_numpy().T
fig_name = "precision_recall.pdf" fig_name = "precision_recall.pdf"
logger.info("saving {}".format(fig_name)) logger.info("saving {}".format(fig_name))
fig = precision_recall_f1iso_confintval([np_avg_metrics[0]],[np_avg_metrics[1]],[np_avg_metrics[7]],[np_avg_metrics[8]],[np_avg_metrics[10]],[np_avg_metrics[11]], [model.name,None], title=output_folder) fig = precision_recall_f1iso_confintval(
[np_avg_metrics[0]],
[np_avg_metrics[1]],
[np_avg_metrics[7]],
[np_avg_metrics[8]],
[np_avg_metrics[10]],
[np_avg_metrics[11]],
[model.name, None],
title=output_folder,
)
fig_filename = os.path.join(results_subfolder, fig_name) fig_filename = os.path.join(results_subfolder, fig_name)
fig.savefig(fig_filename) fig.savefig(fig_filename)
# Report times # Report times
total_inference_time = str(datetime.timedelta(seconds=int(sum(times)))) total_inference_time = str(datetime.timedelta(seconds=int(sum(times))))
average_batch_inference_time = np.mean(times) average_batch_inference_time = np.mean(times)
total_evalution_time = str(datetime.timedelta(seconds=int(time.time() - start_total_time ))) total_evalution_time = str(
datetime.timedelta(seconds=int(time.time() - start_total_time))
)
logger.info("Average batch inference time: {:.5f}s".format(average_batch_inference_time)) logger.info(
"Average batch inference time: {:.5f}s".format(average_batch_inference_time)
)
times_file = "Times.txt" times_file = "Times.txt"
logger.info("saving {}".format(times_file)) logger.info("saving {}".format(times_file))
with open (os.path.join(results_subfolder,times_file), "w+") as outfile: with open(os.path.join(results_subfolder, times_file), "w+") as outfile:
date = datetime.datetime.now() date = datetime.datetime.now()
outfile.write("Date: {} \n".format(date.strftime("%Y-%m-%d %H:%M:%S"))) outfile.write("Date: {} \n".format(date.strftime("%Y-%m-%d %H:%M:%S")))
outfile.write("Total evaluation run-time: {} \n".format(total_evalution_time)) outfile.write("Total evaluation run-time: {} \n".format(total_evalution_time))
outfile.write("Average batch inference time: {} \n".format(average_batch_inference_time)) outfile.write(
"Average batch inference time: {} \n".format(average_batch_inference_time)
)
outfile.write("Total inference time: {} \n".format(total_inference_time)) outfile.write("Total inference time: {} \n".format(total_inference_time))
# Save model summary # Save model summary
summary_file = 'ModelSummary.txt' summary_file = "ModelSummary.txt"
logger.info("saving {}".format(summary_file)) logger.info("saving {}".format(summary_file))
with open (os.path.join(results_subfolder,summary_file), "w+") as outfile: with open(os.path.join(results_subfolder, summary_file), "w+") as outfile:
summary(model,outfile) summary(model, outfile)
...@@ -15,12 +15,7 @@ from bob.ip.binseg.engine.inferencer import save_probability_images ...@@ -15,12 +15,7 @@ from bob.ip.binseg.engine.inferencer import save_probability_images
from bob.ip.binseg.engine.inferencer import save_hdf from bob.ip.binseg.engine.inferencer import save_hdf
def do_predict( def do_predict(model, data_loader, device, output_folder=None):
model,
data_loader,
device,
output_folder = None
):
""" """
Run inference and calculate metrics Run inference and calculate metrics
...@@ -37,8 +32,8 @@ def do_predict( ...@@ -37,8 +32,8 @@ def do_predict(
logger = logging.getLogger("bob.ip.binseg.engine.inference") logger = logging.getLogger("bob.ip.binseg.engine.inference")
logger.info("Start evaluation") logger.info("Start evaluation")
logger.info("Output folder: {}, Device: {}".format(output_folder, device)) logger.info("Output folder: {}, Device: {}".format(output_folder, device))
results_subfolder = os.path.join(output_folder,'results') results_subfolder = os.path.join(output_folder, "results")
os.makedirs(results_subfolder,exist_ok=True) os.makedirs(results_subfolder, exist_ok=True)
model.eval().to(device) model.eval().to(device)
# Sigmoid for probabilities # Sigmoid for probabilities
...@@ -58,7 +53,7 @@ def do_predict( ...@@ -58,7 +53,7 @@ def do_predict(
# necessary check for hed architecture that uses several outputs # necessary check for hed architecture that uses several outputs
# for loss calculation instead of just the last concatfuse block # for loss calculation instead of just the last concatfuse block
if isinstance(outputs,list): if isinstance(outputs, list):
outputs = outputs[-1] outputs = outputs[-1]
probabilities = sigmoid(outputs) probabilities = sigmoid(outputs)
...@@ -72,22 +67,25 @@ def do_predict( ...@@ -72,22 +67,25 @@ def do_predict(
# Save hdf5 # Save hdf5
save_hdf(probabilities, names, output_folder, logger) save_hdf(probabilities, names, output_folder, logger)
# Report times # Report times
total_inference_time = str(datetime.timedelta(seconds=int(sum(times)))) total_inference_time = str(datetime.timedelta(seconds=int(sum(times))))
average_batch_inference_time = np.mean(times) average_batch_inference_time = np.mean(times)
total_evalution_time = str(datetime.timedelta(seconds=int(time.time() - start_total_time ))) total_evalution_time = str(
datetime.timedelta(seconds=int(time.time() - start_total_time))
)
logger.info("Average batch inference time: {:.5f}s".format(average_batch_inference_time)) logger.info(
"Average batch inference time: {:.5f}s".format(average_batch_inference_time)
)
times_file = "Times.txt" times_file = "Times.txt"
logger.info("saving {}".format(times_file)) logger.info("saving {}".format(times_file))
with open (os.path.join(results_subfolder,times_file), "w+") as outfile: with open(os.path.join(results_subfolder, times_file), "w+") as outfile:
date = datetime.datetime.now() date = datetime.datetime.now()
outfile.write("Date: {} \n".format(date.strftime("%Y-%m-%d %H:%M:%S"))) outfile.write("Date: {} \n".format(date.strftime("%Y-%m-%d %H:%M:%S")))
outfile.write("Total evaluation run-time: {} \n".format(total_evalution_time)) outfile.write("Total evaluation run-time: {} \n".format(total_evalution_time))
outfile.write("Average batch inference time: {} \n".format(average_batch_inference_time)) outfile.write(
"Average batch inference time: {} \n".format(average_batch_inference_time)
)
outfile.write("Total inference time: {} \n".format(total_inference_time)) outfile.write("Total inference time: {} \n".format(total_inference_time))
...@@ -13,10 +13,12 @@ import numpy as np ...@@ -13,10 +13,12 @@ import numpy as np
from bob.ip.binseg.utils.metric import SmoothedValue from bob.ip.binseg.utils.metric import SmoothedValue
from bob.ip.binseg.utils.plot import loss_curve from bob.ip.binseg.utils.plot import loss_curve
def sharpen(x, T): def sharpen(x, T):
temp = x**(1/T) temp = x ** (1 / T)
return temp / temp.sum(dim=1, keepdim=True) return temp / temp.sum(dim=1, keepdim=True)
def mix_up(alpha, input, target, unlabeled_input, unlabled_target): def mix_up(alpha, input, target, unlabeled_input, unlabled_target):
"""Applies mix up as described in [MIXMATCH_19]. """Applies mix up as described in [MIXMATCH_19].
...@@ -41,21 +43,30 @@ def mix_up(alpha, input, target, unlabeled_input, unlabled_target): ...@@ -41,21 +43,30 @@ def mix_up(alpha, input, target, unlabeled_input, unlabled_target):
""" """
# TODO: # TODO:
with torch.no_grad(): with torch.no_grad():
l = np.random.beta(alpha, alpha) # Eq (8) l = np.random.beta(alpha, alpha) # Eq (8)
l = max(l, 1 - l) # Eq (9) l = max(l, 1 - l) # Eq (9)
# Shuffle and concat. Alg. 1 Line: 12 # Shuffle and concat. Alg. 1 Line: 12
w_inputs = torch.cat([input,unlabeled_input],0) w_inputs = torch.cat([input, unlabeled_input], 0)
w_targets = torch.cat([target,unlabled_target],0) w_targets = torch.cat([target, unlabled_target], 0)
idx = torch.randperm(w_inputs.size(0)) # get random index idx = torch.randperm(w_inputs.size(0)) # get random index
# Apply MixUp to labeled data and entries from W. Alg. 1 Line: 13 # Apply MixUp to labeled data and entries from W. Alg. 1 Line: 13
input_mixedup = l * input + (1 - l) * w_inputs[idx[len(input):]] input_mixedup = l * input + (1 - l) * w_inputs[idx[len(input) :]]
target_mixedup = l * target + (1 - l) * w_targets[idx[len(target):]] target_mixedup = l * target + (1 - l) * w_targets[idx[len(target) :]]
# Apply MixUp to unlabeled data and entries from W. Alg. 1 Line: 14 # Apply MixUp to unlabeled data and entries from W. Alg. 1 Line: 14
unlabeled_input_mixedup = l * unlabeled_input + (1 - l) * w_inputs[idx[:len(unlabeled_input)]] unlabeled_input_mixedup = (
unlabled_target_mixedup = l * unlabled_target + (1 - l) * w_targets[idx[:len(unlabled_target)]] l * unlabeled_input + (1 - l) * w_inputs[idx[: len(unlabeled_input)]]
return input_mixedup, target_mixedup, unlabeled_input_mixedup, unlabled_target_mixedup )
unlabled_target_mixedup = (
l * unlabled_target + (1 - l) * w_targets[idx[: len(unlabled_target)]]
)
return (
input_mixedup,
target_mixedup,
unlabeled_input_mixedup,
unlabled_target_mixedup,
)
def square_rampup(current, rampup_length=16): def square_rampup(current, rampup_length=16):
...@@ -80,9 +91,10 @@ def square_rampup(current, rampup_length=16): ...@@ -80,9 +91,10 @@ def square_rampup(current, rampup_length=16):
if rampup_length == 0: if rampup_length == 0:
return 1.0 return 1.0
else: else:
current = np.clip((current/ float(rampup_length))**2, 0.0, 1.0) current = np.clip((current / float(rampup_length)) ** 2, 0.0, 1.0)
return float(current) return float(current)
def linear_rampup(current, rampup_length=16): def linear_rampup(current, rampup_length=16):
"""slowly ramp-up ``lambda_u`` """slowly ramp-up ``lambda_u``
...@@ -107,6 +119,7 @@ def linear_rampup(current, rampup_length=16): ...@@ -107,6 +119,7 @@ def linear_rampup(current, rampup_length=16):
current = np.clip(current / rampup_length, 0.0, 1.0) current = np.clip(current / rampup_length, 0.0, 1.0)
return float(current) return float(current)
def guess_labels(unlabeled_images, model): def guess_labels(unlabeled_images, model):
""" """
Calculate the average predictions by 2 augmentations: horizontal and vertical flips Calculate the average predictions by 2 augmentations: horizontal and vertical flips
...@@ -130,15 +143,16 @@ def guess_labels(unlabeled_images, model): ...@@ -130,15 +143,16 @@ def guess_labels(unlabeled_images, model):
guess1 = torch.sigmoid(model(unlabeled_images)).unsqueeze(0) guess1 = torch.sigmoid(model(unlabeled_images)).unsqueeze(0)
# Horizontal flip and unsqueeze to work with batches (increase flip dimension by 1) # Horizontal flip and unsqueeze to work with batches (increase flip dimension by 1)
hflip = torch.sigmoid(model(unlabeled_images.flip(2))).unsqueeze(0) hflip = torch.sigmoid(model(unlabeled_images.flip(2))).unsqueeze(0)
guess2 = hflip.flip(3) guess2 = hflip.flip(3)
# Vertical flip and unsqueeze to work with batches (increase flip dimension by 1) # Vertical flip and unsqueeze to work with batches (increase flip dimension by 1)
vflip = torch.sigmoid(model(unlabeled_images.flip(3))).unsqueeze(0) vflip = torch.sigmoid(model(unlabeled_images.flip(3))).unsqueeze(0)
guess3 = vflip.flip(4) guess3 = vflip.flip(4)
# Concat # Concat
concat = torch.cat([guess1,guess2,guess3],0) concat = torch.cat([guess1, guess2, guess3], 0)
avg_guess = torch.mean(concat,0) avg_guess = torch.mean(concat, 0)
return avg_guess return avg_guess
def do_ssltrain( def do_ssltrain(
model, model,
data_loader, data_loader,
...@@ -150,7 +164,7 @@ def do_ssltrain( ...@@ -150,7 +164,7 @@ def do_ssltrain(
device, device,
arguments, arguments,
output_folder, output_folder,
rampup_length rampup_length,
): ):
""" """
Train model and save to disk. Train model and save to disk.
...@@ -196,7 +210,9 @@ def do_ssltrain( ...@@ -196,7 +210,9 @@ def do_ssltrain(
max_epoch = arguments["max_epoch"] max_epoch = arguments["max_epoch"]
# Logg to file # Logg to file
with open (os.path.join(output_folder,"{}_trainlog.csv".format(model.name)), "a+",1) as outfile: with open(
os.path.join(output_folder, "{}_trainlog.csv".format(model.name)), "a+", 1
) as outfile:
for state in optimizer.state.values(): for state in optimizer.state.values():
for k, v in state.items(): for k, v in state.items():
if isinstance(v, torch.Tensor): if isinstance(v, torch.Tensor):
...@@ -226,11 +242,17 @@ def do_ssltrain( ...@@ -226,11 +242,17 @@ def do_ssltrain(
unlabeled_outputs = model(unlabeled_images) unlabeled_outputs = model(unlabeled_images)
# guessed unlabeled outputs # guessed unlabeled outputs
unlabeled_ground_truths = guess_labels(unlabeled_images, model) unlabeled_ground_truths = guess_labels(unlabeled_images, model)
#unlabeled_ground_truths = sharpen(unlabeled_ground_truths,0.5) # unlabeled_ground_truths = sharpen(unlabeled_ground_truths,0.5)
#images, ground_truths, unlabeled_images, unlabeled_ground_truths = mix_up(0.75, images, ground_truths, unlabeled_images, unlabeled_ground_truths) # images, ground_truths, unlabeled_images, unlabeled_ground_truths = mix_up(0.75, images, ground_truths, unlabeled_images, unlabeled_ground_truths)
ramp_up_factor = square_rampup(epoch,rampup_length=rampup_length) ramp_up_factor = square_rampup(epoch, rampup_length=rampup_length)
loss, ll, ul = criterion(outputs, ground_truths, unlabeled_outputs, unlabeled_ground_truths, ramp_up_factor) loss, ll, ul = criterion(
outputs,
ground_truths,
unlabeled_outputs,
unlabeled_ground_truths,
ramp_up_factor,
)
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
...@@ -247,60 +269,77 @@ def do_ssltrain( ...@@ -247,60 +269,77 @@ def do_ssltrain(
epoch_time = time.time() - start_epoch_time epoch_time = time.time() - start_epoch_time
eta_seconds = epoch_time * (max_epoch - epoch) eta_seconds = epoch_time * (max_epoch - epoch)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
outfile.write(("{epoch}, " outfile.write(
"{avg_loss:.6f}, " (
"{median_loss:.6f}, " "{epoch}, "
"{median_labeled_loss}," "{avg_loss:.6f}, "
"{median_unlabeled_loss}," "{median_loss:.6f}, "
"{lr:.6f}, " "{median_labeled_loss},"
"{memory:.0f}" "{median_unlabeled_loss},"
"\n" "{lr:.6f}, "
).format( "{memory:.0f}"
"\n"
).format(
eta=eta_string, eta=eta_string,
epoch=epoch, epoch=epoch,
avg_loss=losses.avg, avg_loss=losses.avg,
median_loss=losses.median, median_loss=losses.median,
median_labeled_loss = labeled_loss.median, median_labeled_loss=labeled_loss.median,
median_unlabeled_loss = unlabeled_loss.median, median_unlabeled_loss=unlabeled_loss.median,
lr=optimizer.param_groups[0]["lr"], lr=optimizer.param_groups[0]["lr"],
memory = (torch.cuda.max_memory_allocated() / 1024.0 / 1024.0) if torch.cuda.is_available() else .0, memory=(torch.cuda.max_memory_allocated() / 1024.0 / 1024.0)
) if torch.cuda.is_available()
else 0.0,
) )
logger.info(("eta: {eta}, " )
"epoch: {epoch}, " logger.info(
"avg. loss: {avg_loss:.6f}, " (
"median loss: {median_loss:.6f}, " "eta: {eta}, "
"labeled loss: {median_labeled_loss}, " "epoch: {epoch}, "
"unlabeled loss: {median_unlabeled_loss}, " "avg. loss: {avg_loss:.6f}, "
"lr: {lr:.6f}, " "median loss: {median_loss:.6f}, "
"max mem: {memory:.0f}" "labeled loss: {median_labeled_loss}, "
).format( "unlabeled loss: {median_unlabeled_loss}, "
"lr: {lr:.6f}, "
"max mem: {memory:.0f}"
).format(
eta=eta_string, eta=eta_string,
epoch=epoch, epoch=epoch,
avg_loss=losses.avg, avg_loss=losses.avg,
median_loss=losses.median, median_loss=losses.median,
median_labeled_loss = labeled_loss.median, median_labeled_loss=labeled_loss.median,
median_unlabeled_loss = unlabeled_loss.median, median_unlabeled_loss=unlabeled_loss.median,
lr=optimizer.param_groups[0]["lr"], lr=optimizer.param_groups[0]["lr"],
memory = (torch.cuda.max_memory_allocated() / 1024.0 / 1024.0) if torch.cuda.is_available() else .0 memory=(torch.cuda.max_memory_allocated() / 1024.0 / 1024.0)
) if torch.cuda.is_available()
else 0.0,
) )
)
total_training_time = time.time() - start_training_time total_training_time = time.time() - start_training_time
total_time_str = str(datetime.timedelta(seconds=total_training_time)) total_time_str = str(datetime.timedelta(seconds=total_training_time))
logger.info( logger.info(
"Total training time: {} ({:.4f} s / epoch)".format( "Total training time: {} ({:.4f} s / epoch)".format(
total_time_str, total_training_time / (max_epoch) total_time_str, total_training_time / (max_epoch)
)) )
)
log_plot_file = os.path.join(output_folder,"{}_trainlog.pdf".format(model.name))
logdf = pd.read_csv(os.path.join(output_folder,"{}_trainlog.csv".format(model.name)),header=None, names=["avg. loss", "median loss", "labeled loss", "unlabeled loss", "lr","max memory"]) log_plot_file = os.path.join(output_folder, "{}_trainlog.pdf".format(model.name))
fig = loss_curve(logdf,output_folder) logdf = pd.read_csv(
os.path.join(output_folder, "{}_trainlog.csv".format(model.name)),
header=None,
names=[
"avg. loss",
"median loss",
"labeled loss",
"unlabeled loss",
"lr",
"max memory",
],
)
fig = loss_curve(logdf, output_folder)
logger.info("saving {}".format(log_plot_file)) logger.info("saving {}".format(log_plot_file))
fig.savefig(log_plot_file) fig.savefig(log_plot_file)
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import os import os
import logging import logging
import time import time
import datetime import datetime
...@@ -23,7 +23,7 @@ def do_train( ...@@ -23,7 +23,7 @@ def do_train(
checkpoint_period, checkpoint_period,
device, device,
arguments, arguments,
output_folder output_folder,
): ):
""" """
Train model and save to disk. Train model and save to disk.
...@@ -55,8 +55,10 @@ def do_train( ...@@ -55,8 +55,10 @@ def do_train(
max_epoch = arguments["max_epoch"] max_epoch = arguments["max_epoch"]
# Logg to file # Logg to file
with open (os.path.join(output_folder,"{}_trainlog.csv".format(model.name)), "a+") as outfile: with open(
os.path.join(output_folder, "{}_trainlog.csv".format(model.name)), "a+"
) as outfile:
model.train().to(device) model.train().to(device)
for state in optimizer.state.values(): for state in optimizer.state.values():
for k, v in state.items(): for k, v in state.items():
...@@ -70,7 +72,7 @@ def do_train( ...@@ -70,7 +72,7 @@ def do_train(
losses = SmoothedValue(len(data_loader)) losses = SmoothedValue(len(data_loader))
epoch = epoch + 1 epoch = epoch + 1
arguments["epoch"] = epoch arguments["epoch"] = epoch
# Epoch time # Epoch time
start_epoch_time = time.time() start_epoch_time = time.time()
...@@ -81,9 +83,9 @@ def do_train( ...@@ -81,9 +83,9 @@ def do_train(
masks = None masks = None
if len(samples) == 4: if len(samples) == 4:
masks = samples[-1].to(device) masks = samples[-1].to(device)
outputs = model(images) outputs = model(images)
loss = criterion(outputs, ground_truths, masks) loss = criterion(outputs, ground_truths, masks)
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
...@@ -100,51 +102,62 @@ def do_train( ...@@ -100,51 +102,62 @@ def do_train(
epoch_time = time.time() - start_epoch_time epoch_time = time.time() - start_epoch_time
eta_seconds = epoch_time * (max_epoch - epoch) eta_seconds = epoch_time * (max_epoch - epoch)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
outfile.write(("{epoch}, " outfile.write(
"{avg_loss:.6f}, " (
"{median_loss:.6f}, " "{epoch}, "
"{lr:.6f}, " "{avg_loss:.6f}, "
"{memory:.0f}" "{median_loss:.6f}, "
"\n" "{lr:.6f}, "
).format( "{memory:.0f}"
"\n"
).format(
eta=eta_string, eta=eta_string,
epoch=epoch, epoch=epoch,
avg_loss=losses.avg, avg_loss=losses.avg,
median_loss=losses.median, median_loss=losses.median,
lr=optimizer.param_groups[0]["lr"], lr=optimizer.param_groups[0]["lr"],
memory = (torch.cuda.max_memory_allocated() / 1024.0 / 1024.0) if torch.cuda.is_available() else .0, memory=(torch.cuda.max_memory_allocated() / 1024.0 / 1024.0)
) if torch.cuda.is_available()
) else 0.0,
logger.info(("eta: {eta}, " )
"epoch: {epoch}, " )
"avg. loss: {avg_loss:.6f}, " logger.info(
"median loss: {median_loss:.6f}, " (
"lr: {lr:.6f}, " "eta: {eta}, "
"max mem: {memory:.0f}" "epoch: {epoch}, "
).format( "avg. loss: {avg_loss:.6f}, "
"median loss: {median_loss:.6f}, "
"lr: {lr:.6f}, "
"max mem: {memory:.0f}"
).format(
eta=eta_string, eta=eta_string,
epoch=epoch, epoch=epoch,
avg_loss=losses.avg, avg_loss=losses.avg,
median_loss=losses.median, median_loss=losses.median,
lr=optimizer.param_groups[0]["lr"], lr=optimizer.param_groups[0]["lr"],
memory = (torch.cuda.max_memory_allocated() / 1024.0 / 1024.0) if torch.cuda.is_available() else .0 memory=(torch.cuda.max_memory_allocated() / 1024.0 / 1024.0)
) if torch.cuda.is_available()
else 0.0,
) )
)
total_training_time = time.time() - start_training_time total_training_time = time.time() - start_training_time
total_time_str = str(datetime.timedelta(seconds=total_training_time)) total_time_str = str(datetime.timedelta(seconds=total_training_time))
logger.info( logger.info(
"Total training time: {} ({:.4f} s / epoch)".format( "Total training time: {} ({:.4f} s / epoch)".format(
total_time_str, total_training_time / (max_epoch) total_time_str, total_training_time / (max_epoch)
)) )
)
log_plot_file = os.path.join(output_folder,"{}_trainlog.pdf".format(model.name))
logdf = pd.read_csv(os.path.join(output_folder,"{}_trainlog.csv".format(model.name)),header=None, names=["avg. loss", "median loss","lr","max memory"]) log_plot_file = os.path.join(output_folder, "{}_trainlog.pdf".format(model.name))
fig = loss_curve(logdf,output_folder) logdf = pd.read_csv(
os.path.join(output_folder, "{}_trainlog.csv".format(model.name)),
header=None,
names=["avg. loss", "median loss", "lr", "max memory"],
)
fig = loss_curve(logdf, output_folder)
logger.info("saving {}".format(log_plot_file)) logger.info("saving {}".format(log_plot_file))
fig.savefig(log_plot_file) fig.savefig(log_plot_file)
...@@ -12,7 +12,7 @@ def conv_bn(inp, oup, stride): ...@@ -12,7 +12,7 @@ def conv_bn(inp, oup, stride):
return torch.nn.Sequential( return torch.nn.Sequential(
torch.nn.Conv2d(inp, oup, 3, stride, 1, bias=False), torch.nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
torch.nn.BatchNorm2d(oup), torch.nn.BatchNorm2d(oup),
torch.nn.ReLU6(inplace=True) torch.nn.ReLU6(inplace=True),
) )
...@@ -20,7 +20,7 @@ def conv_1x1_bn(inp, oup): ...@@ -20,7 +20,7 @@ def conv_1x1_bn(inp, oup):
return torch.nn.Sequential( return torch.nn.Sequential(
torch.nn.Conv2d(inp, oup, 1, 1, 0, bias=False), torch.nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
torch.nn.BatchNorm2d(oup), torch.nn.BatchNorm2d(oup),
torch.nn.ReLU6(inplace=True) torch.nn.ReLU6(inplace=True),
) )
...@@ -36,7 +36,9 @@ class InvertedResidual(torch.nn.Module): ...@@ -36,7 +36,9 @@ class InvertedResidual(torch.nn.Module):
if expand_ratio == 1: if expand_ratio == 1:
self.conv = torch.nn.Sequential( self.conv = torch.nn.Sequential(
# dw # dw
torch.nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), torch.nn.Conv2d(
hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False
),
torch.nn.BatchNorm2d(hidden_dim), torch.nn.BatchNorm2d(hidden_dim),
torch.nn.ReLU6(inplace=True), torch.nn.ReLU6(inplace=True),
# pw-linear # pw-linear
...@@ -50,7 +52,9 @@ class InvertedResidual(torch.nn.Module): ...@@ -50,7 +52,9 @@ class InvertedResidual(torch.nn.Module):
torch.nn.BatchNorm2d(hidden_dim), torch.nn.BatchNorm2d(hidden_dim),
torch.nn.ReLU6(inplace=True), torch.nn.ReLU6(inplace=True),
# dw # dw
torch.nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), torch.nn.Conv2d(
hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False
),
torch.nn.BatchNorm2d(hidden_dim), torch.nn.BatchNorm2d(hidden_dim),
torch.nn.ReLU6(inplace=True), torch.nn.ReLU6(inplace=True),
# pw-linear # pw-linear
...@@ -66,7 +70,14 @@ class InvertedResidual(torch.nn.Module): ...@@ -66,7 +70,14 @@ class InvertedResidual(torch.nn.Module):
class MobileNetV2(torch.nn.Module): class MobileNetV2(torch.nn.Module):
def __init__(self, n_class=1000, input_size=224, width_mult=1., return_features = None, m2u=True): def __init__(
self,
n_class=1000,
input_size=224,
width_mult=1.0,
return_features=None,
m2u=True,
):
super(MobileNetV2, self).__init__() super(MobileNetV2, self).__init__()
self.return_features = return_features self.return_features = return_features
self.m2u = m2u self.m2u = m2u
...@@ -80,34 +91,38 @@ class MobileNetV2(torch.nn.Module): ...@@ -80,34 +91,38 @@ class MobileNetV2(torch.nn.Module):
[6, 32, 3, 2], [6, 32, 3, 2],
[6, 64, 4, 2], [6, 64, 4, 2],
[6, 96, 3, 1], [6, 96, 3, 1],
#[6, 160, 3, 2], # [6, 160, 3, 2],
#[6, 320, 1, 1], # [6, 320, 1, 1],
] ]
# building first layer # building first layer
assert input_size % 32 == 0 assert input_size % 32 == 0
input_channel = int(input_channel * width_mult) input_channel = int(input_channel * width_mult)
#self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel # self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel
self.features = [conv_bn(3, input_channel, 2)] self.features = [conv_bn(3, input_channel, 2)]
# building inverted residual blocks # building inverted residual blocks
for t, c, n, s in interverted_residual_setting: for t, c, n, s in interverted_residual_setting:
output_channel = int(c * width_mult) output_channel = int(c * width_mult)
for i in range(n): for i in range(n):
if i == 0: if i == 0:
self.features.append(block(input_channel, output_channel, s, expand_ratio=t)) self.features.append(
block(input_channel, output_channel, s, expand_ratio=t)
)
else: else:
self.features.append(block(input_channel, output_channel, 1, expand_ratio=t)) self.features.append(
block(input_channel, output_channel, 1, expand_ratio=t)
)
input_channel = output_channel input_channel = output_channel
# building last several layers # building last several layers
#self.features.append(conv_1x1_bn(input_channel, self.last_channel)) # self.features.append(conv_1x1_bn(input_channel, self.last_channel))
# make it torch.nn.Sequential # make it torch.nn.Sequential
self.features = torch.nn.Sequential(*self.features) self.features = torch.nn.Sequential(*self.features)
# building classifier # building classifier
#self.classifier = torch.nn.Sequential( # self.classifier = torch.nn.Sequential(
# torch.nn.Dropout(0.2), # torch.nn.Dropout(0.2),
# torch.nn.Linear(self.last_channel, n_class), # torch.nn.Linear(self.last_channel, n_class),
#) # )
self._initialize_weights() self._initialize_weights()
...@@ -117,7 +132,7 @@ class MobileNetV2(torch.nn.Module): ...@@ -117,7 +132,7 @@ class MobileNetV2(torch.nn.Module):
outputs.append(x.shape[2:4]) outputs.append(x.shape[2:4])
if self.m2u: if self.m2u:
outputs.append(x) outputs.append(x)
for index,m in enumerate(self.features): for index, m in enumerate(self.features):
x = m(x) x = m(x)
# extract layers # extract layers
if index in self.return_features: if index in self.return_features:
...@@ -128,7 +143,7 @@ class MobileNetV2(torch.nn.Module): ...@@ -128,7 +143,7 @@ class MobileNetV2(torch.nn.Module):
for m in self.modules(): for m in self.modules():
if isinstance(m, torch.nn.Conv2d): if isinstance(m, torch.nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n)) m.weight.data.normal_(0, math.sqrt(2.0 / n))
if m.bias is not None: if m.bias is not None:
m.bias.data.zero_() m.bias.data.zero_()
elif isinstance(m, torch.nn.BatchNorm2d): elif isinstance(m, torch.nn.BatchNorm2d):
......
...@@ -18,20 +18,13 @@ model_urls = { ...@@ -18,20 +18,13 @@ model_urls = {
def _conv3x3(in_planes, out_planes, stride=1): def _conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding""" """3x3 convolution with padding"""
return nn.Conv2d( return nn.Conv2d(
in_planes, in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False,
out_planes,
kernel_size=3,
stride=stride,
padding=1,
bias=False,
) )
def _conv1x1(in_planes, out_planes, stride=1): def _conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution""" """1x1 convolution"""
return nn.Conv2d( return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
in_planes, out_planes, kernel_size=1, stride=stride, bias=False
)
class _BasicBlock(nn.Module): class _BasicBlock(nn.Module):
...@@ -105,9 +98,7 @@ class _Bottleneck(nn.Module): ...@@ -105,9 +98,7 @@ class _Bottleneck(nn.Module):
class ResNet(nn.Module): class ResNet(nn.Module):
def __init__( def __init__(self, block, layers, return_features, zero_init_residual=False):
self, block, layers, return_features, zero_init_residual=False
):
""" """
Generic ResNet network with layer return. Generic ResNet network with layer return.
Attributes Attributes
...@@ -118,9 +109,7 @@ class ResNet(nn.Module): ...@@ -118,9 +109,7 @@ class ResNet(nn.Module):
super(ResNet, self).__init__() super(ResNet, self).__init__()
self.inplanes = 64 self.inplanes = 64
self.return_features = return_features self.return_features = return_features
self.conv1 = nn.Conv2d( self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
3, 64, kernel_size=7, stride=2, padding=3, bias=False
)
self.bn1 = nn.BatchNorm2d(64) self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True) self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
...@@ -142,9 +131,7 @@ class ResNet(nn.Module): ...@@ -142,9 +131,7 @@ class ResNet(nn.Module):
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_( nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
m.weight, mode="fan_out", nonlinearity="relu"
)
elif isinstance(m, nn.BatchNorm2d): elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1) nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
...@@ -229,9 +216,7 @@ def shaperesnet50(pretrained=False, **kwargs): ...@@ -229,9 +216,7 @@ def shaperesnet50(pretrained=False, **kwargs):
if pretrained: if pretrained:
model.load_state_dict( model.load_state_dict(
model_zoo.load_url( model_zoo.load_url(
model_urls[ model_urls["resnet50_trained_on_SIN_and_IN_then_finetuned_on_IN"]
"resnet50_trained_on_SIN_and_IN_then_finetuned_on_IN"
]
) )
) )
return model return model
......
...@@ -8,19 +8,18 @@ import torch.utils.model_zoo as model_zoo ...@@ -8,19 +8,18 @@ import torch.utils.model_zoo as model_zoo
model_urls = { model_urls = {
'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', "vgg11": "https://download.pytorch.org/models/vgg11-bbd30ac9.pth",
'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', "vgg13": "https://download.pytorch.org/models/vgg13-c768596a.pth",
'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', "vgg16": "https://download.pytorch.org/models/vgg16-397923af.pth",
'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', "vgg19": "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth', "vgg11_bn": "https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth', "vgg13_bn": "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth', "vgg16_bn": "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth', "vgg19_bn": "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth",
} }
class VGG(nn.Module): class VGG(nn.Module):
def __init__(self, features, return_features, init_weights=True): def __init__(self, features, return_features, init_weights=True):
super(VGG, self).__init__() super(VGG, self).__init__()
self.features = features self.features = features
...@@ -32,7 +31,7 @@ class VGG(nn.Module): ...@@ -32,7 +31,7 @@ class VGG(nn.Module):
outputs = [] outputs = []
# hw of input, needed for DRIU and HED # hw of input, needed for DRIU and HED
outputs.append(x.shape[2:4]) outputs.append(x.shape[2:4])
for index,m in enumerate(self.features): for index, m in enumerate(self.features):
x = m(x) x = m(x)
# extract layers # extract layers
if index in self.return_features: if index in self.return_features:
...@@ -42,7 +41,7 @@ class VGG(nn.Module): ...@@ -42,7 +41,7 @@ class VGG(nn.Module):
def _initialize_weights(self): def _initialize_weights(self):
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if m.bias is not None: if m.bias is not None:
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d): elif isinstance(m, nn.BatchNorm2d):
...@@ -57,7 +56,7 @@ def _make_layers(cfg, batch_norm=False): ...@@ -57,7 +56,7 @@ def _make_layers(cfg, batch_norm=False):
layers = [] layers = []
in_channels = 3 in_channels = 3
for v in cfg: for v in cfg:
if v == 'M': if v == "M":
layers.append(nn.MaxPool2d(kernel_size=2, stride=2)) layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
else: else:
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
...@@ -70,10 +69,51 @@ def _make_layers(cfg, batch_norm=False): ...@@ -70,10 +69,51 @@ def _make_layers(cfg, batch_norm=False):
_cfg = { _cfg = {
'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], "A": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], "B": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], "D": [
'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 64,
64,
"M",
128,
128,
"M",
256,
256,
256,
"M",
512,
512,
512,
"M",
512,
512,
512,
"M",
],
"E": [
64,
64,
"M",
128,
128,
"M",
256,
256,
256,
256,
"M",
512,
512,
512,
512,
"M",
512,
512,
512,
512,
"M",
],
} }
...@@ -83,10 +123,10 @@ def vgg11(pretrained=False, **kwargs): ...@@ -83,10 +123,10 @@ def vgg11(pretrained=False, **kwargs):
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
""" """
if pretrained: if pretrained:
kwargs['init_weights'] = False kwargs["init_weights"] = False
model = VGG(_make_layers(_cfg['A']), **kwargs) model = VGG(_make_layers(_cfg["A"]), **kwargs)
if pretrained: if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['vgg11'])) model.load_state_dict(model_zoo.load_url(model_urls["vgg11"]))
return model return model
...@@ -96,10 +136,10 @@ def vgg11_bn(pretrained=False, **kwargs): ...@@ -96,10 +136,10 @@ def vgg11_bn(pretrained=False, **kwargs):
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
""" """
if pretrained: if pretrained:
kwargs['init_weights'] = False kwargs["init_weights"] = False
model = VGG(_make_layers(_cfg['A'], batch_norm=True), **kwargs) model = VGG(_make_layers(_cfg["A"], batch_norm=True), **kwargs)
if pretrained: if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['vgg11_bn'])) model.load_state_dict(model_zoo.load_url(model_urls["vgg11_bn"]))
return model return model
...@@ -109,10 +149,10 @@ def vgg13(pretrained=False, **kwargs): ...@@ -109,10 +149,10 @@ def vgg13(pretrained=False, **kwargs):
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
""" """
if pretrained: if pretrained:
kwargs['init_weights'] = False kwargs["init_weights"] = False
model = VGG(_make_layers(_cfg['B']), **kwargs) model = VGG(_make_layers(_cfg["B"]), **kwargs)
if pretrained: if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['vgg13'])) model.load_state_dict(model_zoo.load_url(model_urls["vgg13"]))
return model return model
...@@ -122,10 +162,10 @@ def vgg13_bn(pretrained=False, **kwargs): ...@@ -122,10 +162,10 @@ def vgg13_bn(pretrained=False, **kwargs):
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
""" """
if pretrained: if pretrained:
kwargs['init_weights'] = False kwargs["init_weights"] = False
model = VGG(_make_layers(_cfg['B'], batch_norm=True), **kwargs) model = VGG(_make_layers(_cfg["B"], batch_norm=True), **kwargs)
if pretrained: if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['vgg13_bn'])) model.load_state_dict(model_zoo.load_url(model_urls["vgg13_bn"]))
return model return model
...@@ -135,10 +175,10 @@ def vgg16(pretrained=False, **kwargs): ...@@ -135,10 +175,10 @@ def vgg16(pretrained=False, **kwargs):
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
""" """
if pretrained: if pretrained:
kwargs['init_weights'] = False kwargs["init_weights"] = False
model = VGG(_make_layers(_cfg['D']), **kwargs) model = VGG(_make_layers(_cfg["D"]), **kwargs)
if pretrained: if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['vgg16']),strict=False) model.load_state_dict(model_zoo.load_url(model_urls["vgg16"]), strict=False)
return model return model
...@@ -148,10 +188,10 @@ def vgg16_bn(pretrained=False, **kwargs): ...@@ -148,10 +188,10 @@ def vgg16_bn(pretrained=False, **kwargs):
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
""" """
if pretrained: if pretrained:
kwargs['init_weights'] = False kwargs["init_weights"] = False
model = VGG(_make_layers(_cfg['D'], batch_norm=True), **kwargs) model = VGG(_make_layers(_cfg["D"], batch_norm=True), **kwargs)
if pretrained: if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['vgg16_bn'])) model.load_state_dict(model_zoo.load_url(model_urls["vgg16_bn"]))
return model return model
...@@ -161,10 +201,10 @@ def vgg19(pretrained=False, **kwargs): ...@@ -161,10 +201,10 @@ def vgg19(pretrained=False, **kwargs):
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
""" """
if pretrained: if pretrained:
kwargs['init_weights'] = False kwargs["init_weights"] = False
model = VGG(_make_layers(_cfg['E']), **kwargs) model = VGG(_make_layers(_cfg["E"]), **kwargs)
if pretrained: if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['vgg19'])) model.load_state_dict(model_zoo.load_url(model_urls["vgg19"]))
return model return model
...@@ -174,8 +214,8 @@ def vgg19_bn(pretrained=False, **kwargs): ...@@ -174,8 +214,8 @@ def vgg19_bn(pretrained=False, **kwargs):
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
""" """
if pretrained: if pretrained:
kwargs['init_weights'] = False kwargs["init_weights"] = False
model = VGG(_make_layers(_cfg['E'], batch_norm=True), **kwargs) model = VGG(_make_layers(_cfg["E"], batch_norm=True), **kwargs)
if pretrained: if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['vgg19_bn'])) model.load_state_dict(model_zoo.load_url(model_urls["vgg19_bn"]))
return model return model
...@@ -43,12 +43,7 @@ class DRIU(torch.nn.Module): ...@@ -43,12 +43,7 @@ class DRIU(torch.nn.Module):
def __init__(self, in_channels_list=None): def __init__(self, in_channels_list=None):
super(DRIU, self).__init__() super(DRIU, self).__init__()
( (in_conv_1_2_16, in_upsample2, in_upsample_4, in_upsample_8,) = in_channels_list
in_conv_1_2_16,
in_upsample2,
in_upsample_4,
in_upsample_8,
) = in_channels_list
self.conv1_2_16 = torch.nn.Conv2d(in_conv_1_2_16, 16, 3, 1, 1) self.conv1_2_16 = torch.nn.Conv2d(in_conv_1_2_16, 16, 3, 1, 1)
# Upsample layers # Upsample layers
......
...@@ -5,24 +5,31 @@ import torch ...@@ -5,24 +5,31 @@ import torch
import torch.nn import torch.nn
from collections import OrderedDict from collections import OrderedDict
from bob.ip.binseg.modeling.backbones.vgg import vgg16_bn from bob.ip.binseg.modeling.backbones.vgg import vgg16_bn
from bob.ip.binseg.modeling.make_layers import conv_with_kaiming_uniform,convtrans_with_kaiming_uniform, UpsampleCropBlock from bob.ip.binseg.modeling.make_layers import (
conv_with_kaiming_uniform,
convtrans_with_kaiming_uniform,
UpsampleCropBlock,
)
class ConcatFuseBlock(torch.nn.Module): class ConcatFuseBlock(torch.nn.Module):
""" """
Takes in four feature maps with 16 channels each, concatenates them Takes in four feature maps with 16 channels each, concatenates them
and applies a 1x1 convolution with 1 output channel. and applies a 1x1 convolution with 1 output channel.
""" """
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.conv = torch.nn.Sequential( self.conv = torch.nn.Sequential(
conv_with_kaiming_uniform(4*16,1,1,1,0) conv_with_kaiming_uniform(4 * 16, 1, 1, 1, 0), torch.nn.BatchNorm2d(1)
,torch.nn.BatchNorm2d(1)
) )
def forward(self,x1,x2,x3,x4):
x_cat = torch.cat([x1,x2,x3,x4],dim=1) def forward(self, x1, x2, x3, x4):
x_cat = torch.cat([x1, x2, x3, x4], dim=1)
x = self.conv(x_cat) x = self.conv(x_cat)
return x return x
class DRIU(torch.nn.Module): class DRIU(torch.nn.Module):
""" """
DRIU head module DRIU head module
...@@ -34,6 +41,7 @@ class DRIU(torch.nn.Module): ...@@ -34,6 +41,7 @@ class DRIU(torch.nn.Module):
in_channels_list : list in_channels_list : list
number of channels for each feature map that is returned from backbone number of channels for each feature map that is returned from backbone
""" """
def __init__(self, in_channels_list=None): def __init__(self, in_channels_list=None):
super(DRIU, self).__init__() super(DRIU, self).__init__()
in_conv_1_2_16, in_upsample2, in_upsample_4, in_upsample_8 = in_channels_list in_conv_1_2_16, in_upsample2, in_upsample_4, in_upsample_8 = in_channels_list
...@@ -47,7 +55,7 @@ class DRIU(torch.nn.Module): ...@@ -47,7 +55,7 @@ class DRIU(torch.nn.Module):
# Concat and Fuse # Concat and Fuse
self.concatfuse = ConcatFuseBlock() self.concatfuse = ConcatFuseBlock()
def forward(self,x): def forward(self, x):
""" """
Parameters Parameters
---------- ----------
...@@ -62,12 +70,13 @@ class DRIU(torch.nn.Module): ...@@ -62,12 +70,13 @@ class DRIU(torch.nn.Module):
""" """
hw = x[0] hw = x[0]
conv1_2_16 = self.conv1_2_16(x[1]) # conv1_2_16 conv1_2_16 = self.conv1_2_16(x[1]) # conv1_2_16
upsample2 = self.upsample2(x[2], hw) # side-multi2-up upsample2 = self.upsample2(x[2], hw) # side-multi2-up
upsample4 = self.upsample4(x[3], hw) # side-multi3-up upsample4 = self.upsample4(x[3], hw) # side-multi3-up
upsample8 = self.upsample8(x[4], hw) # side-multi4-up upsample8 = self.upsample8(x[4], hw) # side-multi4-up
out = self.concatfuse(conv1_2_16, upsample2, upsample4, upsample8) out = self.concatfuse(conv1_2_16, upsample2, upsample4, upsample8)
return out return out
def build_driu(): def build_driu():
""" """
Adds backbone and head together Adds backbone and head together
...@@ -78,9 +87,11 @@ def build_driu(): ...@@ -78,9 +87,11 @@ def build_driu():
module : :py:class:`torch.nn.Module` module : :py:class:`torch.nn.Module`
""" """
backbone = vgg16_bn(pretrained=False, return_features = [5, 12, 19, 29]) backbone = vgg16_bn(pretrained=False, return_features=[5, 12, 19, 29])
driu_head = DRIU([64, 128, 256, 512]) driu_head = DRIU([64, 128, 256, 512])
model = torch.nn.Sequential(OrderedDict([("backbone", backbone), ("head", driu_head)])) model = torch.nn.Sequential(
OrderedDict([("backbone", backbone), ("head", driu_head)])
)
model.name = "DRIUBN" model.name = "DRIUBN"
return model return model
...@@ -5,22 +5,29 @@ import torch ...@@ -5,22 +5,29 @@ import torch
import torch.nn import torch.nn
from collections import OrderedDict from collections import OrderedDict
from bob.ip.binseg.modeling.backbones.vgg import vgg16 from bob.ip.binseg.modeling.backbones.vgg import vgg16
from bob.ip.binseg.modeling.make_layers import conv_with_kaiming_uniform,convtrans_with_kaiming_uniform, UpsampleCropBlock from bob.ip.binseg.modeling.make_layers import (
conv_with_kaiming_uniform,
convtrans_with_kaiming_uniform,
UpsampleCropBlock,
)
class ConcatFuseBlock(torch.nn.Module): class ConcatFuseBlock(torch.nn.Module):
""" """
Takes in four feature maps with 16 channels each, concatenates them Takes in four feature maps with 16 channels each, concatenates them
and applies a 1x1 convolution with 1 output channel. and applies a 1x1 convolution with 1 output channel.
""" """
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.conv = conv_with_kaiming_uniform(4*16,1,1,1,0) self.conv = conv_with_kaiming_uniform(4 * 16, 1, 1, 1, 0)
def forward(self,x1,x2,x3,x4): def forward(self, x1, x2, x3, x4):
x_cat = torch.cat([x1,x2,x3,x4],dim=1) x_cat = torch.cat([x1, x2, x3, x4], dim=1)
x = self.conv(x_cat) x = self.conv(x_cat)
return x return x
class DRIUOD(torch.nn.Module): class DRIUOD(torch.nn.Module):
""" """
DRIU head module DRIU head module
...@@ -30,6 +37,7 @@ class DRIUOD(torch.nn.Module): ...@@ -30,6 +37,7 @@ class DRIUOD(torch.nn.Module):
in_channels_list : list in_channels_list : list
number of channels for each feature map that is returned from backbone number of channels for each feature map that is returned from backbone
""" """
def __init__(self, in_channels_list=None): def __init__(self, in_channels_list=None):
super(DRIUOD, self).__init__() super(DRIUOD, self).__init__()
in_upsample2, in_upsample_4, in_upsample_8, in_upsample_16 = in_channels_list in_upsample2, in_upsample_4, in_upsample_8, in_upsample_16 = in_channels_list
...@@ -40,11 +48,10 @@ class DRIUOD(torch.nn.Module): ...@@ -40,11 +48,10 @@ class DRIUOD(torch.nn.Module):
self.upsample8 = UpsampleCropBlock(in_upsample_8, 16, 16, 8, 0) self.upsample8 = UpsampleCropBlock(in_upsample_8, 16, 16, 8, 0)
self.upsample16 = UpsampleCropBlock(in_upsample_16, 16, 32, 16, 0) self.upsample16 = UpsampleCropBlock(in_upsample_16, 16, 32, 16, 0)
# Concat and Fuse # Concat and Fuse
self.concatfuse = ConcatFuseBlock() self.concatfuse = ConcatFuseBlock()
def forward(self,x): def forward(self, x):
""" """
Parameters Parameters
---------- ----------
...@@ -59,12 +66,13 @@ class DRIUOD(torch.nn.Module): ...@@ -59,12 +66,13 @@ class DRIUOD(torch.nn.Module):
""" """
hw = x[0] hw = x[0]
upsample2 = self.upsample2(x[1], hw) # side-multi2-up upsample2 = self.upsample2(x[1], hw) # side-multi2-up
upsample4 = self.upsample4(x[2], hw) # side-multi3-up upsample4 = self.upsample4(x[2], hw) # side-multi3-up
upsample8 = self.upsample8(x[3], hw) # side-multi4-up upsample8 = self.upsample8(x[3], hw) # side-multi4-up
upsample16 = self.upsample16(x[4], hw) # side-multi5-up upsample16 = self.upsample16(x[4], hw) # side-multi5-up
out = self.concatfuse(upsample2, upsample4, upsample8,upsample16) out = self.concatfuse(upsample2, upsample4, upsample8, upsample16)
return out return out
def build_driuod(): def build_driuod():
""" """
Adds backbone and head together Adds backbone and head together
...@@ -74,9 +82,11 @@ def build_driuod(): ...@@ -74,9 +82,11 @@ def build_driuod():
module : :py:class:`torch.nn.Module` module : :py:class:`torch.nn.Module`
""" """
backbone = vgg16(pretrained=False, return_features = [8, 14, 22,29]) backbone = vgg16(pretrained=False, return_features=[8, 14, 22, 29])
driu_head = DRIUOD([128, 256, 512,512]) driu_head = DRIUOD([128, 256, 512, 512])
model = torch.nn.Sequential(OrderedDict([("backbone", backbone), ("head", driu_head)])) model = torch.nn.Sequential(
OrderedDict([("backbone", backbone), ("head", driu_head)])
)
model.name = "DRIUOD" model.name = "DRIUOD"
return model return model
...@@ -5,22 +5,29 @@ import torch ...@@ -5,22 +5,29 @@ import torch
import torch.nn import torch.nn
from collections import OrderedDict from collections import OrderedDict
from bob.ip.binseg.modeling.backbones.vgg import vgg16 from bob.ip.binseg.modeling.backbones.vgg import vgg16
from bob.ip.binseg.modeling.make_layers import conv_with_kaiming_uniform,convtrans_with_kaiming_uniform, UpsampleCropBlock from bob.ip.binseg.modeling.make_layers import (
conv_with_kaiming_uniform,
convtrans_with_kaiming_uniform,
UpsampleCropBlock,
)
class ConcatFuseBlock(torch.nn.Module): class ConcatFuseBlock(torch.nn.Module):
""" """
Takes in four feature maps with 16 channels each, concatenates them Takes in four feature maps with 16 channels each, concatenates them
and applies a 1x1 convolution with 1 output channel. and applies a 1x1 convolution with 1 output channel.
""" """
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.conv = conv_with_kaiming_uniform(4*16,1,1,1,0) self.conv = conv_with_kaiming_uniform(4 * 16, 1, 1, 1, 0)
def forward(self,x1,x2,x3,x4): def forward(self, x1, x2, x3, x4):
x_cat = torch.cat([x1,x2,x3,x4],dim=1) x_cat = torch.cat([x1, x2, x3, x4], dim=1)
x = self.conv(x_cat) x = self.conv(x_cat)
return x return x
class DRIUPIX(torch.nn.Module): class DRIUPIX(torch.nn.Module):
""" """
DRIUPIX head module. DRIU with pixelshuffle instead of ConvTrans2D DRIUPIX head module. DRIU with pixelshuffle instead of ConvTrans2D
...@@ -30,6 +37,7 @@ class DRIUPIX(torch.nn.Module): ...@@ -30,6 +37,7 @@ class DRIUPIX(torch.nn.Module):
in_channels_list : list in_channels_list : list
number of channels for each feature map that is returned from backbone number of channels for each feature map that is returned from backbone
""" """
def __init__(self, in_channels_list=None): def __init__(self, in_channels_list=None):
super(DRIUPIX, self).__init__() super(DRIUPIX, self).__init__()
in_conv_1_2_16, in_upsample2, in_upsample_4, in_upsample_8 = in_channels_list in_conv_1_2_16, in_upsample2, in_upsample_4, in_upsample_8 = in_channels_list
...@@ -37,13 +45,17 @@ class DRIUPIX(torch.nn.Module): ...@@ -37,13 +45,17 @@ class DRIUPIX(torch.nn.Module):
self.conv1_2_16 = torch.nn.Conv2d(in_conv_1_2_16, 16, 3, 1, 1) self.conv1_2_16 = torch.nn.Conv2d(in_conv_1_2_16, 16, 3, 1, 1)
# Upsample layers # Upsample layers
self.upsample2 = UpsampleCropBlock(in_upsample2, 16, 4, 2, 0, pixelshuffle=True) self.upsample2 = UpsampleCropBlock(in_upsample2, 16, 4, 2, 0, pixelshuffle=True)
self.upsample4 = UpsampleCropBlock(in_upsample_4, 16, 8, 4, 0, pixelshuffle=True) self.upsample4 = UpsampleCropBlock(
self.upsample8 = UpsampleCropBlock(in_upsample_8, 16, 16, 8, 0, pixelshuffle=True) in_upsample_4, 16, 8, 4, 0, pixelshuffle=True
)
self.upsample8 = UpsampleCropBlock(
in_upsample_8, 16, 16, 8, 0, pixelshuffle=True
)
# Concat and Fuse # Concat and Fuse
self.concatfuse = ConcatFuseBlock() self.concatfuse = ConcatFuseBlock()
def forward(self,x): def forward(self, x):
""" """
Parameters Parameters
---------- ----------
...@@ -58,12 +70,13 @@ class DRIUPIX(torch.nn.Module): ...@@ -58,12 +70,13 @@ class DRIUPIX(torch.nn.Module):
""" """
hw = x[0] hw = x[0]
conv1_2_16 = self.conv1_2_16(x[1]) # conv1_2_16 conv1_2_16 = self.conv1_2_16(x[1]) # conv1_2_16
upsample2 = self.upsample2(x[2], hw) # side-multi2-up upsample2 = self.upsample2(x[2], hw) # side-multi2-up
upsample4 = self.upsample4(x[3], hw) # side-multi3-up upsample4 = self.upsample4(x[3], hw) # side-multi3-up
upsample8 = self.upsample8(x[4], hw) # side-multi4-up upsample8 = self.upsample8(x[4], hw) # side-multi4-up
out = self.concatfuse(conv1_2_16, upsample2, upsample4, upsample8) out = self.concatfuse(conv1_2_16, upsample2, upsample4, upsample8)
return out return out
def build_driupix(): def build_driupix():
""" """
Adds backbone and head together Adds backbone and head together
...@@ -73,9 +86,11 @@ def build_driupix(): ...@@ -73,9 +86,11 @@ def build_driupix():
module : :py:class:`torch.nn.Module` module : :py:class:`torch.nn.Module`
""" """
backbone = vgg16(pretrained=False, return_features = [3, 8, 14, 22]) backbone = vgg16(pretrained=False, return_features=[3, 8, 14, 22])
driu_head = DRIUPIX([64, 128, 256, 512]) driu_head = DRIUPIX([64, 128, 256, 512])
model = torch.nn.Sequential(OrderedDict([("backbone", backbone), ("head", driu_head)])) model = torch.nn.Sequential(
OrderedDict([("backbone", backbone), ("head", driu_head)])
)
model.name = "DRIUPIX" model.name = "DRIUPIX"
return model return model
...@@ -5,22 +5,29 @@ import torch ...@@ -5,22 +5,29 @@ import torch
import torch.nn import torch.nn
from collections import OrderedDict from collections import OrderedDict
from bob.ip.binseg.modeling.backbones.vgg import vgg16 from bob.ip.binseg.modeling.backbones.vgg import vgg16
from bob.ip.binseg.modeling.make_layers import conv_with_kaiming_uniform, convtrans_with_kaiming_uniform, UpsampleCropBlock from bob.ip.binseg.modeling.make_layers import (
conv_with_kaiming_uniform,
convtrans_with_kaiming_uniform,
UpsampleCropBlock,
)
class ConcatFuseBlock(torch.nn.Module): class ConcatFuseBlock(torch.nn.Module):
""" """
Takes in five feature maps with one channel each, concatenates thems Takes in five feature maps with one channel each, concatenates thems
and applies a 1x1 convolution with 1 output channel. and applies a 1x1 convolution with 1 output channel.
""" """
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.conv = conv_with_kaiming_uniform(5,1,1,1,0) self.conv = conv_with_kaiming_uniform(5, 1, 1, 1, 0)
def forward(self,x1,x2,x3,x4,x5): def forward(self, x1, x2, x3, x4, x5):
x_cat = torch.cat([x1,x2,x3,x4,x5],dim=1) x_cat = torch.cat([x1, x2, x3, x4, x5], dim=1)
x = self.conv(x_cat) x = self.conv(x_cat)
return x return x
class HED(torch.nn.Module): class HED(torch.nn.Module):
""" """
HED head module HED head module
...@@ -30,20 +37,27 @@ class HED(torch.nn.Module): ...@@ -30,20 +37,27 @@ class HED(torch.nn.Module):
in_channels_list : list in_channels_list : list
number of channels for each feature map that is returned from backbone number of channels for each feature map that is returned from backbone
""" """
def __init__(self, in_channels_list=None): def __init__(self, in_channels_list=None):
super(HED, self).__init__() super(HED, self).__init__()
in_conv_1_2_16, in_upsample2, in_upsample_4, in_upsample_8, in_upsample_16 = in_channels_list (
in_conv_1_2_16,
in_upsample2,
in_upsample_4,
in_upsample_8,
in_upsample_16,
) = in_channels_list
self.conv1_2_16 = torch.nn.Conv2d(in_conv_1_2_16,1,3,1,1) self.conv1_2_16 = torch.nn.Conv2d(in_conv_1_2_16, 1, 3, 1, 1)
# Upsample # Upsample
self.upsample2 = UpsampleCropBlock(in_upsample2,1,4,2,0) self.upsample2 = UpsampleCropBlock(in_upsample2, 1, 4, 2, 0)
self.upsample4 = UpsampleCropBlock(in_upsample_4,1,8,4,0) self.upsample4 = UpsampleCropBlock(in_upsample_4, 1, 8, 4, 0)
self.upsample8 = UpsampleCropBlock(in_upsample_8,1,16,8,0) self.upsample8 = UpsampleCropBlock(in_upsample_8, 1, 16, 8, 0)
self.upsample16 = UpsampleCropBlock(in_upsample_16,1,32,16,0) self.upsample16 = UpsampleCropBlock(in_upsample_16, 1, 32, 16, 0)
# Concat and Fuse # Concat and Fuse
self.concatfuse = ConcatFuseBlock() self.concatfuse = ConcatFuseBlock()
def forward(self,x): def forward(self, x):
""" """
Parameters Parameters
---------- ----------
...@@ -58,15 +72,18 @@ class HED(torch.nn.Module): ...@@ -58,15 +72,18 @@ class HED(torch.nn.Module):
""" """
hw = x[0] hw = x[0]
conv1_2_16 = self.conv1_2_16(x[1]) conv1_2_16 = self.conv1_2_16(x[1])
upsample2 = self.upsample2(x[2],hw) upsample2 = self.upsample2(x[2], hw)
upsample4 = self.upsample4(x[3],hw) upsample4 = self.upsample4(x[3], hw)
upsample8 = self.upsample8(x[4],hw) upsample8 = self.upsample8(x[4], hw)
upsample16 = self.upsample16(x[5],hw) upsample16 = self.upsample16(x[5], hw)
concatfuse = self.concatfuse(conv1_2_16,upsample2,upsample4,upsample8,upsample16) concatfuse = self.concatfuse(
conv1_2_16, upsample2, upsample4, upsample8, upsample16
)
out = [upsample2,upsample4,upsample8,upsample16,concatfuse] out = [upsample2, upsample4, upsample8, upsample16, concatfuse]
return out return out
def build_hed(): def build_hed():
""" """
Adds backbone and head together Adds backbone and head together
...@@ -75,9 +92,11 @@ def build_hed(): ...@@ -75,9 +92,11 @@ def build_hed():
------- -------
module : :py:class:`torch.nn.Module` module : :py:class:`torch.nn.Module`
""" """
backbone = vgg16(pretrained=False, return_features = [3, 8, 14, 22, 29]) backbone = vgg16(pretrained=False, return_features=[3, 8, 14, 22, 29])
hed_head = HED([64, 128, 256, 512, 512]) hed_head = HED([64, 128, 256, 512, 512])
model = torch.nn.Sequential(OrderedDict([("backbone", backbone), ("head", hed_head)])) model = torch.nn.Sequential(
OrderedDict([("backbone", backbone), ("head", hed_head)])
)
model.name = "HED" model.name = "HED"
return model return model
...@@ -32,9 +32,7 @@ class WeightedBCELogitsLoss(_Loss): ...@@ -32,9 +32,7 @@ class WeightedBCELogitsLoss(_Loss):
reduction="mean", reduction="mean",
pos_weight=None, pos_weight=None,
): ):
super(WeightedBCELogitsLoss, self).__init__( super(WeightedBCELogitsLoss, self).__init__(size_average, reduce, reduction)
size_average, reduce, reduction
)
self.register_buffer("weight", weight) self.register_buffer("weight", weight)
self.register_buffer("pos_weight", pos_weight) self.register_buffer("pos_weight", pos_weight)
...@@ -56,9 +54,7 @@ class WeightedBCELogitsLoss(_Loss): ...@@ -56,9 +54,7 @@ class WeightedBCELogitsLoss(_Loss):
torch.sum(target, dim=[1, 2, 3]).float().reshape(n, 1) torch.sum(target, dim=[1, 2, 3]).float().reshape(n, 1)
) # torch.Size([n, 1]) ) # torch.Size([n, 1])
if hasattr(masks, "dtype"): if hasattr(masks, "dtype"):
num_mask_neg = c * h * w - torch.sum( num_mask_neg = c * h * w - torch.sum(masks, dim=[1, 2, 3]).float().reshape(
masks, dim=[1, 2, 3]
).float().reshape(
n, 1 n, 1
) # torch.Size([n, 1]) ) # torch.Size([n, 1])
num_neg = c * h * w - num_pos - num_mask_neg num_neg = c * h * w - num_pos - num_mask_neg
...@@ -97,9 +93,7 @@ class SoftJaccardBCELogitsLoss(_Loss): ...@@ -97,9 +93,7 @@ class SoftJaccardBCELogitsLoss(_Loss):
reduction="mean", reduction="mean",
pos_weight=None, pos_weight=None,
): ):
super(SoftJaccardBCELogitsLoss, self).__init__( super(SoftJaccardBCELogitsLoss, self).__init__(size_average, reduce, reduction)
size_average, reduce, reduction
)
self.alpha = alpha self.alpha = alpha
@weak_script_method @weak_script_method
...@@ -145,9 +139,7 @@ class HEDWeightedBCELogitsLoss(_Loss): ...@@ -145,9 +139,7 @@ class HEDWeightedBCELogitsLoss(_Loss):
reduction="mean", reduction="mean",
pos_weight=None, pos_weight=None,
): ):
super(HEDWeightedBCELogitsLoss, self).__init__( super(HEDWeightedBCELogitsLoss, self).__init__(size_average, reduce, reduction)
size_average, reduce, reduction
)
self.register_buffer("weight", weight) self.register_buffer("weight", weight)
self.register_buffer("pos_weight", pos_weight) self.register_buffer("pos_weight", pos_weight)
...@@ -185,9 +177,7 @@ class HEDWeightedBCELogitsLoss(_Loss): ...@@ -185,9 +177,7 @@ class HEDWeightedBCELogitsLoss(_Loss):
numnegnumtotal = torch.ones_like(target) * ( numnegnumtotal = torch.ones_like(target) * (
num_neg / (num_pos + num_neg) num_neg / (num_pos + num_neg)
).unsqueeze(1).unsqueeze(2) ).unsqueeze(1).unsqueeze(2)
weight = torch.where( weight = torch.where((target <= 0.5), numposnumtotal, numnegnumtotal)
(target <= 0.5), numposnumtotal, numnegnumtotal
)
loss = torch.nn.functional.binary_cross_entropy_with_logits( loss = torch.nn.functional.binary_cross_entropy_with_logits(
input, target, weight=weight, reduction=self.reduction input, target, weight=weight, reduction=self.reduction
) )
...@@ -278,9 +268,7 @@ class MixJacLoss(_Loss): ...@@ -278,9 +268,7 @@ class MixJacLoss(_Loss):
self.unlabeled_loss = torch.nn.BCEWithLogitsLoss() self.unlabeled_loss = torch.nn.BCEWithLogitsLoss()
@weak_script_method @weak_script_method
def forward( def forward(self, input, target, unlabeled_input, unlabeled_traget, ramp_up_factor):
self, input, target, unlabeled_input, unlabeled_traget, ramp_up_factor
):
""" """
Parameters Parameters
---------- ----------
......
...@@ -8,35 +8,46 @@ import torch ...@@ -8,35 +8,46 @@ import torch
import torch.nn import torch.nn
from bob.ip.binseg.modeling.backbones.mobilenetv2 import MobileNetV2, InvertedResidual from bob.ip.binseg.modeling.backbones.mobilenetv2 import MobileNetV2, InvertedResidual
class DecoderBlock(torch.nn.Module): class DecoderBlock(torch.nn.Module):
""" """
Decoder block: upsample and concatenate with features maps from the encoder part Decoder block: upsample and concatenate with features maps from the encoder part
""" """
def __init__(self,up_in_c,x_in_c,upsamplemode='bilinear',expand_ratio=0.15):
super().__init__()
self.upsample = torch.nn.Upsample(scale_factor=2,mode=upsamplemode,align_corners=False) # H, W -> 2H, 2W
self.ir1 = InvertedResidual(up_in_c+x_in_c,(x_in_c + up_in_c) // 2,stride=1,expand_ratio=expand_ratio)
def forward(self,up_in,x_in): def __init__(self, up_in_c, x_in_c, upsamplemode="bilinear", expand_ratio=0.15):
super().__init__()
self.upsample = torch.nn.Upsample(
scale_factor=2, mode=upsamplemode, align_corners=False
) # H, W -> 2H, 2W
self.ir1 = InvertedResidual(
up_in_c + x_in_c,
(x_in_c + up_in_c) // 2,
stride=1,
expand_ratio=expand_ratio,
)
def forward(self, up_in, x_in):
up_out = self.upsample(up_in) up_out = self.upsample(up_in)
cat_x = torch.cat([up_out, x_in] , dim=1) cat_x = torch.cat([up_out, x_in], dim=1)
x = self.ir1(cat_x) x = self.ir1(cat_x)
return x return x
class LastDecoderBlock(torch.nn.Module): class LastDecoderBlock(torch.nn.Module):
def __init__(self,x_in_c,upsamplemode='bilinear',expand_ratio=0.15): def __init__(self, x_in_c, upsamplemode="bilinear", expand_ratio=0.15):
super().__init__() super().__init__()
self.upsample = torch.nn.Upsample(scale_factor=2,mode=upsamplemode,align_corners=False) # H, W -> 2H, 2W self.upsample = torch.nn.Upsample(
self.ir1 = InvertedResidual(x_in_c,1,stride=1,expand_ratio=expand_ratio) scale_factor=2, mode=upsamplemode, align_corners=False
) # H, W -> 2H, 2W
self.ir1 = InvertedResidual(x_in_c, 1, stride=1, expand_ratio=expand_ratio)
def forward(self,up_in,x_in): def forward(self, up_in, x_in):
up_out = self.upsample(up_in) up_out = self.upsample(up_in)
cat_x = torch.cat([up_out, x_in] , dim=1) cat_x = torch.cat([up_out, x_in], dim=1)
x = self.ir1(cat_x) x = self.ir1(cat_x)
return x return x
class M2U(torch.nn.Module): class M2U(torch.nn.Module):
""" """
M2U-Net head module M2U-Net head module
...@@ -46,14 +57,17 @@ class M2U(torch.nn.Module): ...@@ -46,14 +57,17 @@ class M2U(torch.nn.Module):
in_channels_list : list in_channels_list : list
number of channels for each feature map that is returned from backbone number of channels for each feature map that is returned from backbone
""" """
def __init__(self, in_channels_list=None,upsamplemode='bilinear',expand_ratio=0.15):
def __init__(
self, in_channels_list=None, upsamplemode="bilinear", expand_ratio=0.15
):
super(M2U, self).__init__() super(M2U, self).__init__()
# Decoder # Decoder
self.decode4 = DecoderBlock(96,32,upsamplemode,expand_ratio) self.decode4 = DecoderBlock(96, 32, upsamplemode, expand_ratio)
self.decode3 = DecoderBlock(64,24,upsamplemode,expand_ratio) self.decode3 = DecoderBlock(64, 24, upsamplemode, expand_ratio)
self.decode2 = DecoderBlock(44,16,upsamplemode,expand_ratio) self.decode2 = DecoderBlock(44, 16, upsamplemode, expand_ratio)
self.decode1 = LastDecoderBlock(33,upsamplemode,expand_ratio) self.decode1 = LastDecoderBlock(33, upsamplemode, expand_ratio)
# initilaize weights # initilaize weights
self._initialize_weights() self._initialize_weights()
...@@ -68,7 +82,7 @@ class M2U(torch.nn.Module): ...@@ -68,7 +82,7 @@ class M2U(torch.nn.Module):
m.weight.data.fill_(1) m.weight.data.fill_(1)
m.bias.data.zero_() m.bias.data.zero_()
def forward(self,x): def forward(self, x):
""" """
Parameters Parameters
---------- ----------
...@@ -80,13 +94,14 @@ class M2U(torch.nn.Module): ...@@ -80,13 +94,14 @@ class M2U(torch.nn.Module):
------- -------
tensor : :py:class:`torch.Tensor` tensor : :py:class:`torch.Tensor`
""" """
decode4 = self.decode4(x[5],x[4]) # 96, 32 decode4 = self.decode4(x[5], x[4]) # 96, 32
decode3 = self.decode3(decode4,x[3]) # 64, 24 decode3 = self.decode3(decode4, x[3]) # 64, 24
decode2 = self.decode2(decode3,x[2]) # 44, 16 decode2 = self.decode2(decode3, x[2]) # 44, 16
decode1 = self.decode1(decode2,x[1]) # 30, 3 decode1 = self.decode1(decode2, x[1]) # 30, 3
return decode1 return decode1
def build_m2unet(): def build_m2unet():
""" """
Adds backbone and head together Adds backbone and head together
...@@ -95,9 +110,11 @@ def build_m2unet(): ...@@ -95,9 +110,11 @@ def build_m2unet():
------- -------
module : :py:class:`torch.nn.Module` module : :py:class:`torch.nn.Module`
""" """
backbone = MobileNetV2(return_features = [1, 3, 6, 13], m2u=True) backbone = MobileNetV2(return_features=[1, 3, 6, 13], m2u=True)
m2u_head = M2U(in_channels_list=[16, 24, 32, 96]) m2u_head = M2U(in_channels_list=[16, 24, 32, 96])
model = torch.nn.Sequential(OrderedDict([("backbone", backbone), ("head", m2u_head)])) model = torch.nn.Sequential(
OrderedDict([("backbone", backbone), ("head", m2u_head)])
)
model.name = "M2UNet" model.name = "M2UNet"
return model return model
...@@ -6,7 +6,10 @@ import torch.nn ...@@ -6,7 +6,10 @@ import torch.nn
from torch.nn import Conv2d from torch.nn import Conv2d
from torch.nn import ConvTranspose2d from torch.nn import ConvTranspose2d
def conv_with_kaiming_uniform(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1):
def conv_with_kaiming_uniform(
in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1
):
conv = Conv2d( conv = Conv2d(
in_channels, in_channels,
out_channels, out_channels,
...@@ -14,16 +17,18 @@ def conv_with_kaiming_uniform(in_channels, out_channels, kernel_size, stride=1, ...@@ -14,16 +17,18 @@ def conv_with_kaiming_uniform(in_channels, out_channels, kernel_size, stride=1,
stride=stride, stride=stride,
padding=padding, padding=padding,
dilation=dilation, dilation=dilation,
bias= True bias=True,
) )
# Caffe2 implementation uses XavierFill, which in fact # Caffe2 implementation uses XavierFill, which in fact
# corresponds to kaiming_uniform_ in PyTorch # corresponds to kaiming_uniform_ in PyTorch
torch.nn.init.kaiming_uniform_(conv.weight, a=1) torch.nn.init.kaiming_uniform_(conv.weight, a=1)
torch.nn.init.constant_(conv.bias, 0) torch.nn.init.constant_(conv.bias, 0)
return conv return conv
def convtrans_with_kaiming_uniform(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1): def convtrans_with_kaiming_uniform(
in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1
):
conv = ConvTranspose2d( conv = ConvTranspose2d(
in_channels, in_channels,
out_channels, out_channels,
...@@ -31,10 +36,10 @@ def convtrans_with_kaiming_uniform(in_channels, out_channels, kernel_size, strid ...@@ -31,10 +36,10 @@ def convtrans_with_kaiming_uniform(in_channels, out_channels, kernel_size, strid
stride=stride, stride=stride,
padding=padding, padding=padding,
dilation=dilation, dilation=dilation,
bias= True bias=True,
) )
# Caffe2 implementation uses XavierFill, which in fact # Caffe2 implementation uses XavierFill, which in fact
# corresponds to kaiming_uniform_ in PyTorch # corresponds to kaiming_uniform_ in PyTorch
torch.nn.init.kaiming_uniform_(conv.weight, a=1) torch.nn.init.kaiming_uniform_(conv.weight, a=1)
torch.nn.init.constant_(conv.bias, 0) torch.nn.init.constant_(conv.bias, 0)
return conv return conv
...@@ -63,15 +68,24 @@ class UpsampleCropBlock(torch.nn.Module): ...@@ -63,15 +68,24 @@ class UpsampleCropBlock(torch.nn.Module):
""" """
def __init__(self, in_channels, out_channels, up_kernel_size, up_stride, up_padding, pixelshuffle=False): def __init__(
self,
in_channels,
out_channels,
up_kernel_size,
up_stride,
up_padding,
pixelshuffle=False,
):
super().__init__() super().__init__()
# NOTE: Kaiming init, replace with torch.nn.Conv2d and torch.nn.ConvTranspose2d to get original DRIU impl. # NOTE: Kaiming init, replace with torch.nn.Conv2d and torch.nn.ConvTranspose2d to get original DRIU impl.
self.conv = conv_with_kaiming_uniform(in_channels, out_channels, 3, 1, 1) self.conv = conv_with_kaiming_uniform(in_channels, out_channels, 3, 1, 1)
if pixelshuffle: if pixelshuffle:
self.upconv = PixelShuffle_ICNR( out_channels, out_channels, scale = up_stride) self.upconv = PixelShuffle_ICNR(out_channels, out_channels, scale=up_stride)
else: else:
self.upconv = convtrans_with_kaiming_uniform(out_channels, out_channels, up_kernel_size, up_stride, up_padding) self.upconv = convtrans_with_kaiming_uniform(
out_channels, out_channels, up_kernel_size, up_stride, up_padding
)
def forward(self, x, input_res): def forward(self, x, input_res):
"""Forward pass of UpsampleBlock. """Forward pass of UpsampleBlock.
...@@ -98,39 +112,40 @@ class UpsampleCropBlock(torch.nn.Module): ...@@ -98,39 +112,40 @@ class UpsampleCropBlock(torch.nn.Module):
# height # height
up_h = x.shape[2] up_h = x.shape[2]
h_crop = up_h - img_h h_crop = up_h - img_h
h_s = h_crop//2 h_s = h_crop // 2
h_e = up_h - (h_crop - h_s) h_e = up_h - (h_crop - h_s)
# width # width
up_w = x.shape[3] up_w = x.shape[3]
w_crop = up_w-img_w w_crop = up_w - img_w
w_s = w_crop//2 w_s = w_crop // 2
w_e = up_w - (w_crop - w_s) w_e = up_w - (w_crop - w_s)
# perform crop # perform crop
# needs explicit ranges for onnx export # needs explicit ranges for onnx export
x = x[:,:,h_s:h_e,w_s:w_e] # crop to input size x = x[:, :, h_s:h_e, w_s:w_e] # crop to input size
return x return x
def ifnone(a, b): def ifnone(a, b):
"``a`` if ``a`` is not None, otherwise ``b``." "``a`` if ``a`` is not None, otherwise ``b``."
return b if a is None else a return b if a is None else a
def icnr(x, scale=2, init=torch.nn.init.kaiming_normal_): def icnr(x, scale=2, init=torch.nn.init.kaiming_normal_):
"""https://docs.fast.ai/layers.html#PixelShuffle_ICNR """https://docs.fast.ai/layers.html#PixelShuffle_ICNR
ICNR init of ``x``, with ``scale`` and ``init`` function. ICNR init of ``x``, with ``scale`` and ``init`` function.
""" """
ni,nf,h,w = x.shape ni, nf, h, w = x.shape
ni2 = int(ni/(scale**2)) ni2 = int(ni / (scale ** 2))
k = init(torch.zeros([ni2,nf,h,w])).transpose(0, 1) k = init(torch.zeros([ni2, nf, h, w])).transpose(0, 1)
k = k.contiguous().view(ni2, nf, -1) k = k.contiguous().view(ni2, nf, -1)
k = k.repeat(1, 1, scale**2) k = k.repeat(1, 1, scale ** 2)
k = k.contiguous().view([nf,ni,h,w]).transpose(0, 1) k = k.contiguous().view([nf, ni, h, w]).transpose(0, 1)
x.data.copy_(k) x.data.copy_(k)
class PixelShuffle_ICNR(torch.nn.Module): class PixelShuffle_ICNR(torch.nn.Module):
"""https://docs.fast.ai/layers.html#PixelShuffle_ICNR """https://docs.fast.ai/layers.html#PixelShuffle_ICNR
...@@ -138,47 +153,52 @@ class PixelShuffle_ICNR(torch.nn.Module): ...@@ -138,47 +153,52 @@ class PixelShuffle_ICNR(torch.nn.Module):
``torch.nn.PixelShuffle``, ``icnr`` init, and ``weight_norm``. ``torch.nn.PixelShuffle``, ``icnr`` init, and ``weight_norm``.
""" """
def __init__(self, ni:int, nf:int=None, scale:int=2): def __init__(self, ni: int, nf: int = None, scale: int = 2):
super().__init__() super().__init__()
nf = ifnone(nf, ni) nf = ifnone(nf, ni)
self.conv = conv_with_kaiming_uniform(ni, nf*(scale**2), 1) self.conv = conv_with_kaiming_uniform(ni, nf * (scale ** 2), 1)
icnr(self.conv.weight) icnr(self.conv.weight)
self.shuf = torch.nn.PixelShuffle(scale) self.shuf = torch.nn.PixelShuffle(scale)
# Blurring over (h*w) kernel # Blurring over (h*w) kernel
# "Super-Resolution using Convolutional Neural Networks without Any Checkerboard Artifacts" # "Super-Resolution using Convolutional Neural Networks without Any Checkerboard Artifacts"
# - https://arxiv.org/abs/1806.02658 # - https://arxiv.org/abs/1806.02658
self.pad = torch.nn.ReplicationPad2d((1,0,1,0)) self.pad = torch.nn.ReplicationPad2d((1, 0, 1, 0))
self.blur = torch.nn.AvgPool2d(2, stride=1) self.blur = torch.nn.AvgPool2d(2, stride=1)
self.relu = torch.nn.ReLU(inplace=True) self.relu = torch.nn.ReLU(inplace=True)
def forward(self,x): def forward(self, x):
x = self.shuf(self.relu(self.conv(x))) x = self.shuf(self.relu(self.conv(x)))
x = self.blur(self.pad(x)) x = self.blur(self.pad(x))
return x return x
class UnetBlock(torch.nn.Module): class UnetBlock(torch.nn.Module):
def __init__(self, up_in_c, x_in_c, pixel_shuffle=False, middle_block=False): def __init__(self, up_in_c, x_in_c, pixel_shuffle=False, middle_block=False):
super().__init__() super().__init__()
# middle block for VGG based U-Net # middle block for VGG based U-Net
if middle_block: if middle_block:
up_out_c = up_in_c up_out_c = up_in_c
else: else:
up_out_c = up_in_c // 2 up_out_c = up_in_c // 2
cat_channels = x_in_c + up_out_c cat_channels = x_in_c + up_out_c
inner_channels = cat_channels // 2 inner_channels = cat_channels // 2
if pixel_shuffle: if pixel_shuffle:
self.upsample = PixelShuffle_ICNR( up_in_c, up_out_c ) self.upsample = PixelShuffle_ICNR(up_in_c, up_out_c)
else: else:
self.upsample = convtrans_with_kaiming_uniform( up_in_c, up_out_c, 2, 2) self.upsample = convtrans_with_kaiming_uniform(up_in_c, up_out_c, 2, 2)
self.convtrans1 = convtrans_with_kaiming_uniform( cat_channels, inner_channels, 3, 1, 1) self.convtrans1 = convtrans_with_kaiming_uniform(
self.convtrans2 = convtrans_with_kaiming_uniform( inner_channels, inner_channels, 3, 1, 1) cat_channels, inner_channels, 3, 1, 1
)
self.convtrans2 = convtrans_with_kaiming_uniform(
inner_channels, inner_channels, 3, 1, 1
)
self.relu = torch.nn.ReLU(inplace=True) self.relu = torch.nn.ReLU(inplace=True)
def forward(self, up_in, x_in): def forward(self, up_in, x_in):
up_out = self.upsample(up_in) up_out = self.upsample(up_in)
cat_x = torch.cat([up_out, x_in] , dim=1) cat_x = torch.cat([up_out, x_in], dim=1)
x = self.relu(self.convtrans1(cat_x)) x = self.relu(self.convtrans1(cat_x))
x = self.relu(self.convtrans2(x)) x = self.relu(self.convtrans2(x))
return x return x
...@@ -4,11 +4,15 @@ ...@@ -4,11 +4,15 @@
import torch.nn as nn import torch.nn as nn
import torch import torch
from collections import OrderedDict from collections import OrderedDict
from bob.ip.binseg.modeling.make_layers import conv_with_kaiming_uniform, convtrans_with_kaiming_uniform, PixelShuffle_ICNR, UnetBlock from bob.ip.binseg.modeling.make_layers import (
conv_with_kaiming_uniform,
convtrans_with_kaiming_uniform,
PixelShuffle_ICNR,
UnetBlock,
)
from bob.ip.binseg.modeling.backbones.resnet import resnet50 from bob.ip.binseg.modeling.backbones.resnet import resnet50
class ResUNet(nn.Module): class ResUNet(nn.Module):
""" """
UNet head module for ResNet backbones UNet head module for ResNet backbones
...@@ -18,12 +22,13 @@ class ResUNet(nn.Module): ...@@ -18,12 +22,13 @@ class ResUNet(nn.Module):
in_channels_list : list in_channels_list : list
number of channels for each feature map that is returned from backbone number of channels for each feature map that is returned from backbone
""" """
def __init__(self, in_channels_list=None, pixel_shuffle=False): def __init__(self, in_channels_list=None, pixel_shuffle=False):
super(ResUNet, self).__init__() super(ResUNet, self).__init__()
# number of channels # number of channels
c_decode1, c_decode2, c_decode3, c_decode4, c_decode5 = in_channels_list c_decode1, c_decode2, c_decode3, c_decode4, c_decode5 = in_channels_list
# number of channels for last upsampling operation # number of channels for last upsampling operation
c_decode0 = (c_decode1 + c_decode2//2)//2 c_decode0 = (c_decode1 + c_decode2 // 2) // 2
# build layers # build layers
self.decode4 = UnetBlock(c_decode5, c_decode4, pixel_shuffle) self.decode4 = UnetBlock(c_decode5, c_decode4, pixel_shuffle)
...@@ -36,7 +41,7 @@ class ResUNet(nn.Module): ...@@ -36,7 +41,7 @@ class ResUNet(nn.Module):
self.decode0 = convtrans_with_kaiming_uniform(c_decode0, c_decode0, 2, 2) self.decode0 = convtrans_with_kaiming_uniform(c_decode0, c_decode0, 2, 2)
self.final = conv_with_kaiming_uniform(c_decode0, 1, 1) self.final = conv_with_kaiming_uniform(c_decode0, 1, 1)
def forward(self,x): def forward(self, x):
""" """
Parameters Parameters
---------- ----------
...@@ -54,6 +59,7 @@ class ResUNet(nn.Module): ...@@ -54,6 +59,7 @@ class ResUNet(nn.Module):
out = self.final(decode0) out = self.final(decode0)
return out return out
def build_res50unet(): def build_res50unet():
""" """
Adds backbone and head together Adds backbone and head together
...@@ -62,8 +68,8 @@ def build_res50unet(): ...@@ -62,8 +68,8 @@ def build_res50unet():
------- -------
model : :py:class:`torch.nn.Module` model : :py:class:`torch.nn.Module`
""" """
backbone = resnet50(pretrained=False, return_features = [2, 4, 5, 6, 7]) backbone = resnet50(pretrained=False, return_features=[2, 4, 5, 6, 7])
unet_head = ResUNet([64, 256, 512, 1024, 2048],pixel_shuffle=False) unet_head = ResUNet([64, 256, 512, 1024, 2048], pixel_shuffle=False)
model = nn.Sequential(OrderedDict([("backbone", backbone), ("head", unet_head)])) model = nn.Sequential(OrderedDict([("backbone", backbone), ("head", unet_head)]))
model.name = "ResUNet" model.name = "ResUNet"
return model return model
...@@ -4,11 +4,15 @@ ...@@ -4,11 +4,15 @@
import torch.nn as nn import torch.nn as nn
import torch import torch
from collections import OrderedDict from collections import OrderedDict
from bob.ip.binseg.modeling.make_layers import conv_with_kaiming_uniform, convtrans_with_kaiming_uniform, PixelShuffle_ICNR, UnetBlock from bob.ip.binseg.modeling.make_layers import (
conv_with_kaiming_uniform,
convtrans_with_kaiming_uniform,
PixelShuffle_ICNR,
UnetBlock,
)
from bob.ip.binseg.modeling.backbones.vgg import vgg16 from bob.ip.binseg.modeling.backbones.vgg import vgg16
class UNet(nn.Module): class UNet(nn.Module):
""" """
UNet head module UNet head module
...@@ -18,6 +22,7 @@ class UNet(nn.Module): ...@@ -18,6 +22,7 @@ class UNet(nn.Module):
in_channels_list : list in_channels_list : list
number of channels for each feature map that is returned from backbone number of channels for each feature map that is returned from backbone
""" """
def __init__(self, in_channels_list=None, pixel_shuffle=False): def __init__(self, in_channels_list=None, pixel_shuffle=False):
super(UNet, self).__init__() super(UNet, self).__init__()
# number of channels # number of channels
...@@ -30,7 +35,7 @@ class UNet(nn.Module): ...@@ -30,7 +35,7 @@ class UNet(nn.Module):
self.decode1 = UnetBlock(c_decode2, c_decode1, pixel_shuffle) self.decode1 = UnetBlock(c_decode2, c_decode1, pixel_shuffle)
self.final = conv_with_kaiming_uniform(c_decode1, 1, 1) self.final = conv_with_kaiming_uniform(c_decode1, 1, 1)
def forward(self,x): def forward(self, x):
""" """
Parameters Parameters
---------- ----------
...@@ -47,6 +52,7 @@ class UNet(nn.Module): ...@@ -47,6 +52,7 @@ class UNet(nn.Module):
out = self.final(decode1) out = self.final(decode1)
return out return out
def build_unet(): def build_unet():
""" """
Adds backbone and head together Adds backbone and head together
...@@ -56,7 +62,7 @@ def build_unet(): ...@@ -56,7 +62,7 @@ def build_unet():
module : :py:class:`torch.nn.Module` module : :py:class:`torch.nn.Module`
""" """
backbone = vgg16(pretrained=False, return_features = [3, 8, 14, 22, 29]) backbone = vgg16(pretrained=False, return_features=[3, 8, 14, 22, 29])
unet_head = UNet([64, 128, 256, 512, 512], pixel_shuffle=False) unet_head = UNet([64, 128, 256, 512, 512], pixel_shuffle=False)
model = nn.Sequential(OrderedDict([("backbone", backbone), ("head", unet_head)])) model = nn.Sequential(OrderedDict([("backbone", backbone), ("head", unet_head)]))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment