Skip to content
Snippets Groups Projects
Commit 7509e337 authored by Samuel GAIST's avatar Samuel GAIST
Browse files

[database][view] Fix output to named tuple member retrieval

This allows to properly initialise and use database views that
uses python keywords as output name.

Fixes #16
parent 67abe074
Branches
Tags
2 merge requests!17Merge development branch 1.5.x,!16Fix output to named tuple member retrieval
Pipeline #23547 passed
...@@ -563,11 +563,13 @@ class DatabaseOutputDataSource(DataSource): ...@@ -563,11 +563,13 @@ class DatabaseOutputDataSource(DataSource):
end = None end = None
previous_value = None previous_value = None
attribute = self.view.get_output_mapping(output_name)
for index, obj in enumerate(objects): for index, obj in enumerate(objects):
if start is None: if start is None:
start = index start = index
previous_value = getattr(obj, output_name) previous_value = getattr(obj, attribute)
elif getattr(obj, output_name) != previous_value: elif getattr(obj, attribute) != previous_value:
end = index - 1 end = index - 1
previous_value = None previous_value = None
...@@ -576,7 +578,7 @@ class DatabaseOutputDataSource(DataSource): ...@@ -576,7 +578,7 @@ class DatabaseOutputDataSource(DataSource):
self.infos.append(Infos(start_index=start, end_index=end)) self.infos.append(Infos(start_index=start, end_index=end))
start = index start = index
previous_value = getattr(obj, output_name) previous_value = getattr(obj, attribute)
end = index end = index
......
...@@ -186,6 +186,9 @@ class Runner(object): ...@@ -186,6 +186,9 @@ class Runner(object):
return loader.run(self.obj, 'get', self.exc, output, index) return loader.run(self.obj, 'get', self.exc, output, index)
def get_output_mapping(self, output):
return loader.run(self.obj, 'get_output_mapping', self.exc, output)
def objects(self): def objects(self):
return self.obj.objs return self.obj.objs
...@@ -523,6 +526,15 @@ class Database(object): ...@@ -523,6 +526,15 @@ class Database(object):
class View(object): class View(object):
def __init__(self):
# Current databases definitions uses named tuple to store information.
# This has one limitation, python keywords like `class` cannot be used.
# output_member_map allows to use that kind of keyword as output name
# while using something different for the named tuple (for example cls,
# klass, etc.)
self.output_member_map = {}
def index(self, root_folder, parameters): def index(self, root_folder, parameters):
"""Returns a list of (named) tuples describing the data provided by the view. """Returns a list of (named) tuples describing the data provided by the view.
...@@ -591,6 +603,12 @@ class View(object): ...@@ -591,6 +603,12 @@ class View(object):
raise NotImplementedError raise NotImplementedError
def get_output_mapping(self, output):
"""Returns the object member to use for given output if any otherwise
the member name is the output name.
"""
return self.output_member_map.get(output, output)
# ---------------------------------------------------------- # ----------------------------------------------------------
......
...@@ -12,7 +12,8 @@ ...@@ -12,7 +12,8 @@
"outputs": { "outputs": {
"a": "user/single_integer/1", "a": "user/single_integer/1",
"b": "user/single_integer/1", "b": "user/single_integer/1",
"sum": "user/single_integer/1" "sum": "user/single_integer/1",
"class": "user/single_integer/1"
} }
} }
] ]
...@@ -29,7 +30,8 @@ ...@@ -29,7 +30,8 @@
"a": "user/single_integer/1", "a": "user/single_integer/1",
"b": "user/single_integer/1", "b": "user/single_integer/1",
"c": "user/single_integer/1", "c": "user/single_integer/1",
"sum": "user/single_integer/1" "sum": "user/single_integer/1",
"class": "user/single_integer/1"
} }
} }
] ]
...@@ -45,7 +47,8 @@ ...@@ -45,7 +47,8 @@
"outputs": { "outputs": {
"a": "user/single_integer/1", "a": "user/single_integer/1",
"b": "user/single_integer/1", "b": "user/single_integer/1",
"sum": "user/single_integer/1" "sum": "user/single_integer/1",
"class": "user/single_integer/1"
} }
}, },
{ {
...@@ -56,7 +59,8 @@ ...@@ -56,7 +59,8 @@
"a": "user/single_integer/1", "a": "user/single_integer/1",
"b": "user/single_integer/1", "b": "user/single_integer/1",
"c": "user/single_integer/1", "c": "user/single_integer/1",
"sum": "user/single_integer/1" "sum": "user/single_integer/1",
"class": "user/single_integer/1"
} }
} }
] ]
......
...@@ -32,20 +32,23 @@ from beat.backend.python.database import View ...@@ -32,20 +32,23 @@ from beat.backend.python.database import View
class Double(View): class Double(View):
def __init__(self):
super(Double, self)
self.output_member_map = {'class': 'cls'}
def index(self, root_folder, parameters): def index(self, root_folder, parameters):
Entry = namedtuple('Entry', ['a', 'b', 'sum']) Entry = namedtuple('Entry', ['a', 'b', 'sum', 'cls'])
return [ return [
Entry(1, 10, 11), Entry(1, 10, 11, 41),
Entry(2, 20, 22), Entry(2, 20, 22, 42),
Entry(3, 30, 33), Entry(3, 30, 33, 43),
Entry(4, 40, 44), Entry(4, 40, 44, 44),
Entry(5, 50, 55), Entry(5, 50, 55, 45),
Entry(6, 60, 66), Entry(6, 60, 66, 46),
Entry(7, 70, 77), Entry(7, 70, 77, 47),
Entry(8, 80, 88), Entry(8, 80, 88, 48),
Entry(9, 90, 99), Entry(9, 90, 99, 49),
] ]
...@@ -66,6 +69,10 @@ class Double(View): ...@@ -66,6 +69,10 @@ class Double(View):
return { return {
'value': numpy.int32(obj.sum) 'value': numpy.int32(obj.sum)
} }
elif output == 'class':
return {
'value': numpy.int32(obj.cls)
}
#---------------------------------------------------------- #----------------------------------------------------------
......
...@@ -397,7 +397,7 @@ class TestDatabaseOutputDataSource(unittest.TestCase): ...@@ -397,7 +397,7 @@ class TestDatabaseOutputDataSource(unittest.TestCase):
view.setup(os.path.join(self.cache_root, 'data.db'), pack=False) view.setup(os.path.join(self.cache_root, 'data.db'), pack=False)
self.assertTrue(view.data_sources is not None) self.assertTrue(view.data_sources is not None)
self.assertEqual(len(view.data_sources), 3) self.assertEqual(len(view.data_sources), 4)
for output_name, data_source in view.data_sources.items(): for output_name, data_source in view.data_sources.items():
self.assertEqual(9, len(data_source)) self.assertEqual(9, len(data_source))
......
...@@ -68,11 +68,12 @@ def test_load_protocol_with_one_set(): ...@@ -68,11 +68,12 @@ def test_load_protocol_with_one_set():
set = database.set("double", "double") set = database.set("double", "double")
nose.tools.eq_(set['name'], 'double') nose.tools.eq_(set['name'], 'double')
nose.tools.eq_(len(set['outputs']), 3) nose.tools.eq_(len(set['outputs']), 4)
assert set['outputs']['a'] is not None assert set['outputs']['a'] is not None
assert set['outputs']['b'] is not None assert set['outputs']['b'] is not None
assert set['outputs']['sum'] is not None assert set['outputs']['sum'] is not None
assert set['outputs']['class'] is not None
#---------------------------------------------------------- #----------------------------------------------------------
...@@ -88,18 +89,20 @@ def test_load_protocol_with_two_sets(): ...@@ -88,18 +89,20 @@ def test_load_protocol_with_two_sets():
set = database.set("two_sets", "double") set = database.set("two_sets", "double")
nose.tools.eq_(set['name'], 'double') nose.tools.eq_(set['name'], 'double')
nose.tools.eq_(len(set['outputs']), 3) nose.tools.eq_(len(set['outputs']), 4)
assert set['outputs']['a'] is not None assert set['outputs']['a'] is not None
assert set['outputs']['b'] is not None assert set['outputs']['b'] is not None
assert set['outputs']['sum'] is not None assert set['outputs']['sum'] is not None
assert set['outputs']['class'] is not None
set = database.set("two_sets", "triple") set = database.set("two_sets", "triple")
nose.tools.eq_(set['name'], 'triple') nose.tools.eq_(set['name'], 'triple')
nose.tools.eq_(len(set['outputs']), 4) nose.tools.eq_(len(set['outputs']), 5)
assert set['outputs']['a'] is not None assert set['outputs']['a'] is not None
assert set['outputs']['b'] is not None assert set['outputs']['b'] is not None
assert set['outputs']['c'] is not None assert set['outputs']['c'] is not None
assert set['outputs']['sum'] is not None assert set['outputs']['sum'] is not None
assert set['outputs']['class'] is not None
\ No newline at end of file
...@@ -121,7 +121,7 @@ class TestDatabaseViewRunner(unittest.TestCase): ...@@ -121,7 +121,7 @@ class TestDatabaseViewRunner(unittest.TestCase):
view.setup(os.path.join(self.cache_root, 'data.db')) view.setup(os.path.join(self.cache_root, 'data.db'))
self.assertTrue(view.data_sources is not None) self.assertTrue(view.data_sources is not None)
self.assertEqual(len(view.data_sources), 3) self.assertEqual(len(view.data_sources), 4)
for i in range(0, 9): for i in range(0, 9):
self.assertEqual(view.get('a', i)['value'], i + 1) self.assertEqual(view.get('a', i)['value'], i + 1)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment