From 79b59ab84311edf0bf8628b81b39757aa929b9fc Mon Sep 17 00:00:00 2001 From: Tim Laibacher <tim.laibacher@idiap.ch> Date: Wed, 19 Jun 2019 09:56:20 +0200 Subject: [PATCH] Add more ds configs --- .../binseg/configs/datasets/chasedb11024.py | 24 +++++++ .../binseg/configs/datasets/chasedb11168.py | 24 +++++++ bob/ip/binseg/configs/datasets/chasedb1544.py | 24 +++++++ bob/ip/binseg/configs/datasets/drive1024.py | 4 +- .../{drive608test.py => drive1168.py} | 12 ++-- .../configs/datasets/drive1168sslhrf.py | 47 +++++++++++++ .../configs/datasets/drive2336sslhrf.py | 10 +-- bob/ip/binseg/configs/datasets/drive608.py | 4 +- bob/ip/binseg/configs/datasets/drive960.py | 4 +- .../binseg/configs/datasets/drive960test.py | 20 ------ .../datasets/drivestarechasedb11168.py | 9 +++ .../datasets/drivestarechasedb1iostar1168.py | 10 +++ bob/ip/binseg/configs/datasets/hrf1168.py | 24 +++++++ .../{stare544test.py => hrf1168test.py} | 8 +-- .../configs/datasets/iostarvessel1168.py | 25 +++++++ .../configs/datasets/iostarvessel960.py | 23 +++++++ bob/ip/binseg/configs/datasets/stare1024.py | 25 +++++++ bob/ip/binseg/configs/datasets/stare1168.py | 25 +++++++ .../configs/datasets/stare1168sslhrf.py | 47 +++++++++++++ .../configs/datasets/stare1168sslhrfrefuge.py | 69 +++++++++++++++++++ bob/ip/binseg/configs/datasets/stare544.py | 24 +++++++ bob/ip/binseg/configs/datasets/stare960.py | 25 +++++++ bob/ip/binseg/configs/models/m2unetssl.py | 2 +- bob/ip/binseg/configs/models/m2unetssl0703.py | 39 +++++++++++ bob/ip/binseg/engine/ssltrainer.py | 30 +++++++- bob/ip/binseg/modeling/losses.py | 2 +- bob/ip/binseg/script/binseg.py | 10 +++ 27 files changed, 527 insertions(+), 43 deletions(-) create mode 100644 bob/ip/binseg/configs/datasets/chasedb11024.py create mode 100644 bob/ip/binseg/configs/datasets/chasedb11168.py create mode 100644 bob/ip/binseg/configs/datasets/chasedb1544.py rename bob/ip/binseg/configs/datasets/{drive608test.py => drive1168.py} (51%) create mode 100644 bob/ip/binseg/configs/datasets/drive1168sslhrf.py delete mode 100644 bob/ip/binseg/configs/datasets/drive960test.py create mode 100644 bob/ip/binseg/configs/datasets/drivestarechasedb11168.py create mode 100644 bob/ip/binseg/configs/datasets/drivestarechasedb1iostar1168.py create mode 100644 bob/ip/binseg/configs/datasets/hrf1168.py rename bob/ip/binseg/configs/datasets/{stare544test.py => hrf1168test.py} (69%) create mode 100644 bob/ip/binseg/configs/datasets/iostarvessel1168.py create mode 100644 bob/ip/binseg/configs/datasets/iostarvessel960.py create mode 100644 bob/ip/binseg/configs/datasets/stare1024.py create mode 100644 bob/ip/binseg/configs/datasets/stare1168.py create mode 100644 bob/ip/binseg/configs/datasets/stare1168sslhrf.py create mode 100644 bob/ip/binseg/configs/datasets/stare1168sslhrfrefuge.py create mode 100644 bob/ip/binseg/configs/datasets/stare544.py create mode 100644 bob/ip/binseg/configs/datasets/stare960.py create mode 100644 bob/ip/binseg/configs/models/m2unetssl0703.py diff --git a/bob/ip/binseg/configs/datasets/chasedb11024.py b/bob/ip/binseg/configs/datasets/chasedb11024.py new file mode 100644 index 00000000..028f10fb --- /dev/null +++ b/bob/ip/binseg/configs/datasets/chasedb11024.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from bob.db.chasedb1 import Database as CHASEDB1 +from bob.ip.binseg.data.transforms import * +from bob.ip.binseg.data.binsegdataset import BinSegDataset + +#### Config #### + +transforms = Compose([ + RandomRotation() + ,Crop(0,18,960,960) + ,Resize(1024) + ,RandomHFlip() + ,RandomVFlip() + ,ColorJitter() + ,ToTensor() + ]) + +# bob.db.dataset init +bobdb = CHASEDB1(protocol = 'default') + +# PyTorch dataset +dataset = BinSegDataset(bobdb, split='train', transform=transforms) \ No newline at end of file diff --git a/bob/ip/binseg/configs/datasets/chasedb11168.py b/bob/ip/binseg/configs/datasets/chasedb11168.py new file mode 100644 index 00000000..d221ea48 --- /dev/null +++ b/bob/ip/binseg/configs/datasets/chasedb11168.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from bob.db.chasedb1 import Database as CHASEDB1 +from bob.ip.binseg.data.transforms import * +from bob.ip.binseg.data.binsegdataset import BinSegDataset + +#### Config #### + +transforms = Compose([ + RandomRotation() + ,Crop(140,18,680,960) + ,Resize(1168) + ,RandomHFlip() + ,RandomVFlip() + ,ColorJitter() + ,ToTensor() + ]) + +# bob.db.dataset init +bobdb = CHASEDB1(protocol = 'default') + +# PyTorch dataset +dataset = BinSegDataset(bobdb, split='train', transform=transforms) \ No newline at end of file diff --git a/bob/ip/binseg/configs/datasets/chasedb1544.py b/bob/ip/binseg/configs/datasets/chasedb1544.py new file mode 100644 index 00000000..9d94cd3c --- /dev/null +++ b/bob/ip/binseg/configs/datasets/chasedb1544.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from bob.db.chasedb1 import Database as CHASEDB1 +from bob.ip.binseg.data.transforms import * +from bob.ip.binseg.data.binsegdataset import BinSegDataset + +#### Config #### + +transforms = Compose([ + RandomRotation() + ,Resize(544) + ,Crop(0,12,544,544) + ,RandomHFlip() + ,RandomVFlip() + ,ColorJitter() + ,ToTensor() + ]) + +# bob.db.dataset init +bobdb = CHASEDB1(protocol = 'default') + +# PyTorch dataset +dataset = BinSegDataset(bobdb, split='train', transform=transforms) \ No newline at end of file diff --git a/bob/ip/binseg/configs/datasets/drive1024.py b/bob/ip/binseg/configs/datasets/drive1024.py index d5f08ff8..dae199f5 100644 --- a/bob/ip/binseg/configs/datasets/drive1024.py +++ b/bob/ip/binseg/configs/datasets/drive1024.py @@ -8,11 +8,11 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset #### Config #### transforms = Compose([ - CenterCrop((540,540)) + RandomRotation() + ,CenterCrop((540,540)) ,Resize(1024) ,RandomHFlip() ,RandomVFlip() - ,RandomRotation() ,ColorJitter() ,ToTensor() ]) diff --git a/bob/ip/binseg/configs/datasets/drive608test.py b/bob/ip/binseg/configs/datasets/drive1168.py similarity index 51% rename from bob/ip/binseg/configs/datasets/drive608test.py rename to bob/ip/binseg/configs/datasets/drive1168.py index 7c597136..3f0f0537 100644 --- a/bob/ip/binseg/configs/datasets/drive608test.py +++ b/bob/ip/binseg/configs/datasets/drive1168.py @@ -8,9 +8,13 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset #### Config #### transforms = Compose([ - CenterCrop((470,544)) - ,Pad((10,9,10,8)) - ,Resize(608) + RandomRotation() + ,Crop(75,10,416,544) + ,Pad((21,0,22,0)) + ,Resize(1168) + ,RandomHFlip() + ,RandomVFlip() + ,ColorJitter() ,ToTensor() ]) @@ -18,4 +22,4 @@ transforms = Compose([ bobdb = DRIVE(protocol = 'default') # PyTorch dataset -dataset = BinSegDataset(bobdb, split='test', transform=transforms) \ No newline at end of file +dataset = BinSegDataset(bobdb, split='train', transform=transforms) \ No newline at end of file diff --git a/bob/ip/binseg/configs/datasets/drive1168sslhrf.py b/bob/ip/binseg/configs/datasets/drive1168sslhrf.py new file mode 100644 index 00000000..75f742d4 --- /dev/null +++ b/bob/ip/binseg/configs/datasets/drive1168sslhrf.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from bob.db.drive import Database as DRIVE +from bob.db.hrf import Database as HRF +from bob.ip.binseg.data.transforms import * +from bob.ip.binseg.data.binsegdataset import BinSegDataset, SSLBinSegDataset, UnLabeledBinSegDataset + +#### Config #### + +#### Unlabeled HRF TRAIN #### +unlabeled_transforms = Compose([ + Crop(0,108,2336,3296) + ,Resize(1168) + ,RandomHFlip() + ,RandomVFlip() + ,RandomRotation() + ,ColorJitter() + ,ToTensor() + ]) + +# bob.db.dataset init +hrfbobdb = HRF(protocol = 'default') + +# PyTorch dataset +unlabeled_dataset = UnLabeledBinSegDataset(hrfbobdb, split='train', transform=unlabeled_transforms) + + +#### Labeled #### +labeled_transforms = Compose([ + Crop(75,10,416,544) + ,Pad((21,0,22,0)) + ,Resize(1168) + ,RandomHFlip() + ,RandomVFlip() + ,RandomRotation() + ,ColorJitter() + ,ToTensor() + ]) + +# bob.db.dataset init +bobdb = DRIVE(protocol = 'default') +labeled_dataset = BinSegDataset(bobdb, split='train', transform=labeled_transforms) + +# SSL Dataset + +dataset = SSLBinSegDataset(labeled_dataset, unlabeled_dataset) \ No newline at end of file diff --git a/bob/ip/binseg/configs/datasets/drive2336sslhrf.py b/bob/ip/binseg/configs/datasets/drive2336sslhrf.py index 6ab1e77d..9b6102ff 100644 --- a/bob/ip/binseg/configs/datasets/drive2336sslhrf.py +++ b/bob/ip/binseg/configs/datasets/drive2336sslhrf.py @@ -19,10 +19,10 @@ unlabeled_transforms = Compose([ ]) # bob.db.dataset init -sslbobdb = HRF(protocol = 'default') +hrfbobdb = HRF(protocol = 'default') # PyTorch dataset -unlabeled_dataset = UnLabeledBinSegDataset(sslbobdb, split='train', transform=unlabeled_transforms) +unlabeled_dataset = UnLabeledBinSegDataset(hrfbobdb, split='train', transform=unlabeled_transforms) #### Labeled #### @@ -39,6 +39,8 @@ labeled_transforms = Compose([ # bob.db.dataset init bobdb = DRIVE(protocol = 'default') +labeled_dataset = BinSegDataset(bobdb, split='train', transform=labeled_transforms) -# PyTorch dataset -dataset = SSLBinSegDataset(bobdb, unlabeled_dataset, split='train', transform=labeled_transforms) +# SSL Dataset + +dataset = SSLBinSegDataset(labeled_dataset, unlabeled_dataset) \ No newline at end of file diff --git a/bob/ip/binseg/configs/datasets/drive608.py b/bob/ip/binseg/configs/datasets/drive608.py index 64037a07..65bc5e65 100644 --- a/bob/ip/binseg/configs/datasets/drive608.py +++ b/bob/ip/binseg/configs/datasets/drive608.py @@ -8,12 +8,12 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset #### Config #### transforms = Compose([ - CenterCrop((470,544)) + RandomRotation() + ,CenterCrop((470,544)) ,Pad((10,9,10,8)) ,Resize(608) ,RandomHFlip() ,RandomVFlip() - ,RandomRotation() ,ColorJitter() ,ToTensor() ]) diff --git a/bob/ip/binseg/configs/datasets/drive960.py b/bob/ip/binseg/configs/datasets/drive960.py index f0d06b78..ab3ac5a9 100644 --- a/bob/ip/binseg/configs/datasets/drive960.py +++ b/bob/ip/binseg/configs/datasets/drive960.py @@ -8,11 +8,11 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset #### Config #### transforms = Compose([ - CenterCrop((544,544)) + RandomRotation() + ,CenterCrop((544,544)) ,Resize(960) ,RandomHFlip() ,RandomVFlip() - ,RandomRotation() ,ColorJitter() ,ToTensor() ]) diff --git a/bob/ip/binseg/configs/datasets/drive960test.py b/bob/ip/binseg/configs/datasets/drive960test.py deleted file mode 100644 index 041fa775..00000000 --- a/bob/ip/binseg/configs/datasets/drive960test.py +++ /dev/null @@ -1,20 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -from bob.db.drive import Database as DRIVE -from bob.ip.binseg.data.transforms import * -from bob.ip.binseg.data.binsegdataset import BinSegDataset - -#### Config #### - -transforms = Compose([ - CenterCrop((544,544)) - ,Resize(960) - ,ToTensor() - ]) - -# bob.db.dataset init -bobdb = DRIVE(protocol = 'default') - -# PyTorch dataset -dataset = BinSegDataset(bobdb, split='test', transform=transforms) \ No newline at end of file diff --git a/bob/ip/binseg/configs/datasets/drivestarechasedb11168.py b/bob/ip/binseg/configs/datasets/drivestarechasedb11168.py new file mode 100644 index 00000000..0e36eff7 --- /dev/null +++ b/bob/ip/binseg/configs/datasets/drivestarechasedb11168.py @@ -0,0 +1,9 @@ +from bob.ip.binseg.configs.datasets.drive1168 import dataset as drive +from bob.ip.binseg.configs.datasets.stare1168 import dataset as stare +from bob.ip.binseg.configs.datasets.chasedb11168 import dataset as chase +import torch + +#### Config #### + +# PyTorch dataset +dataset = torch.utils.data.ConcatDataset([drive,stare,chase]) \ No newline at end of file diff --git a/bob/ip/binseg/configs/datasets/drivestarechasedb1iostar1168.py b/bob/ip/binseg/configs/datasets/drivestarechasedb1iostar1168.py new file mode 100644 index 00000000..62e2972d --- /dev/null +++ b/bob/ip/binseg/configs/datasets/drivestarechasedb1iostar1168.py @@ -0,0 +1,10 @@ +from bob.ip.binseg.configs.datasets.drive1168 import dataset as drive +from bob.ip.binseg.configs.datasets.stare1168 import dataset as stare +from bob.ip.binseg.configs.datasets.chasedb11168 import dataset as chase +from bob.ip.binseg.configs.datasets.iostarvessel1168 import dataset as iostar +import torch + +#### Config #### + +# PyTorch dataset +dataset = torch.utils.data.ConcatDataset([drive,stare,chase,iostar]) \ No newline at end of file diff --git a/bob/ip/binseg/configs/datasets/hrf1168.py b/bob/ip/binseg/configs/datasets/hrf1168.py new file mode 100644 index 00000000..4d0c4d9e --- /dev/null +++ b/bob/ip/binseg/configs/datasets/hrf1168.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from bob.db.hrf import Database as HRF +from bob.ip.binseg.data.transforms import * +from bob.ip.binseg.data.binsegdataset import BinSegDataset + +#### Config #### + +transforms = Compose([ + Crop(0,108,2336,3296) + ,Resize((1168)) + ,RandomHFlip() + ,RandomVFlip() + ,RandomRotation() + ,ColorJitter() + ,ToTensor() + ]) + +# bob.db.dataset init +bobdb = HRF(protocol = 'default') + +# PyTorch dataset +dataset = BinSegDataset(bobdb, split='train', transform=transforms) \ No newline at end of file diff --git a/bob/ip/binseg/configs/datasets/stare544test.py b/bob/ip/binseg/configs/datasets/hrf1168test.py similarity index 69% rename from bob/ip/binseg/configs/datasets/stare544test.py rename to bob/ip/binseg/configs/datasets/hrf1168test.py index 09b26873..86014b75 100644 --- a/bob/ip/binseg/configs/datasets/stare544test.py +++ b/bob/ip/binseg/configs/datasets/hrf1168test.py @@ -1,20 +1,20 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -from bob.db.stare import Database as STARE +from bob.db.hrf import Database as HRF from bob.ip.binseg.data.transforms import * from bob.ip.binseg.data.binsegdataset import BinSegDataset #### Config #### transforms = Compose([ - Resize(471) - ,Pad((0,37,0,36)) + Crop(0,108,2336,3296) + ,Resize((1168)) ,ToTensor() ]) # bob.db.dataset init -bobdb = STARE(protocol = 'default') +bobdb = HRF(protocol = 'default') # PyTorch dataset dataset = BinSegDataset(bobdb, split='test', transform=transforms) \ No newline at end of file diff --git a/bob/ip/binseg/configs/datasets/iostarvessel1168.py b/bob/ip/binseg/configs/datasets/iostarvessel1168.py new file mode 100644 index 00000000..5da5ed1e --- /dev/null +++ b/bob/ip/binseg/configs/datasets/iostarvessel1168.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from bob.db.iostar import Database as IOSTAR +from bob.ip.binseg.data.transforms import * +from bob.ip.binseg.data.binsegdataset import BinSegDataset + +#### Config #### + +transforms = Compose([ + RandomRotation() + ,Crop(144,0,768,1024) + ,Pad((30,0,30,0)) + ,Resize(1168) + ,RandomHFlip() + ,RandomVFlip() + ,ColorJitter() + ,ToTensor() + ]) + +# bob.db.dataset init +bobdb = IOSTAR(protocol='default_vessel') + +# PyTorch dataset +dataset = BinSegDataset(bobdb, split='train', transform=transforms) \ No newline at end of file diff --git a/bob/ip/binseg/configs/datasets/iostarvessel960.py b/bob/ip/binseg/configs/datasets/iostarvessel960.py new file mode 100644 index 00000000..32feec85 --- /dev/null +++ b/bob/ip/binseg/configs/datasets/iostarvessel960.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from bob.db.iostar import Database as IOSTAR +from bob.ip.binseg.data.transforms import * +from bob.ip.binseg.data.binsegdataset import BinSegDataset + +#### Config #### + +transforms = Compose([ + Resize(960) + ,RandomHFlip() + ,RandomVFlip() + ,RandomRotation() + ,ColorJitter() + ,ToTensor() + ]) + +# bob.db.dataset init +bobdb = IOSTAR(protocol='default_vessel') + +# PyTorch dataset +dataset = BinSegDataset(bobdb, split='train', transform=transforms) \ No newline at end of file diff --git a/bob/ip/binseg/configs/datasets/stare1024.py b/bob/ip/binseg/configs/datasets/stare1024.py new file mode 100644 index 00000000..8f6df507 --- /dev/null +++ b/bob/ip/binseg/configs/datasets/stare1024.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from bob.db.stare import Database as STARE +from bob.ip.binseg.data.transforms import * +from bob.ip.binseg.data.binsegdataset import BinSegDataset + +#### Config #### + +transforms = Compose([ + RandomRotation() + ,Pad((0,32,0,32)) + ,Resize(1024) + ,CenterCrop(1024) + ,RandomHFlip() + ,RandomVFlip() + ,ColorJitter() + ,ToTensor() + ]) + +# bob.db.dataset init +bobdb = STARE(protocol = 'default') + +# PyTorch dataset +dataset = BinSegDataset(bobdb, split='train', transform=transforms) \ No newline at end of file diff --git a/bob/ip/binseg/configs/datasets/stare1168.py b/bob/ip/binseg/configs/datasets/stare1168.py new file mode 100644 index 00000000..77e934bf --- /dev/null +++ b/bob/ip/binseg/configs/datasets/stare1168.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from bob.db.stare import Database as STARE +from bob.ip.binseg.data.transforms import * +from bob.ip.binseg.data.binsegdataset import BinSegDataset + +#### Config #### + +transforms = Compose([ + RandomRotation() + ,Crop(50,0,500,705) + ,Resize(1168) + ,Pad((1,0,1,0)) + ,RandomHFlip() + ,RandomVFlip() + ,ColorJitter() + ,ToTensor() + ]) + +# bob.db.dataset init +bobdb = STARE(protocol = 'default') + +# PyTorch dataset +dataset = BinSegDataset(bobdb, split='train', transform=transforms) \ No newline at end of file diff --git a/bob/ip/binseg/configs/datasets/stare1168sslhrf.py b/bob/ip/binseg/configs/datasets/stare1168sslhrf.py new file mode 100644 index 00000000..a7e3e201 --- /dev/null +++ b/bob/ip/binseg/configs/datasets/stare1168sslhrf.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from bob.db.stare import Database as STARE +from bob.db.hrf import Database as HRF +from bob.ip.binseg.data.transforms import * +from bob.ip.binseg.data.binsegdataset import BinSegDataset, SSLBinSegDataset, UnLabeledBinSegDataset + +#### Config #### + +#### Unlabeled HRF TRAIN #### +unlabeled_transforms = Compose([ + RandomRotation() + ,Crop(0,108,2336,3296) + ,Resize(1168) + ,RandomHFlip() + ,RandomVFlip() + ,ColorJitter() + ,ToTensor() + ]) + +# bob.db.dataset init +hrfbobdb = HRF(protocol = 'default') + +# PyTorch dataset +unlabeled_dataset = UnLabeledBinSegDataset(hrfbobdb, split='train', transform=unlabeled_transforms) + + +#### Labeled #### +labeled_transforms = Compose([ + RandomRotation() + ,Crop(50,0,500,705) + ,Resize(1168) + ,Pad((1,0,1,0)) + ,RandomHFlip() + ,RandomVFlip() + ,ColorJitter() + ,ToTensor() + ]) + +# bob.db.dataset init +bobdb = STARE(protocol = 'default') +labeled_dataset = BinSegDataset(bobdb, split='train', transform=labeled_transforms) + +# SSL Dataset + +dataset = SSLBinSegDataset(labeled_dataset, unlabeled_dataset) \ No newline at end of file diff --git a/bob/ip/binseg/configs/datasets/stare1168sslhrfrefuge.py b/bob/ip/binseg/configs/datasets/stare1168sslhrfrefuge.py new file mode 100644 index 00000000..a116997a --- /dev/null +++ b/bob/ip/binseg/configs/datasets/stare1168sslhrfrefuge.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from bob.db.stare import Database as STARE +from bob.db.refuge import Database as REFUGE +from bob.db.hrf import Database as HRF +from bob.ip.binseg.data.transforms import * +from bob.ip.binseg.data.binsegdataset import BinSegDataset, SSLBinSegDataset, UnLabeledBinSegDataset +import torch +#### Config #### + +#### Unlabeled HRF TRAIN #### +unlabeled_transforms = Compose([ + RandomRotation() + ,Crop(0,108,2336,3296) + ,Resize(1168) + ,RandomHFlip() + ,RandomVFlip() + ,ColorJitter() + ,ToTensor() + ]) + +#### Unlabeled REFUGE Test #### + +unlabeled_transforms_refuge = Compose([ + RandomRotation() + ,Crop(220,11,1150,1623) + ,Resize(1168) + ,RandomHFlip() + ,RandomVFlip() + ,ColorJitter() + ,ToTensor() + ]) + +## bob.db.dataset init +# hrf +hrfbobdb = HRF(protocol = 'default') +# refuge +refugebobdb = REFUGE() + + +# PyTorch dataset +unlabeled_dataset_1 = UnLabeledBinSegDataset(hrfbobdb, split='train', transform=unlabeled_transforms) + +unlabeled_dataset_2 = UnLabeledBinSegDataset(refugebobdb, split='test', transform=unlabeled_transforms_refuge) + +# Compose +unlabeled_dataset = torch.utils.data.ConcatDataset([unlabeled_dataset_1,unlabeled_dataset_2]) + + +#### Labeled #### +labeled_transforms = Compose([ + RandomRotation() + ,Crop(50,0,500,705) + ,Resize(1168) + ,Pad((1,0,1,0)) + ,RandomHFlip() + ,RandomVFlip() + ,ColorJitter() + ,ToTensor() + ]) + +# bob.db.dataset init +bobdb = STARE(protocol = 'default') +labeled_dataset = BinSegDataset(bobdb, split='train', transform=labeled_transforms) + +# SSL Dataset + +dataset = SSLBinSegDataset(labeled_dataset, unlabeled_dataset) \ No newline at end of file diff --git a/bob/ip/binseg/configs/datasets/stare544.py b/bob/ip/binseg/configs/datasets/stare544.py new file mode 100644 index 00000000..08c2ad4f --- /dev/null +++ b/bob/ip/binseg/configs/datasets/stare544.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from bob.db.stare import Database as STARE +from bob.ip.binseg.data.transforms import * +from bob.ip.binseg.data.binsegdataset import BinSegDataset + +#### Config #### + +transforms = Compose([ RandomRotation() + ,Resize(471) + ,Pad((0,37,0,36)) + ,RandomHFlip() + ,RandomVFlip() + ,ColorJitter() + ,ToTensor() + ,ToTensor() + ]) + +# bob.db.dataset init +bobdb = STARE(protocol = 'default') + +# PyTorch dataset +dataset = BinSegDataset(bobdb, split='train', transform=transforms) \ No newline at end of file diff --git a/bob/ip/binseg/configs/datasets/stare960.py b/bob/ip/binseg/configs/datasets/stare960.py new file mode 100644 index 00000000..0d1ed788 --- /dev/null +++ b/bob/ip/binseg/configs/datasets/stare960.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from bob.db.stare import Database as STARE +from bob.ip.binseg.data.transforms import * +from bob.ip.binseg.data.binsegdataset import BinSegDataset + +#### Config #### + +transforms = Compose([ + RandomRotation() + ,Pad((0,32,0,32)) + ,Resize(960) + ,CenterCrop(960) + ,RandomHFlip() + ,RandomVFlip() + ,ColorJitter() + ,ToTensor() + ]) + +# bob.db.dataset init +bobdb = STARE(protocol = 'default') + +# PyTorch dataset +dataset = BinSegDataset(bobdb, split='train', transform=transforms) \ No newline at end of file diff --git a/bob/ip/binseg/configs/models/m2unetssl.py b/bob/ip/binseg/configs/models/m2unetssl.py index ac8847ab..8280e52f 100644 --- a/bob/ip/binseg/configs/models/m2unetssl.py +++ b/bob/ip/binseg/configs/models/m2unetssl.py @@ -33,7 +33,7 @@ optimizer = AdaBound(model.parameters(), lr=lr, betas=betas, final_lr=final_lr, eps=eps, weight_decay=weight_decay, amsbound=amsbound) # criterion -criterion = MixJacLoss(lambda_u=0.01, jacalpha=0.7, unlabeledjacalpha=0.7) +criterion = MixJacLoss(lambda_u=0.05, jacalpha=0.7, unlabeledjacalpha=0.3) # scheduler scheduler = MultiStepLR(optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma) diff --git a/bob/ip/binseg/configs/models/m2unetssl0703.py b/bob/ip/binseg/configs/models/m2unetssl0703.py new file mode 100644 index 00000000..d5a16082 --- /dev/null +++ b/bob/ip/binseg/configs/models/m2unetssl0703.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from torch.optim.lr_scheduler import MultiStepLR +from bob.ip.binseg.modeling.m2u import build_m2unet +import torch.optim as optim +from torch.nn import BCEWithLogitsLoss +from bob.ip.binseg.utils.model_zoo import modelurls +from bob.ip.binseg.modeling.losses import MixJacLoss +from bob.ip.binseg.engine.adabound import AdaBound + +##### Config ##### +lr = 0.001 +betas = (0.9, 0.999) +eps = 1e-08 +weight_decay = 0 +final_lr = 0.1 +gamma = 1e-3 +eps = 1e-8 +amsbound = False + +scheduler_milestones = [900] +scheduler_gamma = 0.1 + +# model +model = build_m2unet() + +# pretrained backbone +pretrained_backbone = modelurls['mobilenetv2'] + +# optimizer +optimizer = AdaBound(model.parameters(), lr=lr, betas=betas, final_lr=final_lr, gamma=gamma, + eps=eps, weight_decay=weight_decay, amsbound=amsbound) + +# criterion +criterion = MixJacLoss(lambda_u=0.3, jacalpha=0.7, unlabeledjacalpha=0.3) + +# scheduler +scheduler = MultiStepLR(optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma) diff --git a/bob/ip/binseg/engine/ssltrainer.py b/bob/ip/binseg/engine/ssltrainer.py index 8fb6d2f1..38243117 100644 --- a/bob/ip/binseg/engine/ssltrainer.py +++ b/bob/ip/binseg/engine/ssltrainer.py @@ -51,6 +51,27 @@ def mix_up(alpha, input, target, unlabeled_input, unlabled_target): return input_mixedup, target_mixedup, unlabeled_input_mixedup, unlabled_target_mixedup +def square_rampup(current, rampup_length=16): + """slowly ramp-up ``lambda_u`` + + Parameters + ---------- + current : int + current epoch + rampup_length : int, optional + how long to ramp up, by default 16 + + Returns + ------- + float + ramp up factor + """ + if rampup_length == 0: + return 1.0 + else: + current = np.clip((current/ float(rampup_length))**2, 0.0, 1.0) + return float(current) + def linear_rampup(current, rampup_length=16): """slowly ramp-up ``lambda_u`` @@ -109,7 +130,8 @@ def do_ssltrain( checkpoint_period, device, arguments, - output_folder + output_folder, + rampup_length ): """ Train model and save to disk. @@ -134,6 +156,8 @@ def do_ssltrain( start end end epochs output_folder : str output path + rampup_Length : int + rampup epochs """ logger = logging.getLogger("bob.ip.binseg.engine.trainer") logger.info("Start training") @@ -173,7 +197,7 @@ def do_ssltrain( unlabeled_ground_truths = guess_labels(unlabeled_images, model) #unlabeled_ground_truths = sharpen(unlabeled_ground_truths,0.5) #images, ground_truths, unlabeled_images, unlabeled_ground_truths = mix_up(0.75, images, ground_truths, unlabeled_images, unlabeled_ground_truths) - ramp_up_factor = linear_rampup(epoch,rampup_length=500) + ramp_up_factor = square_rampup(epoch,rampup_length=rampup_length) loss, ll, ul = criterion(outputs, ground_truths, unlabeled_outputs, unlabeled_ground_truths, ramp_up_factor) optimizer.zero_grad() @@ -244,7 +268,7 @@ def do_ssltrain( )) log_plot_file = os.path.join(output_folder,"{}_trainlog.pdf".format(model.name)) - logdf = pd.read_csv(os.path.join(output_folder,"{}_trainlog.csv".format(model.name)),header=None, names=["avg. loss", "median loss","lr","max memory"]) + logdf = pd.read_csv(os.path.join(output_folder,"{}_trainlog.csv".format(model.name)),header=None, names=["avg. loss", "median loss", "labeled loss", "unlabeled loss", "lr","max memory"]) fig = loss_curve(logdf,output_folder) logger.info("saving {}".format(log_plot_file)) fig.savefig(log_plot_file) diff --git a/bob/ip/binseg/modeling/losses.py b/bob/ip/binseg/modeling/losses.py index da2b5f5e..bb6578af 100644 --- a/bob/ip/binseg/modeling/losses.py +++ b/bob/ip/binseg/modeling/losses.py @@ -171,7 +171,7 @@ class MixJacLoss(_Loss): lambda_u : int determines the weighting of SoftJaccard and BCE. """ - def __init__(self, lambda_u=100, jacalpha=0.7, unlabeledjacalpha=0.7, size_average=None, reduce=None, reduction='mean', pos_weight=None): + def __init__(self, lambda_u=100, jacalpha=0.7, unlabeledjacalpha=0.3, size_average=None, reduce=None, reduction='mean', pos_weight=None): super(MixJacLoss, self).__init__(size_average, reduce, reduction) self.lambda_u = lambda_u self.labeled_loss = SoftJaccardBCELogitsLoss(alpha=jacalpha) diff --git a/bob/ip/binseg/script/binseg.py b/bob/ip/binseg/script/binseg.py index 16b59f2d..6af9d438 100644 --- a/bob/ip/binseg/script/binseg.py +++ b/bob/ip/binseg/script/binseg.py @@ -466,6 +466,14 @@ def visualize(dataset, output_path, **kwargs): required=True, default='cpu', cls=ResourceOption) +@click.option( + '--rampup', + '-r', + help='Ramp-up length in epochs', + show_default=True, + required=True, + default='900', + cls=ResourceOption) @verbosity_option(cls=ResourceOption) def ssltrain(model @@ -479,6 +487,7 @@ def ssltrain(model ,dataset ,checkpoint_period ,device + ,rampup ,**kwargs): """ Train a model """ @@ -513,4 +522,5 @@ def ssltrain(model , device , arguments , output_path + , rampup ) \ No newline at end of file -- GitLab