diff --git a/bob/ip/binseg/configs/models/m2unetssl.py b/bob/ip/binseg/configs/models/m2unetssl.py index b0beafe6acef86a74e8955a7de7c2c6c04502037..ac8847ab64cf2e948ef77c6cf2ad9a5e2a2eedb8 100644 --- a/bob/ip/binseg/configs/models/m2unetssl.py +++ b/bob/ip/binseg/configs/models/m2unetssl.py @@ -33,7 +33,7 @@ optimizer = AdaBound(model.parameters(), lr=lr, betas=betas, final_lr=final_lr, eps=eps, weight_decay=weight_decay, amsbound=amsbound) # criterion -criterion = MixJacLoss(lambda_u=0.3, jacalpha=0.7, unlabeledjacalpha=0.7) +criterion = MixJacLoss(lambda_u=0.01, jacalpha=0.7, unlabeledjacalpha=0.7) # scheduler scheduler = MultiStepLR(optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma) diff --git a/bob/ip/binseg/configs/models/m2unetssl0703.py b/bob/ip/binseg/configs/models/m2unetssl0703.py deleted file mode 100644 index d5a160821deaea8436d533938c69b9f535fc763d..0000000000000000000000000000000000000000 --- a/bob/ip/binseg/configs/models/m2unetssl0703.py +++ /dev/null @@ -1,39 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -from torch.optim.lr_scheduler import MultiStepLR -from bob.ip.binseg.modeling.m2u import build_m2unet -import torch.optim as optim -from torch.nn import BCEWithLogitsLoss -from bob.ip.binseg.utils.model_zoo import modelurls -from bob.ip.binseg.modeling.losses import MixJacLoss -from bob.ip.binseg.engine.adabound import AdaBound - -##### Config ##### -lr = 0.001 -betas = (0.9, 0.999) -eps = 1e-08 -weight_decay = 0 -final_lr = 0.1 -gamma = 1e-3 -eps = 1e-8 -amsbound = False - -scheduler_milestones = [900] -scheduler_gamma = 0.1 - -# model -model = build_m2unet() - -# pretrained backbone -pretrained_backbone = modelurls['mobilenetv2'] - -# optimizer -optimizer = AdaBound(model.parameters(), lr=lr, betas=betas, final_lr=final_lr, gamma=gamma, - eps=eps, weight_decay=weight_decay, amsbound=amsbound) - -# criterion -criterion = MixJacLoss(lambda_u=0.3, jacalpha=0.7, unlabeledjacalpha=0.3) - -# scheduler -scheduler = MultiStepLR(optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma) diff --git a/bob/ip/binseg/data/binsegdataset.py b/bob/ip/binseg/data/binsegdataset.py index 0f3ca24730c7f3b83880ee42137555057e9218eb..2917203c7b530bee796431e0dbe5e7af1f85a2b9 100644 --- a/bob/ip/binseg/data/binsegdataset.py +++ b/bob/ip/binseg/data/binsegdataset.py @@ -19,8 +19,11 @@ class BinSegDataset(Dataset): mask : bool whether dataset contains masks or not """ - def __init__(self, bobdb, split = 'train', transform = None): - self.database = bobdb.samples(split) + def __init__(self, bobdb, split = 'train', transform = None,index_to = None): + if index_to: + self.database = bobdb.samples(split)[:index_to] + else: + self.database = bobdb.samples(split) self.transform = transform self.split = split @@ -47,15 +50,12 @@ class BinSegDataset(Dataset): Returns ------- list - dataitem [img_name, img, gt, mask] + dataitem [img_name, img, gt] """ img = self.database[index].img.pil_image() gt = self.database[index].gt.pil_image() img_name = self.database[index].img.basename sample = [img, gt] - if self.mask: - mask = self.database[index].mask.pil_image() - sample.append(mask) if self.transform : sample = self.transform(*sample) @@ -72,20 +72,14 @@ class SSLBinSegDataset(Dataset): Parameters ---------- - bobdb : :py:mod:`bob.db.base` - Binary segmentation bob database (e.g. bob.db.drive) + labeled_dataset : :py:class:`torch.utils.data.Dataset` + BinSegDataset with labeled samples unlabeled_dataset : :py:class:`torch.utils.data.Dataset` - dataset with unlabeled data - split : str - ``'train'`` or ``'test'``. Defaults to ``'train'`` - transform : :py:mod:`bob.ip.binseg.data.transforms`, optional - A transform or composition of transfroms. Defaults to ``None``. + UnLabeledBinSegDataset with unlabeled data """ - def __init__(self, bobdb, unlabeled_dataset, split = 'train', transform = None): - self.database = bobdb.samples(split) + def __init__(self, labeled_dataset, unlabeled_dataset): + self.labeled_dataset = labeled_dataset self.unlabeled_dataset = unlabeled_dataset - self.transform = transform - self.split = split def __len__(self): @@ -95,7 +89,7 @@ class SSLBinSegDataset(Dataset): int size of the dataset """ - return len(self.database) + return len(self.labeled_dataset) def __getitem__(self,index): """ @@ -108,17 +102,8 @@ class SSLBinSegDataset(Dataset): list dataitem [img_name, img, gt, unlabeled_img_name, unlabeled_img] """ - img = self.database[index].img.pil_image() - gt = self.database[index].gt.pil_image() - img_name = self.database[index].img.basename - sample = [img, gt] - - - if self.transform : - sample = self.transform(*sample) - - sample.insert(0,img_name) - unlabeled_img_name, unlabeled_img = self.unlabeled_dataset[index] + sample = self.labeled_dataset[index] + unlabeled_img_name, unlabeled_img = self.unlabeled_dataset[0] sample.extend([unlabeled_img_name, unlabeled_img]) return sample @@ -138,8 +123,11 @@ class UnLabeledBinSegDataset(Dataset): transform : :py:mod:`bob.ip.binseg.data.transforms`, optional A transform or composition of transfroms. Defaults to ``None``. """ - def __init__(self, db, split = 'train', transform = None): - self.database = db.samples(split) + def __init__(self, db, split = 'train', transform = None,index_from= None): + if index_from: + self.database = db.samples(split)[index_from:] + else: + self.database = db.samples(split) self.transform = transform self.split = split diff --git a/bob/ip/binseg/engine/ssltrainer.py b/bob/ip/binseg/engine/ssltrainer.py index 54f8519471a5b770656fec4ef0714394c94f863e..8fb6d2f1c5fd5172e08e2d5eb034698ce47f1218 100644 --- a/bob/ip/binseg/engine/ssltrainer.py +++ b/bob/ip/binseg/engine/ssltrainer.py @@ -13,6 +13,10 @@ import numpy as np from bob.ip.binseg.utils.metric import SmoothedValue from bob.ip.binseg.utils.plot import loss_curve +def sharpen(x, T): + temp = x**(1/T) + return temp / temp.sum(dim=1, keepdim=True) + def mix_up(alpha, input, target, unlabeled_input, unlabled_target): """Applies mix up as described in [MIXMATCH_19]. @@ -28,21 +32,23 @@ def mix_up(alpha, input, target, unlabeled_input, unlabled_target): ------- list """ - l = np.random.beta(alpha, alpha) # Eq (8) - l = max(l, 1 - l) # Eq (9) - # Shuffle and concat. Alg. 1 Line: 12 - w_inputs = torch.cat([input,unlabeled_input],0) - w_targets = torch.cat([target,unlabled_target],0) - idx = torch.randperm(w_inputs.size(0)) # get random index - - # Apply MixUp to labeled data and entries from W. Alg. 1 Line: 13 - input_mixedup = l * input + (1 - l) * w_inputs[idx[len(input):]] - target_mixedup = l * target + (1 - l) * w_targets[idx[len(target):]] - - # Apply MixUp to unlabeled data and entries from W. Alg. 1 Line: 14 - unlabeled_input_mixedup = l * unlabeled_input + (1 - l) * w_inputs[idx[:len(unlabeled_input)]] - unlabled_target_mixedup = l * unlabled_target + (1 - l) * w_targets[idx[:len(unlabled_target)]] - return input_mixedup, target_mixedup, unlabeled_input_mixedup, unlabled_target_mixedup + # TODO: + with torch.no_grad(): + l = np.random.beta(alpha, alpha) # Eq (8) + l = max(l, 1 - l) # Eq (9) + # Shuffle and concat. Alg. 1 Line: 12 + w_inputs = torch.cat([input,unlabeled_input],0) + w_targets = torch.cat([target,unlabled_target],0) + idx = torch.randperm(w_inputs.size(0)) # get random index + + # Apply MixUp to labeled data and entries from W. Alg. 1 Line: 13 + input_mixedup = l * input + (1 - l) * w_inputs[idx[len(input):]] + target_mixedup = l * target + (1 - l) * w_targets[idx[len(target):]] + + # Apply MixUp to unlabeled data and entries from W. Alg. 1 Line: 14 + unlabeled_input_mixedup = l * unlabeled_input + (1 - l) * w_inputs[idx[:len(unlabeled_input)]] + unlabled_target_mixedup = l * unlabled_target + (1 - l) * w_targets[idx[:len(unlabled_target)]] + return input_mixedup, target_mixedup, unlabeled_input_mixedup, unlabled_target_mixedup def linear_rampup(current, rampup_length=16): @@ -135,7 +141,7 @@ def do_ssltrain( max_epoch = arguments["max_epoch"] # Logg to file - with open (os.path.join(output_folder,"{}_trainlog.csv".format(model.name)), "a+") as outfile: + with open (os.path.join(output_folder,"{}_trainlog.csv".format(model.name)), "a+",1) as outfile: for state in optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): @@ -165,9 +171,10 @@ def do_ssltrain( unlabeled_outputs = model(unlabeled_images) # guessed unlabeled outputs unlabeled_ground_truths = guess_labels(unlabeled_images, model) - ramp_up_factor = linear_rampup(epoch,rampup_length=16) + #unlabeled_ground_truths = sharpen(unlabeled_ground_truths,0.5) + #images, ground_truths, unlabeled_images, unlabeled_ground_truths = mix_up(0.75, images, ground_truths, unlabeled_images, unlabeled_ground_truths) + ramp_up_factor = linear_rampup(epoch,rampup_length=500) - loss, ll, ul = criterion(outputs, ground_truths, unlabeled_outputs, unlabeled_ground_truths, ramp_up_factor) optimizer.zero_grad() loss.backward() @@ -212,8 +219,8 @@ def do_ssltrain( "epoch: {epoch}, " "avg. loss: {avg_loss:.6f}, " "median loss: {median_loss:.6f}, " - "{median_labeled_loss}, " - "{median_unlabeled_loss}, " + "labeled loss: {median_labeled_loss}, " + "unlabeled loss: {median_unlabeled_loss}, " "lr: {lr:.6f}, " "max mem: {memory:.0f}" ).format( @@ -241,3 +248,4 @@ def do_ssltrain( fig = loss_curve(logdf,output_folder) logger.info("saving {}".format(log_plot_file)) fig.savefig(log_plot_file) + \ No newline at end of file diff --git a/bob/ip/binseg/modeling/losses.py b/bob/ip/binseg/modeling/losses.py index 4ab9175802de7e1cbcbe676d7e22693b1fba868d..da2b5f5ed6a518b9f6e0aafe24a5edc1f247237b 100644 --- a/bob/ip/binseg/modeling/losses.py +++ b/bob/ip/binseg/modeling/losses.py @@ -171,11 +171,11 @@ class MixJacLoss(_Loss): lambda_u : int determines the weighting of SoftJaccard and BCE. """ - def __init__(self, lambda_u=0.3, jacalpha=0.7, unlabeledjacalpha=0.7, size_average=None, reduce=None, reduction='mean', pos_weight=None): + def __init__(self, lambda_u=100, jacalpha=0.7, unlabeledjacalpha=0.7, size_average=None, reduce=None, reduction='mean', pos_weight=None): super(MixJacLoss, self).__init__(size_average, reduce, reduction) self.lambda_u = lambda_u self.labeled_loss = SoftJaccardBCELogitsLoss(alpha=jacalpha) - self.unlabeled_loss = SoftJaccardBCELogitsLoss(alpha=unlabeledjacalpha) + self.unlabeled_loss = torch.nn.BCEWithLogitsLoss() @weak_script_method