database.py 34.8 KB
Newer Older
1 2 3
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :

4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
###################################################################################
#                                                                                 #
# Copyright (c) 2019 Idiap Research Institute, http://www.idiap.ch/               #
# Contact: beat.support@idiap.ch                                                  #
#                                                                                 #
# Redistribution and use in source and binary forms, with or without              #
# modification, are permitted provided that the following conditions are met:     #
#                                                                                 #
# 1. Redistributions of source code must retain the above copyright notice, this  #
# list of conditions and the following disclaimer.                                #
#                                                                                 #
# 2. Redistributions in binary form must reproduce the above copyright notice,    #
# this list of conditions and the following disclaimer in the documentation       #
# and/or other materials provided with the distribution.                          #
#                                                                                 #
# 3. Neither the name of the copyright holder nor the names of its contributors   #
# may be used to endorse or promote products derived from this software without   #
# specific prior written permission.                                              #
#                                                                                 #
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND #
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED   #
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE          #
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE    #
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL      #
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR      #
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER      #
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,   #
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE   #
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.            #
#                                                                                 #
###################################################################################
35 36


37 38 39 40 41 42 43
"""
========
database
========

Validation of databases
"""
44

Samuel GAIST's avatar
Samuel GAIST committed
45
import itertools
46 47 48
import os
import sys

49
from collections import namedtuple
50

Samuel GAIST's avatar
Samuel GAIST committed
51 52 53 54
import numpy as np
import simplejson as json
import six

55
from . import loader
Philip ABBET's avatar
Philip ABBET committed
56 57
from . import utils
from .dataformat import DataFormat
58
from .exceptions import OutputError
Samuel GAIST's avatar
Samuel GAIST committed
59 60
from .outputs import OutputList
from .protocoltemplate import ProtocolTemplate
Philip ABBET's avatar
Philip ABBET committed
61

62
# ----------------------------------------------------------
63

Philip ABBET's avatar
Philip ABBET committed
64 65

class Storage(utils.CodeStorage):
Philip ABBET's avatar
Philip ABBET committed
66
    """Resolves paths for databases
Philip ABBET's avatar
Philip ABBET committed
67

Philip ABBET's avatar
Philip ABBET committed
68
    Parameters:
Philip ABBET's avatar
Philip ABBET committed
69

70
      prefix (str): Establishes the prefix of your installation.
Philip ABBET's avatar
Philip ABBET committed
71

Philip ABBET's avatar
Philip ABBET committed
72 73
      name (str): The name of the database object in the format
        ``<name>/<version>``.
Philip ABBET's avatar
Philip ABBET committed
74

Philip ABBET's avatar
Philip ABBET committed
75
    """
Philip ABBET's avatar
Philip ABBET committed
76

77 78 79
    asset_type = "database"
    asset_folder = "databases"

Philip ABBET's avatar
Philip ABBET committed
80
    def __init__(self, prefix, name):
Philip ABBET's avatar
Philip ABBET committed
81

82
        if name.count("/") != 1:
Philip ABBET's avatar
Philip ABBET committed
83
            raise RuntimeError("invalid database name: `%s'" % name)
Philip ABBET's avatar
Philip ABBET committed
84

85
        self.name, self.version = name.split("/")
Philip ABBET's avatar
Philip ABBET committed
86
        self.fullname = name
87
        self.prefix = prefix
Philip ABBET's avatar
Philip ABBET committed
88

89
        path = os.path.join(self.prefix, self.asset_folder, name + ".json")
90
        path = path[:-5]
91
        # views are coded in Python
92
        super(Storage, self).__init__(path, "python")
93 94


95
# ----------------------------------------------------------
96

97

98
class Runner(object):
99
    """A special loader class for database views, with specialized methods
100

Philip ABBET's avatar
Philip ABBET committed
101
    Parameters:
102

Philip ABBET's avatar
Philip ABBET committed
103
      db_name (str): The full name of the database object for this view
104

105 106
      module (:std:term:`module`): The preloaded module containing the database
        views as returned by :py:func:`.loader.load_module`.
107

108
      prefix (str): Establishes the prefix of your installation.
109

André Anjos's avatar
André Anjos committed
110
      root_folder (str): The path pointing to the root folder of this database
111

112 113 114
      exc (:std:term:`class`): The class to use as base exception when
        translating the exception from the user code. Read the documention of
        :py:func:`.loader.run` for more details.
115

Philip ABBET's avatar
Philip ABBET committed
116
      *args: Constructor parameters for the database view. Normally, none.
117

Philip ABBET's avatar
Philip ABBET committed
118
      **kwargs: Constructor parameters for the database view. Normally, none.
119

120
    """
121

122
    def __init__(self, module, definition, prefix, root_folder, exc=None):
123

Philip ABBET's avatar
Philip ABBET committed
124
        try:
125 126
            class_ = getattr(module, definition["view"])
        except Exception:
Philip ABBET's avatar
Philip ABBET committed
127 128 129 130
            if exc is not None:
                type, value, traceback = sys.exc_info()
                six.reraise(exc, exc(value), traceback)
            else:
131
                raise  # just re-raise the user exception
132

133 134 135 136 137 138
        self.obj = loader.run(class_, "__new__", exc)
        self.ready = False
        self.prefix = prefix
        self.root_folder = root_folder
        self.definition = definition
        self.exc = exc or RuntimeError
139
        self.data_sources = None
140

141
    def index(self, filename):
142
        """Index the content of the view"""
143

144
        parameters = self.definition.get("parameters", {})
145

146
        objs = loader.run(self.obj, "index", self.exc, self.root_folder, parameters)
147

148 149
        if not isinstance(objs, list):
            raise self.exc("index() didn't return a list")
150

151 152
        if not os.path.exists(os.path.dirname(filename)):
            os.makedirs(os.path.dirname(filename))
153

154
        with open(filename, "wb") as f:
155
            data = json.dumps(objs, cls=utils.NumpyJSONEncoder)
156
            f.write(data.encode("utf-8"))
157

158
    def setup(self, filename, start_index=None, end_index=None, pack=True):
159
        """Sets up the view"""
160

161 162
        if self.ready:
            return
163

164
        with open(filename, "rb") as f:
165 166 167 168
            objs = json.loads(
                f.read().decode("utf-8"),
                object_pairs_hook=utils.error_on_duplicate_key_hook,
            )
169

170 171
        Entry = namedtuple("Entry", sorted(objs[0].keys()))
        objs = [Entry(**x) for x in objs]
172

173
        parameters = self.definition.get("parameters", {})
174

175 176 177 178 179 180 181 182 183 184
        loader.run(
            self.obj,
            "setup",
            self.exc,
            self.root_folder,
            parameters,
            objs,
            start_index=start_index,
            end_index=end_index,
        )
185

186 187
        # Create data sources for the outputs
        from .data import DatabaseOutputDataSource
188

189
        self.data_sources = {}
190
        for output_name, output_format in self.definition.get("outputs", {}).items():
191
            data_source = DatabaseOutputDataSource()
192 193 194 195 196 197 198 199 200
            data_source.setup(
                self,
                output_name,
                output_format,
                self.prefix,
                start_index=start_index,
                end_index=end_index,
                pack=pack,
            )
201
            self.data_sources[output_name] = data_source
202

203
        self.ready = True
204

205
    def get(self, output, index):
206
        """Returns the data of the provided output at the provided index"""
207

Philip ABBET's avatar
Philip ABBET committed
208
        if not self.ready:
209
            raise self.exc("Database view not yet setup")
210

211
        return loader.run(self.obj, "get", self.exc, output, index)
212

213
    def get_output_mapping(self, output):
214
        return loader.run(self.obj, "get_output_mapping", self.exc, output)
215

216 217 218 219
    def objects(self):
        return self.obj.objs


220
# ----------------------------------------------------------
221

222 223

class Database(object):
Philip ABBET's avatar
Philip ABBET committed
224
    """Databases define the start point of the dataflow in an experiment.
225 226


Philip ABBET's avatar
Philip ABBET committed
227
    Parameters:
228

229
      prefix (str): Establishes the prefix of your installation.
230

Philip ABBET's avatar
Philip ABBET committed
231
      name (str): The fully qualified database name (e.g. ``db/1``)
232

233 234 235 236 237 238
      dataformat_cache (:py:class:`dict`, Optional): A dictionary mapping
        dataformat names to loaded dataformats. This parameter is optional and,
        if passed, may greatly speed-up database loading times as dataformats
        that are already loaded may be re-used. If you use this parameter, you
        must guarantee that the cache is refreshed as appropriate in case the
        underlying dataformats change.
239 240


Philip ABBET's avatar
Philip ABBET committed
241
    Attributes:
242

Philip ABBET's avatar
Philip ABBET committed
243
      name (str): The full, valid name of this database
244

Philip ABBET's avatar
Philip ABBET committed
245 246
      data (dict): The original data for this database, as loaded by our JSON
        decoder.
247

Philip ABBET's avatar
Philip ABBET committed
248
    """
249

Philip ABBET's avatar
Philip ABBET committed
250
    def __init__(self, prefix, name, dataformat_cache=None):
251

Philip ABBET's avatar
Philip ABBET committed
252 253
        self._name = None
        self.prefix = prefix
254
        self.dataformats = {}  # preloaded dataformats
Philip ABBET's avatar
Philip ABBET committed
255
        self.storage = None
256

Philip ABBET's avatar
Philip ABBET committed
257 258
        self.errors = []
        self.data = None
259

Philip ABBET's avatar
Philip ABBET committed
260 261
        # if the user has not provided a cache, still use one for performance
        dataformat_cache = dataformat_cache if dataformat_cache is not None else {}
262

Philip ABBET's avatar
Philip ABBET committed
263
        self._load(name, dataformat_cache)
264

265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294
    def _update_dataformat_cache(self, outputs, dataformat_cache):
        for key, value in outputs.items():

            if value in self.dataformats:
                continue

            if value in dataformat_cache:
                dataformat = dataformat_cache[value]
            else:
                dataformat = DataFormat(self.prefix, value)
                dataformat_cache[value] = dataformat

            self.dataformats[value] = dataformat

    def _load_v1(self, dataformat_cache):
        """Loads a v1 database and fills the dataformat cache"""

        for protocol in self.data["protocols"]:
            for set_ in protocol["sets"]:
                self._update_dataformat_cache(set_["outputs"], dataformat_cache)

    def _load_v2(self, dataformat_cache):
        """Loads a v2 database and fills the dataformat cache"""

        for protocol in self.data["protocols"]:
            protocol_template = ProtocolTemplate(
                self.prefix, protocol["template"], dataformat_cache
            )
            for set_ in protocol_template.sets():
                self._update_dataformat_cache(set_["outputs"], dataformat_cache)
295

Philip ABBET's avatar
Philip ABBET committed
296 297
    def _load(self, data, dataformat_cache):
        """Loads the database"""
298

Philip ABBET's avatar
Philip ABBET committed
299
        self._name = data
Philip ABBET's avatar
Philip ABBET committed
300

Philip ABBET's avatar
Philip ABBET committed
301 302 303
        self.storage = Storage(self.prefix, self._name)
        json_path = self.storage.json.path
        if not self.storage.json.exists():
304
            self.errors.append("Database declaration file not found: %s" % json_path)
Philip ABBET's avatar
Philip ABBET committed
305
            return
Philip ABBET's avatar
Philip ABBET committed
306

307
        with open(json_path, "rb") as f:
308 309 310 311 312 313 314 315
            try:
                self.data = json.loads(
                    f.read().decode("utf-8"),
                    object_pairs_hook=utils.error_on_duplicate_key_hook,
                )
            except RuntimeError as error:
                self.errors.append("Database declaration file invalid: %s" % error)
                return
Philip ABBET's avatar
Philip ABBET committed
316

317 318 319
        self.code_path = self.storage.code.path
        self.code = self.storage.code.load()

320
        if self.schema_version == 1:
321
            self._load_v1(dataformat_cache)
322
        elif self.schema_version == 2:
323 324
            self._load_v2(dataformat_cache)
        else:
325 326
            raise RuntimeError(
                "Invalid schema version {schema_version}".format(
327
                    schema_version=self.schema_version
328 329
                )
            )
330

Philip ABBET's avatar
Philip ABBET committed
331 332 333 334
    @property
    def name(self):
        """Returns the name of this object
        """
335
        return self._name or "__unnamed_database__"
336

337 338 339 340 341 342 343 344
    @name.setter
    def name(self, value):
        self._name = value
        self.storage = Storage(self.prefix, value)

    @property
    def description(self):
        """The short description for this object"""
345
        return self.data.get("description", None)
346 347 348 349

    @description.setter
    def description(self, value):
        """Sets the short description for this object"""
350
        self.data["description"] = value
351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369

    @property
    def documentation(self):
        """The full-length description for this object"""

        if not self._name:
            raise RuntimeError("database has no name")

        if self.storage.doc.exists():
            return self.storage.doc.load()
        return None

    @documentation.setter
    def documentation(self, value):
        """Sets the full-length description for this object"""

        if not self._name:
            raise RuntimeError("database has no name")

370
        if hasattr(value, "read"):
371 372 373 374 375 376 377 378 379 380 381 382
            self.storage.doc.save(value.read())
        else:
            self.storage.doc.save(value)

    def hash(self):
        """Returns the hexadecimal hash for its declaration"""

        if not self._name:
            raise RuntimeError("database has no name")

        return self.storage.hash()

Philip ABBET's avatar
Philip ABBET committed
383 384 385
    @property
    def schema_version(self):
        """Returns the schema version"""
386
        return self.data.get("schema_version", 1)
387

Philip ABBET's avatar
Philip ABBET committed
388 389
    @property
    def valid(self):
390
        """A boolean that indicates if this database is valid or not"""
Philip ABBET's avatar
Philip ABBET committed
391

392
        return not bool(self.errors)
Philip ABBET's avatar
Philip ABBET committed
393

394 395 396 397 398 399
    @property
    def environment(self):
        """Returns the run environment if any has been set"""

        return self.data.get("environment")

Philip ABBET's avatar
Philip ABBET committed
400 401 402
    @property
    def protocols(self):
        """The declaration of all the protocols of the database"""
403

404 405
        data = self.data["protocols"]
        return dict(zip([k["name"] for k in data], data))
406

Philip ABBET's avatar
Philip ABBET committed
407 408
    def protocol(self, name):
        """The declaration of a specific protocol in the database"""
409

Philip ABBET's avatar
Philip ABBET committed
410
        return self.protocols[name]
411

Philip ABBET's avatar
Philip ABBET committed
412 413 414
    @property
    def protocol_names(self):
        """Names of protocols declared for this database"""
415

416 417
        data = self.data["protocols"]
        return [k["name"] for k in data]
418

Philip ABBET's avatar
Philip ABBET committed
419 420
    def sets(self, protocol):
        """The declaration of a specific set in the database protocol"""
421

422
        if self.schema_version == 1:
423 424 425 426 427 428 429 430 431
            data = self.protocol(protocol)["sets"]
        else:
            protocol = self.protocol(protocol)
            protocol_template = ProtocolTemplate(self.prefix, protocol["template"])
            if not protocol_template.valid:
                raise RuntimeError(
                    "\n  * {}".format("\n  * ".join(protocol_template.errors))
                )
            data = protocol_template.sets()
432

433
        return dict(zip([k["name"] for k in data], data))
434

Philip ABBET's avatar
Philip ABBET committed
435 436
    def set(self, protocol, name):
        """The declaration of all the protocols of the database"""
437

Philip ABBET's avatar
Philip ABBET committed
438
        return self.sets(protocol)[name]
439

Philip ABBET's avatar
Philip ABBET committed
440 441
    def set_names(self, protocol):
        """The names of sets in a given protocol for this database"""
442

443
        if self.schema_version == 1:
444 445 446 447 448 449 450 451 452 453 454 455 456 457
            data = self.protocol(protocol)["sets"]
        else:
            protocol = self.protocol(protocol)
            protocol_template = ProtocolTemplate(self.prefix, protocol["template"])
            if not protocol_template.valid:
                raise RuntimeError(
                    "\n  * {}".format("\n  * ".join(protocol_template.errors))
                )
            data = protocol_template.sets()

        return [k["name"] for k in data]

    def view_definition(self, protocol_name, set_name):
        """Returns the definition of a view
458

459 460 461 462 463 464 465 466 467
        Parameters:
          protocol_name (str): The name of the protocol where to retrieve the view
            from

          set_name (str): The name of the set in the protocol where to retrieve the
            view from

        """

468
        if self.schema_version == 1:
469 470 471 472 473 474 475
            view_definition = self.set(protocol_name, set_name)
        else:
            protocol = self.protocol(protocol_name)
            template_name = protocol["template"]
            protocol_template = ProtocolTemplate(self.prefix, template_name)
            view_definition = protocol_template.set(set_name)
            view_definition["view"] = protocol["views"][set_name]["view"]
476 477 478
            parameters = protocol["views"][set_name].get("parameters")
            if parameters is not None:
                view_definition["parameters"] = parameters
479 480

        return view_definition
481

482
    def view(self, protocol, name, exc=None, root_folder=None):
Philip ABBET's avatar
Philip ABBET committed
483
        """Returns the database view, given the protocol and the set name
484

Philip ABBET's avatar
Philip ABBET committed
485
        Parameters:
486

487 488
          protocol (str): The name of the protocol where to retrieve the view
            from
489

Philip ABBET's avatar
Philip ABBET committed
490 491
          name (str): The name of the set in the protocol where to retrieve the
            view from
492

493
          exc (:std:term:`class`): If passed, must be a valid exception class
494 495
            that will be used to report errors in the read-out of this
            database's view.
496

Philip ABBET's avatar
Philip ABBET committed
497
        Returns:
498

Philip ABBET's avatar
Philip ABBET committed
499 500
          The database view, which will be constructed, but not setup. You
          **must** set it up before using methods ``done`` or ``next``.
501

Philip ABBET's avatar
Philip ABBET committed
502 503 504 505 506 507 508
        """

        if not self._name:
            exc = exc or RuntimeError
            raise exc("database has no name")

        if not self.valid:
509 510 511 512 513
            message = (
                "cannot load view for set `%s' of protocol `%s' "
                "from invalid database (%s)\n%s"
                % (protocol, name, self.name, "   \n".join(self.errors))
            )
514 515 516
            if exc:
                raise exc(message)

Philip ABBET's avatar
Philip ABBET committed
517 518
            raise RuntimeError(message)

519 520
        # loads the module only once through the lifetime of the database
        # object
Philip ABBET's avatar
Philip ABBET committed
521
        try:
522 523 524 525 526
            if not hasattr(self, "_module"):
                self._module = loader.load_module(
                    self.name.replace(os.sep, "_"), self.storage.code.path, {}
                )
        except Exception:
Philip ABBET's avatar
Philip ABBET committed
527 528 529 530
            if exc is not None:
                type, value, traceback = sys.exc_info()
                six.reraise(exc, exc(value), traceback)
            else:
531
                raise  # just re-raise the user exception
532

533
        if root_folder is None:
534
            root_folder = self.data["root_folder"]
535

536 537 538 539 540 541 542
        return Runner(
            self._module,
            self.view_definition(protocol, name),
            self.prefix,
            root_folder,
            exc,
        )
543

544 545 546 547 548 549
    def json_dumps(self, indent=4):
        """Dumps the JSON declaration of this object in a string


        Parameters:

550 551
          indent (int): The number of indentation spaces at every indentation
            level
552 553 554 555 556 557 558 559


        Returns:

          str: The JSON representation for this object

        """

560
        return json.dumps(self.data, indent=indent, cls=utils.NumpyJSONEncoder)
561 562 563 564 565 566 567 568 569

    def __str__(self):
        return self.json_dumps()

    def write(self, storage=None):
        """Writes contents to prefix location

        Parameters:

570 571 572
          storage (:py:class:`.Storage`, Optional): If you pass a new storage,
            then this object will be written to that storage point rather than
            its default.
573 574 575 576 577 578

        """

        if storage is None:
            if not self._name:
                raise RuntimeError("database has no name")
579
            storage = self.storage  # overwrite
580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610

        storage.save(str(self), self.code, self.description)

    def export(self, prefix):
        """Recursively exports itself into another prefix

        Dataformats associated are also exported recursively


        Parameters:

          prefix (str): A path to a prefix that must different then my own.


        Returns:

          None


        Raises:

          RuntimeError: If prefix and self.prefix point to the same directory.

        """

        if not self._name:
            raise RuntimeError("database has no name")

        if not self.valid:
            raise RuntimeError("database is not valid")

611
        if prefix == self.prefix:
612 613 614
            raise RuntimeError(
                "Cannot export database to the same prefix (" "%s)" % prefix
            )
615 616 617 618

        for k in self.dataformats.values():
            k.export(prefix)

619
        if self.schema_version != 1:
620 621 622 623
            for protocol in self.protocols.values():
                protocol_template = ProtocolTemplate(self.prefix, protocol["template"])
                protocol_template.export(prefix)

624 625 626
        self.write(Storage(prefix, self.name))


627
# ----------------------------------------------------------
628 629 630


