Skip to content
Snippets Groups Projects
Commit 90411bed authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Merge branch 'dict-sort' into 'master'

Avoid dict sort errors that happen in Python 3

See merge request bob/gridtk!18
parents b132968e 0fe5fc94
No related branches found
No related tags found
1 merge request!18Avoid dict sort errors that happen in Python 3
Pipeline #
......@@ -11,9 +11,17 @@ import yaml
import jinja2
class _OrderedDict(collections.OrderedDict):
"""An OrderedDict class that can be compared.
This is to avoid sort errors (in Python 3) that happen in jinja internally.
"""
def __lt__(self, other):
return id(self) < id(other)
def _ordered_load(stream, Loader=yaml.Loader,
object_pairs_hook=collections.OrderedDict):
'''Loads the contents of the YAML stream into :py:class:`collection.OrderedDict`'s
object_pairs_hook=_OrderedDict):
'''Loads the contents of the YAML stream into :py:class:`collections.OrderedDict`'s
See: https://stackoverflow.com/questions/5121931/in-python-how-can-you-load-yaml-mappings-as-ordereddicts
......@@ -123,8 +131,8 @@ def expand(data):
# separates "unique" objects from the ones we have to iterate
# pre-assemble return dictionary
iterables = collections.OrderedDict()
unique = collections.OrderedDict()
iterables = _OrderedDict()
unique = _OrderedDict()
for key, value in data.items():
if isinstance(value, list) and not key.startswith('_'):
iterables[key] = value
......@@ -133,7 +141,7 @@ def expand(data):
# generates all possible combinations of iterables
for values in itertools.product(*iterables.values()):
retval = collections.OrderedDict(unique)
retval = _OrderedDict(unique)
keys = list(iterables.keys())
retval.update(dict(zip(keys, values)))
yield retval
......
......@@ -275,3 +275,73 @@ def test_cmdline_unique_aggregation():
finally:
shutil.rmtree(tmpdir)
def test_cmdline_aggregation_dict_groupby():
data = """
model:
- {name: patch_1, patch_size: 28}
train:
- {database: replaymobile, protocol: grandtest}
- {database: replay, protocol: grandtest}
eval:
- {database: replaymobile, protocol: grandtest, groups: ['dev', 'eval']}
- {database: replay, protocol: grandtest, groups: ['dev', 'eval']}
"""
template = '{{ model.name }}-{{ train.database }}-{{ eval.database }}'
aggtmpl = """
{% set cfg2 = cfgset|groupby('train')|map(attribute='list') -%}
{% for cfg3 in cfg2 %}
{% set k = cfg3[0] -%}
test-{{ k.model.name }}-{{ k.train.database }}-{{ k.eval.database }}
{%- endfor %}
"""
gen_expected = [
'patch_1-replay-replay',
'patch_1-replay-replaymobile',
'patch_1-replaymobile-replay',
'patch_1-replaymobile-replaymobile',
]
agg_expected = [
'',
'',
'test-patch_1-replaymobile-replaymobile',
'test-patch_1-replay-replaymobile',
]
tmpdir = tempfile.mkdtemp()
try:
variables = os.path.join(tmpdir, 'variables.yaml')
with open(variables, 'wt') as f: f.write(data)
gentmpl = os.path.join(tmpdir, 'gentmpl.txt')
with open(gentmpl, 'wt') as f: f.write(template)
genout = os.path.join(tmpdir, 'out', template + '.txt')
aggtmpl_file = os.path.join(tmpdir, 'agg.txt')
with open(aggtmpl_file, 'wt') as f: f.write(aggtmpl)
aggout = os.path.join(tmpdir, 'out', 'agg.txt')
nose.tools.eq_(jgen.main(['-vv', variables, gentmpl, genout, aggtmpl_file,
aggout]), 0)
# check all files are there and correspond to the expected output
outdir = os.path.dirname(genout)
for k in gen_expected:
ofile = os.path.join(outdir, k + '.txt')
assert os.path.exists(ofile)
with open(ofile, 'rt') as f: contents = f.read()
nose.tools.eq_(contents, k)
assert os.path.exists(aggout)
with open(aggout, 'rt') as f: contents = f.read()
for line in agg_expected:
assert line in contents, contents
finally:
shutil.rmtree(tmpdir)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment