From 26083434da0aa9047bbf7ea7d0583b15459d6e9d Mon Sep 17 00:00:00 2001 From: Samuel Gaist <samuel.gaist@idiap.ch> Date: Fri, 14 Jun 2019 09:56:52 +0200 Subject: [PATCH] [database] Use error_on_duplicate_key_hook when loading data --- beat/backend/python/database.py | 20 +++++-- .../databases/duplicate_key_error/1.json | 19 +++++++ .../prefix/databases/duplicate_key_error/1.py | 53 +++++++++++++++++++ beat/backend/python/test/test_database.py | 13 ++++- 4 files changed, 99 insertions(+), 6 deletions(-) create mode 100644 beat/backend/python/test/prefix/databases/duplicate_key_error/1.json create mode 100644 beat/backend/python/test/prefix/databases/duplicate_key_error/1.py diff --git a/beat/backend/python/database.py b/beat/backend/python/database.py index 73010e9..c810440 100644 --- a/beat/backend/python/database.py +++ b/beat/backend/python/database.py @@ -46,7 +46,7 @@ import os import sys import six -import simplejson +import simplejson as json import itertools import numpy as np from collections import namedtuple @@ -153,7 +153,7 @@ class Runner(object): os.makedirs(os.path.dirname(filename)) with open(filename, "wb") as f: - data = simplejson.dumps(objs, cls=utils.NumpyJSONEncoder) + data = json.dumps(objs, cls=utils.NumpyJSONEncoder) f.write(data.encode("utf-8")) def setup(self, filename, start_index=None, end_index=None, pack=True): @@ -163,7 +163,10 @@ class Runner(object): return with open(filename, "rb") as f: - objs = simplejson.loads(f.read().decode("utf-8")) + objs = json.loads( + f.read().decode("utf-8"), + object_pairs_hook=utils.error_on_duplicate_key_hook, + ) Entry = namedtuple("Entry", sorted(objs[0].keys())) objs = [Entry(**x) for x in objs] @@ -303,7 +306,14 @@ class Database(object): return with open(json_path, "rb") as f: - self.data = simplejson.loads(f.read().decode("utf-8")) + try: + self.data = json.loads( + f.read().decode("utf-8"), + object_pairs_hook=utils.error_on_duplicate_key_hook, + ) + except RuntimeError as error: + self.errors.append("Database declaration file invalid: %s" % error) + return self.code_path = self.storage.code.path self.code = self.storage.code.load() @@ -542,7 +552,7 @@ class Database(object): """ - return simplejson.dumps(self.data, indent=indent, cls=utils.NumpyJSONEncoder) + return json.dumps(self.data, indent=indent, cls=utils.NumpyJSONEncoder) def __str__(self): return self.json_dumps() diff --git a/beat/backend/python/test/prefix/databases/duplicate_key_error/1.json b/beat/backend/python/test/prefix/databases/duplicate_key_error/1.json new file mode 100644 index 0000000..f684a0f --- /dev/null +++ b/beat/backend/python/test/prefix/databases/duplicate_key_error/1.json @@ -0,0 +1,19 @@ +{ + "root_folder": "/tmp/foo/bar", + "protocols": [ + { + "name": "test_duplicate_key", + "name": "test_duplicate_key", + "template": "double/1", + "views": { + "double": { + "view": "MyView", + "parameters": { + "threshold": 3 + } + } + } + } + ], + "schema_version": 2 +} diff --git a/beat/backend/python/test/prefix/databases/duplicate_key_error/1.py b/beat/backend/python/test/prefix/databases/duplicate_key_error/1.py new file mode 100644 index 0000000..e7037b8 --- /dev/null +++ b/beat/backend/python/test/prefix/databases/duplicate_key_error/1.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python +# vim: set fileencoding=utf-8 : + +################################################################################### +# # +# Copyright (c) 2019 Idiap Research Institute, http://www.idiap.ch/ # +# Contact: beat.support@idiap.ch # +# # +# Redistribution and use in source and binary forms, with or without # +# modification, are permitted provided that the following conditions are met: # +# # +# 1. Redistributions of source code must retain the above copyright notice, this # +# list of conditions and the following disclaimer. # +# # +# 2. Redistributions in binary form must reproduce the above copyright notice, # +# this list of conditions and the following disclaimer in the documentation # +# and/or other materials provided with the distribution. # +# # +# 3. Neither the name of the copyright holder nor the names of its contributors # +# may be used to endorse or promote products derived from this software without # +# specific prior written permission. # +# # +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND # +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED # +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE # +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE # +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL # +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR # +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER # +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # +# # +################################################################################### + + +import numpy + +from collections import namedtuple +from beat.backend.python.database import View as BaseView + + +class MyView(BaseView): + def index(self, root_folder, parameters): + Entry = namedtuple("Entry", ["out"]) + + return [Entry(42)] + + def get(self, output, index): + obj = self.objs[index] + + if output == "out": + return {"value": numpy.int32(obj.out)} diff --git a/beat/backend/python/test/test_database.py b/beat/backend/python/test/test_database.py index b5e0248..3b68145 100644 --- a/beat/backend/python/test/test_database.py +++ b/beat/backend/python/test/test_database.py @@ -149,4 +149,15 @@ def compare_definitions(db_name, protocol_name, view_name): db_2 = load("{}/2".format(db_name)) db_2_view_definition = db_2.view_definition(protocol_name, view_name) - nose.tools.eq_(db_1_view_definition, db_2_view_definition) + db_1_sorted = sorted(db_1_view_definition) + db_2_sorted = sorted(db_2_view_definition) + nose.tools.eq_(db_1_sorted, db_2_sorted) + + +# ---------------------------------------------------------- + + +def test_duplicate_key_error(): + database = Database(prefix, "duplicate_key_error/1") + nose.tools.assert_false(database.valid) + nose.tools.assert_true("Database declaration file invalid" in database.errors[0]) -- GitLab