diff --git a/bob/ip/binseg/configs/datasets/chasedb11024.py b/bob/ip/binseg/configs/datasets/chasedb11024.py new file mode 100644 index 0000000000000000000000000000000000000000..028f10fb443cd4f1c9c728203f13641cc67b95ad --- /dev/null +++ b/bob/ip/binseg/configs/datasets/chasedb11024.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from bob.db.chasedb1 import Database as CHASEDB1 +from bob.ip.binseg.data.transforms import * +from bob.ip.binseg.data.binsegdataset import BinSegDataset + +#### Config #### + +transforms = Compose([ + RandomRotation() + ,Crop(0,18,960,960) + ,Resize(1024) + ,RandomHFlip() + ,RandomVFlip() + ,ColorJitter() + ,ToTensor() + ]) + +# bob.db.dataset init +bobdb = CHASEDB1(protocol = 'default') + +# PyTorch dataset +dataset = BinSegDataset(bobdb, split='train', transform=transforms) \ No newline at end of file diff --git a/bob/ip/binseg/configs/datasets/chasedb11168.py b/bob/ip/binseg/configs/datasets/chasedb11168.py new file mode 100644 index 0000000000000000000000000000000000000000..d221ea4879c4e7378e0f9ad6bcce2cd5a77bf04c --- /dev/null +++ b/bob/ip/binseg/configs/datasets/chasedb11168.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from bob.db.chasedb1 import Database as CHASEDB1 +from bob.ip.binseg.data.transforms import * +from bob.ip.binseg.data.binsegdataset import BinSegDataset + +#### Config #### + +transforms = Compose([ + RandomRotation() + ,Crop(140,18,680,960) + ,Resize(1168) + ,RandomHFlip() + ,RandomVFlip() + ,ColorJitter() + ,ToTensor() + ]) + +# bob.db.dataset init +bobdb = CHASEDB1(protocol = 'default') + +# PyTorch dataset +dataset = BinSegDataset(bobdb, split='train', transform=transforms) \ No newline at end of file diff --git a/bob/ip/binseg/configs/datasets/chasedb1544.py b/bob/ip/binseg/configs/datasets/chasedb1544.py new file mode 100644 index 0000000000000000000000000000000000000000..9d94cd3ca6e9e1504aaba5bbb78422f547e102c2 --- /dev/null +++ b/bob/ip/binseg/configs/datasets/chasedb1544.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from bob.db.chasedb1 import Database as CHASEDB1 +from bob.ip.binseg.data.transforms import * +from bob.ip.binseg.data.binsegdataset import BinSegDataset + +#### Config #### + +transforms = Compose([ + RandomRotation() + ,Resize(544) + ,Crop(0,12,544,544) + ,RandomHFlip() + ,RandomVFlip() + ,ColorJitter() + ,ToTensor() + ]) + +# bob.db.dataset init +bobdb = CHASEDB1(protocol = 'default') + +# PyTorch dataset +dataset = BinSegDataset(bobdb, split='train', transform=transforms) \ No newline at end of file diff --git a/bob/ip/binseg/configs/datasets/drive1024.py b/bob/ip/binseg/configs/datasets/drive1024.py index d5f08ff8742d49652de783183d44b8e3eec1b3a3..dae199f50dc59c194a5ada24ff8e99f9aa4fd642 100644 --- a/bob/ip/binseg/configs/datasets/drive1024.py +++ b/bob/ip/binseg/configs/datasets/drive1024.py @@ -8,11 +8,11 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset #### Config #### transforms = Compose([ - CenterCrop((540,540)) + RandomRotation() + ,CenterCrop((540,540)) ,Resize(1024) ,RandomHFlip() ,RandomVFlip() - ,RandomRotation() ,ColorJitter() ,ToTensor() ]) diff --git a/bob/ip/binseg/configs/datasets/drive608test.py b/bob/ip/binseg/configs/datasets/drive1168.py similarity index 51% rename from bob/ip/binseg/configs/datasets/drive608test.py rename to bob/ip/binseg/configs/datasets/drive1168.py index 7c59713624bbfab1f95fd5f4b079951258e40a99..3f0f0537e1ba67d7beb3d77af492ba8f9fc539a2 100644 --- a/bob/ip/binseg/configs/datasets/drive608test.py +++ b/bob/ip/binseg/configs/datasets/drive1168.py @@ -8,9 +8,13 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset #### Config #### transforms = Compose([ - CenterCrop((470,544)) - ,Pad((10,9,10,8)) - ,Resize(608) + RandomRotation() + ,Crop(75,10,416,544) + ,Pad((21,0,22,0)) + ,Resize(1168) + ,RandomHFlip() + ,RandomVFlip() + ,ColorJitter() ,ToTensor() ]) @@ -18,4 +22,4 @@ transforms = Compose([ bobdb = DRIVE(protocol = 'default') # PyTorch dataset -dataset = BinSegDataset(bobdb, split='test', transform=transforms) \ No newline at end of file +dataset = BinSegDataset(bobdb, split='train', transform=transforms) \ No newline at end of file diff --git a/bob/ip/binseg/configs/datasets/drive1168sslhrf.py b/bob/ip/binseg/configs/datasets/drive1168sslhrf.py new file mode 100644 index 0000000000000000000000000000000000000000..75f742d4a07882d979cd9a09fe498dbeaee04665 --- /dev/null +++ b/bob/ip/binseg/configs/datasets/drive1168sslhrf.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from bob.db.drive import Database as DRIVE +from bob.db.hrf import Database as HRF +from bob.ip.binseg.data.transforms import * +from bob.ip.binseg.data.binsegdataset import BinSegDataset, SSLBinSegDataset, UnLabeledBinSegDataset + +#### Config #### + +#### Unlabeled HRF TRAIN #### +unlabeled_transforms = Compose([ + Crop(0,108,2336,3296) + ,Resize(1168) + ,RandomHFlip() + ,RandomVFlip() + ,RandomRotation() + ,ColorJitter() + ,ToTensor() + ]) + +# bob.db.dataset init +hrfbobdb = HRF(protocol = 'default') + +# PyTorch dataset +unlabeled_dataset = UnLabeledBinSegDataset(hrfbobdb, split='train', transform=unlabeled_transforms) + + +#### Labeled #### +labeled_transforms = Compose([ + Crop(75,10,416,544) + ,Pad((21,0,22,0)) + ,Resize(1168) + ,RandomHFlip() + ,RandomVFlip() + ,RandomRotation() + ,ColorJitter() + ,ToTensor() + ]) + +# bob.db.dataset init +bobdb = DRIVE(protocol = 'default') +labeled_dataset = BinSegDataset(bobdb, split='train', transform=labeled_transforms) + +# SSL Dataset + +dataset = SSLBinSegDataset(labeled_dataset, unlabeled_dataset) \ No newline at end of file diff --git a/bob/ip/binseg/configs/datasets/drive2336sslhrf.py b/bob/ip/binseg/configs/datasets/drive2336sslhrf.py index 6ab1e77daf8195c18d3853bcb2e66bc843c6ca35..9b6102ffc533f0088ffba7d228a008a081922259 100644 --- a/bob/ip/binseg/configs/datasets/drive2336sslhrf.py +++ b/bob/ip/binseg/configs/datasets/drive2336sslhrf.py @@ -19,10 +19,10 @@ unlabeled_transforms = Compose([ ]) # bob.db.dataset init -sslbobdb = HRF(protocol = 'default') +hrfbobdb = HRF(protocol = 'default') # PyTorch dataset -unlabeled_dataset = UnLabeledBinSegDataset(sslbobdb, split='train', transform=unlabeled_transforms) +unlabeled_dataset = UnLabeledBinSegDataset(hrfbobdb, split='train', transform=unlabeled_transforms) #### Labeled #### @@ -39,6 +39,8 @@ labeled_transforms = Compose([ # bob.db.dataset init bobdb = DRIVE(protocol = 'default') +labeled_dataset = BinSegDataset(bobdb, split='train', transform=labeled_transforms) -# PyTorch dataset -dataset = SSLBinSegDataset(bobdb, unlabeled_dataset, split='train', transform=labeled_transforms) +# SSL Dataset + +dataset = SSLBinSegDataset(labeled_dataset, unlabeled_dataset) \ No newline at end of file diff --git a/bob/ip/binseg/configs/datasets/drive608.py b/bob/ip/binseg/configs/datasets/drive608.py index 64037a07feeef8cff2c85b1a787eb0bafd1cd168..65bc5e6521dcdd3bbbf66ea77a741c6270dcdd58 100644 --- a/bob/ip/binseg/configs/datasets/drive608.py +++ b/bob/ip/binseg/configs/datasets/drive608.py @@ -8,12 +8,12 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset #### Config #### transforms = Compose([ - CenterCrop((470,544)) + RandomRotation() + ,CenterCrop((470,544)) ,Pad((10,9,10,8)) ,Resize(608) ,RandomHFlip() ,RandomVFlip() - ,RandomRotation() ,ColorJitter() ,ToTensor() ]) diff --git a/bob/ip/binseg/configs/datasets/drive960.py b/bob/ip/binseg/configs/datasets/drive960.py index f0d06b781a7135f929ae28a53b9c83ea268c3f20..ab3ac5a9d03916dce4bbb162f51455a5bb05bcba 100644 --- a/bob/ip/binseg/configs/datasets/drive960.py +++ b/bob/ip/binseg/configs/datasets/drive960.py @@ -8,11 +8,11 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset #### Config #### transforms = Compose([ - CenterCrop((544,544)) + RandomRotation() + ,CenterCrop((544,544)) ,Resize(960) ,RandomHFlip() ,RandomVFlip() - ,RandomRotation() ,ColorJitter() ,ToTensor() ]) diff --git a/bob/ip/binseg/configs/datasets/drive960test.py b/bob/ip/binseg/configs/datasets/drive960test.py deleted file mode 100644 index 041fa775f9ce6a47c6f8db145fed22d009e86cde..0000000000000000000000000000000000000000 --- a/bob/ip/binseg/configs/datasets/drive960test.py +++ /dev/null @@ -1,20 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -from bob.db.drive import Database as DRIVE -from bob.ip.binseg.data.transforms import * -from bob.ip.binseg.data.binsegdataset import BinSegDataset - -#### Config #### - -transforms = Compose([ - CenterCrop((544,544)) - ,Resize(960) - ,ToTensor() - ]) - -# bob.db.dataset init -bobdb = DRIVE(protocol = 'default') - -# PyTorch dataset -dataset = BinSegDataset(bobdb, split='test', transform=transforms) \ No newline at end of file diff --git a/bob/ip/binseg/configs/datasets/drivestarechasedb11168.py b/bob/ip/binseg/configs/datasets/drivestarechasedb11168.py new file mode 100644 index 0000000000000000000000000000000000000000..0e36eff7195de38bd2122a0bb75942a210fba1db --- /dev/null +++ b/bob/ip/binseg/configs/datasets/drivestarechasedb11168.py @@ -0,0 +1,9 @@ +from bob.ip.binseg.configs.datasets.drive1168 import dataset as drive +from bob.ip.binseg.configs.datasets.stare1168 import dataset as stare +from bob.ip.binseg.configs.datasets.chasedb11168 import dataset as chase +import torch + +#### Config #### + +# PyTorch dataset +dataset = torch.utils.data.ConcatDataset([drive,stare,chase]) \ No newline at end of file diff --git a/bob/ip/binseg/configs/datasets/drivestarechasedb1iostar1168.py b/bob/ip/binseg/configs/datasets/drivestarechasedb1iostar1168.py new file mode 100644 index 0000000000000000000000000000000000000000..62e2972d3526fd5ff138a5ea99f0855849beda12 --- /dev/null +++ b/bob/ip/binseg/configs/datasets/drivestarechasedb1iostar1168.py @@ -0,0 +1,10 @@ +from bob.ip.binseg.configs.datasets.drive1168 import dataset as drive +from bob.ip.binseg.configs.datasets.stare1168 import dataset as stare +from bob.ip.binseg.configs.datasets.chasedb11168 import dataset as chase +from bob.ip.binseg.configs.datasets.iostarvessel1168 import dataset as iostar +import torch + +#### Config #### + +# PyTorch dataset +dataset = torch.utils.data.ConcatDataset([drive,stare,chase,iostar]) \ No newline at end of file diff --git a/bob/ip/binseg/configs/datasets/hrf1168.py b/bob/ip/binseg/configs/datasets/hrf1168.py new file mode 100644 index 0000000000000000000000000000000000000000..4d0c4d9eb3097a3c918a24bd2621bb3284f41215 --- /dev/null +++ b/bob/ip/binseg/configs/datasets/hrf1168.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from bob.db.hrf import Database as HRF +from bob.ip.binseg.data.transforms import * +from bob.ip.binseg.data.binsegdataset import BinSegDataset + +#### Config #### + +transforms = Compose([ + Crop(0,108,2336,3296) + ,Resize((1168)) + ,RandomHFlip() + ,RandomVFlip() + ,RandomRotation() + ,ColorJitter() + ,ToTensor() + ]) + +# bob.db.dataset init +bobdb = HRF(protocol = 'default') + +# PyTorch dataset +dataset = BinSegDataset(bobdb, split='train', transform=transforms) \ No newline at end of file diff --git a/bob/ip/binseg/configs/datasets/stare544test.py b/bob/ip/binseg/configs/datasets/hrf1168test.py similarity index 69% rename from bob/ip/binseg/configs/datasets/stare544test.py rename to bob/ip/binseg/configs/datasets/hrf1168test.py index 09b268737d08be21890ddde620769ec4e37bc874..86014b75bd7ea428a5f48f85776189d6eeccb619 100644 --- a/bob/ip/binseg/configs/datasets/stare544test.py +++ b/bob/ip/binseg/configs/datasets/hrf1168test.py @@ -1,20 +1,20 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -from bob.db.stare import Database as STARE +from bob.db.hrf import Database as HRF from bob.ip.binseg.data.transforms import * from bob.ip.binseg.data.binsegdataset import BinSegDataset #### Config #### transforms = Compose([ - Resize(471) - ,Pad((0,37,0,36)) + Crop(0,108,2336,3296) + ,Resize((1168)) ,ToTensor() ]) # bob.db.dataset init -bobdb = STARE(protocol = 'default') +bobdb = HRF(protocol = 'default') # PyTorch dataset dataset = BinSegDataset(bobdb, split='test', transform=transforms) \ No newline at end of file diff --git a/bob/ip/binseg/configs/datasets/iostarvessel1168.py b/bob/ip/binseg/configs/datasets/iostarvessel1168.py new file mode 100644 index 0000000000000000000000000000000000000000..5da5ed1e912065ca4ea2a81e4bd0f8b4e8d5475d --- /dev/null +++ b/bob/ip/binseg/configs/datasets/iostarvessel1168.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from bob.db.iostar import Database as IOSTAR +from bob.ip.binseg.data.transforms import * +from bob.ip.binseg.data.binsegdataset import BinSegDataset + +#### Config #### + +transforms = Compose([ + RandomRotation() + ,Crop(144,0,768,1024) + ,Pad((30,0,30,0)) + ,Resize(1168) + ,RandomHFlip() + ,RandomVFlip() + ,ColorJitter() + ,ToTensor() + ]) + +# bob.db.dataset init +bobdb = IOSTAR(protocol='default_vessel') + +# PyTorch dataset +dataset = BinSegDataset(bobdb, split='train', transform=transforms) \ No newline at end of file diff --git a/bob/ip/binseg/configs/datasets/iostarvessel960.py b/bob/ip/binseg/configs/datasets/iostarvessel960.py new file mode 100644 index 0000000000000000000000000000000000000000..32feec853882cbdcadc9fea91de4a1d61e168cc0 --- /dev/null +++ b/bob/ip/binseg/configs/datasets/iostarvessel960.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from bob.db.iostar import Database as IOSTAR +from bob.ip.binseg.data.transforms import * +from bob.ip.binseg.data.binsegdataset import BinSegDataset + +#### Config #### + +transforms = Compose([ + Resize(960) + ,RandomHFlip() + ,RandomVFlip() + ,RandomRotation() + ,ColorJitter() + ,ToTensor() + ]) + +# bob.db.dataset init +bobdb = IOSTAR(protocol='default_vessel') + +# PyTorch dataset +dataset = BinSegDataset(bobdb, split='train', transform=transforms) \ No newline at end of file diff --git a/bob/ip/binseg/configs/datasets/stare1024.py b/bob/ip/binseg/configs/datasets/stare1024.py new file mode 100644 index 0000000000000000000000000000000000000000..8f6df507b16aeffb485bc448c57cf8b21f47bda1 --- /dev/null +++ b/bob/ip/binseg/configs/datasets/stare1024.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from bob.db.stare import Database as STARE +from bob.ip.binseg.data.transforms import * +from bob.ip.binseg.data.binsegdataset import BinSegDataset + +#### Config #### + +transforms = Compose([ + RandomRotation() + ,Pad((0,32,0,32)) + ,Resize(1024) + ,CenterCrop(1024) + ,RandomHFlip() + ,RandomVFlip() + ,ColorJitter() + ,ToTensor() + ]) + +# bob.db.dataset init +bobdb = STARE(protocol = 'default') + +# PyTorch dataset +dataset = BinSegDataset(bobdb, split='train', transform=transforms) \ No newline at end of file diff --git a/bob/ip/binseg/configs/datasets/stare1168.py b/bob/ip/binseg/configs/datasets/stare1168.py new file mode 100644 index 0000000000000000000000000000000000000000..77e934bf6b6f387105df08519932822bcb11cf09 --- /dev/null +++ b/bob/ip/binseg/configs/datasets/stare1168.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from bob.db.stare import Database as STARE +from bob.ip.binseg.data.transforms import * +from bob.ip.binseg.data.binsegdataset import BinSegDataset + +#### Config #### + +transforms = Compose([ + RandomRotation() + ,Crop(50,0,500,705) + ,Resize(1168) + ,Pad((1,0,1,0)) + ,RandomHFlip() + ,RandomVFlip() + ,ColorJitter() + ,ToTensor() + ]) + +# bob.db.dataset init +bobdb = STARE(protocol = 'default') + +# PyTorch dataset +dataset = BinSegDataset(bobdb, split='train', transform=transforms) \ No newline at end of file diff --git a/bob/ip/binseg/configs/datasets/stare1168sslhrf.py b/bob/ip/binseg/configs/datasets/stare1168sslhrf.py new file mode 100644 index 0000000000000000000000000000000000000000..a7e3e201464c8262dc946e4ba114635299223295 --- /dev/null +++ b/bob/ip/binseg/configs/datasets/stare1168sslhrf.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from bob.db.stare import Database as STARE +from bob.db.hrf import Database as HRF +from bob.ip.binseg.data.transforms import * +from bob.ip.binseg.data.binsegdataset import BinSegDataset, SSLBinSegDataset, UnLabeledBinSegDataset + +#### Config #### + +#### Unlabeled HRF TRAIN #### +unlabeled_transforms = Compose([ + RandomRotation() + ,Crop(0,108,2336,3296) + ,Resize(1168) + ,RandomHFlip() + ,RandomVFlip() + ,ColorJitter() + ,ToTensor() + ]) + +# bob.db.dataset init +hrfbobdb = HRF(protocol = 'default') + +# PyTorch dataset +unlabeled_dataset = UnLabeledBinSegDataset(hrfbobdb, split='train', transform=unlabeled_transforms) + + +#### Labeled #### +labeled_transforms = Compose([ + RandomRotation() + ,Crop(50,0,500,705) + ,Resize(1168) + ,Pad((1,0,1,0)) + ,RandomHFlip() + ,RandomVFlip() + ,ColorJitter() + ,ToTensor() + ]) + +# bob.db.dataset init +bobdb = STARE(protocol = 'default') +labeled_dataset = BinSegDataset(bobdb, split='train', transform=labeled_transforms) + +# SSL Dataset + +dataset = SSLBinSegDataset(labeled_dataset, unlabeled_dataset) \ No newline at end of file diff --git a/bob/ip/binseg/configs/datasets/stare1168sslhrfrefuge.py b/bob/ip/binseg/configs/datasets/stare1168sslhrfrefuge.py new file mode 100644 index 0000000000000000000000000000000000000000..a116997a8e63a3ac38e9b4fb2b43d0115fc56251 --- /dev/null +++ b/bob/ip/binseg/configs/datasets/stare1168sslhrfrefuge.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from bob.db.stare import Database as STARE +from bob.db.refuge import Database as REFUGE +from bob.db.hrf import Database as HRF +from bob.ip.binseg.data.transforms import * +from bob.ip.binseg.data.binsegdataset import BinSegDataset, SSLBinSegDataset, UnLabeledBinSegDataset +import torch +#### Config #### + +#### Unlabeled HRF TRAIN #### +unlabeled_transforms = Compose([ + RandomRotation() + ,Crop(0,108,2336,3296) + ,Resize(1168) + ,RandomHFlip() + ,RandomVFlip() + ,ColorJitter() + ,ToTensor() + ]) + +#### Unlabeled REFUGE Test #### + +unlabeled_transforms_refuge = Compose([ + RandomRotation() + ,Crop(220,11,1150,1623) + ,Resize(1168) + ,RandomHFlip() + ,RandomVFlip() + ,ColorJitter() + ,ToTensor() + ]) + +## bob.db.dataset init +# hrf +hrfbobdb = HRF(protocol = 'default') +# refuge +refugebobdb = REFUGE() + + +# PyTorch dataset +unlabeled_dataset_1 = UnLabeledBinSegDataset(hrfbobdb, split='train', transform=unlabeled_transforms) + +unlabeled_dataset_2 = UnLabeledBinSegDataset(refugebobdb, split='test', transform=unlabeled_transforms_refuge) + +# Compose +unlabeled_dataset = torch.utils.data.ConcatDataset([unlabeled_dataset_1,unlabeled_dataset_2]) + + +#### Labeled #### +labeled_transforms = Compose([ + RandomRotation() + ,Crop(50,0,500,705) + ,Resize(1168) + ,Pad((1,0,1,0)) + ,RandomHFlip() + ,RandomVFlip() + ,ColorJitter() + ,ToTensor() + ]) + +# bob.db.dataset init +bobdb = STARE(protocol = 'default') +labeled_dataset = BinSegDataset(bobdb, split='train', transform=labeled_transforms) + +# SSL Dataset + +dataset = SSLBinSegDataset(labeled_dataset, unlabeled_dataset) \ No newline at end of file diff --git a/bob/ip/binseg/configs/datasets/stare544.py b/bob/ip/binseg/configs/datasets/stare544.py new file mode 100644 index 0000000000000000000000000000000000000000..08c2ad4fb8f4610d6213cb3357d0e4bd491a4f3f --- /dev/null +++ b/bob/ip/binseg/configs/datasets/stare544.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from bob.db.stare import Database as STARE +from bob.ip.binseg.data.transforms import * +from bob.ip.binseg.data.binsegdataset import BinSegDataset + +#### Config #### + +transforms = Compose([ RandomRotation() + ,Resize(471) + ,Pad((0,37,0,36)) + ,RandomHFlip() + ,RandomVFlip() + ,ColorJitter() + ,ToTensor() + ,ToTensor() + ]) + +# bob.db.dataset init +bobdb = STARE(protocol = 'default') + +# PyTorch dataset +dataset = BinSegDataset(bobdb, split='train', transform=transforms) \ No newline at end of file diff --git a/bob/ip/binseg/configs/datasets/stare960.py b/bob/ip/binseg/configs/datasets/stare960.py new file mode 100644 index 0000000000000000000000000000000000000000..0d1ed7883cb746f469534ad2a29f491501e7566e --- /dev/null +++ b/bob/ip/binseg/configs/datasets/stare960.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from bob.db.stare import Database as STARE +from bob.ip.binseg.data.transforms import * +from bob.ip.binseg.data.binsegdataset import BinSegDataset + +#### Config #### + +transforms = Compose([ + RandomRotation() + ,Pad((0,32,0,32)) + ,Resize(960) + ,CenterCrop(960) + ,RandomHFlip() + ,RandomVFlip() + ,ColorJitter() + ,ToTensor() + ]) + +# bob.db.dataset init +bobdb = STARE(protocol = 'default') + +# PyTorch dataset +dataset = BinSegDataset(bobdb, split='train', transform=transforms) \ No newline at end of file diff --git a/bob/ip/binseg/configs/models/m2unetssl.py b/bob/ip/binseg/configs/models/m2unetssl.py index ac8847ab64cf2e948ef77c6cf2ad9a5e2a2eedb8..8280e52f8d94de8b43cc1f4da6e8afb0b4f42878 100644 --- a/bob/ip/binseg/configs/models/m2unetssl.py +++ b/bob/ip/binseg/configs/models/m2unetssl.py @@ -33,7 +33,7 @@ optimizer = AdaBound(model.parameters(), lr=lr, betas=betas, final_lr=final_lr, eps=eps, weight_decay=weight_decay, amsbound=amsbound) # criterion -criterion = MixJacLoss(lambda_u=0.01, jacalpha=0.7, unlabeledjacalpha=0.7) +criterion = MixJacLoss(lambda_u=0.05, jacalpha=0.7, unlabeledjacalpha=0.3) # scheduler scheduler = MultiStepLR(optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma) diff --git a/bob/ip/binseg/configs/models/m2unetssl0703.py b/bob/ip/binseg/configs/models/m2unetssl0703.py new file mode 100644 index 0000000000000000000000000000000000000000..d5a160821deaea8436d533938c69b9f535fc763d --- /dev/null +++ b/bob/ip/binseg/configs/models/m2unetssl0703.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from torch.optim.lr_scheduler import MultiStepLR +from bob.ip.binseg.modeling.m2u import build_m2unet +import torch.optim as optim +from torch.nn import BCEWithLogitsLoss +from bob.ip.binseg.utils.model_zoo import modelurls +from bob.ip.binseg.modeling.losses import MixJacLoss +from bob.ip.binseg.engine.adabound import AdaBound + +##### Config ##### +lr = 0.001 +betas = (0.9, 0.999) +eps = 1e-08 +weight_decay = 0 +final_lr = 0.1 +gamma = 1e-3 +eps = 1e-8 +amsbound = False + +scheduler_milestones = [900] +scheduler_gamma = 0.1 + +# model +model = build_m2unet() + +# pretrained backbone +pretrained_backbone = modelurls['mobilenetv2'] + +# optimizer +optimizer = AdaBound(model.parameters(), lr=lr, betas=betas, final_lr=final_lr, gamma=gamma, + eps=eps, weight_decay=weight_decay, amsbound=amsbound) + +# criterion +criterion = MixJacLoss(lambda_u=0.3, jacalpha=0.7, unlabeledjacalpha=0.3) + +# scheduler +scheduler = MultiStepLR(optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma) diff --git a/bob/ip/binseg/engine/ssltrainer.py b/bob/ip/binseg/engine/ssltrainer.py index 8fb6d2f1c5fd5172e08e2d5eb034698ce47f1218..382431176fc33e4bf98e3cc38d4440c6f532c30d 100644 --- a/bob/ip/binseg/engine/ssltrainer.py +++ b/bob/ip/binseg/engine/ssltrainer.py @@ -51,6 +51,27 @@ def mix_up(alpha, input, target, unlabeled_input, unlabled_target): return input_mixedup, target_mixedup, unlabeled_input_mixedup, unlabled_target_mixedup +def square_rampup(current, rampup_length=16): + """slowly ramp-up ``lambda_u`` + + Parameters + ---------- + current : int + current epoch + rampup_length : int, optional + how long to ramp up, by default 16 + + Returns + ------- + float + ramp up factor + """ + if rampup_length == 0: + return 1.0 + else: + current = np.clip((current/ float(rampup_length))**2, 0.0, 1.0) + return float(current) + def linear_rampup(current, rampup_length=16): """slowly ramp-up ``lambda_u`` @@ -109,7 +130,8 @@ def do_ssltrain( checkpoint_period, device, arguments, - output_folder + output_folder, + rampup_length ): """ Train model and save to disk. @@ -134,6 +156,8 @@ def do_ssltrain( start end end epochs output_folder : str output path + rampup_Length : int + rampup epochs """ logger = logging.getLogger("bob.ip.binseg.engine.trainer") logger.info("Start training") @@ -173,7 +197,7 @@ def do_ssltrain( unlabeled_ground_truths = guess_labels(unlabeled_images, model) #unlabeled_ground_truths = sharpen(unlabeled_ground_truths,0.5) #images, ground_truths, unlabeled_images, unlabeled_ground_truths = mix_up(0.75, images, ground_truths, unlabeled_images, unlabeled_ground_truths) - ramp_up_factor = linear_rampup(epoch,rampup_length=500) + ramp_up_factor = square_rampup(epoch,rampup_length=rampup_length) loss, ll, ul = criterion(outputs, ground_truths, unlabeled_outputs, unlabeled_ground_truths, ramp_up_factor) optimizer.zero_grad() @@ -244,7 +268,7 @@ def do_ssltrain( )) log_plot_file = os.path.join(output_folder,"{}_trainlog.pdf".format(model.name)) - logdf = pd.read_csv(os.path.join(output_folder,"{}_trainlog.csv".format(model.name)),header=None, names=["avg. loss", "median loss","lr","max memory"]) + logdf = pd.read_csv(os.path.join(output_folder,"{}_trainlog.csv".format(model.name)),header=None, names=["avg. loss", "median loss", "labeled loss", "unlabeled loss", "lr","max memory"]) fig = loss_curve(logdf,output_folder) logger.info("saving {}".format(log_plot_file)) fig.savefig(log_plot_file) diff --git a/bob/ip/binseg/modeling/losses.py b/bob/ip/binseg/modeling/losses.py index da2b5f5ed6a518b9f6e0aafe24a5edc1f247237b..bb6578af297893252e06f7cfdafe4d950da3790a 100644 --- a/bob/ip/binseg/modeling/losses.py +++ b/bob/ip/binseg/modeling/losses.py @@ -171,7 +171,7 @@ class MixJacLoss(_Loss): lambda_u : int determines the weighting of SoftJaccard and BCE. """ - def __init__(self, lambda_u=100, jacalpha=0.7, unlabeledjacalpha=0.7, size_average=None, reduce=None, reduction='mean', pos_weight=None): + def __init__(self, lambda_u=100, jacalpha=0.7, unlabeledjacalpha=0.3, size_average=None, reduce=None, reduction='mean', pos_weight=None): super(MixJacLoss, self).__init__(size_average, reduce, reduction) self.lambda_u = lambda_u self.labeled_loss = SoftJaccardBCELogitsLoss(alpha=jacalpha) diff --git a/bob/ip/binseg/script/binseg.py b/bob/ip/binseg/script/binseg.py index 16b59f2d83c468b10786e793a3d4a266a9df13b5..6af9d438a317a280d1b2c3e0c5fbe65fbd50068d 100644 --- a/bob/ip/binseg/script/binseg.py +++ b/bob/ip/binseg/script/binseg.py @@ -466,6 +466,14 @@ def visualize(dataset, output_path, **kwargs): required=True, default='cpu', cls=ResourceOption) +@click.option( + '--rampup', + '-r', + help='Ramp-up length in epochs', + show_default=True, + required=True, + default='900', + cls=ResourceOption) @verbosity_option(cls=ResourceOption) def ssltrain(model @@ -479,6 +487,7 @@ def ssltrain(model ,dataset ,checkpoint_period ,device + ,rampup ,**kwargs): """ Train a model """ @@ -513,4 +522,5 @@ def ssltrain(model , device , arguments , output_path + , rampup ) \ No newline at end of file