Commit afd61d86 authored by Samuel GAIST's avatar Samuel GAIST

[database] Implement v2 support

parent c7363305
Pipeline #29236 passed with stage
in 7 minutes and 26 seconds
This diff is collapsed.
......@@ -42,24 +42,25 @@ exceptions
Custom exceptions
"""
class RemoteException(Exception):
"""Exception happening on a remote location"""
def __init__(self, kind, message):
super(RemoteException, self).__init__()
if kind == 'sys':
if kind == "sys":
self.system_error = message
self.user_error = ''
self.user_error = ""
else:
self.system_error = ''
self.system_error = ""
self.user_error = message
def __str__(self):
if self.system_error:
return '(sys) {}'.format(self.system_error)
return "(sys) {}".format(self.system_error)
else:
return '(usr) {}'.format(self.user_error)
return "(usr) {}".format(self.user_error)
class UserError(Exception):
......@@ -70,3 +71,9 @@ class UserError(Exception):
def __str__(self):
return repr(self.value)
class OutputError(Exception):
"""Error happening on output"""
pass
{
"schema_version": 2,
"root_folder": "/tmp/path/not/set",
"protocols": [
{
"name": "double",
"template": "double/1",
"views": {
"double": {
"view": "Double"
}
}
},
{
"name": "triple",
"template": "triple/1",
"views": {
"triple": {
"view": "Triple"
}
}
},
{
"name": "two_sets",
"template": "two_sets/1",
"views": {
"double": {
"view": "Double"
},
"triple": {
"view": "Triple"
}
}
},
{
"name": "labelled",
"template": "labelled/1",
"views": {
"labelled": {
"view": "Labelled"
}
}
},
{
"name": "different_frequencies",
"template": "different_frequencies/1",
"views": {
"double" : {
"view": "DifferentFrequencies"
}
}
}
]
}
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
###################################################################################
# #
# Copyright (c) 2019 Idiap Research Institute, http://www.idiap.ch/ #
# Contact: beat.support@idiap.ch #
# #
# Redistribution and use in source and binary forms, with or without #
# modification, are permitted provided that the following conditions are met: #
# #
# 1. Redistributions of source code must retain the above copyright notice, this #
# list of conditions and the following disclaimer. #
# #
# 2. Redistributions in binary form must reproduce the above copyright notice, #
# this list of conditions and the following disclaimer in the documentation #
# and/or other materials provided with the distribution. #
# #
# 3. Neither the name of the copyright holder nor the names of its contributors #
# may be used to endorse or promote products derived from this software without #
# specific prior written permission. #
# #
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND #
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED #
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE #
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE #
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL #
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR #
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER #
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, #
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE #
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #
# #
###################################################################################
import numpy
from collections import namedtuple
from beat.backend.python.database import View
class Double(View):
def index(self, root_folder, parameters):
Entry = namedtuple("Entry", ["a", "b", "sum"])
return [
Entry(1, 10, 11),
Entry(2, 20, 22),
Entry(3, 30, 33),
Entry(4, 40, 44),
Entry(5, 50, 55),
Entry(6, 60, 66),
Entry(7, 70, 77),
Entry(8, 80, 88),
Entry(9, 90, 99),
]
def get(self, output, index):
obj = self.objs[index]
if output == "a":
return {"value": numpy.int32(obj.a)}
elif output == "b":
return {"value": numpy.int32(obj.b)}
elif output == "sum":
return {"value": numpy.int32(obj.sum)}
elif output == "class":
return {"value": numpy.int32(obj.cls)}
# ----------------------------------------------------------
class Triple(View):
def index(self, root_folder, parameters):
Entry = namedtuple("Entry", ["a", "b", "c", "sum"])
return [
Entry(1, 10, 100, 111),
Entry(2, 20, 200, 222),
Entry(3, 30, 300, 333),
Entry(4, 40, 400, 444),
Entry(5, 50, 500, 555),
Entry(6, 60, 600, 666),
Entry(7, 70, 700, 777),
Entry(8, 80, 800, 888),
Entry(9, 90, 900, 999),
]
def get(self, output, index):
obj = self.objs[index]
if output == "a":
return {"value": numpy.int32(obj.a)}
elif output == "b":
return {"value": numpy.int32(obj.b)}
elif output == "c":
return {"value": numpy.int32(obj.c)}
elif output == "sum":
return {"value": numpy.int32(obj.sum)}
# ----------------------------------------------------------
class Labelled(View):
def index(self, root_folder, parameters):
Entry = namedtuple("Entry", ["label", "value"])
return [
Entry("A", 1),
Entry("A", 2),
Entry("A", 3),
Entry("A", 4),
Entry("A", 5),
Entry("B", 10),
Entry("B", 20),
Entry("B", 30),
Entry("B", 40),
Entry("B", 50),
Entry("C", 100),
Entry("C", 200),
Entry("C", 300),
Entry("C", 400),
Entry("C", 500),
]
def get(self, output, index):
obj = self.objs[index]
if output == "label":
return {"value": obj.label}
elif output == "value":
return {"value": numpy.int32(obj.value)}
# ----------------------------------------------------------
class DifferentFrequencies(View):
def index(self, root_folder, parameters):
Entry = namedtuple("Entry", ["a", "b"])
return [
Entry(1, 10),
Entry(1, 20),
Entry(1, 30),
Entry(1, 40),
Entry(2, 50),
Entry(2, 60),
Entry(2, 70),
Entry(2, 80),
]
def get(self, output, index):
obj = self.objs[index]
if output == "a":
return {"value": numpy.int32(obj.a)}
elif output == "b":
return {"value": numpy.int32(obj.b)}
......@@ -3,8 +3,6 @@
"sets": [
{
"name": "double",
"template": "double",
"view": "DifferentFrequencies",
"outputs": {
"a": "user/single_integer/1",
"b": "user/single_integer/1"
......
......@@ -3,8 +3,6 @@
"sets": [
{
"name": "double",
"template": "double",
"view": "Double",
"outputs": {
"a": "user/single_integer/1",
"b": "user/single_integer/1",
......
......@@ -3,8 +3,6 @@
"sets": [
{
"name": "labelled",
"template": "labelled",
"view": "Labelled",
"outputs": {
"value": "user/single_integer/1",
"label": "user/single_string/1"
......
......@@ -3,8 +3,6 @@
"sets": [
{
"name": "triple",
"view": "Triple",
"template": "triple",
"outputs": {
"a": "user/single_integer/1",
"b": "user/single_integer/1",
......
......@@ -3,8 +3,6 @@
"sets": [
{
"name": "double",
"template": "double",
"view": "Double",
"outputs": {
"a": "user/single_integer/1",
"b": "user/single_integer/1",
......@@ -13,8 +11,6 @@
},
{
"name": "triple",
"template": "triple",
"view": "Triple",
"outputs": {
"a": "user/single_integer/1",
"b": "user/single_integer/1",
......
......@@ -40,74 +40,92 @@ from ..database import Database
from . import prefix
#----------------------------------------------------------
INTEGERS_DBS = ["integers_db/{}".format(i) for i in range(1, 3)]
# ----------------------------------------------------------
def load(database_name):
database = Database(prefix, database_name)
assert database.valid
nose.tools.assert_true(database.valid, "\n * %s" % "\n * ".join(database.errors))
return database
#----------------------------------------------------------
# ----------------------------------------------------------
def test_load_valid_database():
database = Database(prefix, 'integers_db/1')
assert database.valid, '\n * %s' % '\n * '.join(database.errors)
for db_name in INTEGERS_DBS:
yield load_valid_database, db_name
def load_valid_database(db_name):
database = load(db_name)
nose.tools.eq_(len(database.sets("double")), 1)
nose.tools.eq_(len(database.sets("triple")), 1)
nose.tools.eq_(len(database.sets("two_sets")), 2)
#----------------------------------------------------------
# ----------------------------------------------------------
def test_load_protocol_with_one_set():
database = Database(prefix, 'integers_db/1')
for db_name in INTEGERS_DBS:
yield load_valid_database, db_name
def load_protocol_with_one_set(db_name):
database = load(db_name)
protocol = database.protocol("double")
nose.tools.eq_(len(protocol['sets']), 1)
nose.tools.eq_(len(protocol["sets"]), 1)
set = database.set("double", "double")
set_ = database.set("double", "double")
nose.tools.eq_(set['name'], 'double')
nose.tools.eq_(len(set['outputs']), 3)
nose.tools.eq_(set_["name"], "double")
nose.tools.eq_(len(set_["outputs"]), 3)
assert set['outputs']['a'] is not None
assert set['outputs']['b'] is not None
assert set['outputs']['sum'] is not None
nose.tools.assert_is_not_none(set_["outputs"]["a"])
nose.tools.assert_is_not_none(set_["outputs"]["b"])
nose.tools.assert_is_not_none(set_["outputs"]["sum"])
#----------------------------------------------------------
# ----------------------------------------------------------
def test_load_protocol_with_two_sets():
database = Database(prefix, 'integers_db/1')
for db_name in INTEGERS_DBS:
yield load_valid_database, db_name
def load_protocol_with_two_sets(db_name):
database = load(db_name)
protocol = database.protocol("two_sets")
nose.tools.eq_(len(protocol['sets']), 2)
nose.tools.eq_(len(protocol["sets"]), 2)
set = database.set("two_sets", "double")
set_ = database.set("two_sets", "double")
nose.tools.eq_(set['name'], 'double')
nose.tools.eq_(len(set['outputs']), 3)
nose.tools.eq_(set["name"], "double")
nose.tools.eq_(len(set["outputs"]), 3)
assert set['outputs']['a'] is not None
assert set['outputs']['b'] is not None
assert set['outputs']['sum'] is not None
nose.tools.assert_is_not_none(set_["outputs"]["a"])
nose.tools.assert_is_not_none(set_["outputs"]["b"])
nose.tools.assert_is_not_none(set_["outputs"]["sum"])
set = database.set("two_sets", "triple")
set_ = database.set("two_sets", "triple")
nose.tools.eq_(set['name'], 'triple')
nose.tools.eq_(len(set['outputs']), 4)
nose.tools.eq_(set_["name"], "triple")
nose.tools.eq_(len(set_["outputs"]), 4)
assert set['outputs']['a'] is not None
assert set['outputs']['b'] is not None
assert set['outputs']['c'] is not None
assert set['outputs']['sum'] is not None
nose.tools.assert_is_not_none(set_["outputs"]["a"])
nose.tools.assert_is_not_none(set_["outputs"]["b"])
nose.tools.assert_is_not_none(set_["outputs"]["c"])
nose.tools.assert_is_not_none(set_["outputs"]["sum"])
......@@ -39,116 +39,114 @@ import tempfile
import shutil
import os
from ddt import ddt
from ddt import idata
from ..database import Database
from .test_database import INTEGERS_DBS
from . import prefix
#----------------------------------------------------------
# ----------------------------------------------------------
class MyExc(Exception):
pass
#----------------------------------------------------------
# ----------------------------------------------------------
@ddt
class TestDatabaseViewRunner(unittest.TestCase):
def setUp(self):
self.cache_root = tempfile.mkdtemp(prefix=__name__)
def tearDown(self):
shutil.rmtree(self.cache_root)
def test_syntax_error(self):
db = Database(prefix, 'syntax_error/1')
db = Database(prefix, "syntax_error/1")
self.assertTrue(db.valid)
with self.assertRaises(SyntaxError):
view = db.view('protocol', 'set')
db.view("protocol", "set")
def test_unknown_view(self):
db = Database(prefix, 'integers_db/1')
db = Database(prefix, "integers_db/1")
self.assertTrue(db.valid)
with self.assertRaises(KeyError):
view = db.view('protocol', 'does_not_exist')
db.view("protocol", "does_not_exist")
def test_valid_view(self):
db = Database(prefix, 'integers_db/1')
@idata(INTEGERS_DBS)
def test_valid_view(self, db_name):
db = Database(prefix, db_name)
self.assertTrue(db.valid)
view = db.view('double', 'double')
view = db.view("double", "double")
self.assertTrue(view is not None)
def test_indexing_crash(self):
db = Database(prefix, 'crash/1')
db = Database(prefix, "crash/1")
self.assertTrue(db.valid)
view = db.view('protocol', 'index_crashes', MyExc)
view = db.view("protocol", "index_crashes", MyExc)
with self.assertRaises(MyExc):
view.index(os.path.join(self.cache_root, 'data.db'))
view.index(os.path.join(self.cache_root, "data.db"))
def test_get_crash(self):
db = Database(prefix, 'crash/1')
db = Database(prefix, "crash/1")
self.assertTrue(db.valid)
view = db.view('protocol', 'get_crashes', MyExc)
view.index(os.path.join(self.cache_root, 'data.db'))
view.setup(os.path.join(self.cache_root, 'data.db'))
view = db.view("protocol", "get_crashes", MyExc)
view.index(os.path.join(self.cache_root, "data.db"))
view.setup(os.path.join(self.cache_root, "data.db"))
with self.assertRaises(MyExc):
view.get('a', 0)
view.get("a", 0)
def test_not_setup(self):
db = Database(prefix, 'crash/1')
db = Database(prefix, "crash/1")
self.assertTrue(db.valid)
view = db.view('protocol', 'get_crashes', MyExc)
view = db.view("protocol", "get_crashes", MyExc)
with self.assertRaises(MyExc):
view.get('a', 0)
view.get("a", 0)
def test_success(self):
db = Database(prefix, 'integers_db/1')
@idata(INTEGERS_DBS)
def test_success(self, db_name):
db = Database(prefix, db_name)
self.assertTrue(db.valid)
view = db.view('double', 'double', MyExc)
view.index(os.path.join(self.cache_root, 'data.db'))
view.setup(os.path.join(self.cache_root, 'data.db'))
view = db.view("double", "double", MyExc)
view.index(os.path.join(self.cache_root, "data.db"))
view.setup(os.path.join(self.cache_root, "data.db"))
self.assertTrue(view.data_sources is not None)
self.assertEqual(len(view.data_sources), 3)
for i in range(0, 9):
self.assertEqual(view.get('a', i)['value'], i + 1)
self.assertEqual(view.get('b', i)['value'], (i + 1) * 10)
self.assertEqual(view.get('sum', i)['value'], (i + 1) * 10 + i + 1)
self.assertEqual(view.get("a", i)["value"], i + 1)
self.assertEqual(view.get("b", i)["value"], (i + 1) * 10)
self.assertEqual(view.get("sum", i)["value"], (i + 1) * 10 + i + 1)
def test_success_using_keywords(self):
db = Database(prefix, 'python_keyword/1')
db = Database(prefix, "python_keyword/1")
self.assertTrue(db.valid)
view = db.view('keyword', 'keyword', MyExc)
view.index(os.path.join(self.cache_root, 'data.db'))
view.setup(os.path.join(self.cache_root, 'data.db'))
view = db.view("keyword", "keyword", MyExc)
view.index(os.path.join(self.cache_root, "data.db"))
view.setup(os.path.join(self.cache_root, "data.db"))
self.assertTrue(view.data_sources is not None)
self.assertEqual(len(view.data_sources), 3)
for i in range(0, 9):
self.assertEqual(view.get('class', i)['value'], i + 1)
self.assertEqual(view.get('def', i)['value'], (i + 1) * 10)
self.assertEqual(view.get('sum', i)['value'], (i + 1) * 10 + i + 1)
self.assertEqual(view.get("class", i)["value"], i + 1)
self.assertEqual(view.get("def", i)["value"], (i + 1) * 10)
self.assertEqual(view.get("sum", i)["value"], (i + 1) * 10 + i + 1)
......@@ -43,48 +43,49 @@ import multiprocessing
import tempfile
import shutil
from ddt import ddt
from ddt import idata
from ..scripts import index
from ..hash import hashDataset
from ..hash import toPath
from .test_database import INTEGERS_DBS
from . import prefix
#----------------------------------------------------------
# ----------------------------------------------------------
class IndexationProcess(multiprocessing.Process):
def __init__(self, queue, arguments):
super(IndexationProcess, self).__init__()
self.queue = queue
self.arguments = arguments
def run(self):
self.queue.put('STARTED')
self.queue.put("STARTED")
index.main(self.arguments)
#----------------------------------------------------------
# ----------------------------------------------------------
@ddt
class TestDatabaseIndexation(unittest.TestCase):
def __init__(self, methodName='runTest'):
def __init__(self, methodName="runTest"):
super(TestDatabaseIndexation, self).__init__(methodName)
self.databases_indexation_process = None
self.working_dir = None
self.cache_root = None
def setUp(self):
self.shutdown_everything() # In case another test failed badly during its setUp()
self.working_dir = tempfile.mkdtemp(prefix=__name__)
self.cache_root = tempfile.mkdtemp(prefix=__name__)
def tearDown(self):
self.shutdown_everything()
......@@ -95,7 +96,6 @@ class TestDatabaseIndexation(unittest.TestCase):
self.cache_root = None
self.data_source = None