class View(object):
631 632 633 634 635 636 637 638
    def __init__(self):
        #  Current databases definitions uses named tuple to store information.
        #  This has one limitation, python keywords like `class` cannot be used.
        #  output_member_map allows to use that kind of keyword as output name
        #  while using something different for the named tuple (for example cls,
        #  klass, etc.)
        self.output_member_map = {}

639 640 641 642
    def index(self, root_folder, parameters):
        """Returns a list of (named) tuples describing the data provided by the view.

        The ordering of values inside the tuples is free, but it is expected
643 644
        that the list is ordered in a consistent manner (ie. all train images
        of person A, then all train images of person B, ...).
645 646 647

        For instance, assuming a view providing that kind of data:

648 649 650 651 652 653 654 655 656 657 658
        .. code-block:: text

           ----------- ----------- ----------- ----------- ----------- -----------
           |  image  | |  image  | |  image  | |  image  | |  image  | |  image  |
           ----------- ----------- ----------- ----------- ----------- -----------
           ----------- ----------- ----------- ----------- ----------- -----------
           | file_id | | file_id | | file_id | | file_id | | file_id | | file_id |
           ----------- ----------- ----------- ----------- ----------- -----------
           ----------------------------------- -----------------------------------
           |             client_id           | |             client_id           |
           ----------------------------------- -----------------------------------
659 660 661

        a list like the following should be generated:

662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678
        .. code-block:: python

           [
               (client_id=1, file_id=1, image=filename1),
               (client_id=1, file_id=2, image=filename2),
               (client_id=1, file_id=3, image=filename3),
               (client_id=2, file_id=4, image=filename4),
               (client_id=2, file_id=5, image=filename5),
               (client_id=2, file_id=6, image=filename6),
               ...
           ]

        .. warning::

           DO NOT store images, sound files or data loadable from a file in the
           list!  Store the path of the file to load instead.

679 680 681 682 683 684
        """

        raise NotImplementedError

    def setup(self, root_folder, parameters, objs, start_index=None, end_index=None):

685
        # Initialisation
686 687 688 689 690 691 692 693
        self.root_folder = root_folder
        self.parameters = parameters
        self.objs = objs

        # Determine the range of indices that must be provided
        self.start_index = start_index if start_index is not None else 0
        self.end_index = end_index if end_index is not None else len(self.objs) - 1

694
        self.objs = self.objs[self.start_index : self.end_index + 1]  # noqa
695 696

    def get(self, output, index):
697 698 699 700
        """Returns the data of the provided output at the provided index in the
        list of (named) tuples describing the data provided by the view
        (accessible at self.objs)
        """
701 702

        raise NotImplementedError
703

704 705 706 707 708 709
    def get_output_mapping(self, output):
        """Returns the object member to use for given output if any otherwise
        the member name is the output name.
        """
        return self.output_member_map.get(output, output)

710

711
# ----------------------------------------------------------
712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743


class DatabaseTester:
    """Used while developing a new database view, to test its behavior

    This class tests that, for each combination of connected/not connected
    outputs:

      - Data indices seems consistent
      - All the connected outputs produce data
      - All the not connected outputs don't produce data

    It also report some stats, and can generate a text file detailing the
    data generated by each output.

    By default, outputs are assumed to produce data at constant intervals.
    Those that don't follow this pattern, must be declared as 'irregular'.

    Note that no particular check is done about the database declaration or
    the correctness of the generated data with their data formats. This class
    is mainly used to check that the outputs are correctly synchronized.
    """

    # Mock output class
    class MockOutput:
        def __init__(self, name, connected):
            self.name = name
            self.connected = connected
            self.last_written_data_index = -1
            self.written_data = []

        def write(self, data, end_data_index):
744 745 746
            self.written_data.append(
                (self.last_written_data_index + 1, end_data_index, data)
            )
747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777
            self.last_written_data_index = end_data_index

        def isConnected(self):
            return self.connected

    class SynchronizedUnit:
        def __init__(self, start, end):
            self.start = start
            self.end = end
            self.data = {}
            self.children = []

        def addData(self, output, start, end, data):
            if (start == self.start) and (end == self.end):
                self.data[output] = self._dataToString(data)
            elif (len(self.children) == 0) or (self.children[-1].end < start):
                unit = DatabaseTester.SynchronizedUnit(start, end)
                unit.addData(output, start, end, data)
                self.children.append(unit)
            else:
                for index, unit in enumerate(self.children):
                    if (unit.start <= start) and (unit.end >= end):
                        unit.addData(output, start, end, data)
                        break
                    elif (unit.start == start) and (unit.end < end):
                        new_unit = DatabaseTester.SynchronizedUnit(start, end)
                        new_unit.addData(output, start, end, data)
                        new_unit.children.append(unit)

                        for i in range(index + 1, len(self.children)):
                            unit = self.children[i]
