#!/usr/bin/env python # -*- coding: utf-8 -*- from torch.utils.data import Dataset import random class BinSegDataset(Dataset): """PyTorch dataset wrapper around bob.db binary segmentation datasets. A transform object can be passed that will be applied to the image, ground truth and mask (if present). It supports indexing such that dataset[i] can be used to get ith sample. Parameters ---------- bobdb : :py:mod:`bob.db.base` Binary segmentation bob database (e.g. bob.db.drive) split : str ``'train'`` or ``'test'``. Defaults to ``'train'`` transform : :py:mod:`bob.ip.binseg.data.transforms`, optional A transform or composition of transfroms. Defaults to ``None``. mask : bool whether dataset contains masks or not """ def __init__(self, bobdb, split="train", transform=None, index_to=None): if index_to: self.database = bobdb.samples(split)[:index_to] else: self.database = bobdb.samples(split) self.transform = transform self.split = split @property def mask(self): # check if first sample contains a mask return hasattr(self.database[0], "mask") def __len__(self): """ Returns ------- int size of the dataset """ return len(self.database) def __getitem__(self, index): """ Parameters ---------- index : int Returns ------- list dataitem [img_name, img, gt] """ img = self.database[index].img.pil_image() gt = self.database[index].gt.pil_image() img_name = self.database[index].img.basename sample = [img, gt] if self.transform: sample = self.transform(*sample) sample.insert(0, img_name) return sample class SSLBinSegDataset(Dataset): """PyTorch dataset wrapper around bob.db binary segmentation datasets. A transform object can be passed that will be applied to the image, ground truth and mask (if present). It supports indexing such that dataset[i] can be used to get ith sample. Parameters ---------- labeled_dataset : :py:class:`torch.utils.data.Dataset` BinSegDataset with labeled samples unlabeled_dataset : :py:class:`torch.utils.data.Dataset` UnLabeledBinSegDataset with unlabeled data """ def __init__(self, labeled_dataset, unlabeled_dataset): self.labeled_dataset = labeled_dataset self.unlabeled_dataset = unlabeled_dataset def __len__(self): """ Returns ------- int size of the dataset """ return len(self.labeled_dataset) def __getitem__(self, index): """ Parameters ---------- index : int Returns ------- list dataitem [img_name, img, gt, unlabeled_img_name, unlabeled_img] """ sample = self.labeled_dataset[index] unlabeled_img_name, unlabeled_img = self.unlabeled_dataset[0] sample.extend([unlabeled_img_name, unlabeled_img]) return sample class UnLabeledBinSegDataset(Dataset): # TODO: if switch to handle case were not a bob.db object but a path to a directory is used """PyTorch dataset wrapper around bob.db binary segmentation datasets. A transform object can be passed that will be applied to the image, ground truth and mask (if present). It supports indexing such that dataset[i] can be used to get ith sample. Parameters ---------- dv : :py:mod:`bob.db.base` or str Binary segmentation bob database (e.g. bob.db.drive) or path to folder containing unlabeled images split : str ``'train'`` or ``'test'``. Defaults to ``'train'`` transform : :py:mod:`bob.ip.binseg.data.transforms`, optional A transform or composition of transfroms. Defaults to ``None``. """ def __init__(self, db, split="train", transform=None, index_from=None): if index_from: self.database = db.samples(split)[index_from:] else: self.database = db.samples(split) self.transform = transform self.split = split def __len__(self): """ Returns ------- int size of the dataset """ return len(self.database) def __getitem__(self, index): """ Parameters ---------- index : int Returns ------- list dataitem [img_name, img] """ random.shuffle(self.database) img = self.database[index].img.pil_image() img_name = self.database[index].img.basename sample = [img] if self.transform: sample = self.transform(img) sample.insert(0, img_name) return sample