diff --git a/gridtk/generator.py b/gridtk/generator.py index b82ea6081907cc512780a170cbdc9e8f29096810..a1a744cf1105e65386527a6afc1c1ebcb5134841 100644 --- a/gridtk/generator.py +++ b/gridtk/generator.py @@ -11,9 +11,17 @@ import yaml import jinja2 +class _OrderedDict(collections.OrderedDict): + """An OrderedDict class that can be compared. + This is to avoid sort errors (in Python 3) that happen in jinja internally. + """ + def __lt__(self, other): + return id(self) < id(other) + + def _ordered_load(stream, Loader=yaml.Loader, - object_pairs_hook=collections.OrderedDict): - '''Loads the contents of the YAML stream into :py:class:`collection.OrderedDict`'s + object_pairs_hook=_OrderedDict): + '''Loads the contents of the YAML stream into :py:class:`collections.OrderedDict`'s See: https://stackoverflow.com/questions/5121931/in-python-how-can-you-load-yaml-mappings-as-ordereddicts @@ -123,8 +131,8 @@ def expand(data): # separates "unique" objects from the ones we have to iterate # pre-assemble return dictionary - iterables = collections.OrderedDict() - unique = collections.OrderedDict() + iterables = _OrderedDict() + unique = _OrderedDict() for key, value in data.items(): if isinstance(value, list) and not key.startswith('_'): iterables[key] = value @@ -133,7 +141,7 @@ def expand(data): # generates all possible combinations of iterables for values in itertools.product(*iterables.values()): - retval = collections.OrderedDict(unique) + retval = _OrderedDict(unique) keys = list(iterables.keys()) retval.update(dict(zip(keys, values))) yield retval diff --git a/gridtk/tests/test_generator.py b/gridtk/tests/test_generator.py index 6a6d8ea9b4cde860f78a409bfa05423c1500c7e0..880afbcfb748dcbc7843e13dd890e2a4715279d8 100644 --- a/gridtk/tests/test_generator.py +++ b/gridtk/tests/test_generator.py @@ -275,3 +275,73 @@ def test_cmdline_unique_aggregation(): finally: shutil.rmtree(tmpdir) + + +def test_cmdline_aggregation_dict_groupby(): + + data = """ +model: + - {name: patch_1, patch_size: 28} + +train: + - {database: replaymobile, protocol: grandtest} + - {database: replay, protocol: grandtest} + +eval: + - {database: replaymobile, protocol: grandtest, groups: ['dev', 'eval']} + - {database: replay, protocol: grandtest, groups: ['dev', 'eval']} +""" + + template = '{{ model.name }}-{{ train.database }}-{{ eval.database }}' + + aggtmpl = """ +{% set cfg2 = cfgset|groupby('train')|map(attribute='list') -%} +{% for cfg3 in cfg2 %} +{% set k = cfg3[0] -%} +test-{{ k.model.name }}-{{ k.train.database }}-{{ k.eval.database }} +{%- endfor %} +""" + + gen_expected = [ + 'patch_1-replay-replay', + 'patch_1-replay-replaymobile', + 'patch_1-replaymobile-replay', + 'patch_1-replaymobile-replaymobile', + ] + + agg_expected = [ + '', + '', + 'test-patch_1-replaymobile-replaymobile', + 'test-patch_1-replay-replaymobile', + ] + tmpdir = tempfile.mkdtemp() + + try: + variables = os.path.join(tmpdir, 'variables.yaml') + with open(variables, 'wt') as f: f.write(data) + gentmpl = os.path.join(tmpdir, 'gentmpl.txt') + with open(gentmpl, 'wt') as f: f.write(template) + genout = os.path.join(tmpdir, 'out', template + '.txt') + + aggtmpl_file = os.path.join(tmpdir, 'agg.txt') + with open(aggtmpl_file, 'wt') as f: f.write(aggtmpl) + aggout = os.path.join(tmpdir, 'out', 'agg.txt') + + nose.tools.eq_(jgen.main(['-vv', variables, gentmpl, genout, aggtmpl_file, + aggout]), 0) + + # check all files are there and correspond to the expected output + outdir = os.path.dirname(genout) + for k in gen_expected: + ofile = os.path.join(outdir, k + '.txt') + assert os.path.exists(ofile) + with open(ofile, 'rt') as f: contents = f.read() + nose.tools.eq_(contents, k) + assert os.path.exists(aggout) + with open(aggout, 'rt') as f: contents = f.read() + for line in agg_expected: + assert line in contents, contents + + finally: + shutil.rmtree(tmpdir)