database.py 29.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :

###############################################################################
#                                                                             #
# Copyright (c) 2016 Idiap Research Institute, http://www.idiap.ch/           #
# Contact: beat.support@idiap.ch                                              #
#                                                                             #
# This file is part of the beat.backend.python module of the BEAT platform.   #
#                                                                             #
# Commercial License Usage                                                    #
# Licensees holding valid commercial BEAT licenses may use this file in       #
# accordance with the terms contained in a written agreement between you      #
# and Idiap. For further information contact tto@idiap.ch                     #
#                                                                             #
# Alternatively, this file may be used under the terms of the GNU Affero      #
# Public License version 3 as published by the Free Software and appearing    #
# in the file LICENSE.AGPL included in the packaging of this file.            #
# The BEAT platform is distributed in the hope that it will be useful, but    #
# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY  #
# or FITNESS FOR A PARTICULAR PURPOSE.                                        #
#                                                                             #
# You should have received a copy of the GNU Affero Public License along      #
# with the BEAT platform. If not, see http://www.gnu.org/licenses/.           #
#                                                                             #
###############################################################################


29 30 31 32 33 34 35
"""
========
database
========

Validation of databases
"""
36 37 38 39 40 41

import os
import sys

import six
import simplejson
42 43
import itertools
import numpy as np
44
from collections import namedtuple
45 46

from . import loader
Philip ABBET's avatar
Philip ABBET committed
47 48 49
from . import utils

from .dataformat import DataFormat
50
from .outputs import OutputList
Philip ABBET's avatar
Philip ABBET committed
51 52


53
# ----------------------------------------------------------
54

Philip ABBET's avatar
Philip ABBET committed
55 56

class Storage(utils.CodeStorage):
Philip ABBET's avatar
Philip ABBET committed
57
    """Resolves paths for databases
Philip ABBET's avatar
Philip ABBET committed
58

Philip ABBET's avatar
Philip ABBET committed
59
    Parameters:
Philip ABBET's avatar
Philip ABBET committed
60

61
      prefix (str): Establishes the prefix of your installation.
Philip ABBET's avatar
Philip ABBET committed
62

Philip ABBET's avatar
Philip ABBET committed
63 64
      name (str): The name of the database object in the format
        ``<name>/<version>``.
Philip ABBET's avatar
Philip ABBET committed
65

Philip ABBET's avatar
Philip ABBET committed
66
    """
Philip ABBET's avatar
Philip ABBET committed
67

Philip ABBET's avatar
Philip ABBET committed
68
    def __init__(self, prefix, name):
Philip ABBET's avatar
Philip ABBET committed
69

Philip ABBET's avatar
Philip ABBET committed
70 71
        if name.count('/') != 1:
            raise RuntimeError("invalid database name: `%s'" % name)
Philip ABBET's avatar
Philip ABBET committed
72

Philip ABBET's avatar
Philip ABBET committed
73 74
        self.name, self.version = name.split('/')
        self.fullname = name
75
        self.prefix = prefix
Philip ABBET's avatar
Philip ABBET committed
76

77
        path = os.path.join(self.prefix, 'databases', name + '.json')
78
        path = path[:-5]
79 80
        # views are coded in Python
        super(Storage, self).__init__(path, 'python')
81 82


83
# ----------------------------------------------------------
84

85

86
class Runner(object):
Philip ABBET's avatar
Philip ABBET committed
87
    '''A special loader class for database views, with specialized methods
88

Philip ABBET's avatar
Philip ABBET committed
89
    Parameters:
90

Philip ABBET's avatar
Philip ABBET committed
91
      db_name (str): The full name of the database object for this view
92

93 94
      module (:std:term:`module`): The preloaded module containing the database
        views as returned by :py:func:`.loader.load_module`.
95

96
      prefix (str): Establishes the prefix of your installation.
97

André Anjos's avatar
André Anjos committed
98
      root_folder (str): The path pointing to the root folder of this database
99

100 101 102
      exc (:std:term:`class`): The class to use as base exception when
        translating the exception from the user code. Read the documention of
        :py:func:`.loader.run` for more details.
103

Philip ABBET's avatar
Philip ABBET committed
104
      *args: Constructor parameters for the database view. Normally, none.
105

Philip ABBET's avatar
Philip ABBET committed
106
      **kwargs: Constructor parameters for the database view. Normally, none.
107

Philip ABBET's avatar
Philip ABBET committed
108
    '''
109

110
    def __init__(self, module, definition, prefix, root_folder, exc=None):
111

Philip ABBET's avatar
Philip ABBET committed
112 113 114 115 116 117 118
        try:
            class_ = getattr(module, definition['view'])
        except Exception as e:
            if exc is not None:
                type, value, traceback = sys.exc_info()
                six.reraise(exc, exc(value), traceback)
            else:
119
                raise  # just re-raise the user exception
120

121 122 123 124 125 126 127
        self.obj          = loader.run(class_, '__new__', exc)
        self.ready        = False
        self.prefix       = prefix
        self.root_folder  = root_folder
        self.definition   = definition
        self.exc          = exc or RuntimeError
        self.data_sources = None
128 129


130 131
    def index(self, filename):
        '''Index the content of the view'''
132

133
        parameters = self.definition.get('parameters', {})
134

135
        objs = loader.run(self.obj, 'index', self.exc, self.root_folder, parameters)
136

137 138
        if not isinstance(objs, list):
            raise self.exc("index() didn't return a list")
139

140 141
        if not os.path.exists(os.path.dirname(filename)):
            os.makedirs(os.path.dirname(filename))
142

143
        with open(filename, 'wb') as f:
144 145
            data = simplejson.dumps(objs, cls=utils.NumpyJSONEncoder)
            f.write(data.encode('utf-8'))
146 147


148 149
    def setup(self, filename, start_index=None, end_index=None, pack=True):
        '''Sets up the view'''
150

151 152
        if self.ready:
            return
153

154
        with open(filename, 'rb') as f:
155
            objs = simplejson.loads(f.read().decode('utf-8'))
156

157 158
        Entry = namedtuple('Entry', sorted(objs[0].keys()))
        objs = [ Entry(**x) for x in objs ]
159

160
        parameters = self.definition.get('parameters', {})
161

162 163
        loader.run(self.obj, 'setup', self.exc, self.root_folder, parameters,
                   objs, start_index=start_index, end_index=end_index)
164 165


166 167 168
        # Create data sources for the outputs
        from .data import DatabaseOutputDataSource
        from .dataformat import DataFormat
169

170 171 172 173 174 175
        self.data_sources = {}
        for output_name, output_format in self.definition.get('outputs', {}).items():
            data_source = DatabaseOutputDataSource()
            data_source.setup(self, output_name, output_format, self.prefix,
                              start_index=start_index, end_index=end_index, pack=pack)
            self.data_sources[output_name] = data_source
176

177
        self.ready = True
178 179


180 181
    def get(self, output, index):
        '''Returns the data of the provided output at the provided index'''
182

Philip ABBET's avatar
Philip ABBET committed
183
        if not self.ready:
184
            raise self.exc("Database view not yet setup")
185

186
        return loader.run(self.obj, 'get', self.exc, output, index)
187 188


189 190 191
    def get_output_mapping(self, output):
        return loader.run(self.obj, 'get_output_mapping', self.exc, output)

192 193 194 195
    def objects(self):
        return self.obj.objs


196
# ----------------------------------------------------------
197

198 199

class Database(object):
Philip ABBET's avatar
Philip ABBET committed
200
    """Databases define the start point of the dataflow in an experiment.
201 202


Philip ABBET's avatar
Philip ABBET committed
203
    Parameters:
204

205
      prefix (str): Establishes the prefix of your installation.
206

Philip ABBET's avatar
Philip ABBET committed
207
      name (str): The fully qualified database name (e.g. ``db/1``)
208

209 210 211 212 213 214
      dataformat_cache (:py:class:`dict`, Optional): A dictionary mapping
        dataformat names to loaded dataformats. This parameter is optional and,
        if passed, may greatly speed-up database loading times as dataformats
        that are already loaded may be re-used. If you use this parameter, you
        must guarantee that the cache is refreshed as appropriate in case the
        underlying dataformats change.
215 216


Philip ABBET's avatar
Philip ABBET committed
217
    Attributes:
218

Philip ABBET's avatar
Philip ABBET committed
219
      name (str): The full, valid name of this database
220

Philip ABBET's avatar
Philip ABBET committed
221 222
      data (dict): The original data for this database, as loaded by our JSON
        decoder.
223

Philip ABBET's avatar
Philip ABBET committed
224
    """
225

Philip ABBET's avatar
Philip ABBET committed
226
    def __init__(self, prefix, name, dataformat_cache=None):
227

Philip ABBET's avatar
Philip ABBET committed
228 229
        self._name = None
        self.prefix = prefix
230
        self.dataformats = {}  # preloaded dataformats
Philip ABBET's avatar
Philip ABBET committed
231
        self.storage = None
232

Philip ABBET's avatar
Philip ABBET committed
233 234
        self.errors = []
        self.data = None
235

Philip ABBET's avatar
Philip ABBET committed
236 237
        # if the user has not provided a cache, still use one for performance
        dataformat_cache = dataformat_cache if dataformat_cache is not None else {}
238

Philip ABBET's avatar
Philip ABBET committed
239
        self._load(name, dataformat_cache)
240 241


Philip ABBET's avatar
Philip ABBET committed
242 243
    def _load(self, data, dataformat_cache):
        """Loads the database"""
244

Philip ABBET's avatar
Philip ABBET committed
245
        self._name = data
Philip ABBET's avatar
Philip ABBET committed
246

Philip ABBET's avatar
Philip ABBET committed
247 248 249 250 251
        self.storage = Storage(self.prefix, self._name)
        json_path = self.storage.json.path
        if not self.storage.json.exists():
            self.errors.append('Database declaration file not found: %s' % json_path)
            return
Philip ABBET's avatar
Philip ABBET committed
252

Philip ABBET's avatar
Philip ABBET committed
253
        with open(json_path, 'rb') as f:
254
            self.data = simplejson.loads(f.read().decode('utf-8'))
Philip ABBET's avatar
Philip ABBET committed
255

256 257 258
        self.code_path = self.storage.code.path
        self.code = self.storage.code.load()

Philip ABBET's avatar
Philip ABBET committed
259 260
        for protocol in self.data['protocols']:
            for _set in protocol['sets']:
Philip ABBET's avatar
Philip ABBET committed
261

Philip ABBET's avatar
Philip ABBET committed
262
                for key, value in _set['outputs'].items():
Philip ABBET's avatar
Philip ABBET committed
263

Philip ABBET's avatar
Philip ABBET committed
264 265
                    if value in self.dataformats:
                        continue
Philip ABBET's avatar
Philip ABBET committed
266

Philip ABBET's avatar
Philip ABBET committed
267 268 269 270 271
                    if value in dataformat_cache:
                        dataformat = dataformat_cache[value]
                    else:
                        dataformat = DataFormat(self.prefix, value)
                        dataformat_cache[value] = dataformat
Philip ABBET's avatar
Philip ABBET committed
272

Philip ABBET's avatar
Philip ABBET committed
273
                    self.dataformats[value] = dataformat
274 275


Philip ABBET's avatar
Philip ABBET committed
276 277 278 279 280
    @property
    def name(self):
        """Returns the name of this object
        """
        return self._name or '__unnamed_database__'
281 282


283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334
    @name.setter
    def name(self, value):
        self._name = value
        self.storage = Storage(self.prefix, value)


    @property
    def description(self):
        """The short description for this object"""
        return self.data.get('description', None)


    @description.setter
    def description(self, value):
        """Sets the short description for this object"""
        self.data['description'] = value


    @property
    def documentation(self):
        """The full-length description for this object"""

        if not self._name:
            raise RuntimeError("database has no name")

        if self.storage.doc.exists():
            return self.storage.doc.load()
        return None


    @documentation.setter
    def documentation(self, value):
        """Sets the full-length description for this object"""

        if not self._name:
            raise RuntimeError("database has no name")

        if hasattr(value, 'read'):
            self.storage.doc.save(value.read())
        else:
            self.storage.doc.save(value)


    def hash(self):
        """Returns the hexadecimal hash for its declaration"""

        if not self._name:
            raise RuntimeError("database has no name")

        return self.storage.hash()


Philip ABBET's avatar
Philip ABBET committed
335 336 337 338
    @property
    def schema_version(self):
        """Returns the schema version"""
        return self.data.get('schema_version', 1)
339 340


Philip ABBET's avatar
Philip ABBET committed
341 342
    @property
    def valid(self):
343
        """A boolean that indicates if this database is valid or not"""
Philip ABBET's avatar
Philip ABBET committed
344

345
        return not bool(self.errors)
Philip ABBET's avatar
Philip ABBET committed
346

Philip ABBET's avatar
Philip ABBET committed
347 348 349
    @property
    def protocols(self):
        """The declaration of all the protocols of the database"""
350

Philip ABBET's avatar
Philip ABBET committed
351 352
        data = self.data['protocols']
        return dict(zip([k['name'] for k in data], data))
353 354


Philip ABBET's avatar
Philip ABBET committed
355 356
    def protocol(self, name):
        """The declaration of a specific protocol in the database"""
357

Philip ABBET's avatar
Philip ABBET committed
358
        return self.protocols[name]
359 360


Philip ABBET's avatar
Philip ABBET committed
361 362 363
    @property
    def protocol_names(self):
        """Names of protocols declared for this database"""
364

Philip ABBET's avatar
Philip ABBET committed
365 366
        data = self.data['protocols']
        return [k['name'] for k in data]
367 368


Philip ABBET's avatar
Philip ABBET committed
369 370
    def sets(self, protocol):
        """The declaration of a specific set in the database protocol"""
371

Philip ABBET's avatar
Philip ABBET committed
372 373
        data = self.protocol(protocol)['sets']
        return dict(zip([k['name'] for k in data], data))
374 375


Philip ABBET's avatar
Philip ABBET committed
376 377
    def set(self, protocol, name):
        """The declaration of all the protocols of the database"""
378

Philip ABBET's avatar
Philip ABBET committed
379
        return self.sets(protocol)[name]
380 381


Philip ABBET's avatar
Philip ABBET committed
382 383
    def set_names(self, protocol):
        """The names of sets in a given protocol for this database"""
384

Philip ABBET's avatar
Philip ABBET committed
385 386
        data = self.protocol(protocol)['sets']
        return [k['name'] for k in data]
387 388


389
    def view(self, protocol, name, exc=None, root_folder=None):
Philip ABBET's avatar
Philip ABBET committed
390
        """Returns the database view, given the protocol and the set name
391

Philip ABBET's avatar
Philip ABBET committed
392
        Parameters:
393

394 395
          protocol (str): The name of the protocol where to retrieve the view
            from
396

Philip ABBET's avatar
Philip ABBET committed
397 398
          name (str): The name of the set in the protocol where to retrieve the
            view from
399

400
          exc (:std:term:`class`): If passed, must be a valid exception class
401 402
            that will be used to report errors in the read-out of this
            database's view.
403

Philip ABBET's avatar
Philip ABBET committed
404
        Returns:
405

Philip ABBET's avatar
Philip ABBET committed
406 407
          The database view, which will be constructed, but not setup. You
          **must** set it up before using methods ``done`` or ``next``.
408

Philip ABBET's avatar
Philip ABBET committed
409 410 411 412 413 414 415 416 417
        """

        if not self._name:
            exc = exc or RuntimeError
            raise exc("database has no name")

        if not self.valid:
            message = "cannot load view for set `%s' of protocol `%s' " \
                    "from invalid database (%s)" % (protocol, name, self.name)
418 419 420
            if exc:
                raise exc(message)

Philip ABBET's avatar
Philip ABBET committed
421 422
            raise RuntimeError(message)

423 424
        # loads the module only once through the lifetime of the database
        # object
Philip ABBET's avatar
Philip ABBET committed
425 426 427 428 429 430 431 432 433
        try:
            if not hasattr(self, '_module'):
                self._module = loader.load_module(self.name.replace(os.sep, '_'),
                          self.storage.code.path, {})
        except Exception as e:
            if exc is not None:
                type, value, traceback = sys.exc_info()
                six.reraise(exc, exc(value), traceback)
            else:
434
                raise  # just re-raise the user exception
435

436 437 438 439 440
        if root_folder is None:
            root_folder = self.data['root_folder']

        return Runner(self._module, self.set(protocol, name),
                      self.prefix, root_folder, exc)
441 442


