diff --git a/bob/ip/binseg/configs/datasets/drive1024ssliostar.py b/bob/ip/binseg/configs/datasets/drive1024ssliostar.py
new file mode 100644
index 0000000000000000000000000000000000000000..0158c4a9c4595c9baf558afa6dd97d5ee1ca64b6
--- /dev/null
+++ b/bob/ip/binseg/configs/datasets/drive1024ssliostar.py
@@ -0,0 +1,42 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+from bob.db.drive import Database as DRIVE
+from bob.db.iostar import Database as IOSTAR
+from bob.ip.binseg.data.transforms import *
+from bob.ip.binseg.data.binsegdataset import BinSegDataset, SSLBinSegDataset, UnLabeledBinSegDataset
+
+#### Config ####
+
+#### Unlabeled IOSTAR TRAIN ####
+unlabeled_transforms = Compose([  
+                        RandomHFlip()
+                        ,RandomVFlip()
+                        ,RandomRotation()
+                        ,ColorJitter()
+                        ,ToTensor()
+                    ])
+
+# bob.db.dataset init
+sslbobdb = IOSTAR(protocol = 'default_vessel')
+
+# PyTorch dataset
+unlabeled_dataset = UnLabeledBinSegDataset(sslbobdb, split='train', transform=unlabeled_transforms)
+
+
+#### Labeled ####
+labeled_transforms = Compose([
+                        CenterCrop((540,540))  
+                        ,Resize(1024)
+                        ,RandomHFlip()
+                        ,RandomVFlip()
+                        ,RandomRotation()
+                        ,ColorJitter()
+                        ,ToTensor()
+                    ])
+
+# bob.db.dataset init
+bobdb = DRIVE(protocol = 'default')
+
+# PyTorch dataset
+dataset = SSLBinSegDataset(bobdb, unlabeled_dataset, split='train', transform=labeled_transforms)
diff --git a/bob/ip/binseg/configs/datasets/drive2336sslhrf.py b/bob/ip/binseg/configs/datasets/drive2336sslhrf.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ab1e77daf8195c18d3853bcb2e66bc843c6ca35
--- /dev/null
+++ b/bob/ip/binseg/configs/datasets/drive2336sslhrf.py
@@ -0,0 +1,44 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+from bob.db.drive import Database as DRIVE
+from bob.db.hrf import Database as HRF
+from bob.ip.binseg.data.transforms import *
+from bob.ip.binseg.data.binsegdataset import BinSegDataset, SSLBinSegDataset, UnLabeledBinSegDataset
+
+#### Config ####
+
+#### Unlabeled HRF TRAIN ####
+unlabeled_transforms = Compose([  
+                        Crop(0,108,2336,3296)
+                        ,RandomHFlip()
+                        ,RandomVFlip()
+                        ,RandomRotation()
+                        ,ColorJitter()
+                        ,ToTensor()
+                    ])
+
+# bob.db.dataset init
+sslbobdb = HRF(protocol = 'default')
+
+# PyTorch dataset
+unlabeled_dataset = UnLabeledBinSegDataset(sslbobdb, split='train', transform=unlabeled_transforms)
+
+
+#### Labeled ####
+labeled_transforms = Compose([  
+                        Crop(75,10,416,544)
+                        ,Pad((21,0,22,0))
+                        ,Resize(2336)
+                        ,RandomHFlip()
+                        ,RandomVFlip()
+                        ,RandomRotation()
+                        ,ColorJitter()
+                        ,ToTensor()
+                    ])
+
+# bob.db.dataset init
+bobdb = DRIVE(protocol = 'default')
+
+# PyTorch dataset
+dataset = SSLBinSegDataset(bobdb, unlabeled_dataset, split='train', transform=labeled_transforms)
diff --git a/bob/ip/binseg/configs/models/m2unetssl.py b/bob/ip/binseg/configs/models/m2unetssl.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0beafe6acef86a74e8955a7de7c2c6c04502037
--- /dev/null
+++ b/bob/ip/binseg/configs/models/m2unetssl.py
@@ -0,0 +1,39 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+from torch.optim.lr_scheduler import MultiStepLR
+from bob.ip.binseg.modeling.m2u import build_m2unet
+import torch.optim as optim
+from torch.nn import BCEWithLogitsLoss
+from bob.ip.binseg.utils.model_zoo import modelurls
+from bob.ip.binseg.modeling.losses import MixJacLoss
+from bob.ip.binseg.engine.adabound import AdaBound
+
+##### Config #####
+lr = 0.001
+betas = (0.9, 0.999)
+eps = 1e-08
+weight_decay = 0
+final_lr = 0.1
+gamma = 1e-3
+eps = 1e-8
+amsbound = False
+
+scheduler_milestones = [900]
+scheduler_gamma = 0.1
+
+# model
+model = build_m2unet()
+
+# pretrained backbone
+pretrained_backbone = modelurls['mobilenetv2']
+
+# optimizer
+optimizer = AdaBound(model.parameters(), lr=lr, betas=betas, final_lr=final_lr, gamma=gamma,
+                 eps=eps, weight_decay=weight_decay, amsbound=amsbound) 
+    
+# criterion
+criterion = MixJacLoss(lambda_u=0.3, jacalpha=0.7, unlabeledjacalpha=0.7)
+
+# scheduler
+scheduler = MultiStepLR(optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma)
diff --git a/bob/ip/binseg/data/binsegdataset.py b/bob/ip/binseg/data/binsegdataset.py
index 86fc97df482707a9715b38790741aff65bfe2b52..0f3ca24730c7f3b83880ee42137555057e9218eb 100644
--- a/bob/ip/binseg/data/binsegdataset.py
+++ b/bob/ip/binseg/data/binsegdataset.py
@@ -1,6 +1,7 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
 from torch.utils.data import Dataset
