Skip to content
Snippets Groups Projects
Commit 833335da authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[all] Passed black on all python files

parent c3dc9915
No related branches found
No related tags found
1 merge request!12Streamlining
Pipeline #38202 passed
Showing
with 273 additions and 164 deletions
...@@ -7,13 +7,10 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset ...@@ -7,13 +7,10 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset
#### Config #### #### Config ####
transforms = Compose([ transforms = Compose([Pad((8, 8, 8, 8)), ToTensor()])
Pad((8,8,8,8))
,ToTensor()
])
# bob.db.dataset init # bob.db.dataset init
bobdb = RIMONER3(protocol = 'default_od') bobdb = RIMONER3(protocol="default_od")
# PyTorch dataset # PyTorch dataset
dataset = BinSegDataset(bobdb, split='test', transform=transforms) dataset = BinSegDataset(bobdb, split="test", transform=transforms)
\ No newline at end of file
...@@ -7,17 +7,19 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset ...@@ -7,17 +7,19 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset
#### Config #### #### Config ####
transforms = Compose([ transforms = Compose(
Pad((2,1,2,2)) [
,RandomHFlip() Pad((2, 1, 2, 2)),
,RandomVFlip() RandomHFlip(),
,RandomRotation() RandomVFlip(),
,ColorJitter() RandomRotation(),
,ToTensor() ColorJitter(),
]) ToTensor(),
]
)
# bob.db.dataset init # bob.db.dataset init
bobdb = STARE(protocol = 'default') bobdb = STARE(protocol="default")
# PyTorch dataset # PyTorch dataset
dataset = BinSegDataset(bobdb, split='train', transform=transforms) dataset = BinSegDataset(bobdb, split="train", transform=transforms)
\ No newline at end of file
...@@ -7,19 +7,21 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset ...@@ -7,19 +7,21 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset
#### Config #### #### Config ####
transforms = Compose([ transforms = Compose(
RandomRotation() [
,Pad((0,32,0,32)) RandomRotation(),
,Resize(1024) Pad((0, 32, 0, 32)),
,CenterCrop(1024) Resize(1024),
,RandomHFlip() CenterCrop(1024),
,RandomVFlip() RandomHFlip(),
,ColorJitter() RandomVFlip(),
,ToTensor() ColorJitter(),
]) ToTensor(),
]
)
# bob.db.dataset init # bob.db.dataset init
bobdb = STARE(protocol = 'default') bobdb = STARE(protocol="default")
# PyTorch dataset # PyTorch dataset
dataset = BinSegDataset(bobdb, split='train', transform=transforms) dataset = BinSegDataset(bobdb, split="train", transform=transforms)
\ No newline at end of file
...@@ -7,19 +7,21 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset ...@@ -7,19 +7,21 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset
#### Config #### #### Config ####
transforms = Compose([ transforms = Compose(
RandomRotation() [
,Crop(50,0,500,705) RandomRotation(),
,Resize(1168) Crop(50, 0, 500, 705),
,Pad((1,0,1,0)) Resize(1168),
,RandomHFlip() Pad((1, 0, 1, 0)),
,RandomVFlip() RandomHFlip(),
,ColorJitter() RandomVFlip(),
,ToTensor() ColorJitter(),
]) ToTensor(),
]
)
# bob.db.dataset init # bob.db.dataset init
bobdb = STARE(protocol = 'default') bobdb = STARE(protocol="default")
# PyTorch dataset # PyTorch dataset
dataset = BinSegDataset(bobdb, split='train', transform=transforms) dataset = BinSegDataset(bobdb, split="train", transform=transforms)
\ No newline at end of file
...@@ -7,17 +7,20 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset ...@@ -7,17 +7,20 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset
#### Config #### #### Config ####
transforms = Compose([ RandomRotation() transforms = Compose(
,Resize(471) [
,Pad((0,37,0,36)) RandomRotation(),
,RandomHFlip() Resize(471),
,RandomVFlip() Pad((0, 37, 0, 36)),
,ColorJitter() RandomHFlip(),
,ToTensor() RandomVFlip(),
]) ColorJitter(),
ToTensor(),
]
)
# bob.db.dataset init # bob.db.dataset init
bobdb = STARE(protocol = 'default') bobdb = STARE(protocol="default")
# PyTorch dataset # PyTorch dataset
dataset = BinSegDataset(bobdb, split='train', transform=transforms) dataset = BinSegDataset(bobdb, split="train", transform=transforms)
\ No newline at end of file
...@@ -7,19 +7,21 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset ...@@ -7,19 +7,21 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset
#### Config #### #### Config ####
transforms = Compose([ transforms = Compose(
RandomRotation() [
,Pad((0,32,0,32)) RandomRotation(),
,Resize(960) Pad((0, 32, 0, 32)),
,CenterCrop(960) Resize(960),
,RandomHFlip() CenterCrop(960),
,RandomVFlip() RandomHFlip(),
,ColorJitter() RandomVFlip(),
,ToTensor() ColorJitter(),
]) ToTensor(),
]
)
# bob.db.dataset init # bob.db.dataset init
bobdb = STARE(protocol = 'default') bobdb = STARE(protocol="default")
# PyTorch dataset # PyTorch dataset
dataset = BinSegDataset(bobdb, split='train', transform=transforms) dataset = BinSegDataset(bobdb, split="train", transform=transforms)
\ No newline at end of file
...@@ -7,4 +7,4 @@ import torch ...@@ -7,4 +7,4 @@ import torch
#### Config #### #### Config ####
# PyTorch dataset # PyTorch dataset
dataset = torch.utils.data.ConcatDataset([stare,chase,hrf,iostar]) dataset = torch.utils.data.ConcatDataset([stare, chase, hrf, iostar])
\ No newline at end of file
...@@ -5,30 +5,38 @@ from bob.ip.binseg.configs.datasets.hrf544 import dataset as hrf ...@@ -5,30 +5,38 @@ from bob.ip.binseg.configs.datasets.hrf544 import dataset as hrf
from bob.db.drive import Database as DRIVE from bob.db.drive import Database as DRIVE
from bob.ip.binseg.data.transforms import * from bob.ip.binseg.data.transforms import *
import torch import torch
from bob.ip.binseg.data.binsegdataset import BinSegDataset, SSLBinSegDataset, UnLabeledBinSegDataset from bob.ip.binseg.data.binsegdataset import (
BinSegDataset,
SSLBinSegDataset,
UnLabeledBinSegDataset,
)
#### Config #### #### Config ####
# PyTorch dataset # PyTorch dataset
labeled_dataset = torch.utils.data.ConcatDataset([stare,chase,iostar,hrf]) labeled_dataset = torch.utils.data.ConcatDataset([stare, chase, iostar, hrf])
#### Unlabeled STARE TRAIN #### #### Unlabeled STARE TRAIN ####
unlabeled_transforms = Compose([ unlabeled_transforms = Compose(
CenterCrop((544,544)) [
,RandomHFlip() CenterCrop((544, 544)),
,RandomVFlip() RandomHFlip(),
,RandomRotation() RandomVFlip(),
,ColorJitter() RandomRotation(),
,ToTensor() ColorJitter(),
]) ToTensor(),
]
)
# bob.db.dataset init # bob.db.dataset init
drivebobdb = DRIVE(protocol = 'default') drivebobdb = DRIVE(protocol="default")
# PyTorch dataset # PyTorch dataset
unlabeled_dataset = UnLabeledBinSegDataset(drivebobdb, split='train', transform=unlabeled_transforms) unlabeled_dataset = UnLabeledBinSegDataset(
drivebobdb, split="train", transform=unlabeled_transforms
)
# SSL Dataset # SSL Dataset
dataset = SSLBinSegDataset(labeled_dataset, unlabeled_dataset) dataset = SSLBinSegDataset(labeled_dataset, unlabeled_dataset)
\ No newline at end of file
...@@ -7,13 +7,10 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset ...@@ -7,13 +7,10 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset
#### Config #### #### Config ####
transforms = Compose([ transforms = Compose([Pad((2, 1, 2, 2)), ToTensor()])
Pad((2,1,2,2))
,ToTensor()
])
# bob.db.dataset init # bob.db.dataset init
bobdb = STARE(protocol = 'default') bobdb = STARE(protocol="default")
# PyTorch dataset # PyTorch dataset
dataset = BinSegDataset(bobdb, split='test', transform=transforms) dataset = BinSegDataset(bobdb, split="test", transform=transforms)
\ No newline at end of file
...@@ -26,13 +26,23 @@ scheduler_gamma = 0.1 ...@@ -26,13 +26,23 @@ scheduler_gamma = 0.1
model = build_driu() model = build_driu()
# pretrained backbone # pretrained backbone
pretrained_backbone = modelurls['vgg16_bn'] pretrained_backbone = modelurls["vgg16_bn"]
# optimizer # optimizer
optimizer = AdaBound(model.parameters(), lr=lr, betas=betas, final_lr=final_lr, gamma=gamma, optimizer = AdaBound(
eps=eps, weight_decay=weight_decay, amsbound=amsbound) model.parameters(),
lr=lr,
betas=betas,
final_lr=final_lr,
gamma=gamma,
eps=eps,
weight_decay=weight_decay,
amsbound=amsbound,
)
# criterion # criterion
criterion = SoftJaccardBCELogitsLoss(alpha=0.7) criterion = SoftJaccardBCELogitsLoss(alpha=0.7)
# scheduler # scheduler
scheduler = MultiStepLR(optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma) scheduler = MultiStepLR(
optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma
)
...@@ -26,14 +26,24 @@ scheduler_gamma = 0.1 ...@@ -26,14 +26,24 @@ scheduler_gamma = 0.1
model = build_driu() model = build_driu()
# pretrained backbone # pretrained backbone
pretrained_backbone = modelurls['vgg16_bn'] pretrained_backbone = modelurls["vgg16_bn"]
# optimizer # optimizer
optimizer = AdaBound(model.parameters(), lr=lr, betas=betas, final_lr=final_lr, gamma=gamma, optimizer = AdaBound(
eps=eps, weight_decay=weight_decay, amsbound=amsbound) model.parameters(),
lr=lr,
betas=betas,
final_lr=final_lr,
gamma=gamma,
eps=eps,
weight_decay=weight_decay,
amsbound=amsbound,
)
# criterion # criterion
criterion = MixJacLoss(lambda_u=0.05, jacalpha=0.7) criterion = MixJacLoss(lambda_u=0.05, jacalpha=0.7)
# scheduler # scheduler
scheduler = MultiStepLR(optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma) scheduler = MultiStepLR(
optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma
)
...@@ -26,13 +26,23 @@ scheduler_gamma = 0.1 ...@@ -26,13 +26,23 @@ scheduler_gamma = 0.1
model = build_driuod() model = build_driuod()
# pretrained backbone # pretrained backbone
pretrained_backbone = modelurls['vgg16'] pretrained_backbone = modelurls["vgg16"]
# optimizer # optimizer
optimizer = AdaBound(model.parameters(), lr=lr, betas=betas, final_lr=final_lr, gamma=gamma, optimizer = AdaBound(
eps=eps, weight_decay=weight_decay, amsbound=amsbound) model.parameters(),
lr=lr,
betas=betas,
final_lr=final_lr,
gamma=gamma,
eps=eps,
weight_decay=weight_decay,
amsbound=amsbound,
)
# criterion # criterion
criterion = SoftJaccardBCELogitsLoss(alpha=0.7) criterion = SoftJaccardBCELogitsLoss(alpha=0.7)
# scheduler # scheduler
scheduler = MultiStepLR(optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma) scheduler = MultiStepLR(
optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma
)
...@@ -26,14 +26,24 @@ scheduler_gamma = 0.1 ...@@ -26,14 +26,24 @@ scheduler_gamma = 0.1
model = build_driu() model = build_driu()
# pretrained backbone # pretrained backbone
pretrained_backbone = modelurls['vgg16'] pretrained_backbone = modelurls["vgg16"]
# optimizer # optimizer
optimizer = AdaBound(model.parameters(), lr=lr, betas=betas, final_lr=final_lr, gamma=gamma, optimizer = AdaBound(
eps=eps, weight_decay=weight_decay, amsbound=amsbound) model.parameters(),
lr=lr,
betas=betas,
final_lr=final_lr,
gamma=gamma,
eps=eps,
weight_decay=weight_decay,
amsbound=amsbound,
)
# criterion # criterion
criterion = MixJacLoss(lambda_u=0.05, jacalpha=0.7) criterion = MixJacLoss(lambda_u=0.05, jacalpha=0.7)
# scheduler # scheduler
scheduler = MultiStepLR(optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma) scheduler = MultiStepLR(
optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma
)
...@@ -27,13 +27,23 @@ scheduler_gamma = 0.1 ...@@ -27,13 +27,23 @@ scheduler_gamma = 0.1
model = build_hed() model = build_hed()
# pretrained backbone # pretrained backbone
pretrained_backbone = modelurls['vgg16'] pretrained_backbone = modelurls["vgg16"]
# optimizer # optimizer
optimizer = AdaBound(model.parameters(), lr=lr, betas=betas, final_lr=final_lr, gamma=gamma, optimizer = AdaBound(
eps=eps, weight_decay=weight_decay, amsbound=amsbound) model.parameters(),
lr=lr,
betas=betas,
final_lr=final_lr,
gamma=gamma,
eps=eps,
weight_decay=weight_decay,
amsbound=amsbound,
)
# criterion # criterion
criterion = HEDSoftJaccardBCELogitsLoss(alpha=0.7) criterion = HEDSoftJaccardBCELogitsLoss(alpha=0.7)
# scheduler # scheduler
scheduler = MultiStepLR(optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma) scheduler = MultiStepLR(
optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma
)
...@@ -26,14 +26,24 @@ scheduler_gamma = 0.1 ...@@ -26,14 +26,24 @@ scheduler_gamma = 0.1
model = build_m2unet() model = build_m2unet()
# pretrained backbone # pretrained backbone
pretrained_backbone = modelurls['mobilenetv2'] pretrained_backbone = modelurls["mobilenetv2"]
# optimizer # optimizer
optimizer = AdaBound(model.parameters(), lr=lr, betas=betas, final_lr=final_lr, gamma=gamma, optimizer = AdaBound(
eps=eps, weight_decay=weight_decay, amsbound=amsbound) model.parameters(),
lr=lr,
betas=betas,
final_lr=final_lr,
gamma=gamma,
eps=eps,
weight_decay=weight_decay,
amsbound=amsbound,
)
# criterion # criterion
criterion = SoftJaccardBCELogitsLoss(alpha=0.7) criterion = SoftJaccardBCELogitsLoss(alpha=0.7)
# scheduler # scheduler
scheduler = MultiStepLR(optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma) scheduler = MultiStepLR(
optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma
)
...@@ -26,14 +26,24 @@ scheduler_gamma = 0.1 ...@@ -26,14 +26,24 @@ scheduler_gamma = 0.1
model = build_m2unet() model = build_m2unet()
# pretrained backbone # pretrained backbone
pretrained_backbone = modelurls['mobilenetv2'] pretrained_backbone = modelurls["mobilenetv2"]
# optimizer # optimizer
optimizer = AdaBound(model.parameters(), lr=lr, betas=betas, final_lr=final_lr, gamma=gamma, optimizer = AdaBound(
eps=eps, weight_decay=weight_decay, amsbound=amsbound) model.parameters(),
lr=lr,
betas=betas,
final_lr=final_lr,
gamma=gamma,
eps=eps,
weight_decay=weight_decay,
amsbound=amsbound,
)
# criterion # criterion
criterion = MixJacLoss(lambda_u=0.05, jacalpha=0.7) criterion = MixJacLoss(lambda_u=0.05, jacalpha=0.7)
# scheduler # scheduler
scheduler = MultiStepLR(optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma) scheduler = MultiStepLR(
optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma
)
...@@ -26,14 +26,24 @@ scheduler_gamma = 0.1 ...@@ -26,14 +26,24 @@ scheduler_gamma = 0.1
model = build_res50unet() model = build_res50unet()
# pretrained backbone # pretrained backbone
pretrained_backbone = modelurls['resnet50'] pretrained_backbone = modelurls["resnet50"]
# optimizer # optimizer
optimizer = AdaBound(model.parameters(), lr=lr, betas=betas, final_lr=final_lr, gamma=gamma, optimizer = AdaBound(
eps=eps, weight_decay=weight_decay, amsbound=amsbound) model.parameters(),
lr=lr,
betas=betas,
final_lr=final_lr,
gamma=gamma,
eps=eps,
weight_decay=weight_decay,
amsbound=amsbound,
)
# criterion # criterion
criterion = SoftJaccardBCELogitsLoss(alpha=0.7) criterion = SoftJaccardBCELogitsLoss(alpha=0.7)
# scheduler # scheduler
scheduler = MultiStepLR(optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma) scheduler = MultiStepLR(
optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma
)
...@@ -26,14 +26,24 @@ scheduler_gamma = 0.1 ...@@ -26,14 +26,24 @@ scheduler_gamma = 0.1
model = build_unet() model = build_unet()
# pretrained backbone # pretrained backbone
pretrained_backbone = modelurls['vgg16'] pretrained_backbone = modelurls["vgg16"]
# optimizer # optimizer
optimizer = AdaBound(model.parameters(), lr=lr, betas=betas, final_lr=final_lr, gamma=gamma, optimizer = AdaBound(
eps=eps, weight_decay=weight_decay, amsbound=amsbound) model.parameters(),
lr=lr,
betas=betas,
final_lr=final_lr,
gamma=gamma,
eps=eps,
weight_decay=weight_decay,
amsbound=amsbound,
)
# criterion # criterion
criterion = SoftJaccardBCELogitsLoss(alpha=0.7) criterion = SoftJaccardBCELogitsLoss(alpha=0.7)
# scheduler # scheduler
scheduler = MultiStepLR(optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma) scheduler = MultiStepLR(
optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma
)
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
from torch.utils.data import Dataset from torch.utils.data import Dataset
import random import random
class BinSegDataset(Dataset): class BinSegDataset(Dataset):
"""PyTorch dataset wrapper around bob.db binary segmentation datasets. """PyTorch dataset wrapper around bob.db binary segmentation datasets.
A transform object can be passed that will be applied to the image, ground truth and mask (if present). A transform object can be passed that will be applied to the image, ground truth and mask (if present).
...@@ -19,18 +20,19 @@ class BinSegDataset(Dataset): ...@@ -19,18 +20,19 @@ class BinSegDataset(Dataset):
mask : bool mask : bool
whether dataset contains masks or not whether dataset contains masks or not
""" """
def __init__(self, bobdb, split = 'train', transform = None,index_to = None):
def __init__(self, bobdb, split="train", transform=None, index_to=None):
if index_to: if index_to:
self.database = bobdb.samples(split)[:index_to] self.database = bobdb.samples(split)[:index_to]
else: else:
self.database = bobdb.samples(split) self.database = bobdb.samples(split)
self.transform = transform self.transform = transform
self.split = split self.split = split
@property @property
def mask(self): def mask(self):
# check if first sample contains a mask # check if first sample contains a mask
return hasattr(self.database[0], 'mask') return hasattr(self.database[0], "mask")
def __len__(self): def __len__(self):
""" """
...@@ -40,8 +42,8 @@ class BinSegDataset(Dataset): ...@@ -40,8 +42,8 @@ class BinSegDataset(Dataset):
size of the dataset size of the dataset
""" """
return len(self.database) return len(self.database)
def __getitem__(self,index): def __getitem__(self, index):
""" """
Parameters Parameters
---------- ----------
...@@ -56,12 +58,12 @@ class BinSegDataset(Dataset): ...@@ -56,12 +58,12 @@ class BinSegDataset(Dataset):
gt = self.database[index].gt.pil_image() gt = self.database[index].gt.pil_image()
img_name = self.database[index].img.basename img_name = self.database[index].img.basename
sample = [img, gt] sample = [img, gt]
if self.transform : if self.transform:
sample = self.transform(*sample) sample = self.transform(*sample)
sample.insert(0,img_name) sample.insert(0, img_name)
return sample return sample
...@@ -77,10 +79,10 @@ class SSLBinSegDataset(Dataset): ...@@ -77,10 +79,10 @@ class SSLBinSegDataset(Dataset):
unlabeled_dataset : :py:class:`torch.utils.data.Dataset` unlabeled_dataset : :py:class:`torch.utils.data.Dataset`
UnLabeledBinSegDataset with unlabeled data UnLabeledBinSegDataset with unlabeled data
""" """
def __init__(self, labeled_dataset, unlabeled_dataset): def __init__(self, labeled_dataset, unlabeled_dataset):
self.labeled_dataset = labeled_dataset self.labeled_dataset = labeled_dataset
self.unlabeled_dataset = unlabeled_dataset self.unlabeled_dataset = unlabeled_dataset
def __len__(self): def __len__(self):
""" """
...@@ -90,8 +92,8 @@ class SSLBinSegDataset(Dataset): ...@@ -90,8 +92,8 @@ class SSLBinSegDataset(Dataset):
size of the dataset size of the dataset
""" """
return len(self.labeled_dataset) return len(self.labeled_dataset)
def __getitem__(self,index): def __getitem__(self, index):
""" """
Parameters Parameters
---------- ----------
...@@ -123,13 +125,14 @@ class UnLabeledBinSegDataset(Dataset): ...@@ -123,13 +125,14 @@ class UnLabeledBinSegDataset(Dataset):
transform : :py:mod:`bob.ip.binseg.data.transforms`, optional transform : :py:mod:`bob.ip.binseg.data.transforms`, optional
A transform or composition of transfroms. Defaults to ``None``. A transform or composition of transfroms. Defaults to ``None``.
""" """
def __init__(self, db, split = 'train', transform = None,index_from= None):
def __init__(self, db, split="train", transform=None, index_from=None):
if index_from: if index_from:
self.database = db.samples(split)[index_from:] self.database = db.samples(split)[index_from:]
else: else:
self.database = db.samples(split) self.database = db.samples(split)
self.transform = transform self.transform = transform
self.split = split self.split = split
def __len__(self): def __len__(self):
""" """
...@@ -139,8 +142,8 @@ class UnLabeledBinSegDataset(Dataset): ...@@ -139,8 +142,8 @@ class UnLabeledBinSegDataset(Dataset):
size of the dataset size of the dataset
""" """
return len(self.database) return len(self.database)
def __getitem__(self,index): def __getitem__(self, index):
""" """
Parameters Parameters
---------- ----------
...@@ -155,9 +158,9 @@ class UnLabeledBinSegDataset(Dataset): ...@@ -155,9 +158,9 @@ class UnLabeledBinSegDataset(Dataset):
img = self.database[index].img.pil_image() img = self.database[index].img.pil_image()
img_name = self.database[index].img.basename img_name = self.database[index].img.basename
sample = [img] sample = [img]
if self.transform : if self.transform:
sample = self.transform(img) sample = self.transform(img)
sample.insert(0,img_name) sample.insert(0, img_name)
return sample return sample
\ No newline at end of file
...@@ -8,16 +8,18 @@ import torch ...@@ -8,16 +8,18 @@ import torch
import torchvision.transforms.functional as VF import torchvision.transforms.functional as VF
import bob.io.base import bob.io.base
def get_file_lists(data_path): def get_file_lists(data_path):
data_path = Path(data_path) data_path = Path(data_path)
image_path = data_path.joinpath('images')
image_file_names = np.array(sorted(list(image_path.glob('*'))))
gt_path = data_path.joinpath('gt') image_path = data_path.joinpath("images")
gt_file_names = np.array(sorted(list(gt_path.glob('*')))) image_file_names = np.array(sorted(list(image_path.glob("*"))))
gt_path = data_path.joinpath("gt")
gt_file_names = np.array(sorted(list(gt_path.glob("*"))))
return image_file_names, gt_file_names return image_file_names, gt_file_names
class ImageFolder(Dataset): class ImageFolder(Dataset):
""" """
Generic ImageFolder dataset, that contains two folders: Generic ImageFolder dataset, that contains two folders:
...@@ -32,7 +34,8 @@ class ImageFolder(Dataset): ...@@ -32,7 +34,8 @@ class ImageFolder(Dataset):
full path to root of dataset full path to root of dataset
""" """
def __init__(self, path, transform = None):
def __init__(self, path, transform=None):
self.transform = transform self.transform = transform
self.img_file_list, self.gt_file_list = get_file_lists(path) self.img_file_list, self.gt_file_list = get_file_lists(path)
...@@ -44,8 +47,8 @@ class ImageFolder(Dataset): ...@@ -44,8 +47,8 @@ class ImageFolder(Dataset):
size of the dataset size of the dataset
""" """
return len(self.img_file_list) return len(self.img_file_list)
def __getitem__(self,index): def __getitem__(self, index):
""" """
Parameters Parameters
---------- ----------
...@@ -58,22 +61,22 @@ class ImageFolder(Dataset): ...@@ -58,22 +61,22 @@ class ImageFolder(Dataset):
""" """
img_path = self.img_file_list[index] img_path = self.img_file_list[index]
img_name = img_path.name img_name = img_path.name
img = Image.open(img_path).convert(mode='RGB') img = Image.open(img_path).convert(mode="RGB")
gt_path = self.gt_file_list[index] gt_path = self.gt_file_list[index]
if gt_path.suffix == '.hdf5': if gt_path.suffix == ".hdf5":
gt = bob.io.base.load(str(gt_path)).astype('float32') gt = bob.io.base.load(str(gt_path)).astype("float32")
# not elegant but since transforms require PIL images we do this hacky workaround here # not elegant but since transforms require PIL images we do this hacky workaround here
gt = torch.from_numpy(gt) gt = torch.from_numpy(gt)
gt = VF.to_pil_image(gt).convert(mode='1', dither=None) gt = VF.to_pil_image(gt).convert(mode="1", dither=None)
else: else:
gt = Image.open(gt_path).convert(mode='1', dither=None) gt = Image.open(gt_path).convert(mode="1", dither=None)
sample = [img, gt] sample = [img, gt]
if self.transform : if self.transform:
sample = self.transform(*sample) sample = self.transform(*sample)
sample.insert(0,img_name) sample.insert(0, img_name)
return sample return sample
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment