database.py 34.2 KB
Newer Older
1 2 3
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :

4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
###################################################################################
#                                                                                 #
# Copyright (c) 2019 Idiap Research Institute, http://www.idiap.ch/               #
# Contact: beat.support@idiap.ch                                                  #
#                                                                                 #
# Redistribution and use in source and binary forms, with or without              #
# modification, are permitted provided that the following conditions are met:     #
#                                                                                 #
# 1. Redistributions of source code must retain the above copyright notice, this  #
# list of conditions and the following disclaimer.                                #
#                                                                                 #
# 2. Redistributions in binary form must reproduce the above copyright notice,    #
# this list of conditions and the following disclaimer in the documentation       #
# and/or other materials provided with the distribution.                          #
#                                                                                 #
# 3. Neither the name of the copyright holder nor the names of its contributors   #
# may be used to endorse or promote products derived from this software without   #
# specific prior written permission.                                              #
#                                                                                 #
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND #
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED   #
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE          #
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE    #
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL      #
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR      #
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER      #
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,   #
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE   #
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.            #
#                                                                                 #
###################################################################################
35 36


37 38 39 40 41 42 43
"""
========
database
========

Validation of databases
"""
44 45 46 47 48 49

import os
import sys

import six
import simplejson
50 51
import itertools
import numpy as np
52
from collections import namedtuple
53 54

from . import loader
Philip ABBET's avatar
Philip ABBET committed
55 56
from . import utils

57
from .protocoltemplate import ProtocolTemplate
Philip ABBET's avatar
Philip ABBET committed
58
from .dataformat import DataFormat
59
from .outputs import OutputList
60
from .exceptions import OutputError
Philip ABBET's avatar
Philip ABBET committed
61 62


63
# ----------------------------------------------------------
64

Philip ABBET's avatar
Philip ABBET committed
65 66

class Storage(utils.CodeStorage):
Philip ABBET's avatar
Philip ABBET committed
67
    """Resolves paths for databases
Philip ABBET's avatar
Philip ABBET committed
68

Philip ABBET's avatar
Philip ABBET committed
69
    Parameters:
Philip ABBET's avatar
Philip ABBET committed
70

71
      prefix (str): Establishes the prefix of your installation.
Philip ABBET's avatar
Philip ABBET committed
72

Philip ABBET's avatar
Philip ABBET committed
73 74
      name (str): The name of the database object in the format
        ``<name>/<version>``.
Philip ABBET's avatar
Philip ABBET committed
75

Philip ABBET's avatar
Philip ABBET committed
76
    """
Philip ABBET's avatar
Philip ABBET committed
77

Philip ABBET's avatar
Philip ABBET committed
78
    def __init__(self, prefix, name):
Philip ABBET's avatar
Philip ABBET committed
79

80
        if name.count("/") != 1:
Philip ABBET's avatar
Philip ABBET committed
81
            raise RuntimeError("invalid database name: `%s'" % name)
Philip ABBET's avatar
Philip ABBET committed
82

83
        self.name, self.version = name.split("/")
Philip ABBET's avatar
Philip ABBET committed
84
        self.fullname = name
85
        self.prefix = prefix
Philip ABBET's avatar
Philip ABBET committed
86

87
        path = os.path.join(self.prefix, "databases", name + ".json")
88
        path = path[:-5]
89
        # views are coded in Python
90
        super(Storage, self).__init__(path, "python")
91 92


93
# ----------------------------------------------------------
94

95

96
class Runner(object):
97
    """A special loader class for database views, with specialized methods
98

Philip ABBET's avatar
Philip ABBET committed
99
    Parameters:
100

Philip ABBET's avatar
Philip ABBET committed
101
      db_name (str): The full name of the database object for this view
102

103 104
      module (:std:term:`module`): The preloaded module containing the database
        views as returned by :py:func:`.loader.load_module`.
105

106
      prefix (str): Establishes the prefix of your installation.
107

André Anjos's avatar
André Anjos committed
108
      root_folder (str): The path pointing to the root folder of this database
109

110 111 112
      exc (:std:term:`class`): The class to use as base exception when
        translating the exception from the user code. Read the documention of
        :py:func:`.loader.run` for more details.
113

Philip ABBET's avatar
Philip ABBET committed
114
      *args: Constructor parameters for the database view. Normally, none.
115

Philip ABBET's avatar
Philip ABBET committed
116
      **kwargs: Constructor parameters for the database view. Normally, none.
117

118
    """
119

120
    def __init__(self, module, definition, prefix, root_folder, exc=None):
121

Philip ABBET's avatar
Philip ABBET committed
122
        try:
123 124
            class_ = getattr(module, definition["view"])
        except Exception:
Philip ABBET's avatar
Philip ABBET committed
125 126 127 128
            if exc is not None:
                type, value, traceback = sys.exc_info()
                six.reraise(exc, exc(value), traceback)
            else:
129
                raise  # just re-raise the user exception
130

131 132 133 134 135 136
        self.obj = loader.run(class_, "__new__", exc)
        self.ready = False
        self.prefix = prefix
        self.root_folder = root_folder
        self.definition = definition
        self.exc = exc or RuntimeError
137
        self.data_sources = None
138

139
    def index(self, filename):
140
        """Index the content of the view"""
141

142
        parameters = self.definition.get("parameters", {})
143

144
        objs = loader.run(self.obj, "index", self.exc, self.root_folder, parameters)
145

146 147
        if not isinstance(objs, list):
            raise self.exc("index() didn't return a list")
148

149 150
        if not os.path.exists(os.path.dirname(filename)):
            os.makedirs(os.path.dirname(filename))
151

152
        with open(filename, "wb") as f:
153
            data = simplejson.dumps(objs, cls=utils.NumpyJSONEncoder)
154
            f.write(data.encode("utf-8"))
155

156
    def setup(self, filename, start_index=None, end_index=None, pack=True):
157
        """Sets up the view"""
158

159 160
        if self.ready:
            return
161

162 163
        with open(filename, "rb") as f:
            objs = simplejson.loads(f.read().decode("utf-8"))
164

165 166
        Entry = namedtuple("Entry", sorted(objs[0].keys()))
        objs = [Entry(**x) for x in objs]
167

168
        parameters = self.definition.get("parameters", {})
169

170 171 172 173 174 175 176 177 178 179
        loader.run(
            self.obj,
            "setup",
            self.exc,
            self.root_folder,
            parameters,
            objs,
            start_index=start_index,
            end_index=end_index,
        )
180

181 182
        # Create data sources for the outputs
        from .data import DatabaseOutputDataSource
183

184
        self.data_sources = {}
185
        for output_name, output_format in self.definition.get("outputs", {}).items():
186
            data_source = DatabaseOutputDataSource()
187 188 189 190 191 192 193 194 195
            data_source.setup(
                self,
                output_name,
                output_format,
                self.prefix,
                start_index=start_index,
                end_index=end_index,
                pack=pack,
            )
196
            self.data_sources[output_name] = data_source
197

198
        self.ready = True
199

200
    def get(self, output, index):
201
        """Returns the data of the provided output at the provided index"""
202

Philip ABBET's avatar
Philip ABBET committed
203
        if not self.ready:
204
            raise self.exc("Database view not yet setup")
205

206
        return loader.run(self.obj, "get", self.exc, output, index)
207

208
    def get_output_mapping(self, output):
209
        return loader.run(self.obj, "get_output_mapping", self.exc, output)
210

211 212 213 214
    def objects(self):
        return self.obj.objs


215
# ----------------------------------------------------------
216

217 218

class Database(object):
Philip ABBET's avatar
Philip ABBET committed
219
    """Databases define the start point of the dataflow in an experiment.
220 221


Philip ABBET's avatar
Philip ABBET committed
222
    Parameters:
223

224
      prefix (str): Establishes the prefix of your installation.
225

Philip ABBET's avatar
Philip ABBET committed
226
      name (str): The fully qualified database name (e.g. ``db/1``)
227

228 229 230 231 232 233
      dataformat_cache (:py:class:`dict`, Optional): A dictionary mapping
        dataformat names to loaded dataformats. This parameter is optional and,
        if passed, may greatly speed-up database loading times as dataformats
        that are already loaded may be re-used. If you use this parameter, you
        must guarantee that the cache is refreshed as appropriate in case the
        underlying dataformats change.
234 235


Philip ABBET's avatar
Philip ABBET committed
236
    Attributes:
237

Philip ABBET's avatar
Philip ABBET committed
238
      name (str): The full, valid name of this database
239

Philip ABBET's avatar
Philip ABBET committed
240 241
      data (dict): The original data for this database, as loaded by our JSON
        decoder.
242

Philip ABBET's avatar
Philip ABBET committed
243
    """
244

Philip ABBET's avatar
Philip ABBET committed
245
    def __init__(self, prefix, name, dataformat_cache=None):
246

Philip ABBET's avatar
Philip ABBET committed
247 248
        self._name = None
        self.prefix = prefix
249
        self.dataformats = {}  # preloaded dataformats
Philip ABBET's avatar
Philip ABBET committed
250
        self.storage = None
251

Philip ABBET's avatar
Philip ABBET committed
252 253
        self.errors = []
        self.data = None
254

Philip ABBET's avatar
Philip ABBET committed
255 256
        # if the user has not provided a cache, still use one for performance
        dataformat_cache = dataformat_cache if dataformat_cache is not None else {}
257

Philip ABBET's avatar
Philip ABBET committed
258
        self._load(name, dataformat_cache)
259

260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289
    def _update_dataformat_cache(self, outputs, dataformat_cache):
        for key, value in outputs.items():

            if value in self.dataformats:
                continue

            if value in dataformat_cache:
                dataformat = dataformat_cache[value]
            else:
                dataformat = DataFormat(self.prefix, value)
                dataformat_cache[value] = dataformat

            self.dataformats[value] = dataformat

    def _load_v1(self, dataformat_cache):
        """Loads a v1 database and fills the dataformat cache"""

        for protocol in self.data["protocols"]:
            for set_ in protocol["sets"]:
                self._update_dataformat_cache(set_["outputs"], dataformat_cache)

    def _load_v2(self, dataformat_cache):
        """Loads a v2 database and fills the dataformat cache"""

        for protocol in self.data["protocols"]:
            protocol_template = ProtocolTemplate(
                self.prefix, protocol["template"], dataformat_cache
            )
            for set_ in protocol_template.sets():
                self._update_dataformat_cache(set_["outputs"], dataformat_cache)
290

Philip ABBET's avatar
Philip ABBET committed
291 292
    def _load(self, data, dataformat_cache):
        """Loads the database"""
293

Philip ABBET's avatar
Philip ABBET committed
294
        self._name = data
Philip ABBET's avatar
Philip ABBET committed
295

Philip ABBET's avatar
Philip ABBET committed
296 297 298
        self.storage = Storage(self.prefix, self._name)
        json_path = self.storage.json.path
        if not self.storage.json.exists():
299
            self.errors.append("Database declaration file not found: %s" % json_path)
Philip ABBET's avatar
Philip ABBET committed
300
            return
Philip ABBET's avatar
Philip ABBET committed
301

302 303
        with open(json_path, "rb") as f:
            self.data = simplejson.loads(f.read().decode("utf-8"))
Philip ABBET's avatar
Philip ABBET committed
304

305 306 307
        self.code_path = self.storage.code.path
        self.code = self.storage.code.load()

308
        if self.schema_version == 1:
309
            self._load_v1(dataformat_cache)
310
        elif self.schema_version == 2:
311 312
            self._load_v2(dataformat_cache)
        else:
313 314
            raise RuntimeError(
                "Invalid schema version {schema_version}".format(
315
                    schema_version=self.schema_version
316 317
                )
            )
318

Philip ABBET's avatar
Philip ABBET committed
319 320 321 322
    @property
    def name(self):
        """Returns the name of this object
        """
323
        return self._name or "__unnamed_database__"
324

325 326 327 328 329 330 331 332
    @name.setter
    def name(self, value):
        self._name = value
        self.storage = Storage(self.prefix, value)

    @property
    def description(self):
        """The short description for this object"""
333
        return self.data.get("description", None)
334 335 336 337

    @description.setter
    def description(self, value):
        """Sets the short description for this object"""
338
        self.data["description"] = value
339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357

    @property
    def documentation(self):
        """The full-length description for this object"""

        if not self._name:
            raise RuntimeError("database has no name")

        if self.storage.doc.exists():
            return self.storage.doc.load()
        return None

    @documentation.setter
    def documentation(self, value):
        """Sets the full-length description for this object"""

        if not self._name:
            raise RuntimeError("database has no name")

358
        if hasattr(value, "read"):
359 360 361 362 363 364 365 366 367 368 369 370
            self.storage.doc.save(value.read())
        else:
            self.storage.doc.save(value)

    def hash(self):
        """Returns the hexadecimal hash for its declaration"""

        if not self._name:
            raise RuntimeError("database has no name")

        return self.storage.hash()

Philip ABBET's avatar
Philip ABBET committed
371 372 373
    @property
    def schema_version(self):
        """Returns the schema version"""
374
        return self.data.get("schema_version", 1)
375

Philip ABBET's avatar
Philip ABBET committed
376 377
    @property
    def valid(self):
378
        """A boolean that indicates if this database is valid or not"""
Philip ABBET's avatar
Philip ABBET committed
379

380
        return not bool(self.errors)
Philip ABBET's avatar
Philip ABBET committed
381

Philip ABBET's avatar
Philip ABBET committed
382 383 384
    @property
    def protocols(self):
        """The declaration of all the protocols of the database"""
385

386 387
        data = self.data["protocols"]
        return dict(zip([k["name"] for k in data], data))
388

Philip ABBET's avatar
Philip ABBET committed
389 390
    def protocol(self, name):
        """The declaration of a specific protocol in the database"""
391

Philip ABBET's avatar
Philip ABBET committed
392
        return self.protocols[name]
393

Philip ABBET's avatar
Philip ABBET committed
394 395 396
    @property
    def protocol_names(self):
        """Names of protocols declared for this database"""
397

398 399
        data = self.data["protocols"]
        return [k["name"] for k in data]
400

Philip ABBET's avatar
Philip ABBET committed
401 402
    def sets(self, protocol):
        """The declaration of a specific set in the database protocol"""
403

404
        if self.schema_version == 1:
405 406 407 408 409 410 411 412 413
            data = self.protocol(protocol)["sets"]
        else:
            protocol = self.protocol(protocol)
            protocol_template = ProtocolTemplate(self.prefix, protocol["template"])
            if not protocol_template.valid:
                raise RuntimeError(
                    "\n  * {}".format("\n  * ".join(protocol_template.errors))
                )
            data = protocol_template.sets()
414

415
        return dict(zip([k["name"] for k in data], data))
416

Philip ABBET's avatar
Philip ABBET committed
417 418
    def set(self, protocol, name):
        """The declaration of all the protocols of the database"""
419

Philip ABBET's avatar
Philip ABBET committed
420
        return self.sets(protocol)[name]
421

Philip ABBET's avatar
Philip ABBET committed
422 423
    def set_names(self, protocol):
        """The names of sets in a given protocol for this database"""
424

425
        if self.schema_version == 1:
426 427 428 429 430 431 432 433 434 435 436 437 438 439
            data = self.protocol(protocol)["sets"]
        else:
            protocol = self.protocol(protocol)
            protocol_template = ProtocolTemplate(self.prefix, protocol["template"])
            if not protocol_template.valid:
                raise RuntimeError(
                    "\n  * {}".format("\n  * ".join(protocol_template.errors))
                )
            data = protocol_template.sets()

        return [k["name"] for k in data]

    def view_definition(self, protocol_name, set_name):
        """Returns the definition of a view
440

441 442 443 444 445 446 447 448 449
        Parameters:
          protocol_name (str): The name of the protocol where to retrieve the view
            from

          set_name (str): The name of the set in the protocol where to retrieve the
            view from

        """

450
        if self.schema_version == 1:
451 452 453 454 455 456 457
            view_definition = self.set(protocol_name, set_name)
        else:
            protocol = self.protocol(protocol_name)
            template_name = protocol["template"]
            protocol_template = ProtocolTemplate(self.prefix, template_name)
            view_definition = protocol_template.set(set_name)
            view_definition["view"] = protocol["views"][set_name]["view"]
458 459 460
            parameters = protocol["views"][set_name].get("parameters")
            if parameters is not None:
                view_definition["parameters"] = parameters
461 462

        return view_definition
463

464
    def view(self, protocol, name, exc=None, root_folder=None):
Philip ABBET's avatar
Philip ABBET committed
465
        """Returns the database view, given the protocol and the set name
466

Philip ABBET's avatar
Philip ABBET committed
467
        Parameters:
468

469 470
          protocol (str): The name of the protocol where to retrieve the view
            from
471

Philip ABBET's avatar
Philip ABBET committed
472 473
          name (str): The name of the set in the protocol where to retrieve the
            view from
474

475
          exc (:std:term:`class`): If passed, must be a valid exception class
476 477
            that will be used to report errors in the read-out of this
            database's view.
478

Philip ABBET's avatar
Philip ABBET committed
479
        Returns:
480

Philip ABBET's avatar
Philip ABBET committed
481 482
          The database view, which will be constructed, but not setup. You
          **must** set it up before using methods ``done`` or ``next``.
483

Philip ABBET's avatar
Philip ABBET committed
484 485 486 487 488 489 490
        """

        if not self._name:
            exc = exc or RuntimeError
            raise exc("database has no name")

        if not self.valid:
491 492 493 494 495
            message = (
                "cannot load view for set `%s' of protocol `%s' "
                "from invalid database (%s)\n%s"
                % (protocol, name, self.name, "   \n".join(self.errors))
            )
496 497 498
            if exc:
                raise exc(message)

Philip ABBET's avatar
Philip ABBET committed
499 500
            raise RuntimeError(message)

501 502
        # loads the module only once through the lifetime of the database
        # object
Philip ABBET's avatar
Philip ABBET committed
503
        try:
504 505 506 507 508
            if not hasattr(self, "_module"):
                self._module = loader.load_module(
                    self.name.replace(os.sep, "_"), self.storage.code.path, {}
                )
        except Exception:
Philip ABBET's avatar
Philip ABBET committed
509 510 511 512
            if exc is not None:
                type, value, traceback = sys.exc_info()
                six.reraise(exc, exc(value), traceback)
            else:
513
                raise  # just re-raise the user exception
514

515
        if root_folder is None:
516
            root_folder = self.data["root_folder"]
517

518 519 520 521 522 523 524
        return Runner(
            self._module,
            self.view_definition(protocol, name),
            self.prefix,
            root_folder,
            exc,
        )
525

526 527 528 529 530 531
    def json_dumps(self, indent=4):
        """Dumps the JSON declaration of this object in a string


        Parameters:

532 533
          indent (int): The number of indentation spaces at every indentation
            level
534 535 536 537 538 539 540 541


        Returns:

          str: The JSON representation for this object

        """

542
        return simplejson.dumps(self.data, indent=indent, cls=utils.NumpyJSONEncoder)
543 544 545 546 547 548 549 550 551

    def __str__(self):
        return self.json_dumps()

    def write(self, storage=None):
        """Writes contents to prefix location

        Parameters:

552 553 554
          storage (:py:class:`.Storage`, Optional): If you pass a new storage,
            then this object will be written to that storage point rather than
            its default.
555 556 557 558 559 560

        """

        if storage is None:
            if not self._name:
                raise RuntimeError("database has no name")
561
            storage = self.storage  # overwrite
562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592

        storage.save(str(self), self.code, self.description)

    def export(self, prefix):
        """Recursively exports itself into another prefix

        Dataformats associated are also exported recursively


        Parameters:

          prefix (str): A path to a prefix that must different then my own.


        Returns:

          None


        Raises:

          RuntimeError: If prefix and self.prefix point to the same directory.

        """

        if not self._name:
            raise RuntimeError("database has no name")

        if not self.valid:
            raise RuntimeError("database is not valid")

593
        if prefix == self.prefix:
594 595 596
            raise RuntimeError(
                "Cannot export database to the same prefix (" "%s)" % prefix
            )
597 598 599 600

        for k in self.dataformats.values():
            k.export(prefix)

601
        if self.schema_version != 1:
602 603 604 605
            for protocol in self.protocols.values():
                protocol_template = ProtocolTemplate(self.prefix, protocol["template"])
                protocol_template.export(prefix)

606 607 608
        self.write(Storage(prefix, self.name))


609
# ----------------------------------------------------------
610 611 612


class View(object):
613 614 615 616 617 618 619 620
    def __init__(self):
        #  Current databases definitions uses named tuple to store information.
        #  This has one limitation, python keywords like `class` cannot be used.
        #  output_member_map allows to use that kind of keyword as output name
        #  while using something different for the named tuple (for example cls,
        #  klass, etc.)
        self.output_member_map = {}

621 622 623 624
    def index(self, root_folder, parameters):
        """Returns a list of (named) tuples describing the data provided by the view.

        The ordering of values inside the tuples is free, but it is expected
625 626
        that the list is ordered in a consistent manner (ie. all train images
        of person A, then all train images of person B, ...).
627 628 629

        For instance, assuming a view providing that kind of data:

630 631 632 633 634 635 636 637 638 639 640
        .. code-block:: text

           ----------- ----------- ----------- ----------- ----------- -----------
           |  image  | |  image  | |  image  | |  image  | |  image  | |  image  |
           ----------- ----------- ----------- ----------- ----------- -----------
           ----------- ----------- ----------- ----------- ----------- -----------
           | file_id | | file_id | | file_id | | file_id | | file_id | | file_id |
           ----------- ----------- ----------- ----------- ----------- -----------
           ----------------------------------- -----------------------------------
           |             client_id           | |             client_id           |
           ----------------------------------- -----------------------------------
641 642 643

        a list like the following should be generated:

644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660
        .. code-block:: python

           [
               (client_id=1, file_id=1, image=filename1),
               (client_id=1, file_id=2, image=filename2),
               (client_id=1, file_id=3, image=filename3),
               (client_id=2, file_id=4, image=filename4),
               (client_id=2, file_id=5, image=filename5),
               (client_id=2, file_id=6, image=filename6),
               ...
           ]

        .. warning::

           DO NOT store images, sound files or data loadable from a file in the
           list!  Store the path of the file to load instead.

661 662 663 664 665 666
        """

        raise NotImplementedError

    def setup(self, root_folder, parameters, objs, start_index=None, end_index=None):

667
        # Initialisation
668 669 670 671 672 673 674 675
        self.root_folder = root_folder
        self.parameters = parameters
        self.objs = objs

        # Determine the range of indices that must be provided
        self.start_index = start_index if start_index is not None else 0
        self.end_index = end_index if end_index is not None else len(self.objs) - 1

676
        self.objs = self.objs[self.start_index : self.end_index + 1]  # noqa
677 678

    def get(self, output, index):
679 680 681 682
        """Returns the data of the provided output at the provided index in the
        list of (named) tuples describing the data provided by the view
        (accessible at self.objs)
        """
683 684

        raise NotImplementedError
685

686 687 688 689 690 691
    def get_output_mapping(self, output):
        """Returns the object member to use for given output if any otherwise
        the member name is the output name.
        """
        return self.output_member_map.get(output, output)

692

693
# ----------------------------------------------------------
694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725


class DatabaseTester:
    """Used while developing a new database view, to test its behavior

    This class tests that, for each combination of connected/not connected
    outputs:

      - Data indices seems consistent
      - All the connected outputs produce data
      - All the not connected outputs don't produce data

    It also report some stats, and can generate a text file detailing the
    data generated by each output.

    By default, outputs are assumed to produce data at constant intervals.
    Those that don't follow this pattern, must be declared as 'irregular'.

    Note that no particular check is done about the database declaration or
    the correctness of the generated data with their data formats. This class
    is mainly used to check that the outputs are correctly synchronized.
    """

    # Mock output class
    class MockOutput:
        def __init__(self, name, connected):
            self.name = name
            self.connected = connected
            self.last_written_data_index = -1
            self.written_data = []

        def write(self, data, end_data_index):
726 727 728
            self.written_data.append(
                (self.last_written_data_index + 1, end_data_index, data)
            )
729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759
            self.last_written_data_index = end_data_index

        def isConnected(self):
            return self.connected

    class SynchronizedUnit:
        def __init__(self, start, end):
            self.start = start
            self.end = end
            self.data = {}
            self.children = []

        def addData(self, output, start, end, data):
            if (start == self.start) and (end == self.end):
                self.data[output] = self._dataToString(data)
            elif (len(self.children) == 0) or (self.children[-1].end < start):
                unit = DatabaseTester.SynchronizedUnit(start, end)
                unit.addData(output, start, end, data)
                self.children.append(unit)
            else:
                for index, unit in enumerate(self.children):
                    if (unit.start <= start) and (unit.end >= end):
                        unit.addData(output, start, end, data)
                        break
                    elif (unit.start == start) and (unit.end < end):
                        new_unit = DatabaseTester.SynchronizedUnit(start, end)
                        new_unit.addData(output, start, end, data)
                        new_unit.children.append(unit)

                        for i in range(index + 1, len(self.children)):
                            unit = self.children[i]
760
                            if unit.end <= end:
761 762 763 764
                                new_unit.children.append(unit)
                            else:
                                break

765 766 767
                        self.children = (
                            self.children[:index] + [new_unit] + self.children[i:]
                        )
768 769 770 771 772 773 774 775
                        break

        def toString(self):
            texts = {}

            for child in self.children:
                child_texts = child.toString()
                for output, text in child_texts.items():
776
                    if output in texts:
777
                        texts[output] += " " + text
778 779 780 781
                    else:
                        texts[output] = text

            if len(self.data) > 0:
782
                length = max([len(x) + 6 for x in self.data.values()])
783 784 785 786 787 788 789 790 791 792 793 794 795 796 797

                if len(texts) > 0:
                    children_length = len(texts.values()[0])
                    if children_length >= length:
                        length = children_length
                    else:
                        diff = length - children_length
                        if diff % 2 == 0:
                            diff1 = diff / 2
                            diff2 = diff1
                        else:
                            diff1 = diff // 2
                            diff2 = diff - diff1

                        for k, v in texts.items():
798
                            texts[k] = "|%s%s%s|" % ("-" * diff1, v[1:-1], "-" * diff2)
799 800 801 802 803 804 805 806 807 808

                for output, value in self.data.items():
                    output_length = len(value) + 6
                    diff = length - output_length
                    if diff % 2 == 0:
                        diff1 = diff / 2
                        diff2 = diff1
                    else:
                        diff1 = diff // 2
                        diff2 = diff - diff1
809
                    texts[output] = "|-%s %s %s-|" % ("-" * diff1, value, "-" * diff2)
810 811 812 813

            length = max(len(x) for x in texts.values())
            for k, v in texts.items():
                if len(v) < length:
814
                    texts[k] += " " * (length - len(v))
815 816 817 818 819

            return texts

        def _dataToString(self, data):
            if (len(data) > 1) or (len(data) == 0):
820
                return "X"
821 822 823 824

            value = data[data.keys()[0]]

            if isinstance(value, np.ndarray) or isinstance(value, dict):
825
                return "X"
826 827 828

            return str(value)

829 830 831 832 833 834 835 836 837
    def __init__(
        self,
        name,
        view_class,
        outputs_declaration,
        parameters,
        irregular_outputs=[],
        all_combinations=True,
    ):
838 839 840 841 842 843 844 845 846 847
        self.name = name
        self.view_class = view_class
        self.outputs_declaration = {}
        self.parameters = parameters
        self.irregular_outputs = irregular_outputs

        self.determine_regular_intervals(outputs_declaration)

        if all_combinations:
            for L in range(0, len(self.outputs_declaration) + 1):
848 849 850
                for subset in itertools.combinations(
                    self.outputs_declaration.keys(), L
                ):
851 852 853 854 855 856 857 858 859 860
                    self.run(subset)
        else:
            self.run(self.outputs_declaration.keys())

    def determine_regular_intervals(self, outputs_declaration):
        outputs = OutputList()
        for name in outputs_declaration:
            outputs.add(DatabaseTester.MockOutput(name, True))

        view = self.view_class()
861
        view.setup("", outputs, self.parameters)
862 863 864 865 866

        view.next()

        for output in outputs:
            if output.name not in self.irregular_outputs:
867 868 869
                self.outputs_declaration[output.name] = (
                    output.last_written_data_index + 1
                )
870 871 872 873 874 875 876
            else:
                self.outputs_declaration[output.name] = None

    def run(self, connected_outputs):
        if len(connected_outputs) == 0:
            return

877 878 879 880
        print(
            "Testing '%s', with %d output(s): %s"
            % (self.name, len(connected_outputs), ", ".join(connected_outputs))
        )
881 882

        # Create the mock outputs
883 884 885 886 887 888 889 890 891 892 893
        connected_outputs = dict(
            [x for x in self.outputs_declaration.items() if x[0] in connected_outputs]
        )

        not_connected_outputs = dict(
            [
                x
                for x in self.outputs_declaration.items()
                if x[0] not in connected_outputs
            ]
        )
