Commit 26083434 authored by Samuel GAIST's avatar Samuel GAIST

[database] Use error_on_duplicate_key_hook when loading data

parent 24ff81ae
......@@ -46,7 +46,7 @@ import os
import sys
import six
import simplejson
import simplejson as json
import itertools
import numpy as np
from collections import namedtuple
......@@ -153,7 +153,7 @@ class Runner(object):
os.makedirs(os.path.dirname(filename))
with open(filename, "wb") as f:
data = simplejson.dumps(objs, cls=utils.NumpyJSONEncoder)
data = json.dumps(objs, cls=utils.NumpyJSONEncoder)
f.write(data.encode("utf-8"))
def setup(self, filename, start_index=None, end_index=None, pack=True):
......@@ -163,7 +163,10 @@ class Runner(object):
return
with open(filename, "rb") as f:
objs = simplejson.loads(f.read().decode("utf-8"))
objs = json.loads(
f.read().decode("utf-8"),
object_pairs_hook=utils.error_on_duplicate_key_hook,
)
Entry = namedtuple("Entry", sorted(objs[0].keys()))
objs = [Entry(**x) for x in objs]
......@@ -303,7 +306,14 @@ class Database(object):
return
with open(json_path, "rb") as f:
self.data = simplejson.loads(f.read().decode("utf-8"))
try:
self.data = json.loads(
f.read().decode("utf-8"),
object_pairs_hook=utils.error_on_duplicate_key_hook,
)
except RuntimeError as error:
self.errors.append("Database declaration file invalid: %s" % error)
return
self.code_path = self.storage.code.path
self.code = self.storage.code.load()
......@@ -542,7 +552,7 @@ class Database(object):
"""
return simplejson.dumps(self.data, indent=indent, cls=utils.NumpyJSONEncoder)
return json.dumps(self.data, indent=indent, cls=utils.NumpyJSONEncoder)
def __str__(self):
return self.json_dumps()
......
{
"root_folder": "/tmp/foo/bar",
"protocols": [
{
"name": "test_duplicate_key",
"name": "test_duplicate_key",
"template": "double/1",
"views": {
"double": {
"view": "MyView",
"parameters": {
"threshold": 3
}
}
}
}
],
"schema_version": 2
}
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
###################################################################################
# #
# Copyright (c) 2019 Idiap Research Institute, http://www.idiap.ch/ #
# Contact: beat.support@idiap.ch #
# #
# Redistribution and use in source and binary forms, with or without #
# modification, are permitted provided that the following conditions are met: #
# #
# 1. Redistributions of source code must retain the above copyright notice, this #
# list of conditions and the following disclaimer. #
# #
# 2. Redistributions in binary form must reproduce the above copyright notice, #
# this list of conditions and the following disclaimer in the documentation #
# and/or other materials provided with the distribution. #
# #
# 3. Neither the name of the copyright holder nor the names of its contributors #
# may be used to endorse or promote products derived from this software without #
# specific prior written permission. #
# #
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND #
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED #
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE #
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE #
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL #
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR #
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER #
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, #
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE #
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #
# #
###################################################################################
import numpy
from collections import namedtuple
from beat.backend.python.database import View as BaseView
class MyView(BaseView):
def index(self, root_folder, parameters):
Entry = namedtuple("Entry", ["out"])
return [Entry(42)]
def get(self, output, index):
obj = self.objs[index]
if output == "out":
return {"value": numpy.int32(obj.out)}
......@@ -149,4 +149,15 @@ def compare_definitions(db_name, protocol_name, view_name):
db_2 = load("{}/2".format(db_name))
db_2_view_definition = db_2.view_definition(protocol_name, view_name)
nose.tools.eq_(db_1_view_definition, db_2_view_definition)
db_1_sorted = sorted(db_1_view_definition)
db_2_sorted = sorted(db_2_view_definition)
nose.tools.eq_(db_1_sorted, db_2_sorted)
# ----------------------------------------------------------
def test_duplicate_key_error():
database = Database(prefix, "duplicate_key_error/1")
nose.tools.assert_false(database.valid)
nose.tools.assert_true("Database declaration file invalid" in database.errors[0])
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment