diff --git a/bob/ip/binseg/configs/models/m2unetssl.py b/bob/ip/binseg/configs/models/m2unetssl.py
index b0beafe6acef86a74e8955a7de7c2c6c04502037..ac8847ab64cf2e948ef77c6cf2ad9a5e2a2eedb8 100644
--- a/bob/ip/binseg/configs/models/m2unetssl.py
+++ b/bob/ip/binseg/configs/models/m2unetssl.py
@@ -33,7 +33,7 @@ optimizer = AdaBound(model.parameters(), lr=lr, betas=betas, final_lr=final_lr,
                  eps=eps, weight_decay=weight_decay, amsbound=amsbound) 
     
 # criterion
-criterion = MixJacLoss(lambda_u=0.3, jacalpha=0.7, unlabeledjacalpha=0.7)
+criterion = MixJacLoss(lambda_u=0.01, jacalpha=0.7, unlabeledjacalpha=0.7)
 
 # scheduler
 scheduler = MultiStepLR(optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma)
diff --git a/bob/ip/binseg/configs/models/m2unetssl0703.py b/bob/ip/binseg/configs/models/m2unetssl0703.py
deleted file mode 100644
index d5a160821deaea8436d533938c69b9f535fc763d..0000000000000000000000000000000000000000
--- a/bob/ip/binseg/configs/models/m2unetssl0703.py
+++ /dev/null
@@ -1,39 +0,0 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
-
-from torch.optim.lr_scheduler import MultiStepLR
-from bob.ip.binseg.modeling.m2u import build_m2unet
-import torch.optim as optim
-from torch.nn import BCEWithLogitsLoss
-from bob.ip.binseg.utils.model_zoo import modelurls
-from bob.ip.binseg.modeling.losses import MixJacLoss
-from bob.ip.binseg.engine.adabound import AdaBound
-
-##### Config #####
-lr = 0.001
-betas = (0.9, 0.999)
-eps = 1e-08
-weight_decay = 0
-final_lr = 0.1
-gamma = 1e-3
-eps = 1e-8
-amsbound = False
-
-scheduler_milestones = [900]
-scheduler_gamma = 0.1
-
-# model
-model = build_m2unet()
-
-# pretrained backbone
-pretrained_backbone = modelurls['mobilenetv2']
-
-# optimizer
-optimizer = AdaBound(model.parameters(), lr=lr, betas=betas, final_lr=final_lr, gamma=gamma,
-                 eps=eps, weight_decay=weight_decay, amsbound=amsbound) 
-    
-# criterion
-criterion = MixJacLoss(lambda_u=0.3, jacalpha=0.7, unlabeledjacalpha=0.3)
-
-# scheduler
-scheduler = MultiStepLR(optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma)
diff --git a/bob/ip/binseg/data/binsegdataset.py b/bob/ip/binseg/data/binsegdataset.py
index 0f3ca24730c7f3b83880ee42137555057e9218eb..2917203c7b530bee796431e0dbe5e7af1f85a2b9 100644
--- a/bob/ip/binseg/data/binsegdataset.py
+++ b/bob/ip/binseg/data/binsegdataset.py
@@ -19,8 +19,11 @@ class BinSegDataset(Dataset):
     mask : bool
         whether dataset contains masks or not
     """
-    def __init__(self, bobdb, split = 'train', transform = None):
-        self.database = bobdb.samples(split)
+    def __init__(self, bobdb, split = 'train', transform = None,index_to = None):
+        if index_to:
+            self.database = bobdb.samples(split)[:index_to]
+        else:
+            self.database = bobdb.samples(split)
         self.transform = transform
         self.split = split
     
@@ -47,15 +50,12 @@ class BinSegDataset(Dataset):
         Returns
         -------
         list
-            dataitem [img_name, img, gt, mask]
+            dataitem [img_name, img, gt]
         """
         img = self.database[index].img.pil_image()
         gt = self.database[index].gt.pil_image()
         img_name = self.database[index].img.basename
         sample = [img, gt]
-        if self.mask:
-            mask = self.database[index].mask.pil_image()
-            sample.append(mask)
         
         if self.transform :
             sample = self.transform(*sample)
@@ -72,20 +72,14 @@ class SSLBinSegDataset(Dataset):
     
     Parameters
     ---------- 
-    bobdb : :py:mod:`bob.db.base`
-        Binary segmentation bob database (e.g. bob.db.drive) 
+    labeled_dataset : :py:class:`torch.utils.data.Dataset`
+        BinSegDataset with labeled samples
     unlabeled_dataset : :py:class:`torch.utils.data.Dataset`
-        dataset with unlabeled data
-    split : str 
-        ``'train'`` or ``'test'``. Defaults to ``'train'``
-    transform : :py:mod:`bob.ip.binseg.data.transforms`, optional
-        A transform or composition of transfroms. Defaults to ``None``.
+        UnLabeledBinSegDataset with unlabeled data
     """
-    def __init__(self, bobdb, unlabeled_dataset, split = 'train', transform = None):
-        self.database = bobdb.samples(split)
+    def __init__(self, labeled_dataset, unlabeled_dataset):
+        self.labeled_dataset = labeled_dataset
         self.unlabeled_dataset = unlabeled_dataset
-        self.transform = transform
-        self.split = split
     
 
     def __len__(self):
@@ -95,7 +89,7 @@ class SSLBinSegDataset(Dataset):
         int
             size of the dataset
         """
-        return len(self.database)
+        return len(self.labeled_dataset)
     
     def __getitem__(self,index):
         """
@@ -108,17 +102,8 @@ class SSLBinSegDataset(Dataset):
         list
             dataitem [img_name, img, gt, unlabeled_img_name, unlabeled_img]
         """
-        img = self.database[index].img.pil_image()
-        gt = self.database[index].gt.pil_image()
-        img_name = self.database[index].img.basename
-        sample = [img, gt]
-
-        
-        if self.transform :
-            sample = self.transform(*sample)
-    
-        sample.insert(0,img_name)
-        unlabeled_img_name, unlabeled_img = self.unlabeled_dataset[index]
+        sample = self.labeled_dataset[index]
+        unlabeled_img_name, unlabeled_img = self.unlabeled_dataset[0]
         sample.extend([unlabeled_img_name, unlabeled_img])
         return sample
 
@@ -138,8 +123,11 @@ class UnLabeledBinSegDataset(Dataset):
     transform : :py:mod:`bob.ip.binseg.data.transforms`, optional
         A transform or composition of transfroms. Defaults to ``None``.
     """
-    def __init__(self, db, split = 'train', transform = None):
-        self.database = db.samples(split)
+    def __init__(self, db, split = 'train', transform = None,index_from= None):
+        if index_from:
+            self.database = db.samples(split)[index_from:]
+        else:
+            self.database = db.samples(split)
         self.transform = transform
         self.split = split   
 
diff --git a/bob/ip/binseg/engine/ssltrainer.py b/bob/ip/binseg/engine/ssltrainer.py
index 54f8519471a5b770656fec4ef0714394c94f863e..8fb6d2f1c5fd5172e08e2d5eb034698ce47f1218 100644
--- a/bob/ip/binseg/engine/ssltrainer.py
+++ b/bob/ip/binseg/engine/ssltrainer.py
@@ -13,6 +13,10 @@ import numpy as np
 from bob.ip.binseg.utils.metric import SmoothedValue
 from bob.ip.binseg.utils.plot import loss_curve
 
+def sharpen(x, T):
+    temp = x**(1/T)
+    return temp / temp.sum(dim=1, keepdim=True)
+
 def mix_up(alpha, input, target, unlabeled_input, unlabled_target):
     """Applies mix up as described in [MIXMATCH_19].
     
@@ -28,21 +32,23 @@ def mix_up(alpha, input, target, unlabeled_input, unlabled_target):
     -------
     list
     """
-    l = np.random.beta(alpha, alpha) # Eq (8)
-    l = max(l, 1 - l) # Eq (9)
-    # Shuffle and concat. Alg. 1 Line: 12
-    w_inputs = torch.cat([input,unlabeled_input],0)
-    w_targets = torch.cat([target,unlabled_target],0)
-    idx = torch.randperm(w_inputs.size(0)) # get random index 
-     
-    # Apply MixUp to labeled data and entries from W. Alg. 1 Line: 13
-    input_mixedup = l * input + (1 - l) * w_inputs[idx[len(input):]] 
-    target_mixedup = l * target + (1 - l) * w_targets[idx[len(target):]]
-    
-    # Apply MixUp to unlabeled data and entries from W. Alg. 1 Line: 14
-    unlabeled_input_mixedup = l * unlabeled_input + (1 - l) * w_inputs[idx[:len(unlabeled_input)]]
-    unlabled_target_mixedup =  l * unlabled_target + (1 - l) * w_targets[idx[:len(unlabled_target)]]
-    return input_mixedup, target_mixedup, unlabeled_input_mixedup, unlabled_target_mixedup
+    # TODO: 
+    with torch.no_grad():
+        l = np.random.beta(alpha, alpha) # Eq (8)
+        l = max(l, 1 - l) # Eq (9)
+        # Shuffle and concat. Alg. 1 Line: 12
+        w_inputs = torch.cat([input,unlabeled_input],0)
+        w_targets = torch.cat([target,unlabled_target],0)
+        idx = torch.randperm(w_inputs.size(0)) # get random index 
+        
+        # Apply MixUp to labeled data and entries from W. Alg. 1 Line: 13
+        input_mixedup = l * input + (1 - l) * w_inputs[idx[len(input):]] 
+        target_mixedup = l * target + (1 - l) * w_targets[idx[len(target):]]
+        
+        # Apply MixUp to unlabeled data and entries from W. Alg. 1 Line: 14
+        unlabeled_input_mixedup = l * unlabeled_input + (1 - l) * w_inputs[idx[:len(unlabeled_input)]]
+        unlabled_target_mixedup =  l * unlabled_target + (1 - l) * w_targets[idx[:len(unlabled_target)]]
+        return input_mixedup, target_mixedup, unlabeled_input_mixedup, unlabled_target_mixedup
 
 
 def linear_rampup(current, rampup_length=16):
@@ -135,7 +141,7 @@ def do_ssltrain(
     max_epoch = arguments["max_epoch"]
 
     # Logg to file
-    with open (os.path.join(output_folder,"{}_trainlog.csv".format(model.name)), "a+") as outfile:
+    with open (os.path.join(output_folder,"{}_trainlog.csv".format(model.name)), "a+",1) as outfile:
         for state in optimizer.state.values():
             for k, v in state.items():
                 if isinstance(v, torch.Tensor):
@@ -165,9 +171,10 @@ def do_ssltrain(
                 unlabeled_outputs = model(unlabeled_images)
                 # guessed unlabeled outputs
                 unlabeled_ground_truths = guess_labels(unlabeled_images, model)
-                ramp_up_factor = linear_rampup(epoch,rampup_length=16)
+                #unlabeled_ground_truths = sharpen(unlabeled_ground_truths,0.5)
+                #images, ground_truths, unlabeled_images, unlabeled_ground_truths = mix_up(0.75, images, ground_truths, unlabeled_images, unlabeled_ground_truths)
+                ramp_up_factor = linear_rampup(epoch,rampup_length=500)
 
-                
                 loss, ll, ul = criterion(outputs, ground_truths, unlabeled_outputs, unlabeled_ground_truths, ramp_up_factor)
                 optimizer.zero_grad()
                 loss.backward()
@@ -212,8 +219,8 @@ def do_ssltrain(
                         "epoch: {epoch}, "
                         "avg. loss: {avg_loss:.6f}, "
                         "median loss: {median_loss:.6f}, "
-                        "{median_labeled_loss}, "
-                        "{median_unlabeled_loss}, "
+                        "labeled loss: {median_labeled_loss}, "
+                        "unlabeled loss: {median_unlabeled_loss}, "
                         "lr: {lr:.6f}, "
                         "max mem: {memory:.0f}"
                         ).format(
@@ -241,3 +248,4 @@ def do_ssltrain(
     fig = loss_curve(logdf,output_folder)
     logger.info("saving {}".format(log_plot_file))
     fig.savefig(log_plot_file)
+  
\ No newline at end of file
diff --git a/bob/ip/binseg/modeling/losses.py b/bob/ip/binseg/modeling/losses.py
index 4ab9175802de7e1cbcbe676d7e22693b1fba868d..da2b5f5ed6a518b9f6e0aafe24a5edc1f247237b 100644
--- a/bob/ip/binseg/modeling/losses.py
+++ b/bob/ip/binseg/modeling/losses.py
@@ -171,11 +171,11 @@ class MixJacLoss(_Loss):
     lambda_u : int
         determines the weighting of SoftJaccard and BCE.
     """
-    def __init__(self, lambda_u=0.3, jacalpha=0.7, unlabeledjacalpha=0.7, size_average=None, reduce=None, reduction='mean', pos_weight=None):
+    def __init__(self, lambda_u=100, jacalpha=0.7, unlabeledjacalpha=0.7, size_average=None, reduce=None, reduction='mean', pos_weight=None):
         super(MixJacLoss, self).__init__(size_average, reduce, reduction)
         self.lambda_u = lambda_u
         self.labeled_loss = SoftJaccardBCELogitsLoss(alpha=jacalpha)
-        self.unlabeled_loss = SoftJaccardBCELogitsLoss(alpha=unlabeledjacalpha)
+        self.unlabeled_loss = torch.nn.BCEWithLogitsLoss()
 
 
     @weak_script_method