894 895 896 897 898 899 900

        outputs = OutputList()
        for name in self.outputs_declaration.keys():
            outputs.add(DatabaseTester.MockOutput(name, name in connected_outputs))

        # Create the view
        view = self.view_class()
901
        view.setup("", outputs, self.parameters)
902 903 904 905 906 907 908 909 910 911

        # Initialisations
        next_expected_indices = {}
        for name, interval in connected_outputs.items():
            next_expected_indices[name] = 0

        next_index = 0

        def _done():
            for output in outputs:
912 913 914
                if output.isConnected() and not view.done(
                    output.last_written_data_index
                ):
915 916 917 918
                    return False
            return True

        # Ask for all the data
919
        while not (_done()):
920 921 922 923 924
            view.next()

            # Check the indices for the connected outputs
            for name in connected_outputs.keys():
                if name not in self.irregular_outputs:
925 926 927 928 929 930 931 932 933
                    if not (
                        outputs[name].written_data[-1][0] == next_expected_indices[name]
                    ):
                        raise OutputError("Wrong current index")
                    if not (
                        outputs[name].written_data[-1][1]
                        == next_expected_indices[name] + connected_outputs[name] - 1
                    ):
                        raise OutputError("Wrong next index")
934
                else:
935 936 937 938 939 940 941 942
                    if not (
                        outputs[name].written_data[-1][0] == next_expected_indices[name]
                    ):
                        raise OutputError("Wrong current index")
                    if not (
                        outputs[name].written_data[-1][1] >= next_expected_indices[name]
                    ):
                        raise OutputError("Wrong next index")
943 944 945

            # Check that the not connected outputs didn't produce data
            for name in not_connected_outputs.keys():
946 947
                if len(outputs[name].written_data) != 0:
                    raise OutputError("Data written on unconnected output")
948 949

            # Compute the next data index that should be produced
950 951 952
            next_index = 1 + min(
                [x.written_data[-1][1] for x in outputs if x.isConnected()]
            )
953

954 955
            # Compute the next data index that should be produced by each
            # connected output
956 957
            for name in connected_outputs.keys():
                if name not in self.irregular_outputs:
958 959 960 961
                    if (
                        next_index
                        == next_expected_indices[name] + connected_outputs[name]
                    ):
962 963 964
                        next_expected_indices[name] += connected_outputs[name]
                else:
                    if next_index > outputs[name].written_data[-1][1]:
965 966 967
                        next_expected_indices[name] = (
                            outputs[name].written_data[-1][1] + 1
                        )
968 969 970

        # Check the number of data produced on the regular outputs
        for name in connected_outputs.keys():
971
            print("  - %s: %d data" % (name, len(outputs[name].written_data)))
972
            if name not in self.irregular_outputs:
973 974 975 976 977
                if not (
                    len(outputs[name].written_data)
                    == next_index / connected_outputs[name]
                ):
                    raise OutputError("Invalid number of data produced")
978 979 980

        # Check that all outputs ends on the same index
        for name in connected_outputs.keys():
981 982
            if not outputs[name].written_data[-1][1] == next_index - 1:
                raise OutputError("Outputs not on same index")
983

984 985
        # Generate a text file with lots of details (only if all outputs are
        # connected)
986 987 988
        if len(connected_outputs) == len(self.outputs_declaration):
            sorted_outputs = sorted(outputs, key=lambda x: len(x.written_data))

989 990 991
            unit = DatabaseTester.SynchronizedUnit(
                0, sorted_outputs[0].written_data[-1][1]
            )
992 993 994 995 996 997 998

            for output in sorted_outputs:
                for data in output.written_data:
                    unit.addData(output.name, data[0], data[1], data[2])

            texts = unit.toString()

999
            outputs_max_length = max([len(x) for x in self.outputs_declaration.keys()])
1000

1001
            with open(self.name.replace(" ", "_") + ".txt", "w") as f:
1002 1003
                for i in range(1, len(sorted_outputs) + 1):
                    output_name = sorted_outputs[-i].name
1004
                    f.write(output_name + ": ")
1005 1006

                    if len(output_name) < outputs_max_length:
1007
                        f.write(" " * (outputs_max_length - len(output_name)))
1008

1009
                    f.write(texts[output_name] + "\n")