Skip to content
Snippets Groups Projects
Commit 49401e67 authored by Tim Laibacher's avatar Tim Laibacher
Browse files

Fix SSL to work with unlabeled datasets smaller than the labeled dataset

parent 26c20533
No related branches found
No related tags found
1 merge request!1Ssl
Pipeline #30725 failed
...@@ -33,7 +33,7 @@ optimizer = AdaBound(model.parameters(), lr=lr, betas=betas, final_lr=final_lr, ...@@ -33,7 +33,7 @@ optimizer = AdaBound(model.parameters(), lr=lr, betas=betas, final_lr=final_lr,
eps=eps, weight_decay=weight_decay, amsbound=amsbound) eps=eps, weight_decay=weight_decay, amsbound=amsbound)
# criterion # criterion
criterion = MixJacLoss(lambda_u=0.3, jacalpha=0.7, unlabeledjacalpha=0.7) criterion = MixJacLoss(lambda_u=0.01, jacalpha=0.7, unlabeledjacalpha=0.7)
# scheduler # scheduler
scheduler = MultiStepLR(optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma) scheduler = MultiStepLR(optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma)
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from torch.optim.lr_scheduler import MultiStepLR
from bob.ip.binseg.modeling.m2u import build_m2unet
import torch.optim as optim
from torch.nn import BCEWithLogitsLoss
from bob.ip.binseg.utils.model_zoo import modelurls
from bob.ip.binseg.modeling.losses import MixJacLoss
from bob.ip.binseg.engine.adabound import AdaBound
##### Config #####
lr = 0.001
betas = (0.9, 0.999)
eps = 1e-08
weight_decay = 0
final_lr = 0.1
gamma = 1e-3
eps = 1e-8
amsbound = False
scheduler_milestones = [900]
scheduler_gamma = 0.1
# model
model = build_m2unet()
# pretrained backbone
pretrained_backbone = modelurls['mobilenetv2']
# optimizer
optimizer = AdaBound(model.parameters(), lr=lr, betas=betas, final_lr=final_lr, gamma=gamma,
eps=eps, weight_decay=weight_decay, amsbound=amsbound)
# criterion
criterion = MixJacLoss(lambda_u=0.3, jacalpha=0.7, unlabeledjacalpha=0.3)
# scheduler
scheduler = MultiStepLR(optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma)
...@@ -19,8 +19,11 @@ class BinSegDataset(Dataset): ...@@ -19,8 +19,11 @@ class BinSegDataset(Dataset):
mask : bool mask : bool
whether dataset contains masks or not whether dataset contains masks or not
""" """
def __init__(self, bobdb, split = 'train', transform = None): def __init__(self, bobdb, split = 'train', transform = None,index_to = None):
self.database = bobdb.samples(split) if index_to:
self.database = bobdb.samples(split)[:index_to]
else:
self.database = bobdb.samples(split)
self.transform = transform self.transform = transform
self.split = split self.split = split
...@@ -47,15 +50,12 @@ class BinSegDataset(Dataset): ...@@ -47,15 +50,12 @@ class BinSegDataset(Dataset):
Returns Returns
------- -------
list list
dataitem [img_name, img, gt, mask] dataitem [img_name, img, gt]
""" """
img = self.database[index].img.pil_image() img = self.database[index].img.pil_image()
gt = self.database[index].gt.pil_image() gt = self.database[index].gt.pil_image()
img_name = self.database[index].img.basename img_name = self.database[index].img.basename
sample = [img, gt] sample = [img, gt]
if self.mask:
mask = self.database[index].mask.pil_image()
sample.append(mask)
if self.transform : if self.transform :
sample = self.transform(*sample) sample = self.transform(*sample)
...@@ -72,20 +72,14 @@ class SSLBinSegDataset(Dataset): ...@@ -72,20 +72,14 @@ class SSLBinSegDataset(Dataset):
Parameters Parameters
---------- ----------
bobdb : :py:mod:`bob.db.base` labeled_dataset : :py:class:`torch.utils.data.Dataset`
Binary segmentation bob database (e.g. bob.db.drive) BinSegDataset with labeled samples
unlabeled_dataset : :py:class:`torch.utils.data.Dataset` unlabeled_dataset : :py:class:`torch.utils.data.Dataset`
dataset with unlabeled data UnLabeledBinSegDataset with unlabeled data
split : str
``'train'`` or ``'test'``. Defaults to ``'train'``
transform : :py:mod:`bob.ip.binseg.data.transforms`, optional
A transform or composition of transfroms. Defaults to ``None``.
""" """
def __init__(self, bobdb, unlabeled_dataset, split = 'train', transform = None): def __init__(self, labeled_dataset, unlabeled_dataset):
self.database = bobdb.samples(split) self.labeled_dataset = labeled_dataset
self.unlabeled_dataset = unlabeled_dataset self.unlabeled_dataset = unlabeled_dataset
self.transform = transform
self.split = split
def __len__(self): def __len__(self):
...@@ -95,7 +89,7 @@ class SSLBinSegDataset(Dataset): ...@@ -95,7 +89,7 @@ class SSLBinSegDataset(Dataset):
int int
size of the dataset size of the dataset
""" """
return len(self.database) return len(self.labeled_dataset)
def __getitem__(self,index): def __getitem__(self,index):
""" """
...@@ -108,17 +102,8 @@ class SSLBinSegDataset(Dataset): ...@@ -108,17 +102,8 @@ class SSLBinSegDataset(Dataset):
list list
dataitem [img_name, img, gt, unlabeled_img_name, unlabeled_img] dataitem [img_name, img, gt, unlabeled_img_name, unlabeled_img]
""" """
img = self.database[index].img.pil_image() sample = self.labeled_dataset[index]
gt = self.database[index].gt.pil_image() unlabeled_img_name, unlabeled_img = self.unlabeled_dataset[0]
img_name = self.database[index].img.basename
sample = [img, gt]
if self.transform :
sample = self.transform(*sample)
sample.insert(0,img_name)
unlabeled_img_name, unlabeled_img = self.unlabeled_dataset[index]
sample.extend([unlabeled_img_name, unlabeled_img]) sample.extend([unlabeled_img_name, unlabeled_img])
return sample return sample
...@@ -138,8 +123,11 @@ class UnLabeledBinSegDataset(Dataset): ...@@ -138,8 +123,11 @@ class UnLabeledBinSegDataset(Dataset):
transform : :py:mod:`bob.ip.binseg.data.transforms`, optional transform : :py:mod:`bob.ip.binseg.data.transforms`, optional
A transform or composition of transfroms. Defaults to ``None``. A transform or composition of transfroms. Defaults to ``None``.
""" """
def __init__(self, db, split = 'train', transform = None): def __init__(self, db, split = 'train', transform = None,index_from= None):
self.database = db.samples(split) if index_from:
self.database = db.samples(split)[index_from:]
else:
self.database = db.samples(split)
self.transform = transform self.transform = transform
self.split = split self.split = split
......
...@@ -13,6 +13,10 @@ import numpy as np ...@@ -13,6 +13,10 @@ import numpy as np
from bob.ip.binseg.utils.metric import SmoothedValue from bob.ip.binseg.utils.metric import SmoothedValue
from bob.ip.binseg.utils.plot import loss_curve from bob.ip.binseg.utils.plot import loss_curve
def sharpen(x, T):
temp = x**(1/T)
return temp / temp.sum(dim=1, keepdim=True)
def mix_up(alpha, input, target, unlabeled_input, unlabled_target): def mix_up(alpha, input, target, unlabeled_input, unlabled_target):
"""Applies mix up as described in [MIXMATCH_19]. """Applies mix up as described in [MIXMATCH_19].
...@@ -28,21 +32,23 @@ def mix_up(alpha, input, target, unlabeled_input, unlabled_target): ...@@ -28,21 +32,23 @@ def mix_up(alpha, input, target, unlabeled_input, unlabled_target):
------- -------
list list
""" """
l = np.random.beta(alpha, alpha) # Eq (8) # TODO:
l = max(l, 1 - l) # Eq (9) with torch.no_grad():
# Shuffle and concat. Alg. 1 Line: 12 l = np.random.beta(alpha, alpha) # Eq (8)
w_inputs = torch.cat([input,unlabeled_input],0) l = max(l, 1 - l) # Eq (9)
w_targets = torch.cat([target,unlabled_target],0) # Shuffle and concat. Alg. 1 Line: 12
idx = torch.randperm(w_inputs.size(0)) # get random index w_inputs = torch.cat([input,unlabeled_input],0)
w_targets = torch.cat([target,unlabled_target],0)
# Apply MixUp to labeled data and entries from W. Alg. 1 Line: 13 idx = torch.randperm(w_inputs.size(0)) # get random index
input_mixedup = l * input + (1 - l) * w_inputs[idx[len(input):]]
target_mixedup = l * target + (1 - l) * w_targets[idx[len(target):]] # Apply MixUp to labeled data and entries from W. Alg. 1 Line: 13
input_mixedup = l * input + (1 - l) * w_inputs[idx[len(input):]]
# Apply MixUp to unlabeled data and entries from W. Alg. 1 Line: 14 target_mixedup = l * target + (1 - l) * w_targets[idx[len(target):]]
unlabeled_input_mixedup = l * unlabeled_input + (1 - l) * w_inputs[idx[:len(unlabeled_input)]]
unlabled_target_mixedup = l * unlabled_target + (1 - l) * w_targets[idx[:len(unlabled_target)]] # Apply MixUp to unlabeled data and entries from W. Alg. 1 Line: 14
return input_mixedup, target_mixedup, unlabeled_input_mixedup, unlabled_target_mixedup unlabeled_input_mixedup = l * unlabeled_input + (1 - l) * w_inputs[idx[:len(unlabeled_input)]]
unlabled_target_mixedup = l * unlabled_target + (1 - l) * w_targets[idx[:len(unlabled_target)]]
return input_mixedup, target_mixedup, unlabeled_input_mixedup, unlabled_target_mixedup
def linear_rampup(current, rampup_length=16): def linear_rampup(current, rampup_length=16):
...@@ -135,7 +141,7 @@ def do_ssltrain( ...@@ -135,7 +141,7 @@ def do_ssltrain(
max_epoch = arguments["max_epoch"] max_epoch = arguments["max_epoch"]
# Logg to file # Logg to file
with open (os.path.join(output_folder,"{}_trainlog.csv".format(model.name)), "a+") as outfile: with open (os.path.join(output_folder,"{}_trainlog.csv".format(model.name)), "a+",1) as outfile:
for state in optimizer.state.values(): for state in optimizer.state.values():
for k, v in state.items(): for k, v in state.items():
if isinstance(v, torch.Tensor): if isinstance(v, torch.Tensor):
...@@ -165,9 +171,10 @@ def do_ssltrain( ...@@ -165,9 +171,10 @@ def do_ssltrain(
unlabeled_outputs = model(unlabeled_images) unlabeled_outputs = model(unlabeled_images)
# guessed unlabeled outputs # guessed unlabeled outputs
unlabeled_ground_truths = guess_labels(unlabeled_images, model) unlabeled_ground_truths = guess_labels(unlabeled_images, model)
ramp_up_factor = linear_rampup(epoch,rampup_length=16) #unlabeled_ground_truths = sharpen(unlabeled_ground_truths,0.5)
#images, ground_truths, unlabeled_images, unlabeled_ground_truths = mix_up(0.75, images, ground_truths, unlabeled_images, unlabeled_ground_truths)
ramp_up_factor = linear_rampup(epoch,rampup_length=500)
loss, ll, ul = criterion(outputs, ground_truths, unlabeled_outputs, unlabeled_ground_truths, ramp_up_factor) loss, ll, ul = criterion(outputs, ground_truths, unlabeled_outputs, unlabeled_ground_truths, ramp_up_factor)
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
...@@ -212,8 +219,8 @@ def do_ssltrain( ...@@ -212,8 +219,8 @@ def do_ssltrain(
"epoch: {epoch}, " "epoch: {epoch}, "
"avg. loss: {avg_loss:.6f}, " "avg. loss: {avg_loss:.6f}, "
"median loss: {median_loss:.6f}, " "median loss: {median_loss:.6f}, "
"{median_labeled_loss}, " "labeled loss: {median_labeled_loss}, "
"{median_unlabeled_loss}, " "unlabeled loss: {median_unlabeled_loss}, "
"lr: {lr:.6f}, " "lr: {lr:.6f}, "
"max mem: {memory:.0f}" "max mem: {memory:.0f}"
).format( ).format(
...@@ -241,3 +248,4 @@ def do_ssltrain( ...@@ -241,3 +248,4 @@ def do_ssltrain(
fig = loss_curve(logdf,output_folder) fig = loss_curve(logdf,output_folder)
logger.info("saving {}".format(log_plot_file)) logger.info("saving {}".format(log_plot_file))
fig.savefig(log_plot_file) fig.savefig(log_plot_file)
\ No newline at end of file
...@@ -171,11 +171,11 @@ class MixJacLoss(_Loss): ...@@ -171,11 +171,11 @@ class MixJacLoss(_Loss):
lambda_u : int lambda_u : int
determines the weighting of SoftJaccard and BCE. determines the weighting of SoftJaccard and BCE.
""" """
def __init__(self, lambda_u=0.3, jacalpha=0.7, unlabeledjacalpha=0.7, size_average=None, reduce=None, reduction='mean', pos_weight=None): def __init__(self, lambda_u=100, jacalpha=0.7, unlabeledjacalpha=0.7, size_average=None, reduce=None, reduction='mean', pos_weight=None):
super(MixJacLoss, self).__init__(size_average, reduce, reduction) super(MixJacLoss, self).__init__(size_average, reduce, reduction)
self.lambda_u = lambda_u self.lambda_u = lambda_u
self.labeled_loss = SoftJaccardBCELogitsLoss(alpha=jacalpha) self.labeled_loss = SoftJaccardBCELogitsLoss(alpha=jacalpha)
self.unlabeled_loss = SoftJaccardBCELogitsLoss(alpha=unlabeledjacalpha) self.unlabeled_loss = torch.nn.BCEWithLogitsLoss()
@weak_script_method @weak_script_method
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment