database.py 29.9 KB
Newer Older
1 2 3
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :

4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
###################################################################################
#                                                                                 #
# Copyright (c) 2019 Idiap Research Institute, http://www.idiap.ch/               #
# Contact: beat.support@idiap.ch                                                  #
#                                                                                 #
# Redistribution and use in source and binary forms, with or without              #
# modification, are permitted provided that the following conditions are met:     #
#                                                                                 #
# 1. Redistributions of source code must retain the above copyright notice, this  #
# list of conditions and the following disclaimer.                                #
#                                                                                 #
# 2. Redistributions in binary form must reproduce the above copyright notice,    #
# this list of conditions and the following disclaimer in the documentation       #
# and/or other materials provided with the distribution.                          #
#                                                                                 #
# 3. Neither the name of the copyright holder nor the names of its contributors   #
# may be used to endorse or promote products derived from this software without   #
# specific prior written permission.                                              #
#                                                                                 #
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND #
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED   #
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE          #
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE    #
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL      #
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR      #
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER      #
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,   #
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE   #
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.            #
#                                                                                 #
###################################################################################
35 36


37 38 39 40 41 42 43
"""
========
database
========

Validation of databases
"""
44 45 46 47 48 49

import os
import sys

import six
import simplejson
50 51
import itertools
import numpy as np
52
from collections import namedtuple
53 54

from . import loader
Philip ABBET's avatar
Philip ABBET committed
55 56 57
from . import utils

from .dataformat import DataFormat
58
from .outputs import OutputList
Philip ABBET's avatar
Philip ABBET committed
59 60


61
# ----------------------------------------------------------
62

Philip ABBET's avatar
Philip ABBET committed
63 64

class Storage(utils.CodeStorage):
Philip ABBET's avatar
Philip ABBET committed
65
    """Resolves paths for databases
Philip ABBET's avatar
Philip ABBET committed
66

Philip ABBET's avatar
Philip ABBET committed
67
    Parameters:
Philip ABBET's avatar
Philip ABBET committed
68

69
      prefix (str): Establishes the prefix of your installation.
Philip ABBET's avatar
Philip ABBET committed
70

Philip ABBET's avatar
Philip ABBET committed
71 72
      name (str): The name of the database object in the format
        ``<name>/<version>``.
Philip ABBET's avatar
Philip ABBET committed
73

Philip ABBET's avatar
Philip ABBET committed
74
    """
Philip ABBET's avatar
Philip ABBET committed
75

Philip ABBET's avatar
Philip ABBET committed
76
    def __init__(self, prefix, name):
Philip ABBET's avatar
Philip ABBET committed
77

Philip ABBET's avatar
Philip ABBET committed
78 79
        if name.count('/') != 1:
            raise RuntimeError("invalid database name: `%s'" % name)
Philip ABBET's avatar
Philip ABBET committed
80

Philip ABBET's avatar
Philip ABBET committed
81 82
        self.name, self.version = name.split('/')
        self.fullname = name
83
        self.prefix = prefix
Philip ABBET's avatar
Philip ABBET committed
84

85
        path = os.path.join(self.prefix, 'databases', name + '.json')
86
        path = path[:-5]
87 88
        # views are coded in Python
        super(Storage, self).__init__(path, 'python')
89 90


91
# ----------------------------------------------------------
92

93

94
class Runner(object):
Philip ABBET's avatar
Philip ABBET committed
95
    '''A special loader class for database views, with specialized methods
96

Philip ABBET's avatar
Philip ABBET committed
97
    Parameters:
98

Philip ABBET's avatar
Philip ABBET committed
99
      db_name (str): The full name of the database object for this view
100

101 102
      module (:std:term:`module`): The preloaded module containing the database
        views as returned by :py:func:`.loader.load_module`.
103

104
      prefix (str): Establishes the prefix of your installation.
105

André Anjos's avatar
André Anjos committed
106
      root_folder (str): The path pointing to the root folder of this database
107

108 109 110
      exc (:std:term:`class`): The class to use as base exception when
        translating the exception from the user code. Read the documention of
        :py:func:`.loader.run` for more details.
111

Philip ABBET's avatar
Philip ABBET committed
112
      *args: Constructor parameters for the database view. Normally, none.
113

Philip ABBET's avatar
Philip ABBET committed
114
      **kwargs: Constructor parameters for the database view. Normally, none.
115

Philip ABBET's avatar
Philip ABBET committed
116
    '''
117

118
    def __init__(self, module, definition, prefix, root_folder, exc=None):
119

Philip ABBET's avatar
Philip ABBET committed
120 121 122 123 124 125 126
        try:
            class_ = getattr(module, definition['view'])
        except Exception as e:
            if exc is not None:
                type, value, traceback = sys.exc_info()
                six.reraise(exc, exc(value), traceback)
            else:
127
                raise  # just re-raise the user exception
128

129 130 131 132 133 134 135
        self.obj          = loader.run(class_, '__new__', exc)
        self.ready        = False
        self.prefix       = prefix
        self.root_folder  = root_folder
        self.definition   = definition
        self.exc          = exc or RuntimeError
        self.data_sources = None
136 137


138 139
    def index(self, filename):
        '''Index the content of the view'''
140

141
        parameters = self.definition.get('parameters', {})
142

143
        objs = loader.run(self.obj, 'index', self.exc, self.root_folder, parameters)
144

145 146
        if not isinstance(objs, list):
            raise self.exc("index() didn't return a list")
147

148 149
        if not os.path.exists(os.path.dirname(filename)):
            os.makedirs(os.path.dirname(filename))
150

151
        with open(filename, 'wb') as f:
152 153
            data = simplejson.dumps(objs, cls=utils.NumpyJSONEncoder)
            f.write(data.encode('utf-8'))
154 155


156 157
    def setup(self, filename, start_index=None, end_index=None, pack=True):
        '''Sets up the view'''
158

159 160
        if self.ready:
            return
161

162
        with open(filename, 'rb') as f:
163
            objs = simplejson.loads(f.read().decode('utf-8'))
164

165 166
        Entry = namedtuple('Entry', sorted(objs[0].keys()))
        objs = [ Entry(**x) for x in objs ]
167

168
        parameters = self.definition.get('parameters', {})
169

170 171
        loader.run(self.obj, 'setup', self.exc, self.root_folder, parameters,
                   objs, start_index=start_index, end_index=end_index)
172 173


174 175 176
        # Create data sources for the outputs
        from .data import DatabaseOutputDataSource
        from .dataformat import DataFormat
177

178 179 180 181 182 183
        self.data_sources = {}
        for output_name, output_format in self.definition.get('outputs', {}).items():
            data_source = DatabaseOutputDataSource()
            data_source.setup(self, output_name, output_format, self.prefix,
                              start_index=start_index, end_index=end_index, pack=pack)
            self.data_sources[output_name] = data_source
184

185
        self.ready = True
186 187


188 189
    def get(self, output, index):
        '''Returns the data of the provided output at the provided index'''
190

Philip ABBET's avatar
Philip ABBET committed
191
        if not self.ready:
192
            raise self.exc("Database view not yet setup")
193

194
        return loader.run(self.obj, 'get', self.exc, output, index)
195 196


197 198 199
    def get_output_mapping(self, output):
        return loader.run(self.obj, 'get_output_mapping', self.exc, output)

200 201 202 203
    def objects(self):
        return self.obj.objs


204
# ----------------------------------------------------------
205

206 207

class Database(object):
Philip ABBET's avatar
Philip ABBET committed
208
    """Databases define the start point of the dataflow in an experiment.
209 210


Philip ABBET's avatar
Philip ABBET committed
211
    Parameters:
212

213
      prefix (str): Establishes the prefix of your installation.
214

Philip ABBET's avatar
Philip ABBET committed
215
      name (str): The fully qualified database name (e.g. ``db/1``)
216

217 218 219 220 221 222
      dataformat_cache (:py:class:`dict`, Optional): A dictionary mapping
        dataformat names to loaded dataformats. This parameter is optional and,
        if passed, may greatly speed-up database loading times as dataformats
        that are already loaded may be re-used. If you use this parameter, you
        must guarantee that the cache is refreshed as appropriate in case the
        underlying dataformats change.
223 224


Philip ABBET's avatar
Philip ABBET committed
225
    Attributes:
226

Philip ABBET's avatar
Philip ABBET committed
227
      name (str): The full, valid name of this database
228

Philip ABBET's avatar
Philip ABBET committed
229 230
      data (dict): The original data for this database, as loaded by our JSON
        decoder.
231

Philip ABBET's avatar
Philip ABBET committed
232
    """
233

Philip ABBET's avatar
Philip ABBET committed
234
    def __init__(self, prefix, name, dataformat_cache=None):
235

Philip ABBET's avatar
Philip ABBET committed
236 237
        self._name = None
        self.prefix = prefix
238
        self.dataformats = {}  # preloaded dataformats
Philip ABBET's avatar
Philip ABBET committed
239
        self.storage = None
240

Philip ABBET's avatar
Philip ABBET committed
241 242
        self.errors = []
        self.data = None
243

Philip ABBET's avatar
Philip ABBET committed
244 245
        # if the user has not provided a cache, still use one for performance
        dataformat_cache = dataformat_cache if dataformat_cache is not None else {}
246

Philip ABBET's avatar
Philip ABBET committed
247
        self._load(name, dataformat_cache)
248 249


Philip ABBET's avatar
Philip ABBET committed
250 251
    def _load(self, data, dataformat_cache):
        """Loads the database"""
252

Philip ABBET's avatar
Philip ABBET committed
253
        self._name = data
Philip ABBET's avatar
Philip ABBET committed
254

Philip ABBET's avatar
Philip ABBET committed
255 256 257 258 259
        self.storage = Storage(self.prefix, self._name)
        json_path = self.storage.json.path
        if not self.storage.json.exists():
            self.errors.append('Database declaration file not found: %s' % json_path)
            return
Philip ABBET's avatar
Philip ABBET committed
260

Philip ABBET's avatar
Philip ABBET committed
261
        with open(json_path, 'rb') as f:
262
            self.data = simplejson.loads(f.read().decode('utf-8'))
Philip ABBET's avatar
Philip ABBET committed
263

264 265 266
        self.code_path = self.storage.code.path
        self.code = self.storage.code.load()

Philip ABBET's avatar
Philip ABBET committed
267 268
        for protocol in self.data['protocols']:
            for _set in protocol['sets']:
Philip ABBET's avatar
Philip ABBET committed
269

Philip ABBET's avatar
Philip ABBET committed
270
                for key, value in _set['outputs'].items():
Philip ABBET's avatar
Philip ABBET committed
271

Philip ABBET's avatar
Philip ABBET committed
272 273
                    if value in self.dataformats:
                        continue
Philip ABBET's avatar
Philip ABBET committed
274

Philip ABBET's avatar
Philip ABBET committed
275 276 277 278 279
                    if value in dataformat_cache:
                        dataformat = dataformat_cache[value]
                    else:
                        dataformat = DataFormat(self.prefix, value)
                        dataformat_cache[value] = dataformat
Philip ABBET's avatar
Philip ABBET committed
280

Philip ABBET's avatar
Philip ABBET committed
281
                    self.dataformats[value] = dataformat
282 283


Philip ABBET's avatar
Philip ABBET committed
284 285 286 287 288
    @property
    def name(self):
        """Returns the name of this object
        """
        return self._name or '__unnamed_database__'
289 290


291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342
    @name.setter
    def name(self, value):
        self._name = value
        self.storage = Storage(self.prefix, value)


    @property
    def description(self):
        """The short description for this object"""
        return self.data.get('description', None)


    @description.setter
    def description(self, value):
        """Sets the short description for this object"""
        self.data['description'] = value


    @property
    def documentation(self):
        """The full-length description for this object"""

        if not self._name:
            raise RuntimeError("database has no name")

        if self.storage.doc.exists():
            return self.storage.doc.load()
        return None


    @documentation.setter
    def documentation(self, value):
        """Sets the full-length description for this object"""

        if not self._name:
            raise RuntimeError("database has no name")

        if hasattr(value, 'read'):
            self.storage.doc.save(value.read())
        else:
            self.storage.doc.save(value)


    def hash(self):
        """Returns the hexadecimal hash for its declaration"""

        if not self._name:
            raise RuntimeError("database has no name")

        return self.storage.hash()


Philip ABBET's avatar
Philip ABBET committed
343 344 345 346
    @property
    def schema_version(self):
        """Returns the schema version"""
        return self.data.get('schema_version', 1)
347 348


Philip ABBET's avatar
Philip ABBET committed
349 350
    @property
    def valid(self):
351
        """A boolean that indicates if this database is valid or not"""
Philip ABBET's avatar
Philip ABBET committed
352

353
        return not bool(self.errors)
Philip ABBET's avatar
Philip ABBET committed
354

Philip ABBET's avatar
Philip ABBET committed
355 356 357
    @property
    def protocols(self):
        """The declaration of all the protocols of the database"""
358

Philip ABBET's avatar
Philip ABBET committed
359 360
        data = self.data['protocols']
        return dict(zip([k['name'] for k in data], data))
361 362


Philip ABBET's avatar
Philip ABBET committed
363 364
    def protocol(self, name):
        """The declaration of a specific protocol in the database"""
365

Philip ABBET's avatar
Philip ABBET committed
366
        return self.protocols[name]
367 368


Philip ABBET's avatar
Philip ABBET committed
369 370 371
    @property
    def protocol_names(self):
        """Names of protocols declared for this database"""
372

Philip ABBET's avatar
Philip ABBET committed
373 374
        data = self.data['protocols']
        return [k['name'] for k in data]
375 376


Philip ABBET's avatar
Philip ABBET committed
377 378
    def sets(self, protocol):
        """The declaration of a specific set in the database protocol"""
379

Philip ABBET's avatar
Philip ABBET committed
380 381
        data = self.protocol(protocol)['sets']
        return dict(zip([k['name'] for k in data], data))
382 383


Philip ABBET's avatar
Philip ABBET committed
384 385
    def set(self, protocol, name):
        """The declaration of all the protocols of the database"""
386

Philip ABBET's avatar
Philip ABBET committed
387
        return self.sets(protocol)[name]
388 389


Philip ABBET's avatar
Philip ABBET committed
390 391
    def set_names(self, protocol):
        """The names of sets in a given protocol for this database"""
392

Philip ABBET's avatar
Philip ABBET committed
393 394
        data = self.protocol(protocol)['sets']
        return [k['name'] for k in data]
395 396


397
    def view(self, protocol, name, exc=None, root_folder=None):
Philip ABBET's avatar
Philip ABBET committed
398
        """Returns the database view, given the protocol and the set name
399

Philip ABBET's avatar
Philip ABBET committed
400
        Parameters:
401

402 403
          protocol (str): The name of the protocol where to retrieve the view
            from
404

Philip ABBET's avatar
Philip ABBET committed
405 406
          name (str): The name of the set in the protocol where to retrieve the
            view from
407

408
          exc (:std:term:`class`): If passed, must be a valid exception class
409 410
            that will be used to report errors in the read-out of this
            database's view.
411

Philip ABBET's avatar
Philip ABBET committed
412
        Returns:
413

Philip ABBET's avatar
Philip ABBET committed
414 415
          The database view, which will be constructed, but not setup. You
          **must** set it up before using methods ``done`` or ``next``.
416

Philip ABBET's avatar
Philip ABBET committed
417 418 419 420 421 422 423 424 425
        """

        if not self._name:
            exc = exc or RuntimeError
            raise exc("database has no name")

        if not self.valid:
            message = "cannot load view for set `%s' of protocol `%s' " \
                    "from invalid database (%s)" % (protocol, name, self.name)
426 427 428
            if exc:
                raise exc(message)

Philip ABBET's avatar
Philip ABBET committed
429 430
            raise RuntimeError(message)

431 432
        # loads the module only once through the lifetime of the database
        # object
Philip ABBET's avatar
Philip ABBET committed
433 434 435 436 437 438 439 440 441
        try:
            if not hasattr(self, '_module'):
                self._module = loader.load_module(self.name.replace(os.sep, '_'),
                          self.storage.code.path, {})
        except Exception as e:
            if exc is not None:
                type, value, traceback = sys.exc_info()
                six.reraise(exc, exc(value), traceback)
            else:
442
                raise  # just re-raise the user exception
443

444 445 446 447 448
        if root_folder is None:
            root_folder = self.data['root_folder']

        return Runner(self._module, self.set(protocol, name),
                      self.prefix, root_folder, exc)
449 450


451 452 453 454 455 456
    def json_dumps(self, indent=4):
        """Dumps the JSON declaration of this object in a string


        Parameters:

457 458
          indent (int): The number of indentation spaces at every indentation
            level
459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479


        Returns:

          str: The JSON representation for this object

        """

        return simplejson.dumps(self.data, indent=indent,
                                cls=utils.NumpyJSONEncoder)


    def __str__(self):
        return self.json_dumps()


    def write(self, storage=None):
        """Writes contents to prefix location

        Parameters:

480 481 482
          storage (:py:class:`.Storage`, Optional): If you pass a new storage,
            then this object will be written to that storage point rather than
            its default.
483 484 485 486 487 488

        """

        if storage is None:
            if not self._name:
                raise RuntimeError("database has no name")
489
            storage = self.storage  # overwrite
490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521

        storage.save(str(self), self.code, self.description)


    def export(self, prefix):
        """Recursively exports itself into another prefix

        Dataformats associated are also exported recursively


        Parameters:

          prefix (str): A path to a prefix that must different then my own.


        Returns:

          None


        Raises:

          RuntimeError: If prefix and self.prefix point to the same directory.

        """

        if not self._name:
            raise RuntimeError("database has no name")

        if not self.valid:
            raise RuntimeError("database is not valid")

522 523 524
        if prefix == self.prefix:
            raise RuntimeError("Cannot export database to the same prefix ("
                               "%s)" % prefix)
525 526 527 528 529 530 531

        for k in self.dataformats.values():
            k.export(prefix)

        self.write(Storage(prefix, self.name))


532
# ----------------------------------------------------------
533 534 535 536


class View(object):

537 538 539 540 541 542 543 544 545

    def __init__(self):
        #  Current databases definitions uses named tuple to store information.
        #  This has one limitation, python keywords like `class` cannot be used.
        #  output_member_map allows to use that kind of keyword as output name
        #  while using something different for the named tuple (for example cls,
        #  klass, etc.)
        self.output_member_map = {}

546 547 548 549
    def index(self, root_folder, parameters):
        """Returns a list of (named) tuples describing the data provided by the view.

        The ordering of values inside the tuples is free, but it is expected
550 551
        that the list is ordered in a consistent manner (ie. all train images
        of person A, then all train images of person B, ...).
552 553 554

        For instance, assuming a view providing that kind of data:

555 556 557 558 559 560 561 562 563 564 565
        .. code-block:: text

           ----------- ----------- ----------- ----------- ----------- -----------
           |  image  | |  image  | |  image  | |  image  | |  image  | |  image  |
           ----------- ----------- ----------- ----------- ----------- -----------
           ----------- ----------- ----------- ----------- ----------- -----------
           | file_id | | file_id | | file_id | | file_id | | file_id | | file_id |
           ----------- ----------- ----------- ----------- ----------- -----------
           ----------------------------------- -----------------------------------
           |             client_id           | |             client_id           |
           ----------------------------------- -----------------------------------
566 567 568

        a list like the following should be generated:

569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585
        .. code-block:: python

           [
               (client_id=1, file_id=1, image=filename1),
               (client_id=1, file_id=2, image=filename2),
               (client_id=1, file_id=3, image=filename3),
               (client_id=2, file_id=4, image=filename4),
               (client_id=2, file_id=5, image=filename5),
               (client_id=2, file_id=6, image=filename6),
               ...
           ]

        .. warning::

           DO NOT store images, sound files or data loadable from a file in the
           list!  Store the path of the file to load instead.

586 587 588 589 590 591 592
        """

        raise NotImplementedError


    def setup(self, root_folder, parameters, objs, start_index=None, end_index=None):

593
        # Initialisation
594 595 596 597 598 599 600 601 602 603 604 605
        self.root_folder = root_folder
        self.parameters = parameters
        self.objs = objs

        # Determine the range of indices that must be provided
        self.start_index = start_index if start_index is not None else 0
        self.end_index = end_index if end_index is not None else len(self.objs) - 1

        self.objs = self.objs[self.start_index : self.end_index + 1]


    def get(self, output, index):
606 607 608 609
        """Returns the data of the provided output at the provided index in the
        list of (named) tuples describing the data provided by the view
        (accessible at self.objs)
        """
610 611

        raise NotImplementedError
612 613


614 615 616 617 618 619
    def get_output_mapping(self, output):
        """Returns the object member to use for given output if any otherwise
        the member name is the output name.
        """
        return self.output_member_map.get(output, output)