443 444 445 446 447 448
    def json_dumps(self, indent=4):
        """Dumps the JSON declaration of this object in a string


        Parameters:

449 450
          indent (int): The number of indentation spaces at every indentation
            level
451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471


        Returns:

          str: The JSON representation for this object

        """

        return simplejson.dumps(self.data, indent=indent,
                                cls=utils.NumpyJSONEncoder)


    def __str__(self):
        return self.json_dumps()


    def write(self, storage=None):
        """Writes contents to prefix location

        Parameters:

472 473 474
          storage (:py:class:`.Storage`, Optional): If you pass a new storage,
            then this object will be written to that storage point rather than
            its default.
475 476 477 478 479 480

        """

        if storage is None:
            if not self._name:
                raise RuntimeError("database has no name")
481
            storage = self.storage  # overwrite
482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513

        storage.save(str(self), self.code, self.description)


    def export(self, prefix):
        """Recursively exports itself into another prefix

        Dataformats associated are also exported recursively


        Parameters:

          prefix (str): A path to a prefix that must different then my own.


        Returns:

          None


        Raises:

          RuntimeError: If prefix and self.prefix point to the same directory.

        """

        if not self._name:
            raise RuntimeError("database has no name")

        if not self.valid:
            raise RuntimeError("database is not valid")

514 515 516
        if prefix == self.prefix:
            raise RuntimeError("Cannot export database to the same prefix ("
                               "%s)" % prefix)
517 518 519 520 521 522 523

        for k in self.dataformats.values():
            k.export(prefix)

        self.write(Storage(prefix, self.name))


524
# ----------------------------------------------------------
525 526 527 528


class View(object):

529 530 531 532 533 534 535 536 537

    def __init__(self):
        #  Current databases definitions uses named tuple to store information.
        #  This has one limitation, python keywords like `class` cannot be used.
        #  output_member_map allows to use that kind of keyword as output name
        #  while using something different for the named tuple (for example cls,
        #  klass, etc.)
        self.output_member_map = {}

538 539 540 541
    def index(self, root_folder, parameters):
        """Returns a list of (named) tuples describing the data provided by the view.

        The ordering of values inside the tuples is free, but it is expected
542 543
        that the list is ordered in a consistent manner (ie. all train images
        of person A, then all train images of person B, ...).
544 545 546

        For instance, assuming a view providing that kind of data:

547 548 549 550 551 552 553 554 555 556 557
        .. code-block:: text

           ----------- ----------- ----------- ----------- ----------- -----------
           |  image  | |  image  | |  image  | |  image  | |  image  | |  image  |
           ----------- ----------- ----------- ----------- ----------- -----------
           ----------- ----------- ----------- ----------- ----------- -----------
           | file_id | | file_id | | file_id | | file_id | | file_id | | file_id |
           ----------- ----------- ----------- ----------- ----------- -----------
           ----------------------------------- -----------------------------------
           |             client_id           | |             client_id           |
           ----------------------------------- -----------------------------------
558 559 560

        a list like the following should be generated:

561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577
        .. code-block:: python

           [
               (client_id=1, file_id=1, image=filename1),
               (client_id=1, file_id=2, image=filename2),
               (client_id=1, file_id=3, image=filename3),
               (client_id=2, file_id=4, image=filename4),
               (client_id=2, file_id=5, image=filename5),
               (client_id=2, file_id=6, image=filename6),
               ...
           ]

        .. warning::

           DO NOT store images, sound files or data loadable from a file in the
           list!  Store the path of the file to load instead.

578 579 580 581 582 583 584
        """

        raise NotImplementedError


    def setup(self, root_folder, parameters, objs, start_index=None, end_index=None):

585
        # Initialisation
586 587 588 589 590 591 592 593 594 595 596 597
        self.root_folder = root_folder
        self.parameters = parameters
        self.objs = objs

        # Determine the range of indices that must be provided
        self.start_index = start_index if start_index is not None else 0
        self.end_index = end_index if end_index is not None else len(self.objs) - 1

        self.objs = self.objs[self.start_index : self.end_index + 1]


    def get(self, output, index):
598 599 600 601
        """Returns the data of the provided output at the provided index in the
        list of (named) tuples describing the data provided by the view
        (accessible at self.objs)
        """
602 603

        raise NotImplementedError
604 605


606 607 608 609 610 611
    def get_output_mapping(self, output):
        """Returns the object member to use for given output if any otherwise
        the member name is the output name.
        """
        return self.output_member_map.get(output, output)

612
# ----------------------------------------------------------
613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693


class DatabaseTester:
    """Used while developing a new database view, to test its behavior

    This class tests that, for each combination of connected/not connected
    outputs:

      - Data indices seems consistent
      - All the connected outputs produce data
      - All the not connected outputs don't produce data

    It also report some stats, and can generate a text file detailing the
    data generated by each output.

    By default, outputs are assumed to produce data at constant intervals.
    Those that don't follow this pattern, must be declared as 'irregular'.

    Note that no particular check is done about the database declaration or
    the correctness of the generated data with their data formats. This class
    is mainly used to check that the outputs are correctly synchronized.
    """

    # Mock output class
    class MockOutput:

        def __init__(self, name, connected):
            self.name = name
            self.connected = connected
            self.last_written_data_index = -1
            self.written_data = []

        def write(self, data, end_data_index):
            self.written_data.append(( self.last_written_data_index + 1, end_data_index, data ))
            self.last_written_data_index = end_data_index

        def isConnected(self):
            return self.connected


    class SynchronizedUnit:

        def __init__(self, start, end):
            self.start = start
            self.end = end
            self.data = {}
            self.children = []

        def addData(self, output, start, end, data):
            if (start == self.start) and (end == self.end):
                self.data[output] = self._dataToString(data)
            elif (len(self.children) == 0) or (self.children[-1].end < start):
                unit = DatabaseTester.SynchronizedUnit(start, end)
                unit.addData(output, start, end, data)
                self.children.append(unit)
            else:
                for index, unit in enumerate(self.children):
                    if (unit.start <= start) and (unit.end >= end):
                        unit.addData(output, start, end, data)
                        break
                    elif (unit.start == start) and (unit.end < end):
                        new_unit = DatabaseTester.SynchronizedUnit(start, end)
                        new_unit.addData(output, start, end, data)
                        new_unit.children.append(unit)

                        for i in range(index + 1, len(self.children)):
                            unit = self.children[i]
                            if (unit.end <= end):
                                new_unit.children.append(unit)
                            else:
                                break

                        self.children = self.children[:index] + [new_unit] + self.children[i:]
                        break

        def toString(self):
            texts = {}

            for child in self.children:
                child_texts = child.toString()
                for output, text in child_texts.items():
694
                    if output in texts:
695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786
                        texts[output] += ' ' + text
                    else:
                        texts[output] = text

            if len(self.data) > 0:
                length = max([ len(x) + 6 for x in self.data.values() ])

                if len(texts) > 0:
                    children_length = len(texts.values()[0])
                    if children_length >= length:
                        length = children_length
                    else:
                        diff = length - children_length
                        if diff % 2 == 0:
                            diff1 = diff / 2
                            diff2 = diff1
                        else:
                            diff1 = diff // 2
                            diff2 = diff - diff1

                        for k, v in texts.items():
                            texts[k] = '|%s%s%s|' % ('-' * diff1, v[1:-1], '-' * diff2)

                for output, value in self.data.items():
                    output_length = len(value) + 6
                    diff = length - output_length
                    if diff % 2 == 0:
                        diff1 = diff / 2
                        diff2 = diff1
                    else:
                        diff1 = diff // 2
                        diff2 = diff - diff1
                    texts[output] = '|-%s %s %s-|' % ('-' * diff1, value, '-' * diff2)

            length = max(len(x) for x in texts.values())
            for k, v in texts.items():
                if len(v) < length:
                    texts[k] += ' ' * (length - len(v))

            return texts

        def _dataToString(self, data):
            if (len(data) > 1) or (len(data) == 0):
                return 'X'

            value = data[data.keys()[0]]

            if isinstance(value, np.ndarray) or isinstance(value, dict):
                return 'X'

            return str(value)


    def __init__(self, name, view_class, outputs_declaration, parameters,
                 irregular_outputs=[], all_combinations=True):
        self.name = name
        self.view_class = view_class
        self.outputs_declaration = {}
        self.parameters = parameters
        self.irregular_outputs = irregular_outputs

        self.determine_regular_intervals(outputs_declaration)

        if all_combinations:
            for L in range(0, len(self.outputs_declaration) + 1):
                for subset in itertools.combinations(self.outputs_declaration.keys(), L):
                    self.run(subset)
        else:
            self.run(self.outputs_declaration.keys())


    def determine_regular_intervals(self, outputs_declaration):
        outputs = OutputList()
        for name in outputs_declaration:
            outputs.add(DatabaseTester.MockOutput(name, True))

        view = self.view_class()
        view.setup('', outputs, self.parameters)

        view.next()

        for output in outputs:
            if output.name not in self.irregular_outputs:
                self.outputs_declaration[output.name] = output.last_written_data_index + 1
            else:
                self.outputs_declaration[output.name] = None


    def run(self, connected_outputs):
        if len(connected_outputs) == 0:
            return

