diff --git a/conda/meta.yaml b/conda/meta.yaml index 771b4a3feed6194c6ccb1599ba52ed2f02f4504f..74ab9d7b21af02d6c2d1e74780fa9ba6d906d213 100644 --- a/conda/meta.yaml +++ b/conda/meta.yaml @@ -36,7 +36,7 @@ requirements: - torchvision {{ torchvision }} - tqdm {{ tqdm }} - tensorboard {{ tensorboard }} - - pytorch-lightning {{ pytorch_lightning }} + - lightning {{ lightning }} - clapper run: - python >=3.9 @@ -53,7 +53,7 @@ requirements: - {{ pin_compatible('torchvision') }} - {{ pin_compatible('tqdm') }} - {{ pin_compatible('tensorboard') }} - - {{ pin_compatible('pytorch-lightning') }} + - {{ pin_compatible('lightning') }} - clapper test: diff --git a/doc/api.rst b/doc/api.rst index 7314356253fefa05fc6de69fadefccbd61471aca..0114d9ff62b83dfc2178ade21426a7a2cb334dfe 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -53,6 +53,10 @@ Direct data-access through iterators. ptbench.data.indian ptbench.data.nih_cxr14_re ptbench.data.padchest_RS + ptbench.data.tbx11k_simplified + ptbench.data.tbx11k_simplified_RS + ptbench.data.tbx11k_simplified_v2 + ptbench.data.tbx11k_simplified_v2_RS .. _ptbench.api.models: diff --git a/doc/catalog.json b/doc/catalog.json index 795fb7d5bf8ce777d2ea3834485aafadc9ccb166..6c65691e90743e6d8c5ad405902b6a9e494a1c88 100644 --- a/doc/catalog.json +++ b/doc/catalog.json @@ -5,5 +5,13 @@ "latest": "https://clapper.readthedocs.io/en/latest/" }, "sources": {} + }, + "lightning": { + "versions": { + "stable": "https://lightning.ai/docs/pytorch/stable/" + }, + "sources": { + "environment": "lightning" + } } } diff --git a/doc/conf.py b/doc/conf.py index c46a3dadcb7ecfc431589c397ec3e9a5c02b0dd8..49b7ceacf910e8ea1e1d2c5c602654a678b64908 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -122,6 +122,7 @@ auto_intersphinx_packages = [ "psutil", "torch", "torchvision", + "lightning", ("clapper", "latest"), ("python", "3"), ] diff --git a/doc/config.rst b/doc/config.rst index af8fa995a3247f5bd735627b9d59e6de436025c4..828bba07a4c3db58cdc4854c7ab1c4e407f21f17 100644 --- a/doc/config.rst +++ b/doc/config.rst @@ -50,6 +50,12 @@ if applicable. Use these datasets for training and evaluating your models. ptbench.configs.datasets.mc_ch_in.default ptbench.configs.datasets.mc_ch_in.rgb ptbench.configs.datasets.mc_ch_in_RS.default + ptbench.configs.datasets.mc_ch_in_11k.default + ptbench.configs.datasets.mc_ch_in_11k.rgb + ptbench.configs.datasets.mc_ch_in_11k_RS.default + ptbench.configs.datasets.mc_ch_in_11kv2.default + ptbench.configs.datasets.mc_ch_in_11kv2.rgb + ptbench.configs.datasets.mc_ch_in_11kv2_RS.default ptbench.configs.datasets.mc_ch_in_pc.default ptbench.configs.datasets.mc_ch_in_pc.rgb ptbench.configs.datasets.mc_ch_in_pc_RS.default @@ -68,6 +74,12 @@ if applicable. Use these datasets for training and evaluating your models. ptbench.configs.datasets.shenzhen.default ptbench.configs.datasets.shenzhen.rgb ptbench.configs.datasets.shenzhen_RS.default + ptbench.configs.datasets.tbx11k_simplified.default + ptbench.configs.datasets.tbx11k_simplified.rgb + ptbench.configs.datasets.tbx11k_simplified_RS.default + ptbench.configs.datasets.tbx11k_simplified_v2.default + ptbench.configs.datasets.tbx11k_simplified_v2.rgb + ptbench.configs.datasets.tbx11k_simplified_v2_RS.default .. _ptbench.configs.datasets.folds: @@ -97,6 +109,12 @@ datasets. Nine other folds are available for every configuration (from 1 to ptbench.configs.datasets.mc_ch_in.fold_0 ptbench.configs.datasets.mc_ch_in.fold_0_rgb ptbench.configs.datasets.mc_ch_in_RS.fold_0 + ptbench.configs.datasets.mc_ch_in_11k.fold_0 + ptbench.configs.datasets.mc_ch_in_11k.fold_0_rgb + ptbench.configs.datasets.mc_ch_in_11k_RS.fold_0 + ptbench.configs.datasets.mc_ch_in_11kv2.fold_0 + ptbench.configs.datasets.mc_ch_in_11kv2.fold_0_rgb + ptbench.configs.datasets.mc_ch_in_11kv2_RS.fold_0 ptbench.configs.datasets.montgomery.fold_0 ptbench.configs.datasets.montgomery.fold_0_rgb ptbench.configs.datasets.montgomery_RS.fold_0 @@ -106,6 +124,12 @@ datasets. Nine other folds are available for every configuration (from 1 to ptbench.configs.datasets.tbpoc.fold_0 ptbench.configs.datasets.tbpoc.fold_0_rgb ptbench.configs.datasets.tbpoc_RS.fold_0 + ptbench.configs.datasets.tbx11k_simplified.fold_0 + ptbench.configs.datasets.tbx11k_simplified.fold_0_rgb + ptbench.configs.datasets.tbx11k_simplified_RS.fold_0 + ptbench.configs.datasets.tbx11k_simplified_v2.fold_0 + ptbench.configs.datasets.tbx11k_simplified_v2.fold_0_rgb + ptbench.configs.datasets.tbx11k_simplified_v2_RS.fold_0 .. include:: links.rst diff --git a/doc/extras.inv b/doc/extras.inv index 88973215f3227495564e345c34603dad42a9b532..d053cdcf2c67163e69d16b9ab3cff2f0aaaf4f37 100644 --- a/doc/extras.inv +++ b/doc/extras.inv @@ -2,5 +2,5 @@ # Project: extras # Version: stable # The remainder of this file is compressed using zlib. -xÚEËÁ € лSti¼² * - PÒ~MØÞÞ߃è–îlYšƒ†f‡h5êÃWÙ¯i¡tóÌ}àÅNôäo°!¬%ò]B-4OÎŒ ã \ No newline at end of file +xÚA +1E÷ž¢ [[ÜÎôR§± ´MH£2žÞ‡A\ˆ.ÃÏ{OIútņTŠ¯íLRšá¡+.ÌÎ$Uns<èôlI¢› ×ÔŸ2¸h“–l¶«Œ1iÅíBõ$`g§Ý/ëa¾Gôæ%<« ‰ÂXõŒîƒåŸšëÍ×UØë±)ðÏibÅ‚wÿèĘ/ \ No newline at end of file diff --git a/doc/extras.txt b/doc/extras.txt index e827f8fa4af7ea94634d044dec1d282029babcc4..77fd0ca6112dcd91c2c303eefe87b279f3f64035 100644 --- a/doc/extras.txt +++ b/doc/extras.txt @@ -3,3 +3,6 @@ # Version: stable # The remainder of this file is compressed using zlib. torchvision.transforms py:module 1 https://pytorch.org/vision/stable/transforms.html - +lightning.pytorch.core.module.LightningModule.forward py:method 1 api/lightning.pytorch.core.LightningModule.html#$ - +lightning.pytorch.core.module.LightningModule.predict_step py:method 1 api/lightning.pytorch.core.LightningModule.html#$ - +optimizer_step py:method 1 api/lightning.pytorch.core.LightningModule.html#$ - diff --git a/doc/install.rst b/doc/install.rst index 0ec7ad7d7d9dfbefb6b3106ddcfc164d41d7f435..170f1883325ffccb41ad78d2eafbd87ceeb7528e 100644 --- a/doc/install.rst +++ b/doc/install.rst @@ -66,6 +66,7 @@ Here is an example configuration file that may be useful as a starting point: montgomery = "/Users/myself/dbs/montgomery-xrayset" shenzhen = "/Users/myself/dbs/shenzhen" nih_cxr14_re = "/Users/myself/dbs/nih-cxr14-re" + tbx11k_simplified = "/Users/myself/dbs/tbx11k-simplified" [nih_cxr14_re] idiap_folder_structure = false # set to `true` if at Idiap @@ -145,6 +146,42 @@ In addition to the splits presented in the following table, 10 folds .. _ptbench.setup.datasets.tb+signs: +Tuberculosis multilabel dataset +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The following dataset contains the labels healthy, sick & non-TB, active TB, +and latent TB. The implemented tbx11k dataset in this package is based on +the simplified version, which is just a more compact version of the original. +In addition to the splits presented in the following table, 10 folds +(for cross-validation) randomly generated are available for these datasets. + +.. list-table:: + + * - Dataset + - Reference + - H x W + - Samples + - Training + - Validation + - Test + * - TBX11K_ + - [TBX11K-2020]_ + - 512 x 512 + - 11'200 + - 6600 + - 1800 + - 2800 + * - TBX11K_SIMPLIFIED_ + - [TBX11K-SIMPLIFIED-2020]_ + - 512 x 512 + - 11'200 + - 6600 + - 1800 + - 2800 + + +.. _ptbench.setup.datasets.tbmultilabel+signs: + Tuberculosis + radiological findings dataset ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/doc/links.rst b/doc/links.rst index df1ef5ea53ee13e8d6c41c02bdaf2367b3940bd6..92ecafd5a8e0fd128efaaccd637d2b2731f2f788 100644 --- a/doc/links.rst +++ b/doc/links.rst @@ -18,3 +18,5 @@ .. _indian: https://sourceforge.net/projects/tbxpredict/ .. _NIH_CXR14_re: https://nihcc.app.box.com/v/ChestXray-NIHCC .. _PadChest: https://bimcv.cipf.es/bimcv-projects/padchest/ +.. _TBX11K: https://mmcheng.net/tb/ +.. _TBX11K_simplified: https://www.kaggle.com/datasets/vbookshelf/tbx11k-simplified diff --git a/doc/references.rst b/doc/references.rst index 4897ee5bebf8a8937fa247d2f4fc8a199279ff56..5a349fa9dd960aa393f8f9d5dfd6f696676317e6 100644 --- a/doc/references.rst +++ b/doc/references.rst @@ -59,3 +59,13 @@ performance and interobserver agreement of urine lipoarabinomannan in diagnosing HIV-Associated tuberculosis in an emergency center.**, J. Acquir. Immune Defic. Syndr. 1999 81, e10–e14 (2019). + +.. [TBX11K-2020] *Liu, Y., Wu, Y.-H., Ban, Y., Wang, H., and Cheng, M.-*, + **Rethinking computer-aided tuberculosis diagnosis.**, + In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern + Recognition, pages 2646–2655. + +.. [TBX11K-SIMPLIFIED-2020] *Liu, Y., Wu, Y.-H., Ban, Y., Wang, H., and Cheng, M.-*, + **Rethinking computer-aided tuberculosis diagnosis.**, + In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern + Recognition, pages 2646–2655. diff --git a/pyproject.toml b/pyproject.toml index 47820ed74c5103622503d734f73e199b8e8c38d3..7407826a8270d9d1d909a92811d4644b42aec52a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,7 @@ dependencies = [ "pillow", "torch>=1.8", "torchvision>=0.10", - "pytorch-lightning", + "lightning", "tensorboard", ] @@ -186,6 +186,76 @@ indian_rs_f6 = "ptbench.configs.datasets.indian_RS.fold_6" indian_rs_f7 = "ptbench.configs.datasets.indian_RS.fold_7" indian_rs_f8 = "ptbench.configs.datasets.indian_RS.fold_8" indian_rs_f9 = "ptbench.configs.datasets.indian_RS.fold_9" +# TBX11K simplified dataset split 1 (and cross-validation folds) +tbx11k_simplified = "ptbench.configs.datasets.tbx11k_simplified.default" +tbx11k_simplified_rgb = "ptbench.configs.datasets.tbx11k_simplified.rgb" +tbx11k_simplified_f0 = "ptbench.configs.datasets.tbx11k_simplified.fold_0" +tbx11k_simplified_f1 = "ptbench.configs.datasets.tbx11k_simplified.fold_1" +tbx11k_simplified_f2 = "ptbench.configs.datasets.tbx11k_simplified.fold_2" +tbx11k_simplified_f3 = "ptbench.configs.datasets.tbx11k_simplified.fold_3" +tbx11k_simplified_f4 = "ptbench.configs.datasets.tbx11k_simplified.fold_4" +tbx11k_simplified_f5 = "ptbench.configs.datasets.tbx11k_simplified.fold_5" +tbx11k_simplified_f6 = "ptbench.configs.datasets.tbx11k_simplified.fold_6" +tbx11k_simplified_f7 = "ptbench.configs.datasets.tbx11k_simplified.fold_7" +tbx11k_simplified_f8 = "ptbench.configs.datasets.tbx11k_simplified.fold_8" +tbx11k_simplified_f9 = "ptbench.configs.datasets.tbx11k_simplified.fold_9" +tbx11k_simplified_f0_rgb = "ptbench.configs.datasets.tbx11k_simplified.fold_0_rgb" +tbx11k_simplified_f1_rgb = "ptbench.configs.datasets.tbx11k_simplified.fold_1_rgb" +tbx11k_simplified_f2_rgb = "ptbench.configs.datasets.tbx11k_simplified.fold_2_rgb" +tbx11k_simplified_f3_rgb = "ptbench.configs.datasets.tbx11k_simplified.fold_3_rgb" +tbx11k_simplified_f4_rgb = "ptbench.configs.datasets.tbx11k_simplified.fold_4_rgb" +tbx11k_simplified_f5_rgb = "ptbench.configs.datasets.tbx11k_simplified.fold_5_rgb" +tbx11k_simplified_f6_rgb = "ptbench.configs.datasets.tbx11k_simplified.fold_6_rgb" +tbx11k_simplified_f7_rgb = "ptbench.configs.datasets.tbx11k_simplified.fold_7_rgb" +tbx11k_simplified_f8_rgb = "ptbench.configs.datasets.tbx11k_simplified.fold_8_rgb" +tbx11k_simplified_f9_rgb = "ptbench.configs.datasets.tbx11k_simplified.fold_9_rgb" +# extended TBX11K simplified dataset split 1 (with radiological signs) +tbx11k_simplified_rs = "ptbench.configs.datasets.tbx11k_simplified_RS.default" +tbx11k_simplified_rs_f0 = "ptbench.configs.datasets.tbx11k_simplified_RS.fold_0" +tbx11k_simplified_rs_f1 = "ptbench.configs.datasets.tbx11k_simplified_RS.fold_1" +tbx11k_simplified_rs_f2 = "ptbench.configs.datasets.tbx11k_simplified_RS.fold_2" +tbx11k_simplified_rs_f3 = "ptbench.configs.datasets.tbx11k_simplified_RS.fold_3" +tbx11k_simplified_rs_f4 = "ptbench.configs.datasets.tbx11k_simplified_RS.fold_4" +tbx11k_simplified_rs_f5 = "ptbench.configs.datasets.tbx11k_simplified_RS.fold_5" +tbx11k_simplified_rs_f6 = "ptbench.configs.datasets.tbx11k_simplified_RS.fold_6" +tbx11k_simplified_rs_f7 = "ptbench.configs.datasets.tbx11k_simplified_RS.fold_7" +tbx11k_simplified_rs_f8 = "ptbench.configs.datasets.tbx11k_simplified_RS.fold_8" +tbx11k_simplified_rs_f9 = "ptbench.configs.datasets.tbx11k_simplified_RS.fold_9" +# TBX11K simplified dataset split 2 (and cross-validation folds) +tbx11k_simplified_v2 = "ptbench.configs.datasets.tbx11k_simplified_v2.default" +tbx11k_simplified_v2_rgb = "ptbench.configs.datasets.tbx11k_simplified_v2.rgb" +tbx11k_simplified_v2_f0 = "ptbench.configs.datasets.tbx11k_simplified_v2.fold_0" +tbx11k_simplified_v2_f1 = "ptbench.configs.datasets.tbx11k_simplified_v2.fold_1" +tbx11k_simplified_v2_f2 = "ptbench.configs.datasets.tbx11k_simplified_v2.fold_2" +tbx11k_simplified_v2_f3 = "ptbench.configs.datasets.tbx11k_simplified_v2.fold_3" +tbx11k_simplified_v2_f4 = "ptbench.configs.datasets.tbx11k_simplified_v2.fold_4" +tbx11k_simplified_v2_f5 = "ptbench.configs.datasets.tbx11k_simplified_v2.fold_5" +tbx11k_simplified_v2_f6 = "ptbench.configs.datasets.tbx11k_simplified_v2.fold_6" +tbx11k_simplified_v2_f7 = "ptbench.configs.datasets.tbx11k_simplified_v2.fold_7" +tbx11k_simplified_v2_f8 = "ptbench.configs.datasets.tbx11k_simplified_v2.fold_8" +tbx11k_simplified_v2_f9 = "ptbench.configs.datasets.tbx11k_simplified_v2.fold_9" +tbx11k_simplified_v2_f0_rgb = "ptbench.configs.datasets.tbx11k_simplified_v2.fold_0_rgb" +tbx11k_simplified_v2_f1_rgb = "ptbench.configs.datasets.tbx11k_simplified_v2.fold_1_rgb" +tbx11k_simplified_v2_f2_rgb = "ptbench.configs.datasets.tbx11k_simplified_v2.fold_2_rgb" +tbx11k_simplified_v2_f3_rgb = "ptbench.configs.datasets.tbx11k_simplified_v2.fold_3_rgb" +tbx11k_simplified_v2_f4_rgb = "ptbench.configs.datasets.tbx11k_simplified_v2.fold_4_rgb" +tbx11k_simplified_v2_f5_rgb = "ptbench.configs.datasets.tbx11k_simplified_v2.fold_5_rgb" +tbx11k_simplified_v2_f6_rgb = "ptbench.configs.datasets.tbx11k_simplified_v2.fold_6_rgb" +tbx11k_simplified_v2_f7_rgb = "ptbench.configs.datasets.tbx11k_simplified_v2.fold_7_rgb" +tbx11k_simplified_v2_f8_rgb = "ptbench.configs.datasets.tbx11k_simplified_v2.fold_8_rgb" +tbx11k_simplified_v2_f9_rgb = "ptbench.configs.datasets.tbx11k_simplified_v2.fold_9_rgb" +# extended TBX11K simplified dataset split 2 (with radiological signs) +tbx11k_simplified_v2_rs = "ptbench.configs.datasets.tbx11k_simplified_v2_RS.default" +tbx11k_simplified_v2_rs_f0 = "ptbench.configs.datasets.tbx11k_simplified_v2_RS.fold_0" +tbx11k_simplified_v2_rs_f1 = "ptbench.configs.datasets.tbx11k_simplified_v2_RS.fold_1" +tbx11k_simplified_v2_rs_f2 = "ptbench.configs.datasets.tbx11k_simplified_v2_RS.fold_2" +tbx11k_simplified_v2_rs_f3 = "ptbench.configs.datasets.tbx11k_simplified_v2_RS.fold_3" +tbx11k_simplified_v2_rs_f4 = "ptbench.configs.datasets.tbx11k_simplified_v2_RS.fold_4" +tbx11k_simplified_v2_rs_f5 = "ptbench.configs.datasets.tbx11k_simplified_v2_RS.fold_5" +tbx11k_simplified_v2_rs_f6 = "ptbench.configs.datasets.tbx11k_simplified_v2_RS.fold_6" +tbx11k_simplified_v2_rs_f7 = "ptbench.configs.datasets.tbx11k_simplified_v2_RS.fold_7" +tbx11k_simplified_v2_rs_f8 = "ptbench.configs.datasets.tbx11k_simplified_v2_RS.fold_8" +tbx11k_simplified_v2_rs_f9 = "ptbench.configs.datasets.tbx11k_simplified_v2_RS.fold_9" # montgomery-shenzhen aggregated dataset mc_ch = "ptbench.configs.datasets.mc_ch.default" mc_ch_rgb = "ptbench.configs.datasets.mc_ch.rgb" @@ -258,6 +328,78 @@ mc_ch_in_rs_f6 = "ptbench.configs.datasets.mc_ch_in_RS.fold_6" mc_ch_in_rs_f7 = "ptbench.configs.datasets.mc_ch_in_RS.fold_7" mc_ch_in_rs_f8 = "ptbench.configs.datasets.mc_ch_in_RS.fold_8" mc_ch_in_rs_f9 = "ptbench.configs.datasets.mc_ch_in_RS.fold_9" +# montgomery-shenzhen-indian-tbx11k aggregated dataset +mc_ch_in_11k = "ptbench.configs.datasets.mc_ch_in_11k.default" +mc_ch_in_11k_rgb = "ptbench.configs.datasets.mc_ch_in_11k.rgb" +mc_ch_in_11k_f0 = "ptbench.configs.datasets.mc_ch_in_11k.fold_0" +mc_ch_in_11k_f1 = "ptbench.configs.datasets.mc_ch_in_11k.fold_1" +mc_ch_in_11k_f2 = "ptbench.configs.datasets.mc_ch_in_11k.fold_2" +mc_ch_in_11k_f3 = "ptbench.configs.datasets.mc_ch_in_11k.fold_3" +mc_ch_in_11k_f4 = "ptbench.configs.datasets.mc_ch_in_11k.fold_4" +mc_ch_in_11k_f5 = "ptbench.configs.datasets.mc_ch_in_11k.fold_5" +mc_ch_in_11k_f6 = "ptbench.configs.datasets.mc_ch_in_11k.fold_6" +mc_ch_in_11k_f7 = "ptbench.configs.datasets.mc_ch_in_11k.fold_7" +mc_ch_in_11k_f8 = "ptbench.configs.datasets.mc_ch_in_11k.fold_8" +mc_ch_in_11k_f9 = "ptbench.configs.datasets.mc_ch_in_11k.fold_9" +mc_ch_in_11k_f0_rgb = "ptbench.configs.datasets.mc_ch_in_11k.fold_0_rgb" +mc_ch_in_11k_f1_rgb = "ptbench.configs.datasets.mc_ch_in_11k.fold_1_rgb" +mc_ch_in_11k_f2_rgb = "ptbench.configs.datasets.mc_ch_in_11k.fold_2_rgb" +mc_ch_in_11k_f3_rgb = "ptbench.configs.datasets.mc_ch_in_11k.fold_3_rgb" +mc_ch_in_11k_f4_rgb = "ptbench.configs.datasets.mc_ch_in_11k.fold_4_rgb" +mc_ch_in_11k_f5_rgb = "ptbench.configs.datasets.mc_ch_in_11k.fold_5_rgb" +mc_ch_in_11k_f6_rgb = "ptbench.configs.datasets.mc_ch_in_11k.fold_6_rgb" +mc_ch_in_11k_f7_rgb = "ptbench.configs.datasets.mc_ch_in_11k.fold_7_rgb" +mc_ch_in_11k_f8_rgb = "ptbench.configs.datasets.mc_ch_in_11k.fold_8_rgb" +mc_ch_in_11k_f9_rgb = "ptbench.configs.datasets.mc_ch_in_11k.fold_9_rgb" +# extended montgomery-shenzhen-indian-tbx11k aggregated dataset +# (with radiological signs) +mc_ch_in_11k_rs = "ptbench.configs.datasets.mc_ch_in_11k_RS.default" +mc_ch_in_11k_rs_f0 = "ptbench.configs.datasets.mc_ch_in_11k_RS.fold_0" +mc_ch_in_11k_rs_f1 = "ptbench.configs.datasets.mc_ch_in_11k_RS.fold_1" +mc_ch_in_11k_rs_f2 = "ptbench.configs.datasets.mc_ch_in_11k_RS.fold_2" +mc_ch_in_11k_rs_f3 = "ptbench.configs.datasets.mc_ch_in_11k_RS.fold_3" +mc_ch_in_11k_rs_f4 = "ptbench.configs.datasets.mc_ch_in_11k_RS.fold_4" +mc_ch_in_11k_rs_f5 = "ptbench.configs.datasets.mc_ch_in_11k_RS.fold_5" +mc_ch_in_11k_rs_f6 = "ptbench.configs.datasets.mc_ch_in_11k_RS.fold_6" +mc_ch_in_11k_rs_f7 = "ptbench.configs.datasets.mc_ch_in_11k_RS.fold_7" +mc_ch_in_11k_rs_f8 = "ptbench.configs.datasets.mc_ch_in_11k_RS.fold_8" +mc_ch_in_11k_rs_f9 = "ptbench.configs.datasets.mc_ch_in_11k_RS.fold_9" +# montgomery-shenzhen-indian-tbx11kv2 aggregated dataset +mc_ch_in_11kv2 = "ptbench.configs.datasets.mc_ch_in_11kv2.default" +mc_ch_in_11kv2_rgb = "ptbench.configs.datasets.mc_ch_in_11kv2.rgb" +mc_ch_in_11kv2_f0 = "ptbench.configs.datasets.mc_ch_in_11kv2.fold_0" +mc_ch_in_11kv2_f1 = "ptbench.configs.datasets.mc_ch_in_11kv2.fold_1" +mc_ch_in_11kv2_f2 = "ptbench.configs.datasets.mc_ch_in_11kv2.fold_2" +mc_ch_in_11kv2_f3 = "ptbench.configs.datasets.mc_ch_in_11kv2.fold_3" +mc_ch_in_11kv2_f4 = "ptbench.configs.datasets.mc_ch_in_11kv2.fold_4" +mc_ch_in_11kv2_f5 = "ptbench.configs.datasets.mc_ch_in_11kv2.fold_5" +mc_ch_in_11kv2_f6 = "ptbench.configs.datasets.mc_ch_in_11kv2.fold_6" +mc_ch_in_11kv2_f7 = "ptbench.configs.datasets.mc_ch_in_11kv2.fold_7" +mc_ch_in_11kv2_f8 = "ptbench.configs.datasets.mc_ch_in_11kv2.fold_8" +mc_ch_in_11kv2_f9 = "ptbench.configs.datasets.mc_ch_in_11kv2.fold_9" +mc_ch_in_11kv2_f0_rgb = "ptbench.configs.datasets.mc_ch_in_11kv2.fold_0_rgb" +mc_ch_in_11kv2_f1_rgb = "ptbench.configs.datasets.mc_ch_in_11kv2.fold_1_rgb" +mc_ch_in_11kv2_f2_rgb = "ptbench.configs.datasets.mc_ch_in_11kv2.fold_2_rgb" +mc_ch_in_11kv2_f3_rgb = "ptbench.configs.datasets.mc_ch_in_11kv2.fold_3_rgb" +mc_ch_in_11kv2_f4_rgb = "ptbench.configs.datasets.mc_ch_in_11kv2.fold_4_rgb" +mc_ch_in_11kv2_f5_rgb = "ptbench.configs.datasets.mc_ch_in_11kv2.fold_5_rgb" +mc_ch_in_11kv2_f6_rgb = "ptbench.configs.datasets.mc_ch_in_11kv2.fold_6_rgb" +mc_ch_in_11kv2_f7_rgb = "ptbench.configs.datasets.mc_ch_in_11kv2.fold_7_rgb" +mc_ch_in_11kv2_f8_rgb = "ptbench.configs.datasets.mc_ch_in_11kv2.fold_8_rgb" +mc_ch_in_11kv2_f9_rgb = "ptbench.configs.datasets.mc_ch_in_11kv2.fold_9_rgb" +# extended montgomery-shenzhen-indian-tbx11kv2 aggregated dataset +# (with radiological signs) +mc_ch_in_11kv2_rs = "ptbench.configs.datasets.mc_ch_in_11kv2_RS.default" +mc_ch_in_11kv2_rs_f0 = "ptbench.configs.datasets.mc_ch_in_11kv2_RS.fold_0" +mc_ch_in_11kv2_rs_f1 = "ptbench.configs.datasets.mc_ch_in_11kv2_RS.fold_1" +mc_ch_in_11kv2_rs_f2 = "ptbench.configs.datasets.mc_ch_in_11kv2_RS.fold_2" +mc_ch_in_11kv2_rs_f3 = "ptbench.configs.datasets.mc_ch_in_11kv2_RS.fold_3" +mc_ch_in_11kv2_rs_f4 = "ptbench.configs.datasets.mc_ch_in_11kv2_RS.fold_4" +mc_ch_in_11kv2_rs_f5 = "ptbench.configs.datasets.mc_ch_in_11kv2_RS.fold_5" +mc_ch_in_11kv2_rs_f6 = "ptbench.configs.datasets.mc_ch_in_11kv2_RS.fold_6" +mc_ch_in_11kv2_rs_f7 = "ptbench.configs.datasets.mc_ch_in_11kv2_RS.fold_7" +mc_ch_in_11kv2_rs_f8 = "ptbench.configs.datasets.mc_ch_in_11kv2_RS.fold_8" +mc_ch_in_11kv2_rs_f9 = "ptbench.configs.datasets.mc_ch_in_11kv2_RS.fold_9" # tbpoc dataset (and cross-validation folds) tbpoc_f0 = "ptbench.configs.datasets.tbpoc.fold_0" tbpoc_f1 = "ptbench.configs.datasets.tbpoc.fold_1" diff --git a/src/ptbench/configs/datasets/mc_ch_in_11k/__init__.py b/src/ptbench/configs/datasets/mc_ch_in_11k/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e8970bd744280c9fb07f6f09d5b4b844a9f57993 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11k/__init__.py @@ -0,0 +1,157 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from torch.utils.data.dataset import ConcatDataset + + +def _maker(protocol): + if protocol == "default": + from ..indian import default as indian + from ..montgomery import default as mc + from ..shenzhen import default as ch + from ..tbx11k_simplified import default as tbx11k + elif protocol == "rgb": + from ..indian import rgb as indian + from ..montgomery import rgb as mc + from ..shenzhen import rgb as ch + from ..tbx11k_simplified import rgb as tbx11k + elif protocol == "fold_0": + from ..indian import fold_0 as indian + from ..montgomery import fold_0 as mc + from ..shenzhen import fold_0 as ch + from ..tbx11k_simplified import fold_0 as tbx11k + elif protocol == "fold_1": + from ..indian import fold_1 as indian + from ..montgomery import fold_1 as mc + from ..shenzhen import fold_1 as ch + from ..tbx11k_simplified import fold_1 as tbx11k + elif protocol == "fold_2": + from ..indian import fold_2 as indian + from ..montgomery import fold_2 as mc + from ..shenzhen import fold_2 as ch + from ..tbx11k_simplified import fold_2 as tbx11k + elif protocol == "fold_3": + from ..indian import fold_3 as indian + from ..montgomery import fold_3 as mc + from ..shenzhen import fold_3 as ch + from ..tbx11k_simplified import fold_3 as tbx11k + elif protocol == "fold_4": + from ..indian import fold_4 as indian + from ..montgomery import fold_4 as mc + from ..shenzhen import fold_4 as ch + from ..tbx11k_simplified import fold_4 as tbx11k + elif protocol == "fold_5": + from ..indian import fold_5 as indian + from ..montgomery import fold_5 as mc + from ..shenzhen import fold_5 as ch + from ..tbx11k_simplified import fold_5 as tbx11k + elif protocol == "fold_6": + from ..indian import fold_6 as indian + from ..montgomery import fold_6 as mc + from ..shenzhen import fold_6 as ch + from ..tbx11k_simplified import fold_6 as tbx11k + elif protocol == "fold_7": + from ..indian import fold_7 as indian + from ..montgomery import fold_7 as mc + from ..shenzhen import fold_7 as ch + from ..tbx11k_simplified import fold_7 as tbx11k + elif protocol == "fold_8": + from ..indian import fold_8 as indian + from ..montgomery import fold_8 as mc + from ..shenzhen import fold_8 as ch + from ..tbx11k_simplified import fold_8 as tbx11k + elif protocol == "fold_9": + from ..indian import fold_9 as indian + from ..montgomery import fold_9 as mc + from ..shenzhen import fold_9 as ch + from ..tbx11k_simplified import fold_9 as tbx11k + elif protocol == "fold_0_rgb": + from ..indian import fold_0_rgb as indian + from ..montgomery import fold_0_rgb as mc + from ..shenzhen import fold_0_rgb as ch + from ..tbx11k_simplified import fold_0_rgb as tbx11k + elif protocol == "fold_1_rgb": + from ..indian import fold_1_rgb as indian + from ..montgomery import fold_1_rgb as mc + from ..shenzhen import fold_1_rgb as ch + from ..tbx11k_simplified import fold_1_rgb as tbx11k + elif protocol == "fold_2_rgb": + from ..indian import fold_2_rgb as indian + from ..montgomery import fold_2_rgb as mc + from ..shenzhen import fold_2_rgb as ch + from ..tbx11k_simplified import fold_2_rgb as tbx11k + elif protocol == "fold_3_rgb": + from ..indian import fold_3_rgb as indian + from ..montgomery import fold_3_rgb as mc + from ..shenzhen import fold_3_rgb as ch + from ..tbx11k_simplified import fold_3_rgb as tbx11k + elif protocol == "fold_4_rgb": + from ..indian import fold_4_rgb as indian + from ..montgomery import fold_4_rgb as mc + from ..shenzhen import fold_4_rgb as ch + from ..tbx11k_simplified import fold_4_rgb as tbx11k + elif protocol == "fold_5_rgb": + from ..indian import fold_5_rgb as indian + from ..montgomery import fold_5_rgb as mc + from ..shenzhen import fold_5_rgb as ch + from ..tbx11k_simplified import fold_5_rgb as tbx11k + elif protocol == "fold_6_rgb": + from ..indian import fold_6_rgb as indian + from ..montgomery import fold_6_rgb as mc + from ..shenzhen import fold_6_rgb as ch + from ..tbx11k_simplified import fold_6_rgb as tbx11k + elif protocol == "fold_7_rgb": + from ..indian import fold_7_rgb as indian + from ..montgomery import fold_7_rgb as mc + from ..shenzhen import fold_7_rgb as ch + from ..tbx11k_simplified import fold_7_rgb as tbx11k + elif protocol == "fold_8_rgb": + from ..indian import fold_8_rgb as indian + from ..montgomery import fold_8_rgb as mc + from ..shenzhen import fold_8_rgb as ch + from ..tbx11k_simplified import fold_8_rgb as tbx11k + elif protocol == "fold_9_rgb": + from ..indian import fold_9_rgb as indian + from ..montgomery import fold_9_rgb as mc + from ..shenzhen import fold_9_rgb as ch + from ..tbx11k_simplified import fold_9_rgb as tbx11k + + mc = mc.dataset + ch = ch.dataset + indian = indian.dataset + tbx11k = tbx11k.dataset + + dataset = {} + dataset["__train__"] = ConcatDataset( + [ + mc["__train__"], + ch["__train__"], + indian["__train__"], + tbx11k["__train__"], + ] + ) + dataset["train"] = ConcatDataset( + [mc["train"], ch["train"], indian["train"], tbx11k["train"]] + ) + dataset["__valid__"] = ConcatDataset( + [ + mc["__valid__"], + ch["__valid__"], + indian["__valid__"], + tbx11k["__valid__"], + ] + ) + dataset["validation"] = ConcatDataset( + [ + mc["validation"], + ch["validation"], + indian["validation"], + tbx11k["validation"], + ] + ) + dataset["test"] = ConcatDataset( + [mc["test"], ch["test"], indian["test"], tbx11k["test"]] + ) + + return dataset diff --git a/src/ptbench/configs/datasets/mc_ch_in_11k/default.py b/src/ptbench/configs/datasets/mc_ch_in_11k/default.py new file mode 100644 index 0000000000000000000000000000000000000000..7d4f16bda48b05e7e9302ffc9c689d8393b3e495 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11k/default.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets.""" + +from . import _maker + +dataset = _maker("default") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11k/fold_0.py b/src/ptbench/configs/datasets/mc_ch_in_11k/fold_0.py new file mode 100644 index 0000000000000000000000000000000000000000..757a0eb98214ba020d76095363d424b9209540e7 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11k/fold_0.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 0)""" + +from . import _maker + +dataset = _maker("fold_0") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11k/fold_0_rgb.py b/src/ptbench/configs/datasets/mc_ch_in_11k/fold_0_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..48e05ff3f71f13976190d04cfaf59c5c36996bac --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11k/fold_0_rgb.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 0, RGB)""" + +from . import _maker + +dataset = _maker("fold_0_rgb") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11k/fold_1.py b/src/ptbench/configs/datasets/mc_ch_in_11k/fold_1.py new file mode 100644 index 0000000000000000000000000000000000000000..5657958934b926879bd26503c9b383e775bc724d --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11k/fold_1.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 1)""" + +from . import _maker + +dataset = _maker("fold_1") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11k/fold_1_rgb.py b/src/ptbench/configs/datasets/mc_ch_in_11k/fold_1_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..c782d68de247c876ddd6826100cbb7908342b928 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11k/fold_1_rgb.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 1, RGB)""" + +from . import _maker + +dataset = _maker("fold_1_rgb") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11k/fold_2.py b/src/ptbench/configs/datasets/mc_ch_in_11k/fold_2.py new file mode 100644 index 0000000000000000000000000000000000000000..10a597bcb8e0485db63f0d7500b15b3e78877066 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11k/fold_2.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 2)""" + +from . import _maker + +dataset = _maker("fold_2") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11k/fold_2_rgb.py b/src/ptbench/configs/datasets/mc_ch_in_11k/fold_2_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..d624f3af53abcf053c7bf17a9822a86cb53e2923 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11k/fold_2_rgb.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 2, RGB)""" + +from . import _maker + +dataset = _maker("fold_2_rgb") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11k/fold_3.py b/src/ptbench/configs/datasets/mc_ch_in_11k/fold_3.py new file mode 100644 index 0000000000000000000000000000000000000000..39bee4fec99e81eecc22a365183283bcd2ec3d98 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11k/fold_3.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 3)""" + +from . import _maker + +dataset = _maker("fold_3") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11k/fold_3_rgb.py b/src/ptbench/configs/datasets/mc_ch_in_11k/fold_3_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..7b26e4257e61013843e3a62c3bc419003e23b645 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11k/fold_3_rgb.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 3, RGB)""" + +from . import _maker + +dataset = _maker("fold_3_rgb") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11k/fold_4.py b/src/ptbench/configs/datasets/mc_ch_in_11k/fold_4.py new file mode 100644 index 0000000000000000000000000000000000000000..5fb56292fd97636f452cde06c87bb34c89f01b1c --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11k/fold_4.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 4)""" + +from . import _maker + +dataset = _maker("fold_4") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11k/fold_4_rgb.py b/src/ptbench/configs/datasets/mc_ch_in_11k/fold_4_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..fbc4f0cfd9edc602fbe5665aca0465b29c5183b5 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11k/fold_4_rgb.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 4, RGB)""" + +from . import _maker + +dataset = _maker("fold_4_rgb") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11k/fold_5.py b/src/ptbench/configs/datasets/mc_ch_in_11k/fold_5.py new file mode 100644 index 0000000000000000000000000000000000000000..679bb9b3cbbdede06cd87834239609720f439296 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11k/fold_5.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 5)""" + +from . import _maker + +dataset = _maker("fold_5") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11k/fold_5_rgb.py b/src/ptbench/configs/datasets/mc_ch_in_11k/fold_5_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..747d510ecd1c7bd2f32ab7b139a53603d5bbee88 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11k/fold_5_rgb.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 5, RGB)""" + +from . import _maker + +dataset = _maker("fold_5_rgb") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11k/fold_6.py b/src/ptbench/configs/datasets/mc_ch_in_11k/fold_6.py new file mode 100644 index 0000000000000000000000000000000000000000..cd8e4cd571b8c796bad3221584870888c5186d3d --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11k/fold_6.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 6)""" + +from . import _maker + +dataset = _maker("fold_6") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11k/fold_6_rgb.py b/src/ptbench/configs/datasets/mc_ch_in_11k/fold_6_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..86f112c3aae0c1c1dd48002347f78ce565797d47 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11k/fold_6_rgb.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 6, RGB)""" + +from . import _maker + +dataset = _maker("fold_6_rgb") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11k/fold_7.py b/src/ptbench/configs/datasets/mc_ch_in_11k/fold_7.py new file mode 100644 index 0000000000000000000000000000000000000000..98241531d3e15720f07ef9174687c47db7d737f1 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11k/fold_7.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 7)""" + +from . import _maker + +dataset = _maker("fold_7") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11k/fold_7_rgb.py b/src/ptbench/configs/datasets/mc_ch_in_11k/fold_7_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..981fe19180e0d8d4e1b21653f52a92a567723a63 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11k/fold_7_rgb.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 7, RGB)""" + +from . import _maker + +dataset = _maker("fold_7_rgb") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11k/fold_8.py b/src/ptbench/configs/datasets/mc_ch_in_11k/fold_8.py new file mode 100644 index 0000000000000000000000000000000000000000..dab1a234a3842ab450706d86060651d4383ddbfc --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11k/fold_8.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 8)""" + +from . import _maker + +dataset = _maker("fold_8") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11k/fold_8_rgb.py b/src/ptbench/configs/datasets/mc_ch_in_11k/fold_8_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..798b8de64761ef0d87f491ef08b43426f55898f2 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11k/fold_8_rgb.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 8, RGB)""" + +from . import _maker + +dataset = _maker("fold_8_rgb") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11k/fold_9.py b/src/ptbench/configs/datasets/mc_ch_in_11k/fold_9.py new file mode 100644 index 0000000000000000000000000000000000000000..097724b9446c4c2f0bef8ee6f838c1c11ff627a5 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11k/fold_9.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 9)""" + +from . import _maker + +dataset = _maker("fold_9") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11k/fold_9_rgb.py b/src/ptbench/configs/datasets/mc_ch_in_11k/fold_9_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..c6c564a40b957b562a37bb30b5809f7cf680e896 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11k/fold_9_rgb.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 9, RGB)""" + +from . import _maker + +dataset = _maker("fold_9_rgb") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11k/rgb.py b/src/ptbench/configs/datasets/mc_ch_in_11k/rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..f47796a89c31a5a31c0f972d81b5d97c7f8742b4 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11k/rgb.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (RGB)""" + +from . import _maker + +dataset = _maker("rgb") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11k_RS/__init__.py b/src/ptbench/configs/datasets/mc_ch_in_11k_RS/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7fa4f30799fd886ca7f792acb51d77870bede69c --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11k_RS/__init__.py @@ -0,0 +1,157 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from torch.utils.data.dataset import ConcatDataset + + +def _maker(protocol): + if protocol == "default": + from ..indian_RS import default as indian + from ..montgomery_RS import default as mc + from ..shenzhen_RS import default as ch + from ..tbx11k_simplified_RS import default as tbx11k + elif protocol == "rgb": + from ..indian_RS import rgb as indian + from ..montgomery_RS import rgb as mc + from ..shenzhen_RS import rgb as ch + from ..tbx11k_simplified_RS import rgb as tbx11k + elif protocol == "fold_0": + from ..indian_RS import fold_0 as indian + from ..montgomery_RS import fold_0 as mc + from ..shenzhen_RS import fold_0 as ch + from ..tbx11k_simplified_RS import fold_0 as tbx11k + elif protocol == "fold_1": + from ..indian_RS import fold_1 as indian + from ..montgomery_RS import fold_1 as mc + from ..shenzhen_RS import fold_1 as ch + from ..tbx11k_simplified_RS import fold_1 as tbx11k + elif protocol == "fold_2": + from ..indian_RS import fold_2 as indian + from ..montgomery_RS import fold_2 as mc + from ..shenzhen_RS import fold_2 as ch + from ..tbx11k_simplified_RS import fold_2 as tbx11k + elif protocol == "fold_3": + from ..indian_RS import fold_3 as indian + from ..montgomery_RS import fold_3 as mc + from ..shenzhen_RS import fold_3 as ch + from ..tbx11k_simplified_RS import fold_3 as tbx11k + elif protocol == "fold_4": + from ..indian_RS import fold_4 as indian + from ..montgomery_RS import fold_4 as mc + from ..shenzhen_RS import fold_4 as ch + from ..tbx11k_simplified_RS import fold_4 as tbx11k + elif protocol == "fold_5": + from ..indian_RS import fold_5 as indian + from ..montgomery_RS import fold_5 as mc + from ..shenzhen_RS import fold_5 as ch + from ..tbx11k_simplified_RS import fold_5 as tbx11k + elif protocol == "fold_6": + from ..indian_RS import fold_6 as indian + from ..montgomery_RS import fold_6 as mc + from ..shenzhen_RS import fold_6 as ch + from ..tbx11k_simplified_RS import fold_6 as tbx11k + elif protocol == "fold_7": + from ..indian_RS import fold_7 as indian + from ..montgomery_RS import fold_7 as mc + from ..shenzhen_RS import fold_7 as ch + from ..tbx11k_simplified_RS import fold_7 as tbx11k + elif protocol == "fold_8": + from ..indian_RS import fold_8 as indian + from ..montgomery_RS import fold_8 as mc + from ..shenzhen_RS import fold_8 as ch + from ..tbx11k_simplified_RS import fold_8 as tbx11k + elif protocol == "fold_9": + from ..indian_RS import fold_9 as indian + from ..montgomery_RS import fold_9 as mc + from ..shenzhen_RS import fold_9 as ch + from ..tbx11k_simplified_RS import fold_9 as tbx11k + elif protocol == "fold_0_rgb": + from ..indian_RS import fold_0_rgb as indian + from ..montgomery_RS import fold_0_rgb as mc + from ..shenzhen_RS import fold_0_rgb as ch + from ..tbx11k_simplified_RS import fold_0_rgb as tbx11k + elif protocol == "fold_1_rgb": + from ..indian_RS import fold_1_rgb as indian + from ..montgomery_RS import fold_1_rgb as mc + from ..shenzhen_RS import fold_1_rgb as ch + from ..tbx11k_simplified_RS import fold_1_rgb as tbx11k + elif protocol == "fold_2_rgb": + from ..indian_RS import fold_2_rgb as indian + from ..montgomery_RS import fold_2_rgb as mc + from ..shenzhen_RS import fold_2_rgb as ch + from ..tbx11k_simplified_RS import fold_2_rgb as tbx11k + elif protocol == "fold_3_rgb": + from ..indian_RS import fold_3_rgb as indian + from ..montgomery_RS import fold_3_rgb as mc + from ..shenzhen_RS import fold_3_rgb as ch + from ..tbx11k_simplified_RS import fold_3_rgb as tbx11k + elif protocol == "fold_4_rgb": + from ..indian_RS import fold_4_rgb as indian + from ..montgomery_RS import fold_4_rgb as mc + from ..shenzhen_RS import fold_4_rgb as ch + from ..tbx11k_simplified_RS import fold_4_rgb as tbx11k + elif protocol == "fold_5_rgb": + from ..indian_RS import fold_5_rgb as indian + from ..montgomery_RS import fold_5_rgb as mc + from ..shenzhen_RS import fold_5_rgb as ch + from ..tbx11k_simplified_RS import fold_5_rgb as tbx11k + elif protocol == "fold_6_rgb": + from ..indian_RS import fold_6_rgb as indian + from ..montgomery_RS import fold_6_rgb as mc + from ..shenzhen_RS import fold_6_rgb as ch + from ..tbx11k_simplified_RS import fold_6_rgb as tbx11k + elif protocol == "fold_7_rgb": + from ..indian_RS import fold_7_rgb as indian + from ..montgomery_RS import fold_7_rgb as mc + from ..shenzhen_RS import fold_7_rgb as ch + from ..tbx11k_simplified_RS import fold_7_rgb as tbx11k + elif protocol == "fold_8_rgb": + from ..indian_RS import fold_8_rgb as indian + from ..montgomery_RS import fold_8_rgb as mc + from ..shenzhen_RS import fold_8_rgb as ch + from ..tbx11k_simplified_RS import fold_8_rgb as tbx11k + elif protocol == "fold_9_rgb": + from ..indian_RS import fold_9_rgb as indian + from ..montgomery_RS import fold_9_rgb as mc + from ..shenzhen_RS import fold_9_rgb as ch + from ..tbx11k_simplified_RS import fold_9_rgb as tbx11k + + mc = mc.dataset + ch = ch.dataset + indian = indian.dataset + tbx11k = tbx11k.dataset + + dataset = {} + dataset["__train__"] = ConcatDataset( + [ + mc["__train__"], + ch["__train__"], + indian["__train__"], + tbx11k["__train__"], + ] + ) + dataset["train"] = ConcatDataset( + [mc["train"], ch["train"], indian["train"], tbx11k["train"]] + ) + dataset["__valid__"] = ConcatDataset( + [ + mc["__valid__"], + ch["__valid__"], + indian["__valid__"], + tbx11k["__valid__"], + ] + ) + dataset["validation"] = ConcatDataset( + [ + mc["validation"], + ch["validation"], + indian["validation"], + tbx11k["validation"], + ] + ) + dataset["test"] = ConcatDataset( + [mc["test"], ch["test"], indian["test"], tbx11k["test"]] + ) + + return dataset diff --git a/src/ptbench/configs/datasets/mc_ch_in_11k_RS/default.py b/src/ptbench/configs/datasets/mc_ch_in_11k_RS/default.py new file mode 100644 index 0000000000000000000000000000000000000000..7d4f16bda48b05e7e9302ffc9c689d8393b3e495 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11k_RS/default.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets.""" + +from . import _maker + +dataset = _maker("default") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11k_RS/fold_0.py b/src/ptbench/configs/datasets/mc_ch_in_11k_RS/fold_0.py new file mode 100644 index 0000000000000000000000000000000000000000..757a0eb98214ba020d76095363d424b9209540e7 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11k_RS/fold_0.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 0)""" + +from . import _maker + +dataset = _maker("fold_0") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11k_RS/fold_0_rgb.py b/src/ptbench/configs/datasets/mc_ch_in_11k_RS/fold_0_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..48e05ff3f71f13976190d04cfaf59c5c36996bac --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11k_RS/fold_0_rgb.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 0, RGB)""" + +from . import _maker + +dataset = _maker("fold_0_rgb") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11k_RS/fold_1.py b/src/ptbench/configs/datasets/mc_ch_in_11k_RS/fold_1.py new file mode 100644 index 0000000000000000000000000000000000000000..5657958934b926879bd26503c9b383e775bc724d --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11k_RS/fold_1.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 1)""" + +from . import _maker + +dataset = _maker("fold_1") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11k_RS/fold_1_rgb.py b/src/ptbench/configs/datasets/mc_ch_in_11k_RS/fold_1_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..c782d68de247c876ddd6826100cbb7908342b928 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11k_RS/fold_1_rgb.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 1, RGB)""" + +from . import _maker + +dataset = _maker("fold_1_rgb") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11k_RS/fold_2.py b/src/ptbench/configs/datasets/mc_ch_in_11k_RS/fold_2.py new file mode 100644 index 0000000000000000000000000000000000000000..10a597bcb8e0485db63f0d7500b15b3e78877066 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11k_RS/fold_2.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 2)""" + +from . import _maker + +dataset = _maker("fold_2") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11k_RS/fold_2_rgb.py b/src/ptbench/configs/datasets/mc_ch_in_11k_RS/fold_2_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..d624f3af53abcf053c7bf17a9822a86cb53e2923 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11k_RS/fold_2_rgb.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 2, RGB)""" + +from . import _maker + +dataset = _maker("fold_2_rgb") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11k_RS/fold_3.py b/src/ptbench/configs/datasets/mc_ch_in_11k_RS/fold_3.py new file mode 100644 index 0000000000000000000000000000000000000000..39bee4fec99e81eecc22a365183283bcd2ec3d98 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11k_RS/fold_3.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 3)""" + +from . import _maker + +dataset = _maker("fold_3") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11k_RS/fold_3_rgb.py b/src/ptbench/configs/datasets/mc_ch_in_11k_RS/fold_3_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..7b26e4257e61013843e3a62c3bc419003e23b645 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11k_RS/fold_3_rgb.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 3, RGB)""" + +from . import _maker + +dataset = _maker("fold_3_rgb") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11k_RS/fold_4.py b/src/ptbench/configs/datasets/mc_ch_in_11k_RS/fold_4.py new file mode 100644 index 0000000000000000000000000000000000000000..5fb56292fd97636f452cde06c87bb34c89f01b1c --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11k_RS/fold_4.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 4)""" + +from . import _maker + +dataset = _maker("fold_4") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11k_RS/fold_4_rgb.py b/src/ptbench/configs/datasets/mc_ch_in_11k_RS/fold_4_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..fbc4f0cfd9edc602fbe5665aca0465b29c5183b5 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11k_RS/fold_4_rgb.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 4, RGB)""" + +from . import _maker + +dataset = _maker("fold_4_rgb") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11k_RS/fold_5.py b/src/ptbench/configs/datasets/mc_ch_in_11k_RS/fold_5.py new file mode 100644 index 0000000000000000000000000000000000000000..679bb9b3cbbdede06cd87834239609720f439296 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11k_RS/fold_5.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 5)""" + +from . import _maker + +dataset = _maker("fold_5") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11k_RS/fold_5_rgb.py b/src/ptbench/configs/datasets/mc_ch_in_11k_RS/fold_5_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..747d510ecd1c7bd2f32ab7b139a53603d5bbee88 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11k_RS/fold_5_rgb.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 5, RGB)""" + +from . import _maker + +dataset = _maker("fold_5_rgb") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11k_RS/fold_6.py b/src/ptbench/configs/datasets/mc_ch_in_11k_RS/fold_6.py new file mode 100644 index 0000000000000000000000000000000000000000..cd8e4cd571b8c796bad3221584870888c5186d3d --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11k_RS/fold_6.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 6)""" + +from . import _maker + +dataset = _maker("fold_6") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11k_RS/fold_6_rgb.py b/src/ptbench/configs/datasets/mc_ch_in_11k_RS/fold_6_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..86f112c3aae0c1c1dd48002347f78ce565797d47 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11k_RS/fold_6_rgb.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 6, RGB)""" + +from . import _maker + +dataset = _maker("fold_6_rgb") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11k_RS/fold_7.py b/src/ptbench/configs/datasets/mc_ch_in_11k_RS/fold_7.py new file mode 100644 index 0000000000000000000000000000000000000000..98241531d3e15720f07ef9174687c47db7d737f1 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11k_RS/fold_7.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 7)""" + +from . import _maker + +dataset = _maker("fold_7") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11k_RS/fold_7_rgb.py b/src/ptbench/configs/datasets/mc_ch_in_11k_RS/fold_7_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..981fe19180e0d8d4e1b21653f52a92a567723a63 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11k_RS/fold_7_rgb.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 7, RGB)""" + +from . import _maker + +dataset = _maker("fold_7_rgb") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11k_RS/fold_8.py b/src/ptbench/configs/datasets/mc_ch_in_11k_RS/fold_8.py new file mode 100644 index 0000000000000000000000000000000000000000..dab1a234a3842ab450706d86060651d4383ddbfc --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11k_RS/fold_8.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 8)""" + +from . import _maker + +dataset = _maker("fold_8") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11k_RS/fold_8_rgb.py b/src/ptbench/configs/datasets/mc_ch_in_11k_RS/fold_8_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..798b8de64761ef0d87f491ef08b43426f55898f2 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11k_RS/fold_8_rgb.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 8, RGB)""" + +from . import _maker + +dataset = _maker("fold_8_rgb") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11k_RS/fold_9.py b/src/ptbench/configs/datasets/mc_ch_in_11k_RS/fold_9.py new file mode 100644 index 0000000000000000000000000000000000000000..097724b9446c4c2f0bef8ee6f838c1c11ff627a5 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11k_RS/fold_9.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 9)""" + +from . import _maker + +dataset = _maker("fold_9") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11k_RS/fold_9_rgb.py b/src/ptbench/configs/datasets/mc_ch_in_11k_RS/fold_9_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..c6c564a40b957b562a37bb30b5809f7cf680e896 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11k_RS/fold_9_rgb.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 9, RGB)""" + +from . import _maker + +dataset = _maker("fold_9_rgb") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11k_RS/rgb.py b/src/ptbench/configs/datasets/mc_ch_in_11k_RS/rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..f47796a89c31a5a31c0f972d81b5d97c7f8742b4 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11k_RS/rgb.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (RGB)""" + +from . import _maker + +dataset = _maker("rgb") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11kv2/__init__.py b/src/ptbench/configs/datasets/mc_ch_in_11kv2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c36f7f60c8d1111baa7f3559c419e134b7ece62f --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11kv2/__init__.py @@ -0,0 +1,157 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from torch.utils.data.dataset import ConcatDataset + + +def _maker(protocol): + if protocol == "default": + from ..indian import default as indian + from ..montgomery import default as mc + from ..shenzhen import default as ch + from ..tbx11k_simplified_v2 import default as tbx11kv2 + elif protocol == "rgb": + from ..indian import rgb as indian + from ..montgomery import rgb as mc + from ..shenzhen import rgb as ch + from ..tbx11k_simplified_v2 import rgb as tbx11kv2 + elif protocol == "fold_0": + from ..indian import fold_0 as indian + from ..montgomery import fold_0 as mc + from ..shenzhen import fold_0 as ch + from ..tbx11k_simplified_v2 import fold_0 as tbx11kv2 + elif protocol == "fold_1": + from ..indian import fold_1 as indian + from ..montgomery import fold_1 as mc + from ..shenzhen import fold_1 as ch + from ..tbx11k_simplified_v2 import fold_1 as tbx11kv2 + elif protocol == "fold_2": + from ..indian import fold_2 as indian + from ..montgomery import fold_2 as mc + from ..shenzhen import fold_2 as ch + from ..tbx11k_simplified_v2 import fold_2 as tbx11kv2 + elif protocol == "fold_3": + from ..indian import fold_3 as indian + from ..montgomery import fold_3 as mc + from ..shenzhen import fold_3 as ch + from ..tbx11k_simplified_v2 import fold_3 as tbx11kv2 + elif protocol == "fold_4": + from ..indian import fold_4 as indian + from ..montgomery import fold_4 as mc + from ..shenzhen import fold_4 as ch + from ..tbx11k_simplified_v2 import fold_4 as tbx11kv2 + elif protocol == "fold_5": + from ..indian import fold_5 as indian + from ..montgomery import fold_5 as mc + from ..shenzhen import fold_5 as ch + from ..tbx11k_simplified_v2 import fold_5 as tbx11kv2 + elif protocol == "fold_6": + from ..indian import fold_6 as indian + from ..montgomery import fold_6 as mc + from ..shenzhen import fold_6 as ch + from ..tbx11k_simplified_v2 import fold_6 as tbx11kv2 + elif protocol == "fold_7": + from ..indian import fold_7 as indian + from ..montgomery import fold_7 as mc + from ..shenzhen import fold_7 as ch + from ..tbx11k_simplified_v2 import fold_7 as tbx11kv2 + elif protocol == "fold_8": + from ..indian import fold_8 as indian + from ..montgomery import fold_8 as mc + from ..shenzhen import fold_8 as ch + from ..tbx11k_simplified_v2 import fold_8 as tbx11kv2 + elif protocol == "fold_9": + from ..indian import fold_9 as indian + from ..montgomery import fold_9 as mc + from ..shenzhen import fold_9 as ch + from ..tbx11k_simplified_v2 import fold_9 as tbx11kv2 + elif protocol == "fold_0_rgb": + from ..indian import fold_0_rgb as indian + from ..montgomery import fold_0_rgb as mc + from ..shenzhen import fold_0_rgb as ch + from ..tbx11k_simplified_v2 import fold_0_rgb as tbx11kv2 + elif protocol == "fold_1_rgb": + from ..indian import fold_1_rgb as indian + from ..montgomery import fold_1_rgb as mc + from ..shenzhen import fold_1_rgb as ch + from ..tbx11k_simplified_v2 import fold_1_rgb as tbx11kv2 + elif protocol == "fold_2_rgb": + from ..indian import fold_2_rgb as indian + from ..montgomery import fold_2_rgb as mc + from ..shenzhen import fold_2_rgb as ch + from ..tbx11k_simplified_v2 import fold_2_rgb as tbx11kv2 + elif protocol == "fold_3_rgb": + from ..indian import fold_3_rgb as indian + from ..montgomery import fold_3_rgb as mc + from ..shenzhen import fold_3_rgb as ch + from ..tbx11k_simplified_v2 import fold_3_rgb as tbx11kv2 + elif protocol == "fold_4_rgb": + from ..indian import fold_4_rgb as indian + from ..montgomery import fold_4_rgb as mc + from ..shenzhen import fold_4_rgb as ch + from ..tbx11k_simplified_v2 import fold_4_rgb as tbx11kv2 + elif protocol == "fold_5_rgb": + from ..indian import fold_5_rgb as indian + from ..montgomery import fold_5_rgb as mc + from ..shenzhen import fold_5_rgb as ch + from ..tbx11k_simplified_v2 import fold_5_rgb as tbx11kv2 + elif protocol == "fold_6_rgb": + from ..indian import fold_6_rgb as indian + from ..montgomery import fold_6_rgb as mc + from ..shenzhen import fold_6_rgb as ch + from ..tbx11k_simplified_v2 import fold_6_rgb as tbx11kv2 + elif protocol == "fold_7_rgb": + from ..indian import fold_7_rgb as indian + from ..montgomery import fold_7_rgb as mc + from ..shenzhen import fold_7_rgb as ch + from ..tbx11k_simplified_v2 import fold_7_rgb as tbx11kv2 + elif protocol == "fold_8_rgb": + from ..indian import fold_8_rgb as indian + from ..montgomery import fold_8_rgb as mc + from ..shenzhen import fold_8_rgb as ch + from ..tbx11k_simplified_v2 import fold_8_rgb as tbx11kv2 + elif protocol == "fold_9_rgb": + from ..indian import fold_9_rgb as indian + from ..montgomery import fold_9_rgb as mc + from ..shenzhen import fold_9_rgb as ch + from ..tbx11k_simplified_v2 import fold_9_rgb as tbx11kv2 + + mc = mc.dataset + ch = ch.dataset + indian = indian.dataset + tbx11kv2 = tbx11kv2.dataset + + dataset = {} + dataset["__train__"] = ConcatDataset( + [ + mc["__train__"], + ch["__train__"], + indian["__train__"], + tbx11kv2["__train__"], + ] + ) + dataset["train"] = ConcatDataset( + [mc["train"], ch["train"], indian["train"], tbx11kv2["train"]] + ) + dataset["__valid__"] = ConcatDataset( + [ + mc["__valid__"], + ch["__valid__"], + indian["__valid__"], + tbx11kv2["__valid__"], + ] + ) + dataset["validation"] = ConcatDataset( + [ + mc["validation"], + ch["validation"], + indian["validation"], + tbx11kv2["validation"], + ] + ) + dataset["test"] = ConcatDataset( + [mc["test"], ch["test"], indian["test"], tbx11kv2["test"]] + ) + + return dataset diff --git a/src/ptbench/configs/datasets/mc_ch_in_11kv2/default.py b/src/ptbench/configs/datasets/mc_ch_in_11kv2/default.py new file mode 100644 index 0000000000000000000000000000000000000000..7d4f16bda48b05e7e9302ffc9c689d8393b3e495 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11kv2/default.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets.""" + +from . import _maker + +dataset = _maker("default") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11kv2/fold_0.py b/src/ptbench/configs/datasets/mc_ch_in_11kv2/fold_0.py new file mode 100644 index 0000000000000000000000000000000000000000..757a0eb98214ba020d76095363d424b9209540e7 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11kv2/fold_0.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 0)""" + +from . import _maker + +dataset = _maker("fold_0") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11kv2/fold_0_rgb.py b/src/ptbench/configs/datasets/mc_ch_in_11kv2/fold_0_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..48e05ff3f71f13976190d04cfaf59c5c36996bac --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11kv2/fold_0_rgb.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 0, RGB)""" + +from . import _maker + +dataset = _maker("fold_0_rgb") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11kv2/fold_1.py b/src/ptbench/configs/datasets/mc_ch_in_11kv2/fold_1.py new file mode 100644 index 0000000000000000000000000000000000000000..5657958934b926879bd26503c9b383e775bc724d --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11kv2/fold_1.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 1)""" + +from . import _maker + +dataset = _maker("fold_1") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11kv2/fold_1_rgb.py b/src/ptbench/configs/datasets/mc_ch_in_11kv2/fold_1_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..c782d68de247c876ddd6826100cbb7908342b928 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11kv2/fold_1_rgb.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 1, RGB)""" + +from . import _maker + +dataset = _maker("fold_1_rgb") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11kv2/fold_2.py b/src/ptbench/configs/datasets/mc_ch_in_11kv2/fold_2.py new file mode 100644 index 0000000000000000000000000000000000000000..10a597bcb8e0485db63f0d7500b15b3e78877066 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11kv2/fold_2.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 2)""" + +from . import _maker + +dataset = _maker("fold_2") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11kv2/fold_2_rgb.py b/src/ptbench/configs/datasets/mc_ch_in_11kv2/fold_2_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..d624f3af53abcf053c7bf17a9822a86cb53e2923 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11kv2/fold_2_rgb.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 2, RGB)""" + +from . import _maker + +dataset = _maker("fold_2_rgb") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11kv2/fold_3.py b/src/ptbench/configs/datasets/mc_ch_in_11kv2/fold_3.py new file mode 100644 index 0000000000000000000000000000000000000000..39bee4fec99e81eecc22a365183283bcd2ec3d98 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11kv2/fold_3.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 3)""" + +from . import _maker + +dataset = _maker("fold_3") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11kv2/fold_3_rgb.py b/src/ptbench/configs/datasets/mc_ch_in_11kv2/fold_3_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..7b26e4257e61013843e3a62c3bc419003e23b645 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11kv2/fold_3_rgb.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 3, RGB)""" + +from . import _maker + +dataset = _maker("fold_3_rgb") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11kv2/fold_4.py b/src/ptbench/configs/datasets/mc_ch_in_11kv2/fold_4.py new file mode 100644 index 0000000000000000000000000000000000000000..5fb56292fd97636f452cde06c87bb34c89f01b1c --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11kv2/fold_4.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 4)""" + +from . import _maker + +dataset = _maker("fold_4") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11kv2/fold_4_rgb.py b/src/ptbench/configs/datasets/mc_ch_in_11kv2/fold_4_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..fbc4f0cfd9edc602fbe5665aca0465b29c5183b5 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11kv2/fold_4_rgb.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 4, RGB)""" + +from . import _maker + +dataset = _maker("fold_4_rgb") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11kv2/fold_5.py b/src/ptbench/configs/datasets/mc_ch_in_11kv2/fold_5.py new file mode 100644 index 0000000000000000000000000000000000000000..679bb9b3cbbdede06cd87834239609720f439296 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11kv2/fold_5.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 5)""" + +from . import _maker + +dataset = _maker("fold_5") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11kv2/fold_5_rgb.py b/src/ptbench/configs/datasets/mc_ch_in_11kv2/fold_5_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..747d510ecd1c7bd2f32ab7b139a53603d5bbee88 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11kv2/fold_5_rgb.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 5, RGB)""" + +from . import _maker + +dataset = _maker("fold_5_rgb") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11kv2/fold_6.py b/src/ptbench/configs/datasets/mc_ch_in_11kv2/fold_6.py new file mode 100644 index 0000000000000000000000000000000000000000..cd8e4cd571b8c796bad3221584870888c5186d3d --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11kv2/fold_6.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 6)""" + +from . import _maker + +dataset = _maker("fold_6") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11kv2/fold_6_rgb.py b/src/ptbench/configs/datasets/mc_ch_in_11kv2/fold_6_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..86f112c3aae0c1c1dd48002347f78ce565797d47 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11kv2/fold_6_rgb.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 6, RGB)""" + +from . import _maker + +dataset = _maker("fold_6_rgb") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11kv2/fold_7.py b/src/ptbench/configs/datasets/mc_ch_in_11kv2/fold_7.py new file mode 100644 index 0000000000000000000000000000000000000000..98241531d3e15720f07ef9174687c47db7d737f1 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11kv2/fold_7.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 7)""" + +from . import _maker + +dataset = _maker("fold_7") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11kv2/fold_7_rgb.py b/src/ptbench/configs/datasets/mc_ch_in_11kv2/fold_7_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..981fe19180e0d8d4e1b21653f52a92a567723a63 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11kv2/fold_7_rgb.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 7, RGB)""" + +from . import _maker + +dataset = _maker("fold_7_rgb") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11kv2/fold_8.py b/src/ptbench/configs/datasets/mc_ch_in_11kv2/fold_8.py new file mode 100644 index 0000000000000000000000000000000000000000..dab1a234a3842ab450706d86060651d4383ddbfc --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11kv2/fold_8.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 8)""" + +from . import _maker + +dataset = _maker("fold_8") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11kv2/fold_8_rgb.py b/src/ptbench/configs/datasets/mc_ch_in_11kv2/fold_8_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..798b8de64761ef0d87f491ef08b43426f55898f2 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11kv2/fold_8_rgb.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 8, RGB)""" + +from . import _maker + +dataset = _maker("fold_8_rgb") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11kv2/fold_9.py b/src/ptbench/configs/datasets/mc_ch_in_11kv2/fold_9.py new file mode 100644 index 0000000000000000000000000000000000000000..097724b9446c4c2f0bef8ee6f838c1c11ff627a5 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11kv2/fold_9.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 9)""" + +from . import _maker + +dataset = _maker("fold_9") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11kv2/fold_9_rgb.py b/src/ptbench/configs/datasets/mc_ch_in_11kv2/fold_9_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..c6c564a40b957b562a37bb30b5809f7cf680e896 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11kv2/fold_9_rgb.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 9, RGB)""" + +from . import _maker + +dataset = _maker("fold_9_rgb") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11kv2/rgb.py b/src/ptbench/configs/datasets/mc_ch_in_11kv2/rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..f47796a89c31a5a31c0f972d81b5d97c7f8742b4 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11kv2/rgb.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (RGB)""" + +from . import _maker + +dataset = _maker("rgb") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/__init__.py b/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5e4b82769c73b5b89c8841990b33c7acc2014fed --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/__init__.py @@ -0,0 +1,157 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from torch.utils.data.dataset import ConcatDataset + + +def _maker(protocol): + if protocol == "default": + from ..indian_RS import default as indian + from ..montgomery_RS import default as mc + from ..shenzhen_RS import default as ch + from ..tbx11k_simplified_v2_RS import default as tbx11kv2 + elif protocol == "rgb": + from ..indian_RS import rgb as indian + from ..montgomery_RS import rgb as mc + from ..shenzhen_RS import rgb as ch + from ..tbx11k_simplified_v2_RS import rgb as tbx11kv2 + elif protocol == "fold_0": + from ..indian_RS import fold_0 as indian + from ..montgomery_RS import fold_0 as mc + from ..shenzhen_RS import fold_0 as ch + from ..tbx11k_simplified_v2_RS import fold_0 as tbx11kv2 + elif protocol == "fold_1": + from ..indian_RS import fold_1 as indian + from ..montgomery_RS import fold_1 as mc + from ..shenzhen_RS import fold_1 as ch + from ..tbx11k_simplified_v2_RS import fold_1 as tbx11kv2 + elif protocol == "fold_2": + from ..indian_RS import fold_2 as indian + from ..montgomery_RS import fold_2 as mc + from ..shenzhen_RS import fold_2 as ch + from ..tbx11k_simplified_v2_RS import fold_2 as tbx11kv2 + elif protocol == "fold_3": + from ..indian_RS import fold_3 as indian + from ..montgomery_RS import fold_3 as mc + from ..shenzhen_RS import fold_3 as ch + from ..tbx11k_simplified_v2_RS import fold_3 as tbx11kv2 + elif protocol == "fold_4": + from ..indian_RS import fold_4 as indian + from ..montgomery_RS import fold_4 as mc + from ..shenzhen_RS import fold_4 as ch + from ..tbx11k_simplified_v2_RS import fold_4 as tbx11kv2 + elif protocol == "fold_5": + from ..indian_RS import fold_5 as indian + from ..montgomery_RS import fold_5 as mc + from ..shenzhen_RS import fold_5 as ch + from ..tbx11k_simplified_v2_RS import fold_5 as tbx11kv2 + elif protocol == "fold_6": + from ..indian_RS import fold_6 as indian + from ..montgomery_RS import fold_6 as mc + from ..shenzhen_RS import fold_6 as ch + from ..tbx11k_simplified_v2_RS import fold_6 as tbx11kv2 + elif protocol == "fold_7": + from ..indian_RS import fold_7 as indian + from ..montgomery_RS import fold_7 as mc + from ..shenzhen_RS import fold_7 as ch + from ..tbx11k_simplified_v2_RS import fold_7 as tbx11kv2 + elif protocol == "fold_8": + from ..indian_RS import fold_8 as indian + from ..montgomery_RS import fold_8 as mc + from ..shenzhen_RS import fold_8 as ch + from ..tbx11k_simplified_v2_RS import fold_8 as tbx11kv2 + elif protocol == "fold_9": + from ..indian_RS import fold_9 as indian + from ..montgomery_RS import fold_9 as mc + from ..shenzhen_RS import fold_9 as ch + from ..tbx11k_simplified_v2_RS import fold_9 as tbx11kv2 + elif protocol == "fold_0_rgb": + from ..indian_RS import fold_0_rgb as indian + from ..montgomery_RS import fold_0_rgb as mc + from ..shenzhen_RS import fold_0_rgb as ch + from ..tbx11k_simplified_v2_RS import fold_0_rgb as tbx11kv2 + elif protocol == "fold_1_rgb": + from ..indian_RS import fold_1_rgb as indian + from ..montgomery_RS import fold_1_rgb as mc + from ..shenzhen_RS import fold_1_rgb as ch + from ..tbx11k_simplified_v2_RS import fold_1_rgb as tbx11kv2 + elif protocol == "fold_2_rgb": + from ..indian_RS import fold_2_rgb as indian + from ..montgomery_RS import fold_2_rgb as mc + from ..shenzhen_RS import fold_2_rgb as ch + from ..tbx11k_simplified_v2_RS import fold_2_rgb as tbx11kv2 + elif protocol == "fold_3_rgb": + from ..indian_RS import fold_3_rgb as indian + from ..montgomery_RS import fold_3_rgb as mc + from ..shenzhen_RS import fold_3_rgb as ch + from ..tbx11k_simplified_v2_RS import fold_3_rgb as tbx11kv2 + elif protocol == "fold_4_rgb": + from ..indian_RS import fold_4_rgb as indian + from ..montgomery_RS import fold_4_rgb as mc + from ..shenzhen_RS import fold_4_rgb as ch + from ..tbx11k_simplified_v2_RS import fold_4_rgb as tbx11kv2 + elif protocol == "fold_5_rgb": + from ..indian_RS import fold_5_rgb as indian + from ..montgomery_RS import fold_5_rgb as mc + from ..shenzhen_RS import fold_5_rgb as ch + from ..tbx11k_simplified_v2_RS import fold_5_rgb as tbx11kv2 + elif protocol == "fold_6_rgb": + from ..indian_RS import fold_6_rgb as indian + from ..montgomery_RS import fold_6_rgb as mc + from ..shenzhen_RS import fold_6_rgb as ch + from ..tbx11k_simplified_v2_RS import fold_6_rgb as tbx11kv2 + elif protocol == "fold_7_rgb": + from ..indian_RS import fold_7_rgb as indian + from ..montgomery_RS import fold_7_rgb as mc + from ..shenzhen_RS import fold_7_rgb as ch + from ..tbx11k_simplified_v2_RS import fold_7_rgb as tbx11kv2 + elif protocol == "fold_8_rgb": + from ..indian_RS import fold_8_rgb as indian + from ..montgomery_RS import fold_8_rgb as mc + from ..shenzhen_RS import fold_8_rgb as ch + from ..tbx11k_simplified_v2_RS import fold_8_rgb as tbx11kv2 + elif protocol == "fold_9_rgb": + from ..indian_RS import fold_9_rgb as indian + from ..montgomery_RS import fold_9_rgb as mc + from ..shenzhen_RS import fold_9_rgb as ch + from ..tbx11k_simplified_v2_RS import fold_9_rgb as tbx11kv2 + + mc = mc.dataset + ch = ch.dataset + indian = indian.dataset + tbx11kv2 = tbx11kv2.dataset + + dataset = {} + dataset["__train__"] = ConcatDataset( + [ + mc["__train__"], + ch["__train__"], + indian["__train__"], + tbx11kv2["__train__"], + ] + ) + dataset["train"] = ConcatDataset( + [mc["train"], ch["train"], indian["train"], tbx11kv2["train"]] + ) + dataset["__valid__"] = ConcatDataset( + [ + mc["__valid__"], + ch["__valid__"], + indian["__valid__"], + tbx11kv2["__valid__"], + ] + ) + dataset["validation"] = ConcatDataset( + [ + mc["validation"], + ch["validation"], + indian["validation"], + tbx11kv2["validation"], + ] + ) + dataset["test"] = ConcatDataset( + [mc["test"], ch["test"], indian["test"], tbx11kv2["test"]] + ) + + return dataset diff --git a/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/default.py b/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/default.py new file mode 100644 index 0000000000000000000000000000000000000000..7d4f16bda48b05e7e9302ffc9c689d8393b3e495 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/default.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets.""" + +from . import _maker + +dataset = _maker("default") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/fold_0.py b/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/fold_0.py new file mode 100644 index 0000000000000000000000000000000000000000..757a0eb98214ba020d76095363d424b9209540e7 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/fold_0.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 0)""" + +from . import _maker + +dataset = _maker("fold_0") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/fold_0_rgb.py b/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/fold_0_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..48e05ff3f71f13976190d04cfaf59c5c36996bac --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/fold_0_rgb.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 0, RGB)""" + +from . import _maker + +dataset = _maker("fold_0_rgb") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/fold_1.py b/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/fold_1.py new file mode 100644 index 0000000000000000000000000000000000000000..5657958934b926879bd26503c9b383e775bc724d --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/fold_1.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 1)""" + +from . import _maker + +dataset = _maker("fold_1") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/fold_1_rgb.py b/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/fold_1_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..c782d68de247c876ddd6826100cbb7908342b928 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/fold_1_rgb.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 1, RGB)""" + +from . import _maker + +dataset = _maker("fold_1_rgb") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/fold_2.py b/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/fold_2.py new file mode 100644 index 0000000000000000000000000000000000000000..10a597bcb8e0485db63f0d7500b15b3e78877066 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/fold_2.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 2)""" + +from . import _maker + +dataset = _maker("fold_2") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/fold_2_rgb.py b/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/fold_2_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..d624f3af53abcf053c7bf17a9822a86cb53e2923 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/fold_2_rgb.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 2, RGB)""" + +from . import _maker + +dataset = _maker("fold_2_rgb") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/fold_3.py b/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/fold_3.py new file mode 100644 index 0000000000000000000000000000000000000000..39bee4fec99e81eecc22a365183283bcd2ec3d98 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/fold_3.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 3)""" + +from . import _maker + +dataset = _maker("fold_3") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/fold_3_rgb.py b/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/fold_3_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..7b26e4257e61013843e3a62c3bc419003e23b645 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/fold_3_rgb.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 3, RGB)""" + +from . import _maker + +dataset = _maker("fold_3_rgb") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/fold_4.py b/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/fold_4.py new file mode 100644 index 0000000000000000000000000000000000000000..5fb56292fd97636f452cde06c87bb34c89f01b1c --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/fold_4.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 4)""" + +from . import _maker + +dataset = _maker("fold_4") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/fold_4_rgb.py b/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/fold_4_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..fbc4f0cfd9edc602fbe5665aca0465b29c5183b5 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/fold_4_rgb.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 4, RGB)""" + +from . import _maker + +dataset = _maker("fold_4_rgb") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/fold_5.py b/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/fold_5.py new file mode 100644 index 0000000000000000000000000000000000000000..679bb9b3cbbdede06cd87834239609720f439296 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/fold_5.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 5)""" + +from . import _maker + +dataset = _maker("fold_5") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/fold_5_rgb.py b/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/fold_5_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..747d510ecd1c7bd2f32ab7b139a53603d5bbee88 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/fold_5_rgb.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 5, RGB)""" + +from . import _maker + +dataset = _maker("fold_5_rgb") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/fold_6.py b/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/fold_6.py new file mode 100644 index 0000000000000000000000000000000000000000..cd8e4cd571b8c796bad3221584870888c5186d3d --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/fold_6.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 6)""" + +from . import _maker + +dataset = _maker("fold_6") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/fold_6_rgb.py b/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/fold_6_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..86f112c3aae0c1c1dd48002347f78ce565797d47 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/fold_6_rgb.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 6, RGB)""" + +from . import _maker + +dataset = _maker("fold_6_rgb") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/fold_7.py b/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/fold_7.py new file mode 100644 index 0000000000000000000000000000000000000000..98241531d3e15720f07ef9174687c47db7d737f1 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/fold_7.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 7)""" + +from . import _maker + +dataset = _maker("fold_7") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/fold_7_rgb.py b/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/fold_7_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..981fe19180e0d8d4e1b21653f52a92a567723a63 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/fold_7_rgb.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 7, RGB)""" + +from . import _maker + +dataset = _maker("fold_7_rgb") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/fold_8.py b/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/fold_8.py new file mode 100644 index 0000000000000000000000000000000000000000..dab1a234a3842ab450706d86060651d4383ddbfc --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/fold_8.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 8)""" + +from . import _maker + +dataset = _maker("fold_8") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/fold_8_rgb.py b/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/fold_8_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..798b8de64761ef0d87f491ef08b43426f55898f2 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/fold_8_rgb.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 8, RGB)""" + +from . import _maker + +dataset = _maker("fold_8_rgb") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/fold_9.py b/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/fold_9.py new file mode 100644 index 0000000000000000000000000000000000000000..097724b9446c4c2f0bef8ee6f838c1c11ff627a5 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/fold_9.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 9)""" + +from . import _maker + +dataset = _maker("fold_9") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/fold_9_rgb.py b/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/fold_9_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..c6c564a40b957b562a37bb30b5809f7cf680e896 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/fold_9_rgb.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 9, RGB)""" + +from . import _maker + +dataset = _maker("fold_9_rgb") diff --git a/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/rgb.py b/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..f47796a89c31a5a31c0f972d81b5d97c7f8742b4 --- /dev/null +++ b/src/ptbench/configs/datasets/mc_ch_in_11kv2_RS/rgb.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (RGB)""" + +from . import _maker + +dataset = _maker("rgb") diff --git a/src/ptbench/configs/datasets/tbx11k_simplified/__init__.py b/src/ptbench/configs/datasets/tbx11k_simplified/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..595c323178507ec04f9dc1768e080310b9ca7d50 --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified/__init__.py @@ -0,0 +1,25 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + + +def _maker(protocol, RGB=False): + from torchvision import transforms + + from ....data.tbx11k_simplified import dataset as raw + from ....data.transforms import ElasticDeformation + from .. import make_dataset as mk + + post_transforms = [] + if RGB: + post_transforms = [ + transforms.Lambda(lambda x: x.convert("RGB")), + transforms.ToTensor(), + ] + + return mk( + [raw.subsets(protocol)], + [], + [ElasticDeformation(p=0.8)], + post_transforms, + ) diff --git a/src/ptbench/configs/datasets/tbx11k_simplified/default.py b/src/ptbench/configs/datasets/tbx11k_simplified/default.py new file mode 100644 index 0000000000000000000000000000000000000000..cb23f352123b276cb4de22a744018900d355bf3c --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified/default.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (default protocol) + +* Split reference: first 62.5% of TB and healthy CXR for "train" 15.9% for +* "validation", 21.6% for "test" +* This split only consists of healthy and active TB samples +* "Latent TB" or "sick & non-TB" samples are not included in this configuration +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified` for dataset details +""" + +from . import _maker + +dataset = _maker("default") diff --git a/src/ptbench/configs/datasets/tbx11k_simplified/fold_0.py b/src/ptbench/configs/datasets/tbx11k_simplified/fold_0.py new file mode 100644 index 0000000000000000000000000000000000000000..8f907af758f60609f2992faf80b979f41f2c4807 --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified/fold_0.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (default protocol) + +* Split reference: first 62.5% of TB and healthy CXR for "train" 15.9% for +* "validation", 21.6% for "test" +* This split only consists of healthy and active TB samples +* "Latent TB" or "sick & non-TB" samples are not included in this configuration +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified` for dataset details +""" + +from . import _maker + +dataset = _maker("fold_0") diff --git a/src/ptbench/configs/datasets/tbx11k_simplified/fold_0_rgb.py b/src/ptbench/configs/datasets/tbx11k_simplified/fold_0_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..141a3e954167bebd40662a3ae6481d86fb77c0f6 --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified/fold_0_rgb.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (default protocol) + +* Split reference: first 62.5% of TB and healthy CXR for "train" 15.9% for +* "validation", 21.6% for "test" +* This split only consists of healthy and active TB samples +* "Latent TB" or "sick & non-TB" samples are not included in this configuration +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified` for dataset details +""" + +from . import _maker + +dataset = _maker("fold_0", RGB=True) diff --git a/src/ptbench/configs/datasets/tbx11k_simplified/fold_1.py b/src/ptbench/configs/datasets/tbx11k_simplified/fold_1.py new file mode 100644 index 0000000000000000000000000000000000000000..cb7e2360f48ae3d3a631b15aac0ce19a427a582d --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified/fold_1.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (default protocol) + +* Split reference: first 62.5% of TB and healthy CXR for "train" 15.9% for +* "validation", 21.6% for "test" +* This split only consists of healthy and active TB samples +* "Latent TB" or "sick & non-TB" samples are not included in this configuration +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified` for dataset details +""" + +from . import _maker + +dataset = _maker("fold_1") diff --git a/src/ptbench/configs/datasets/tbx11k_simplified/fold_1_rgb.py b/src/ptbench/configs/datasets/tbx11k_simplified/fold_1_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..a9cecd0b9da5d9f942d729ad97b69ed7399ec6b1 --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified/fold_1_rgb.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (default protocol) + +* Split reference: first 62.5% of TB and healthy CXR for "train" 15.9% for +* "validation", 21.6% for "test" +* This split only consists of healthy and active TB samples +* "Latent TB" or "sick & non-TB" samples are not included in this configuration +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified` for dataset details +""" + +from . import _maker + +dataset = _maker("fold_1", RGB=True) diff --git a/src/ptbench/configs/datasets/tbx11k_simplified/fold_2.py b/src/ptbench/configs/datasets/tbx11k_simplified/fold_2.py new file mode 100644 index 0000000000000000000000000000000000000000..1cffe2fbf4ed29f04002c08097e27d7083e2eed5 --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified/fold_2.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (default protocol) + +* Split reference: first 62.5% of TB and healthy CXR for "train" 15.9% for +* "validation", 21.6% for "test" +* This split only consists of healthy and active TB samples +* "Latent TB" or "sick & non-TB" samples are not included in this configuration +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified` for dataset details +""" + +from . import _maker + +dataset = _maker("fold_2") diff --git a/src/ptbench/configs/datasets/tbx11k_simplified/fold_2_rgb.py b/src/ptbench/configs/datasets/tbx11k_simplified/fold_2_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..fadd67cc66c10cdee76ecda9388ecb505540f423 --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified/fold_2_rgb.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (default protocol) + +* Split reference: first 62.5% of TB and healthy CXR for "train" 15.9% for +* "validation", 21.6% for "test" +* This split only consists of healthy and active TB samples +* "Latent TB" or "sick & non-TB" samples are not included in this configuration +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified` for dataset details +""" + +from . import _maker + +dataset = _maker("fold_2", RGB=True) diff --git a/src/ptbench/configs/datasets/tbx11k_simplified/fold_3.py b/src/ptbench/configs/datasets/tbx11k_simplified/fold_3.py new file mode 100644 index 0000000000000000000000000000000000000000..64743fbd76490c567eff0447cd9c2221305a34e8 --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified/fold_3.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (default protocol) + +* Split reference: first 62.5% of TB and healthy CXR for "train" 15.9% for +* "validation", 21.6% for "test" +* This split only consists of healthy and active TB samples +* "Latent TB" or "sick & non-TB" samples are not included in this configuration +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified` for dataset details +""" + +from . import _maker + +dataset = _maker("fold_3") diff --git a/src/ptbench/configs/datasets/tbx11k_simplified/fold_3_rgb.py b/src/ptbench/configs/datasets/tbx11k_simplified/fold_3_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..d17f3e6a71440d7b076d3145de4e548903dbb033 --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified/fold_3_rgb.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (default protocol) + +* Split reference: first 62.5% of TB and healthy CXR for "train" 15.9% for +* "validation", 21.6% for "test" +* This split only consists of healthy and active TB samples +* "Latent TB" or "sick & non-TB" samples are not included in this configuration +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified` for dataset details +""" + +from . import _maker + +dataset = _maker("fold_3", RGB=True) diff --git a/src/ptbench/configs/datasets/tbx11k_simplified/fold_4.py b/src/ptbench/configs/datasets/tbx11k_simplified/fold_4.py new file mode 100644 index 0000000000000000000000000000000000000000..0739b98ba941e5ed19204baa233553a262404671 --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified/fold_4.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (default protocol) + +* Split reference: first 62.5% of TB and healthy CXR for "train" 15.9% for +* "validation", 21.6% for "test" +* This split only consists of healthy and active TB samples +* "Latent TB" or "sick & non-TB" samples are not included in this configuration +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified` for dataset details +""" + +from . import _maker + +dataset = _maker("fold_4") diff --git a/src/ptbench/configs/datasets/tbx11k_simplified/fold_4_rgb.py b/src/ptbench/configs/datasets/tbx11k_simplified/fold_4_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..83a560c45f6a7b1f57990e2cb16ab4cb0a80284f --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified/fold_4_rgb.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (default protocol) + +* Split reference: first 62.5% of TB and healthy CXR for "train" 15.9% for +* "validation", 21.6% for "test" +* This split only consists of healthy and active TB samples +* "Latent TB" or "sick & non-TB" samples are not included in this configuration +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified` for dataset details +""" + +from . import _maker + +dataset = _maker("fold_4", RGB=True) diff --git a/src/ptbench/configs/datasets/tbx11k_simplified/fold_5.py b/src/ptbench/configs/datasets/tbx11k_simplified/fold_5.py new file mode 100644 index 0000000000000000000000000000000000000000..dc200e162dbef1f3b5f3ebf17cf57cb5f713c3f9 --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified/fold_5.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (default protocol) + +* Split reference: first 62.5% of TB and healthy CXR for "train" 15.9% for +* "validation", 21.6% for "test" +* This split only consists of healthy and active TB samples +* "Latent TB" or "sick & non-TB" samples are not included in this configuration +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified` for dataset details +""" + +from . import _maker + +dataset = _maker("fold_5") diff --git a/src/ptbench/configs/datasets/tbx11k_simplified/fold_5_rgb.py b/src/ptbench/configs/datasets/tbx11k_simplified/fold_5_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..e7471a1fbaca1f11860befacaf798ba14cbbe878 --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified/fold_5_rgb.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (default protocol) + +* Split reference: first 62.5% of TB and healthy CXR for "train" 15.9% for +* "validation", 21.6% for "test" +* This split only consists of healthy and active TB samples +* "Latent TB" or "sick & non-TB" samples are not included in this configuration +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified` for dataset details +""" + +from . import _maker + +dataset = _maker("fold_5", RGB=True) diff --git a/src/ptbench/configs/datasets/tbx11k_simplified/fold_6.py b/src/ptbench/configs/datasets/tbx11k_simplified/fold_6.py new file mode 100644 index 0000000000000000000000000000000000000000..6bfb172bbba8716f38f5fca924243a794b97e9bd --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified/fold_6.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (default protocol) + +* Split reference: first 62.5% of TB and healthy CXR for "train" 15.9% for +* "validation", 21.6% for "test" +* This split only consists of healthy and active TB samples +* "Latent TB" or "sick & non-TB" samples are not included in this configuration +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified` for dataset details +""" + +from . import _maker + +dataset = _maker("fold_6") diff --git a/src/ptbench/configs/datasets/tbx11k_simplified/fold_6_rgb.py b/src/ptbench/configs/datasets/tbx11k_simplified/fold_6_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..899877819594fa70094c34aec910a4c11b254289 --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified/fold_6_rgb.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (default protocol) + +* Split reference: first 62.5% of TB and healthy CXR for "train" 15.9% for +* "validation", 21.6% for "test" +* This split only consists of healthy and active TB samples +* "Latent TB" or "sick & non-TB" samples are not included in this configuration +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified` for dataset details +""" + +from . import _maker + +dataset = _maker("fold_6", RGB=True) diff --git a/src/ptbench/configs/datasets/tbx11k_simplified/fold_7.py b/src/ptbench/configs/datasets/tbx11k_simplified/fold_7.py new file mode 100644 index 0000000000000000000000000000000000000000..553504966744d0cf415c0a66d473b5e49dae405d --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified/fold_7.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (default protocol) + +* Split reference: first 62.5% of TB and healthy CXR for "train" 15.9% for +* "validation", 21.6% for "test" +* This split only consists of healthy and active TB samples +* "Latent TB" or "sick & non-TB" samples are not included in this configuration +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified` for dataset details +""" + +from . import _maker + +dataset = _maker("fold_7") diff --git a/src/ptbench/configs/datasets/tbx11k_simplified/fold_7_rgb.py b/src/ptbench/configs/datasets/tbx11k_simplified/fold_7_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..82b419f01cf1584b6e6a2eb39fbb02b1b02d0960 --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified/fold_7_rgb.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (default protocol) + +* Split reference: first 62.5% of TB and healthy CXR for "train" 15.9% for +* "validation", 21.6% for "test" +* This split only consists of healthy and active TB samples +* "Latent TB" or "sick & non-TB" samples are not included in this configuration +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified` for dataset details +""" + +from . import _maker + +dataset = _maker("fold_7", RGB=True) diff --git a/src/ptbench/configs/datasets/tbx11k_simplified/fold_8.py b/src/ptbench/configs/datasets/tbx11k_simplified/fold_8.py new file mode 100644 index 0000000000000000000000000000000000000000..d4cdc774097b76adc12ffb170880887c09b8cfa7 --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified/fold_8.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (default protocol) + +* Split reference: first 62.5% of TB and healthy CXR for "train" 15.9% for +* "validation", 21.6% for "test" +* This split only consists of healthy and active TB samples +* "Latent TB" or "sick & non-TB" samples are not included in this configuration +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified` for dataset details +""" + +from . import _maker + +dataset = _maker("fold_8") diff --git a/src/ptbench/configs/datasets/tbx11k_simplified/fold_8_rgb.py b/src/ptbench/configs/datasets/tbx11k_simplified/fold_8_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..8d008ce3e41b528907f57acf7e93443aa26a9bc6 --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified/fold_8_rgb.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (default protocol) + +* Split reference: first 62.5% of TB and healthy CXR for "train" 15.9% for +* "validation", 21.6% for "test" +* This split only consists of healthy and active TB samples +* "Latent TB" or "sick & non-TB" samples are not included in this configuration +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified` for dataset details +""" + +from . import _maker + +dataset = _maker("fold_8", RGB=True) diff --git a/src/ptbench/configs/datasets/tbx11k_simplified/fold_9.py b/src/ptbench/configs/datasets/tbx11k_simplified/fold_9.py new file mode 100644 index 0000000000000000000000000000000000000000..f46b35499af3a6d3fc59789955151f7226c83287 --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified/fold_9.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (default protocol) + +* Split reference: first 62.5% of TB and healthy CXR for "train" 15.9% for +* "validation", 21.6% for "test" +* This split only consists of healthy and active TB samples +* "Latent TB" or "sick & non-TB" samples are not included in this configuration +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified` for dataset details +""" + +from . import _maker + +dataset = _maker("fold_9") diff --git a/src/ptbench/configs/datasets/tbx11k_simplified/fold_9_rgb.py b/src/ptbench/configs/datasets/tbx11k_simplified/fold_9_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..73507723ecd5a5794a5b4cd02f52344b39482188 --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified/fold_9_rgb.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (default protocol) + +* Split reference: first 62.5% of TB and healthy CXR for "train" 15.9% for +* "validation", 21.6% for "test" +* This split only consists of healthy and active TB samples +* "Latent TB" or "sick & non-TB" samples are not included in this configuration +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified` for dataset details +""" + +from . import _maker + +dataset = _maker("fold_9", RGB=True) diff --git a/src/ptbench/configs/datasets/tbx11k_simplified/rgb.py b/src/ptbench/configs/datasets/tbx11k_simplified/rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..e22dff69194f5ad95eaa636aa87ce1e21ca6d13a --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified/rgb.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (default protocol) + +* Split reference: first 62.5% of TB and healthy CXR for "train" 15.9% for +* "validation", 21.6% for "test" +* This split only consists of healthy and active TB samples +* "Latent TB" or "sick & non-TB" samples are not included in this configuration +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified` for dataset details +""" + +from . import _maker + +dataset = _maker("default", RGB=True) diff --git a/src/ptbench/configs/datasets/tbx11k_simplified_RS/__init__.py b/src/ptbench/configs/datasets/tbx11k_simplified_RS/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7e9e9b1f0b62dccd9a2e66fefda238176b868178 --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified_RS/__init__.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + + +def _maker(protocol): + from ....data.tbx11k_simplified_RS import dataset as raw + from .. import make_dataset as mk + + return mk([raw.subsets(protocol)]) diff --git a/src/ptbench/configs/datasets/tbx11k_simplified_RS/default.py b/src/ptbench/configs/datasets/tbx11k_simplified_RS/default.py new file mode 100644 index 0000000000000000000000000000000000000000..e5a51eb7346e6f540be00d6f88741f308c670958 --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified_RS/default.py @@ -0,0 +1,18 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (default protocol) (extended with +DensenetRS predictions) + +* Split reference: first 62.5% of TB and healthy CXR for "train" 15.9% for +* "validation", 21.6% for "test" +* This split only consists of healthy and active TB samples +* "Latent TB" or "sick & non-TB" samples are not included in this configuration +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified_RS` for dataset details +""" + +from . import _maker + +dataset = _maker("default") diff --git a/src/ptbench/configs/datasets/tbx11k_simplified_RS/fold_0.py b/src/ptbench/configs/datasets/tbx11k_simplified_RS/fold_0.py new file mode 100644 index 0000000000000000000000000000000000000000..c9c14cd56b67a66895d20eaf08846cfb69816244 --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified_RS/fold_0.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (cross validation fold 0) + +* Split reference: first 62.5% of TB and healthy CXR for "train" 15.9% for +* "validation", 21.6% for "test" +* This split only consists of healthy and active TB samples +* "Latent TB" or "sick & non-TB" samples are not included in this configuration +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified_RS` for dataset details +""" + +from . import _maker + +dataset = _maker("fold_0") diff --git a/src/ptbench/configs/datasets/tbx11k_simplified_RS/fold_1.py b/src/ptbench/configs/datasets/tbx11k_simplified_RS/fold_1.py new file mode 100644 index 0000000000000000000000000000000000000000..998469b93dc4b7d09b2569905e5ffdbeabfd31c7 --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified_RS/fold_1.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (cross validation fold 1) + +* Split reference: first 62.5% of TB and healthy CXR for "train" 15.9% for +* "validation", 21.6% for "test" +* This split only consists of healthy and active TB samples +* "Latent TB" or "sick & non-TB" samples are not included in this configuration +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified_RS` for dataset details +""" + +from . import _maker + +dataset = _maker("fold_1") diff --git a/src/ptbench/configs/datasets/tbx11k_simplified_RS/fold_2.py b/src/ptbench/configs/datasets/tbx11k_simplified_RS/fold_2.py new file mode 100644 index 0000000000000000000000000000000000000000..642f4ae40c3d14159d685cec776449ad5bbe9b3f --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified_RS/fold_2.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (cross validation fold 2) + +* Split reference: first 62.5% of TB and healthy CXR for "train" 15.9% for +* "validation", 21.6% for "test" +* This split only consists of healthy and active TB samples +* "Latent TB" or "sick & non-TB" samples are not included in this configuration +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified_RS` for dataset details +""" + +from . import _maker + +dataset = _maker("fold_2") diff --git a/src/ptbench/configs/datasets/tbx11k_simplified_RS/fold_3.py b/src/ptbench/configs/datasets/tbx11k_simplified_RS/fold_3.py new file mode 100644 index 0000000000000000000000000000000000000000..cdf0e5ca9d4d00aa3e75c75c1a34fcbea93ca652 --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified_RS/fold_3.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (cross validation fold 3) + +* Split reference: first 62.5% of TB and healthy CXR for "train" 15.9% for +* "validation", 21.6% for "test" +* This split only consists of healthy and active TB samples +* "Latent TB" or "sick & non-TB" samples are not included in this configuration +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified_RS` for dataset details +""" + +from . import _maker + +dataset = _maker("fold_3") diff --git a/src/ptbench/configs/datasets/tbx11k_simplified_RS/fold_4.py b/src/ptbench/configs/datasets/tbx11k_simplified_RS/fold_4.py new file mode 100644 index 0000000000000000000000000000000000000000..909d4abed9762680405533190eb8fec5e372efc9 --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified_RS/fold_4.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (cross validation fold 4) + +* Split reference: first 62.5% of TB and healthy CXR for "train" 15.9% for +* "validation", 21.6% for "test" +* This split only consists of healthy and active TB samples +* "Latent TB" or "sick & non-TB" samples are not included in this configuration +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified_RS` for dataset details +""" + +from . import _maker + +dataset = _maker("fold_4") diff --git a/src/ptbench/configs/datasets/tbx11k_simplified_RS/fold_5.py b/src/ptbench/configs/datasets/tbx11k_simplified_RS/fold_5.py new file mode 100644 index 0000000000000000000000000000000000000000..969b13a4f4d9ade3efe6ec5fc88edf1618d0dd9c --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified_RS/fold_5.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (cross validation fold 5) + +* Split reference: first 62.5% of TB and healthy CXR for "train" 15.9% for +* "validation", 21.6% for "test" +* This split only consists of healthy and active TB samples +* "Latent TB" or "sick & non-TB" samples are not included in this configuration +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified_RS` for dataset details +""" + +from . import _maker + +dataset = _maker("fold_5") diff --git a/src/ptbench/configs/datasets/tbx11k_simplified_RS/fold_6.py b/src/ptbench/configs/datasets/tbx11k_simplified_RS/fold_6.py new file mode 100644 index 0000000000000000000000000000000000000000..bde32bef1bccd97d27bd2d70980d0f5402438913 --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified_RS/fold_6.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (cross validation fold 6) + +* Split reference: first 62.5% of TB and healthy CXR for "train" 15.9% for +* "validation", 21.6% for "test" +* This split only consists of healthy and active TB samples +* "Latent TB" or "sick & non-TB" samples are not included in this configuration +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified_RS` for dataset details +""" + +from . import _maker + +dataset = _maker("fold_6") diff --git a/src/ptbench/configs/datasets/tbx11k_simplified_RS/fold_7.py b/src/ptbench/configs/datasets/tbx11k_simplified_RS/fold_7.py new file mode 100644 index 0000000000000000000000000000000000000000..56fdd69c2b40d444b4544377ad33cf56eefef6f4 --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified_RS/fold_7.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (cross validation fold 7) + +* Split reference: first 62.5% of TB and healthy CXR for "train" 15.9% for +* "validation", 21.6% for "test" +* This split only consists of healthy and active TB samples +* "Latent TB" or "sick & non-TB" samples are not included in this configuration +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified_RS` for dataset details +""" + +from . import _maker + +dataset = _maker("fold_7") diff --git a/src/ptbench/configs/datasets/tbx11k_simplified_RS/fold_8.py b/src/ptbench/configs/datasets/tbx11k_simplified_RS/fold_8.py new file mode 100644 index 0000000000000000000000000000000000000000..9bdbe53d9b45315b546899af54555a00b5eaf74b --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified_RS/fold_8.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (cross validation fold 8) + +* Split reference: first 62.5% of TB and healthy CXR for "train" 15.9% for +* "validation", 21.6% for "test" +* This split only consists of healthy and active TB samples +* "Latent TB" or "sick & non-TB" samples are not included in this configuration +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified_RS` for dataset details +""" + +from . import _maker + +dataset = _maker("fold_8") diff --git a/src/ptbench/configs/datasets/tbx11k_simplified_RS/fold_9.py b/src/ptbench/configs/datasets/tbx11k_simplified_RS/fold_9.py new file mode 100644 index 0000000000000000000000000000000000000000..ff4772bd68a261d801ee624911b7f8997bef531c --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified_RS/fold_9.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (cross validation fold 9) + +* Split reference: first 62.5% of TB and healthy CXR for "train" 15.9% for +* "validation", 21.6% for "test" +* This split only consists of healthy and active TB samples +* "Latent TB" or "sick & non-TB" samples are not included in this configuration +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified_RS` for dataset details +""" + +from . import _maker + +dataset = _maker("fold_9") diff --git a/src/ptbench/configs/datasets/tbx11k_simplified_v2/__init__.py b/src/ptbench/configs/datasets/tbx11k_simplified_v2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a1e6f50ca64a471417fef6ba47253cce9af9b12f --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified_v2/__init__.py @@ -0,0 +1,25 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + + +def _maker(protocol, RGB=False): + from torchvision import transforms + + from ....data.tbx11k_simplified_v2 import dataset as raw + from ....data.transforms import ElasticDeformation + from .. import make_dataset as mk + + post_transforms = [] + if RGB: + post_transforms = [ + transforms.Lambda(lambda x: x.convert("RGB")), + transforms.ToTensor(), + ] + + return mk( + [raw.subsets(protocol)], + [], + [ElasticDeformation(p=0.8)], + post_transforms, + ) diff --git a/src/ptbench/configs/datasets/tbx11k_simplified_v2/default.py b/src/ptbench/configs/datasets/tbx11k_simplified_v2/default.py new file mode 100644 index 0000000000000000000000000000000000000000..f969c8554d7b67f35b2aea6c08a09d2aeb7340bf --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified_v2/default.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (default protocol) + +* Split reference: first 62.6% of CXR for "train", 16% for "validation", +* 21.4% for "test" +* This split consists of non-TB and active TB samples +* "healthy", "latent TB", and "sick & non-TB" samples are all merged under the label "non-TB" +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified_v2` for dataset details +""" + +from . import _maker + +dataset = _maker("default") diff --git a/src/ptbench/configs/datasets/tbx11k_simplified_v2/fold_0.py b/src/ptbench/configs/datasets/tbx11k_simplified_v2/fold_0.py new file mode 100644 index 0000000000000000000000000000000000000000..1805c0005fbad4f93dbb8f2b4225959cb0882832 --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified_v2/fold_0.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (cross validation fold 0) + +* Split reference: first 62.6% of CXR for "train", 16% for "validation", +* 21.4% for "test" +* This split consists of non-TB and active TB samples +* "healthy", "latent TB", and "sick & non-TB" samples are all merged under the label "non-TB" +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified_v2` for dataset details +""" + +from . import _maker + +dataset = _maker("fold_0") diff --git a/src/ptbench/configs/datasets/tbx11k_simplified_v2/fold_0_rgb.py b/src/ptbench/configs/datasets/tbx11k_simplified_v2/fold_0_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..b232652f8d009bd5ec8b80108201b216d3ada961 --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified_v2/fold_0_rgb.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (cross validation fold 0, RGB) + +* Split reference: first 62.6% of CXR for "train", 16% for "validation", +* 21.4% for "test" +* This split consists of non-TB and active TB samples +* "healthy", "latent TB", and "sick & non-TB" samples are all merged under the label "non-TB" +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified_v2` for dataset details +""" + +from . import _maker + +dataset = _maker("fold_0", RGB=True) diff --git a/src/ptbench/configs/datasets/tbx11k_simplified_v2/fold_1.py b/src/ptbench/configs/datasets/tbx11k_simplified_v2/fold_1.py new file mode 100644 index 0000000000000000000000000000000000000000..8d8074a1281c4b528527e02a824892239d558c1a --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified_v2/fold_1.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (cross validation fold 1) + +* Split reference: first 62.6% of CXR for "train", 16% for "validation", +* 21.4% for "test" +* This split consists of non-TB and active TB samples +* "healthy", "latent TB", and "sick & non-TB" samples are all merged under the label "non-TB" +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified_v2` for dataset details +""" + +from . import _maker + +dataset = _maker("fold_1") diff --git a/src/ptbench/configs/datasets/tbx11k_simplified_v2/fold_1_rgb.py b/src/ptbench/configs/datasets/tbx11k_simplified_v2/fold_1_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..ec0421059f94bffb2fa769ad5990f31636e75a47 --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified_v2/fold_1_rgb.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (cross validation fold 1, RGB) + +* Split reference: first 62.6% of CXR for "train", 16% for "validation", +* 21.4% for "test" +* This split consists of non-TB and active TB samples +* "healthy", "latent TB", and "sick & non-TB" samples are all merged under the label "non-TB" +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified_v2` for dataset details +""" + +from . import _maker + +dataset = _maker("fold_1", RGB=True) diff --git a/src/ptbench/configs/datasets/tbx11k_simplified_v2/fold_2.py b/src/ptbench/configs/datasets/tbx11k_simplified_v2/fold_2.py new file mode 100644 index 0000000000000000000000000000000000000000..1962b0370b9a3c24fbe1cb81119760f5239d4f83 --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified_v2/fold_2.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (cross validation fold 2) + +* Split reference: first 62.6% of CXR for "train", 16% for "validation", +* 21.4% for "test" +* This split consists of non-TB and active TB samples +* "healthy", "latent TB", and "sick & non-TB" samples are all merged under the label "non-TB" +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified_v2` for dataset details +""" + +from . import _maker + +dataset = _maker("fold_2") diff --git a/src/ptbench/configs/datasets/tbx11k_simplified_v2/fold_2_rgb.py b/src/ptbench/configs/datasets/tbx11k_simplified_v2/fold_2_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..660073a4c13ae62dba576cffb61953e063bdc560 --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified_v2/fold_2_rgb.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (cross validation fold 2, RGB) + +* Split reference: first 62.6% of CXR for "train", 16% for "validation", +* 21.4% for "test" +* This split consists of non-TB and active TB samples +* "healthy", "latent TB", and "sick & non-TB" samples are all merged under the label "non-TB" +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified_v2` for dataset details +""" + +from . import _maker + +dataset = _maker("fold_2", RGB=True) diff --git a/src/ptbench/configs/datasets/tbx11k_simplified_v2/fold_3.py b/src/ptbench/configs/datasets/tbx11k_simplified_v2/fold_3.py new file mode 100644 index 0000000000000000000000000000000000000000..9872ef5831ffb43a61afce0764a91915aae603d6 --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified_v2/fold_3.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (cross validation fold 3) + +* Split reference: first 62.6% of CXR for "train", 16% for "validation", +* 21.4% for "test" +* This split consists of non-TB and active TB samples +* "healthy", "latent TB", and "sick & non-TB" samples are all merged under the label "non-TB" +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified_v2` for dataset details +""" + +from . import _maker + +dataset = _maker("fold_3") diff --git a/src/ptbench/configs/datasets/tbx11k_simplified_v2/fold_3_rgb.py b/src/ptbench/configs/datasets/tbx11k_simplified_v2/fold_3_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..bb54f1d206ea1893f1c6f4bac516e7d6a00d3a7d --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified_v2/fold_3_rgb.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (cross validation fold 3, RGB) + +* Split reference: first 62.6% of CXR for "train", 16% for "validation", +* 21.4% for "test" +* This split consists of non-TB and active TB samples +* "healthy", "latent TB", and "sick & non-TB" samples are all merged under the label "non-TB" +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified_v2` for dataset details +""" + +from . import _maker + +dataset = _maker("fold_3", RGB=True) diff --git a/src/ptbench/configs/datasets/tbx11k_simplified_v2/fold_4.py b/src/ptbench/configs/datasets/tbx11k_simplified_v2/fold_4.py new file mode 100644 index 0000000000000000000000000000000000000000..7cde75d8faf0d9834bdd9c4bbc0715b76cc5c530 --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified_v2/fold_4.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (cross validation fold 4) + +* Split reference: first 62.6% of CXR for "train", 16% for "validation", +* 21.4% for "test" +* This split consists of non-TB and active TB samples +* "healthy", "latent TB", and "sick & non-TB" samples are all merged under the label "non-TB" +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified_v2` for dataset details +""" + +from . import _maker + +dataset = _maker("fold_4") diff --git a/src/ptbench/configs/datasets/tbx11k_simplified_v2/fold_4_rgb.py b/src/ptbench/configs/datasets/tbx11k_simplified_v2/fold_4_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..d5bf8d8255148537ee2e432e4102949ee2f81d15 --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified_v2/fold_4_rgb.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (cross validation fold 4, RGB) + +* Split reference: first 62.6% of CXR for "train", 16% for "validation", +* 21.4% for "test" +* This split consists of non-TB and active TB samples +* "healthy", "latent TB", and "sick & non-TB" samples are all merged under the label "non-TB" +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified_v2` for dataset details +""" + +from . import _maker + +dataset = _maker("fold_4", RGB=True) diff --git a/src/ptbench/configs/datasets/tbx11k_simplified_v2/fold_5.py b/src/ptbench/configs/datasets/tbx11k_simplified_v2/fold_5.py new file mode 100644 index 0000000000000000000000000000000000000000..d9ca35bf50da984d837d1e4609cb5178a6c6d39c --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified_v2/fold_5.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (cross validation fold 5) + +* Split reference: first 62.6% of CXR for "train", 16% for "validation", +* 21.4% for "test" +* This split consists of non-TB and active TB samples +* "healthy", "latent TB", and "sick & non-TB" samples are all merged under the label "non-TB" +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified_v2` for dataset details +""" + +from . import _maker + +dataset = _maker("fold_5") diff --git a/src/ptbench/configs/datasets/tbx11k_simplified_v2/fold_5_rgb.py b/src/ptbench/configs/datasets/tbx11k_simplified_v2/fold_5_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..983326689b110a2ccb0d78fcfa61cf6a950d2d36 --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified_v2/fold_5_rgb.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (cross validation fold 5, RGB) + +* Split reference: first 62.6% of CXR for "train", 16% for "validation", +* 21.4% for "test" +* This split consists of non-TB and active TB samples +* "healthy", "latent TB", and "sick & non-TB" samples are all merged under the label "non-TB" +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified_v2` for dataset details +""" + +from . import _maker + +dataset = _maker("fold_5", RGB=True) diff --git a/src/ptbench/configs/datasets/tbx11k_simplified_v2/fold_6.py b/src/ptbench/configs/datasets/tbx11k_simplified_v2/fold_6.py new file mode 100644 index 0000000000000000000000000000000000000000..c8abb0658bf78ad2cede7f1c4329735595284138 --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified_v2/fold_6.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (cross validation fold 6) + +* Split reference: first 62.6% of CXR for "train", 16% for "validation", +* 21.4% for "test" +* This split consists of non-TB and active TB samples +* "healthy", "latent TB", and "sick & non-TB" samples are all merged under the label "non-TB" +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified_v2` for dataset details +""" + +from . import _maker + +dataset = _maker("fold_6") diff --git a/src/ptbench/configs/datasets/tbx11k_simplified_v2/fold_6_rgb.py b/src/ptbench/configs/datasets/tbx11k_simplified_v2/fold_6_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..dabdb67452b3cf30e31ab2a70cc08f30d2afb003 --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified_v2/fold_6_rgb.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (cross validation fold 6, RGB) + +* Split reference: first 62.6% of CXR for "train", 16% for "validation", +* 21.4% for "test" +* This split consists of non-TB and active TB samples +* "healthy", "latent TB", and "sick & non-TB" samples are all merged under the label "non-TB" +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified_v2` for dataset details +""" + +from . import _maker + +dataset = _maker("fold_6", RGB=True) diff --git a/src/ptbench/configs/datasets/tbx11k_simplified_v2/fold_7.py b/src/ptbench/configs/datasets/tbx11k_simplified_v2/fold_7.py new file mode 100644 index 0000000000000000000000000000000000000000..67864aa6ccdc3448177f7139879b0eb05f6cd10d --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified_v2/fold_7.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (cross validation fold 7) + +* Split reference: first 62.6% of CXR for "train", 16% for "validation", +* 21.4% for "test" +* This split consists of non-TB and active TB samples +* "healthy", "latent TB", and "sick & non-TB" samples are all merged under the label "non-TB" +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified_v2` for dataset details +""" + +from . import _maker + +dataset = _maker("fold_7") diff --git a/src/ptbench/configs/datasets/tbx11k_simplified_v2/fold_7_rgb.py b/src/ptbench/configs/datasets/tbx11k_simplified_v2/fold_7_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..d37bda6e40f650bc313c90159838857f4140d395 --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified_v2/fold_7_rgb.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (cross validation fold 7, RGB) + +* Split reference: first 62.6% of CXR for "train", 16% for "validation", +* 21.4% for "test" +* This split consists of non-TB and active TB samples +* "healthy", "latent TB", and "sick & non-TB" samples are all merged under the label "non-TB" +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified_v2` for dataset details +""" + +from . import _maker + +dataset = _maker("fold_7", RGB=True) diff --git a/src/ptbench/configs/datasets/tbx11k_simplified_v2/fold_8.py b/src/ptbench/configs/datasets/tbx11k_simplified_v2/fold_8.py new file mode 100644 index 0000000000000000000000000000000000000000..2ac58e16e9e9ebbb9dbc73432605c10c7ad53371 --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified_v2/fold_8.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (cross validation fold 8) + +* Split reference: first 62.6% of CXR for "train", 16% for "validation", +* 21.4% for "test" +* This split consists of non-TB and active TB samples +* "healthy", "latent TB", and "sick & non-TB" samples are all merged under the label "non-TB" +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified_v2` for dataset details +""" + +from . import _maker + +dataset = _maker("fold_8") diff --git a/src/ptbench/configs/datasets/tbx11k_simplified_v2/fold_8_rgb.py b/src/ptbench/configs/datasets/tbx11k_simplified_v2/fold_8_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..f6adcd4300305ab0b153041444f17b1eef2a7de0 --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified_v2/fold_8_rgb.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (cross validation fold 8, RGB) + +* Split reference: first 62.6% of CXR for "train", 16% for "validation", +* 21.4% for "test" +* This split consists of non-TB and active TB samples +* "healthy", "latent TB", and "sick & non-TB" samples are all merged under the label "non-TB" +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified_v2` for dataset details +""" + +from . import _maker + +dataset = _maker("fold_8", RGB=True) diff --git a/src/ptbench/configs/datasets/tbx11k_simplified_v2/fold_9.py b/src/ptbench/configs/datasets/tbx11k_simplified_v2/fold_9.py new file mode 100644 index 0000000000000000000000000000000000000000..1034cf8176641e8afeb27fc2e00acea306411265 --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified_v2/fold_9.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (cross validation fold 9) + +* Split reference: first 62.6% of CXR for "train", 16% for "validation", +* 21.4% for "test" +* This split consists of non-TB and active TB samples +* "healthy", "latent TB", and "sick & non-TB" samples are all merged under the label "non-TB" +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified_v2` for dataset details +""" + +from . import _maker + +dataset = _maker("fold_9") diff --git a/src/ptbench/configs/datasets/tbx11k_simplified_v2/fold_9_rgb.py b/src/ptbench/configs/datasets/tbx11k_simplified_v2/fold_9_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..8fe3f1dbf75faa2537aabe297b41b82d4c9b4c3b --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified_v2/fold_9_rgb.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (cross validation fold 9, RGB) + +* Split reference: first 62.6% of CXR for "train", 16% for "validation", +* 21.4% for "test" +* This split consists of non-TB and active TB samples +* "healthy", "latent TB", and "sick & non-TB" samples are all merged under the label "non-TB" +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified_v2` for dataset details +""" + +from . import _maker + +dataset = _maker("fold_9", RGB=True) diff --git a/src/ptbench/configs/datasets/tbx11k_simplified_v2/rgb.py b/src/ptbench/configs/datasets/tbx11k_simplified_v2/rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..c327dad9578765a4542aa1a41ce16d36d6d49d6a --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified_v2/rgb.py @@ -0,0 +1,18 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (default protocol, converted in +RGB) + +* Split reference: first 62.6% of CXR for "train", 16% for "validation", +* 21.4% for "test" +* This split consists of non-TB and active TB samples +* "healthy", "latent TB", and "sick & non-TB" samples are all merged under the label "non-TB" +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified_v2` for dataset details +""" + +from . import _maker + +dataset = _maker("default", RGB=True) diff --git a/src/ptbench/configs/datasets/tbx11k_simplified_v2_RS/__init__.py b/src/ptbench/configs/datasets/tbx11k_simplified_v2_RS/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4fe785fd103413df959fdd4490387739b7030858 --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified_v2_RS/__init__.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + + +def _maker(protocol): + from ....data.tbx11k_simplified_v2_RS import dataset as raw + from .. import make_dataset as mk + + return mk([raw.subsets(protocol)]) diff --git a/src/ptbench/configs/datasets/tbx11k_simplified_v2_RS/default.py b/src/ptbench/configs/datasets/tbx11k_simplified_v2_RS/default.py new file mode 100644 index 0000000000000000000000000000000000000000..cd8c8ee4e5a5d5a6cb7ece8cd5ba7a1c831d5785 --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified_v2_RS/default.py @@ -0,0 +1,18 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (default protocol) (extended with +DensenetRS predictions) + +* Split reference: first 62.6% of CXR for "train", 16% for "validation", +* 21.4% for "test" +* This split consists of non-TB and active TB samples +* "healthy", "latent TB", and "sick & non-TB" samples are all merged under the label "non-TB" +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified_v2_RS` for dataset details +""" + +from . import _maker + +dataset = _maker("default") diff --git a/src/ptbench/configs/datasets/tbx11k_simplified_v2_RS/fold_0.py b/src/ptbench/configs/datasets/tbx11k_simplified_v2_RS/fold_0.py new file mode 100644 index 0000000000000000000000000000000000000000..43f6dd382504dc77283b8d8dba44d4c84a0dffc7 --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified_v2_RS/fold_0.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (cross validation fold 0) + +* Split reference: first 62.6% of CXR for "train", 16% for "validation", +* 21.4% for "test" +* This split consists of non-TB and active TB samples +* "healthy", "latent TB", and "sick & non-TB" samples are all merged under the label "non-TB" +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified_v2_RS` for dataset details +""" + +from . import _maker + +dataset = _maker("fold_0") diff --git a/src/ptbench/configs/datasets/tbx11k_simplified_v2_RS/fold_1.py b/src/ptbench/configs/datasets/tbx11k_simplified_v2_RS/fold_1.py new file mode 100644 index 0000000000000000000000000000000000000000..80f29174707cfd2b9fd3afb32748d247690a1efc --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified_v2_RS/fold_1.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (cross validation fold 1) + +* Split reference: first 62.6% of CXR for "train", 16% for "validation", +* 21.4% for "test" +* This split consists of non-TB and active TB samples +* "healthy", "latent TB", and "sick & non-TB" samples are all merged under the label "non-TB" +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified_v2_RS` for dataset details +""" + +from . import _maker + +dataset = _maker("fold_1") diff --git a/src/ptbench/configs/datasets/tbx11k_simplified_v2_RS/fold_2.py b/src/ptbench/configs/datasets/tbx11k_simplified_v2_RS/fold_2.py new file mode 100644 index 0000000000000000000000000000000000000000..2b9fa0457075619ece160125d681a51c86996ba8 --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified_v2_RS/fold_2.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (cross validation fold 2) + +* Split reference: first 62.6% of CXR for "train", 16% for "validation", +* 21.4% for "test" +* This split consists of non-TB and active TB samples +* "healthy", "latent TB", and "sick & non-TB" samples are all merged under the label "non-TB" +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified_v2_RS` for dataset details +""" + +from . import _maker + +dataset = _maker("fold_2") diff --git a/src/ptbench/configs/datasets/tbx11k_simplified_v2_RS/fold_3.py b/src/ptbench/configs/datasets/tbx11k_simplified_v2_RS/fold_3.py new file mode 100644 index 0000000000000000000000000000000000000000..afd4b88e1ee6a4e5144a0a75e9d35ff0320d0a26 --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified_v2_RS/fold_3.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (cross validation fold 3) + +* Split reference: first 62.6% of CXR for "train", 16% for "validation", +* 21.4% for "test" +* This split consists of non-TB and active TB samples +* "healthy", "latent TB", and "sick & non-TB" samples are all merged under the label "non-TB" +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified_v2_RS` for dataset details +""" + +from . import _maker + +dataset = _maker("fold_3") diff --git a/src/ptbench/configs/datasets/tbx11k_simplified_v2_RS/fold_4.py b/src/ptbench/configs/datasets/tbx11k_simplified_v2_RS/fold_4.py new file mode 100644 index 0000000000000000000000000000000000000000..df6351cc8068256195eb32c81f3d587a278aa252 --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified_v2_RS/fold_4.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (cross validation fold 4) + +* Split reference: first 62.6% of CXR for "train", 16% for "validation", +* 21.4% for "test" +* This split consists of non-TB and active TB samples +* "healthy", "latent TB", and "sick & non-TB" samples are all merged under the label "non-TB" +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified_v2_RS` for dataset details +""" + +from . import _maker + +dataset = _maker("fold_4") diff --git a/src/ptbench/configs/datasets/tbx11k_simplified_v2_RS/fold_5.py b/src/ptbench/configs/datasets/tbx11k_simplified_v2_RS/fold_5.py new file mode 100644 index 0000000000000000000000000000000000000000..41acd200cf8042b5dad5f787a4254b7ba23a12f4 --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified_v2_RS/fold_5.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (cross validation fold 5) + +* Split reference: first 62.6% of CXR for "train", 16% for "validation", +* 21.4% for "test" +* This split consists of non-TB and active TB samples +* "healthy", "latent TB", and "sick & non-TB" samples are all merged under the label "non-TB" +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified_v2_RS` for dataset details +""" + +from . import _maker + +dataset = _maker("fold_5") diff --git a/src/ptbench/configs/datasets/tbx11k_simplified_v2_RS/fold_6.py b/src/ptbench/configs/datasets/tbx11k_simplified_v2_RS/fold_6.py new file mode 100644 index 0000000000000000000000000000000000000000..b31df1ace8b389f69c9bbe294a166eb387c08562 --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified_v2_RS/fold_6.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (cross validation fold 6) + +* Split reference: first 62.6% of CXR for "train", 16% for "validation", +* 21.4% for "test" +* This split consists of non-TB and active TB samples +* "healthy", "latent TB", and "sick & non-TB" samples are all merged under the label "non-TB" +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified_v2_RS` for dataset details +""" + +from . import _maker + +dataset = _maker("fold_6") diff --git a/src/ptbench/configs/datasets/tbx11k_simplified_v2_RS/fold_7.py b/src/ptbench/configs/datasets/tbx11k_simplified_v2_RS/fold_7.py new file mode 100644 index 0000000000000000000000000000000000000000..a79b55899d7726627c4278d0f7563e9bbb991b44 --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified_v2_RS/fold_7.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (cross validation fold 7) + +* Split reference: first 62.6% of CXR for "train", 16% for "validation", +* 21.4% for "test" +* This split consists of non-TB and active TB samples +* "healthy", "latent TB", and "sick & non-TB" samples are all merged under the label "non-TB" +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified_v2_RS` for dataset details +""" + +from . import _maker + +dataset = _maker("fold_7") diff --git a/src/ptbench/configs/datasets/tbx11k_simplified_v2_RS/fold_8.py b/src/ptbench/configs/datasets/tbx11k_simplified_v2_RS/fold_8.py new file mode 100644 index 0000000000000000000000000000000000000000..2aedd2413b5c784b04c4e3ae8695f0bbd37fc240 --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified_v2_RS/fold_8.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (cross validation fold 8) + +* Split reference: first 62.6% of CXR for "train", 16% for "validation", +* 21.4% for "test" +* This split consists of non-TB and active TB samples +* "healthy", "latent TB", and "sick & non-TB" samples are all merged under the label "non-TB" +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified_v2_RS` for dataset details +""" + +from . import _maker + +dataset = _maker("fold_8") diff --git a/src/ptbench/configs/datasets/tbx11k_simplified_v2_RS/fold_9.py b/src/ptbench/configs/datasets/tbx11k_simplified_v2_RS/fold_9.py new file mode 100644 index 0000000000000000000000000000000000000000..ac5fdc6962aafe3826c3dea67977822d23272563 --- /dev/null +++ b/src/ptbench/configs/datasets/tbx11k_simplified_v2_RS/fold_9.py @@ -0,0 +1,17 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11k simplified dataset for TB detection (cross validation fold 9) + +* Split reference: first 62.6% of CXR for "train", 16% for "validation", +* 21.4% for "test" +* This split consists of non-TB and active TB samples +* "healthy", "latent TB", and "sick & non-TB" samples are all merged under the label "non-TB" +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.tbx11k_simplified_v2_RS` for dataset details +""" + +from . import _maker + +dataset = _maker("fold_9") diff --git a/src/ptbench/configs/models/alexnet.py b/src/ptbench/configs/models/alexnet.py index 5917db2146ac2a1efa4584cbc6126a3e88ae7f86..cf8bfd35aa10ad3493be9483ba0347fc8ebb7da1 100644 --- a/src/ptbench/configs/models/alexnet.py +++ b/src/ptbench/configs/models/alexnet.py @@ -4,19 +4,21 @@ """AlexNet.""" +from torch import empty from torch.nn import BCEWithLogitsLoss -from torch.optim import SGD -from ...models.alexnet import build_alexnet +from ...models.alexnet import Alexnet # config -lr = 0.01 - -# model -model = build_alexnet(pretrained=False) +optimizer_configs = {"lr": 0.01, "momentum": 0.1} # optimizer -optimizer = SGD(model.parameters(), lr=lr, momentum=0.1) - +optimizer = "SGD" # criterion -criterion = BCEWithLogitsLoss() +criterion = BCEWithLogitsLoss(pos_weight=empty(1)) +criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1)) + +# model +model = Alexnet( + criterion, criterion_valid, optimizer, optimizer_configs, pretrained=False +) diff --git a/src/ptbench/configs/models/alexnet_pretrained.py b/src/ptbench/configs/models/alexnet_pretrained.py index f792151dc4b489c0ef27526db2bcfcccd5852be3..1d196be6f79ea5c70987c1d1a66eaf32e8e7ca4c 100644 --- a/src/ptbench/configs/models/alexnet_pretrained.py +++ b/src/ptbench/configs/models/alexnet_pretrained.py @@ -2,24 +2,23 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -"""AlexNet. - -Pretrained AlexNet -""" +"""AlexNet.""" +from torch import empty from torch.nn import BCEWithLogitsLoss -from torch.optim import SGD -from ...models.alexnet import build_alexnet +from ...models.alexnet import Alexnet # config -lr = 0.001 - -# model -model = build_alexnet(pretrained=True) +optimizer_configs = {"lr": 0.001, "momentum": 0.1} # optimizer -optimizer = SGD(model.parameters(), lr=lr, momentum=0.1) - +optimizer = "SGD" # criterion -criterion = BCEWithLogitsLoss() +criterion = BCEWithLogitsLoss(pos_weight=empty(1)) +criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1)) + +# model +model = Alexnet( + criterion, criterion_valid, optimizer, optimizer_configs, pretrained=True +) diff --git a/src/ptbench/configs/models/densenet.py b/src/ptbench/configs/models/densenet.py index 2017786a71968846a763deb5b13879bae32cf5e6..6759490854fd05ac8d8d3a9eec5b1494e0cfb0f2 100644 --- a/src/ptbench/configs/models/densenet.py +++ b/src/ptbench/configs/models/densenet.py @@ -4,19 +4,22 @@ """DenseNet.""" +from torch import empty from torch.nn import BCEWithLogitsLoss -from torch.optim import Adam -from ...models.densenet import build_densenet +from ...models.densenet import Densenet # config -lr = 0.0001 - -# model -model = build_densenet(pretrained=False) +optimizer_configs = {"lr": 0.0001} # optimizer -optimizer = Adam(model.parameters(), lr=lr) +optimizer = "Adam" # criterion -criterion = BCEWithLogitsLoss() +criterion = BCEWithLogitsLoss(pos_weight=empty(1)) +criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1)) + +# model +model = Densenet( + criterion, criterion_valid, optimizer, optimizer_configs, pretrained=False +) diff --git a/src/ptbench/configs/models/densenet_pretrained.py b/src/ptbench/configs/models/densenet_pretrained.py index cb729083ab89fa96e13232ad7223c1659d3cdc69..b018a52203061b847cdae9f09b5edfa713930302 100644 --- a/src/ptbench/configs/models/densenet_pretrained.py +++ b/src/ptbench/configs/models/densenet_pretrained.py @@ -4,19 +4,22 @@ """DenseNet.""" +from torch import empty from torch.nn import BCEWithLogitsLoss -from torch.optim import Adam -from ...models.densenet import build_densenet +from ...models.densenet import Densenet # config -lr = 0.01 - -# model -model = build_densenet(pretrained=True) +optimizer_configs = {"lr": 0.01} # optimizer -optimizer = Adam(model.parameters(), lr=lr) +optimizer = "Adam" # criterion -criterion = BCEWithLogitsLoss() +criterion = BCEWithLogitsLoss(pos_weight=empty(1)) +criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1)) + +# model +model = Densenet( + criterion, criterion_valid, optimizer, optimizer_configs, pretrained=True +) diff --git a/src/ptbench/configs/models/logistic_regression.py b/src/ptbench/configs/models/logistic_regression.py index b93935b471b3f34a371c81b7659ae12309bee05b..145dddd7a3a9c6e0b0574f3c6358f6f9f51056a0 100644 --- a/src/ptbench/configs/models/logistic_regression.py +++ b/src/ptbench/configs/models/logistic_regression.py @@ -7,20 +7,23 @@ Simple feedforward network taking radiological signs in output and predicting tuberculosis presence in output. """ - +from torch import empty from torch.nn import BCEWithLogitsLoss -from torch.optim import Adam -from ...models.logistic_regression import build_logistic_regression +from ...models.logistic_regression import LogisticRegression # config -lr = 1e-2 - -# model -model = build_logistic_regression(14) +optimizer_configs = {"lr": 1e-2} +input_size = 14 # optimizer -optimizer = Adam(model.parameters(), lr=lr) +optimizer = "Adam" # criterion -criterion = BCEWithLogitsLoss() +criterion = BCEWithLogitsLoss(pos_weight=empty(1)) +criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1)) + +# model +model = LogisticRegression( + criterion, criterion_valid, optimizer, optimizer_configs, input_size +) diff --git a/src/ptbench/configs/models/pasa.py b/src/ptbench/configs/models/pasa.py index d09161d0408118a3fd2dccefc1b082319f99051b..3ee0b92164b5531b65049b94e71b01b07e2ad27e 100644 --- a/src/ptbench/configs/models/pasa.py +++ b/src/ptbench/configs/models/pasa.py @@ -11,19 +11,20 @@ Screening and Visualization". Reference: [PASA-2019]_ """ +from torch import empty from torch.nn import BCEWithLogitsLoss -from torch.optim import Adam -from ...models.pasa import build_pasa +from ...models.pasa import PASA # config -lr = 8e-5 - -# model -model = build_pasa() +optimizer_configs = {"lr": 8e-5} # optimizer -optimizer = Adam(model.parameters(), lr=lr) +optimizer = "Adam" # criterion -criterion = BCEWithLogitsLoss() +criterion = BCEWithLogitsLoss(pos_weight=empty(1)) +criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1)) + +# model +model = PASA(criterion, criterion_valid, optimizer, optimizer_configs) diff --git a/src/ptbench/configs/models/signs_to_tb.py b/src/ptbench/configs/models/signs_to_tb.py index 3bd552da69e558c5bc083c75829ab2af2d69d8b3..1ce89b1299764463f3f7f9e54e2ad2172acfea6f 100644 --- a/src/ptbench/configs/models/signs_to_tb.py +++ b/src/ptbench/configs/models/signs_to_tb.py @@ -8,19 +8,22 @@ Simple feedforward network taking radiological signs in output and predicting tuberculosis presence in output. """ +from torch import empty from torch.nn import BCEWithLogitsLoss -from torch.optim import Adam -from ...models.signs_to_tb import build_signs_to_tb +from ...models.signs_to_tb import SignsToTB # config -lr = 1e-2 - -# model -model = build_signs_to_tb(14, 10) +optimizer_configs = {"lr": 1e-2} # optimizer -optimizer = Adam(model.parameters(), lr=lr) +optimizer = "Adam" # criterion -criterion = BCEWithLogitsLoss() +criterion = BCEWithLogitsLoss(pos_weight=empty(1)) +criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1)) + +# model +model = SignsToTB( + criterion, criterion_valid, optimizer, optimizer_configs, 14, 10 +) diff --git a/src/ptbench/configs/models_datasets/densenet_rs.py b/src/ptbench/configs/models_datasets/densenet_rs.py index 57fc7b78674779d50fe16002dc154b0dbe13e33a..714404bf173c8868e4c67e49f257a1b221e540d7 100644 --- a/src/ptbench/configs/models_datasets/densenet_rs.py +++ b/src/ptbench/configs/models_datasets/densenet_rs.py @@ -7,10 +7,10 @@ A Densenet121 model for radiological extraction """ +from torch import empty from torch.nn import BCEWithLogitsLoss -from torch.optim import Adam -from ...models.densenet_rs import build_densenetrs +from ...models.densenet_rs import DensenetRS # Import the default protocol if none is available if "dataset" not in locals(): @@ -19,16 +19,14 @@ if "dataset" not in locals(): dataset = default.dataset # config -lr = 1e-4 - -# model -model = build_densenetrs() +optimizer_configs = {"lr": 1e-4} # optimizer -optimizer = Adam( - filter(lambda p: p.requires_grad, model.model.model_ft.parameters()), lr=lr -) +optimizer = "Adam" # criterion -criterion = BCEWithLogitsLoss() -criterion_valid = BCEWithLogitsLoss() +criterion = BCEWithLogitsLoss(pos_weight=empty(1)) +criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1)) + +# model +model = DensenetRS(criterion, criterion_valid, optimizer, optimizer_configs) diff --git a/src/ptbench/data/loader.py b/src/ptbench/data/loader.py index 2e03aa353c5f5e95db70b53072f94071b60cf677..12a7517e9249c61fcbe58ea7339258e095feb3d0 100644 --- a/src/ptbench/data/loader.py +++ b/src/ptbench/data/loader.py @@ -101,3 +101,37 @@ def make_delayed(sample, loader, key=None): key=key or sample["data"], label=sample["label"], ) + + +def make_delayed_bbox(sample, loader, key=None): + """Returns a delayed-loading Sample object. + + Parameters + ---------- + + sample : dict + A dictionary that maps field names to sample data values (e.g. paths) + + loader : object + A function that inputs ``sample`` dictionaries and returns the loaded + data. + + key : str + A unique key identifier for this sample. If not provided, assumes + ``sample`` is a dictionary with a ``data`` entry and uses its path as + key. + + + Returns + ------- + + sample : ptbench.data.sample.DelayedSample + In which ``key`` is as provided and ``data`` can be accessed to trigger + sample loading. + """ + return DelayedSample( + functools.partial(loader, sample), + key=key or sample["data"], + label=sample["label"], + bboxes=sample["bboxes"], + ) diff --git a/src/ptbench/data/tbx11k_simplified/__init__.py b/src/ptbench/data/tbx11k_simplified/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dd2be35dc4d7992a66a101ad5dcc00ceb127c082 --- /dev/null +++ b/src/ptbench/data/tbx11k_simplified/__init__.py @@ -0,0 +1,90 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11K simplified dataset for computer-aided diagnosis. + +The TBX11K database has been established to foster research +in computer-aided diagnosis of pulmonary diseases with a special +focus on tuberculosis (aTB). The dataset was specifically +designed to be used with CNNs. It contains 11,000 chest X-ray +images, each of a unique patient. They were labeled by expert +radiologists with 5 - 10+ years of experience. Possible labels +are: "healthy", "active TB", "latent TB", and "sick & non-tb". +The version of the dataset used in this benchmark is a simplified. + +* Reference: [TBX11K-SIMPLIFIED-2020]_ +* Original (released) resolution (height x width or width x height): 512 x 512 +* Split reference: none +* Protocol ``default``: + + * Training samples: 62.5% of TB and healthy CXR (including labels) + * Validation samples: 15.9% of TB and healthy CXR (including labels) + * Test samples: 21.6% of TB and healthy CXR (including labels) +""" + +import importlib.resources +import os + +from ...utils.rc import load_rc +from ..dataset import JSONDataset +from ..loader import load_pil_baw, make_delayed, make_delayed_bbox + +_protocols = [ + importlib.resources.files(__name__).joinpath("default.json.bz2"), + importlib.resources.files(__name__).joinpath("fold_0.json.bz2"), + importlib.resources.files(__name__).joinpath("fold_1.json.bz2"), + importlib.resources.files(__name__).joinpath("fold_2.json.bz2"), + importlib.resources.files(__name__).joinpath("fold_3.json.bz2"), + importlib.resources.files(__name__).joinpath("fold_4.json.bz2"), + importlib.resources.files(__name__).joinpath("fold_5.json.bz2"), + importlib.resources.files(__name__).joinpath("fold_6.json.bz2"), + importlib.resources.files(__name__).joinpath("fold_7.json.bz2"), + importlib.resources.files(__name__).joinpath("fold_8.json.bz2"), + importlib.resources.files(__name__).joinpath("fold_9.json.bz2"), +] + +_datadir = load_rc().get( + "datadir.tbx11k_simplified", os.path.realpath(os.curdir) +) + + +def _raw_data_loader(sample): + return dict( + data=load_pil_baw(os.path.join(_datadir, sample["data"])), # type: ignore + label=sample["label"], + ) + + +def _raw_data_loader_bbox(sample): + return dict( + data=load_pil_baw(os.path.join(_datadir, sample["data"])), # type: ignore + label=sample["label"], + bboxes=sample["bboxes"], + ) + + +def _loader(context, sample): + # "context" is ignored in this case - database is homogeneous + # we return delayed samples to avoid loading all images at once + return make_delayed(sample, _raw_data_loader) + + +def _loader_bbox(context, sample): + # "context" is ignored in this case - database is homogeneous + # we return delayed samples to avoid loading all images at once + return make_delayed_bbox(sample, _raw_data_loader_bbox) + + +dataset = JSONDataset( + protocols=_protocols, + fieldnames=("data", "label"), + loader=_loader, +) + +dataset_with_bboxes = JSONDataset( + protocols=_protocols, + fieldnames=("data", "label", "bboxes"), + loader=_loader_bbox, +) +"""TBX11K simplified dataset object.""" diff --git a/src/ptbench/data/tbx11k_simplified/default.json.bz2 b/src/ptbench/data/tbx11k_simplified/default.json.bz2 new file mode 100644 index 0000000000000000000000000000000000000000..06e5a3a7011f976118ae1067785e59b9e294480c Binary files /dev/null and b/src/ptbench/data/tbx11k_simplified/default.json.bz2 differ diff --git a/src/ptbench/data/tbx11k_simplified/fold_0.json.bz2 b/src/ptbench/data/tbx11k_simplified/fold_0.json.bz2 new file mode 100644 index 0000000000000000000000000000000000000000..267555153aebab7b1a0cfdd3dca0d4e7149cf5f3 Binary files /dev/null and b/src/ptbench/data/tbx11k_simplified/fold_0.json.bz2 differ diff --git a/src/ptbench/data/tbx11k_simplified/fold_1.json.bz2 b/src/ptbench/data/tbx11k_simplified/fold_1.json.bz2 new file mode 100644 index 0000000000000000000000000000000000000000..ceaa92a210722a839b40852d6dbb866b09712a99 Binary files /dev/null and b/src/ptbench/data/tbx11k_simplified/fold_1.json.bz2 differ diff --git a/src/ptbench/data/tbx11k_simplified/fold_2.json.bz2 b/src/ptbench/data/tbx11k_simplified/fold_2.json.bz2 new file mode 100644 index 0000000000000000000000000000000000000000..615ecfe9ddb562c27b14fcefd531b131a337f1b4 Binary files /dev/null and b/src/ptbench/data/tbx11k_simplified/fold_2.json.bz2 differ diff --git a/src/ptbench/data/tbx11k_simplified/fold_3.json.bz2 b/src/ptbench/data/tbx11k_simplified/fold_3.json.bz2 new file mode 100644 index 0000000000000000000000000000000000000000..aee0fea3dc3095b6fb703b33365248a0439888d3 Binary files /dev/null and b/src/ptbench/data/tbx11k_simplified/fold_3.json.bz2 differ diff --git a/src/ptbench/data/tbx11k_simplified/fold_4.json.bz2 b/src/ptbench/data/tbx11k_simplified/fold_4.json.bz2 new file mode 100644 index 0000000000000000000000000000000000000000..a312ef4cb65aad35644b0e862da9be9ffce8f44f Binary files /dev/null and b/src/ptbench/data/tbx11k_simplified/fold_4.json.bz2 differ diff --git a/src/ptbench/data/tbx11k_simplified/fold_5.json.bz2 b/src/ptbench/data/tbx11k_simplified/fold_5.json.bz2 new file mode 100644 index 0000000000000000000000000000000000000000..926d83ed57633fe7388f744d904723821ece7bd7 Binary files /dev/null and b/src/ptbench/data/tbx11k_simplified/fold_5.json.bz2 differ diff --git a/src/ptbench/data/tbx11k_simplified/fold_6.json.bz2 b/src/ptbench/data/tbx11k_simplified/fold_6.json.bz2 new file mode 100644 index 0000000000000000000000000000000000000000..47b8d130e7523fa9ab4a4f8927e2302dbbaeb352 Binary files /dev/null and b/src/ptbench/data/tbx11k_simplified/fold_6.json.bz2 differ diff --git a/src/ptbench/data/tbx11k_simplified/fold_7.json.bz2 b/src/ptbench/data/tbx11k_simplified/fold_7.json.bz2 new file mode 100644 index 0000000000000000000000000000000000000000..7300fee208aff942afd5cc25845ecb7b3c55f859 Binary files /dev/null and b/src/ptbench/data/tbx11k_simplified/fold_7.json.bz2 differ diff --git a/src/ptbench/data/tbx11k_simplified/fold_8.json.bz2 b/src/ptbench/data/tbx11k_simplified/fold_8.json.bz2 new file mode 100644 index 0000000000000000000000000000000000000000..b7056757fa5af3fccb19eb03d52f0bcd72d13571 Binary files /dev/null and b/src/ptbench/data/tbx11k_simplified/fold_8.json.bz2 differ diff --git a/src/ptbench/data/tbx11k_simplified/fold_9.json.bz2 b/src/ptbench/data/tbx11k_simplified/fold_9.json.bz2 new file mode 100644 index 0000000000000000000000000000000000000000..ef2c46e58692ba23816fcd3c853ae4a702acd591 Binary files /dev/null and b/src/ptbench/data/tbx11k_simplified/fold_9.json.bz2 differ diff --git a/src/ptbench/data/tbx11k_simplified_RS/__init__.py b/src/ptbench/data/tbx11k_simplified_RS/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8598a8c5b1a11f1a8a1866e5e4d04b83c27925ff --- /dev/null +++ b/src/ptbench/data/tbx11k_simplified_RS/__init__.py @@ -0,0 +1,62 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Extended TBX11K simplified dataset for computer-aided diagnosis (extended +with DensenetRS predictions) + +The TBX11K database has been established to foster research +in computer-aided diagnosis of pulmonary diseases with a special +focus on tuberculosis (aTB). The dataset was specifically +designed to be used with CNNs. It contains 11,000 chest X-ray +images, each of a unique patient. They were labeled by expert +radiologists with 5 - 10+ years of experience. Possible labels +are: "healthy", "active TB", "latent TB", and "sick & non-tb". +The version of the dataset used in this benchmark is a simplified. + +* Reference: [TBX11K-SIMPLIFIED-2020]_ +* Original (released) resolution (height x width or width x height): 512 x 512 +* Split reference: none +* Protocol ``default``: + + * Training samples: 62.5% of TB and healthy CXR (including labels) + * Validation samples: 15.9% of TB and healthy CXR (including labels) + * Test samples: 21.6% of TB and healthy CXR (including labels) +""" + +import importlib.resources + +from ..dataset import JSONDataset +from ..loader import make_delayed + +_protocols = [ + importlib.resources.files(__name__).joinpath("default.json.bz2"), + importlib.resources.files(__name__).joinpath("fold_0.json.bz2"), + importlib.resources.files(__name__).joinpath("fold_1.json.bz2"), + importlib.resources.files(__name__).joinpath("fold_2.json.bz2"), + importlib.resources.files(__name__).joinpath("fold_3.json.bz2"), + importlib.resources.files(__name__).joinpath("fold_4.json.bz2"), + importlib.resources.files(__name__).joinpath("fold_5.json.bz2"), + importlib.resources.files(__name__).joinpath("fold_6.json.bz2"), + importlib.resources.files(__name__).joinpath("fold_7.json.bz2"), + importlib.resources.files(__name__).joinpath("fold_8.json.bz2"), + importlib.resources.files(__name__).joinpath("fold_9.json.bz2"), +] + + +def _raw_data_loader(sample): + return dict(data=sample["data"], label=sample["label"]) + + +def _loader(context, sample): + # "context" is ignored in this case - database is homogeneous + # we returned delayed samples to avoid loading all images at once + return make_delayed(sample, _raw_data_loader, key=sample["filename"]) + + +dataset = JSONDataset( + protocols=_protocols, + fieldnames=("filename", "label", "data"), + loader=_loader, +) +"""Extended TBX11K simplified dataset object.""" diff --git a/src/ptbench/data/tbx11k_simplified_RS/default.json.bz2 b/src/ptbench/data/tbx11k_simplified_RS/default.json.bz2 new file mode 100644 index 0000000000000000000000000000000000000000..5c192dec44e60c3fb36606cd60d0fa3505d8a96b Binary files /dev/null and b/src/ptbench/data/tbx11k_simplified_RS/default.json.bz2 differ diff --git a/src/ptbench/data/tbx11k_simplified_RS/fold_0.json.bz2 b/src/ptbench/data/tbx11k_simplified_RS/fold_0.json.bz2 new file mode 100644 index 0000000000000000000000000000000000000000..4c5f8107c702ece1c017efdbea53de24f4a635c8 Binary files /dev/null and b/src/ptbench/data/tbx11k_simplified_RS/fold_0.json.bz2 differ diff --git a/src/ptbench/data/tbx11k_simplified_RS/fold_1.json.bz2 b/src/ptbench/data/tbx11k_simplified_RS/fold_1.json.bz2 new file mode 100644 index 0000000000000000000000000000000000000000..649066593b98927a38a29a4d9fedb29a8e6f74fc Binary files /dev/null and b/src/ptbench/data/tbx11k_simplified_RS/fold_1.json.bz2 differ diff --git a/src/ptbench/data/tbx11k_simplified_RS/fold_2.json.bz2 b/src/ptbench/data/tbx11k_simplified_RS/fold_2.json.bz2 new file mode 100644 index 0000000000000000000000000000000000000000..1d92d9b8ac56fd1d6ac7e652f586563e95d74ada Binary files /dev/null and b/src/ptbench/data/tbx11k_simplified_RS/fold_2.json.bz2 differ diff --git a/src/ptbench/data/tbx11k_simplified_RS/fold_3.json.bz2 b/src/ptbench/data/tbx11k_simplified_RS/fold_3.json.bz2 new file mode 100644 index 0000000000000000000000000000000000000000..ec4b689bebf364d8c95851a2e9887aee72ee63a7 Binary files /dev/null and b/src/ptbench/data/tbx11k_simplified_RS/fold_3.json.bz2 differ diff --git a/src/ptbench/data/tbx11k_simplified_RS/fold_4.json.bz2 b/src/ptbench/data/tbx11k_simplified_RS/fold_4.json.bz2 new file mode 100644 index 0000000000000000000000000000000000000000..490352c09dae504c2a60c9d662ca4c8d7692b29c Binary files /dev/null and b/src/ptbench/data/tbx11k_simplified_RS/fold_4.json.bz2 differ diff --git a/src/ptbench/data/tbx11k_simplified_RS/fold_5.json.bz2 b/src/ptbench/data/tbx11k_simplified_RS/fold_5.json.bz2 new file mode 100644 index 0000000000000000000000000000000000000000..d20b871321b8445588a561ce833d9ca03ce925de Binary files /dev/null and b/src/ptbench/data/tbx11k_simplified_RS/fold_5.json.bz2 differ diff --git a/src/ptbench/data/tbx11k_simplified_RS/fold_6.json.bz2 b/src/ptbench/data/tbx11k_simplified_RS/fold_6.json.bz2 new file mode 100644 index 0000000000000000000000000000000000000000..8eb8360de34f36258ed341f0bb3e89c0bc4b966d Binary files /dev/null and b/src/ptbench/data/tbx11k_simplified_RS/fold_6.json.bz2 differ diff --git a/src/ptbench/data/tbx11k_simplified_RS/fold_7.json.bz2 b/src/ptbench/data/tbx11k_simplified_RS/fold_7.json.bz2 new file mode 100644 index 0000000000000000000000000000000000000000..3c9753344bd607e9dda547eb8c5bf4cf5536b4f3 Binary files /dev/null and b/src/ptbench/data/tbx11k_simplified_RS/fold_7.json.bz2 differ diff --git a/src/ptbench/data/tbx11k_simplified_RS/fold_8.json.bz2 b/src/ptbench/data/tbx11k_simplified_RS/fold_8.json.bz2 new file mode 100644 index 0000000000000000000000000000000000000000..433d4de4b5805dba15aed5d036ee8558b0249226 Binary files /dev/null and b/src/ptbench/data/tbx11k_simplified_RS/fold_8.json.bz2 differ diff --git a/src/ptbench/data/tbx11k_simplified_RS/fold_9.json.bz2 b/src/ptbench/data/tbx11k_simplified_RS/fold_9.json.bz2 new file mode 100644 index 0000000000000000000000000000000000000000..824ea8aafaf7198c85726cb52d14243ee9267f00 Binary files /dev/null and b/src/ptbench/data/tbx11k_simplified_RS/fold_9.json.bz2 differ diff --git a/src/ptbench/data/tbx11k_simplified_v2/__init__.py b/src/ptbench/data/tbx11k_simplified_v2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7cf676b00d4570c8b2f6b42724f55eeb0eab2fc4 --- /dev/null +++ b/src/ptbench/data/tbx11k_simplified_v2/__init__.py @@ -0,0 +1,90 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""TBX11K simplified dataset for computer-aided diagnosis. + +The TBX11K database has been established to foster research +in computer-aided diagnosis of pulmonary diseases with a special +focus on tuberculosis (aTB). The dataset was specifically +designed to be used with CNNs. It contains 11,000 chest X-ray +images, each of a unique patient. They were labeled by expert +radiologists with 5 - 10+ years of experience. Possible labels +are: "healthy", "active TB", "latent TB", and "sick & non-tb". +The version of the dataset used in this benchmark is a simplified. + +* Reference: [TBX11K-SIMPLIFIED-2020]_ +* Original (released) resolution (height x width or width x height): 512 x 512 +* Split reference: none +* Protocol ``default``: + + * Training samples: 62.6% of CXR (including labels) + * Validation samples: 16% of CXR (including labels) + * Test samples: 21.4% of CXR (including labels) +""" + +import importlib.resources +import os + +from ...utils.rc import load_rc +from ..dataset import JSONDataset +from ..loader import load_pil_baw, make_delayed, make_delayed_bbox + +_protocols = [ + importlib.resources.files(__name__).joinpath("default.json.bz2"), + importlib.resources.files(__name__).joinpath("fold_0.json.bz2"), + importlib.resources.files(__name__).joinpath("fold_1.json.bz2"), + importlib.resources.files(__name__).joinpath("fold_2.json.bz2"), + importlib.resources.files(__name__).joinpath("fold_3.json.bz2"), + importlib.resources.files(__name__).joinpath("fold_4.json.bz2"), + importlib.resources.files(__name__).joinpath("fold_5.json.bz2"), + importlib.resources.files(__name__).joinpath("fold_6.json.bz2"), + importlib.resources.files(__name__).joinpath("fold_7.json.bz2"), + importlib.resources.files(__name__).joinpath("fold_8.json.bz2"), + importlib.resources.files(__name__).joinpath("fold_9.json.bz2"), +] + +_datadir = load_rc().get( + "datadir.tbx11k_simplified", os.path.realpath(os.curdir) +) + + +def _raw_data_loader(sample): + return dict( + data=load_pil_baw(os.path.join(_datadir, sample["data"])), # type: ignore + label=sample["label"], + ) + + +def _raw_data_loader_bbox(sample): + return dict( + data=load_pil_baw(os.path.join(_datadir, sample["data"])), # type: ignore + label=sample["label"], + bboxes=sample["bboxes"], + ) + + +def _loader(context, sample): + # "context" is ignored in this case - database is homogeneous + # we return delayed samples to avoid loading all images at once + return make_delayed(sample, _raw_data_loader) + + +def _loader_bbox(context, sample): + # "context" is ignored in this case - database is homogeneous + # we return delayed samples to avoid loading all images at once + return make_delayed_bbox(sample, _raw_data_loader_bbox) + + +dataset = JSONDataset( + protocols=_protocols, + fieldnames=("data", "label"), + loader=_loader, +) + +dataset_with_bboxes = JSONDataset( + protocols=_protocols, + fieldnames=("data", "label", "bboxes"), + loader=_loader_bbox, +) +"""TBX11K simplified dataset object.""" diff --git a/src/ptbench/data/tbx11k_simplified_v2/default.json.bz2 b/src/ptbench/data/tbx11k_simplified_v2/default.json.bz2 new file mode 100644 index 0000000000000000000000000000000000000000..1947e1ec8b6e45c82ed467367792151f3a06def3 Binary files /dev/null and b/src/ptbench/data/tbx11k_simplified_v2/default.json.bz2 differ diff --git a/src/ptbench/data/tbx11k_simplified_v2/fold_0.json.bz2 b/src/ptbench/data/tbx11k_simplified_v2/fold_0.json.bz2 new file mode 100644 index 0000000000000000000000000000000000000000..beff51ae634b349f40e955ff8739385f51b7b067 Binary files /dev/null and b/src/ptbench/data/tbx11k_simplified_v2/fold_0.json.bz2 differ diff --git a/src/ptbench/data/tbx11k_simplified_v2/fold_1.json.bz2 b/src/ptbench/data/tbx11k_simplified_v2/fold_1.json.bz2 new file mode 100644 index 0000000000000000000000000000000000000000..ed2f98d133fd97135830c6f20e1e0114e1264992 Binary files /dev/null and b/src/ptbench/data/tbx11k_simplified_v2/fold_1.json.bz2 differ diff --git a/src/ptbench/data/tbx11k_simplified_v2/fold_2.json.bz2 b/src/ptbench/data/tbx11k_simplified_v2/fold_2.json.bz2 new file mode 100644 index 0000000000000000000000000000000000000000..f3ab3352a2c3a3f613f084c77458ee36555d52b8 Binary files /dev/null and b/src/ptbench/data/tbx11k_simplified_v2/fold_2.json.bz2 differ diff --git a/src/ptbench/data/tbx11k_simplified_v2/fold_3.json.bz2 b/src/ptbench/data/tbx11k_simplified_v2/fold_3.json.bz2 new file mode 100644 index 0000000000000000000000000000000000000000..f6e0c5c17e74e9171989f6e4fd0185047a78fe36 Binary files /dev/null and b/src/ptbench/data/tbx11k_simplified_v2/fold_3.json.bz2 differ diff --git a/src/ptbench/data/tbx11k_simplified_v2/fold_4.json.bz2 b/src/ptbench/data/tbx11k_simplified_v2/fold_4.json.bz2 new file mode 100644 index 0000000000000000000000000000000000000000..473793dd2bd8b0830132f54c16ff2458871122d6 Binary files /dev/null and b/src/ptbench/data/tbx11k_simplified_v2/fold_4.json.bz2 differ diff --git a/src/ptbench/data/tbx11k_simplified_v2/fold_5.json.bz2 b/src/ptbench/data/tbx11k_simplified_v2/fold_5.json.bz2 new file mode 100644 index 0000000000000000000000000000000000000000..40c492e496a4951a0aee98702b87affbec7d0af8 Binary files /dev/null and b/src/ptbench/data/tbx11k_simplified_v2/fold_5.json.bz2 differ diff --git a/src/ptbench/data/tbx11k_simplified_v2/fold_6.json.bz2 b/src/ptbench/data/tbx11k_simplified_v2/fold_6.json.bz2 new file mode 100644 index 0000000000000000000000000000000000000000..3324f8d899422800184777407159ba38ae0088d4 Binary files /dev/null and b/src/ptbench/data/tbx11k_simplified_v2/fold_6.json.bz2 differ diff --git a/src/ptbench/data/tbx11k_simplified_v2/fold_7.json.bz2 b/src/ptbench/data/tbx11k_simplified_v2/fold_7.json.bz2 new file mode 100644 index 0000000000000000000000000000000000000000..6a435bced6808f41476d8635484d9b3e7356b5ef Binary files /dev/null and b/src/ptbench/data/tbx11k_simplified_v2/fold_7.json.bz2 differ diff --git a/src/ptbench/data/tbx11k_simplified_v2/fold_8.json.bz2 b/src/ptbench/data/tbx11k_simplified_v2/fold_8.json.bz2 new file mode 100644 index 0000000000000000000000000000000000000000..126734afadde81f1d861043b63d331f8b918548a Binary files /dev/null and b/src/ptbench/data/tbx11k_simplified_v2/fold_8.json.bz2 differ diff --git a/src/ptbench/data/tbx11k_simplified_v2/fold_9.json.bz2 b/src/ptbench/data/tbx11k_simplified_v2/fold_9.json.bz2 new file mode 100644 index 0000000000000000000000000000000000000000..cde889ebdeb3bef2bc8b5877447140003300969d Binary files /dev/null and b/src/ptbench/data/tbx11k_simplified_v2/fold_9.json.bz2 differ diff --git a/src/ptbench/data/tbx11k_simplified_v2_RS/__init__.py b/src/ptbench/data/tbx11k_simplified_v2_RS/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f2bcf8c2a53cd117f556c5854774d5ea81fb2f6c --- /dev/null +++ b/src/ptbench/data/tbx11k_simplified_v2_RS/__init__.py @@ -0,0 +1,63 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Extended TBX11K simplified dataset for computer-aided diagnosis (extended +with DensenetRS predictions) + +The TBX11K database has been established to foster research +in computer-aided diagnosis of pulmonary diseases with a special +focus on tuberculosis (aTB). The dataset was specifically +designed to be used with CNNs. It contains 11,000 chest X-ray +images, each of a unique patient. They were labeled by expert +radiologists with 5 - 10+ years of experience. Possible labels +are: "healthy", "active TB", "latent TB", and "sick & non-tb". +The version of the dataset used in this benchmark is a simplified. + +* Reference: [TBX11K-SIMPLIFIED-2020]_ +* Original (released) resolution (height x width or width x height): 512 x 512 +* Split reference: none +* Protocol ``default``: + + * Training samples: 62.6% of CXR (including labels) + * Validation samples: 16% of CXR (including labels) + * Test samples: 21.4% of CXR (including labels) +""" + +import importlib.resources + +from ..dataset import JSONDataset +from ..loader import make_delayed + +_protocols = [ + importlib.resources.files(__name__).joinpath("default.json.bz2"), + importlib.resources.files(__name__).joinpath("fold_0.json.bz2"), + importlib.resources.files(__name__).joinpath("fold_1.json.bz2"), + importlib.resources.files(__name__).joinpath("fold_2.json.bz2"), + importlib.resources.files(__name__).joinpath("fold_3.json.bz2"), + importlib.resources.files(__name__).joinpath("fold_4.json.bz2"), + importlib.resources.files(__name__).joinpath("fold_5.json.bz2"), + importlib.resources.files(__name__).joinpath("fold_6.json.bz2"), + importlib.resources.files(__name__).joinpath("fold_7.json.bz2"), + importlib.resources.files(__name__).joinpath("fold_8.json.bz2"), + importlib.resources.files(__name__).joinpath("fold_9.json.bz2"), +] + + +def _raw_data_loader(sample): + return dict(data=sample["data"], label=sample["label"]) + + +def _loader(context, sample): + # "context" is ignored in this case - database is homogeneous + # we returned delayed samples to avoid loading all images at once + return make_delayed(sample, _raw_data_loader, key=sample["filename"]) + + +dataset = JSONDataset( + protocols=_protocols, + fieldnames=("filename", "label", "data"), + loader=_loader, +) + +"""Extended TBX11K simplified dataset object.""" diff --git a/src/ptbench/data/tbx11k_simplified_v2_RS/default.json.bz2 b/src/ptbench/data/tbx11k_simplified_v2_RS/default.json.bz2 new file mode 100644 index 0000000000000000000000000000000000000000..a9df1825d2349ebaf364884839a0fd3174e27ea2 Binary files /dev/null and b/src/ptbench/data/tbx11k_simplified_v2_RS/default.json.bz2 differ diff --git a/src/ptbench/data/tbx11k_simplified_v2_RS/fold_0.json.bz2 b/src/ptbench/data/tbx11k_simplified_v2_RS/fold_0.json.bz2 new file mode 100644 index 0000000000000000000000000000000000000000..0084bd7fa1149d2e1507e3333a97a714c37e554f Binary files /dev/null and b/src/ptbench/data/tbx11k_simplified_v2_RS/fold_0.json.bz2 differ diff --git a/src/ptbench/data/tbx11k_simplified_v2_RS/fold_1.json.bz2 b/src/ptbench/data/tbx11k_simplified_v2_RS/fold_1.json.bz2 new file mode 100644 index 0000000000000000000000000000000000000000..e8084fcdfd515b3c0ef78cac625b35da6ff46e88 Binary files /dev/null and b/src/ptbench/data/tbx11k_simplified_v2_RS/fold_1.json.bz2 differ diff --git a/src/ptbench/data/tbx11k_simplified_v2_RS/fold_2.json.bz2 b/src/ptbench/data/tbx11k_simplified_v2_RS/fold_2.json.bz2 new file mode 100644 index 0000000000000000000000000000000000000000..2328c509c0695443a43bc1bee2b9a3d7746282c0 Binary files /dev/null and b/src/ptbench/data/tbx11k_simplified_v2_RS/fold_2.json.bz2 differ diff --git a/src/ptbench/data/tbx11k_simplified_v2_RS/fold_3.json.bz2 b/src/ptbench/data/tbx11k_simplified_v2_RS/fold_3.json.bz2 new file mode 100644 index 0000000000000000000000000000000000000000..ae9712900b960ddf504c6b7544492d2c19a8c42d Binary files /dev/null and b/src/ptbench/data/tbx11k_simplified_v2_RS/fold_3.json.bz2 differ diff --git a/src/ptbench/data/tbx11k_simplified_v2_RS/fold_4.json.bz2 b/src/ptbench/data/tbx11k_simplified_v2_RS/fold_4.json.bz2 new file mode 100644 index 0000000000000000000000000000000000000000..86d9c595977c68f097e3711f005188e42090a5e4 Binary files /dev/null and b/src/ptbench/data/tbx11k_simplified_v2_RS/fold_4.json.bz2 differ diff --git a/src/ptbench/data/tbx11k_simplified_v2_RS/fold_5.json.bz2 b/src/ptbench/data/tbx11k_simplified_v2_RS/fold_5.json.bz2 new file mode 100644 index 0000000000000000000000000000000000000000..84f479c72ad0667f8baa1ce666e540f6ee00762a Binary files /dev/null and b/src/ptbench/data/tbx11k_simplified_v2_RS/fold_5.json.bz2 differ diff --git a/src/ptbench/data/tbx11k_simplified_v2_RS/fold_6.json.bz2 b/src/ptbench/data/tbx11k_simplified_v2_RS/fold_6.json.bz2 new file mode 100644 index 0000000000000000000000000000000000000000..48d52e3b0ed6708953c6b591bf0d27ae64e324cb Binary files /dev/null and b/src/ptbench/data/tbx11k_simplified_v2_RS/fold_6.json.bz2 differ diff --git a/src/ptbench/data/tbx11k_simplified_v2_RS/fold_7.json.bz2 b/src/ptbench/data/tbx11k_simplified_v2_RS/fold_7.json.bz2 new file mode 100644 index 0000000000000000000000000000000000000000..8ac08974c73c84ef4356b884cf48a6e777e417e0 Binary files /dev/null and b/src/ptbench/data/tbx11k_simplified_v2_RS/fold_7.json.bz2 differ diff --git a/src/ptbench/data/tbx11k_simplified_v2_RS/fold_8.json.bz2 b/src/ptbench/data/tbx11k_simplified_v2_RS/fold_8.json.bz2 new file mode 100644 index 0000000000000000000000000000000000000000..84dad2eb65508361e11d2f826494e25a0dbae6f8 Binary files /dev/null and b/src/ptbench/data/tbx11k_simplified_v2_RS/fold_8.json.bz2 differ diff --git a/src/ptbench/data/tbx11k_simplified_v2_RS/fold_9.json.bz2 b/src/ptbench/data/tbx11k_simplified_v2_RS/fold_9.json.bz2 new file mode 100644 index 0000000000000000000000000000000000000000..0c23b2ca6834ab68988d97a1ccd857f6d40c9b57 Binary files /dev/null and b/src/ptbench/data/tbx11k_simplified_v2_RS/fold_9.json.bz2 differ diff --git a/src/ptbench/engine/callbacks.py b/src/ptbench/engine/callbacks.py new file mode 100644 index 0000000000000000000000000000000000000000..b266ae6221cf9a925ff941f1c99bdfdd044fa23f --- /dev/null +++ b/src/ptbench/engine/callbacks.py @@ -0,0 +1,107 @@ +import csv +import time + +import numpy + +from lightning.pytorch import Callback +from lightning.pytorch.callbacks import BasePredictionWriter + + +# This ensures CSVLogger logs training and evaluation metrics on the same line +# CSVLogger only accepts numerical values, not strings +class LoggingCallback(Callback): + """Lightning callback to log various training metrics and device + information.""" + + def __init__(self, resource_monitor): + super().__init__() + self.training_loss = [] + self.validation_loss = [] + self.start_training_time = 0 + self.start_epoch_time = 0 + + self.resource_monitor = resource_monitor + self.max_queue_retries = 2 + + def on_train_start(self, trainer, pl_module): + self.start_training_time = time.time() + + def on_train_epoch_start(self, trainer, pl_module): + self.start_epoch_time = time.time() + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + self.training_loss.append(outputs["loss"].item()) + + def on_validation_batch_end( + self, trainer, pl_module, outputs, batch, batch_idx + ): + self.validation_loss.append(outputs["validation_loss"].item()) + + def on_validation_epoch_end(self, trainer, pl_module): + self.resource_monitor.trigger_summary() + + self.epoch_time = time.time() - self.start_epoch_time + eta_seconds = self.epoch_time * ( + trainer.max_epochs - trainer.current_epoch + ) + current_time = time.time() - self.start_training_time + + self.log("total_time", current_time) + self.log("eta", eta_seconds) + self.log("loss", numpy.average(self.training_loss)) + self.log("learning_rate", pl_module.hparams["optimizer_configs"]["lr"]) + self.log("validation_loss", numpy.average(self.validation_loss)) + + queue_retries = 0 + # In case the resource monitor takes longer to fetch data from the queue, we wait + # Give up after self.resource_monitor.interval * self.max_queue_retries if cannot retrieve metrics from queue + while ( + self.resource_monitor.data is None + and queue_retries < self.max_queue_retries + ): + queue_retries = queue_retries + 1 + print( + f"Monitor queue is empty, retrying in {self.resource_monitor.interval}s" + ) + time.sleep(self.resource_monitor.interval) + + if queue_retries >= self.max_queue_retries: + print( + f"Unable to fetch monitoring information from queue after {queue_retries} retries" + ) + + assert self.resource_monitor.q.empty() + + for metric_name, metric_value in self.resource_monitor.data: + self.log(metric_name, float(metric_value)) + + self.resource_monitor.data = None + + self.training_loss = [] + self.validation_loss = [] + + +class PredictionsWriter(BasePredictionWriter): + """Lightning callback to write predictions to a file.""" + + def __init__(self, logfile_name, logfile_fields, write_interval): + super().__init__(write_interval) + self.logfile_name = logfile_name + self.logfile_fields = logfile_fields + + def write_on_epoch_end( + self, trainer, pl_module, predictions, batch_indices + ): + with open(self.logfile_name, "w") as logfile: + logwriter = csv.DictWriter(logfile, fieldnames=self.logfile_fields) + logwriter.writeheader() + + for prediction in predictions: + logwriter.writerow( + { + "filename": prediction[0], + "likelihood": prediction[1].numpy(), + "ground_truth": prediction[2].numpy(), + } + ) + logfile.flush() diff --git a/src/ptbench/engine/predictor.py b/src/ptbench/engine/predictor.py index 6c4dd4af73bcd1ff868a2f20964f9d7980ddd2a5..dc037af0679d3493a718a26dc393a61d00659a27 100644 --- a/src/ptbench/engine/predictor.py +++ b/src/ptbench/engine/predictor.py @@ -2,272 +2,77 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -import csv -import datetime import logging import os -import shutil -import time -import matplotlib.pyplot as plt -import numpy -import PIL -import torch +from lightning.pytorch import Trainer -from matplotlib.gridspec import GridSpec -from matplotlib.patches import Rectangle -from torchvision import transforms -from tqdm import tqdm - -from ..utils.grad_cams import GradCAM +from ..utils.accelerator import AcceleratorProcessor +from .callbacks import PredictionsWriter logger = logging.getLogger(__name__) -colors = [ - [(47, 79, 79), "Cardiomegaly"], - [(255, 0, 0), "Emphysema"], - [(0, 128, 0), "Pleural effusion"], - [(0, 0, 128), "Hernia"], - [(255, 84, 0), "Infiltration"], - [(222, 184, 135), "Mass"], - [(0, 255, 0), "Nodule"], - [(0, 191, 255), "Atelectasis"], - [(0, 0, 255), "Pneumothorax"], - [(255, 0, 255), "Pleural thickening"], - [(255, 255, 0), "Pneumonia"], - [(126, 0, 255), "Fibrosis"], - [(255, 20, 147), "Edema"], - [(0, 255, 180), "Consolidation"], -] - -def run(model, data_loader, name, device, output_folder, grad_cams=False): - """Runs inference on input data, outputs HDF5 files with predictions. +def run(model, data_loader, name, accelerator, output_folder, grad_cams=False): + """Runs inference on input data, outputs csv files with predictions. Parameters --------- model : :py:class:`torch.nn.Module` - neural network model (e.g. pasa) + Neural network model (e.g. pasa). data_loader : py:class:`torch.torch.utils.data.DataLoader` + The pytorch Dataloader used to iterate over batches. name : str - the local name of this dataset (e.g. ``train``, or ``test``), to be + The local name of this dataset (e.g. ``train``, or ``test``), to be used when saving measures files. - device : str - device to use ``cpu`` or ``cuda:0`` + accelerator : str + A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The device can also be specified (gpu:0) output_folder : str - folder where to store output prediction and model - summary + Directory in which the results will be saved. grad_cams : bool - if we export grad cams for every prediction (must be used along - a batch size of 1 with the DensenetRS model) + If we export grad cams for every prediction (must be used along + a batch size of 1 with the DensenetRS model). Returns ------- all_predictions : list - All the predictions associated with filename and groundtruth + All the predictions associated with filename and ground truth. """ output_folder = os.path.join(output_folder, name) logger.info(f"Output folder: {output_folder}") os.makedirs(output_folder, exist_ok=True) - logger.info(f"Device: {device}") - - logfile_name = os.path.join(output_folder, "predictions.csv") - logfile_fields = ("filename", "likelihood", "ground_truth") - - if os.path.exists(logfile_name): - backup = logfile_name + "~" - if os.path.exists(backup): - os.unlink(backup) - shutil.move(logfile_name, backup) - - if grad_cams: - grad_folder = os.path.join(output_folder, "cams") - logger.info(f"Grad cams folder: {grad_folder}") - os.makedirs(grad_folder, exist_ok=True) - - with open(logfile_name, "a+", newline="") as logfile: - logwriter = csv.DictWriter(logfile, fieldnames=logfile_fields) - - logwriter.writeheader() - - model.eval() # set evaluation mode - model.to(device) # set/cast parameters to device - - # Setup timers - start_total_time = time.time() - times = [] - len_samples = [] - - all_predictions = [] - - for samples in tqdm( - data_loader, - desc="batches", - leave=False, - disable=None, - ): - names = samples[0] - images = samples[1].to( - device=device, non_blocking=torch.cuda.is_available() - ) - - # Gradcams generation - allowed_models = ["DensenetRS"] - if grad_cams and model.name in allowed_models: - gcam = GradCAM(model=model) - probs, ids = gcam.forward(images) - - # To store signs overlays - cams_img = dict() - - # Top k number of radiological signs for which we generate cams - topk = 1 - - for i in range(topk): - # Keep only "positive" signs - if probs[:, [i]] > 0.5: - # Grad-CAM - b = ids[:, [i]] - gcam.backward(ids=ids[:, [i]]) - regions = gcam.generate( - target_layer="model_ft.features.denseblock4.denselayer16.conv2" - ) - - for j in range(len(images)): - current_cam = regions[j, 0].cpu().numpy() - current_cam[current_cam < 0.75] = 0.0 - current_cam[current_cam >= 0.75] = 1.0 - current_cam = PIL.Image.fromarray( - numpy.uint8(current_cam * 255), "L" - ) - cams_img[b.item()] = [ - current_cam, - round(probs[:, [i]].item(), 2), - ] + accelerator_processor = AcceleratorProcessor(accelerator) - if len(cams_img) > 0: - # Convert original image tensor into PIL Image - original_image = transforms.ToPILImage(mode="RGB")( - images[0] - ) + if accelerator_processor.device is None: + devices = "auto" + else: + devices = accelerator_processor.device - for sign_id, label_prob in cams_img.items(): - label = label_prob[0] + logger.info(f"Device: {devices}") - # Create the colored overlay for current sign - colored_sign = PIL.ImageOps.colorize( - label.convert("L"), (0, 0, 0), colors[sign_id][0] - ) - - # blend image and label together - first blend to get signs drawn with a - # slight "label_color" tone on top, then composite with original image, to - # avoid loosing brightness. - retval = PIL.Image.blend( - original_image, colored_sign, 0.5 - ) - composite_mask = PIL.ImageOps.invert(label.convert("L")) - original_image = PIL.Image.composite( - original_image, retval, composite_mask - ) - - handles = [] - labels = [] - for i, v in enumerate(colors): - # If sign present on image - if cams_img.get(i) is not None: - handles.append( - Rectangle( - (0, 0), - 1, - 1, - color=tuple(v / 255 for v in v[0]), - ) - ) - labels.append( - v[1] + " (" + str(cams_img[i][1]) + ")" - ) - - gs = GridSpec(6, 1) - fig = plt.figure(figsize=(10, 11)) - ax1 = fig.add_subplot(gs[:-1, :]) # For the plot - ax2 = fig.add_subplot(gs[-1, :]) # For the legend - - ax1.imshow(original_image) - ax1.axis("off") - ax2.legend( - handles, labels, mode="expand", ncol=3, frameon=False - ) - ax2.axis("off") - - original_filename = ( - samples[0][0].split("/")[-1].split(".")[0] - ) - cam_filename = os.path.join( - grad_folder, original_filename + "_cam.png" - ) - fig.savefig(cam_filename) - - with torch.no_grad(): - start_time = time.perf_counter() - outputs = model(images) - probabilities = torch.sigmoid(outputs) - - # necessary check for HED architecture that uses several outputs - # for loss calculation instead of just the last concatfuse block - if isinstance(outputs, list): - outputs = outputs[-1] - - # predictions = sigmoid(outputs) - - batch_time = time.perf_counter() - start_time - times.append(batch_time) - len_samples.append(len(images)) - - logdata = ( - ("filename", f"{names[0]}"), - ( - "likelihood", - f"{torch.flatten(probabilities).data.cpu().numpy()}", - ), - ( - "ground_truth", - f"{torch.flatten(samples[2]).data.cpu().numpy()}", - ), - ) - - logwriter.writerow(dict(k for k in logdata)) - logfile.flush() - tqdm.write(" | ".join([f"{k}: {v}" for (k, v) in logdata[:4]])) - - # Keep prediction for relevance analysis - all_predictions.append( - [ - names[0], - torch.flatten(probabilities).data.cpu().numpy(), - torch.flatten(samples[2]).data.cpu().numpy(), - ] - ) - - # report operational summary - total_time = datetime.timedelta( - seconds=int(time.time() - start_total_time) - ) - logger.info(f"Total time: {total_time}") - - average_batch_time = numpy.mean(times) - logger.info(f"Average batch time: {average_batch_time:g}s") - - average_image_time = numpy.sum( - numpy.array(times) * len_samples - ) / float(sum(len_samples)) - logger.info(f"Average image time: {average_image_time:g}s") + logfile_name = os.path.join(output_folder, "predictions.csv") + logfile_fields = ("filename", "likelihood", "ground_truth") - return all_predictions + trainer = Trainer( + accelerator=accelerator_processor.accelerator, + devices=devices, + callbacks=[ + PredictionsWriter( + logfile_name=logfile_name, + logfile_fields=logfile_fields, + write_interval="epoch", + ), + ], + ) + + all_predictions = trainer.predict(model, data_loader) + + return all_predictions diff --git a/src/ptbench/engine/trainer.py b/src/ptbench/engine/trainer.py index fd48d4f1fdd82eea05d28dfff7d035e5694a08fd..a85a3da566691922323fbeb8a56d3472def48389 100644 --- a/src/ptbench/engine/trainer.py +++ b/src/ptbench/engine/trainer.py @@ -2,53 +2,23 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -import contextlib import csv -import datetime import logging import os import shutil -import sys -import time -import numpy -import torch +from lightning.pytorch import Trainer +from lightning.pytorch.callbacks import ModelCheckpoint +from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger +from lightning.pytorch.utilities.model_summary import ModelSummary -from tqdm import tqdm - -# from ..utils.resources import cpu_constants, gpu_constants, cpu_log, gpu_log +from ..utils.accelerator import AcceleratorProcessor from ..utils.resources import ResourceMonitor, cpu_constants, gpu_constants -from ..utils.summary import summary +from .callbacks import LoggingCallback logger = logging.getLogger(__name__) -@contextlib.contextmanager -def torch_evaluation(model): - """Context manager to turn ON/OFF model evaluation. - - This context manager will turn evaluation mode ON on entry and turn it OFF - when exiting the ``with`` statement block. - - - Parameters - ---------- - - model : :py:class:`torch.nn.Module` - Network - - - Yields - ------ - - model : :py:class:`torch.nn.Module` - Network - """ - model.eval() - yield model - model.train() - - def check_gpu(device): """Check the device type and the availability of GPU. @@ -58,52 +28,13 @@ def check_gpu(device): device : :py:class:`torch.device` device to use """ - if device.type == "cuda": + if device == "cuda": # asserts we do have a GPU assert bool( gpu_constants() ), f"Device set to '{device}', but nvidia-smi is not installed" -def initialize_lowest_validation_loss(logfile_name, arguments): - """Initialize the lowest validation loss from the logfile if it exists and - if the training does not start from epoch 0, which means that a previous - training session is resumed. - - Parameters - ---------- - - logfile_name : str - The logfile_name which is a join between the output_folder and trainlog.csv - - arguments : dict - start and end epochs - """ - - if arguments["epoch"] != 0 and os.path.exists(logfile_name): - # Open the CSV file - with open(logfile_name) as file: - reader = csv.DictReader(file) - column_name = "validation_loss" - - if column_name not in reader.fieldnames: - return sys.float_info.max - - # Get the values of the desired column as a list - values = [float(row[column_name]) for row in reader] - - if not values: - return sys.float_info.max - - lowest_value = min(values) - logger.info( - f"Found lowest validation loss from previous session: {lowest_value}" - ) - return lowest_value - - return sys.float_info.max - - def save_model_summary(output_folder, model): """Save a little summary of the model in a txt file. @@ -127,10 +58,9 @@ def save_model_summary(output_folder, model): summary_path = os.path.join(output_folder, "model_summary.txt") logger.info(f"Saving model summary at {summary_path}...") with open(summary_path, "w") as f: - r, n = summary(model) - logger.info(f"Model has {n} parameters...") - f.write(r) - return r, n + summary = ModelSummary(model, max_depth=-1) + f.write(str(summary)) + return summary, ModelSummary(model).total_parameters def static_information_to_csv(static_logfile_name, device, n): @@ -149,7 +79,7 @@ def static_information_to_csv(static_logfile_name, device, n): shutil.move(static_logfile_name, backup) with open(static_logfile_name, "w", newline="") as f: logdata = cpu_constants() - if device.type == "cuda": + if device == "cuda": logdata += gpu_constants() logdata += (("model_size", n),) logwriter = csv.DictWriter(f, fieldnames=[k[0] for k in logdata]) @@ -213,313 +143,22 @@ def create_logfile_fields(valid_loader, extra_valid_loaders, device): logfile_fields += ("validation_loss",) if extra_valid_loaders: logfile_fields += ("extra_validation_losses",) - logfile_fields += tuple( - ResourceMonitor.monitored_keys(device.type == "cuda") - ) + logfile_fields += tuple(ResourceMonitor.monitored_keys(device == "cuda")) return logfile_fields -def train_epoch(loader, model, optimizer, device, criterion, batch_chunk_count): - """Trains the model for a single epoch (through all batches) - - Parameters - ---------- - - loader : :py:class:`torch.utils.data.DataLoader` - To be used to train the model - - model : :py:class:`torch.nn.Module` - Network (e.g. driu, hed, unet) - - optimizer : :py:mod:`torch.optim` - - device : :py:class:`torch.device` - device to use - - criterion : :py:class:`torch.nn.modules.loss._Loss` - - batch_chunk_count: int - If this number is different than 1, then each batch will be divided in - this number of chunks. Gradients will be accumulated to perform each - mini-batch. This is particularly interesting when one has limited RAM - on the GPU, but would like to keep training with larger batches. One - exchanges for longer processing times in this case. To better understand - gradient accumulation, read - https://stackoverflow.com/questions/62067400/understanding-accumulated-gradients-in-pytorch. - - - Returns - ------- - - loss : float - A floating-point value corresponding the weighted average of this - epoch's loss - """ - losses_in_epoch = [] - samples_in_epoch = [] - losses_in_batch = [] - samples_in_batch = [] - - # progress bar only on interactive jobs - for idx, samples in enumerate( - tqdm(loader, desc="train", leave=False, disable=None) - ): - images = samples[1].to( - device=device, non_blocking=torch.cuda.is_available() - ) - labels = samples[2].to( - device=device, non_blocking=torch.cuda.is_available() - ) - - # Increase label dimension if too low - # Allows single and multiclass usage - if labels.ndim == 1: - labels = torch.reshape(labels, (labels.shape[0], 1)) - - # Forward pass on the network - outputs = model(images) - - loss = criterion(outputs, labels.double()) - - losses_in_batch.append(loss.item()) - samples_in_batch.append(len(samples)) - - # Normalize loss to account for batch accumulation - loss = loss / batch_chunk_count - - # Accumulate gradients - does not update weights just yet... - loss.backward() - - # Weight update on the network - if ((idx + 1) % batch_chunk_count == 0) or (idx + 1 == len(loader)): - # Advances optimizer to the "next" state and applies weight update - # over the whole model - optimizer.step() - - # Zeroes gradients for the next batch - optimizer.zero_grad() - - # Normalize loss for current batch - batch_loss = numpy.average( - losses_in_batch, weights=samples_in_batch - ) - losses_in_epoch.append(batch_loss.item()) - samples_in_epoch.append(len(samples)) - - losses_in_batch.clear() - samples_in_batch.clear() - logger.debug(f"batch loss: {batch_loss.item()}") - - return numpy.average(losses_in_epoch, weights=samples_in_epoch) - - -def validate_epoch(loader, model, device, criterion, pbar_desc): - """Processes input samples and returns loss (scalar) - - Parameters - ---------- - - loader : :py:class:`torch.utils.data.DataLoader` - To be used to validate the model - - model : :py:class:`torch.nn.Module` - Network (e.g. driu, hed, unet) - - optimizer : :py:mod:`torch.optim` - - device : :py:class:`torch.device` - device to use - - criterion : :py:class:`torch.nn.modules.loss._Loss` - loss function - - pbar_desc : str - A string for the progress bar descriptor - - - Returns - ------- - - loss : float - A floating-point value corresponding the weighted average of this - epoch's loss - """ - batch_losses = [] - samples_in_batch = [] - - with torch.no_grad(), torch_evaluation(model): - for samples in tqdm(loader, desc=pbar_desc, leave=False, disable=None): - images = samples[1].to( - device=device, - non_blocking=torch.cuda.is_available(), - ) - labels = samples[2].to( - device=device, - non_blocking=torch.cuda.is_available(), - ) - - # Increase label dimension if too low - # Allows single and multiclass usage - if labels.ndim == 1: - labels = torch.reshape(labels, (labels.shape[0], 1)) - - # data forwarding on the existing network - outputs = model(images) - loss = criterion(outputs, labels.double()) - - batch_losses.append(loss.item()) - samples_in_batch.append(len(samples)) - - return numpy.average(batch_losses, weights=samples_in_batch) - - -def checkpointer_process( - checkpointer, - checkpoint_period, - valid_loss, - lowest_validation_loss, - arguments, - epoch, - max_epoch, -): - """Process the checkpointer, save the final model and keep track of the - best model. - - Parameters - ---------- - - checkpointer : :py:class:`ptbench.utils.checkpointer.Checkpointer` - checkpointer implementation - - checkpoint_period : int - save a checkpoint every ``n`` epochs. If set to ``0`` (zero), then do - not save intermediary checkpoints - - valid_loss : float - Current epoch validation loss - - lowest_validation_loss : float - Keeps track of the best (lowest) validation loss - - arguments : dict - start and end epochs - - max_epoch : int - end_potch - - Returns - ------- - - lowest_validation_loss : float - The lowest validation loss currently observed - """ - if checkpoint_period and (epoch % checkpoint_period == 0): - checkpointer.save("model_periodic_save", **arguments) - - if valid_loss is not None and valid_loss < lowest_validation_loss: - lowest_validation_loss = valid_loss - logger.info( - f"Found new low on validation set:" f" {lowest_validation_loss:.6f}" - ) - checkpointer.save("model_lowest_valid_loss", **arguments) - - if epoch >= max_epoch: - checkpointer.save("model_final_epoch", **arguments) - - return lowest_validation_loss - - -def write_log_info( - epoch, - current_time, - eta_seconds, - loss, - valid_loss, - extra_valid_losses, - optimizer, - logwriter, - logfile, - resource_data, -): - """Write log info in trainlog.csv. - - Parameters - ---------- - - epoch : int - Current epoch - - current_time : float - Current training time - - eta_seconds : float - estimated time-of-arrival taking into consideration previous epoch performance - - loss : float - Current epoch's training loss - - valid_loss : :py:class:`float`, None - Current epoch's validation loss - - extra_valid_losses : :py:class:`list` of :py:class:`float` - Validation losses from other validation datasets being currently - tracked - - optimizer : :py:mod:`torch.optim` - - logwriter : csv.DictWriter - Dictionary writer that give the ability to write on the trainlog.csv - - logfile : io.TextIOWrapper - - resource_data : tuple - Monitored resources at the machine (CPU and GPU) - """ - - logdata = ( - ("epoch", f"{epoch}"), - ( - "total_time", - f"{datetime.timedelta(seconds=int(current_time))}", - ), - ("eta", f"{datetime.timedelta(seconds=int(eta_seconds))}"), - ("loss", f"{loss:.6f}"), - ("learning_rate", f"{optimizer.param_groups[0]['lr']:.6f}"), - ) - - if valid_loss is not None: - logdata += (("validation_loss", f"{valid_loss:.6f}"),) - - if extra_valid_losses: - entry = numpy.array_str( - numpy.array(extra_valid_losses), - max_line_width=sys.maxsize, - precision=6, - ) - logdata += (("extra_validation_losses", entry),) - - logdata += resource_data - - logwriter.writerow(dict(k for k in logdata)) - logfile.flush() - tqdm.write("|".join([f"{k}: {v}" for (k, v) in logdata[:4]])) - - def run( model, data_loader, valid_loader, extra_valid_loaders, - optimizer, - criterion, - checkpointer, checkpoint_period, - device, + accelerator, arguments, output_folder, monitoring_interval, batch_chunk_count, - criterion_valid, + checkpoint, ): """Fits a CNN model using supervised learning and save it to disk. @@ -531,10 +170,10 @@ def run( ---------- model : :py:class:`torch.nn.Module` - Network (e.g. driu, hed, unet) + Neural network model (e.g. pasa). data_loader : :py:class:`torch.utils.data.DataLoader` - To be used to train the model + The pytorch Dataloader used to iterate over batches. valid_loaders : :py:class:`list` of :py:class:`torch.utils.data.DataLoader` To be used to validate the model and enable automatic checkpointing. @@ -546,29 +185,21 @@ def run( an extra column with the loss of every dataset in this list is kept on the final training log. - optimizer : :py:mod:`torch.optim` - - criterion : :py:class:`torch.nn.modules.loss._Loss` - loss function - - checkpointer : :py:class:`ptbench.utils.checkpointer.Checkpointer` - checkpointer implementation - checkpoint_period : int - save a checkpoint every ``n`` epochs. If set to ``0`` (zero), then do - not save intermediary checkpoints + Save a checkpoint every ``n`` epochs. If set to ``0`` (zero), then do + not save intermediary checkpoints. - device : :py:class:`torch.device` - device to use + accelerator : str + A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The device can also be specified (gpu:0) arguments : dict - start and end epochs + Start and end epochs: output_folder : str - output path + Directory in which the results will be saved. monitoring_interval : int, float - interval, in seconds (or fractions), through which we should monitor + Interval, in seconds (or fractions), through which we should monitor resources during training. batch_chunk_count: int @@ -577,137 +208,59 @@ def run( mini-batch. This is particularly interesting when one has limited RAM on the GPU, but would like to keep training with larger batches. One exchanges for longer processing times in this case. - - criterion_valid : :py:class:`torch.nn.modules.loss._Loss` - specific loss function for the validation set """ - start_epoch = arguments["epoch"] max_epoch = arguments["max_epoch"] - check_gpu(device) + accelerator_processor = AcceleratorProcessor(accelerator) os.makedirs(output_folder, exist_ok=True) # Save model summary r, n = save_model_summary(output_folder, model) - # write static information to a CSV file - static_logfile_name = os.path.join(output_folder, "constants.csv") - - static_information_to_csv(static_logfile_name, device, n) - - # Log continous information to (another) file - logfile_name = os.path.join(output_folder, "trainlog.csv") + csv_logger = CSVLogger(output_folder, "logs_csv") + tensorboard_logger = TensorBoardLogger(output_folder, "logs_tensorboard") - check_exist_logfile(logfile_name, arguments) + resource_monitor = ResourceMonitor( + interval=monitoring_interval, + has_gpu=(accelerator_processor.accelerator == "gpu"), + main_pid=os.getpid(), + logging_level=logging.ERROR, + ) - logfile_fields = create_logfile_fields( - valid_loader, extra_valid_loaders, device + checkpoint_callback = ModelCheckpoint( + output_folder, + "model_lowest_valid_loss", + save_last=True, + monitor="validation_loss", + mode="min", + save_on_train_epoch_end=False, + every_n_epochs=checkpoint_period, ) - # the lowest validation loss obtained so far - this value is updated only - # if a validation set is available - lowest_validation_loss = initialize_lowest_validation_loss( - logfile_name, arguments + checkpoint_callback.CHECKPOINT_NAME_LAST = "model_final_epoch" + + # write static information to a CSV file + static_logfile_name = os.path.join(output_folder, "constants.csv") + static_information_to_csv( + static_logfile_name, accelerator_processor.to_torch(), n ) - # set a specific validation criterion if the user has set one - criterion_valid = criterion_valid or criterion - - with open(logfile_name, "a+", newline="") as logfile: - logwriter = csv.DictWriter(logfile, fieldnames=logfile_fields) - - if arguments["epoch"] == 0: - logwriter.writeheader() - - model.train() # set training mode - - model.to(device) # set/cast parameters to device - for state in optimizer.state.values(): - for k, v in state.items(): - if isinstance(v, torch.Tensor): - state[k] = v.to(device) - - # Total training timer - start_training_time = time.time() - - for epoch in tqdm( - range(start_epoch, max_epoch), - desc="epoch", - leave=False, - disable=None, - ): - with ResourceMonitor( - interval=monitoring_interval, - has_gpu=(device.type == "cuda"), - main_pid=os.getpid(), - logging_level=logging.ERROR, - ) as resource_monitor: - epoch = epoch + 1 - arguments["epoch"] = epoch - - # Epoch time - start_epoch_time = time.time() - - train_loss = train_epoch( - data_loader, - model, - optimizer, - device, - criterion, - batch_chunk_count, - ) - - valid_loss = ( - validate_epoch( - valid_loader, model, device, criterion_valid, "valid" - ) - if valid_loader is not None - else None - ) - - extra_valid_losses = [] - for pos, extra_valid_loader in enumerate(extra_valid_loaders): - loss = validate_epoch( - extra_valid_loader, - model, - device, - criterion_valid, - f"xval@{pos+1}", - ) - extra_valid_losses.append(loss) - - lowest_validation_loss = checkpointer_process( - checkpointer, - checkpoint_period, - valid_loss, - lowest_validation_loss, - arguments, - epoch, - max_epoch, - ) - - # computes ETA (estimated time-of-arrival; end of training) taking - # into consideration previous epoch performance - epoch_time = time.time() - start_epoch_time - eta_seconds = epoch_time * (max_epoch - epoch) - current_time = time.time() - start_training_time - - write_log_info( - epoch, - current_time, - eta_seconds, - train_loss, - valid_loss, - extra_valid_losses, - optimizer, - logwriter, - logfile, - resource_monitor.data, - ) - - total_training_time = time.time() - start_training_time - logger.info( - f"Total training time: {datetime.timedelta(seconds=total_training_time)} ({(total_training_time/max_epoch):.4f}s in average per epoch)" + if accelerator_processor.device is None: + devices = "auto" + else: + devices = accelerator_processor.device + + with resource_monitor: + trainer = Trainer( + accelerator=accelerator_processor.accelerator, + devices=devices, + max_epochs=max_epoch, + accumulate_grad_batches=batch_chunk_count, + logger=[csv_logger, tensorboard_logger], + check_val_every_n_epoch=1, + callbacks=[LoggingCallback(resource_monitor), checkpoint_callback], ) + + _ = trainer.fit(model, data_loader, valid_loader, ckpt_path=checkpoint) diff --git a/src/ptbench/models/alexnet.py b/src/ptbench/models/alexnet.py index ea096ecbdb51f05679d37594fa41d7c4788d8874..e871a982ea393c919aab6c819ef2dd2bb70fcc96 100644 --- a/src/ptbench/models/alexnet.py +++ b/src/ptbench/models/alexnet.py @@ -2,62 +2,107 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from collections import OrderedDict - +import lightning.pytorch as pl +import torch import torch.nn as nn import torchvision.models as models from .normalizer import TorchVisionNormalizer -class Alexnet(nn.Module): +class Alexnet(pl.LightningModule): """Alexnet module. Note: only usable with a normalized dataset """ - def __init__(self, pretrained=False): + def __init__( + self, + criterion, + criterion_valid, + optimizer, + optimizer_configs, + pretrained=False, + ): super().__init__() + self.save_hyperparameters(ignore=["criterion", "criterion_valid"]) + + self.name = "AlexNet" + # Load pretrained model weights = ( None if pretrained is False else models.AlexNet_Weights.DEFAULT ) self.model_ft = models.alexnet(weights=weights) + self.normalizer = TorchVisionNormalizer(nb_channels=1) + # Adapt output features self.model_ft.classifier[4] = nn.Linear(4096, 512) self.model_ft.classifier[6] = nn.Linear(512, 1) def forward(self, x): - """ + x = self.normalizer(x) + x = self.model_ft(x) - Parameters - ---------- + return x - x : list - list of tensors. + def training_step(self, batch, batch_idx): + images = batch[1] + labels = batch[2] - Returns - ------- + # Increase label dimension if too low + # Allows single and multiclass usage + if labels.ndim == 1: + labels = torch.reshape(labels, (labels.shape[0], 1)) - tensor : :py:class:`torch.Tensor` + # Forward pass on the network + outputs = self(images) - """ - return self.model_ft(x) + # Manually move criterion to selected device, since not part of the model. + self.hparams.criterion = self.hparams.criterion.to(self.device) + training_loss = self.hparams.criterion(outputs, labels.float()) + return {"loss": training_loss} -def build_alexnet(pretrained=False): - """Build Alexnet CNN. + def validation_step(self, batch, batch_idx): + images = batch[1] + labels = batch[2] - Returns - ------- + # Increase label dimension if too low + # Allows single and multiclass usage + if labels.ndim == 1: + labels = torch.reshape(labels, (labels.shape[0], 1)) - module : :py:class:`torch.nn.Module` - """ - model = Alexnet(pretrained=pretrained) - model = [("normalizer", TorchVisionNormalizer()), ("model", model)] - model = nn.Sequential(OrderedDict(model)) + # data forwarding on the existing network + outputs = self(images) + + # Manually move criterion to selected device, since not part of the model. + self.hparams.criterion_valid = self.hparams.criterion_valid.to( + self.device + ) + validation_loss = self.hparams.criterion_valid(outputs, labels.float()) + + return {"validation_loss": validation_loss} + + def predict_step(self, batch, batch_idx, grad_cams=False): + names = batch[0] + images = batch[1] + + outputs = self(images) + probabilities = torch.sigmoid(outputs) + + # necessary check for HED architecture that uses several outputs + # for loss calculation instead of just the last concatfuse block + if isinstance(outputs, list): + outputs = outputs[-1] + + return names[0], torch.flatten(probabilities), torch.flatten(batch[2]) + + def configure_optimizers(self): + optimizer = getattr(torch.optim, self.hparams.optimizer)( + self.parameters(), **self.hparams.optimizer_configs + ) - model.name = "AlexNet" - return model + return optimizer diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py index 7a98acac0c9b2ab3bd8994426807e2f99da9f7d0..ea6e623c3a9cdfc3f9d6896ea77a681f7a2f5cc7 100644 --- a/src/ptbench/models/densenet.py +++ b/src/ptbench/models/densenet.py @@ -2,23 +2,37 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from collections import OrderedDict - +import lightning.pytorch as pl +import torch import torch.nn as nn import torchvision.models as models from .normalizer import TorchVisionNormalizer -class Densenet(nn.Module): +class Densenet(pl.LightningModule): """Densenet module. Note: only usable with a normalized dataset """ - def __init__(self, pretrained=False): + def __init__( + self, + criterion, + criterion_valid, + optimizer, + optimizer_configs, + pretrained=False, + nb_channels=3, + ): super().__init__() + self.save_hyperparameters(ignore=["criterion", "criterion_valid"]) + + self.name = "Densenet" + + self.normalizer = TorchVisionNormalizer(nb_channels=nb_channels) + # Load pretrained model weights = None if not pretrained else models.DenseNet121_Weights.DEFAULT self.model_ft = models.densenet121(weights=weights) @@ -29,37 +43,67 @@ class Densenet(nn.Module): ) def forward(self, x): - """ + x = self.normalizer(x) + x = self.model_ft(x) - Parameters - ---------- + return x - x : list - list of tensors. + def training_step(self, batch, batch_idx): + images = batch[1] + labels = batch[2] - Returns - ------- + # Increase label dimension if too low + # Allows single and multiclass usage + if labels.ndim == 1: + labels = torch.reshape(labels, (labels.shape[0], 1)) - tensor : :py:class:`torch.Tensor` + # Forward pass on the network + outputs = self(images) - """ - return self.model_ft(x) + # Manually move criterion to selected device, since not part of the model. + self.hparams.criterion = self.hparams.criterion.to(self.device) + training_loss = self.hparams.criterion(outputs, labels.float()) + return {"loss": training_loss} -def build_densenet(pretrained=False, nb_channels=3): - """Build Densenet CNN. + def validation_step(self, batch, batch_idx): + images = batch[1] + labels = batch[2] - Returns - ------- + # Increase label dimension if too low + # Allows single and multiclass usage + if labels.ndim == 1: + labels = torch.reshape(labels, (labels.shape[0], 1)) - module : :py:class:`torch.nn.Module` - """ - model = Densenet(pretrained=pretrained) - model = [ - ("normalizer", TorchVisionNormalizer(nb_channels=nb_channels)), - ("model", model), - ] - model = nn.Sequential(OrderedDict(model)) - - model.name = "Densenet" - return model + # data forwarding on the existing network + outputs = self(images) + + # Manually move criterion to selected device, since not part of the model. + self.hparams.criterion_valid = self.hparams.criterion_valid.to( + self.device + ) + validation_loss = self.hparams.criterion_valid(outputs, labels.float()) + + return {"validation_loss": validation_loss} + + def predict_step(self, batch, batch_idx, grad_cams=False): + names = batch[0] + images = batch[1] + + outputs = self(images) + probabilities = torch.sigmoid(outputs) + + # necessary check for HED architecture that uses several outputs + # for loss calculation instead of just the last concatfuse block + if isinstance(outputs, list): + outputs = outputs[-1] + + return names[0], torch.flatten(probabilities), torch.flatten(batch[2]) + + def configure_optimizers(self): + # Dynamically instantiates the optimizer given the configs + optimizer = getattr(torch.optim, self.hparams.optimizer)( + self.parameters(), **self.hparams.optimizer_configs + ) + + return optimizer diff --git a/src/ptbench/models/densenet_rs.py b/src/ptbench/models/densenet_rs.py index c4448fbca8d97a60ccc041cc847bd79a0fd58056..a9d69e27928d5ec9a3d525d1a043370deeacb119 100644 --- a/src/ptbench/models/densenet_rs.py +++ b/src/ptbench/models/densenet_rs.py @@ -2,20 +2,32 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from collections import OrderedDict - +import lightning.pytorch as pl +import torch import torch.nn as nn import torchvision.models as models from .normalizer import TorchVisionNormalizer -class DensenetRS(nn.Module): +class DensenetRS(pl.LightningModule): """Densenet121 module for radiological extraction.""" - def __init__(self): + def __init__( + self, + criterion, + criterion_valid, + optimizer, + optimizer_configs, + ): super().__init__() + self.save_hyperparameters(ignore=["criterion", "criterion_valid"]) + + self.name = "DensenetRS" + + self.normalizer = TorchVisionNormalizer() + # Load pretrained model self.model_ft = models.densenet121( weights=models.DenseNet121_Weights.DEFAULT @@ -26,34 +38,67 @@ class DensenetRS(nn.Module): self.model_ft.classifier = nn.Linear(num_ftrs, 14) def forward(self, x): - """ + x = self.normalizer(x) + x = self.model_ft(x) + return x + + def training_step(self, batch, batch_idx): + images = batch[1] + labels = batch[2] + + # Increase label dimension if too low + # Allows single and multiclass usage + if labels.ndim == 1: + labels = torch.reshape(labels, (labels.shape[0], 1)) - Parameters - ---------- + # Forward pass on the network + outputs = self(images) - x : list - list of tensors. + # Manually move criterion to selected device, since not part of the model. + self.hparams.criterion = self.hparams.criterion.to(self.device) + training_loss = self.hparams.criterion(outputs, labels.float()) - Returns - ------- + return {"loss": training_loss} - tensor : :py:class:`torch.Tensor` + def validation_step(self, batch, batch_idx): + images = batch[1] + labels = batch[2] - """ - return self.model_ft(x) + # Increase label dimension if too low + # Allows single and multiclass usage + if labels.ndim == 1: + labels = torch.reshape(labels, (labels.shape[0], 1)) + # data forwarding on the existing network + outputs = self(images) -def build_densenetrs(): - """Build DensenetRS CNN. + # Manually move criterion to selected device, since not part of the model. + self.hparams.criterion_valid = self.hparams.criterion_valid.to( + self.device + ) + validation_loss = self.hparams.criterion_valid(outputs, labels.float()) + + return {"validation_loss": validation_loss} + + def predict_step(self, batch, batch_idx, grad_cams=False): + names = batch[0] + images = batch[1] - Returns - ------- + outputs = self(images) + probabilities = torch.sigmoid(outputs) - module : :py:class:`torch.nn.Module` - """ - model = DensenetRS() - model = [("normalizer", TorchVisionNormalizer()), ("model", model)] - model = nn.Sequential(OrderedDict(model)) + # necessary check for HED architecture that uses several outputs + # for loss calculation instead of just the last concatfuse block + if isinstance(outputs, list): + outputs = outputs[-1] + + return names[0], torch.flatten(probabilities), torch.flatten(batch[2]) + + def configure_optimizers(self): + # Dynamically instantiates the optimizer given the configs + optimizer = getattr(torch.optim, self.hparams.optimizer)( + filter(lambda p: p.requires_grad, self.model_ft.parameters()), + **self.hparams.optimizer_configs, + ) - model.name = "DensenetRS" - return model + return optimizer diff --git a/src/ptbench/models/logistic_regression.py b/src/ptbench/models/logistic_regression.py index 7e7818c71d8ebdb636a9b863965974cb71f91fba..6efd2a25c9726d5aeb081ae2a7ed22192b9befcd 100644 --- a/src/ptbench/models/logistic_regression.py +++ b/src/ptbench/models/logistic_regression.py @@ -2,45 +2,90 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +import lightning.pytorch as pl import torch import torch.nn as nn -class LogisticRegression(nn.Module): +class LogisticRegression(pl.LightningModule): """Radiological signs to Tuberculosis module.""" - def __init__(self, input_size): + def __init__( + self, + criterion, + criterion_valid, + optimizer, + optimizer_configs, + input_size, + ): super().__init__() - self.linear = torch.nn.Linear(input_size, 1) + + self.save_hyperparameters(ignore=["criterion", "criterion_valid"]) + + self.name = "logistic_regression" + + self.linear = nn.Linear(self.hparams.input_size, 1) def forward(self, x): - """ + output = self.linear(x) - Parameters - ---------- + return output - x : list - list of tensors. + def training_step(self, batch, batch_idx): + images = batch[1] + labels = batch[2] - Returns - ------- + # Increase label dimension if too low + # Allows single and multiclass usage + if labels.ndim == 1: + labels = torch.reshape(labels, (labels.shape[0], 1)) - tensor : :py:class:`torch.Tensor` + # Forward pass on the network + outputs = self(images) - """ - output = self.linear(x) + # Manually move criterion to selected device, since not part of the model. + self.hparams.criterion = self.hparams.criterion.to(self.device) + training_loss = self.hparams.criterion(outputs, labels.float()) - return output + return {"loss": training_loss} + + def validation_step(self, batch, batch_idx): + images = batch[1] + labels = batch[2] + + # Increase label dimension if too low + # Allows single and multiclass usage + if labels.ndim == 1: + labels = torch.reshape(labels, (labels.shape[0], 1)) + + # data forwarding on the existing network + outputs = self(images) + + # Manually move criterion to selected device, since not part of the model. + self.hparams.criterion_valid = self.hparams.criterion_valid.to( + self.device + ) + validation_loss = self.hparams.criterion_valid(outputs, labels.float()) + + return {"validation_loss": validation_loss} + + def predict_step(self, batch, batch_idx, grad_cams=False): + names = batch[0] + images = batch[1] + + outputs = self(images) + probabilities = torch.sigmoid(outputs) + # necessary check for HED architecture that uses several outputs + # for loss calculation instead of just the last concatfuse block + if isinstance(outputs, list): + outputs = outputs[-1] -def build_logistic_regression(input_size): - """Build logistic regression module. + return names[0], torch.flatten(probabilities), torch.flatten(batch[2]) - Returns - ------- + def configure_optimizers(self): + optimizer = getattr(torch.optim, self.hparams.optimizer)( + self.parameters(), **self.hparams.optimizer_configs + ) - module : :py:class:`torch.nn.Module` - """ - model = LogisticRegression(input_size) - model.name = "logistic_regression" - return model + return optimizer diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py index 10e6cedeb672094f62f5c1d16dbaeb9d6983ce34..125867bda1aa4f6e2317708cc5010d9120518f46 100644 --- a/src/ptbench/models/pasa.py +++ b/src/ptbench/models/pasa.py @@ -2,8 +2,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from collections import OrderedDict - +import lightning.pytorch as pl import torch import torch.nn as nn import torch.nn.functional as F @@ -11,14 +10,23 @@ import torch.nn.functional as F from .normalizer import TorchVisionNormalizer -class PASA(nn.Module): +class PASA(pl.LightningModule): """PASA module. Based on paper by [PASA-2019]_. """ - def __init__(self): + def __init__( + self, criterion, criterion_valid, optimizer, optimizer_configs + ): super().__init__() + + self.save_hyperparameters() + + self.name = "pasa" + + self.normalizer = TorchVisionNormalizer(nb_channels=1) + # First convolution block self.fc1 = nn.Conv2d(1, 4, (3, 3), (2, 2), (1, 1)) self.fc2 = nn.Conv2d(4, 16, (3, 3), (2, 2), (1, 1)) @@ -68,20 +76,8 @@ class PASA(nn.Module): self.dense = nn.Linear(80, 1) # Fully connected layer def forward(self, x): - """ - - Parameters - ---------- + x = self.normalizer(x) - x : list - list of tensors. - - Returns - ------- - - tensor : :py:class:`torch.Tensor` - - """ # First convolution block _x = x x = F.relu(self.batchNorm2d_4(self.fc1(x))) # 1st convolution @@ -127,21 +123,62 @@ class PASA(nn.Module): return x + def training_step(self, batch, batch_idx): + images = batch[1] + labels = batch[2] -def build_pasa(): - """Build pasa CNN. + # Increase label dimension if too low + # Allows single and multiclass usage + if labels.ndim == 1: + labels = torch.reshape(labels, (labels.shape[0], 1)) - Returns - ------- + # Forward pass on the network + outputs = self(images) - module : :py:class:`torch.nn.Module` - """ - model = PASA() - model = [ - ("normalizer", TorchVisionNormalizer(nb_channels=1)), - ("model", model), - ] - model = nn.Sequential(OrderedDict(model)) - - model.name = "pasa" - return model + # Manually move criterion to selected device, since not part of the model. + self.hparams.criterion = self.hparams.criterion.to(self.device) + training_loss = self.hparams.criterion(outputs, labels.double()) + + return {"loss": training_loss} + + def validation_step(self, batch, batch_idx): + images = batch[1] + labels = batch[2] + + # Increase label dimension if too low + # Allows single and multiclass usage + if labels.ndim == 1: + labels = torch.reshape(labels, (labels.shape[0], 1)) + + # data forwarding on the existing network + outputs = self(images) + + # Manually move criterion to selected device, since not part of the model. + self.hparams.criterion_valid = self.hparams.criterion_valid.to( + self.device + ) + validation_loss = self.hparams.criterion_valid(outputs, labels.double()) + + return {"validation_loss": validation_loss} + + def predict_step(self, batch, batch_idx, grad_cams=False): + names = batch[0] + images = batch[1] + + outputs = self(images) + probabilities = torch.sigmoid(outputs) + + # necessary check for HED architecture that uses several outputs + # for loss calculation instead of just the last concatfuse block + if isinstance(outputs, list): + outputs = outputs[-1] + + return names[0], torch.flatten(probabilities), torch.flatten(batch[2]) + + def configure_optimizers(self): + # Dynamically instantiates the optimizer given the configs + optimizer = getattr(torch.optim, self.hparams.optimizer)( + self.parameters(), **self.hparams.optimizer_configs + ) + + return optimizer diff --git a/src/ptbench/models/signs_to_tb.py b/src/ptbench/models/signs_to_tb.py index f3b3d5eac6ad6169625ffc4e5388f1a2dc87f4c3..aa22864558aec7340cfd53b7e3b9622e72980e8a 100644 --- a/src/ptbench/models/signs_to_tb.py +++ b/src/ptbench/models/signs_to_tb.py @@ -2,36 +2,35 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +import lightning.pytorch as pl import torch -import torch.nn as nn -class SignsToTB(nn.Module): +class SignsToTB(pl.LightningModule): """Radiological signs to Tuberculosis module.""" - def __init__(self, input_size, hidden_size): + def __init__( + self, + criterion, + criterion_valid, + optimizer, + optimizer_configs, + input_size, + hidden_size, + ): super().__init__() - self.input_size = input_size - self.hidden_size = hidden_size - self.fc1 = torch.nn.Linear(self.input_size, self.hidden_size) - self.relu = torch.nn.ReLU() - self.fc2 = torch.nn.Linear(self.hidden_size, 1) - - def forward(self, x): - """ - Parameters - ---------- + self.save_hyperparameters() - x : list - list of tensors. + self.name = "signs_to_tb" - Returns - ------- - - tensor : :py:class:`torch.Tensor` + self.fc1 = torch.nn.Linear( + self.hparams.input_size, self.hparams.hidden_size + ) + self.relu = torch.nn.ReLU() + self.fc2 = torch.nn.Linear(self.hparams.hidden_size, 1) - """ + def forward(self, x): hidden = self.fc1(x) relu = self.relu(hidden) @@ -39,15 +38,62 @@ class SignsToTB(nn.Module): return output + def training_step(self, batch, batch_idx): + images = batch[1] + labels = batch[2] + + # Increase label dimension if too low + # Allows single and multiclass usage + if labels.ndim == 1: + labels = torch.reshape(labels, (labels.shape[0], 1)) + + # Forward pass on the network + outputs = self(images) + + # Manually move criterion to selected device, since not part of the model. + self.hparams.criterion = self.hparams.criterion.to(self.device) + training_loss = self.hparams.criterion(outputs, labels.float()) + + return {"loss": training_loss} + + def validation_step(self, batch, batch_idx): + images = batch[1] + labels = batch[2] + + # Increase label dimension if too low + # Allows single and multiclass usage + if labels.ndim == 1: + labels = torch.reshape(labels, (labels.shape[0], 1)) + + # data forwarding on the existing network + outputs = self(images) + + # Manually move criterion to selected device, since not part of the model. + self.hparams.criterion_valid = self.hparams.criterion_valid.to( + self.device + ) + validation_loss = self.hparams.criterion_valid(outputs, labels.float()) + + return {"validation_loss": validation_loss} + + def predict_step(self, batch, batch_idx, grad_cams=False): + names = batch[0] + images = batch[1] + + outputs = self(images) + probabilities = torch.sigmoid(outputs) + + # necessary check for HED architecture that uses several outputs + # for loss calculation instead of just the last concatfuse block + if isinstance(outputs, list): + outputs = outputs[-1] -def build_signs_to_tb(input_size, hidden_size): - """Build SignsToTB shallow model. + return names[0], torch.flatten(probabilities), torch.flatten(batch[2]) - Returns - ------- + def configure_optimizers(self): + # Dynamically instantiates the optimizer given the configs + optimizer = getattr(torch.optim, self.hparams.optimizer)( + self.parameters(), **self.hparams.optimizer_configs + ) - module : :py:class:`torch.nn.Module` - """ - model = SignsToTB(input_size, hidden_size) - model.name = "signs_to_tb" - return model + return optimizer diff --git a/src/ptbench/scripts/predict.py b/src/ptbench/scripts/predict.py index 51275fc4efc4cda47172fece53ed3911a15b97cf..3613e75ad03c7f64ce9ce2c81e013d251fdf2858 100644 --- a/src/ptbench/scripts/predict.py +++ b/src/ptbench/scripts/predict.py @@ -62,8 +62,9 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") cls=ResourceOption, ) @click.option( - "--device", - help='A string indicating the device to use (e.g. "cpu" or "cuda:0")', + "--accelerator", + "-a", + help='A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The device can also be specified (gpu:0)', show_default=True, required=True, default="cpu", @@ -72,7 +73,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @click.option( "--weight", "-w", - help="Path or URL to pretrained model file (.pth extension)", + help="Path or URL to pretrained model file (.ckpt extension)", required=True, cls=ResourceOption, ) @@ -97,7 +98,7 @@ def predict( model, dataset, batch_size, - device, + accelerator, weight, relevance_analysis, grad_cams, @@ -117,21 +118,12 @@ def predict( from torch.utils.data import ConcatDataset, DataLoader from ..engine.predictor import run - from ..utils.checkpointer import Checkpointer - from ..utils.download import download_to_tempfile from ..utils.plot import relevance_analysis_plot dataset = dataset if isinstance(dataset, dict) else dict(test=dataset) - if weight.startswith("http"): - logger.info(f"Temporarily downloading '{weight}'...") - f = download_to_tempfile(weight, progress=True) - weight_fullpath = os.path.abspath(f.name) - else: - weight_fullpath = os.path.abspath(weight) - - checkpointer = Checkpointer(model) - checkpointer.load(weight_fullpath) + logger.info(f"Loading checkpoint from {weight}") + model = model.load_from_checkpoint(weight, strict=False) # Logistic regressor weights if model.name == "logistic_regression": @@ -162,7 +154,7 @@ def predict( pin_memory=torch.cuda.is_available(), ) predictions = run( - model, data_loader, k, device, output_folder, grad_cams + model, data_loader, k, accelerator, output_folder, grad_cams ) # Relevance analysis using permutation feature importance @@ -197,7 +189,11 @@ def predict( ) predictions_with_mean = run( - model, data_loader, k, device, output_folder + "_temp" + model, + data_loader, + k, + accelerator, + output_folder + "_temp", ) # Compute MSE between original and new predictions diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py index bafeb0303899086cd1b00cfe52d4d0896f7045b8..12c5a287f5682340ecb1275f4439ebd4056c657b 100644 --- a/src/ptbench/scripts/train.py +++ b/src/ptbench/scripts/train.py @@ -2,91 +2,15 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -import os - import click from clapper.click import ConfigCommand, ResourceOption, verbosity_option from clapper.logging import setup +from lightning.pytorch import seed_everything -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") - - -def setup_pytorch_device(name): - """Sets-up the pytorch device to use. - - Parameters - ---------- - - name : str - The device name (``cpu``, ``cuda:0``, ``cuda:1``, and so on). If you - set a specific cuda device such as ``cuda:1``, then we'll make sure it - is currently set. - - - Returns - ------- - - device : :py:class:`torch.device` - The pytorch device to use, pre-configured (and checked) - """ - import torch - - if name.startswith("cuda:"): - # In case one has multiple devices, we must first set the one - # we would like to use so pytorch can find it. - logger.info(f"User set device to '{name}' - trying to force device...") - os.environ["CUDA_VISIBLE_DEVICES"] = name.split(":", 1)[1] - if not torch.cuda.is_available(): - raise RuntimeError( - f"CUDA is not currently available, but " - f"you set device to '{name}'" - ) - # Let pytorch auto-select from environment variable - return torch.device("cuda") - - elif name.startswith("cuda"): # use default device - logger.info(f"User set device to '{name}' - using default CUDA device") - assert os.environ.get("CUDA_VISIBLE_DEVICES") is not None - - # cuda or cpu - return torch.device(name) - - -def set_seeds(value, all_gpus): - """Sets up all relevant random seeds (numpy, python, cuda) - - If running with multiple GPUs **at the same time**, set ``all_gpus`` to - ``True`` to force all GPU seeds to be initialized. - - Reference: `PyTorch page for reproducibility - <https://pytorch.org/docs/stable/notes/randomness.html>`_. - - - Parameters - ---------- - - value : int - The random seed value to use - - all_gpus : :py:class:`bool`, Optional - If set, then reset the seed on all GPUs available at once. This is - normally **not** what you want if running on a single GPU - """ - import random - - import numpy.random - import torch - import torch.cuda +from ..utils.checkpointer import get_checkpoint - random.seed(value) - numpy.random.seed(value) - torch.manual_seed(value) - torch.cuda.manual_seed(value) # noop if cuda not available - - # set seeds for all gpus - if all_gpus: - torch.cuda.manual_seed_all(value) # noop if cuda not available +logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") def set_reproducible_cuda(): @@ -160,12 +84,6 @@ def set_reproducible_cuda(): required=True, cls=ResourceOption, ) -@click.option( - "--optimizer", - help="A torch.optim.Optimizer that will be used to train the network", - required=True, - cls=ResourceOption, -) @click.option( "--criterion", help="A loss function to compute the CNN error for every sample " @@ -252,14 +170,15 @@ def set_reproducible_cuda(): "last saved checkpoint if training is restarted with the same " "configuration.", show_default=True, - required=True, - default=0, + required=False, + default=None, type=click.IntRange(min=0), cls=ResourceOption, ) @click.option( - "--device", - help='A string indicating the device to use (e.g. "cpu" or "cuda:0")', + "--accelerator", + "-a", + help='A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The device can also be specified (gpu:0)', show_default=True, required=True, default="cpu", @@ -314,10 +233,17 @@ def set_reproducible_cuda(): default=5.0, cls=ResourceOption, ) +@click.option( + "--resume-from", + help="Which checkpoint to resume training from. Can be one of 'None', 'best', 'last', or a path to a model checkpoint.", + type=str, + required=False, + default=None, + cls=ResourceOption, +) @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False) def train( model, - optimizer, output_folder, epochs, batch_size, @@ -327,11 +253,12 @@ def train( criterion_valid, dataset, checkpoint_period, - device, + accelerator, seed, parallel, normalization, monitoring_interval, + resume_from, **_, ): """Trains an CNN to perform tuberculosis detection. @@ -354,11 +281,8 @@ def train( from ..configs.datasets import get_positive_weights, get_samples_weights from ..engine.trainer import run - from ..utils.checkpointer import Checkpointer - device = setup_pytorch_device(device) - - set_seeds(seed, all_gpus=False) + seed_everything(seed) use_dataset = dataset validation_dataset = None @@ -418,9 +342,6 @@ def train( # Create weighted random sampler train_samples_weights = get_samples_weights(use_dataset) - train_samples_weights = train_samples_weights.to( - device=device, non_blocking=torch.cuda.is_available() - ) train_sampler = WeightedRandomSampler( train_samples_weights, len(train_samples_weights), replacement=True ) @@ -428,10 +349,7 @@ def train( # Redefine a weighted criterion if possible if isinstance(criterion, torch.nn.BCEWithLogitsLoss): positive_weights = get_positive_weights(use_dataset) - positive_weights = positive_weights.to( - device=device, non_blocking=torch.cuda.is_available() - ) - criterion = BCEWithLogitsLoss(pos_weight=positive_weights) + model.hparams.criterion = BCEWithLogitsLoss(pos_weight=positive_weights) else: logger.warning("Weighted criterion not supported") @@ -454,10 +372,9 @@ def train( or criterion_valid is None ): positive_weights = get_positive_weights(validation_dataset) - positive_weights = positive_weights.to( - device=device, non_blocking=torch.cuda.is_available() + model.hparams.criterion_valid = BCEWithLogitsLoss( + pos_weight=positive_weights ) - criterion_valid = BCEWithLogitsLoss(pos_weight=positive_weights) else: logger.warning("Weighted valid criterion not supported") @@ -513,15 +430,16 @@ def train( ) logger.info(f"Z-normalization with mean {mean} and std {std}") - # Checkpointer - checkpointer = Checkpointer(model, optimizer, path=output_folder) - - # Initialize epoch information arguments = {} - arguments["epoch"] = 0 - extra_checkpoint_data = checkpointer.load() - arguments.update(extra_checkpoint_data) arguments["max_epoch"] = epochs + arguments["epoch"] = 0 + + checkpoint_file = get_checkpoint(output_folder, resume_from) + + # We only load the checkpoint to get some information about its state. The actual loading of the model is done in trainer.fit() + if checkpoint_file is not None: + checkpoint = torch.load(checkpoint_file) + arguments["epoch"] = checkpoint["epoch"] logger.info("Training for {} epochs".format(arguments["max_epoch"])) logger.info("Continuing from epoch {}".format(arguments["epoch"])) @@ -531,14 +449,11 @@ def train( data_loader=data_loader, valid_loader=valid_loader, extra_valid_loaders=extra_valid_loaders, - optimizer=optimizer, - criterion=criterion, - checkpointer=checkpointer, checkpoint_period=checkpoint_period, - device=device, + accelerator=accelerator, arguments=arguments, output_folder=output_folder, monitoring_interval=monitoring_interval, batch_chunk_count=batch_chunk_count, - criterion_valid=criterion_valid, + checkpoint=checkpoint_file, ) diff --git a/src/ptbench/utils/accelerator.py b/src/ptbench/utils/accelerator.py new file mode 100644 index 0000000000000000000000000000000000000000..dcfa2f733e1d091c5bb9a4e5785ee47f8e49497c --- /dev/null +++ b/src/ptbench/utils/accelerator.py @@ -0,0 +1,127 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import logging +import os + +import torch + +logger = logging.getLogger(__name__) + + +class AcceleratorProcessor: + """This class is used to convert the torch device naming convention to + lightning's device convention and vice versa. + + It also sets the CUDA_VISIBLE_DEVICES if a gpu accelerator is used. + """ + + def __init__(self, name): + # Note: "auto" is a valid accelerator in lightning, but there doesn't seem to be a way to check which accelerator it will actually use so we don't take it into account for now. + self.torch_to_lightning = {"cpu": "cpu", "cuda": "gpu"} + + self.lightning_to_torch = { + v: k for k, v in self.torch_to_lightning.items() + } + + self.valid_accelerators = set( + list(self.torch_to_lightning.keys()) + + list(self.lightning_to_torch.keys()) + ) + + self.accelerator, self.device = self._split_accelerator_name(name) + + if self.accelerator not in self.valid_accelerators: + raise ValueError(f"Unknown accelerator {self.accelerator}") + + # Keep lightning's convention by default + self.accelerator = self.to_lightning() + self.setup_accelerator() + + def setup_accelerator(self): + """If a gpu accelerator is chosen, checks the CUDA_VISIBLE_DEVICES + environment variable exists or sets its value if specified.""" + if self.accelerator == "gpu": + if not torch.cuda.is_available(): + raise RuntimeError( + f"CUDA is not currently available, but " + f"you set accelerator to '{self.accelerator}'" + ) + + if self.device is not None: + os.environ["CUDA_VISIBLE_DEVICES"] = str(self.device[0]) + else: + if os.environ.get("CUDA_VISIBLE_DEVICES") is None: + raise ValueError( + "Environment variable 'CUDA_VISIBLE_DEVICES' is not set." + "Please set 'CUDA_VISIBLE_DEVICES' of specify a device to use, e.g. cuda:0" + ) + else: + # No need to check the CUDA_VISIBLE_DEVICES environment variable if cpu + pass + + logger.info( + f"Accelerator set to {self.accelerator} and device to {self.device}" + ) + + def _split_accelerator_name(self, accelerator_name): + """Splits an accelerator string into accelerator and device components. + + Parameters + ---------- + + accelerator_name: str + The accelerator (or device in pytorch convention) string (e.g. cuda:0) + + Returns + ------- + + accelerator: str + The accelerator name + device: dict[int] + The selected devices + """ + + split_accelerator = accelerator_name.split(":") + accelerator = split_accelerator[0] + + if len(split_accelerator) > 1: + device = split_accelerator[1] + device = [int(device)] + else: + device = None + + return accelerator, device + + def to_torch(self): + """Converts the accelerator string to torch convention. + + Returns + ------- + + accelerator: str + The accelerator name in pytorch convention + """ + if self.accelerator in self.lightning_to_torch: + return self.lightning_to_torch[self.accelerator] + elif self.accelerator in self.torch_to_lightning: + return self.accelerator + else: + raise ValueError("Unknown accelerator.") + + def to_lightning(self): + """Converts the accelerator string to lightning convention. + + Returns + ------- + + accelerator: str + The accelerator name in lightning convention + """ + if self.accelerator in self.torch_to_lightning: + return self.torch_to_lightning[self.accelerator] + elif self.accelerator in self.lightning_to_torch: + return self.accelerator + else: + raise ValueError("Unknown accelerator.") diff --git a/src/ptbench/utils/checkpointer.py b/src/ptbench/utils/checkpointer.py index 3e839b0e11bd29848d4acfd1b5364a3e0ff5431e..81516d28110871690ab148305fb9f5925a728601 100644 --- a/src/ptbench/utils/checkpointer.py +++ b/src/ptbench/utils/checkpointer.py @@ -1,99 +1,61 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later - import logging import os -import torch - logger = logging.getLogger(__name__) -class Checkpointer: - """A simple pytorch checkpointer. +def get_checkpoint(output_folder, resume_from): + """Gets a checkpoint file. + + Can return the best or last checkpoint, or a checkpoint at a specific path. + Ensures the checkpoint exists, raising an error if it is not the case. Parameters ---------- - model : torch.nn.Module - Network model, eventually loaded from a checkpointed file + output_folder : :py:class:`str` + Directory in which checkpoints are stored. - optimizer : :py:mod:`torch.optim`, Optional - Optimizer + resume_from : :py:class:`str` + Which model to get. Can be one of "best", "last", or a path to a checkpoint. - scheduler : :py:mod:`torch.optim`, Optional - Learning rate scheduler + Returns + ------- - path : :py:class:`str`, Optional - Directory where to save checkpoints. + checkpoint_file : :py:class:`str` + The requested model. """ - - def __init__(self, model, optimizer=None, scheduler=None, path="."): - self.model = model - self.optimizer = optimizer - self.scheduler = scheduler - self.path = os.path.realpath(path) - - def save(self, name, **kwargs): - data = {} - data["model"] = self.model.state_dict() - if self.optimizer is not None: - data["optimizer"] = self.optimizer.state_dict() - if self.scheduler is not None: - data["scheduler"] = self.scheduler.state_dict() - data.update(kwargs) - - name = f"{name}.pth" - outf = os.path.join(self.path, name) - logger.info(f"Saving checkpoint to {outf}") - torch.save(data, outf) - with open(self._last_checkpoint_filename, "w") as f: - f.write(name) - - def load(self, f=None): - """Loads model, optimizer and scheduler from file. - - Parameters - ========== - - f : :py:class:`str`, Optional - Name of a file (absolute or relative to ``self.path``), that - contains the checkpoint data to load into the model, and optionally - into the optimizer and the scheduler. If not specified, loads data - from current path. - """ - if f is None: - f = self.last_checkpoint() - - if f is None: - # no checkpoint could be found - logger.warning("No checkpoint found (and none passed)") - return {} - - # loads file data into memory - logger.info(f"Loading checkpoint from {f}...") - checkpoint = torch.load(f, map_location=torch.device("cpu")) - - # converts model entry to model parameters - self.model.load_state_dict(checkpoint.pop("model")) - - if self.optimizer is not None: - self.optimizer.load_state_dict(checkpoint.pop("optimizer")) - if self.scheduler is not None: - self.scheduler.load_state_dict(checkpoint.pop("scheduler")) - - return checkpoint - - @property - def _last_checkpoint_filename(self): - return os.path.join(self.path, "last_checkpoint") - - def has_checkpoint(self): - return os.path.exists(self._last_checkpoint_filename) - - def last_checkpoint(self): - if self.has_checkpoint(): - with open(self._last_checkpoint_filename) as fobj: - return os.path.join(self.path, fobj.read().strip()) - return None + last_checkpoint_path = os.path.join(output_folder, "model_final_epoch.ckpt") + best_checkpoint_path = os.path.join( + output_folder, "model_lowest_valid_loss.ckpt" + ) + + if resume_from == "last": + if os.path.isfile(last_checkpoint_path): + checkpoint_file = last_checkpoint_path + logger.info(f"Resuming training from {resume_from} checkpoint") + else: + raise FileNotFoundError( + f"Could not find checkpoint {last_checkpoint_path}" + ) + + elif resume_from == "best": + if os.path.isfile(best_checkpoint_path): + checkpoint_file = last_checkpoint_path + logger.info(f"Resuming training from {resume_from} checkpoint") + else: + raise FileNotFoundError( + f"Could not find checkpoint {best_checkpoint_path}" + ) + + elif resume_from is None: + checkpoint_file = None + + else: + if os.path.isfile(resume_from): + checkpoint_file = resume_from + logger.info(f"Resuming training from checkpoint {resume_from}") + else: + raise FileNotFoundError(f"Could not find checkpoint {resume_from}") + + return checkpoint_file diff --git a/src/ptbench/utils/resources.py b/src/ptbench/utils/resources.py index be23ee452a1823555220c5d92d80a2f7c6a9223f..fa0ac3dd2b2332a8d938850d97739b457eeed13a 100644 --- a/src/ptbench/utils/resources.py +++ b/src/ptbench/utils/resources.py @@ -233,6 +233,7 @@ class CPULogger: cpu_percent = [] open_files = [] gone = set() + for k in self.cluster: try: memory_info.append(k.memory_info()) @@ -243,7 +244,6 @@ class CPULogger: # it is too late to update any intermediate list # at this point, but ensures to update counts later on gone.add(k) - return ( ("cpu_memory_used", psutil.virtual_memory().used / GB), ("cpu_rss", sum([k.rss for k in memory_info]) / GB), @@ -288,6 +288,10 @@ class _InformationGatherer: for i, k in enumerate(gpu_log()): self.data[i + self.cpu_keys_len].append(k[1]) + def clear(self): + """Clears accumulated data.""" + self.data = [[] for _ in self.keys] + def summary(self): """Returns the current data.""" if len(self.data[0]) == 0: @@ -298,7 +302,9 @@ class _InformationGatherer: return tuple(retval) -def _monitor_worker(interval, has_gpu, main_pid, stop, queue, logging_level): +def _monitor_worker( + interval, has_gpu, main_pid, stop, summary_event, queue, logging_level +): """A monitoring worker that measures resources and returns lists. Parameters @@ -329,6 +335,12 @@ def _monitor_worker(interval, has_gpu, main_pid, stop, queue, logging_level): while not stop.is_set(): try: ra.acc() # guarantees at least an entry will be available + + if summary_event.is_set(): + queue.put(ra.summary()) + ra.clear() + summary_event.clear() + time.sleep(interval) except Exception: logger.warning( @@ -337,8 +349,6 @@ def _monitor_worker(interval, has_gpu, main_pid, stop, queue, logging_level): ) time.sleep(0.5) # wait half a second, and try again! - queue.put(ra.summary()) - class ResourceMonitor: """An external, non-blocking CPU/GPU resource monitor. @@ -364,7 +374,8 @@ class ResourceMonitor: self.interval = interval self.has_gpu = has_gpu self.main_pid = main_pid - self.event = multiprocessing.Event() + self.stop_event = multiprocessing.Event() + self.summary_event = multiprocessing.Event() self.q = multiprocessing.Queue() self.logging_level = logging_level @@ -375,7 +386,8 @@ class ResourceMonitor: self.interval, self.has_gpu, self.main_pid, - self.event, + self.stop_event, + self.summary_event, self.q, self.logging_level, ), @@ -390,19 +402,9 @@ class ResourceMonitor: def __enter__(self): """Starts the monitoring process.""" self.monitor.start() - return self - def __exit__(self, *exc): - """Stops the monitoring process and returns the summary of - observations.""" - - self.event.set() - self.monitor.join() - if self.monitor.exitcode != 0: - logger.error( - f"CPU/GPU resource monitor process exited with code " - f"{self.monitor.exitcode}. Check logs for errors!" - ) + def trigger_summary(self): + self.summary_event.set() try: data = self.q.get(timeout=2 * self.interval) @@ -426,3 +428,15 @@ class ResourceMonitor: else: summary.append((k, 0.0)) self.data = tuple(summary) + + def __exit__(self, *exc): + """Stops the monitoring process and returns the summary of + observations.""" + + self.stop_event.set() + self.monitor.join() + if self.monitor.exitcode != 0: + logger.error( + f"CPU/GPU resource monitor process exited with code " + f"{self.monitor.exitcode}. Check logs for errors!" + ) diff --git a/tests/data/lfs b/tests/data/lfs index 69185f0d9ea67893722c5a840e2caa59946b3b83..9e3818d3fe0944697e91f171c59fdf035a22868c 160000 --- a/tests/data/lfs +++ b/tests/data/lfs @@ -1 +1 @@ -Subproject commit 69185f0d9ea67893722c5a840e2caa59946b3b83 +Subproject commit 9e3818d3fe0944697e91f171c59fdf035a22868c diff --git a/tests/test_11k.py b/tests/test_11k.py new file mode 100644 index 0000000000000000000000000000000000000000..8f101c82b2f298f8c3ce676a003e0ff18aa2a731 --- /dev/null +++ b/tests/test_11k.py @@ -0,0 +1,208 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Tests for TBX11K simplified dataset split 1.""" + +import pytest + + +def test_protocol_consistency(): + from ptbench.data.tbx11k_simplified import dataset + + # Default protocol + subset = dataset.subsets("default") + assert len(subset) == 3 + + assert "train" in subset + assert len(subset["train"]) == 2767 + for s in subset["train"]: + assert s.key.startswith("images/") + + assert "validation" in subset + assert len(subset["validation"]) == 706 + for s in subset["validation"]: + assert s.key.startswith("images/") + + assert "test" in subset + assert len(subset["test"]) == 957 + for s in subset["test"]: + assert s.key.startswith("images/") + + # Check labels + for s in subset["train"]: + assert s.label in [0.0, 1.0] + + for s in subset["validation"]: + assert s.label in [0.0, 1.0] + + for s in subset["test"]: + assert s.label in [0.0, 1.0] + + # Cross-validation fold 0-9 + for f in range(10): + subset = dataset.subsets("fold_" + str(f)) + assert len(subset) == 3 + + assert "train" in subset + assert len(subset["train"]) == 3177 + for s in subset["train"]: + assert s.key.startswith("images/") + + assert "validation" in subset + assert len(subset["validation"]) == 810 + for s in subset["validation"]: + assert s.key.startswith("images/") + + assert "test" in subset + assert len(subset["test"]) == 443 + for s in subset["test"]: + assert s.key.startswith("images/") + + # Check labels + for s in subset["train"]: + assert s.label in [0.0, 1.0] + + for s in subset["validation"]: + assert s.label in [0.0, 1.0] + + for s in subset["test"]: + assert s.label in [0.0, 1.0] + + +def test_protocol_consistency_bbox(): + from ptbench.data.tbx11k_simplified import dataset_with_bboxes + + # Default protocol + subset = dataset_with_bboxes.subsets("default") + assert len(subset) == 3 + + assert "train" in subset + assert len(subset["train"]) == 2767 + for s in subset["train"]: + assert s.key.startswith("images/") + + assert "validation" in subset + assert len(subset["validation"]) == 706 + for s in subset["validation"]: + assert s.key.startswith("images/") + + assert "test" in subset + assert len(subset["test"]) == 957 + for s in subset["test"]: + assert s.key.startswith("images/") + + # Check labels + for s in subset["train"]: + assert s.label in [0.0, 1.0] + + for s in subset["validation"]: + assert s.label in [0.0, 1.0] + + for s in subset["test"]: + assert s.label in [0.0, 1.0] + + # Check bounding boxes + for s in subset["train"]: + assert s.bboxes == "none" or s.bboxes[0].startswith("{'xmin':") + + # Cross-validation fold 0-9 + for f in range(10): + subset = dataset_with_bboxes.subsets("fold_" + str(f)) + assert len(subset) == 3 + + assert "train" in subset + assert len(subset["train"]) == 3177 + for s in subset["train"]: + assert s.key.startswith("images/") + + assert "validation" in subset + assert len(subset["validation"]) == 810 + for s in subset["validation"]: + assert s.key.startswith("images/") + + assert "test" in subset + assert len(subset["test"]) == 443 + for s in subset["test"]: + assert s.key.startswith("images/") + + # Check labels + for s in subset["train"]: + assert s.label in [0.0, 1.0] + + for s in subset["validation"]: + assert s.label in [0.0, 1.0] + + for s in subset["test"]: + assert s.label in [0.0, 1.0] + + # Check bounding boxes + for s in subset["train"]: + assert s.bboxes == "none" or s.bboxes[0].startswith("{'xmin':") + + +@pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k_simplified") +def test_loading(): + from ptbench.data.tbx11k_simplified import dataset + + def _check_sample(s): + data = s.data + assert isinstance(data, dict) + assert len(data) == 2 + + assert "data" in data + assert data["data"].size == (512, 512) + + assert data["data"].mode == "L" # Check colors + + assert "label" in data + assert data["label"] in [0, 1] # Check labels + + limit = 30 # use this to limit testing to first images only, else None + + subset = dataset.subsets("default") + for s in subset["train"][:limit]: + _check_sample(s) + + +@pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k_simplified") +def test_loading_bbox(): + from ptbench.data.tbx11k_simplified import dataset_with_bboxes + + def _check_sample(s): + data = s.data + assert isinstance(data, dict) + assert len(data) == 3 + + assert "data" in data + assert data["data"].size == (512, 512) + + assert data["data"].mode == "L" # Check colors + + assert "label" in data + assert data["label"] in [0, 1] # Check labels + + assert "bboxes" in data + assert data["bboxes"] == "none" or data["bboxes"][0].startswith( + "{'xmin':" + ) + + limit = 30 # use this to limit testing to first images only, else None + + subset = dataset_with_bboxes.subsets("default") + for s in subset["train"][:limit]: + _check_sample(s) + + +@pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k_simplified") +def test_check(): + from ptbench.data.tbx11k_simplified import dataset + + assert dataset.check() == 0 + + +@pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k_simplified") +def test_check_bbox(): + from ptbench.data.tbx11k_simplified import dataset_with_bboxes + + assert dataset_with_bboxes.check() == 0 diff --git a/tests/test_11k_RS.py b/tests/test_11k_RS.py new file mode 100644 index 0000000000000000000000000000000000000000..601bbc4628ea752f3ad52b78cedecd64a4b215dc --- /dev/null +++ b/tests/test_11k_RS.py @@ -0,0 +1,88 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Tests for Extended TBX11K simplified dataset split 1.""" + +import pytest + + +def test_protocol_consistency(): + from ptbench.data.tbx11k_simplified_RS import dataset + + # Default protocol + subset = dataset.subsets("default") + assert len(subset) == 3 + + assert "train" in subset + assert len(subset["train"]) == 2767 + + assert "validation" in subset + assert len(subset["validation"]) == 706 + + assert "test" in subset + assert len(subset["test"]) == 957 + for s in subset["test"]: + assert s.key.startswith("images/") + + # Check labels + for s in subset["train"]: + assert s.label in [0.0, 1.0] + + for s in subset["validation"]: + assert s.label in [0.0, 1.0] + + for s in subset["test"]: + assert s.label in [0.0, 1.0] + + # Cross-validation fold 0-9 + for f in range(10): + subset = dataset.subsets("fold_" + str(f)) + assert len(subset) == 3 + + assert "train" in subset + assert len(subset["train"]) == 3177 + for s in subset["train"]: + assert s.key.startswith("images/") + + assert "validation" in subset + assert len(subset["validation"]) == 810 + for s in subset["validation"]: + assert s.key.startswith("images/") + + assert "test" in subset + assert len(subset["test"]) == 443 + for s in subset["test"]: + assert s.key.startswith("images/") + + # Check labels + for s in subset["train"]: + assert s.label in [0.0, 1.0] + + for s in subset["validation"]: + assert s.label in [0.0, 1.0] + + for s in subset["test"]: + assert s.label in [0.0, 1.0] + + +@pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k_simplified") +def test_loading(): + from ptbench.data.tbx11k_simplified_RS import dataset + + def _check_sample(s): + data = s.data + assert isinstance(data, dict) + assert len(data) == 2 + + assert "data" in data + assert len(data["data"]) == 14 # Check radiological signs + + assert "label" in data + assert data["label"] in [0, 1] # Check labels + + limit = 30 # use this to limit testing to first images only, else None + + subset = dataset.subsets("default") + for s in subset["train"][:limit]: + _check_sample(s) diff --git a/tests/test_11k_v2.py b/tests/test_11k_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..12662886ed4eea1a2fa654c80b9666c53e5af515 --- /dev/null +++ b/tests/test_11k_v2.py @@ -0,0 +1,270 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Tests for TBX11K simplified dataset split 2.""" + +import pytest + + +def test_protocol_consistency(): + from ptbench.data.tbx11k_simplified_v2 import dataset + + # Default protocol + subset = dataset.subsets("default") + assert len(subset) == 3 + + assert "train" in subset + assert len(subset["train"]) == 5241 + for s in subset["train"]: + assert s.key.startswith("images/") + + assert "validation" in subset + assert len(subset["validation"]) == 1335 + for s in subset["validation"]: + assert s.key.startswith("images/") + + assert "test" in subset + assert len(subset["test"]) == 1793 + for s in subset["test"]: + assert s.key.startswith("images/") + + # Check labels + for s in subset["train"]: + assert s.label in [0.0, 1.0] + + for s in subset["validation"]: + assert s.label in [0.0, 1.0] + + for s in subset["test"]: + assert s.label in [0.0, 1.0] + + # Cross-validation fold 0-8 + for f in range(9): + subset = dataset.subsets("fold_" + str(f)) + assert len(subset) == 3 + + assert "train" in subset + assert len(subset["train"]) == 6003 + for s in subset["train"]: + assert s.key.startswith("images/") + + assert "validation" in subset + assert len(subset["validation"]) == 1529 + for s in subset["validation"]: + assert s.key.startswith("images/") + + assert "test" in subset + assert len(subset["test"]) == 837 + for s in subset["test"]: + assert s.key.startswith("images/") + + # Check labels + for s in subset["train"]: + assert s.label in [0.0, 1.0] + + for s in subset["validation"]: + assert s.label in [0.0, 1.0] + + for s in subset["test"]: + assert s.label in [0.0, 1.0] + + # Cross-validation fold 9 + subset = dataset.subsets("fold_9") + assert len(subset) == 3 + + assert "train" in subset + assert len(subset["train"]) == 6003 + for s in subset["train"]: + assert s.key.startswith("images/") + + assert "validation" in subset + assert len(subset["validation"]) == 1530 + for s in subset["validation"]: + assert s.key.startswith("images/") + + assert "test" in subset + assert len(subset["test"]) == 836 + for s in subset["test"]: + assert s.key.startswith("images/") + + # Check labels + for s in subset["train"]: + assert s.label in [0.0, 1.0] + + for s in subset["validation"]: + assert s.label in [0.0, 1.0] + + for s in subset["test"]: + assert s.label in [0.0, 1.0] + + +def test_protocol_consistency_bbox(): + from ptbench.data.tbx11k_simplified_v2 import dataset_with_bboxes + + # Default protocol + subset = dataset_with_bboxes.subsets("default") + assert len(subset) == 3 + + assert "train" in subset + assert len(subset["train"]) == 5241 + for s in subset["train"]: + assert s.key.startswith("images/") + + assert "validation" in subset + assert len(subset["validation"]) == 1335 + for s in subset["validation"]: + assert s.key.startswith("images/") + + assert "test" in subset + assert len(subset["test"]) == 1793 + for s in subset["test"]: + assert s.key.startswith("images/") + + # Check labels + for s in subset["train"]: + assert s.label in [0.0, 1.0] + + for s in subset["validation"]: + assert s.label in [0.0, 1.0] + + for s in subset["test"]: + assert s.label in [0.0, 1.0] + + # Check bounding boxes + for s in subset["train"]: + assert s.bboxes == "none" or s.bboxes[0].startswith("{'xmin':") + + # Cross-validation fold 0-8 + for f in range(9): + subset = dataset_with_bboxes.subsets("fold_" + str(f)) + assert len(subset) == 3 + + assert "train" in subset + assert len(subset["train"]) == 6003 + for s in subset["train"]: + assert s.key.startswith("images/") + + assert "validation" in subset + assert len(subset["validation"]) == 1529 + for s in subset["validation"]: + assert s.key.startswith("images/") + + assert "test" in subset + assert len(subset["test"]) == 837 + for s in subset["test"]: + assert s.key.startswith("images/") + + # Check labels + for s in subset["train"]: + assert s.label in [0.0, 1.0] + + for s in subset["validation"]: + assert s.label in [0.0, 1.0] + + for s in subset["test"]: + assert s.label in [0.0, 1.0] + + # Check bounding boxes + for s in subset["train"]: + assert s.bboxes == "none" or s.bboxes[0].startswith("{'xmin':") + + # Cross-validation fold 9 + subset = dataset_with_bboxes.subsets("fold_9") + assert len(subset) == 3 + + assert "train" in subset + assert len(subset["train"]) == 6003 + for s in subset["train"]: + assert s.key.startswith("images/") + + assert "validation" in subset + assert len(subset["validation"]) == 1530 + for s in subset["validation"]: + assert s.key.startswith("images/") + + assert "test" in subset + assert len(subset["test"]) == 836 + for s in subset["test"]: + assert s.key.startswith("images/") + + # Check labels + for s in subset["train"]: + assert s.label in [0.0, 1.0] + + for s in subset["validation"]: + assert s.label in [0.0, 1.0] + + for s in subset["test"]: + assert s.label in [0.0, 1.0] + + # Check bounding boxes + for s in subset["train"]: + assert s.bboxes == "none" or s.bboxes[0].startswith("{'xmin':") + + +@pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k_simplified_v2") +def test_loading(): + from ptbench.data.tbx11k_simplified_v2 import dataset + + def _check_sample(s): + data = s.data + assert isinstance(data, dict) + assert len(data) == 2 + + assert "data" in data + assert data["data"].size == (512, 512) + + assert data["data"].mode == "L" # Check colors + + assert "label" in data + assert data["label"] in [0, 1] # Check labels + + limit = 30 # use this to limit testing to first images only, else None + + subset = dataset.subsets("default") + for s in subset["train"][:limit]: + _check_sample(s) + + +@pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k_simplified_v2") +def test_loading_bbox(): + from ptbench.data.tbx11k_simplified_v2 import dataset_with_bboxes + + def _check_sample(s): + data = s.data + assert isinstance(data, dict) + assert len(data) == 3 + + assert "data" in data + assert data["data"].size == (512, 512) + + assert data["data"].mode == "L" # Check colors + + assert "label" in data + assert data["label"] in [0, 1] # Check labels + + assert "bboxes" in data + assert data["bboxes"] == "none" or data["bboxes"][0].startswith( + "{'xmin':" + ) + + limit = 30 # use this to limit testing to first images only, else None + + subset = dataset_with_bboxes.subsets("default") + for s in subset["train"][:limit]: + _check_sample(s) + + +@pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k_simplified_v2") +def test_check(): + from ptbench.data.tbx11k_simplified_v2 import dataset + + assert dataset.check() == 0 + + +@pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k_simplified_v2") +def test_check_bbox(): + from ptbench.data.tbx11k_simplified_v2 import dataset_with_bboxes + + assert dataset_with_bboxes.check() == 0 diff --git a/tests/test_11k_v2_RS.py b/tests/test_11k_v2_RS.py new file mode 100644 index 0000000000000000000000000000000000000000..c6ac2464324aee1aa45e185c13380e301a949597 --- /dev/null +++ b/tests/test_11k_v2_RS.py @@ -0,0 +1,117 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Tests for Extended TBX11K simplified dataset split 2.""" + +import pytest + + +def test_protocol_consistency(): + from ptbench.data.tbx11k_simplified_v2_RS import dataset + + # Default protocol + subset = dataset.subsets("default") + assert len(subset) == 3 + + assert "train" in subset + assert len(subset["train"]) == 5241 + + assert "validation" in subset + assert len(subset["validation"]) == 1335 + + assert "test" in subset + assert len(subset["test"]) == 1793 + for s in subset["test"]: + assert s.key.startswith("images/") + + # Check labels + for s in subset["train"]: + assert s.label in [0.0, 1.0] + + for s in subset["validation"]: + assert s.label in [0.0, 1.0] + + for s in subset["test"]: + assert s.label in [0.0, 1.0] + + # Cross-validation fold 0-8 + for f in range(9): + subset = dataset.subsets("fold_" + str(f)) + assert len(subset) == 3 + + assert "train" in subset + assert len(subset["train"]) == 6003 + for s in subset["train"]: + assert s.key.startswith("images/") + + assert "validation" in subset + assert len(subset["validation"]) == 1529 + for s in subset["validation"]: + assert s.key.startswith("images/") + + assert "test" in subset + assert len(subset["test"]) == 837 + for s in subset["test"]: + assert s.key.startswith("images/") + + # Check labels + for s in subset["train"]: + assert s.label in [0.0, 1.0] + + for s in subset["validation"]: + assert s.label in [0.0, 1.0] + + for s in subset["test"]: + assert s.label in [0.0, 1.0] + + # Cross-validation fold 9 + subset = dataset.subsets("fold_9") + assert len(subset) == 3 + + assert "train" in subset + assert len(subset["train"]) == 6003 + for s in subset["train"]: + assert s.key.startswith("images/") + + assert "validation" in subset + assert len(subset["validation"]) == 1530 + for s in subset["validation"]: + assert s.key.startswith("images/") + + assert "test" in subset + assert len(subset["test"]) == 836 + for s in subset["test"]: + assert s.key.startswith("images/") + + # Check labels + for s in subset["train"]: + assert s.label in [0.0, 1.0] + + for s in subset["validation"]: + assert s.label in [0.0, 1.0] + + for s in subset["test"]: + assert s.label in [0.0, 1.0] + + +@pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k_simplified") +def test_loading(): + from ptbench.data.tbx11k_simplified_v2_RS import dataset + + def _check_sample(s): + data = s.data + assert isinstance(data, dict) + assert len(data) == 2 + + assert "data" in data + assert len(data["data"]) == 14 # Check radiological signs + + assert "label" in data + assert data["label"] in [0, 1] # Check labels + + limit = 30 # use this to limit testing to first images only, else None + + subset = dataset.subsets("default") + for s in subset["train"][:limit]: + _check_sample(s) diff --git a/tests/test_checkpointer.py b/tests/test_checkpointer.py deleted file mode 100644 index aca95248732f307cd6a0b17cff21628a565e6d50..0000000000000000000000000000000000000000 --- a/tests/test_checkpointer.py +++ /dev/null @@ -1,85 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later - -import os -import unittest - -from collections import OrderedDict -from tempfile import TemporaryDirectory - -import torch - -from ptbench.utils.checkpointer import Checkpointer - - -class TestCheckpointer(unittest.TestCase): - def create_model(self): - return torch.nn.Sequential(torch.nn.Linear(2, 3), torch.nn.Linear(3, 1)) - - def create_complex_model(self): - m = torch.nn.Module() - m.block1 = torch.nn.Module() - m.block1.layer1 = torch.nn.Linear(2, 3) - m.layer2 = torch.nn.Linear(3, 2) - m.res = torch.nn.Module() - m.res.layer2 = torch.nn.Linear(3, 2) - - state_dict = OrderedDict() - state_dict["layer1.weight"] = torch.rand(3, 2) - state_dict["layer1.bias"] = torch.rand(3) - state_dict["layer2.weight"] = torch.rand(2, 3) - state_dict["layer2.bias"] = torch.rand(2) - state_dict["res.layer2.weight"] = torch.rand(2, 3) - state_dict["res.layer2.bias"] = torch.rand(2) - - return m, state_dict - - def test_from_last_checkpoint_model(self): - # test that loading works even if they differ by a prefix - trained_model = self.create_model() - fresh_model = self.create_model() - with TemporaryDirectory() as f: - checkpointer = Checkpointer(trained_model, path=f) - checkpointer.save("checkpoint_file") - - # in the same folder - fresh_checkpointer = Checkpointer(fresh_model, path=f) - assert fresh_checkpointer.has_checkpoint() - assert fresh_checkpointer.last_checkpoint() == os.path.realpath( - os.path.join(f, "checkpoint_file.pth") - ) - _ = fresh_checkpointer.load() - - for trained_p, loaded_p in zip( - trained_model.parameters(), fresh_model.parameters() - ): - # different tensor references - assert id(trained_p) != id(loaded_p) - # same content - assert trained_p.equal(loaded_p) - - def test_from_name_file_model(self): - # test that loading works even if they differ by a prefix - trained_model = self.create_model() - fresh_model = self.create_model() - with TemporaryDirectory() as f: - checkpointer = Checkpointer(trained_model, path=f) - checkpointer.save("checkpoint_file") - - # on different folders - with TemporaryDirectory() as g: - fresh_checkpointer = Checkpointer(fresh_model, path=g) - assert not fresh_checkpointer.has_checkpoint() - assert fresh_checkpointer.last_checkpoint() is None - _ = fresh_checkpointer.load( - os.path.join(f, "checkpoint_file.pth") - ) - - for trained_p, loaded_p in zip( - trained_model.parameters(), fresh_model.parameters() - ): - # different tensor references - assert id(trained_p) != id(loaded_p) - # same content - assert trained_p.equal(loaded_p) diff --git a/tests/test_cli.py b/tests/test_cli.py index 2feb5e6c9370374874b959b20bf3cad5dda10283..1204257d4e4e06b08914f02f655b880ac1ea8034 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -195,14 +195,18 @@ def test_train_pasa_montgomery(temporary_basedir): _assert_exit_0(result) assert os.path.exists( - os.path.join(output_folder, "model_final_epoch.pth") + os.path.join(output_folder, "model_final_epoch.ckpt") ) assert os.path.exists( - os.path.join(output_folder, "model_lowest_valid_loss.pth") + os.path.join(output_folder, "model_lowest_valid_loss.ckpt") ) - assert os.path.exists(os.path.join(output_folder, "last_checkpoint")) assert os.path.exists(os.path.join(output_folder, "constants.csv")) - assert os.path.exists(os.path.join(output_folder, "trainlog.csv")) + assert os.path.exists( + os.path.join(output_folder, "logs_csv", "version_0", "metrics.csv") + ) + assert os.path.exists( + os.path.join(output_folder, "logs_tensorboard", "version_0") + ) assert os.path.exists(os.path.join(output_folder, "model_summary.txt")) keywords = { @@ -210,10 +214,6 @@ def test_train_pasa_montgomery(temporary_basedir): r"^Found \(dedicated\) '__valid__' set for validation$": 1, r"^Continuing from epoch 0$": 1, r"^Saving model summary at.*$": 1, - r"^Model has.*$": 1, - r"^Saving checkpoint": 2, - r"^Total training time:": 1, - r"^Z-normalization with mean": 1, } buf.seek(0) logging_output = buf.read() @@ -247,13 +247,17 @@ def test_train_pasa_montgomery_from_checkpoint(temporary_basedir): ) _assert_exit_0(result0) - assert os.path.exists(os.path.join(output_folder, "model_final_epoch.pth")) + assert os.path.exists(os.path.join(output_folder, "model_final_epoch.ckpt")) assert os.path.exists( - os.path.join(output_folder, "model_lowest_valid_loss.pth") + os.path.join(output_folder, "model_lowest_valid_loss.ckpt") ) - assert os.path.exists(os.path.join(output_folder, "last_checkpoint")) assert os.path.exists(os.path.join(output_folder, "constants.csv")) - assert os.path.exists(os.path.join(output_folder, "trainlog.csv")) + assert os.path.exists( + os.path.join(output_folder, "logs_csv", "version_0", "metrics.csv") + ) + assert os.path.exists( + os.path.join(output_folder, "logs_tensorboard", "version_0") + ) assert os.path.exists(os.path.join(output_folder, "model_summary.txt")) with stdout_logging() as buf: @@ -272,25 +276,25 @@ def test_train_pasa_montgomery_from_checkpoint(temporary_basedir): _assert_exit_0(result) assert os.path.exists( - os.path.join(output_folder, "model_final_epoch.pth") + os.path.join(output_folder, "model_final_epoch.ckpt") ) assert os.path.exists( - os.path.join(output_folder, "model_lowest_valid_loss.pth") + os.path.join(output_folder, "model_lowest_valid_loss.ckpt") ) - assert os.path.exists(os.path.join(output_folder, "last_checkpoint")) assert os.path.exists(os.path.join(output_folder, "constants.csv")) - assert os.path.exists(os.path.join(output_folder, "trainlog.csv")) + assert os.path.exists( + os.path.join(output_folder, "logs_csv", "version_0", "metrics.csv") + ) + assert os.path.exists( + os.path.join(output_folder, "logs_tensorboard", "version_0") + ) assert os.path.exists(os.path.join(output_folder, "model_summary.txt")) keywords = { r"^Found \(dedicated\) '__train__' set for training$": 1, r"^Found \(dedicated\) '__valid__' set for validation$": 1, - r"^Continuing from epoch 1$": 1, + r"^Continuing from epoch 0$": 1, r"^Saving model summary at.*$": 1, - r"^Model has.*$": 1, - r"^Found lowest validation loss from previous session.*$": 1, - r"^Total training time:": 1, - r"^Z-normalization with mean": 1, } buf.seek(0) logging_output = buf.read() @@ -302,10 +306,10 @@ def test_train_pasa_montgomery_from_checkpoint(temporary_basedir): f"instead of the expected {v}:\nOutput:\n{logging_output}" ) - extra_keyword = "Saving checkpoint" - assert ( - extra_keyword in logging_output - ), f"String '{extra_keyword}' did not appear at least once in the output:\nOutput:\n{logging_output}" + # extra_keyword = "Saving checkpoint" + # assert ( + # extra_keyword in logging_output + # ), f"String '{extra_keyword}' did not appear at least once in the output:\nOutput:\n{logging_output}" @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") @@ -324,7 +328,7 @@ def test_predict_pasa_montgomery(temporary_basedir, datadir): "-vv", "--batch-size=1", "--relevance-analysis", - f"--weight={str(datadir / 'lfs' / 'models' / 'pasa.pth')}", + f"--weight={str(datadir / 'lfs' / 'models' / 'pasa.ckpt')}", f"--output-folder={output_folder}", ], ) @@ -342,7 +346,6 @@ def test_predict_pasa_montgomery(temporary_basedir, datadir): keywords = { r"^Loading checkpoint from.*$": 1, - r"^Total time:.*$": 3, r"^Relevance analysis.*$": 3, } buf.seek(0) @@ -513,14 +516,18 @@ def test_train_signstotb_montgomery_rs(temporary_basedir, datadir): _assert_exit_0(result) assert os.path.exists( - os.path.join(output_folder, "model_final_epoch.pth") + os.path.join(output_folder, "model_final_epoch.ckpt") ) assert os.path.exists( - os.path.join(output_folder, "model_lowest_valid_loss.pth") + os.path.join(output_folder, "model_lowest_valid_loss.ckpt") ) - assert os.path.exists(os.path.join(output_folder, "last_checkpoint")) assert os.path.exists(os.path.join(output_folder, "constants.csv")) - assert os.path.exists(os.path.join(output_folder, "trainlog.csv")) + assert os.path.exists( + os.path.join(output_folder, "logs_csv", "version_0", "metrics.csv") + ) + assert os.path.exists( + os.path.join(output_folder, "logs_tensorboard", "version_0") + ) assert os.path.exists(os.path.join(output_folder, "model_summary.txt")) keywords = { @@ -528,9 +535,6 @@ def test_train_signstotb_montgomery_rs(temporary_basedir, datadir): r"^Found \(dedicated\) '__valid__' set for validation$": 1, r"^Continuing from epoch 0$": 1, r"^Saving model summary at.*$": 1, - r"^Model has.*$": 1, - r"^Saving checkpoint": 2, - r"^Total training time:": 1, } buf.seek(0) logging_output = buf.read() @@ -559,7 +563,7 @@ def test_predict_signstotb_montgomery_rs(temporary_basedir, datadir): "-vv", "--batch-size=1", "--relevance-analysis", - f"--weight={str(datadir / 'lfs' / 'models' / 'signstotb.pth')}", + f"--weight={str(datadir / 'lfs' / 'models' / 'signstotb.ckpt')}", f"--output-folder={output_folder}", ], ) @@ -577,7 +581,6 @@ def test_predict_signstotb_montgomery_rs(temporary_basedir, datadir): keywords = { r"^Loading checkpoint from.*$": 1, - r"^Total time:.*$": 3 * 15, r"^Starting relevance analysis for subset.*$": 3, r"^Creating and saving plot at.*$": 3, } @@ -614,14 +617,18 @@ def test_train_logreg_montgomery_rs(temporary_basedir, datadir): _assert_exit_0(result) assert os.path.exists( - os.path.join(output_folder, "model_final_epoch.pth") + os.path.join(output_folder, "model_final_epoch.ckpt") ) assert os.path.exists( - os.path.join(output_folder, "model_lowest_valid_loss.pth") + os.path.join(output_folder, "model_lowest_valid_loss.ckpt") ) - assert os.path.exists(os.path.join(output_folder, "last_checkpoint")) assert os.path.exists(os.path.join(output_folder, "constants.csv")) - assert os.path.exists(os.path.join(output_folder, "trainlog.csv")) + assert os.path.exists( + os.path.join(output_folder, "logs_csv", "version_0", "metrics.csv") + ) + assert os.path.exists( + os.path.join(output_folder, "logs_tensorboard", "version_0") + ) assert os.path.exists(os.path.join(output_folder, "model_summary.txt")) keywords = { @@ -629,9 +636,6 @@ def test_train_logreg_montgomery_rs(temporary_basedir, datadir): r"^Found \(dedicated\) '__valid__' set for validation$": 1, r"^Continuing from epoch 0$": 1, r"^Saving model summary at.*$": 1, - r"^Model has.*$": 1, - r"^Saving checkpoint": 2, - r"^Total training time:": 1, } buf.seek(0) logging_output = buf.read() @@ -659,7 +663,7 @@ def test_predict_logreg_montgomery_rs(temporary_basedir, datadir): "montgomery_rs", "-vv", "--batch-size=1", - f"--weight={str(datadir / 'lfs' / 'models' / 'logreg.pth')}", + f"--weight={str(datadir / 'lfs' / 'models' / 'logreg.ckpt')}", f"--output-folder={output_folder}", ], ) @@ -673,7 +677,6 @@ def test_predict_logreg_montgomery_rs(temporary_basedir, datadir): keywords = { r"^Loading checkpoint from.*$": 1, - r"^Total time:.*$": 3, r"^Logistic regression identified: saving model weights.*$": 1, } buf.seek(0) diff --git a/tests/test_mc_ch_in_11k.py b/tests/test_mc_ch_in_11k.py new file mode 100644 index 0000000000000000000000000000000000000000..9aeb0c36951d1383cd758fb5bcd2172a2426411a --- /dev/null +++ b/tests/test_mc_ch_in_11k.py @@ -0,0 +1,699 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Tests for the aggregated Montgomery-Shenzhen-Indian dataset.""" + + +def test_dataset_consistency(): + from ptbench.configs.datasets.indian import default as indian + from ptbench.configs.datasets.indian import fold_0 as indian_f0 + from ptbench.configs.datasets.indian import fold_0_rgb as indian_f0_rgb + from ptbench.configs.datasets.indian import fold_1 as indian_f1 + from ptbench.configs.datasets.indian import fold_1_rgb as indian_f1_rgb + from ptbench.configs.datasets.indian import fold_2 as indian_f2 + from ptbench.configs.datasets.indian import fold_2_rgb as indian_f2_rgb + from ptbench.configs.datasets.indian import fold_3 as indian_f3 + from ptbench.configs.datasets.indian import fold_3_rgb as indian_f3_rgb + from ptbench.configs.datasets.indian import fold_4 as indian_f4 + from ptbench.configs.datasets.indian import fold_4_rgb as indian_f4_rgb + from ptbench.configs.datasets.indian import fold_5 as indian_f5 + from ptbench.configs.datasets.indian import fold_5_rgb as indian_f5_rgb + from ptbench.configs.datasets.indian import fold_6 as indian_f6 + from ptbench.configs.datasets.indian import fold_6_rgb as indian_f6_rgb + from ptbench.configs.datasets.indian import fold_7 as indian_f7 + from ptbench.configs.datasets.indian import fold_7_rgb as indian_f7_rgb + from ptbench.configs.datasets.indian import fold_8 as indian_f8 + from ptbench.configs.datasets.indian import fold_8_rgb as indian_f8_rgb + from ptbench.configs.datasets.indian import fold_9 as indian_f9 + from ptbench.configs.datasets.indian import fold_9_rgb as indian_f9_rgb + from ptbench.configs.datasets.mc_ch_in_11k import default as mc_ch_in_11k + from ptbench.configs.datasets.mc_ch_in_11k import fold_0 as mc_ch_in_11k_f0 + from ptbench.configs.datasets.mc_ch_in_11k import ( + fold_0_rgb as mc_ch_in_11k_f0_rgb, + ) + from ptbench.configs.datasets.mc_ch_in_11k import fold_1 as mc_ch_in_11k_f1 + from ptbench.configs.datasets.mc_ch_in_11k import ( + fold_1_rgb as mc_ch_in_11k_f1_rgb, + ) + from ptbench.configs.datasets.mc_ch_in_11k import fold_2 as mc_ch_in_11k_f2 + from ptbench.configs.datasets.mc_ch_in_11k import ( + fold_2_rgb as mc_ch_in_11k_f2_rgb, + ) + from ptbench.configs.datasets.mc_ch_in_11k import fold_3 as mc_ch_in_11k_f3 + from ptbench.configs.datasets.mc_ch_in_11k import ( + fold_3_rgb as mc_ch_in_11k_f3_rgb, + ) + from ptbench.configs.datasets.mc_ch_in_11k import fold_4 as mc_ch_in_11k_f4 + from ptbench.configs.datasets.mc_ch_in_11k import ( + fold_4_rgb as mc_ch_in_11k_f4_rgb, + ) + from ptbench.configs.datasets.mc_ch_in_11k import fold_5 as mc_ch_in_11k_f5 + from ptbench.configs.datasets.mc_ch_in_11k import ( + fold_5_rgb as mc_ch_in_11k_f5_rgb, + ) + from ptbench.configs.datasets.mc_ch_in_11k import fold_6 as mc_ch_in_11k_f6 + from ptbench.configs.datasets.mc_ch_in_11k import ( + fold_6_rgb as mc_ch_in_11k_f6_rgb, + ) + from ptbench.configs.datasets.mc_ch_in_11k import fold_7 as mc_ch_in_11k_f7 + from ptbench.configs.datasets.mc_ch_in_11k import ( + fold_7_rgb as mc_ch_in_11k_f7_rgb, + ) + from ptbench.configs.datasets.mc_ch_in_11k import fold_8 as mc_ch_in_11k_f8 + from ptbench.configs.datasets.mc_ch_in_11k import ( + fold_8_rgb as mc_ch_in_11k_f8_rgb, + ) + from ptbench.configs.datasets.mc_ch_in_11k import fold_9 as mc_ch_in_11k_f9 + from ptbench.configs.datasets.mc_ch_in_11k import ( + fold_9_rgb as mc_ch_in_11k_f9_rgb, + ) + from ptbench.configs.datasets.montgomery import default as mc + from ptbench.configs.datasets.montgomery import fold_0 as mc_f0 + from ptbench.configs.datasets.montgomery import fold_0_rgb as mc_f0_rgb + from ptbench.configs.datasets.montgomery import fold_1 as mc_f1 + from ptbench.configs.datasets.montgomery import fold_1_rgb as mc_f1_rgb + from ptbench.configs.datasets.montgomery import fold_2 as mc_f2 + from ptbench.configs.datasets.montgomery import fold_2_rgb as mc_f2_rgb + from ptbench.configs.datasets.montgomery import fold_3 as mc_f3 + from ptbench.configs.datasets.montgomery import fold_3_rgb as mc_f3_rgb + from ptbench.configs.datasets.montgomery import fold_4 as mc_f4 + from ptbench.configs.datasets.montgomery import fold_4_rgb as mc_f4_rgb + from ptbench.configs.datasets.montgomery import fold_5 as mc_f5 + from ptbench.configs.datasets.montgomery import fold_5_rgb as mc_f5_rgb + from ptbench.configs.datasets.montgomery import fold_6 as mc_f6 + from ptbench.configs.datasets.montgomery import fold_6_rgb as mc_f6_rgb + from ptbench.configs.datasets.montgomery import fold_7 as mc_f7 + from ptbench.configs.datasets.montgomery import fold_7_rgb as mc_f7_rgb + from ptbench.configs.datasets.montgomery import fold_8 as mc_f8 + from ptbench.configs.datasets.montgomery import fold_8_rgb as mc_f8_rgb + from ptbench.configs.datasets.montgomery import fold_9 as mc_f9 + from ptbench.configs.datasets.montgomery import fold_9_rgb as mc_f9_rgb + from ptbench.configs.datasets.shenzhen import default as ch + from ptbench.configs.datasets.shenzhen import fold_0 as ch_f0 + from ptbench.configs.datasets.shenzhen import fold_0_rgb as ch_f0_rgb + from ptbench.configs.datasets.shenzhen import fold_1 as ch_f1 + from ptbench.configs.datasets.shenzhen import fold_1_rgb as ch_f1_rgb + from ptbench.configs.datasets.shenzhen import fold_2 as ch_f2 + from ptbench.configs.datasets.shenzhen import fold_2_rgb as ch_f2_rgb + from ptbench.configs.datasets.shenzhen import fold_3 as ch_f3 + from ptbench.configs.datasets.shenzhen import fold_3_rgb as ch_f3_rgb + from ptbench.configs.datasets.shenzhen import fold_4 as ch_f4 + from ptbench.configs.datasets.shenzhen import fold_4_rgb as ch_f4_rgb + from ptbench.configs.datasets.shenzhen import fold_5 as ch_f5 + from ptbench.configs.datasets.shenzhen import fold_5_rgb as ch_f5_rgb + from ptbench.configs.datasets.shenzhen import fold_6 as ch_f6 + from ptbench.configs.datasets.shenzhen import fold_6_rgb as ch_f6_rgb + from ptbench.configs.datasets.shenzhen import fold_7 as ch_f7 + from ptbench.configs.datasets.shenzhen import fold_7_rgb as ch_f7_rgb + from ptbench.configs.datasets.shenzhen import fold_8 as ch_f8 + from ptbench.configs.datasets.shenzhen import fold_8_rgb as ch_f8_rgb + from ptbench.configs.datasets.shenzhen import fold_9 as ch_f9 + from ptbench.configs.datasets.shenzhen import fold_9_rgb as ch_f9_rgb + from ptbench.configs.datasets.tbx11k_simplified import default as tbx11k + from ptbench.configs.datasets.tbx11k_simplified import fold_0 as tbx11k_f0 + from ptbench.configs.datasets.tbx11k_simplified import ( + fold_0_rgb as tbx11k_f0_rgb, + ) + from ptbench.configs.datasets.tbx11k_simplified import fold_1 as tbx11k_f1 + from ptbench.configs.datasets.tbx11k_simplified import ( + fold_1_rgb as tbx11k_f1_rgb, + ) + from ptbench.configs.datasets.tbx11k_simplified import fold_2 as tbx11k_f2 + from ptbench.configs.datasets.tbx11k_simplified import ( + fold_2_rgb as tbx11k_f2_rgb, + ) + from ptbench.configs.datasets.tbx11k_simplified import fold_3 as tbx11k_f3 + from ptbench.configs.datasets.tbx11k_simplified import ( + fold_3_rgb as tbx11k_f3_rgb, + ) + from ptbench.configs.datasets.tbx11k_simplified import fold_4 as tbx11k_f4 + from ptbench.configs.datasets.tbx11k_simplified import ( + fold_4_rgb as tbx11k_f4_rgb, + ) + from ptbench.configs.datasets.tbx11k_simplified import fold_5 as tbx11k_f5 + from ptbench.configs.datasets.tbx11k_simplified import ( + fold_5_rgb as tbx11k_f5_rgb, + ) + from ptbench.configs.datasets.tbx11k_simplified import fold_6 as tbx11k_f6 + from ptbench.configs.datasets.tbx11k_simplified import ( + fold_6_rgb as tbx11k_f6_rgb, + ) + from ptbench.configs.datasets.tbx11k_simplified import fold_7 as tbx11k_f7 + from ptbench.configs.datasets.tbx11k_simplified import ( + fold_7_rgb as tbx11k_f7_rgb, + ) + from ptbench.configs.datasets.tbx11k_simplified import fold_8 as tbx11k_f8 + from ptbench.configs.datasets.tbx11k_simplified import ( + fold_8_rgb as tbx11k_f8_rgb, + ) + from ptbench.configs.datasets.tbx11k_simplified import fold_9 as tbx11k_f9 + from ptbench.configs.datasets.tbx11k_simplified import ( + fold_9_rgb as tbx11k_f9_rgb, + ) + + # Default protocol + mc_ch_in_11k_dataset = mc_ch_in_11k.dataset + assert isinstance(mc_ch_in_11k_dataset, dict) + + mc_dataset = mc.dataset + ch_dataset = ch.dataset + in_dataset = indian.dataset + tbx11k_dataset = tbx11k.dataset + + assert "train" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["train"]) == len(mc_dataset["train"]) + len( + ch_dataset["train"] + ) + len(in_dataset["train"]) + len(tbx11k_dataset["train"]) + + assert "validation" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["test"]) == len(mc_dataset["test"]) + len( + ch_dataset["test"] + ) + len(in_dataset["test"]) + len(tbx11k_dataset["test"]) + + # Fold 0 + mc_ch_in_11k_dataset = mc_ch_in_11k_f0.dataset + assert isinstance(mc_ch_in_11k_dataset, dict) + + mc_dataset = mc_f0.dataset + ch_dataset = ch_f0.dataset + in_dataset = indian_f0.dataset + tbx11k_dataset = tbx11k_f0.dataset + + assert "train" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["train"]) == len(mc_dataset["train"]) + len( + ch_dataset["train"] + ) + len(in_dataset["train"]) + len(tbx11k_dataset["train"]) + + assert "validation" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["test"]) == len(mc_dataset["test"]) + len( + ch_dataset["test"] + ) + len(in_dataset["test"]) + len(tbx11k_dataset["test"]) + + # Fold 1 + mc_ch_in_11k_dataset = mc_ch_in_11k_f1.dataset + assert isinstance(mc_ch_in_11k_dataset, dict) + + mc_dataset = mc_f1.dataset + ch_dataset = ch_f1.dataset + in_dataset = indian_f1.dataset + tbx11k_dataset = tbx11k_f1.dataset + + assert "train" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["train"]) == len(mc_dataset["train"]) + len( + ch_dataset["train"] + ) + len(in_dataset["train"]) + len(tbx11k_dataset["train"]) + + assert "validation" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["test"]) == len(mc_dataset["test"]) + len( + ch_dataset["test"] + ) + len(in_dataset["test"]) + len(tbx11k_dataset["test"]) + + # Fold 2 + mc_ch_in_11k_dataset = mc_ch_in_11k_f2.dataset + assert isinstance(mc_ch_in_11k_dataset, dict) + + mc_dataset = mc_f2.dataset + ch_dataset = ch_f2.dataset + in_dataset = indian_f2.dataset + tbx11k_dataset = tbx11k_f2.dataset + + assert "train" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["train"]) == len(mc_dataset["train"]) + len( + ch_dataset["train"] + ) + len(in_dataset["train"]) + len(tbx11k_dataset["train"]) + + assert "validation" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["test"]) == len(mc_dataset["test"]) + len( + ch_dataset["test"] + ) + len(in_dataset["test"]) + len(tbx11k_dataset["test"]) + + # Fold 3 + mc_ch_in_11k_dataset = mc_ch_in_11k_f3.dataset + assert isinstance(mc_ch_in_11k_dataset, dict) + + mc_dataset = mc_f3.dataset + ch_dataset = ch_f3.dataset + in_dataset = indian_f3.dataset + tbx11k_dataset = tbx11k_f3.dataset + + assert "train" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["train"]) == len(mc_dataset["train"]) + len( + ch_dataset["train"] + ) + len(in_dataset["train"]) + len(tbx11k_dataset["train"]) + + assert "validation" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["test"]) == len(mc_dataset["test"]) + len( + ch_dataset["test"] + ) + len(in_dataset["test"]) + len(tbx11k_dataset["test"]) + + # Fold 4 + mc_ch_in_11k_dataset = mc_ch_in_11k_f4.dataset + assert isinstance(mc_ch_in_11k_dataset, dict) + + mc_dataset = mc_f4.dataset + ch_dataset = ch_f4.dataset + in_dataset = indian_f4.dataset + tbx11k_dataset = tbx11k_f4.dataset + + assert "train" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["train"]) == len(mc_dataset["train"]) + len( + ch_dataset["train"] + ) + len(in_dataset["train"]) + len(tbx11k_dataset["train"]) + + assert "validation" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["test"]) == len(mc_dataset["test"]) + len( + ch_dataset["test"] + ) + len(in_dataset["test"]) + len(tbx11k_dataset["test"]) + + # Fold 5 + mc_ch_in_11k_dataset = mc_ch_in_11k_f5.dataset + assert isinstance(mc_ch_in_11k_dataset, dict) + + mc_dataset = mc_f5.dataset + ch_dataset = ch_f5.dataset + in_dataset = indian_f5.dataset + tbx11k_dataset = tbx11k_f5.dataset + + assert "train" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["train"]) == len(mc_dataset["train"]) + len( + ch_dataset["train"] + ) + len(in_dataset["train"]) + len(tbx11k_dataset["train"]) + + assert "validation" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["test"]) == len(mc_dataset["test"]) + len( + ch_dataset["test"] + ) + len(in_dataset["test"]) + len(tbx11k_dataset["test"]) + + # Fold 6 + mc_ch_in_11k_dataset = mc_ch_in_11k_f6.dataset + assert isinstance(mc_ch_in_11k_dataset, dict) + + mc_dataset = mc_f6.dataset + ch_dataset = ch_f6.dataset + in_dataset = indian_f6.dataset + tbx11k_dataset = tbx11k_f6.dataset + + assert "train" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["train"]) == len(mc_dataset["train"]) + len( + ch_dataset["train"] + ) + len(in_dataset["train"]) + len(tbx11k_dataset["train"]) + + assert "validation" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["test"]) == len(mc_dataset["test"]) + len( + ch_dataset["test"] + ) + len(in_dataset["test"]) + len(tbx11k_dataset["test"]) + + # Fold 7 + mc_ch_in_11k_dataset = mc_ch_in_11k_f7.dataset + assert isinstance(mc_ch_in_11k_dataset, dict) + + mc_dataset = mc_f7.dataset + ch_dataset = ch_f7.dataset + in_dataset = indian_f7.dataset + tbx11k_dataset = tbx11k_f7.dataset + + assert "train" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["train"]) == len(mc_dataset["train"]) + len( + ch_dataset["train"] + ) + len(in_dataset["train"]) + len(tbx11k_dataset["train"]) + + assert "validation" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["test"]) == len(mc_dataset["test"]) + len( + ch_dataset["test"] + ) + len(in_dataset["test"]) + len(tbx11k_dataset["test"]) + + # Fold 8 + mc_ch_in_11k_dataset = mc_ch_in_11k_f8.dataset + assert isinstance(mc_ch_in_11k_dataset, dict) + + mc_dataset = mc_f8.dataset + ch_dataset = ch_f8.dataset + in_dataset = indian_f8.dataset + tbx11k_dataset = tbx11k_f8.dataset + + assert "train" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["train"]) == len(mc_dataset["train"]) + len( + ch_dataset["train"] + ) + len(in_dataset["train"]) + len(tbx11k_dataset["train"]) + + assert "validation" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["test"]) == len(mc_dataset["test"]) + len( + ch_dataset["test"] + ) + len(in_dataset["test"]) + len(tbx11k_dataset["test"]) + + # Fold 9 + mc_ch_in_11k_dataset = mc_ch_in_11k_f9.dataset + assert isinstance(mc_ch_in_11k_dataset, dict) + + mc_dataset = mc_f9.dataset + ch_dataset = ch_f9.dataset + in_dataset = indian_f9.dataset + tbx11k_dataset = tbx11k_f9.dataset + + assert "train" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["train"]) == len(mc_dataset["train"]) + len( + ch_dataset["train"] + ) + len(in_dataset["train"]) + len(tbx11k_dataset["train"]) + + assert "validation" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["test"]) == len(mc_dataset["test"]) + len( + ch_dataset["test"] + ) + len(in_dataset["test"]) + len(tbx11k_dataset["test"]) + + # Fold 0, RGB + mc_ch_in_11k_dataset = mc_ch_in_11k_f0_rgb.dataset + assert isinstance(mc_ch_in_11k_dataset, dict) + + mc_dataset = mc_f0_rgb.dataset + ch_dataset = ch_f0_rgb.dataset + in_dataset = indian_f0_rgb.dataset + tbx11k_dataset = tbx11k_f0_rgb.dataset + + assert "train" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["train"]) == len(mc_dataset["train"]) + len( + ch_dataset["train"] + ) + len(in_dataset["train"]) + len(tbx11k_dataset["train"]) + + assert "validation" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["test"]) == len(mc_dataset["test"]) + len( + ch_dataset["test"] + ) + len(in_dataset["test"]) + len(tbx11k_dataset["test"]) + + # Fold 1, RGB + mc_ch_in_11k_dataset = mc_ch_in_11k_f1_rgb.dataset + assert isinstance(mc_ch_in_11k_dataset, dict) + + mc_dataset = mc_f1_rgb.dataset + ch_dataset = ch_f1_rgb.dataset + in_dataset = indian_f1_rgb.dataset + tbx11k_dataset = tbx11k_f1_rgb.dataset + + assert "train" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["train"]) == len(mc_dataset["train"]) + len( + ch_dataset["train"] + ) + len(in_dataset["train"]) + len(tbx11k_dataset["train"]) + + assert "validation" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["test"]) == len(mc_dataset["test"]) + len( + ch_dataset["test"] + ) + len(in_dataset["test"]) + len(tbx11k_dataset["test"]) + + # Fold 2, RGB + mc_ch_in_11k_dataset = mc_ch_in_11k_f2_rgb.dataset + assert isinstance(mc_ch_in_11k_dataset, dict) + + mc_dataset = mc_f2_rgb.dataset + ch_dataset = ch_f2_rgb.dataset + in_dataset = indian_f2_rgb.dataset + tbx11k_dataset = tbx11k_f2_rgb.dataset + + assert "train" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["train"]) == len(mc_dataset["train"]) + len( + ch_dataset["train"] + ) + len(in_dataset["train"]) + len(tbx11k_dataset["train"]) + + assert "validation" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["test"]) == len(mc_dataset["test"]) + len( + ch_dataset["test"] + ) + len(in_dataset["test"]) + len(tbx11k_dataset["test"]) + + # Fold 3, RGB + mc_ch_in_11k_dataset = mc_ch_in_11k_f3_rgb.dataset + assert isinstance(mc_ch_in_11k_dataset, dict) + + mc_dataset = mc_f3_rgb.dataset + ch_dataset = ch_f3_rgb.dataset + in_dataset = indian_f3_rgb.dataset + tbx11k_dataset = tbx11k_f3_rgb.dataset + + assert "train" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["train"]) == len(mc_dataset["train"]) + len( + ch_dataset["train"] + ) + len(in_dataset["train"]) + len(tbx11k_dataset["train"]) + + assert "validation" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["test"]) == len(mc_dataset["test"]) + len( + ch_dataset["test"] + ) + len(in_dataset["test"]) + len(tbx11k_dataset["test"]) + + # Fold 4, RGB + mc_ch_in_11k_dataset = mc_ch_in_11k_f4_rgb.dataset + assert isinstance(mc_ch_in_11k_dataset, dict) + + mc_dataset = mc_f4_rgb.dataset + ch_dataset = ch_f4_rgb.dataset + in_dataset = indian_f4_rgb.dataset + tbx11k_dataset = tbx11k_f4_rgb.dataset + + assert "train" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["train"]) == len(mc_dataset["train"]) + len( + ch_dataset["train"] + ) + len(in_dataset["train"]) + len(tbx11k_dataset["train"]) + + assert "validation" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["test"]) == len(mc_dataset["test"]) + len( + ch_dataset["test"] + ) + len(in_dataset["test"]) + len(tbx11k_dataset["test"]) + + # Fold 5, RGB + mc_ch_in_11k_dataset = mc_ch_in_11k_f5_rgb.dataset + assert isinstance(mc_ch_in_11k_dataset, dict) + + mc_dataset = mc_f5_rgb.dataset + ch_dataset = ch_f5_rgb.dataset + in_dataset = indian_f5_rgb.dataset + tbx11k_dataset = tbx11k_f5_rgb.dataset + + assert "train" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["train"]) == len(mc_dataset["train"]) + len( + ch_dataset["train"] + ) + len(in_dataset["train"]) + len(tbx11k_dataset["train"]) + + assert "validation" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["test"]) == len(mc_dataset["test"]) + len( + ch_dataset["test"] + ) + len(in_dataset["test"]) + len(tbx11k_dataset["test"]) + + # Fold 6, RGB + mc_ch_in_11k_dataset = mc_ch_in_11k_f6_rgb.dataset + assert isinstance(mc_ch_in_11k_dataset, dict) + + mc_dataset = mc_f6_rgb.dataset + ch_dataset = ch_f6_rgb.dataset + in_dataset = indian_f6_rgb.dataset + tbx11k_dataset = tbx11k_f6_rgb.dataset + + assert "train" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["train"]) == len(mc_dataset["train"]) + len( + ch_dataset["train"] + ) + len(in_dataset["train"]) + len(tbx11k_dataset["train"]) + + assert "validation" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["test"]) == len(mc_dataset["test"]) + len( + ch_dataset["test"] + ) + len(in_dataset["test"]) + len(tbx11k_dataset["test"]) + + # Fold 7, RGB + mc_ch_in_11k_dataset = mc_ch_in_11k_f7_rgb.dataset + assert isinstance(mc_ch_in_11k_dataset, dict) + + mc_dataset = mc_f7_rgb.dataset + ch_dataset = ch_f7_rgb.dataset + in_dataset = indian_f7_rgb.dataset + tbx11k_dataset = tbx11k_f7_rgb.dataset + + assert "train" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["train"]) == len(mc_dataset["train"]) + len( + ch_dataset["train"] + ) + len(in_dataset["train"]) + len(tbx11k_dataset["train"]) + + assert "validation" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["test"]) == len(mc_dataset["test"]) + len( + ch_dataset["test"] + ) + len(in_dataset["test"]) + len(tbx11k_dataset["test"]) + + # Fold 8, RGB + mc_ch_in_11k_dataset = mc_ch_in_11k_f8_rgb.dataset + assert isinstance(mc_ch_in_11k_dataset, dict) + + mc_dataset = mc_f8_rgb.dataset + ch_dataset = ch_f8_rgb.dataset + in_dataset = indian_f8_rgb.dataset + tbx11k_dataset = tbx11k_f8_rgb.dataset + + assert "train" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["train"]) == len(mc_dataset["train"]) + len( + ch_dataset["train"] + ) + len(in_dataset["train"]) + len(tbx11k_dataset["train"]) + + assert "validation" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["test"]) == len(mc_dataset["test"]) + len( + ch_dataset["test"] + ) + len(in_dataset["test"]) + len(tbx11k_dataset["test"]) + + # Fold 9, RGB + mc_ch_in_11k_dataset = mc_ch_in_11k_f9_rgb.dataset + assert isinstance(mc_ch_in_11k_dataset, dict) + + mc_dataset = mc_f9_rgb.dataset + ch_dataset = ch_f9_rgb.dataset + in_dataset = indian_f9_rgb.dataset + tbx11k_dataset = tbx11k_f9_rgb.dataset + + assert "train" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["train"]) == len(mc_dataset["train"]) + len( + ch_dataset["train"] + ) + len(in_dataset["train"]) + len(tbx11k_dataset["train"]) + + assert "validation" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["test"]) == len(mc_dataset["test"]) + len( + ch_dataset["test"] + ) + len(in_dataset["test"]) + len(tbx11k_dataset["test"]) diff --git a/tests/test_mc_ch_in_11k_RS.py b/tests/test_mc_ch_in_11k_RS.py new file mode 100644 index 0000000000000000000000000000000000000000..c0a50ef3bc6c8d0416e2d743cad13e499760a937 --- /dev/null +++ b/tests/test_mc_ch_in_11k_RS.py @@ -0,0 +1,440 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Tests for the aggregated Montgomery-Shenzhen-Indian-tbx11k_simplified +dataset.""" + + +def test_dataset_consistency(): + from ptbench.configs.datasets.indian_RS import default as indian_RS + from ptbench.configs.datasets.indian_RS import fold_0 as indian_f0 + from ptbench.configs.datasets.indian_RS import fold_1 as indian_f1 + from ptbench.configs.datasets.indian_RS import fold_2 as indian_f2 + from ptbench.configs.datasets.indian_RS import fold_3 as indian_f3 + from ptbench.configs.datasets.indian_RS import fold_4 as indian_f4 + from ptbench.configs.datasets.indian_RS import fold_5 as indian_f5 + from ptbench.configs.datasets.indian_RS import fold_6 as indian_f6 + from ptbench.configs.datasets.indian_RS import fold_7 as indian_f7 + from ptbench.configs.datasets.indian_RS import fold_8 as indian_f8 + from ptbench.configs.datasets.indian_RS import fold_9 as indian_f9 + from ptbench.configs.datasets.mc_ch_in_11k_RS import ( + default as mc_ch_in_11k_RS, + ) + from ptbench.configs.datasets.mc_ch_in_11k_RS import ( + fold_0 as mc_ch_in_11k_f0, + ) + from ptbench.configs.datasets.mc_ch_in_11k_RS import ( + fold_1 as mc_ch_in_11k_f1, + ) + from ptbench.configs.datasets.mc_ch_in_11k_RS import ( + fold_2 as mc_ch_in_11k_f2, + ) + from ptbench.configs.datasets.mc_ch_in_11k_RS import ( + fold_3 as mc_ch_in_11k_f3, + ) + from ptbench.configs.datasets.mc_ch_in_11k_RS import ( + fold_4 as mc_ch_in_11k_f4, + ) + from ptbench.configs.datasets.mc_ch_in_11k_RS import ( + fold_5 as mc_ch_in_11k_f5, + ) + from ptbench.configs.datasets.mc_ch_in_11k_RS import ( + fold_6 as mc_ch_in_11k_f6, + ) + from ptbench.configs.datasets.mc_ch_in_11k_RS import ( + fold_7 as mc_ch_in_11k_f7, + ) + from ptbench.configs.datasets.mc_ch_in_11k_RS import ( + fold_8 as mc_ch_in_11k_f8, + ) + from ptbench.configs.datasets.mc_ch_in_11k_RS import ( + fold_9 as mc_ch_in_11k_f9, + ) + from ptbench.configs.datasets.montgomery_RS import default as mc_RS + from ptbench.configs.datasets.montgomery_RS import fold_0 as mc_f0 + from ptbench.configs.datasets.montgomery_RS import fold_1 as mc_f1 + from ptbench.configs.datasets.montgomery_RS import fold_2 as mc_f2 + from ptbench.configs.datasets.montgomery_RS import fold_3 as mc_f3 + from ptbench.configs.datasets.montgomery_RS import fold_4 as mc_f4 + from ptbench.configs.datasets.montgomery_RS import fold_5 as mc_f5 + from ptbench.configs.datasets.montgomery_RS import fold_6 as mc_f6 + from ptbench.configs.datasets.montgomery_RS import fold_7 as mc_f7 + from ptbench.configs.datasets.montgomery_RS import fold_8 as mc_f8 + from ptbench.configs.datasets.montgomery_RS import fold_9 as mc_f9 + from ptbench.configs.datasets.shenzhen_RS import default as ch_RS + from ptbench.configs.datasets.shenzhen_RS import fold_0 as ch_f0 + from ptbench.configs.datasets.shenzhen_RS import fold_1 as ch_f1 + from ptbench.configs.datasets.shenzhen_RS import fold_2 as ch_f2 + from ptbench.configs.datasets.shenzhen_RS import fold_3 as ch_f3 + from ptbench.configs.datasets.shenzhen_RS import fold_4 as ch_f4 + from ptbench.configs.datasets.shenzhen_RS import fold_5 as ch_f5 + from ptbench.configs.datasets.shenzhen_RS import fold_6 as ch_f6 + from ptbench.configs.datasets.shenzhen_RS import fold_7 as ch_f7 + from ptbench.configs.datasets.shenzhen_RS import fold_8 as ch_f8 + from ptbench.configs.datasets.shenzhen_RS import fold_9 as ch_f9 + from ptbench.configs.datasets.tbx11k_simplified_RS import ( + default as tbx11k_RS, + ) + from ptbench.configs.datasets.tbx11k_simplified_RS import ( + fold_0 as tbx11k_f0, + ) + from ptbench.configs.datasets.tbx11k_simplified_RS import ( + fold_1 as tbx11k_f1, + ) + from ptbench.configs.datasets.tbx11k_simplified_RS import ( + fold_2 as tbx11k_f2, + ) + from ptbench.configs.datasets.tbx11k_simplified_RS import ( + fold_3 as tbx11k_f3, + ) + from ptbench.configs.datasets.tbx11k_simplified_RS import ( + fold_4 as tbx11k_f4, + ) + from ptbench.configs.datasets.tbx11k_simplified_RS import ( + fold_5 as tbx11k_f5, + ) + from ptbench.configs.datasets.tbx11k_simplified_RS import ( + fold_6 as tbx11k_f6, + ) + from ptbench.configs.datasets.tbx11k_simplified_RS import ( + fold_7 as tbx11k_f7, + ) + from ptbench.configs.datasets.tbx11k_simplified_RS import ( + fold_8 as tbx11k_f8, + ) + from ptbench.configs.datasets.tbx11k_simplified_RS import ( + fold_9 as tbx11k_f9, + ) + + # Default protocol + mc_ch_in_11k_RS_dataset = mc_ch_in_11k_RS.dataset + assert isinstance(mc_ch_in_11k_RS_dataset, dict) + + mc_RS_dataset = mc_RS.dataset + ch_RS_dataset = ch_RS.dataset + in_RS_dataset = indian_RS.dataset + tbx11k_RS_dataset = tbx11k_RS.dataset + + assert "train" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["train"]) == len( + mc_RS_dataset["train"] + ) + len(ch_RS_dataset["train"]) + len(in_RS_dataset["train"]) + len( + tbx11k_RS_dataset["train"] + ) + + assert "validation" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["validation"]) == len( + mc_RS_dataset["validation"] + ) + len(ch_RS_dataset["validation"]) + len( + in_RS_dataset["validation"] + ) + len( + tbx11k_RS_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["test"]) == len( + mc_RS_dataset["test"] + ) + len(ch_RS_dataset["test"]) + len(in_RS_dataset["test"]) + len( + tbx11k_RS_dataset["test"] + ) + + # Fold 0 + mc_ch_in_11k_RS_dataset = mc_ch_in_11k_f0.dataset + assert isinstance(mc_ch_in_11k_RS_dataset, dict) + + mc_dataset = mc_f0.dataset + ch_dataset = ch_f0.dataset + in_dataset = indian_f0.dataset + tbx11k_dataset = tbx11k_f0.dataset + + assert "train" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["train"]) == len( + mc_dataset["train"] + ) + len(ch_dataset["train"]) + len(in_dataset["train"]) + len( + tbx11k_dataset["train"] + ) + + assert "validation" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["test"]) == len( + mc_dataset["test"] + ) + len(ch_dataset["test"]) + len(in_dataset["test"]) + len( + tbx11k_dataset["test"] + ) + + # Fold 1 + mc_ch_in_11k_RS_dataset = mc_ch_in_11k_f1.dataset + assert isinstance(mc_ch_in_11k_RS_dataset, dict) + + mc_dataset = mc_f1.dataset + ch_dataset = ch_f1.dataset + in_dataset = indian_f1.dataset + tbx11k_dataset = tbx11k_f1.dataset + + assert "train" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["train"]) == len( + mc_dataset["train"] + ) + len(ch_dataset["train"]) + len(in_dataset["train"]) + len( + tbx11k_dataset["train"] + ) + + assert "validation" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["test"]) == len( + mc_dataset["test"] + ) + len(ch_dataset["test"]) + len(in_dataset["test"]) + len( + tbx11k_dataset["test"] + ) + + # Fold 2 + mc_ch_in_11k_RS_dataset = mc_ch_in_11k_f2.dataset + assert isinstance(mc_ch_in_11k_RS_dataset, dict) + + mc_dataset = mc_f2.dataset + ch_dataset = ch_f2.dataset + in_dataset = indian_f2.dataset + tbx11k_dataset = tbx11k_f2.dataset + + assert "train" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["train"]) == len( + mc_dataset["train"] + ) + len(ch_dataset["train"]) + len(in_dataset["train"]) + len( + tbx11k_dataset["train"] + ) + + assert "validation" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["test"]) == len( + mc_dataset["test"] + ) + len(ch_dataset["test"]) + len(in_dataset["test"]) + len( + tbx11k_dataset["test"] + ) + + # Fold 3 + mc_ch_in_11k_RS_dataset = mc_ch_in_11k_f3.dataset + assert isinstance(mc_ch_in_11k_RS_dataset, dict) + + mc_dataset = mc_f3.dataset + ch_dataset = ch_f3.dataset + in_dataset = indian_f3.dataset + tbx11k_dataset = tbx11k_f3.dataset + + assert "train" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["train"]) == len( + mc_dataset["train"] + ) + len(ch_dataset["train"]) + len(in_dataset["train"]) + len( + tbx11k_dataset["train"] + ) + + assert "validation" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["test"]) == len( + mc_dataset["test"] + ) + len(ch_dataset["test"]) + len(in_dataset["test"]) + len( + tbx11k_dataset["test"] + ) + + # Fold 4 + mc_ch_in_11k_RS_dataset = mc_ch_in_11k_f4.dataset + assert isinstance(mc_ch_in_11k_RS_dataset, dict) + + mc_dataset = mc_f4.dataset + ch_dataset = ch_f4.dataset + in_dataset = indian_f4.dataset + tbx11k_dataset = tbx11k_f4.dataset + + assert "train" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["train"]) == len( + mc_dataset["train"] + ) + len(ch_dataset["train"]) + len(in_dataset["train"]) + len( + tbx11k_dataset["train"] + ) + + assert "validation" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["test"]) == len( + mc_dataset["test"] + ) + len(ch_dataset["test"]) + len(in_dataset["test"]) + len( + tbx11k_dataset["test"] + ) + + # Fold 5 + mc_ch_in_11k_RS_dataset = mc_ch_in_11k_f5.dataset + assert isinstance(mc_ch_in_11k_RS_dataset, dict) + + mc_dataset = mc_f5.dataset + ch_dataset = ch_f5.dataset + in_dataset = indian_f5.dataset + tbx11k_dataset = tbx11k_f5.dataset + + assert "train" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["train"]) == len( + mc_dataset["train"] + ) + len(ch_dataset["train"]) + len(in_dataset["train"]) + len( + tbx11k_dataset["train"] + ) + + assert "validation" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["test"]) == len( + mc_dataset["test"] + ) + len(ch_dataset["test"]) + len(in_dataset["test"]) + len( + tbx11k_dataset["test"] + ) + + # Fold 6 + mc_ch_in_11k_RS_dataset = mc_ch_in_11k_f6.dataset + assert isinstance(mc_ch_in_11k_RS_dataset, dict) + + mc_dataset = mc_f6.dataset + ch_dataset = ch_f6.dataset + in_dataset = indian_f6.dataset + tbx11k_dataset = tbx11k_f6.dataset + + assert "train" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["train"]) == len( + mc_dataset["train"] + ) + len(ch_dataset["train"]) + len(in_dataset["train"]) + len( + tbx11k_dataset["train"] + ) + + assert "validation" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["test"]) == len( + mc_dataset["test"] + ) + len(ch_dataset["test"]) + len(in_dataset["test"]) + len( + tbx11k_dataset["test"] + ) + + # Fold 7 + mc_ch_in_11k_RS_dataset = mc_ch_in_11k_f7.dataset + assert isinstance(mc_ch_in_11k_RS_dataset, dict) + + mc_dataset = mc_f7.dataset + ch_dataset = ch_f7.dataset + in_dataset = indian_f7.dataset + tbx11k_dataset = tbx11k_f7.dataset + + assert "train" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["train"]) == len( + mc_dataset["train"] + ) + len(ch_dataset["train"]) + len(in_dataset["train"]) + len( + tbx11k_dataset["train"] + ) + + assert "validation" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["test"]) == len( + mc_dataset["test"] + ) + len(ch_dataset["test"]) + len(in_dataset["test"]) + len( + tbx11k_dataset["test"] + ) + + # Fold 8 + mc_ch_in_11k_RS_dataset = mc_ch_in_11k_f8.dataset + assert isinstance(mc_ch_in_11k_RS_dataset, dict) + + mc_dataset = mc_f8.dataset + ch_dataset = ch_f8.dataset + in_dataset = indian_f8.dataset + tbx11k_dataset = tbx11k_f8.dataset + + assert "train" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["train"]) == len( + mc_dataset["train"] + ) + len(ch_dataset["train"]) + len(in_dataset["train"]) + len( + tbx11k_dataset["train"] + ) + + assert "validation" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["test"]) == len( + mc_dataset["test"] + ) + len(ch_dataset["test"]) + len(in_dataset["test"]) + len( + tbx11k_dataset["test"] + ) + + # Fold 9 + mc_ch_in_11k_RS_dataset = mc_ch_in_11k_f9.dataset + assert isinstance(mc_ch_in_11k_RS_dataset, dict) + + mc_dataset = mc_f9.dataset + ch_dataset = ch_f9.dataset + in_dataset = indian_f9.dataset + tbx11k_dataset = tbx11k_f9.dataset + + assert "train" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["train"]) == len( + mc_dataset["train"] + ) + len(ch_dataset["train"]) + len(in_dataset["train"]) + len( + tbx11k_dataset["train"] + ) + + assert "validation" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["test"]) == len( + mc_dataset["test"] + ) + len(ch_dataset["test"]) + len(in_dataset["test"]) + len( + tbx11k_dataset["test"] + ) diff --git a/tests/test_mc_ch_in_11kv2.py b/tests/test_mc_ch_in_11kv2.py new file mode 100644 index 0000000000000000000000000000000000000000..c923a9f54240dbe240ca9682d3ac8e53933b0549 --- /dev/null +++ b/tests/test_mc_ch_in_11kv2.py @@ -0,0 +1,739 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Tests for the aggregated Montgomery-Shenzhen-Indian dataset.""" + + +def test_dataset_consistency(): + from ptbench.configs.datasets.indian import default as indian + from ptbench.configs.datasets.indian import fold_0 as indian_f0 + from ptbench.configs.datasets.indian import fold_0_rgb as indian_f0_rgb + from ptbench.configs.datasets.indian import fold_1 as indian_f1 + from ptbench.configs.datasets.indian import fold_1_rgb as indian_f1_rgb + from ptbench.configs.datasets.indian import fold_2 as indian_f2 + from ptbench.configs.datasets.indian import fold_2_rgb as indian_f2_rgb + from ptbench.configs.datasets.indian import fold_3 as indian_f3 + from ptbench.configs.datasets.indian import fold_3_rgb as indian_f3_rgb + from ptbench.configs.datasets.indian import fold_4 as indian_f4 + from ptbench.configs.datasets.indian import fold_4_rgb as indian_f4_rgb + from ptbench.configs.datasets.indian import fold_5 as indian_f5 + from ptbench.configs.datasets.indian import fold_5_rgb as indian_f5_rgb + from ptbench.configs.datasets.indian import fold_6 as indian_f6 + from ptbench.configs.datasets.indian import fold_6_rgb as indian_f6_rgb + from ptbench.configs.datasets.indian import fold_7 as indian_f7 + from ptbench.configs.datasets.indian import fold_7_rgb as indian_f7_rgb + from ptbench.configs.datasets.indian import fold_8 as indian_f8 + from ptbench.configs.datasets.indian import fold_8_rgb as indian_f8_rgb + from ptbench.configs.datasets.indian import fold_9 as indian_f9 + from ptbench.configs.datasets.indian import fold_9_rgb as indian_f9_rgb + from ptbench.configs.datasets.mc_ch_in_11kv2 import default as mc_ch_in_11k + from ptbench.configs.datasets.mc_ch_in_11kv2 import ( + fold_0 as mc_ch_in_11k_f0, + ) + from ptbench.configs.datasets.mc_ch_in_11kv2 import ( + fold_0_rgb as mc_ch_in_11k_f0_rgb, + ) + from ptbench.configs.datasets.mc_ch_in_11kv2 import ( + fold_1 as mc_ch_in_11k_f1, + ) + from ptbench.configs.datasets.mc_ch_in_11kv2 import ( + fold_1_rgb as mc_ch_in_11k_f1_rgb, + ) + from ptbench.configs.datasets.mc_ch_in_11kv2 import ( + fold_2 as mc_ch_in_11k_f2, + ) + from ptbench.configs.datasets.mc_ch_in_11kv2 import ( + fold_2_rgb as mc_ch_in_11k_f2_rgb, + ) + from ptbench.configs.datasets.mc_ch_in_11kv2 import ( + fold_3 as mc_ch_in_11k_f3, + ) + from ptbench.configs.datasets.mc_ch_in_11kv2 import ( + fold_3_rgb as mc_ch_in_11k_f3_rgb, + ) + from ptbench.configs.datasets.mc_ch_in_11kv2 import ( + fold_4 as mc_ch_in_11k_f4, + ) + from ptbench.configs.datasets.mc_ch_in_11kv2 import ( + fold_4_rgb as mc_ch_in_11k_f4_rgb, + ) + from ptbench.configs.datasets.mc_ch_in_11kv2 import ( + fold_5 as mc_ch_in_11k_f5, + ) + from ptbench.configs.datasets.mc_ch_in_11kv2 import ( + fold_5_rgb as mc_ch_in_11k_f5_rgb, + ) + from ptbench.configs.datasets.mc_ch_in_11kv2 import ( + fold_6 as mc_ch_in_11k_f6, + ) + from ptbench.configs.datasets.mc_ch_in_11kv2 import ( + fold_6_rgb as mc_ch_in_11k_f6_rgb, + ) + from ptbench.configs.datasets.mc_ch_in_11kv2 import ( + fold_7 as mc_ch_in_11k_f7, + ) + from ptbench.configs.datasets.mc_ch_in_11kv2 import ( + fold_7_rgb as mc_ch_in_11k_f7_rgb, + ) + from ptbench.configs.datasets.mc_ch_in_11kv2 import ( + fold_8 as mc_ch_in_11k_f8, + ) + from ptbench.configs.datasets.mc_ch_in_11kv2 import ( + fold_8_rgb as mc_ch_in_11k_f8_rgb, + ) + from ptbench.configs.datasets.mc_ch_in_11kv2 import ( + fold_9 as mc_ch_in_11k_f9, + ) + from ptbench.configs.datasets.mc_ch_in_11kv2 import ( + fold_9_rgb as mc_ch_in_11k_f9_rgb, + ) + from ptbench.configs.datasets.montgomery import default as mc + from ptbench.configs.datasets.montgomery import fold_0 as mc_f0 + from ptbench.configs.datasets.montgomery import fold_0_rgb as mc_f0_rgb + from ptbench.configs.datasets.montgomery import fold_1 as mc_f1 + from ptbench.configs.datasets.montgomery import fold_1_rgb as mc_f1_rgb + from ptbench.configs.datasets.montgomery import fold_2 as mc_f2 + from ptbench.configs.datasets.montgomery import fold_2_rgb as mc_f2_rgb + from ptbench.configs.datasets.montgomery import fold_3 as mc_f3 + from ptbench.configs.datasets.montgomery import fold_3_rgb as mc_f3_rgb + from ptbench.configs.datasets.montgomery import fold_4 as mc_f4 + from ptbench.configs.datasets.montgomery import fold_4_rgb as mc_f4_rgb + from ptbench.configs.datasets.montgomery import fold_5 as mc_f5 + from ptbench.configs.datasets.montgomery import fold_5_rgb as mc_f5_rgb + from ptbench.configs.datasets.montgomery import fold_6 as mc_f6 + from ptbench.configs.datasets.montgomery import fold_6_rgb as mc_f6_rgb + from ptbench.configs.datasets.montgomery import fold_7 as mc_f7 + from ptbench.configs.datasets.montgomery import fold_7_rgb as mc_f7_rgb + from ptbench.configs.datasets.montgomery import fold_8 as mc_f8 + from ptbench.configs.datasets.montgomery import fold_8_rgb as mc_f8_rgb + from ptbench.configs.datasets.montgomery import fold_9 as mc_f9 + from ptbench.configs.datasets.montgomery import fold_9_rgb as mc_f9_rgb + from ptbench.configs.datasets.shenzhen import default as ch + from ptbench.configs.datasets.shenzhen import fold_0 as ch_f0 + from ptbench.configs.datasets.shenzhen import fold_0_rgb as ch_f0_rgb + from ptbench.configs.datasets.shenzhen import fold_1 as ch_f1 + from ptbench.configs.datasets.shenzhen import fold_1_rgb as ch_f1_rgb + from ptbench.configs.datasets.shenzhen import fold_2 as ch_f2 + from ptbench.configs.datasets.shenzhen import fold_2_rgb as ch_f2_rgb + from ptbench.configs.datasets.shenzhen import fold_3 as ch_f3 + from ptbench.configs.datasets.shenzhen import fold_3_rgb as ch_f3_rgb + from ptbench.configs.datasets.shenzhen import fold_4 as ch_f4 + from ptbench.configs.datasets.shenzhen import fold_4_rgb as ch_f4_rgb + from ptbench.configs.datasets.shenzhen import fold_5 as ch_f5 + from ptbench.configs.datasets.shenzhen import fold_5_rgb as ch_f5_rgb + from ptbench.configs.datasets.shenzhen import fold_6 as ch_f6 + from ptbench.configs.datasets.shenzhen import fold_6_rgb as ch_f6_rgb + from ptbench.configs.datasets.shenzhen import fold_7 as ch_f7 + from ptbench.configs.datasets.shenzhen import fold_7_rgb as ch_f7_rgb + from ptbench.configs.datasets.shenzhen import fold_8 as ch_f8 + from ptbench.configs.datasets.shenzhen import fold_8_rgb as ch_f8_rgb + from ptbench.configs.datasets.shenzhen import fold_9 as ch_f9 + from ptbench.configs.datasets.shenzhen import fold_9_rgb as ch_f9_rgb + from ptbench.configs.datasets.tbx11k_simplified_v2 import default as tbx11k + from ptbench.configs.datasets.tbx11k_simplified_v2 import ( + fold_0 as tbx11k_f0, + ) + from ptbench.configs.datasets.tbx11k_simplified_v2 import ( + fold_0_rgb as tbx11k_f0_rgb, + ) + from ptbench.configs.datasets.tbx11k_simplified_v2 import ( + fold_1 as tbx11k_f1, + ) + from ptbench.configs.datasets.tbx11k_simplified_v2 import ( + fold_1_rgb as tbx11k_f1_rgb, + ) + from ptbench.configs.datasets.tbx11k_simplified_v2 import ( + fold_2 as tbx11k_f2, + ) + from ptbench.configs.datasets.tbx11k_simplified_v2 import ( + fold_2_rgb as tbx11k_f2_rgb, + ) + from ptbench.configs.datasets.tbx11k_simplified_v2 import ( + fold_3 as tbx11k_f3, + ) + from ptbench.configs.datasets.tbx11k_simplified_v2 import ( + fold_3_rgb as tbx11k_f3_rgb, + ) + from ptbench.configs.datasets.tbx11k_simplified_v2 import ( + fold_4 as tbx11k_f4, + ) + from ptbench.configs.datasets.tbx11k_simplified_v2 import ( + fold_4_rgb as tbx11k_f4_rgb, + ) + from ptbench.configs.datasets.tbx11k_simplified_v2 import ( + fold_5 as tbx11k_f5, + ) + from ptbench.configs.datasets.tbx11k_simplified_v2 import ( + fold_5_rgb as tbx11k_f5_rgb, + ) + from ptbench.configs.datasets.tbx11k_simplified_v2 import ( + fold_6 as tbx11k_f6, + ) + from ptbench.configs.datasets.tbx11k_simplified_v2 import ( + fold_6_rgb as tbx11k_f6_rgb, + ) + from ptbench.configs.datasets.tbx11k_simplified_v2 import ( + fold_7 as tbx11k_f7, + ) + from ptbench.configs.datasets.tbx11k_simplified_v2 import ( + fold_7_rgb as tbx11k_f7_rgb, + ) + from ptbench.configs.datasets.tbx11k_simplified_v2 import ( + fold_8 as tbx11k_f8, + ) + from ptbench.configs.datasets.tbx11k_simplified_v2 import ( + fold_8_rgb as tbx11k_f8_rgb, + ) + from ptbench.configs.datasets.tbx11k_simplified_v2 import ( + fold_9 as tbx11k_f9, + ) + from ptbench.configs.datasets.tbx11k_simplified_v2 import ( + fold_9_rgb as tbx11k_f9_rgb, + ) + + # Default protocol + mc_ch_in_11k_dataset = mc_ch_in_11k.dataset + assert isinstance(mc_ch_in_11k_dataset, dict) + + mc_dataset = mc.dataset + ch_dataset = ch.dataset + in_dataset = indian.dataset + tbx11k_dataset = tbx11k.dataset + + assert "train" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["train"]) == len(mc_dataset["train"]) + len( + ch_dataset["train"] + ) + len(in_dataset["train"]) + len(tbx11k_dataset["train"]) + + assert "validation" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["test"]) == len(mc_dataset["test"]) + len( + ch_dataset["test"] + ) + len(in_dataset["test"]) + len(tbx11k_dataset["test"]) + + # Fold 0 + mc_ch_in_11k_dataset = mc_ch_in_11k_f0.dataset + assert isinstance(mc_ch_in_11k_dataset, dict) + + mc_dataset = mc_f0.dataset + ch_dataset = ch_f0.dataset + in_dataset = indian_f0.dataset + tbx11k_dataset = tbx11k_f0.dataset + + assert "train" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["train"]) == len(mc_dataset["train"]) + len( + ch_dataset["train"] + ) + len(in_dataset["train"]) + len(tbx11k_dataset["train"]) + + assert "validation" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["test"]) == len(mc_dataset["test"]) + len( + ch_dataset["test"] + ) + len(in_dataset["test"]) + len(tbx11k_dataset["test"]) + + # Fold 1 + mc_ch_in_11k_dataset = mc_ch_in_11k_f1.dataset + assert isinstance(mc_ch_in_11k_dataset, dict) + + mc_dataset = mc_f1.dataset + ch_dataset = ch_f1.dataset + in_dataset = indian_f1.dataset + tbx11k_dataset = tbx11k_f1.dataset + + assert "train" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["train"]) == len(mc_dataset["train"]) + len( + ch_dataset["train"] + ) + len(in_dataset["train"]) + len(tbx11k_dataset["train"]) + + assert "validation" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["test"]) == len(mc_dataset["test"]) + len( + ch_dataset["test"] + ) + len(in_dataset["test"]) + len(tbx11k_dataset["test"]) + + # Fold 2 + mc_ch_in_11k_dataset = mc_ch_in_11k_f2.dataset + assert isinstance(mc_ch_in_11k_dataset, dict) + + mc_dataset = mc_f2.dataset + ch_dataset = ch_f2.dataset + in_dataset = indian_f2.dataset + tbx11k_dataset = tbx11k_f2.dataset + + assert "train" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["train"]) == len(mc_dataset["train"]) + len( + ch_dataset["train"] + ) + len(in_dataset["train"]) + len(tbx11k_dataset["train"]) + + assert "validation" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["test"]) == len(mc_dataset["test"]) + len( + ch_dataset["test"] + ) + len(in_dataset["test"]) + len(tbx11k_dataset["test"]) + + # Fold 3 + mc_ch_in_11k_dataset = mc_ch_in_11k_f3.dataset + assert isinstance(mc_ch_in_11k_dataset, dict) + + mc_dataset = mc_f3.dataset + ch_dataset = ch_f3.dataset + in_dataset = indian_f3.dataset + tbx11k_dataset = tbx11k_f3.dataset + + assert "train" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["train"]) == len(mc_dataset["train"]) + len( + ch_dataset["train"] + ) + len(in_dataset["train"]) + len(tbx11k_dataset["train"]) + + assert "validation" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["test"]) == len(mc_dataset["test"]) + len( + ch_dataset["test"] + ) + len(in_dataset["test"]) + len(tbx11k_dataset["test"]) + + # Fold 4 + mc_ch_in_11k_dataset = mc_ch_in_11k_f4.dataset + assert isinstance(mc_ch_in_11k_dataset, dict) + + mc_dataset = mc_f4.dataset + ch_dataset = ch_f4.dataset + in_dataset = indian_f4.dataset + tbx11k_dataset = tbx11k_f4.dataset + + assert "train" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["train"]) == len(mc_dataset["train"]) + len( + ch_dataset["train"] + ) + len(in_dataset["train"]) + len(tbx11k_dataset["train"]) + + assert "validation" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["test"]) == len(mc_dataset["test"]) + len( + ch_dataset["test"] + ) + len(in_dataset["test"]) + len(tbx11k_dataset["test"]) + + # Fold 5 + mc_ch_in_11k_dataset = mc_ch_in_11k_f5.dataset + assert isinstance(mc_ch_in_11k_dataset, dict) + + mc_dataset = mc_f5.dataset + ch_dataset = ch_f5.dataset + in_dataset = indian_f5.dataset + tbx11k_dataset = tbx11k_f5.dataset + + assert "train" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["train"]) == len(mc_dataset["train"]) + len( + ch_dataset["train"] + ) + len(in_dataset["train"]) + len(tbx11k_dataset["train"]) + + assert "validation" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["test"]) == len(mc_dataset["test"]) + len( + ch_dataset["test"] + ) + len(in_dataset["test"]) + len(tbx11k_dataset["test"]) + + # Fold 6 + mc_ch_in_11k_dataset = mc_ch_in_11k_f6.dataset + assert isinstance(mc_ch_in_11k_dataset, dict) + + mc_dataset = mc_f6.dataset + ch_dataset = ch_f6.dataset + in_dataset = indian_f6.dataset + tbx11k_dataset = tbx11k_f6.dataset + + assert "train" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["train"]) == len(mc_dataset["train"]) + len( + ch_dataset["train"] + ) + len(in_dataset["train"]) + len(tbx11k_dataset["train"]) + + assert "validation" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["test"]) == len(mc_dataset["test"]) + len( + ch_dataset["test"] + ) + len(in_dataset["test"]) + len(tbx11k_dataset["test"]) + + # Fold 7 + mc_ch_in_11k_dataset = mc_ch_in_11k_f7.dataset + assert isinstance(mc_ch_in_11k_dataset, dict) + + mc_dataset = mc_f7.dataset + ch_dataset = ch_f7.dataset + in_dataset = indian_f7.dataset + tbx11k_dataset = tbx11k_f7.dataset + + assert "train" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["train"]) == len(mc_dataset["train"]) + len( + ch_dataset["train"] + ) + len(in_dataset["train"]) + len(tbx11k_dataset["train"]) + + assert "validation" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["test"]) == len(mc_dataset["test"]) + len( + ch_dataset["test"] + ) + len(in_dataset["test"]) + len(tbx11k_dataset["test"]) + + # Fold 8 + mc_ch_in_11k_dataset = mc_ch_in_11k_f8.dataset + assert isinstance(mc_ch_in_11k_dataset, dict) + + mc_dataset = mc_f8.dataset + ch_dataset = ch_f8.dataset + in_dataset = indian_f8.dataset + tbx11k_dataset = tbx11k_f8.dataset + + assert "train" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["train"]) == len(mc_dataset["train"]) + len( + ch_dataset["train"] + ) + len(in_dataset["train"]) + len(tbx11k_dataset["train"]) + + assert "validation" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["test"]) == len(mc_dataset["test"]) + len( + ch_dataset["test"] + ) + len(in_dataset["test"]) + len(tbx11k_dataset["test"]) + + # Fold 9 + mc_ch_in_11k_dataset = mc_ch_in_11k_f9.dataset + assert isinstance(mc_ch_in_11k_dataset, dict) + + mc_dataset = mc_f9.dataset + ch_dataset = ch_f9.dataset + in_dataset = indian_f9.dataset + tbx11k_dataset = tbx11k_f9.dataset + + assert "train" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["train"]) == len(mc_dataset["train"]) + len( + ch_dataset["train"] + ) + len(in_dataset["train"]) + len(tbx11k_dataset["train"]) + + assert "validation" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["test"]) == len(mc_dataset["test"]) + len( + ch_dataset["test"] + ) + len(in_dataset["test"]) + len(tbx11k_dataset["test"]) + + # Fold 0, RGB + mc_ch_in_11k_dataset = mc_ch_in_11k_f0_rgb.dataset + assert isinstance(mc_ch_in_11k_dataset, dict) + + mc_dataset = mc_f0_rgb.dataset + ch_dataset = ch_f0_rgb.dataset + in_dataset = indian_f0_rgb.dataset + tbx11k_dataset = tbx11k_f0_rgb.dataset + + assert "train" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["train"]) == len(mc_dataset["train"]) + len( + ch_dataset["train"] + ) + len(in_dataset["train"]) + len(tbx11k_dataset["train"]) + + assert "validation" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["test"]) == len(mc_dataset["test"]) + len( + ch_dataset["test"] + ) + len(in_dataset["test"]) + len(tbx11k_dataset["test"]) + + # Fold 1, RGB + mc_ch_in_11k_dataset = mc_ch_in_11k_f1_rgb.dataset + assert isinstance(mc_ch_in_11k_dataset, dict) + + mc_dataset = mc_f1_rgb.dataset + ch_dataset = ch_f1_rgb.dataset + in_dataset = indian_f1_rgb.dataset + tbx11k_dataset = tbx11k_f1_rgb.dataset + + assert "train" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["train"]) == len(mc_dataset["train"]) + len( + ch_dataset["train"] + ) + len(in_dataset["train"]) + len(tbx11k_dataset["train"]) + + assert "validation" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["test"]) == len(mc_dataset["test"]) + len( + ch_dataset["test"] + ) + len(in_dataset["test"]) + len(tbx11k_dataset["test"]) + + # Fold 2, RGB + mc_ch_in_11k_dataset = mc_ch_in_11k_f2_rgb.dataset + assert isinstance(mc_ch_in_11k_dataset, dict) + + mc_dataset = mc_f2_rgb.dataset + ch_dataset = ch_f2_rgb.dataset + in_dataset = indian_f2_rgb.dataset + tbx11k_dataset = tbx11k_f2_rgb.dataset + + assert "train" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["train"]) == len(mc_dataset["train"]) + len( + ch_dataset["train"] + ) + len(in_dataset["train"]) + len(tbx11k_dataset["train"]) + + assert "validation" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["test"]) == len(mc_dataset["test"]) + len( + ch_dataset["test"] + ) + len(in_dataset["test"]) + len(tbx11k_dataset["test"]) + + # Fold 3, RGB + mc_ch_in_11k_dataset = mc_ch_in_11k_f3_rgb.dataset + assert isinstance(mc_ch_in_11k_dataset, dict) + + mc_dataset = mc_f3_rgb.dataset + ch_dataset = ch_f3_rgb.dataset + in_dataset = indian_f3_rgb.dataset + tbx11k_dataset = tbx11k_f3_rgb.dataset + + assert "train" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["train"]) == len(mc_dataset["train"]) + len( + ch_dataset["train"] + ) + len(in_dataset["train"]) + len(tbx11k_dataset["train"]) + + assert "validation" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["test"]) == len(mc_dataset["test"]) + len( + ch_dataset["test"] + ) + len(in_dataset["test"]) + len(tbx11k_dataset["test"]) + + # Fold 4, RGB + mc_ch_in_11k_dataset = mc_ch_in_11k_f4_rgb.dataset + assert isinstance(mc_ch_in_11k_dataset, dict) + + mc_dataset = mc_f4_rgb.dataset + ch_dataset = ch_f4_rgb.dataset + in_dataset = indian_f4_rgb.dataset + tbx11k_dataset = tbx11k_f4_rgb.dataset + + assert "train" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["train"]) == len(mc_dataset["train"]) + len( + ch_dataset["train"] + ) + len(in_dataset["train"]) + len(tbx11k_dataset["train"]) + + assert "validation" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["test"]) == len(mc_dataset["test"]) + len( + ch_dataset["test"] + ) + len(in_dataset["test"]) + len(tbx11k_dataset["test"]) + + # Fold 5, RGB + mc_ch_in_11k_dataset = mc_ch_in_11k_f5_rgb.dataset + assert isinstance(mc_ch_in_11k_dataset, dict) + + mc_dataset = mc_f5_rgb.dataset + ch_dataset = ch_f5_rgb.dataset + in_dataset = indian_f5_rgb.dataset + tbx11k_dataset = tbx11k_f5_rgb.dataset + + assert "train" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["train"]) == len(mc_dataset["train"]) + len( + ch_dataset["train"] + ) + len(in_dataset["train"]) + len(tbx11k_dataset["train"]) + + assert "validation" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["test"]) == len(mc_dataset["test"]) + len( + ch_dataset["test"] + ) + len(in_dataset["test"]) + len(tbx11k_dataset["test"]) + + # Fold 6, RGB + mc_ch_in_11k_dataset = mc_ch_in_11k_f6_rgb.dataset + assert isinstance(mc_ch_in_11k_dataset, dict) + + mc_dataset = mc_f6_rgb.dataset + ch_dataset = ch_f6_rgb.dataset + in_dataset = indian_f6_rgb.dataset + tbx11k_dataset = tbx11k_f6_rgb.dataset + + assert "train" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["train"]) == len(mc_dataset["train"]) + len( + ch_dataset["train"] + ) + len(in_dataset["train"]) + len(tbx11k_dataset["train"]) + + assert "validation" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["test"]) == len(mc_dataset["test"]) + len( + ch_dataset["test"] + ) + len(in_dataset["test"]) + len(tbx11k_dataset["test"]) + + # Fold 7, RGB + mc_ch_in_11k_dataset = mc_ch_in_11k_f7_rgb.dataset + assert isinstance(mc_ch_in_11k_dataset, dict) + + mc_dataset = mc_f7_rgb.dataset + ch_dataset = ch_f7_rgb.dataset + in_dataset = indian_f7_rgb.dataset + tbx11k_dataset = tbx11k_f7_rgb.dataset + + assert "train" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["train"]) == len(mc_dataset["train"]) + len( + ch_dataset["train"] + ) + len(in_dataset["train"]) + len(tbx11k_dataset["train"]) + + assert "validation" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["test"]) == len(mc_dataset["test"]) + len( + ch_dataset["test"] + ) + len(in_dataset["test"]) + len(tbx11k_dataset["test"]) + + # Fold 8, RGB + mc_ch_in_11k_dataset = mc_ch_in_11k_f8_rgb.dataset + assert isinstance(mc_ch_in_11k_dataset, dict) + + mc_dataset = mc_f8_rgb.dataset + ch_dataset = ch_f8_rgb.dataset + in_dataset = indian_f8_rgb.dataset + tbx11k_dataset = tbx11k_f8_rgb.dataset + + assert "train" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["train"]) == len(mc_dataset["train"]) + len( + ch_dataset["train"] + ) + len(in_dataset["train"]) + len(tbx11k_dataset["train"]) + + assert "validation" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["test"]) == len(mc_dataset["test"]) + len( + ch_dataset["test"] + ) + len(in_dataset["test"]) + len(tbx11k_dataset["test"]) + + # Fold 9, RGB + mc_ch_in_11k_dataset = mc_ch_in_11k_f9_rgb.dataset + assert isinstance(mc_ch_in_11k_dataset, dict) + + mc_dataset = mc_f9_rgb.dataset + ch_dataset = ch_f9_rgb.dataset + in_dataset = indian_f9_rgb.dataset + tbx11k_dataset = tbx11k_f9_rgb.dataset + + assert "train" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["train"]) == len(mc_dataset["train"]) + len( + ch_dataset["train"] + ) + len(in_dataset["train"]) + len(tbx11k_dataset["train"]) + + assert "validation" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_dataset + assert len(mc_ch_in_11k_dataset["test"]) == len(mc_dataset["test"]) + len( + ch_dataset["test"] + ) + len(in_dataset["test"]) + len(tbx11k_dataset["test"]) diff --git a/tests/test_mc_ch_in_11kv2_RS.py b/tests/test_mc_ch_in_11kv2_RS.py new file mode 100644 index 0000000000000000000000000000000000000000..61f4f003c399a6a4bb82e85359e1fbf2bee3e176 --- /dev/null +++ b/tests/test_mc_ch_in_11kv2_RS.py @@ -0,0 +1,440 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Tests for the aggregated Montgomery-Shenzhen-Indian-tbx11k_simplified_v2 +dataset.""" + + +def test_dataset_consistency(): + from ptbench.configs.datasets.indian_RS import default as indian_RS + from ptbench.configs.datasets.indian_RS import fold_0 as indian_f0 + from ptbench.configs.datasets.indian_RS import fold_1 as indian_f1 + from ptbench.configs.datasets.indian_RS import fold_2 as indian_f2 + from ptbench.configs.datasets.indian_RS import fold_3 as indian_f3 + from ptbench.configs.datasets.indian_RS import fold_4 as indian_f4 + from ptbench.configs.datasets.indian_RS import fold_5 as indian_f5 + from ptbench.configs.datasets.indian_RS import fold_6 as indian_f6 + from ptbench.configs.datasets.indian_RS import fold_7 as indian_f7 + from ptbench.configs.datasets.indian_RS import fold_8 as indian_f8 + from ptbench.configs.datasets.indian_RS import fold_9 as indian_f9 + from ptbench.configs.datasets.mc_ch_in_11kv2_RS import ( + default as mc_ch_in_11k_RS, + ) + from ptbench.configs.datasets.mc_ch_in_11kv2_RS import ( + fold_0 as mc_ch_in_11k_f0, + ) + from ptbench.configs.datasets.mc_ch_in_11kv2_RS import ( + fold_1 as mc_ch_in_11k_f1, + ) + from ptbench.configs.datasets.mc_ch_in_11kv2_RS import ( + fold_2 as mc_ch_in_11k_f2, + ) + from ptbench.configs.datasets.mc_ch_in_11kv2_RS import ( + fold_3 as mc_ch_in_11k_f3, + ) + from ptbench.configs.datasets.mc_ch_in_11kv2_RS import ( + fold_4 as mc_ch_in_11k_f4, + ) + from ptbench.configs.datasets.mc_ch_in_11kv2_RS import ( + fold_5 as mc_ch_in_11k_f5, + ) + from ptbench.configs.datasets.mc_ch_in_11kv2_RS import ( + fold_6 as mc_ch_in_11k_f6, + ) + from ptbench.configs.datasets.mc_ch_in_11kv2_RS import ( + fold_7 as mc_ch_in_11k_f7, + ) + from ptbench.configs.datasets.mc_ch_in_11kv2_RS import ( + fold_8 as mc_ch_in_11k_f8, + ) + from ptbench.configs.datasets.mc_ch_in_11kv2_RS import ( + fold_9 as mc_ch_in_11k_f9, + ) + from ptbench.configs.datasets.montgomery_RS import default as mc_RS + from ptbench.configs.datasets.montgomery_RS import fold_0 as mc_f0 + from ptbench.configs.datasets.montgomery_RS import fold_1 as mc_f1 + from ptbench.configs.datasets.montgomery_RS import fold_2 as mc_f2 + from ptbench.configs.datasets.montgomery_RS import fold_3 as mc_f3 + from ptbench.configs.datasets.montgomery_RS import fold_4 as mc_f4 + from ptbench.configs.datasets.montgomery_RS import fold_5 as mc_f5 + from ptbench.configs.datasets.montgomery_RS import fold_6 as mc_f6 + from ptbench.configs.datasets.montgomery_RS import fold_7 as mc_f7 + from ptbench.configs.datasets.montgomery_RS import fold_8 as mc_f8 + from ptbench.configs.datasets.montgomery_RS import fold_9 as mc_f9 + from ptbench.configs.datasets.shenzhen_RS import default as ch_RS + from ptbench.configs.datasets.shenzhen_RS import fold_0 as ch_f0 + from ptbench.configs.datasets.shenzhen_RS import fold_1 as ch_f1 + from ptbench.configs.datasets.shenzhen_RS import fold_2 as ch_f2 + from ptbench.configs.datasets.shenzhen_RS import fold_3 as ch_f3 + from ptbench.configs.datasets.shenzhen_RS import fold_4 as ch_f4 + from ptbench.configs.datasets.shenzhen_RS import fold_5 as ch_f5 + from ptbench.configs.datasets.shenzhen_RS import fold_6 as ch_f6 + from ptbench.configs.datasets.shenzhen_RS import fold_7 as ch_f7 + from ptbench.configs.datasets.shenzhen_RS import fold_8 as ch_f8 + from ptbench.configs.datasets.shenzhen_RS import fold_9 as ch_f9 + from ptbench.configs.datasets.tbx11k_simplified_v2_RS import ( + default as tbx11k_RS, + ) + from ptbench.configs.datasets.tbx11k_simplified_v2_RS import ( + fold_0 as tbx11k_f0, + ) + from ptbench.configs.datasets.tbx11k_simplified_v2_RS import ( + fold_1 as tbx11k_f1, + ) + from ptbench.configs.datasets.tbx11k_simplified_v2_RS import ( + fold_2 as tbx11k_f2, + ) + from ptbench.configs.datasets.tbx11k_simplified_v2_RS import ( + fold_3 as tbx11k_f3, + ) + from ptbench.configs.datasets.tbx11k_simplified_v2_RS import ( + fold_4 as tbx11k_f4, + ) + from ptbench.configs.datasets.tbx11k_simplified_v2_RS import ( + fold_5 as tbx11k_f5, + ) + from ptbench.configs.datasets.tbx11k_simplified_v2_RS import ( + fold_6 as tbx11k_f6, + ) + from ptbench.configs.datasets.tbx11k_simplified_v2_RS import ( + fold_7 as tbx11k_f7, + ) + from ptbench.configs.datasets.tbx11k_simplified_v2_RS import ( + fold_8 as tbx11k_f8, + ) + from ptbench.configs.datasets.tbx11k_simplified_v2_RS import ( + fold_9 as tbx11k_f9, + ) + + # Default protocol + mc_ch_in_11k_RS_dataset = mc_ch_in_11k_RS.dataset + assert isinstance(mc_ch_in_11k_RS_dataset, dict) + + mc_RS_dataset = mc_RS.dataset + ch_RS_dataset = ch_RS.dataset + in_RS_dataset = indian_RS.dataset + tbx11k_RS_dataset = tbx11k_RS.dataset + + assert "train" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["train"]) == len( + mc_RS_dataset["train"] + ) + len(ch_RS_dataset["train"]) + len(in_RS_dataset["train"]) + len( + tbx11k_RS_dataset["train"] + ) + + assert "validation" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["validation"]) == len( + mc_RS_dataset["validation"] + ) + len(ch_RS_dataset["validation"]) + len( + in_RS_dataset["validation"] + ) + len( + tbx11k_RS_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["test"]) == len( + mc_RS_dataset["test"] + ) + len(ch_RS_dataset["test"]) + len(in_RS_dataset["test"]) + len( + tbx11k_RS_dataset["test"] + ) + + # Fold 0 + mc_ch_in_11k_RS_dataset = mc_ch_in_11k_f0.dataset + assert isinstance(mc_ch_in_11k_RS_dataset, dict) + + mc_dataset = mc_f0.dataset + ch_dataset = ch_f0.dataset + in_dataset = indian_f0.dataset + tbx11k_dataset = tbx11k_f0.dataset + + assert "train" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["train"]) == len( + mc_dataset["train"] + ) + len(ch_dataset["train"]) + len(in_dataset["train"]) + len( + tbx11k_dataset["train"] + ) + + assert "validation" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["test"]) == len( + mc_dataset["test"] + ) + len(ch_dataset["test"]) + len(in_dataset["test"]) + len( + tbx11k_dataset["test"] + ) + + # Fold 1 + mc_ch_in_11k_RS_dataset = mc_ch_in_11k_f1.dataset + assert isinstance(mc_ch_in_11k_RS_dataset, dict) + + mc_dataset = mc_f1.dataset + ch_dataset = ch_f1.dataset + in_dataset = indian_f1.dataset + tbx11k_dataset = tbx11k_f1.dataset + + assert "train" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["train"]) == len( + mc_dataset["train"] + ) + len(ch_dataset["train"]) + len(in_dataset["train"]) + len( + tbx11k_dataset["train"] + ) + + assert "validation" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["test"]) == len( + mc_dataset["test"] + ) + len(ch_dataset["test"]) + len(in_dataset["test"]) + len( + tbx11k_dataset["test"] + ) + + # Fold 2 + mc_ch_in_11k_RS_dataset = mc_ch_in_11k_f2.dataset + assert isinstance(mc_ch_in_11k_RS_dataset, dict) + + mc_dataset = mc_f2.dataset + ch_dataset = ch_f2.dataset + in_dataset = indian_f2.dataset + tbx11k_dataset = tbx11k_f2.dataset + + assert "train" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["train"]) == len( + mc_dataset["train"] + ) + len(ch_dataset["train"]) + len(in_dataset["train"]) + len( + tbx11k_dataset["train"] + ) + + assert "validation" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["test"]) == len( + mc_dataset["test"] + ) + len(ch_dataset["test"]) + len(in_dataset["test"]) + len( + tbx11k_dataset["test"] + ) + + # Fold 3 + mc_ch_in_11k_RS_dataset = mc_ch_in_11k_f3.dataset + assert isinstance(mc_ch_in_11k_RS_dataset, dict) + + mc_dataset = mc_f3.dataset + ch_dataset = ch_f3.dataset + in_dataset = indian_f3.dataset + tbx11k_dataset = tbx11k_f3.dataset + + assert "train" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["train"]) == len( + mc_dataset["train"] + ) + len(ch_dataset["train"]) + len(in_dataset["train"]) + len( + tbx11k_dataset["train"] + ) + + assert "validation" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["test"]) == len( + mc_dataset["test"] + ) + len(ch_dataset["test"]) + len(in_dataset["test"]) + len( + tbx11k_dataset["test"] + ) + + # Fold 4 + mc_ch_in_11k_RS_dataset = mc_ch_in_11k_f4.dataset + assert isinstance(mc_ch_in_11k_RS_dataset, dict) + + mc_dataset = mc_f4.dataset + ch_dataset = ch_f4.dataset + in_dataset = indian_f4.dataset + tbx11k_dataset = tbx11k_f4.dataset + + assert "train" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["train"]) == len( + mc_dataset["train"] + ) + len(ch_dataset["train"]) + len(in_dataset["train"]) + len( + tbx11k_dataset["train"] + ) + + assert "validation" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["test"]) == len( + mc_dataset["test"] + ) + len(ch_dataset["test"]) + len(in_dataset["test"]) + len( + tbx11k_dataset["test"] + ) + + # Fold 5 + mc_ch_in_11k_RS_dataset = mc_ch_in_11k_f5.dataset + assert isinstance(mc_ch_in_11k_RS_dataset, dict) + + mc_dataset = mc_f5.dataset + ch_dataset = ch_f5.dataset + in_dataset = indian_f5.dataset + tbx11k_dataset = tbx11k_f5.dataset + + assert "train" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["train"]) == len( + mc_dataset["train"] + ) + len(ch_dataset["train"]) + len(in_dataset["train"]) + len( + tbx11k_dataset["train"] + ) + + assert "validation" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["test"]) == len( + mc_dataset["test"] + ) + len(ch_dataset["test"]) + len(in_dataset["test"]) + len( + tbx11k_dataset["test"] + ) + + # Fold 6 + mc_ch_in_11k_RS_dataset = mc_ch_in_11k_f6.dataset + assert isinstance(mc_ch_in_11k_RS_dataset, dict) + + mc_dataset = mc_f6.dataset + ch_dataset = ch_f6.dataset + in_dataset = indian_f6.dataset + tbx11k_dataset = tbx11k_f6.dataset + + assert "train" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["train"]) == len( + mc_dataset["train"] + ) + len(ch_dataset["train"]) + len(in_dataset["train"]) + len( + tbx11k_dataset["train"] + ) + + assert "validation" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["test"]) == len( + mc_dataset["test"] + ) + len(ch_dataset["test"]) + len(in_dataset["test"]) + len( + tbx11k_dataset["test"] + ) + + # Fold 7 + mc_ch_in_11k_RS_dataset = mc_ch_in_11k_f7.dataset + assert isinstance(mc_ch_in_11k_RS_dataset, dict) + + mc_dataset = mc_f7.dataset + ch_dataset = ch_f7.dataset + in_dataset = indian_f7.dataset + tbx11k_dataset = tbx11k_f7.dataset + + assert "train" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["train"]) == len( + mc_dataset["train"] + ) + len(ch_dataset["train"]) + len(in_dataset["train"]) + len( + tbx11k_dataset["train"] + ) + + assert "validation" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["test"]) == len( + mc_dataset["test"] + ) + len(ch_dataset["test"]) + len(in_dataset["test"]) + len( + tbx11k_dataset["test"] + ) + + # Fold 8 + mc_ch_in_11k_RS_dataset = mc_ch_in_11k_f8.dataset + assert isinstance(mc_ch_in_11k_RS_dataset, dict) + + mc_dataset = mc_f8.dataset + ch_dataset = ch_f8.dataset + in_dataset = indian_f8.dataset + tbx11k_dataset = tbx11k_f8.dataset + + assert "train" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["train"]) == len( + mc_dataset["train"] + ) + len(ch_dataset["train"]) + len(in_dataset["train"]) + len( + tbx11k_dataset["train"] + ) + + assert "validation" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["test"]) == len( + mc_dataset["test"] + ) + len(ch_dataset["test"]) + len(in_dataset["test"]) + len( + tbx11k_dataset["test"] + ) + + # Fold 9 + mc_ch_in_11k_RS_dataset = mc_ch_in_11k_f9.dataset + assert isinstance(mc_ch_in_11k_RS_dataset, dict) + + mc_dataset = mc_f9.dataset + ch_dataset = ch_f9.dataset + in_dataset = indian_f9.dataset + tbx11k_dataset = tbx11k_f9.dataset + + assert "train" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["train"]) == len( + mc_dataset["train"] + ) + len(ch_dataset["train"]) + len(in_dataset["train"]) + len( + tbx11k_dataset["train"] + ) + + assert "validation" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["validation"]) == len( + mc_dataset["validation"] + ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( + tbx11k_dataset["validation"] + ) + + assert "test" in mc_ch_in_11k_RS_dataset + assert len(mc_ch_in_11k_RS_dataset["test"]) == len( + mc_dataset["test"] + ) + len(ch_dataset["test"]) + len(in_dataset["test"]) + len( + tbx11k_dataset["test"] + ) diff --git a/tests/test_summary.py b/tests/test_summary.py index 4315a39f670a4235da14948f800102a8b21d1e8d..1f178fdc91dacf4a99eca3ced83d6628cca749d4 100644 --- a/tests/test_summary.py +++ b/tests/test_summary.py @@ -4,7 +4,8 @@ import unittest -from ptbench.models.pasa import build_pasa +import ptbench.configs.models.pasa as pasa_config + from ptbench.utils.summary import summary @@ -12,7 +13,7 @@ class Tester(unittest.TestCase): """Unit test for model architectures.""" def test_summary_driu(self): - model = build_pasa() + model = pasa_config.model s, param = summary(model) self.assertIsInstance(s, str) self.assertIsInstance(param, int)