diff --git a/bob/ip/binseg/configs/datasets/stare/__init__.py b/bob/ip/binseg/configs/datasets/stare/__init__.py index a5229939a9ef5881f0a4ef8ccf44400d57aafcd1..82d1e14a7813b52dffa558924c861a42bab92760 100644 --- a/bob/ip/binseg/configs/datasets/stare/__init__.py +++ b/bob/ip/binseg/configs/datasets/stare/__init__.py @@ -1,9 +1,10 @@ #!/usr/bin/env python # coding=utf-8 -def _maker(protocol): +def _maker(protocol, raw=None): from ....data.transforms import Pad - from ....data.stare import dataset as raw + from ....data.stare import dataset as _raw + raw = raw or _raw #allows user to recreate dataset for testing purposes from .. import make_dataset as mk return mk(raw.subsets(protocol), [Pad((2, 1, 2, 2))]) diff --git a/bob/ip/binseg/data/stare/__init__.py b/bob/ip/binseg/data/stare/__init__.py index eacf312410e5e624727905e2ba7567139344a045..c673d62e5330c27ad76e4c8e13ad29d3611015d2 100644 --- a/bob/ip/binseg/data/stare/__init__.py +++ b/bob/ip/binseg/data/stare/__init__.py @@ -43,7 +43,7 @@ _root_path = bob.extension.rc.get( "bob.ip.binseg.stare.datadir", os.path.realpath(os.curdir) ) -def _make_loader(root_path): +def _make_dataset(root_path): def _loader(context, sample): # "context" is ignore in this case - database is homogeneous @@ -51,13 +51,13 @@ def _make_loader(root_path): data=load_pil_rgb(os.path.join(root_path, sample["data"])), label=load_pil_1(os.path.join(root_path, sample["label"])), ) - return _loader + return JSONDataset( + protocols=_protocols, + fieldnames=_fieldnames, + loader=_loader, + keymaker=data_path_keymaker, + ) -dataset = JSONDataset( - protocols=_protocols, - fieldnames=_fieldnames, - loader=_make_loader(_root_path), - keymaker=data_path_keymaker, -) +dataset = _make_dataset(_root_path) """STARE dataset object""" diff --git a/bob/ip/binseg/test/__init__.py b/bob/ip/binseg/test/__init__.py index 2e507ed77cc08ad5f73b8c579defe79d5759b202..e4b13525b1e78beb2b5b4e9a141b3bb5fa2f4f8c 100644 --- a/bob/ip/binseg/test/__init__.py +++ b/bob/ip/binseg/test/__init__.py @@ -66,15 +66,11 @@ def mock_dataset(): # if the user has the STARE directory ready, then we do a normal return from .utils import rc_variable_set - return stare.dataset, rc_variable_set + return rc["bob.ip.binseg.stare.datadir"], stare.dataset, rc_variable_set # else, we do a "mock" return return ( - stare.JSONDataset( - stare._protocols, - stare._fieldnames, - stare._make_loader(TESTDB_TMPDIR.name), - stare.data_path_keymaker, - ), + TESTDB_TMPDIR.name, + stare._make_dataset(TESTDB_TMPDIR.name), _mock_test_skipper, ) diff --git a/bob/ip/binseg/test/test_cli.py b/bob/ip/binseg/test/test_cli.py index 656fbc998a09daa3dcf89ef959b81cc9e9514709..b244bf8b8e78d3fedac17d2897a09aea21646e9d 100644 --- a/bob/ip/binseg/test/test_cli.py +++ b/bob/ip/binseg/test/test_cli.py @@ -4,13 +4,14 @@ """Tests for our CLI applications""" import re +import tempfile import contextlib from click.testing import CliRunner -## special trick for CI builds from . import mock_dataset -_, rc_variable_set = mock_dataset() + +stare_datadir, stare_dataset, rc_variable_set = mock_dataset() @contextlib.contextmanager @@ -20,11 +21,12 @@ def stdout_logging(): import sys import logging import io + buf = io.StringIO() ch = logging.StreamHandler(buf) - ch.setFormatter(logging.Formatter('%(message)s')) + ch.setFormatter(logging.Formatter("%(message)s")) ch.setLevel(logging.INFO) - logger = logging.getLogger('bob') + logger = logging.getLogger("bob") logger.addHandler(ch) yield buf logger.removeHandler(ch) @@ -32,9 +34,10 @@ def stdout_logging(): def _assert_exit_0(result): - assert result.exit_code == 0, ( - f"Exit code != 0 ({result.exit_code}); Output:\n{result.output}" - ) + assert ( + result.exit_code == 0 + ), f"Exit code != 0 ({result.exit_code}); Output:\n{result.output}" + def _check_help(entry_point): @@ -57,7 +60,7 @@ def test_experiment_help(): def _str_counter(substr, s): - return sum(1 for _ in re.finditer(r'\b%s\b' % re.escape(substr), s)) + return sum(1 for _ in re.finditer(r"\b%s\b" % re.escape(substr), s)) @rc_variable_set("bob.ip.binseg.stare.datadir") @@ -65,39 +68,55 @@ def test_experiment_stare(): from ..script.experiment import experiment runner = CliRunner() - with runner.isolated_filesystem(), stdout_logging() as buf: - result = runner.invoke(experiment, ["m2unet", "stare", "-vv", - "--epochs=1", "--batch-size=1", "--overlayed"]) + with runner.isolated_filesystem(), \ + stdout_logging() as buf, \ + tempfile.NamedTemporaryFile(mode="wt") as config: + + # re-write STARE dataset configuration for test + config.write("from bob.ip.binseg.data.stare import _make_dataset\n") + config.write(f"_raw = _make_dataset('{stare_datadir}')\n") + config.write( + "from bob.ip.binseg.configs.datasets.stare import _maker\n" + ) + config.write("dataset = _maker('ah', _raw)\n") + config.flush() + + result = runner.invoke( + experiment, + ["m2unet", config.name, "-vv", "--epochs=1", "--batch-size=1", + "--overlayed"], + ) _assert_exit_0(result) - keywords = { #from different logging systems - "Started training": 1, #logging - "epoch: 1|total-time": 1, #logging - "Saving checkpoint to results/model/model_final.pth": 1, #logging - "Ended training": 1, #logging - "Started prediction": 1, #logging - "Loading checkpoint from": 2, #logging - #"Saving results/overlayed/probabilities": 1, #tqdm.write - "Ended prediction": 1, #logging - "Started evaluation": 1, #logging - "Highest F1-score of": 2, #logging - "Saving overall precision-recall plot": 2, #logging - #"Saving results/overlayed/analysis": 1, #tqdm.write - "Ended evaluation": 1, #logging - "Started comparison": 1, #logging - "Loading metrics from results/analysis": 2, #logging - "Ended comparison": 1, #logging - } + keywords = { # from different logging systems + "Started training": 1, # logging + "epoch: 1|total-time": 1, # logging + "Saving checkpoint to results/model/model_final.pth": 1, # logging + "Ended training": 1, # logging + "Started prediction": 1, # logging + "Loading checkpoint from": 2, # logging + # "Saving results/overlayed/probabilities": 1, #tqdm.write + "Ended prediction": 1, # logging + "Started evaluation": 1, # logging + "Highest F1-score of": 2, # logging + "Saving overall precision-recall plot": 2, # logging + # "Saving results/overlayed/analysis": 1, #tqdm.write + "Ended evaluation": 1, # logging + "Started comparison": 1, # logging + "Loading metrics from results/analysis": 2, # logging + "Ended comparison": 1, # logging + } buf.seek(0) logging_output = buf.read() - for k,v in keywords.items(): - #if _str_counter(k, logging_output) != v: + for k, v in keywords.items(): + # if _str_counter(k, logging_output) != v: # print(f"Count for string '{k}' appeared " \ # f"({_str_counter(k, result.output)}) " \ # f"instead of the expected {v}") - assert _str_counter(k, logging_output) == v, \ - f"Count for string '{k}' appeared " \ - f"({_str_counter(k, result.output)}) " \ - f"instead of the expected {v}" + assert _str_counter(k, logging_output) == v, ( + f"Count for string '{k}' appeared " + f"({_str_counter(k, result.output)}) " + f"instead of the expected {v}" + ) def test_train_help(): diff --git a/bob/ip/binseg/test/test_config.py b/bob/ip/binseg/test/test_config.py index 9b0463d6e9c39f04be896ea31a32ac9dd2bed870..84af4ed4dab532a088feaff95176eb26b7417d53 100644 --- a/bob/ip/binseg/test/test_config.py +++ b/bob/ip/binseg/test/test_config.py @@ -6,7 +6,7 @@ import nose.tools import torch from . import mock_dataset -stare_dataset, stare_variable_set = mock_dataset() +stare_datadir, stare_dataset, stare_variable_set = mock_dataset() from .utils import rc_variable_set # we only iterate over the first N elements at most - dataset loading has @@ -50,9 +50,9 @@ def test_stare_augmentation_manipulation(): # some tests to check our context management for dataset augmentation works # adequately, with one example dataset - from ..configs.datasets.stare.ah import dataset # hack to allow testing on the CI - dataset["train"]._samples = stare_dataset.subsets("ah")["train"] + from ..configs.datasets.stare import _maker + dataset = _maker("ah", stare_dataset) nose.tools.eq_(dataset["train"].augmented, True) nose.tools.eq_(dataset["test"].augmented, False) @@ -74,9 +74,9 @@ def test_stare_augmentation_manipulation(): @stare_variable_set("bob.ip.binseg.stare.datadir") def test_stare_ah(): - from ..configs.datasets.stare.ah import dataset # hack to allow testing on the CI - dataset["train"]._samples = stare_dataset.subsets("ah")["train"] + from ..configs.datasets.stare import _maker + dataset = _maker("ah", stare_dataset) nose.tools.eq_(len(dataset["train"]), 10) nose.tools.eq_(dataset["train"].augmented, True) @@ -88,9 +88,6 @@ def test_stare_ah(): nose.tools.eq_(sample[2].shape, (1, 608, 704)) #planes, height, width nose.tools.eq_(sample[2].dtype, torch.float32) - # hack to allow testing on the CI - dataset["test"]._samples = stare_dataset.subsets("ah")["test"] - nose.tools.eq_(len(dataset["test"]), 10) nose.tools.eq_(dataset["test"].augmented, False) for sample in dataset["test"][:N]: @@ -105,9 +102,9 @@ def test_stare_ah(): @stare_variable_set("bob.ip.binseg.stare.datadir") def test_stare_vk(): - from ..configs.datasets.stare.vk import dataset # hack to allow testing on the CI - dataset["train"]._samples = stare_dataset.subsets("vk")["train"] + from ..configs.datasets.stare import _maker + dataset = _maker("vk", stare_dataset) nose.tools.eq_(len(dataset["train"]), 10) nose.tools.eq_(dataset["train"].augmented, True) @@ -119,9 +116,6 @@ def test_stare_vk(): nose.tools.eq_(sample[2].shape, (1, 608, 704)) #planes, height, width nose.tools.eq_(sample[2].dtype, torch.float32) - # hack to allow testing on the CI - dataset["test"]._samples = stare_dataset.subsets("vk")["test"] - nose.tools.eq_(len(dataset["test"]), 10) nose.tools.eq_(dataset["test"].augmented, False) for sample in dataset["test"][:N]: diff --git a/bob/ip/binseg/test/test_stare.py b/bob/ip/binseg/test/test_stare.py index 73c7a5275a9e2a24d6e5c69a062a0673a9f78f28..edc52ae1cccb0d2efa6b00d7e0469898f4c1eef5 100644 --- a/bob/ip/binseg/test/test_stare.py +++ b/bob/ip/binseg/test/test_stare.py @@ -11,7 +11,7 @@ import nose.tools ## special trick for CI builds from . import mock_dataset -dataset, rc_variable_set = mock_dataset() +datadir, dataset, rc_variable_set = mock_dataset() from .utils import count_bw