+import random
 
 class BinSegDataset(Dataset):
     """PyTorch dataset wrapper around bob.db binary segmentation datasets. 
@@ -62,3 +63,113 @@ class BinSegDataset(Dataset):
         sample.insert(0,img_name)
         
         return sample
+
+
+class SSLBinSegDataset(Dataset):
+    """PyTorch dataset wrapper around bob.db binary segmentation datasets. 
+    A transform object can be passed that will be applied to the image, ground truth and mask (if present). 
+    It supports indexing such that dataset[i] can be used to get ith sample.
+    
+    Parameters
+    ---------- 
+    bobdb : :py:mod:`bob.db.base`
+        Binary segmentation bob database (e.g. bob.db.drive) 
+    unlabeled_dataset : :py:class:`torch.utils.data.Dataset`
+        dataset with unlabeled data
+    split : str 
+        ``'train'`` or ``'test'``. Defaults to ``'train'``
+    transform : :py:mod:`bob.ip.binseg.data.transforms`, optional
+        A transform or composition of transfroms. Defaults to ``None``.
+    """
+    def __init__(self, bobdb, unlabeled_dataset, split = 'train', transform = None):
+        self.database = bobdb.samples(split)
+        self.unlabeled_dataset = unlabeled_dataset
+        self.transform = transform
+        self.split = split
+    
+
+    def __len__(self):
+        """
+        Returns
+        -------
+        int
+            size of the dataset
+        """
+        return len(self.database)
+    
+    def __getitem__(self,index):
+        """
+        Parameters
+        ----------
+        index : int
+        
+        Returns
+        -------
+        list
+            dataitem [img_name, img, gt, unlabeled_img_name, unlabeled_img]
+        """
+        img = self.database[index].img.pil_image()
+        gt = self.database[index].gt.pil_image()
+        img_name = self.database[index].img.basename
+        sample = [img, gt]
+
+        
+        if self.transform :
+            sample = self.transform(*sample)
+    
+        sample.insert(0,img_name)
+        unlabeled_img_name, unlabeled_img = self.unlabeled_dataset[index]
+        sample.extend([unlabeled_img_name, unlabeled_img])
+        return sample
+
+
+class UnLabeledBinSegDataset(Dataset):
+    # TODO: if switch to handle case were not a bob.db object but a path to a directory is used
+    """PyTorch dataset wrapper around bob.db binary segmentation datasets. 
+    A transform object can be passed that will be applied to the image, ground truth and mask (if present). 
+    It supports indexing such that dataset[i] can be used to get ith sample.
+    
+    Parameters
+    ---------- 
+    dv : :py:mod:`bob.db.base` or str
+        Binary segmentation bob database (e.g. bob.db.drive) or path to folder containing unlabeled images
+    split : str 
+        ``'train'`` or ``'test'``. Defaults to ``'train'``
+    transform : :py:mod:`bob.ip.binseg.data.transforms`, optional
+        A transform or composition of transfroms. Defaults to ``None``.
+    """
+    def __init__(self, db, split = 'train', transform = None):
+        self.database = db.samples(split)
+        self.transform = transform
+        self.split = split   
+
+    def __len__(self):
+        """
+        Returns
+        -------
+        int
+            size of the dataset
+        """
+        return len(self.database)
+    
+    def __getitem__(self,index):
+        """
+        Parameters
+        ----------
+        index : int
+        
+        Returns
+        -------
+        list
+            dataitem [img_name, img]
+        """
+        random.shuffle(self.database)
+        img = self.database[index].img.pil_image()
+        img_name = self.database[index].img.basename
+        sample = [img]
+        if self.transform :
+            sample = self.transform(img)
+        
+        sample.insert(0,img_name)
+        
+        return sample
\ No newline at end of file
diff --git a/bob/ip/binseg/engine/ssltrainer.py b/bob/ip/binseg/engine/ssltrainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0489558e075b7fb873d21a4b3cf50db954eacbb
--- /dev/null
+++ b/bob/ip/binseg/engine/ssltrainer.py
@@ -0,0 +1,239 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+import os 
+import logging
+import time
+import datetime
+import torch
+import pandas as pd
+from tqdm import tqdm
+import numpy as np
+
+from bob.ip.binseg.utils.metric import SmoothedValue
+from bob.ip.binseg.utils.plot import loss_curve
+
+def mix_up(alpha, input, target, unlabeled_input, unlabled_target):
+    """Applies mix up as described in [MIXMATCH_19].
+    
+    Parameters
+    ----------
+    alpha : float
+    input : :py:class:`torch.Tensor`
+    target : :py:class:`torch.Tensor`
+    unlabeled_input : :py:class:`torch.Tensor`
+    unlabled_target : :py:class:`torch.Tensor`
+    
+    Returns
+    -------
+    list
+    """
+    l = np.random.beta(alpha, alpha) # Eq (8)
+    l = max(l, 1 - l) # Eq (9)
+    # Shuffle and concat. Alg. 1 Line: 12
+    w_inputs = torch.cat([input,unlabeled_input],0)
+    w_targets = torch.cat([target,unlabled_target],0)
+    idx = torch.randperm(w_inputs.size(0)) # get random index 
+     
+    # Apply MixUp to labeled data and entries from W. Alg. 1 Line: 13
+    input_mixedup = l * input + (1 - l) * w_inputs[idx[len(input):]] 
+    target_mixedup = l * target + (1 - l) * w_targets[idx[len(target):]]
+    
+    # Apply MixUp to unlabeled data and entries from W. Alg. 1 Line: 14
+    unlabeled_input_mixedup = l * unlabeled_input + (1 - l) * w_inputs[idx[:len(unlabeled_input)]]
+    unlabled_target_mixedup =  l * unlabled_target + (1 - l) * w_targets[idx[:len(unlabled_target)]]
+    return input_mixedup, target_mixedup, unlabeled_input_mixedup, unlabled_target_mixedup
+
+
+def linear_rampup(current, rampup_length=16):
+    """slowly ramp-up ``lambda_u``
+    
+    Parameters
+    ----------
+    current : int
+        current epoch
+    rampup_length : int, optional
+        how long to ramp up, by default 16
+    
+    Returns
+    -------
+    float
+        ramp up factor
+    """
+    if rampup_length == 0:
+        return 1.0
+    else:
+        current = np.clip(current / rampup_length, 0.0, 1.0)
+    return float(current)
+
+def guess_labels(unlabeled_images, model):
+    """
+    Calculate the average predictions by 2 augmentations: horizontal and vertical flips
+    Parameters
+    ----------
+    unlabeled_images : :py:class:`torch.Tensor`
+        shape: ``[n,c,h,w]``
+    target : :py:class:`torch.Tensor`
+    
+    Returns
+    -------
+    :py:class:`torch.Tensor`
+        shape: ``[n,c,h,w]``.
+    """
+    with torch.no_grad():
+        guess1 = torch.sigmoid(model(unlabeled_images)).unsqueeze(0)
+        # Horizontal flip and unsqueeze to work with batches (increase flip dimension by 1)
+        hflip = torch.sigmoid(model(unlabeled_images.flip(2))).unsqueeze(0)
+        guess2  = hflip.flip(3)
+        # Vertical flip and unsqueeze to work with batches (increase flip dimension by 1)
+        vflip = torch.sigmoid(model(unlabeled_images.flip(3))).unsqueeze(0)
+        guess3 = vflip.flip(4)
+        # Concat
+        concat = torch.cat([guess1,guess2,guess3],0)
+        avg_guess = torch.mean(concat,0)
+        return avg_guess
+
+def do_ssltrain(
+    model,
+    data_loader,
+    optimizer,
+    criterion,
+    scheduler,
+    checkpointer,
+    checkpoint_period,
+    device,
+    arguments,
+    output_folder
+):
+    """ 
+    Train model and save to disk.
+    
+    Parameters
+    ----------
+    model : :py:class:`torch.nn.Module` 
+        Network (e.g. DRIU, HED, UNet)
+    data_loader : :py:class:`torch.utils.data.DataLoader`
+    optimizer : :py:mod:`torch.optim`
+    criterion : :py:class:`torch.nn.modules.loss._Loss`
+        loss function
+    scheduler : :py:mod:`torch.optim`
+        learning rate scheduler
+    checkpointer : :py:class:`bob.ip.binseg.utils.checkpointer.DetectronCheckpointer`
+        checkpointer
+    checkpoint_period : int
+        save a checkpoint every n epochs
+    device : str  
+        device to use ``'cpu'`` or ``'cuda'``
+    arguments : dict
+        start end end epochs
+    output_folder : str 
+        output path
+    """
+    logger = logging.getLogger("bob.ip.binseg.engine.trainer")
+    logger.info("Start training")
+    start_epoch = arguments["epoch"]
+    max_epoch = arguments["max_epoch"]
+
+    # Logg to file
+    with open (os.path.join(output_folder,"{}_trainlog.csv".format(model.name)), "a+") as outfile:
+        
+        model.train().to(device)
+        # Total training timer
+        start_training_time = time.time()
+        for epoch in range(start_epoch, max_epoch):
+            scheduler.step()
+            losses = SmoothedValue(len(data_loader))
+            labeled_loss = SmoothedValue(len(data_loader))
+            unlabeled_loss = SmoothedValue(len(data_loader))
+            epoch = epoch + 1
+            arguments["epoch"] = epoch
+            
+            # Epoch time
+            start_epoch_time = time.time()
+
+            for samples in tqdm(data_loader):
+                # labeled
+                images = samples[1].to(device)
+                ground_truths = samples[2].to(device)
+                unlabeled_images = samples[4].to(device)
+                # labeled outputs
+                outputs = model(images)
+                unlabeled_outputs = model(unlabeled_images)
+                # guessed unlabeled outputs
+                unlabeled_ground_truths = guess_labels(unlabeled_images, model)
+                ramp_up_factor = linear_rampup(epoch,rampup_length=16)
+
+                
+                loss, ll, ul = criterion(outputs, ground_truths, unlabeled_outputs, unlabeled_ground_truths, ramp_up_factor)
+                optimizer.zero_grad()
+                loss.backward()
+                optimizer.step()
+                losses.update(loss)
+                labeled_loss.update(ll)
+                unlabeled_loss.update(ul)
+                logger.debug("batch loss: {}".format(loss.item()))
+
+            if epoch % checkpoint_period == 0:
+                checkpointer.save("model_{:03d}".format(epoch), **arguments)
+
+            if epoch == max_epoch:
+                checkpointer.save("model_final", **arguments)
+
+            epoch_time = time.time() - start_epoch_time
+
+
+            eta_seconds = epoch_time * (max_epoch - epoch)
+            eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+
+            outfile.write(("{epoch}, "
+                        "{avg_loss:.6f}, "
+                        "{median_loss:.6f}, "
+                        "{median_labeled_loss},"
+                        "{median_unlabeled_loss},"
+                        "{lr:.6f}, "
+                        "{memory:.0f}"
+                        "\n"
+                        ).format(
+                    eta=eta_string,
+                    epoch=epoch,
+                    avg_loss=losses.avg,
+                    median_loss=losses.median,
+                    median_labeled_loss = labeled_loss.median,
+                    median_unlabeled_loss = unlabeled_loss.median,
+                    lr=optimizer.param_groups[0]["lr"],
+                    memory = (torch.cuda.max_memory_allocated() / 1024.0 / 1024.0) if torch.cuda.is_available() else .0,
+                    )
+                )  
+            logger.info(("eta: {eta}, " 
+                        "epoch: {epoch}, "
+                        "avg. loss: {avg_loss:.6f}, "
+                        "median loss: {median_loss:.6f}, "
+                        "{median_labeled_loss}, "
+                        "{median_unlabeled_loss}, "
+                        "lr: {lr:.6f}, "
+                        "max mem: {memory:.0f}"
+                        ).format(
+                    eta=eta_string,
+                    epoch=epoch,
+                    avg_loss=losses.avg,
+                    median_loss=losses.median,
+                    median_labeled_loss = labeled_loss.median,
+                    median_unlabeled_loss = unlabeled_loss.median,
+                    lr=optimizer.param_groups[0]["lr"],
+                    memory = (torch.cuda.max_memory_allocated() / 1024.0 / 1024.0) if torch.cuda.is_available() else .0
+                    )
+                )
+
+
+        total_training_time = time.time() - start_training_time
+        total_time_str = str(datetime.timedelta(seconds=total_training_time))
+        logger.info(
+            "Total training time: {} ({:.4f} s / epoch)".format(
+                total_time_str, total_training_time / (max_epoch)
+            ))
+        
+    log_plot_file = os.path.join(output_folder,"{}_trainlog.pdf".format(model.name))
+    logdf = pd.read_csv(os.path.join(output_folder,"{}_trainlog.csv".format(model.name)),header=None, names=["avg. loss", "median loss","lr","max memory"])
+    fig = loss_curve(logdf,output_folder)
+    logger.info("saving {}".format(log_plot_file))
+    fig.savefig(log_plot_file)
diff --git a/bob/ip/binseg/modeling/losses.py b/bob/ip/binseg/modeling/losses.py
index 5eeb7950ebe344dadc08fcc8b21c057a081c08ed..4ab9175802de7e1cbcbe676d7e22693b1fba868d 100644
--- a/bob/ip/binseg/modeling/losses.py
+++ b/bob/ip/binseg/modeling/losses.py
@@ -49,9 +49,9 @@ class SoftJaccardBCELogitsLoss(_Loss):
     Attributes
     ----------
     alpha : float
-        determines the weighting of SoftJaccard and BCE. Default: ``0.3``
+        determines the weighting of SoftJaccard and BCE. Default: ``0.7``
     """
-    def __init__(self, alpha=0.3, size_average=None, reduce=None, reduction='mean', pos_weight=None):
+    def __init__(self, alpha=0.7, size_average=None, reduce=None, reduction='mean', pos_weight=None):
         super(SoftJaccardBCELogitsLoss, self).__init__(size_average, reduce, reduction) 
         self.alpha = alpha   
 
@@ -160,4 +160,41 @@ class HEDSoftJaccardBCELogitsLoss(_Loss):
             loss = self.alpha * bceloss + (1 - self.alpha) * (1-softjaccard)
             loss_over_all_inputs.append(loss.unsqueeze(0))
         final_loss = torch.cat(loss_over_all_inputs).mean()
-        return loss
\ No newline at end of file
+        return loss
+
+
+
+class MixJacLoss(_Loss):
+    """ 
+    Attributes
+    ----------
+    lambda_u : int
+        determines the weighting of SoftJaccard and BCE.
+    """
+    def __init__(self, lambda_u=0.3, jacalpha=0.7, unlabeledjacalpha=0.7, size_average=None, reduce=None, reduction='mean', pos_weight=None):
+        super(MixJacLoss, self).__init__(size_average, reduce, reduction)
+        self.lambda_u = lambda_u
+        self.labeled_loss = SoftJaccardBCELogitsLoss(alpha=jacalpha)
+        self.unlabeled_loss = SoftJaccardBCELogitsLoss(alpha=unlabeledjacalpha)
+
+
+    @weak_script_method
+    def forward(self, input, target, unlabeled_input, unlabeled_traget, ramp_up_factor):
+        """
+        Parameters
+        ----------
+        input : :py:class:`torch.Tensor`
+        target : :py:class:`torch.Tensor`
+        unlabeled_input : :py:class:`torch.Tensor`
+        unlabeled_traget : :py:class:`torch.Tensor`
+        ramp_up_factor : float
+        
+        Returns
+        -------
+        list
+        """
+        ll = self.labeled_loss(input,target)
+        ul = self.unlabeled_loss(unlabeled_input, unlabeled_traget)
+        
+        loss = ll + self.lambda_u * ramp_up_factor * ul
+        return loss, ll, ul
\ No newline at end of file
diff --git a/bob/ip/binseg/script/binseg.py b/bob/ip/binseg/script/binseg.py
index 2a90d5ed652dcab7779a97770caf02fec1f39622..16b59f2d83c468b10786e793a3d4a266a9df13b5 100644
--- a/bob/ip/binseg/script/binseg.py
+++ b/bob/ip/binseg/script/binseg.py
@@ -24,6 +24,7 @@ from bob.extension.scripts.click_helper import (verbosity_option,
 from bob.ip.binseg.utils.checkpointer import DetectronCheckpointer
 from torch.utils.data import DataLoader
 from bob.ip.binseg.engine.trainer import do_train
+from bob.ip.binseg.engine.ssltrainer import do_ssltrain
 from bob.ip.binseg.engine.inferencer import do_inference
 from bob.ip.binseg.utils.plot import plot_overview
 from bob.ip.binseg.utils.click import OptionEatAll
@@ -390,4 +391,126 @@ def visualize(dataset, output_path, **kwargs):
     logger.info('Creating TP, FP, FN visualizations for {}'.format(output_path))
     metricsviz(dataset=dataset, output_path=output_path)
     logger.info('Creating overlay visualizations for {}'.format(output_path))
-    overlay(dataset=dataset, output_path=output_path)
\ No newline at end of file
+    overlay(dataset=dataset, output_path=output_path)
+
+
+# SSLTrain
+@binseg.command(entry_point_group='bob.ip.binseg.config', cls=ConfigCommand)
+@click.option(
+    '--output-path',
+    '-o',
+    required=True,
+    default="output",
+    cls=ResourceOption
+    )
+@click.option(
+    '--model',
+    '-m',
+    required=True,
+    cls=ResourceOption
+    )
+@click.option(
+    '--dataset',
+    '-d',
+    required=True,
+    cls=ResourceOption
+    )
+@click.option(
+    '--optimizer',
+    required=True,
+    cls=ResourceOption
+    )
+@click.option(
+    '--criterion',
+    required=True,
+    cls=ResourceOption
+    )
+@click.option(
+    '--scheduler',
+    required=True,
+    cls=ResourceOption
+    )
+@click.option(
+    '--pretrained-backbone',
+    '-t',
+    required=True,
+    cls=ResourceOption
+    )
+@click.option(
+    '--batch-size',
+    '-b',
+    required=True,
+    default=2,
+    cls=ResourceOption)
+@click.option(
+    '--epochs',
+    '-e',
+    help='Number of epochs used for training',
+    show_default=True,
+    required=True,
+    default=6,
+    cls=ResourceOption)
+@click.option(
+    '--checkpoint-period',
+    '-p',
+    help='Number of epochs after which a checkpoint is saved',
+    show_default=True,
+    required=True,
+    default=2,
+    cls=ResourceOption)
+@click.option(
+    '--device',
+    '-d',
+    help='A string indicating the device to use (e.g. "cpu" or "cuda:0"',
+    show_default=True,
+    required=True,
+    default='cpu',
+    cls=ResourceOption)
+
+@verbosity_option(cls=ResourceOption)
+def ssltrain(model
+        ,optimizer
+        ,scheduler
+        ,output_path
+        ,epochs
+        ,pretrained_backbone
+        ,batch_size
+        ,criterion
+        ,dataset
+        ,checkpoint_period
+        ,device
+        ,**kwargs):
+    """ Train a model """
+    
+    if not os.path.exists(output_path): os.makedirs(output_path)
+    
+    # PyTorch dataloader
+    data_loader = DataLoader(
+        dataset = dataset
+        ,batch_size = batch_size
+        ,shuffle= True
+        ,pin_memory = torch.cuda.is_available()
+        )
+
+    # Checkpointer
+    checkpointer = DetectronCheckpointer(model, optimizer, scheduler,save_dir = output_path, save_to_disk=True)
+    arguments = {}
+    arguments["epoch"] = 0 
+    extra_checkpoint_data = checkpointer.load(pretrained_backbone)
+    arguments.update(extra_checkpoint_data)
+    arguments["max_epoch"] = epochs
+    
+    # Train
+    logger.info("Training for {} epochs".format(arguments["max_epoch"]))
+    logger.info("Continuing from epoch {}".format(arguments["epoch"]))
+    do_ssltrain(model
+            , data_loader
+            , optimizer
+            , criterion
+            , scheduler
+            , checkpointer
+            , checkpoint_period
+            , device
+            , arguments
+            , output_path
+            )
\ No newline at end of file