620
# ----------------------------------------------------------
621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701


class DatabaseTester:
    """Used while developing a new database view, to test its behavior

    This class tests that, for each combination of connected/not connected
    outputs:

      - Data indices seems consistent
      - All the connected outputs produce data
      - All the not connected outputs don't produce data

    It also report some stats, and can generate a text file detailing the
    data generated by each output.

    By default, outputs are assumed to produce data at constant intervals.
    Those that don't follow this pattern, must be declared as 'irregular'.

    Note that no particular check is done about the database declaration or
    the correctness of the generated data with their data formats. This class
    is mainly used to check that the outputs are correctly synchronized.
    """

    # Mock output class
    class MockOutput:

        def __init__(self, name, connected):
            self.name = name
            self.connected = connected
            self.last_written_data_index = -1
            self.written_data = []

        def write(self, data, end_data_index):
            self.written_data.append(( self.last_written_data_index + 1, end_data_index, data ))
            self.last_written_data_index = end_data_index

        def isConnected(self):
            return self.connected


    class SynchronizedUnit:

        def __init__(self, start, end):
            self.start = start
            self.end = end
            self.data = {}
            self.children = []

        def addData(self, output, start, end, data):
            if (start == self.start) and (end == self.end):
                self.data[output] = self._dataToString(data)
            elif (len(self.children) == 0) or (self.children[-1].end < start):
                unit = DatabaseTester.SynchronizedUnit(start, end)
                unit.addData(output, start, end, data)
                self.children.append(unit)
            else:
                for index, unit in enumerate(self.children):
                    if (unit.start <= start) and (unit.end >= end):
                        unit.addData(output, start, end, data)
                        break
                    elif (unit.start == start) and (unit.end < end):
                        new_unit = DatabaseTester.SynchronizedUnit(start, end)
                        new_unit.addData(output, start, end, data)
                        new_unit.children.append(unit)

                        for i in range(index + 1, len(self.children)):
                            unit = self.children[i]
                            if (unit.end <= end):
                                new_unit.children.append(unit)
                            else:
                                break

                        self.children = self.children[:index] + [new_unit] + self.children[i:]
                        break

        def toString(self):
            texts = {}

            for child in self.children:
                child_texts = child.toString()
                for output, text in child_texts.items():
702
                    if output in texts:
703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794
                        texts[output] += ' ' + text
                    else:
                        texts[output] = text

            if len(self.data) > 0:
                length = max([ len(x) + 6 for x in self.data.values() ])

                if len(texts) > 0:
                    children_length = len(texts.values()[0])
                    if children_length >= length:
                        length = children_length
                    else:
                        diff = length - children_length
                        if diff % 2 == 0:
                            diff1 = diff / 2
                            diff2 = diff1
                        else:
                            diff1 = diff // 2
                            diff2 = diff - diff1

                        for k, v in texts.items():
                            texts[k] = '|%s%s%s|' % ('-' * diff1, v[1:-1], '-' * diff2)

                for output, value in self.data.items():
                    output_length = len(value) + 6
                    diff = length - output_length
                    if diff % 2 == 0:
                        diff1 = diff / 2
                        diff2 = diff1
                    else:
                        diff1 = diff // 2
                        diff2 = diff - diff1
                    texts[output] = '|-%s %s %s-|' % ('-' * diff1, value, '-' * diff2)

            length = max(len(x) for x in texts.values())
            for k, v in texts.items():
                if len(v) < length:
                    texts[k] += ' ' * (length - len(v))

            return texts

        def _dataToString(self, data):
            if (len(data) > 1) or (len(data) == 0):
                return 'X'

            value = data[data.keys()[0]]

            if isinstance(value, np.ndarray) or isinstance(value, dict):
                return 'X'

            return str(value)


    def __init__(self, name, view_class, outputs_declaration, parameters,
                 irregular_outputs=[], all_combinations=True):
        self.name = name
        self.view_class = view_class
        self.outputs_declaration = {}
        self.parameters = parameters
        self.irregular_outputs = irregular_outputs

        self.determine_regular_intervals(outputs_declaration)

        if all_combinations:
            for L in range(0, len(self.outputs_declaration) + 1):
                for subset in itertools.combinations(self.outputs_declaration.keys(), L):
                    self.run(subset)
        else:
            self.run(self.outputs_declaration.keys())


    def determine_regular_intervals(self, outputs_declaration):
        outputs = OutputList()
        for name in outputs_declaration:
            outputs.add(DatabaseTester.MockOutput(name, True))

        view = self.view_class()
        view.setup('', outputs, self.parameters)

        view.next()

        for output in outputs:
            if output.name not in self.irregular_outputs:
                self.outputs_declaration[output.name] = output.last_written_data_index + 1
            else:
                self.outputs_declaration[output.name] = None


    def run(self, connected_outputs):
        if len(connected_outputs) == 0:
            return

795 796
        print("Testing '%s', with %d output(s): %s" % \
            (self.name, len(connected_outputs), ', '.join(connected_outputs)))
797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848

        # Create the mock outputs
        connected_outputs = dict([ x for x in self.outputs_declaration.items()
                                     if x[0] in connected_outputs ])

        not_connected_outputs = dict([ x for x in self.outputs_declaration.items()
                                         if x[0] not in connected_outputs ])

        outputs = OutputList()
        for name in self.outputs_declaration.keys():
            outputs.add(DatabaseTester.MockOutput(name, name in connected_outputs))


        # Create the view
        view = self.view_class()
        view.setup('', outputs, self.parameters)


        # Initialisations
        next_expected_indices = {}
        for name, interval in connected_outputs.items():
            next_expected_indices[name] = 0

        next_index = 0

        def _done():
            for output in outputs:
                if output.isConnected() and not view.done(output.last_written_data_index):
                    return False
            return True


        # Ask for all the data
        while not(_done()):
            view.next()

            # Check the indices for the connected outputs
            for name in connected_outputs.keys():
                if name not in self.irregular_outputs:
                    assert(outputs[name].written_data[-1][0] == next_expected_indices[name])
                    assert(outputs[name].written_data[-1][1] == next_expected_indices[name] + connected_outputs[name] - 1)
                else:
                    assert(outputs[name].written_data[-1][0] == next_expected_indices[name])
                    assert(outputs[name].written_data[-1][1] >= next_expected_indices[name])

            # Check that the not connected outputs didn't produce data
            for name in not_connected_outputs.keys():
                assert(len(outputs[name].written_data) == 0)

            # Compute the next data index that should be produced
            next_index = 1 + min([ x.written_data[-1][1] for x in outputs if x.isConnected() ])

849 850
            # Compute the next data index that should be produced by each
            # connected output
851 852 853 854 855 856 857 858 859 860
            for name in connected_outputs.keys():
                if name not in self.irregular_outputs:
                    if next_index == next_expected_indices[name] + connected_outputs[name]:
                        next_expected_indices[name] += connected_outputs[name]
                else:
                    if next_index > outputs[name].written_data[-1][1]:
                        next_expected_indices[name] = outputs[name].written_data[-1][1] + 1

        # Check the number of data produced on the regular outputs
        for name in connected_outputs.keys():
861
            print('  - %s: %d data' % (name, len(outputs[name].written_data)))
862 863 864 865 866 867 868 869
            if name not in self.irregular_outputs:
                assert(len(outputs[name].written_data) == next_index / connected_outputs[name])

        # Check that all outputs ends on the same index
        for name in connected_outputs.keys():
            assert(outputs[name].written_data[-1][1] == next_index - 1)


870 871
        # Generate a text file with lots of details (only if all outputs are
        # connected)
872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893
        if len(connected_outputs) == len(self.outputs_declaration):
            sorted_outputs = sorted(outputs, key=lambda x: len(x.written_data))

            unit = DatabaseTester.SynchronizedUnit(0, sorted_outputs[0].written_data[-1][1])

            for output in sorted_outputs:
                for data in output.written_data:
                    unit.addData(output.name, data[0], data[1], data[2])

            texts = unit.toString()

            outputs_max_length = max([ len(x) for x in self.outputs_declaration.keys() ])

            with open(self.name.replace(' ', '_') + '.txt', 'w') as f:
                for i in range(1, len(sorted_outputs) + 1):
                    output_name = sorted_outputs[-i].name
                    f.write(output_name + ': ')

                    if len(output_name) < outputs_max_length:
                        f.write(' ' * (outputs_max_length - len(output_name)))

                    f.write(texts[output_name] + '\n')