778
                            if unit.end <= end:
779 780 781 782
                                new_unit.children.append(unit)
                            else:
                                break

783 784 785
                        self.children = (
                            self.children[:index] + [new_unit] + self.children[i:]
                        )
786 787 788 789 790 791 792 793
                        break

        def toString(self):
            texts = {}

            for child in self.children:
                child_texts = child.toString()
                for output, text in child_texts.items():
794
                    if output in texts:
795
                        texts[output] += " " + text
796 797 798 799
                    else:
                        texts[output] = text

            if len(self.data) > 0:
800
                length = max([len(x) + 6 for x in self.data.values()])
801 802 803 804 805 806 807 808 809 810 811 812 813 814 815

                if len(texts) > 0:
                    children_length = len(texts.values()[0])
                    if children_length >= length:
                        length = children_length
                    else:
                        diff = length - children_length
                        if diff % 2 == 0:
                            diff1 = diff / 2
                            diff2 = diff1
                        else:
                            diff1 = diff // 2
                            diff2 = diff - diff1

                        for k, v in texts.items():
816
                            texts[k] = "|%s%s%s|" % ("-" * diff1, v[1:-1], "-" * diff2)
817 818 819 820 821 822 823 824 825 826

                for output, value in self.data.items():
                    output_length = len(value) + 6
                    diff = length - output_length
                    if diff % 2 == 0:
                        diff1 = diff / 2
                        diff2 = diff1
                    else:
                        diff1 = diff // 2
                        diff2 = diff - diff1
827
                    texts[output] = "|-%s %s %s-|" % ("-" * diff1, value, "-" * diff2)
828 829 830 831

            length = max(len(x) for x in texts.values())
            for k, v in texts.items():
                if len(v) < length:
832
                    texts[k] += " " * (length - len(v))
833 834 835 836 837

            return texts

        def _dataToString(self, data):
            if (len(data) > 1) or (len(data) == 0):
838
                return "X"
839 840 841 842

            value = data[data.keys()[0]]

            if isinstance(value, np.ndarray) or isinstance(value, dict):
843
                return "X"
844 845 846

            return str(value)

847 848 849 850 851 852 853 854 855
    def __init__(
        self,
        name,
        view_class,
        outputs_declaration,
        parameters,
        irregular_outputs=[],
        all_combinations=True,
    ):
856 857 858 859 860 861 862 863 864 865
        self.name = name
        self.view_class = view_class
        self.outputs_declaration = {}
        self.parameters = parameters
        self.irregular_outputs = irregular_outputs

        self.determine_regular_intervals(outputs_declaration)

        if all_combinations:
            for L in range(0, len(self.outputs_declaration) + 1):
866 867 868
                for subset in itertools.combinations(
                    self.outputs_declaration.keys(), L
                ):
869 870 871 872 873 874 875 876 877 878
                    self.run(subset)
        else:
            self.run(self.outputs_declaration.keys())

    def determine_regular_intervals(self, outputs_declaration):
        outputs = OutputList()
        for name in outputs_declaration:
            outputs.add(DatabaseTester.MockOutput(name, True))

        view = self.view_class()
879
        view.setup("", outputs, self.parameters)
880 881 882 883 884

        view.next()

        for output in outputs:
            if output.name not in self.irregular_outputs:
885 886 887
                self.outputs_declaration[output.name] = (
                    output.last_written_data_index + 1
                )
888 889 890 891 892 893 894
            else:
                self.outputs_declaration[output.name] = None

    def run(self, connected_outputs):
        if len(connected_outputs) == 0:
            return

895 896 897 898
        print(
            "Testing '%s', with %d output(s): %s"
            % (self.name, len(connected_outputs), ", ".join(connected_outputs))
        )
899 900

        # Create the mock outputs
901 902 903 904 905 906 907 908 909 910 911
        connected_outputs = dict(
            [x for x in self.outputs_declaration.items() if x[0] in connected_outputs]
        )

        not_connected_outputs = dict(
            [
                x
                for x in self.outputs_declaration.items()
                if x[0] not in connected_outputs
            ]
        )
