Commit 7509e337 authored by Samuel GAIST's avatar Samuel GAIST

[database][view] Fix output to named tuple member retrieval

This allows to properly initialise and use database views that
uses python keywords as output name.

Fixes #16
parent 67abe074
Pipeline #23547 passed with stage
in 13 minutes and 55 seconds
......@@ -563,11 +563,13 @@ class DatabaseOutputDataSource(DataSource):
end = None
previous_value = None
attribute = self.view.get_output_mapping(output_name)
for index, obj in enumerate(objects):
if start is None:
start = index
previous_value = getattr(obj, output_name)
elif getattr(obj, output_name) != previous_value:
previous_value = getattr(obj, attribute)
elif getattr(obj, attribute) != previous_value:
end = index - 1
previous_value = None
......@@ -576,7 +578,7 @@ class DatabaseOutputDataSource(DataSource):
self.infos.append(Infos(start_index=start, end_index=end))
start = index
previous_value = getattr(obj, output_name)
previous_value = getattr(obj, attribute)
end = index
......
......@@ -186,6 +186,9 @@ class Runner(object):
return loader.run(self.obj, 'get', self.exc, output, index)
def get_output_mapping(self, output):
return loader.run(self.obj, 'get_output_mapping', self.exc, output)
def objects(self):
return self.obj.objs
......@@ -523,6 +526,15 @@ class Database(object):
class View(object):
def __init__(self):
# Current databases definitions uses named tuple to store information.
# This has one limitation, python keywords like `class` cannot be used.
# output_member_map allows to use that kind of keyword as output name
# while using something different for the named tuple (for example cls,
# klass, etc.)
self.output_member_map = {}
def index(self, root_folder, parameters):
"""Returns a list of (named) tuples describing the data provided by the view.
......@@ -591,6 +603,12 @@ class View(object):
raise NotImplementedError
def get_output_mapping(self, output):
"""Returns the object member to use for given output if any otherwise
the member name is the output name.
"""
return self.output_member_map.get(output, output)
# ----------------------------------------------------------
......
......@@ -12,7 +12,8 @@
"outputs": {
"a": "user/single_integer/1",
"b": "user/single_integer/1",
"sum": "user/single_integer/1"
"sum": "user/single_integer/1",
"class": "user/single_integer/1"
}
}
]
......@@ -29,7 +30,8 @@
"a": "user/single_integer/1",
"b": "user/single_integer/1",
"c": "user/single_integer/1",
"sum": "user/single_integer/1"
"sum": "user/single_integer/1",
"class": "user/single_integer/1"
}
}
]
......@@ -45,7 +47,8 @@
"outputs": {
"a": "user/single_integer/1",
"b": "user/single_integer/1",
"sum": "user/single_integer/1"
"sum": "user/single_integer/1",
"class": "user/single_integer/1"
}
},
{
......@@ -56,7 +59,8 @@
"a": "user/single_integer/1",
"b": "user/single_integer/1",
"c": "user/single_integer/1",
"sum": "user/single_integer/1"
"sum": "user/single_integer/1",
"class": "user/single_integer/1"
}
}
]
......
......@@ -32,20 +32,23 @@ from beat.backend.python.database import View
class Double(View):
def __init__(self):
super(Double, self)
self.output_member_map = {'class': 'cls'}
def index(self, root_folder, parameters):
Entry = namedtuple('Entry', ['a', 'b', 'sum'])
Entry = namedtuple('Entry', ['a', 'b', 'sum', 'cls'])
return [
Entry(1, 10, 11),
Entry(2, 20, 22),
Entry(3, 30, 33),
Entry(4, 40, 44),
Entry(5, 50, 55),
Entry(6, 60, 66),
Entry(7, 70, 77),
Entry(8, 80, 88),
Entry(9, 90, 99),
Entry(1, 10, 11, 41),
Entry(2, 20, 22, 42),
Entry(3, 30, 33, 43),
Entry(4, 40, 44, 44),
Entry(5, 50, 55, 45),
Entry(6, 60, 66, 46),
Entry(7, 70, 77, 47),
Entry(8, 80, 88, 48),
Entry(9, 90, 99, 49),
]
......@@ -66,6 +69,10 @@ class Double(View):
return {
'value': numpy.int32(obj.sum)
}
elif output == 'class':
return {
'value': numpy.int32(obj.cls)
}
#----------------------------------------------------------
......
......@@ -397,7 +397,7 @@ class TestDatabaseOutputDataSource(unittest.TestCase):
view.setup(os.path.join(self.cache_root, 'data.db'), pack=False)
self.assertTrue(view.data_sources is not None)
self.assertEqual(len(view.data_sources), 3)
self.assertEqual(len(view.data_sources), 4)
for output_name, data_source in view.data_sources.items():
self.assertEqual(9, len(data_source))
......
......@@ -68,11 +68,12 @@ def test_load_protocol_with_one_set():
set = database.set("double", "double")
nose.tools.eq_(set['name'], 'double')
nose.tools.eq_(len(set['outputs']), 3)
nose.tools.eq_(len(set['outputs']), 4)
assert set['outputs']['a'] is not None
assert set['outputs']['b'] is not None
assert set['outputs']['sum'] is not None
assert set['outputs']['class'] is not None
#----------------------------------------------------------
......@@ -88,18 +89,20 @@ def test_load_protocol_with_two_sets():
set = database.set("two_sets", "double")
nose.tools.eq_(set['name'], 'double')
nose.tools.eq_(len(set['outputs']), 3)
nose.tools.eq_(len(set['outputs']), 4)
assert set['outputs']['a'] is not None
assert set['outputs']['b'] is not None
assert set['outputs']['sum'] is not None
assert set['outputs']['class'] is not None
set = database.set("two_sets", "triple")
nose.tools.eq_(set['name'], 'triple')
nose.tools.eq_(len(set['outputs']), 4)
nose.tools.eq_(len(set['outputs']), 5)
assert set['outputs']['a'] is not None
assert set['outputs']['b'] is not None
assert set['outputs']['c'] is not None
assert set['outputs']['sum'] is not None
assert set['outputs']['class'] is not None
\ No newline at end of file
......@@ -121,7 +121,7 @@ class TestDatabaseViewRunner(unittest.TestCase):
view.setup(os.path.join(self.cache_root, 'data.db'))
self.assertTrue(view.data_sources is not None)
self.assertEqual(len(view.data_sources), 3)
self.assertEqual(len(view.data_sources), 4)
for i in range(0, 9):
self.assertEqual(view.get('a', i)['value'], i + 1)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment