database.py 34.8 KB
Newer Older
1 2 3
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :

4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
###################################################################################
#                                                                                 #
# Copyright (c) 2019 Idiap Research Institute, http://www.idiap.ch/               #
# Contact: beat.support@idiap.ch                                                  #
#                                                                                 #
# Redistribution and use in source and binary forms, with or without              #
# modification, are permitted provided that the following conditions are met:     #
#                                                                                 #
# 1. Redistributions of source code must retain the above copyright notice, this  #
# list of conditions and the following disclaimer.                                #
#                                                                                 #
# 2. Redistributions in binary form must reproduce the above copyright notice,    #
# this list of conditions and the following disclaimer in the documentation       #
# and/or other materials provided with the distribution.                          #
#                                                                                 #
# 3. Neither the name of the copyright holder nor the names of its contributors   #
# may be used to endorse or promote products derived from this software without   #
# specific prior written permission.                                              #
#                                                                                 #
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND #
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED   #
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE          #
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE    #
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL      #
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR      #
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER      #
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,   #
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE   #
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.            #
#                                                                                 #
###################################################################################
35 36


37 38 39 40 41 42 43
"""
========
database
========

Validation of databases
"""
44 45 46 47 48

import os
import sys

import six
49
import simplejson as json
50 51
import itertools
import numpy as np
52
from collections import namedtuple
53 54

from . import loader
Philip ABBET's avatar
Philip ABBET committed
55 56
from . import utils

57
from .protocoltemplate import ProtocolTemplate
Philip ABBET's avatar
Philip ABBET committed
58
from .dataformat import DataFormat
59
from .outputs import OutputList
60
from .exceptions import OutputError
Philip ABBET's avatar
Philip ABBET committed
61 62


63
# ----------------------------------------------------------
64

Philip ABBET's avatar
Philip ABBET committed
65 66

class Storage(utils.CodeStorage):
Philip ABBET's avatar
Philip ABBET committed
67
    """Resolves paths for databases
Philip ABBET's avatar
Philip ABBET committed
68

Philip ABBET's avatar
Philip ABBET committed
69
    Parameters:
Philip ABBET's avatar
Philip ABBET committed
70

71
      prefix (str): Establishes the prefix of your installation.
Philip ABBET's avatar
Philip ABBET committed
72

Philip ABBET's avatar
Philip ABBET committed
73 74
      name (str): The name of the database object in the format
        ``<name>/<version>``.
Philip ABBET's avatar
Philip ABBET committed
75

Philip ABBET's avatar
Philip ABBET committed
76
    """
Philip ABBET's avatar
Philip ABBET committed
77

78 79 80
    asset_type = "database"
    asset_folder = "databases"

Philip ABBET's avatar
Philip ABBET committed
81
    def __init__(self, prefix, name):
Philip ABBET's avatar
Philip ABBET committed
82

83
        if name.count("/") != 1:
Philip ABBET's avatar
Philip ABBET committed
84
            raise RuntimeError("invalid database name: `%s'" % name)
Philip ABBET's avatar
Philip ABBET committed
85

86
        self.name, self.version = name.split("/")
Philip ABBET's avatar
Philip ABBET committed
87
        self.fullname = name
88
        self.prefix = prefix
Philip ABBET's avatar
Philip ABBET committed
89

90
        path = os.path.join(self.prefix, self.asset_folder, name + ".json")
91
        path = path[:-5]
92
        # views are coded in Python
93
        super(Storage, self).__init__(path, "python")
94 95


96
# ----------------------------------------------------------
97

98

99
class Runner(object):
100
    """A special loader class for database views, with specialized methods
101

Philip ABBET's avatar
Philip ABBET committed
102
    Parameters:
103

Philip ABBET's avatar
Philip ABBET committed
104
      db_name (str): The full name of the database object for this view
105

106 107
      module (:std:term:`module`): The preloaded module containing the database
        views as returned by :py:func:`.loader.load_module`.
108

109
      prefix (str): Establishes the prefix of your installation.
110

André Anjos's avatar
André Anjos committed
111
      root_folder (str): The path pointing to the root folder of this database
112

113 114 115
      exc (:std:term:`class`): The class to use as base exception when
        translating the exception from the user code. Read the documention of
        :py:func:`.loader.run` for more details.
116

Philip ABBET's avatar
Philip ABBET committed
117
      *args: Constructor parameters for the database view. Normally, none.
118

Philip ABBET's avatar
Philip ABBET committed
119
      **kwargs: Constructor parameters for the database view. Normally, none.
120

121
    """
122

123
    def __init__(self, module, definition, prefix, root_folder, exc=None):
124

Philip ABBET's avatar
Philip ABBET committed
125
        try:
126 127
            class_ = getattr(module, definition["view"])
        except Exception:
Philip ABBET's avatar
Philip ABBET committed
128 129 130 131
            if exc is not None:
                type, value, traceback = sys.exc_info()
                six.reraise(exc, exc(value), traceback)
            else:
132
                raise  # just re-raise the user exception
133

134 135 136 137 138 139
        self.obj = loader.run(class_, "__new__", exc)
        self.ready = False
        self.prefix = prefix
        self.root_folder = root_folder
        self.definition = definition
        self.exc = exc or RuntimeError
140
        self.data_sources = None
141

142
    def index(self, filename):
143
        """Index the content of the view"""
144

145
        parameters = self.definition.get("parameters", {})
146

147
        objs = loader.run(self.obj, "index", self.exc, self.root_folder, parameters)
148

149 150
        if not isinstance(objs, list):
            raise self.exc("index() didn't return a list")
151

152 153
        if not os.path.exists(os.path.dirname(filename)):
            os.makedirs(os.path.dirname(filename))
154

155
        with open(filename, "wb") as f:
156
            data = json.dumps(objs, cls=utils.NumpyJSONEncoder)
157
            f.write(data.encode("utf-8"))
158

159
    def setup(self, filename, start_index=None, end_index=None, pack=True):
160
        """Sets up the view"""
161

162 163
        if self.ready:
            return
164

165
        with open(filename, "rb") as f:
166 167 168 169
            objs = json.loads(
                f.read().decode("utf-8"),
                object_pairs_hook=utils.error_on_duplicate_key_hook,
            )
170

171 172
        Entry = namedtuple("Entry", sorted(objs[0].keys()))
        objs = [Entry(**x) for x in objs]
173

174
        parameters = self.definition.get("parameters", {})
175

176 177 178 179 180 181 182 183 184 185
        loader.run(
            self.obj,
            "setup",
            self.exc,
            self.root_folder,
            parameters,
            objs,
            start_index=start_index,
            end_index=end_index,
        )
186

187 188
        # Create data sources for the outputs
        from .data import DatabaseOutputDataSource
189

190
        self.data_sources = {}
191
        for output_name, output_format in self.definition.get("outputs", {}).items():
192
            data_source = DatabaseOutputDataSource()
193 194 195 196 197 198 199 200 201
            data_source.setup(
                self,
                output_name,
                output_format,
                self.prefix,
                start_index=start_index,
                end_index=end_index,
                pack=pack,
            )
202
            self.data_sources[output_name] = data_source
203

204
        self.ready = True
205

206
    def get(self, output, index):
207
        """Returns the data of the provided output at the provided index"""
208

Philip ABBET's avatar
Philip ABBET committed
209
        if not self.ready:
210
            raise self.exc("Database view not yet setup")
211

212
        return loader.run(self.obj, "get", self.exc, output, index)
213

214
    def get_output_mapping(self, output):
215
        return loader.run(self.obj, "get_output_mapping", self.exc, output)
216

217 218 219 220
    def objects(self):
        return self.obj.objs


221
# ----------------------------------------------------------
222

223 224

class Database(object):
Philip ABBET's avatar
Philip ABBET committed
225
    """Databases define the start point of the dataflow in an experiment.
226 227


Philip ABBET's avatar
Philip ABBET committed
228
    Parameters:
229

230
      prefix (str): Establishes the prefix of your installation.
231

Philip ABBET's avatar
Philip ABBET committed
232
      name (str): The fully qualified database name (e.g. ``db/1``)
233

234 235 236 237 238 239
      dataformat_cache (:py:class:`dict`, Optional): A dictionary mapping
        dataformat names to loaded dataformats. This parameter is optional and,
        if passed, may greatly speed-up database loading times as dataformats
        that are already loaded may be re-used. If you use this parameter, you
        must guarantee that the cache is refreshed as appropriate in case the
        underlying dataformats change.
240 241


Philip ABBET's avatar
Philip ABBET committed
242
    Attributes:
243

Philip ABBET's avatar
Philip ABBET committed
244
      name (str): The full, valid name of this database
245

Philip ABBET's avatar
Philip ABBET committed
246 247
      data (dict): The original data for this database, as loaded by our JSON
        decoder.
248

Philip ABBET's avatar
Philip ABBET committed
249
    """
250

Philip ABBET's avatar
Philip ABBET committed
251
    def __init__(self, prefix, name, dataformat_cache=None):
252

Philip ABBET's avatar
Philip ABBET committed
253 254
        self._name = None
        self.prefix = prefix
255
        self.dataformats = {}  # preloaded dataformats
Philip ABBET's avatar
Philip ABBET committed
256
        self.storage = None
257

Philip ABBET's avatar
Philip ABBET committed
258 259
        self.errors = []
        self.data = None
260

Philip ABBET's avatar
Philip ABBET committed
261 262
        # if the user has not provided a cache, still use one for performance
        dataformat_cache = dataformat_cache if dataformat_cache is not None else {}
263

Philip ABBET's avatar
Philip ABBET committed
264
        self._load(name, dataformat_cache)
265

266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295
    def _update_dataformat_cache(self, outputs, dataformat_cache):
        for key, value in outputs.items():

            if value in self.dataformats:
                continue

            if value in dataformat_cache:
                dataformat = dataformat_cache[value]
            else:
                dataformat = DataFormat(self.prefix, value)
                dataformat_cache[value] = dataformat

            self.dataformats[value] = dataformat

    def _load_v1(self, dataformat_cache):
        """Loads a v1 database and fills the dataformat cache"""

        for protocol in self.data["protocols"]:
            for set_ in protocol["sets"]:
                self._update_dataformat_cache(set_["outputs"], dataformat_cache)

    def _load_v2(self, dataformat_cache):
        """Loads a v2 database and fills the dataformat cache"""

        for protocol in self.data["protocols"]:
            protocol_template = ProtocolTemplate(
                self.prefix, protocol["template"], dataformat_cache
            )
            for set_ in protocol_template.sets():
                self._update_dataformat_cache(set_["outputs"], dataformat_cache)
296

Philip ABBET's avatar
Philip ABBET committed
297 298
    def _load(self, data, dataformat_cache):
        """Loads the database"""
299

Philip ABBET's avatar
Philip ABBET committed
300
        self._name = data
Philip ABBET's avatar
Philip ABBET committed
301

Philip ABBET's avatar
Philip ABBET committed
302 303 304
        self.storage = Storage(self.prefix, self._name)
        json_path = self.storage.json.path
        if not self.storage.json.exists():
305
            self.errors.append("Database declaration file not found: %s" % json_path)
Philip ABBET's avatar
Philip ABBET committed
306
            return
Philip ABBET's avatar
Philip ABBET committed
307

308
        with open(json_path, "rb") as f:
309 310 311 312 313 314 315 316
            try:
                self.data = json.loads(
                    f.read().decode("utf-8"),
                    object_pairs_hook=utils.error_on_duplicate_key_hook,
                )
            except RuntimeError as error:
                self.errors.append("Database declaration file invalid: %s" % error)
                return
Philip ABBET's avatar
Philip ABBET committed
317

318 319 320
        self.code_path = self.storage.code.path
        self.code = self.storage.code.load()

321
        if self.schema_version == 1:
322
            self._load_v1(dataformat_cache)
323
        elif self.schema_version == 2:
324 325
            self._load_v2(dataformat_cache)
        else:
326 327
            raise RuntimeError(
                "Invalid schema version {schema_version}".format(
328
                    schema_version=self.schema_version
329 330
                )
            )
331

Philip ABBET's avatar
Philip ABBET committed
332 333 334 335
    @property
    def name(self):
        """Returns the name of this object
        """
336
        return self._name or "__unnamed_database__"
337

338 339 340 341 342 343 344 345
    @name.setter
    def name(self, value):
        self._name = value
        self.storage = Storage(self.prefix, value)

    @property
    def description(self):
        """The short description for this object"""
346
        return self.data.get("description", None)
347 348 349 350

    @description.setter
    def description(self, value):
        """Sets the short description for this object"""
351
        self.data["description"] = value
352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370

    @property
    def documentation(self):
        """The full-length description for this object"""

        if not self._name:
            raise RuntimeError("database has no name")

        if self.storage.doc.exists():
            return self.storage.doc.load()
        return None

    @documentation.setter
    def documentation(self, value):
        """Sets the full-length description for this object"""

        if not self._name:
            raise RuntimeError("database has no name")

371
        if hasattr(value, "read"):
372 373 374 375 376 377 378 379 380 381 382 383
            self.storage.doc.save(value.read())
        else:
            self.storage.doc.save(value)

    def hash(self):
        """Returns the hexadecimal hash for its declaration"""

        if not self._name:
            raise RuntimeError("database has no name")

        return self.storage.hash()

Philip ABBET's avatar
Philip ABBET committed
384 385 386
    @property
    def schema_version(self):
        """Returns the schema version"""
387
        return self.data.get("schema_version", 1)
388

Philip ABBET's avatar
Philip ABBET committed
389 390
    @property
    def valid(self):
391
        """A boolean that indicates if this database is valid or not"""
Philip ABBET's avatar
Philip ABBET committed
392

393
        return not bool(self.errors)
Philip ABBET's avatar
Philip ABBET committed
394

395 396 397 398 399 400
    @property
    def environment(self):
        """Returns the run environment if any has been set"""

        return self.data.get("environment")

Philip ABBET's avatar
Philip ABBET committed
401 402 403
    @property
    def protocols(self):
        """The declaration of all the protocols of the database"""
404

405 406
        data = self.data["protocols"]
        return dict(zip([k["name"] for k in data], data))
407

Philip ABBET's avatar
Philip ABBET committed
408 409
    def protocol(self, name):
        """The declaration of a specific protocol in the database"""
410

Philip ABBET's avatar
Philip ABBET committed
411
        return self.protocols[name]
412

Philip ABBET's avatar
Philip ABBET committed
413 414 415
    @property
    def protocol_names(self):
        """Names of protocols declared for this database"""
416

417 418
        data = self.data["protocols"]
        return [k["name"] for k in data]
419

Philip ABBET's avatar
Philip ABBET committed
420 421
    def sets(self, protocol):
        """The declaration of a specific set in the database protocol"""
422

423
        if self.schema_version == 1:
424 425 426 427 428 429 430 431 432
            data = self.protocol(protocol)["sets"]
        else:
            protocol = self.protocol(protocol)
            protocol_template = ProtocolTemplate(self.prefix, protocol["template"])
            if not protocol_template.valid:
                raise RuntimeError(
                    "\n  * {}".format("\n  * ".join(protocol_template.errors))
                )
            data = protocol_template.sets()
433

434
        return dict(zip([k["name"] for k in data], data))
435

Philip ABBET's avatar
Philip ABBET committed
436 437
    def set(self, protocol, name):
        """The declaration of all the protocols of the database"""
438

Philip ABBET's avatar
Philip ABBET committed
439
        return self.sets(protocol)[name]
440

Philip ABBET's avatar
Philip ABBET committed
441 442
    def set_names(self, protocol):
        """The names of sets in a given protocol for this database"""
443

444
        if self.schema_version == 1:
445 446 447 448 449 450 451 452 453 454 455 456 457 458
            data = self.protocol(protocol)["sets"]
        else:
            protocol = self.protocol(protocol)
            protocol_template = ProtocolTemplate(self.prefix, protocol["template"])
            if not protocol_template.valid:
                raise RuntimeError(
                    "\n  * {}".format("\n  * ".join(protocol_template.errors))
                )
            data = protocol_template.sets()

        return [k["name"] for k in data]

    def view_definition(self, protocol_name, set_name):
        """Returns the definition of a view
459

460 461 462 463 464 465 466 467 468
        Parameters:
          protocol_name (str): The name of the protocol where to retrieve the view
            from

          set_name (str): The name of the set in the protocol where to retrieve the
            view from

        """

469
        if self.schema_version == 1:
470 471 472 473 474 475 476
            view_definition = self.set(protocol_name, set_name)
        else:
            protocol = self.protocol(protocol_name)
            template_name = protocol["template"]
            protocol_template = ProtocolTemplate(self.prefix, template_name)
            view_definition = protocol_template.set(set_name)
            view_definition["view"] = protocol["views"][set_name]["view"]
477 478 479
            parameters = protocol["views"][set_name].get("parameters")
            if parameters is not None:
                view_definition["parameters"] = parameters
480 481

        return view_definition
482