787 788
        print("Testing '%s', with %d output(s): %s" % \
            (self.name, len(connected_outputs), ', '.join(connected_outputs)))
789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840

        # Create the mock outputs
        connected_outputs = dict([ x for x in self.outputs_declaration.items()
                                     if x[0] in connected_outputs ])

        not_connected_outputs = dict([ x for x in self.outputs_declaration.items()
                                         if x[0] not in connected_outputs ])

        outputs = OutputList()
        for name in self.outputs_declaration.keys():
            outputs.add(DatabaseTester.MockOutput(name, name in connected_outputs))


        # Create the view
        view = self.view_class()
        view.setup('', outputs, self.parameters)


        # Initialisations
        next_expected_indices = {}
        for name, interval in connected_outputs.items():
            next_expected_indices[name] = 0

        next_index = 0

        def _done():
            for output in outputs:
                if output.isConnected() and not view.done(output.last_written_data_index):
                    return False
            return True


        # Ask for all the data
        while not(_done()):
            view.next()

            # Check the indices for the connected outputs
            for name in connected_outputs.keys():
                if name not in self.irregular_outputs:
                    assert(outputs[name].written_data[-1][0] == next_expected_indices[name])
                    assert(outputs[name].written_data[-1][1] == next_expected_indices[name] + connected_outputs[name] - 1)
                else:
                    assert(outputs[name].written_data[-1][0] == next_expected_indices[name])
                    assert(outputs[name].written_data[-1][1] >= next_expected_indices[name])

            # Check that the not connected outputs didn't produce data
            for name in not_connected_outputs.keys():
                assert(len(outputs[name].written_data) == 0)

            # Compute the next data index that should be produced
            next_index = 1 + min([ x.written_data[-1][1] for x in outputs if x.isConnected() ])

841 842
            # Compute the next data index that should be produced by each
            # connected output
843 844 845 846 847 848 849 850 851 852
            for name in connected_outputs.keys():
                if name not in self.irregular_outputs:
                    if next_index == next_expected_indices[name] + connected_outputs[name]:
                        next_expected_indices[name] += connected_outputs[name]
                else:
                    if next_index > outputs[name].written_data[-1][1]:
                        next_expected_indices[name] = outputs[name].written_data[-1][1] + 1

        # Check the number of data produced on the regular outputs
        for name in connected_outputs.keys():
853
            print('  - %s: %d data' % (name, len(outputs[name].written_data)))
854 855 856 857 858 859 860 861
            if name not in self.irregular_outputs:
                assert(len(outputs[name].written_data) == next_index / connected_outputs[name])

        # Check that all outputs ends on the same index
        for name in connected_outputs.keys():
            assert(outputs[name].written_data[-1][1] == next_index - 1)


862 863
        # Generate a text file with lots of details (only if all outputs are
        # connected)
864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885
        if len(connected_outputs) == len(self.outputs_declaration):
            sorted_outputs = sorted(outputs, key=lambda x: len(x.written_data))

            unit = DatabaseTester.SynchronizedUnit(0, sorted_outputs[0].written_data[-1][1])

            for output in sorted_outputs:
                for data in output.written_data:
                    unit.addData(output.name, data[0], data[1], data[2])

            texts = unit.toString()

            outputs_max_length = max([ len(x) for x in self.outputs_declaration.keys() ])

            with open(self.name.replace(' ', '_') + '.txt', 'w') as f:
                for i in range(1, len(sorted_outputs) + 1):
                    output_name = sorted_outputs[-i].name
                    f.write(output_name + ': ')

                    if len(output_name) < outputs_max_length:
                        f.write(' ' * (outputs_max_length - len(output_name)))

                    f.write(texts[output_name] + '\n')