From 8222eb3c353b2d9e9dc82929fde1322f1faf2f9a Mon Sep 17 00:00:00 2001 From: Amir MOHAMMADI Date: Thu, 1 Feb 2018 12:21:03 +0100 Subject: [PATCH 1/3] Add a test for when groupby is used in templates --- gridtk/tests/test_generator.py | 69 ++++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/gridtk/tests/test_generator.py b/gridtk/tests/test_generator.py index 6a6d8ea..94a5b80 100644 --- a/gridtk/tests/test_generator.py +++ b/gridtk/tests/test_generator.py @@ -275,3 +275,72 @@ def test_cmdline_unique_aggregation(): finally: shutil.rmtree(tmpdir) + + +def test_cmdline_aggregation_dict_groupby(): + + data = """ +model: + - {name: patch_1, patch_size: 28} + +train: + - {database: replaymobile, protocol: grandtest} + - {database: replay, protocol: grandtest} + +eval: + - {database: replaymobile, protocol: grandtest, groups: ['dev', 'eval']} + - {database: replay, protocol: grandtest, groups: ['dev', 'eval']} +""" + + template = '{{ model.name }}-{{ train.database }}-{{ eval.database }}' + + aggtmpl = """ +{% set cfg2 = cfgset|groupby('train')|map(attribute='list') -%} +{% for cfg3 in cfg2 %} +{% set k = cfg3[0] -%} +test-{{ k.model.name }}-{{ k.train.database }}-{{ k.eval.database }} +{%- endfor %} +""" + + gen_expected = [ + 'patch_1-replay-replay', + 'patch_1-replay-replaymobile', + 'patch_1-replaymobile-replay', + 'patch_1-replaymobile-replaymobile', + ] + + agg_expected = '\n'.join([ + '', + '', + 'test-patch_1-replay-replaymobile', + 'test-patch_1-replaymobile-replaymobile', + ]) + tmpdir = tempfile.mkdtemp() + + try: + variables = os.path.join(tmpdir, 'variables.yaml') + with open(variables, 'wt') as f: f.write(data) + gentmpl = os.path.join(tmpdir, 'gentmpl.txt') + with open(gentmpl, 'wt') as f: f.write(template) + genout = os.path.join(tmpdir, 'out', template + '.txt') + + aggtmpl_file = os.path.join(tmpdir, 'agg.txt') + with open(aggtmpl_file, 'wt') as f: f.write(aggtmpl) + aggout = os.path.join(tmpdir, 'out', 'agg.txt') + + nose.tools.eq_(jgen.main(['-vv', variables, gentmpl, genout, aggtmpl_file, + aggout]), 0) + + # check all files are there and correspond to the expected output + outdir = os.path.dirname(genout) + for k in gen_expected: + ofile = os.path.join(outdir, k + '.txt') + assert os.path.exists(ofile) + with open(ofile, 'rt') as f: contents = f.read() + nose.tools.eq_(contents, k) + assert os.path.exists(aggout) + with open(aggout, 'rt') as f: contents = f.read() + nose.tools.eq_(contents, agg_expected) + + finally: + shutil.rmtree(tmpdir) -- GitLab From bac33736ad819fce8ae0d5d96362c4ecb6da529c Mon Sep 17 00:00:00 2001 From: Amir MOHAMMADI Date: Thu, 1 Feb 2018 12:23:31 +0100 Subject: [PATCH 2/3] Fix sort errors that happen in Python 3 --- gridtk/generator.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/gridtk/generator.py b/gridtk/generator.py index b82ea60..a1a744c 100644 --- a/gridtk/generator.py +++ b/gridtk/generator.py @@ -11,9 +11,17 @@ import yaml import jinja2 +class _OrderedDict(collections.OrderedDict): + """An OrderedDict class that can be compared. + This is to avoid sort errors (in Python 3) that happen in jinja internally. + """ + def __lt__(self, other): + return id(self) < id(other) + + def _ordered_load(stream, Loader=yaml.Loader, - object_pairs_hook=collections.OrderedDict): - '''Loads the contents of the YAML stream into :py:class:`collection.OrderedDict`'s + object_pairs_hook=_OrderedDict): + '''Loads the contents of the YAML stream into :py:class:`collections.OrderedDict`'s See: https://stackoverflow.com/questions/5121931/in-python-how-can-you-load-yaml-mappings-as-ordereddicts @@ -123,8 +131,8 @@ def expand(data): # separates "unique" objects from the ones we have to iterate # pre-assemble return dictionary - iterables = collections.OrderedDict() - unique = collections.OrderedDict() + iterables = _OrderedDict() + unique = _OrderedDict() for key, value in data.items(): if isinstance(value, list) and not key.startswith('_'): iterables[key] = value @@ -133,7 +141,7 @@ def expand(data): # generates all possible combinations of iterables for values in itertools.product(*iterables.values()): - retval = collections.OrderedDict(unique) + retval = _OrderedDict(unique) keys = list(iterables.keys()) retval.update(dict(zip(keys, values))) yield retval -- GitLab From 0fe5fc9429fa16777a6d4731fc73181eb8ee10ed Mon Sep 17 00:00:00 2001 From: Amir MOHAMMADI Date: Fri, 16 Mar 2018 15:50:30 +0100 Subject: [PATCH 3/3] Make the tests less relaxed The behavior of jinja2 seems to be different between linux and osx --- gridtk/tests/test_generator.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/gridtk/tests/test_generator.py b/gridtk/tests/test_generator.py index 94a5b80..880afbc 100644 --- a/gridtk/tests/test_generator.py +++ b/gridtk/tests/test_generator.py @@ -309,12 +309,12 @@ test-{{ k.model.name }}-{{ k.train.database }}-{{ k.eval.database }} 'patch_1-replaymobile-replaymobile', ] - agg_expected = '\n'.join([ + agg_expected = [ '', '', - 'test-patch_1-replay-replaymobile', 'test-patch_1-replaymobile-replaymobile', - ]) + 'test-patch_1-replay-replaymobile', + ] tmpdir = tempfile.mkdtemp() try: @@ -340,7 +340,8 @@ test-{{ k.model.name }}-{{ k.train.database }}-{{ k.eval.database }} nose.tools.eq_(contents, k) assert os.path.exists(aggout) with open(aggout, 'rt') as f: contents = f.read() - nose.tools.eq_(contents, agg_expected) + for line in agg_expected: + assert line in contents, contents finally: shutil.rmtree(tmpdir) -- GitLab