912 913 914 915 916 917 918

        outputs = OutputList()
        for name in self.outputs_declaration.keys():
            outputs.add(DatabaseTester.MockOutput(name, name in connected_outputs))

        # Create the view
        view = self.view_class()
919
        view.setup("", outputs, self.parameters)
920 921 922 923 924 925 926 927 928 929

        # Initialisations
        next_expected_indices = {}
        for name, interval in connected_outputs.items():
            next_expected_indices[name] = 0

        next_index = 0

        def _done():
            for output in outputs:
930 931 932
                if output.isConnected() and not view.done(
                    output.last_written_data_index
                ):
933 934 935 936
                    return False
            return True

        # Ask for all the data
937
        while not (_done()):
938 939 940 941 942
            view.next()

            # Check the indices for the connected outputs
            for name in connected_outputs.keys():
                if name not in self.irregular_outputs:
943 944 945 946 947 948 949 950 951
                    if not (
                        outputs[name].written_data[-1][0] == next_expected_indices[name]
                    ):
                        raise OutputError("Wrong current index")
                    if not (
                        outputs[name].written_data[-1][1]
                        == next_expected_indices[name] + connected_outputs[name] - 1
                    ):
                        raise OutputError("Wrong next index")
952
                else:
953 954 955 956 957 958 959 960
                    if not (
                        outputs[name].written_data[-1][0] == next_expected_indices[name]
                    ):
                        raise OutputError("Wrong current index")
                    if not (
                        outputs[name].written_data[-1][1] >= next_expected_indices[name]
                    ):
                        raise OutputError("Wrong next index")
961 962 963

            # Check that the not connected outputs didn't produce data
            for name in not_connected_outputs.keys():
964 965
                if len(outputs[name].written_data) != 0:
                    raise OutputError("Data written on unconnected output")
966 967

            # Compute the next data index that should be produced
968 969 970
            next_index = 1 + min(
                [x.written_data[-1][1] for x in outputs if x.isConnected()]
            )
971

972 973
            # Compute the next data index that should be produced by each
            # connected output
974 975
            for name in connected_outputs.keys():
                if name not in self.irregular_outputs:
976 977 978 979
                    if (
                        next_index
                        == next_expected_indices[name] + connected_outputs[name]
                    ):
980 981 982
                        next_expected_indices[name] += connected_outputs[name]
                else:
                    if next_index > outputs[name].written_data[-1][1]:
983 984 985
                        next_expected_indices[name] = (
                            outputs[name].written_data[-1][1] + 1
                        )
986 987 988

        # Check the number of data produced on the regular outputs
        for name in connected_outputs.keys():
989
            print("  - %s: %d data" % (name, len(outputs[name].written_data)))
990
            if name not in self.irregular_outputs:
991 992 993 994 995
                if not (
                    len(outputs[name].written_data)
                    == next_index / connected_outputs[name]
                ):
                    raise OutputError("Invalid number of data produced")
996 997 998

        # Check that all outputs ends on the same index
        for name in connected_outputs.keys():
999 1000
            if not outputs[name].written_data[-1][1] == next_index - 1:
                raise OutputError("Outputs not on same index")
1001

1002 1003
        # Generate a text file with lots of details (only if all outputs are
        # connected)
1004 1005 1006
        if len(connected_outputs) == len(self.outputs_declaration):
            sorted_outputs = sorted(outputs, key=lambda x: len(x.written_data))

1007 1008 1009
            unit = DatabaseTester.SynchronizedUnit(
                0, sorted_outputs[0].written_data[-1][1]
            )
1010 1011 1012 1013 1014 1015 1016

            for output in sorted_outputs:
                for data in output.written_data:
                    unit.addData(output.name, data[0], data[1], data[2])

            texts = unit.toString()

1017
            outputs_max_length = max([len(x) for x in self.outputs_declaration.keys()])
1018

1019
            with open(self.name.replace(" ", "_") + ".txt", "w") as f:
1020 1021
                for i in range(1, len(sorted_outputs) + 1):
                    output_name = sorted_outputs[-i].name
1022
                    f.write(output_name + ": ")
1023 1024

                    if len(output_name) < outputs_max_length:
1025
                        f.write(" " * (outputs_max_length - len(output_name)))
1026

1027
                    f.write(texts[output_name] + "\n")