diff --git a/bob/ip/binseg/configs/datasets/drivestarechasedb1hrf1024.py b/bob/ip/binseg/configs/datasets/drivestarechasedb1hrf1024.py new file mode 100644 index 0000000000000000000000000000000000000000..74628abe77801ef572cd3ad2ff379e9af9287506 --- /dev/null +++ b/bob/ip/binseg/configs/datasets/drivestarechasedb1hrf1024.py @@ -0,0 +1,10 @@ +from bob.ip.binseg.configs.datasets.drive1024 import dataset as drive +from bob.ip.binseg.configs.datasets.stare1024 import dataset as stare +from bob.ip.binseg.configs.datasets.hrf1024 import dataset as hrf +from bob.ip.binseg.configs.datasets.chasedb11024 import dataset as chase +import torch + +#### Config #### + +# PyTorch dataset +dataset = torch.utils.data.ConcatDataset([drive,stare,hrf,chase]) diff --git a/bob/ip/binseg/configs/datasets/drivestarechasedb1hrf1024ssliostar.py b/bob/ip/binseg/configs/datasets/drivestarechasedb1hrf1024ssliostar.py new file mode 100644 index 0000000000000000000000000000000000000000..4dbd23d2cf52f40e4b03c38bb603f4139900a1c1 --- /dev/null +++ b/bob/ip/binseg/configs/datasets/drivestarechasedb1hrf1024ssliostar.py @@ -0,0 +1,34 @@ +from bob.ip.binseg.configs.datasets.drive1024 import dataset as drive +from bob.ip.binseg.configs.datasets.stare1024 import dataset as stare +from bob.ip.binseg.configs.datasets.hrf1024 import dataset as hrf +from bob.ip.binseg.configs.datasets.chasedb11024 import dataset as chasedb +from bob.db.iostar import Database as IOSTAR +from bob.ip.binseg.data.transforms import * +import torch +from bob.ip.binseg.data.binsegdataset import BinSegDataset, SSLBinSegDataset, UnLabeledBinSegDataset + + +#### Config #### + +# PyTorch dataset +labeled_dataset = torch.utils.data.ConcatDataset([drive,stare,hrf,chasedb]) + +#### Unlabeled CHASE TRAIN #### +unlabeled_transforms = Compose([ + Crop(0,18,960,960) + ,RandomHFlip() + ,RandomVFlip() + ,RandomRotation() + ,ColorJitter() + ,ToTensor() + ]) + +# bob.db.dataset init +iostarbobdb = IOSTAR(protocol='default_vessel') + +# PyTorch dataset +unlabeled_dataset = UnLabeledBinSegDataset(iostarbobdb, split='train', transform=unlabeled_transforms) + +# SSL Dataset + +dataset = SSLBinSegDataset(labeled_dataset, unlabeled_dataset) \ No newline at end of file diff --git a/bob/ip/binseg/configs/datasets/drivestarechasedb1iostar1168sslhrf.py b/bob/ip/binseg/configs/datasets/drivestarechasedb1iostar1168sslhrf.py new file mode 100644 index 0000000000000000000000000000000000000000..376aae1482ac3c8bd6911ae61e83f498bb5ad6c6 --- /dev/null +++ b/bob/ip/binseg/configs/datasets/drivestarechasedb1iostar1168sslhrf.py @@ -0,0 +1,35 @@ +from bob.ip.binseg.configs.datasets.drive1168 import dataset as drive +from bob.ip.binseg.configs.datasets.stare1168 import dataset as stare +from bob.ip.binseg.configs.datasets.chasedb11168 import dataset as chasedb +from bob.ip.binseg.configs.datasets.iostarvessel1168 import dataset as iostar +from bob.db.hrf import Database as HRF +from bob.ip.binseg.data.transforms import * +import torch +from bob.ip.binseg.data.binsegdataset import BinSegDataset, SSLBinSegDataset, UnLabeledBinSegDataset + + +#### Config #### + +# PyTorch dataset +labeled_dataset = torch.utils.data.ConcatDataset([drive,stare,iostar,chasedb]) + +#### Unlabeled CHASE TRAIN #### +unlabeled_transforms = Compose([ + Crop(0,108,2336,3296) + ,Resize((1168)) + ,RandomHFlip() + ,RandomVFlip() + ,RandomRotation() + ,ColorJitter() + ,ToTensor() + ]) + +# bob.db.dataset init +hrfbobdb = HRF(protocol='default') + +# PyTorch dataset +unlabeled_dataset = UnLabeledBinSegDataset(hrfbobdb, split='train', transform=unlabeled_transforms) + +# SSL Dataset + +dataset = SSLBinSegDataset(labeled_dataset, unlabeled_dataset) \ No newline at end of file diff --git a/bob/ip/binseg/configs/datasets/hrf1024.py b/bob/ip/binseg/configs/datasets/hrf1024.py new file mode 100644 index 0000000000000000000000000000000000000000..48168445f5689b95a1f55ffda0633d2acdbb619a --- /dev/null +++ b/bob/ip/binseg/configs/datasets/hrf1024.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from bob.db.hrf import Database as HRF +from bob.ip.binseg.data.transforms import * +from bob.ip.binseg.data.binsegdataset import BinSegDataset + +#### Config #### + +transforms = Compose([ + Pad((0,584,0,584)) + ,Resize((1024)) + ,RandomRotation() + ,RandomHFlip() + ,RandomVFlip() + ,ColorJitter() + ,ToTensor() + ]) + +# bob.db.dataset init +bobdb = HRF(protocol = 'default') + +# PyTorch dataset +dataset = BinSegDataset(bobdb, split='train', transform=transforms) diff --git a/bob/ip/binseg/script/binseg.py b/bob/ip/binseg/script/binseg.py index 6af9d438a317a280d1b2c3e0c5fbe65fbd50068d..5b01d03960693f7c9cffe2f0207f3758c799605e 100644 --- a/bob/ip/binseg/script/binseg.py +++ b/bob/ip/binseg/script/binseg.py @@ -309,12 +309,17 @@ def testcheckpoints(model '-o', required=True, ) +@click.option( + '--title', + '-t', + required=False, + ) @verbosity_option(cls=ResourceOption) -def compare(output_path_list, output_path, **kwargs): +def compare(output_path_list, output_path, title, **kwargs): """ Compares multiple metrics files that are stored in the format mymodel/results/Metrics.csv """ logger.debug("Output paths: {}".format(output_path_list)) logger.info('Plotting precision vs recall curves for {}'.format(output_path_list)) - fig = plot_overview(output_path_list) + fig = plot_overview(output_path_list,title) if not os.path.exists(output_path): os.makedirs(output_path) fig_filename = os.path.join(output_path, 'precision_recall_comparison.pdf') logger.info('saving {}'.format(fig_filename)) diff --git a/bob/ip/binseg/utils/plot.py b/bob/ip/binseg/utils/plot.py index bedb7625a0fbdb98b6e72538bd0e9f81ed313d6e..db6269377054f5084e0fe84b8fb14c30733ba9a1 100644 --- a/bob/ip/binseg/utils/plot.py +++ b/bob/ip/binseg/utils/plot.py @@ -158,7 +158,7 @@ def read_metricscsv(file): return np.array(precision), np.array(recall) -def plot_overview(outputfolders): +def plot_overview(outputfolders,title): """ Plots comparison chart of all trained models @@ -166,6 +166,8 @@ def plot_overview(outputfolders): ---------- outputfolder : list list containing output paths of all evaluated models (e.g. ``['DRIVE/model1', 'DRIVE/model2']``) + title : str + title of plot Returns ------- matplotlib.figure.Figure @@ -181,15 +183,16 @@ def plot_overview(outputfolders): precisions.append(pr) recalls.append(re) modelname = folder.split('/')[-1] + datasetname = folder.split('/')[-2] # parameters summary_path = os.path.join(folder,'results/ModelSummary.txt') with open (summary_path, "r") as outfile: rows = outfile.readlines() lastrow = rows[-1] parameter = int(lastrow.split()[1].replace(',','')) - name = '[P={:.2f}M] {}'.format(parameter/100**3, modelname) + name = '[P={:.2f}M] {} {}'.format(parameter/100**3, modelname, datasetname) names.append(name) - title = folder.split('/')[-2] + #title = folder.split('/')[-4] fig = precision_recall_f1iso(precisions,recalls,names,title) return fig @@ -286,4 +289,4 @@ def overlay(dataset, output_path): # save to disk overlayed_path = os.path.join(output_path,'overlayed') if not os.path.exists(overlayed_path): os.makedirs(overlayed_path) - overlayed.save(os.path.join(overlayed_path,name)) \ No newline at end of file + overlayed.save(os.path.join(overlayed_path,name))