483
    def view(self, protocol, name, exc=None, root_folder=None):
Philip ABBET's avatar
Philip ABBET committed
484
        """Returns the database view, given the protocol and the set name
485

Philip ABBET's avatar
Philip ABBET committed
486
        Parameters:
487

488 489
          protocol (str): The name of the protocol where to retrieve the view
            from
490

Philip ABBET's avatar
Philip ABBET committed
491 492
          name (str): The name of the set in the protocol where to retrieve the
            view from
493

494
          exc (:std:term:`class`): If passed, must be a valid exception class
495 496
            that will be used to report errors in the read-out of this
            database's view.
497

Philip ABBET's avatar
Philip ABBET committed
498
        Returns:
499

Philip ABBET's avatar
Philip ABBET committed
500 501
          The database view, which will be constructed, but not setup. You
          **must** set it up before using methods ``done`` or ``next``.
502

Philip ABBET's avatar
Philip ABBET committed
503 504 505 506 507 508 509
        """

        if not self._name:
            exc = exc or RuntimeError
            raise exc("database has no name")

        if not self.valid:
510 511 512 513 514
            message = (
                "cannot load view for set `%s' of protocol `%s' "
                "from invalid database (%s)\n%s"
                % (protocol, name, self.name, "   \n".join(self.errors))
            )
515 516 517
            if exc:
                raise exc(message)

Philip ABBET's avatar
Philip ABBET committed
518 519
            raise RuntimeError(message)

520 521
        # loads the module only once through the lifetime of the database
        # object
Philip ABBET's avatar
Philip ABBET committed
522
        try:
523 524 525 526 527
            if not hasattr(self, "_module"):
                self._module = loader.load_module(
                    self.name.replace(os.sep, "_"), self.storage.code.path, {}
                )
        except Exception:
Philip ABBET's avatar
Philip ABBET committed
528 529 530 531
            if exc is not None:
                type, value, traceback = sys.exc_info()
                six.reraise(exc, exc(value), traceback)
            else:
532
                raise  # just re-raise the user exception
533

534
        if root_folder is None:
535
            root_folder = self.data["root_folder"]
536

537 538 539 540 541 542 543
        return Runner(
            self._module,
            self.view_definition(protocol, name),
            self.prefix,
            root_folder,
            exc,
        )
544

545 546 547 548 549 550
    def json_dumps(self, indent=4):
        """Dumps the JSON declaration of this object in a string


        Parameters:

551 552
          indent (int): The number of indentation spaces at every indentation
            level
553 554 555 556 557 558 559 560


        Returns:

          str: The JSON representation for this object

        """

561
        return json.dumps(self.data, indent=indent, cls=utils.NumpyJSONEncoder)
562 563 564 565 566 567 568 569 570

    def __str__(self):
        return self.json_dumps()

    def write(self, storage=None):
        """Writes contents to prefix location

        Parameters:

571 572 573
          storage (:py:class:`.Storage`, Optional): If you pass a new storage,
            then this object will be written to that storage point rather than
            its default.
574 575 576 577 578 579

        """

        if storage is None:
            if not self._name:
                raise RuntimeError("database has no name")
580
            storage = self.storage  # overwrite
581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611

        storage.save(str(self), self.code, self.description)

    def export(self, prefix):
        """Recursively exports itself into another prefix

        Dataformats associated are also exported recursively


        Parameters:

          prefix (str): A path to a prefix that must different then my own.


        Returns:

          None


        Raises:

          RuntimeError: If prefix and self.prefix point to the same directory.

        """

        if not self._name:
            raise RuntimeError("database has no name")

        if not self.valid:
            raise RuntimeError("database is not valid")

612
        if prefix == self.prefix:
613 614 615
            raise RuntimeError(
                "Cannot export database to the same prefix (" "%s)" % prefix
            )
616 617 618 619

        for k in self.dataformats.values():
            k.export(prefix)

620
        if self.schema_version != 1:
621 622 623 624
            for protocol in self.protocols.values():
                protocol_template = ProtocolTemplate(self.prefix, protocol["template"])
                protocol_template.export(prefix)

625 626 627
        self.write(Storage(prefix, self.name))


628
# ----------------------------------------------------------
629 630 631


class View(object):
632 633 634 635 636 637 638 639
    def __init__(self):
        #  Current databases definitions uses named tuple to store information.
        #  This has one limitation, python keywords like `class` cannot be used.
        #  output_member_map allows to use that kind of keyword as output name
        #  while using something different for the named tuple (for example cls,
        #  klass, etc.)
        self.output_member_map = {}

640 641 642 643
    def index(self, root_folder, parameters):
        """Returns a list of (named) tuples describing the data provided by the view.

        The ordering of values inside the tuples is free, but it is expected
644 645
        that the list is ordered in a consistent manner (ie. all train images
        of person A, then all train images of person B, ...).
646 647 648

        For instance, assuming a view providing that kind of data:

649 650 651 652 653 654 655 656 657 658 659
        .. code-block:: text

           ----------- ----------- ----------- ----------- ----------- -----------
           |  image  | |  image  | |  image  | |  image  | |  image  | |  image  |
           ----------- ----------- ----------- ----------- ----------- -----------
           ----------- ----------- ----------- ----------- ----------- -----------
           | file_id | | file_id | | file_id | | file_id | | file_id | | file_id |
           ----------- ----------- ----------- ----------- ----------- -----------
           ----------------------------------- -----------------------------------
           |             client_id           | |             client_id           |
           ----------------------------------- -----------------------------------
660 661 662

        a list like the following should be generated:

663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679
        .. code-block:: python

           [
               (client_id=1, file_id=1, image=filename1),
               (client_id=1, file_id=2, image=filename2),
               (client_id=1, file_id=3, image=filename3),
               (client_id=2, file_id=4, image=filename4),
               (client_id=2, file_id=5, image=filename5),
               (client_id=2, file_id=6, image=filename6),
               ...
           ]

        .. warning::

           DO NOT store images, sound files or data loadable from a file in the
           list!  Store the path of the file to load instead.

680 681 682 683 684 685
        """

        raise NotImplementedError

    def setup(self, root_folder, parameters, objs, start_index=None, end_index=None):

686
        # Initialisation
687 688 689 690 691 692 693 694
        self.root_folder = root_folder
        self.parameters = parameters
        self.objs = objs

        # Determine the range of indices that must be provided
        self.start_index = start_index if start_index is not None else 0
        self.end_index = end_index if end_index is not None else len(self.objs) - 1

695
        self.objs = self.objs[self.start_index : self.end_index + 1]  # noqa
696 697

    def get(self, output, index):
698 699 700 701
        """Returns the data of the provided output at the provided index in the
        list of (named) tuples describing the data provided by the view
        (accessible at self.objs)
        """
702 703

        raise NotImplementedError
704

705 706 707 708 709 710
    def get_output_mapping(self, output):
        """Returns the object member to use for given output if any otherwise
        the member name is the output name.
        """
        return self.output_member_map.get(output, output)

711

712
# ----------------------------------------------------------
713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744


class DatabaseTester:
    """Used while developing a new database view, to test its behavior

    This class tests that, for each combination of connected/not connected
    outputs:

      - Data indices seems consistent
      - All the connected outputs produce data
      - All the not connected outputs don't produce data

    It also report some stats, and can generate a text file detailing the
    data generated by each output.

    By default, outputs are assumed to produce data at constant intervals.
    Those that don't follow this pattern, must be declared as 'irregular'.

    Note that no particular check is done about the database declaration or
    the correctness of the generated data with their data formats. This class
    is mainly used to check that the outputs are correctly synchronized.
    """

    # Mock output class
    class MockOutput:
        def __init__(self, name, connected):
            self.name = name
            self.connected = connected
            self.last_written_data_index = -1
            self.written_data = []

        def write(self, data, end_data_index):
745 746 747
            self.written_data.append(
                (self.last_written_data_index + 1, end_data_index, data)
            )
748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778
            self.last_written_data_index = end_data_index

        def isConnected(self):
            return self.connected

    class SynchronizedUnit:
        def __init__(self, start, end):
            self.start = start
            self.end = end
            self.data = {}
            self.children = []

        def addData(self, output, start, end, data):
            if (start == self.start) and (end == self.end):
                self.data[output] = self._dataToString(data)
            elif (len(self.children) == 0) or (self.children[-1].end < start):
                unit = DatabaseTester.SynchronizedUnit(start, end)
                unit.addData(output, start, end, data)
                self.children.append(unit)
            else:
                for index, unit in enumerate(self.children):
                    if (unit.start <= start) and (unit.end >= end):
                        unit.addData(output, start, end, data)
                        break
                    elif (unit.start == start) and (unit.end < end):
                        new_unit = DatabaseTester.SynchronizedUnit(start, end)
                        new_unit.addData(output, start, end, data)
                        new_unit.children.append(unit)

                        for i in range(index + 1, len(self.children)):
                            unit = self.children[i]
779
                            if unit.end <= end:
780 781 782 783
                                new_unit.children.append(unit)
                            else:
                                break

784 785 786
                        self.children = (
                            self.children[:index] + [new_unit] + self.children[i:]
                        )
787 788 789 790 791 792 793 794
                        break

        def toString(self):
            texts = {}

            for child in self.children:
                child_texts = child.toString()
                for output, text in child_texts.items():
795
                    if output in texts:
796
                        texts[output] += " " + text
797 798 799 800
                    else:
                        texts[output] = text

            if len(self.data) > 0:
801
                length = max([len(x) + 6 for x in self.data.values()])
802 803 804 805 806 807 808 809 810 811 812 813 814 815 816

                if len(texts) > 0:
                    children_length = len(texts.values()[0])
                    if children_length >= length:
                        length = children_length
                    else:
                        diff = length - children_length
                        if diff % 2 == 0:
                            diff1 = diff / 2
                            diff2 = diff1
                        else:
                            diff1 = diff // 2
                            diff2 = diff - diff1

                        for k, v in texts.items():
817
                            texts[k] = "|%s%s%s|" % ("-" * diff1, v[1:-1], "-" * diff2)
818 819 820 821 822 823 824 825 826 827

                for output, value in self.data.items():
                    output_length = len(value) + 6
                    diff = length - output_length
                    if diff % 2 == 0:
                        diff1 = diff / 2
                        diff2 = diff1
                    else:
                        diff1 = diff // 2
                        diff2 = diff - diff1
828
                    texts[output] = "|-%s %s %s-|" % ("-" * diff1, value, "-" * diff2)
829 830 831 832

            length = max(len(x) for x in texts.values())
            for k, v in texts.items():
                if len(v) < length:
833
                    texts[k] += " " * (length - len(v))
834 835 836 837 838

            return texts

        def _dataToString(self, data):
            if (len(data) > 1) or (len(data) == 0):
839
                return "X"
840 841 842 843

            value = data[data.keys()[0]]

            if isinstance(value, np.ndarray) or isinstance(value, dict):
844
                return "X"
845 846 847

            return str(value)

848 849 850 851 852 853 854 855 856
    def __init__(
        self,
        name,
        view_class,
        outputs_declaration,
        parameters,
        irregular_outputs=[],
        all_combinations=True,
    ):
857 858 859 860 861 862 863 864 865 866
        self.name = name
        self.view_class = view_class
        self.outputs_declaration = {}
        self.parameters = parameters
        self.irregular_outputs = irregular_outputs

        self.determine_regular_intervals(outputs_declaration)

        if all_combinations:
            for L in range(0, len(self.outputs_declaration) + 1):
867 868 869
                for subset in itertools.combinations(
                    self.outputs_declaration.keys(), L
                ):
870 871 872 873 874 875 876 877 878 879
                    self.run(subset)
        else:
            self.run(self.outputs_declaration.keys())

    def determine_regular_intervals(self, outputs_declaration):
        outputs = OutputList()
        for name in outputs_declaration:
            outputs.add(DatabaseTester.MockOutput(name, True))

        view = self.view_class()
880
        view.setup("", outputs, self.parameters)
881 882 883 884 885

        view.next()

        for output in outputs:
            if output.name not in self.irregular_outputs:
886 887 888
                self.outputs_declaration[output.name] = (
                    output.last_written_data_index + 1
                )
889 890 891 892 893 894 895
            else:
                self.outputs_declaration[output.name] = None

    def run(self, connected_outputs):
        if len(connected_outputs) == 0:
            return

896 897 898 899
        print(
            "Testing '%s', with %d output(s): %s"
            % (self.name, len(connected_outputs), ", ".join(connected_outputs))
        )
900 901

        # Create the mock outputs
902 903 904 905 906 907 908 909 910 911 912
        connected_outputs = dict(
            [x for x in self.outputs_declaration.items() if x[0] in connected_outputs]
        )

        not_connected_outputs = dict(
            [
                x
                for x in self.outputs_declaration.items()
                if x[0] not in connected_outputs
            ]
        )
913 914 915 916 917 918 919

        outputs = OutputList()
        for name in self.outputs_declaration.keys():
            outputs.add(DatabaseTester.MockOutput(name, name in connected_outputs))

        # Create the view
        view = self.view_class()
920
        view.setup("", outputs, self.parameters)
921 922 923 924 925 926 927 928 929 930

        # Initialisations
        next_expected_indices = {}
        for name, interval in connected_outputs.items():
            next_expected_indices[name] = 0

        next_index = 0

        def _done():
            for output in outputs:
931 932 933
                if output.isConnected() and not view.done(
                    output.last_written_data_index
                ):
934 935 936 937
                    return False
            return True

        # Ask for all the data
938
        while not (_done()):
939 940 941 942 943
            view.next()

            # Check the indices for the connected outputs
            for name in connected_outputs.keys():
                if name not in self.irregular_outputs:
944 945 946 947 948 949 950 951 952
                    if not (
                        outputs[name].written_data[-1][0] == next_expected_indices[name]
                    ):
                        raise OutputError("Wrong current index")
                    if not (
                        outputs[name].written_data[-1][1]
                        == next_expected_indices[name] + connected_outputs[name] - 1
                    ):
                        raise OutputError("Wrong next index")
953
                else:
954 955 956 957 958 959 960 961
                    if not (
                        outputs[name].written_data[-1][0] == next_expected_indices[name]
                    ):
                        raise OutputError("Wrong current index")
                    if not (
                        outputs[name].written_data[-1][1] >= next_expected_indices[name]
                    ):
                        raise OutputError("Wrong next index")
962 963 964

            # Check that the not connected outputs didn't produce data
            for name in not_connected_outputs.keys():
965 966
                if len(outputs[name].written_data) != 0:
                    raise OutputError("Data written on unconnected output")
967 968

            # Compute the next data index that should be produced
969 970 971
            next_index = 1 + min(
                [x.written_data[-1][1] for x in outputs if x.isConnected()]
            )
972

973 974
            # Compute the next data index that should be produced by each
            # connected output
975 976
            for name in connected_outputs.keys():
                if name not in self.irregular_outputs:
977 978 979 980
                    if (
                        next_index
                        == next_expected_indices[name] + connected_outputs[name]
                    ):
981 982 983
                        next_expected_indices[name] += connected_outputs[name]
                else:
                    if next_index > outputs[name].written_data[-1][1]:
984 985 986
                        next_expected_indices[name] = (
                            outputs[name].written_data[-1][1] + 1
                        )
987 988 989

        # Check the number of data produced on the regular outputs
        for name in connected_outputs.keys():
990
            print("  - %s: %d data" % (name, len(outputs[name].written_data)))
991
            if name not in self.irregular_outputs:
992 993 994 995 996
                if not (
                    len(outputs[name].written_data)
                    == next_index / connected_outputs[name]
                ):
                    raise OutputError("Invalid number of data produced")
997 998 999

        # Check that all outputs ends on the same index
        for name in connected_outputs.keys():
1000 1001
            if not outputs[name].written_data[-1][1] == next_index - 1:
                raise OutputError("Outputs not on same index")
1002

1003 1004
        # Generate a text file with lots of details (only if all outputs are
        # connected)
1005 1006 1007
        if len(connected_outputs) == len(self.outputs_declaration):
            sorted_outputs = sorted(outputs, key=lambda x: len(x.written_data))

1008 1009 1010
            unit = DatabaseTester.SynchronizedUnit(
                0, sorted_outputs[0].written_data[-1][1]
            )
1011 1012 1013 1014 1015 1016 1017

            for output in sorted_outputs:
                for data in output.written_data:
                    unit.addData(output.name, data[0], data[1], data[2])

            texts = unit.toString()

1018
            outputs_max_length = max([len(x) for x in self.outputs_declaration.keys()])
1019

1020
            with open(self.name.replace(" ", "_") + ".txt", "w") as f:
1021 1022
                for i in range(1, len(sorted_outputs) + 1):
                    output_name = sorted_outputs[-i].name
1023
                    f.write(output_name + ": ")
1024 1025

                    if len(output_name) < outputs_max_length:
1026
                        f.write(" " * (outputs_max_length - len(output_name)))
1027

1028
                    f.write(texts[output_name] + "\n")