database.py 34.3 KB
Newer Older
1
2
3
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :

Samuel GAIST's avatar
Samuel GAIST committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
###################################################################################
#                                                                                 #
# Copyright (c) 2019 Idiap Research Institute, http://www.idiap.ch/               #
# Contact: beat.support@idiap.ch                                                  #
#                                                                                 #
# Redistribution and use in source and binary forms, with or without              #
# modification, are permitted provided that the following conditions are met:     #
#                                                                                 #
# 1. Redistributions of source code must retain the above copyright notice, this  #
# list of conditions and the following disclaimer.                                #
#                                                                                 #
# 2. Redistributions in binary form must reproduce the above copyright notice,    #
# this list of conditions and the following disclaimer in the documentation       #
# and/or other materials provided with the distribution.                          #
#                                                                                 #
# 3. Neither the name of the copyright holder nor the names of its contributors   #
# may be used to endorse or promote products derived from this software without   #
# specific prior written permission.                                              #
#                                                                                 #
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND #
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED   #
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE          #
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE    #
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL      #
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR      #
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER      #
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,   #
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE   #
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.            #
#                                                                                 #
###################################################################################
35
36


37
38
39
40
41
42
43
"""
========
database
========

Validation of databases
"""
44
45
46
47
48
49

import os
import sys

import six
import simplejson
Philip ABBET's avatar
Philip ABBET committed
50
51
import itertools
import numpy as np
52
from collections import namedtuple
53
54

from . import loader
Philip ABBET's avatar
Philip ABBET committed
55
56
from . import utils

Samuel GAIST's avatar
Samuel GAIST committed
57
from .protocoltemplate import ProtocolTemplate
Philip ABBET's avatar
Philip ABBET committed
58
from .dataformat import DataFormat
Philip ABBET's avatar
Philip ABBET committed
59
from .outputs import OutputList
Samuel GAIST's avatar
Samuel GAIST committed
60
from .exceptions import OutputError
Philip ABBET's avatar
Philip ABBET committed
61
62


63
# ----------------------------------------------------------
Philip ABBET's avatar
Philip ABBET committed
64

Philip ABBET's avatar
Philip ABBET committed
65
66

class Storage(utils.CodeStorage):
Philip ABBET's avatar
Philip ABBET committed
67
    """Resolves paths for databases
Philip ABBET's avatar
Philip ABBET committed
68

Philip ABBET's avatar
Philip ABBET committed
69
    Parameters:
Philip ABBET's avatar
Philip ABBET committed
70

Samuel GAIST's avatar
Samuel GAIST committed
71
      prefix (str): Establishes the prefix of your installation.
Philip ABBET's avatar
Philip ABBET committed
72

Philip ABBET's avatar
Philip ABBET committed
73
74
      name (str): The name of the database object in the format
        ``<name>/<version>``.
Philip ABBET's avatar
Philip ABBET committed
75

Philip ABBET's avatar
Philip ABBET committed
76
    """
Philip ABBET's avatar
Philip ABBET committed
77

78
79
80
    asset_type = "database"
    asset_folder = "databases"

Philip ABBET's avatar
Philip ABBET committed
81
    def __init__(self, prefix, name):
Philip ABBET's avatar
Philip ABBET committed
82

Samuel GAIST's avatar
Samuel GAIST committed
83
        if name.count("/") != 1:
Philip ABBET's avatar
Philip ABBET committed
84
            raise RuntimeError("invalid database name: `%s'" % name)
Philip ABBET's avatar
Philip ABBET committed
85

Samuel GAIST's avatar
Samuel GAIST committed
86
        self.name, self.version = name.split("/")
Philip ABBET's avatar
Philip ABBET committed
87
        self.fullname = name
Samuel GAIST's avatar
Samuel GAIST committed
88
        self.prefix = prefix
Philip ABBET's avatar
Philip ABBET committed
89

90
        path = os.path.join(self.prefix, self.asset_folder, name + ".json")
91
        path = path[:-5]
92
        # views are coded in Python
Samuel GAIST's avatar
Samuel GAIST committed
93
        super(Storage, self).__init__(path, "python")
94
95


96
# ----------------------------------------------------------
Philip ABBET's avatar
Philip ABBET committed
97

98

99
class Runner(object):
Samuel GAIST's avatar
Samuel GAIST committed
100
    """A special loader class for database views, with specialized methods
101

Philip ABBET's avatar
Philip ABBET committed
102
    Parameters:
103

Philip ABBET's avatar
Philip ABBET committed
104
      db_name (str): The full name of the database object for this view
105

106
107
      module (:std:term:`module`): The preloaded module containing the database
        views as returned by :py:func:`.loader.load_module`.
108

Samuel GAIST's avatar
Samuel GAIST committed
109
      prefix (str): Establishes the prefix of your installation.
110

André Anjos's avatar
André Anjos committed
111
      root_folder (str): The path pointing to the root folder of this database
112

113
114
115
      exc (:std:term:`class`): The class to use as base exception when
        translating the exception from the user code. Read the documention of
        :py:func:`.loader.run` for more details.
116

Philip ABBET's avatar
Philip ABBET committed
117
      *args: Constructor parameters for the database view. Normally, none.
118

Philip ABBET's avatar
Philip ABBET committed
119
      **kwargs: Constructor parameters for the database view. Normally, none.
120

Samuel GAIST's avatar
Samuel GAIST committed
121
    """
122

123
    def __init__(self, module, definition, prefix, root_folder, exc=None):
124

Philip ABBET's avatar
Philip ABBET committed
125
        try:
Samuel GAIST's avatar
Samuel GAIST committed
126
127
            class_ = getattr(module, definition["view"])
        except Exception:
Philip ABBET's avatar
Philip ABBET committed
128
129
130
131
            if exc is not None:
                type, value, traceback = sys.exc_info()
                six.reraise(exc, exc(value), traceback)
            else:
132
                raise  # just re-raise the user exception
133

Samuel GAIST's avatar
Samuel GAIST committed
134
135
136
137
138
139
        self.obj = loader.run(class_, "__new__", exc)
        self.ready = False
        self.prefix = prefix
        self.root_folder = root_folder
        self.definition = definition
        self.exc = exc or RuntimeError
140
        self.data_sources = None
141

142
    def index(self, filename):
Samuel GAIST's avatar
Samuel GAIST committed
143
        """Index the content of the view"""
144

Samuel GAIST's avatar
Samuel GAIST committed
145
        parameters = self.definition.get("parameters", {})
146

Samuel GAIST's avatar
Samuel GAIST committed
147
        objs = loader.run(self.obj, "index", self.exc, self.root_folder, parameters)
148

149
150
        if not isinstance(objs, list):
            raise self.exc("index() didn't return a list")
151

152
153
        if not os.path.exists(os.path.dirname(filename)):
            os.makedirs(os.path.dirname(filename))
154

Samuel GAIST's avatar
Samuel GAIST committed
155
        with open(filename, "wb") as f:
156
            data = simplejson.dumps(objs, cls=utils.NumpyJSONEncoder)
Samuel GAIST's avatar
Samuel GAIST committed
157
            f.write(data.encode("utf-8"))
158

159
    def setup(self, filename, start_index=None, end_index=None, pack=True):
Samuel GAIST's avatar
Samuel GAIST committed
160
        """Sets up the view"""
161

162
163
        if self.ready:
            return
164

Samuel GAIST's avatar
Samuel GAIST committed
165
166
        with open(filename, "rb") as f:
            objs = simplejson.loads(f.read().decode("utf-8"))
167

Samuel GAIST's avatar
Samuel GAIST committed
168
169
        Entry = namedtuple("Entry", sorted(objs[0].keys()))
        objs = [Entry(**x) for x in objs]
170

Samuel GAIST's avatar
Samuel GAIST committed
171
        parameters = self.definition.get("parameters", {})
172

Samuel GAIST's avatar
Samuel GAIST committed
173
174
175
176
177
178
179
180
181
182
        loader.run(
            self.obj,
            "setup",
            self.exc,
            self.root_folder,
            parameters,
            objs,
            start_index=start_index,
            end_index=end_index,
        )
183

184
185
        # Create data sources for the outputs
        from .data import DatabaseOutputDataSource
186

187
        self.data_sources = {}
Samuel GAIST's avatar
Samuel GAIST committed
188
        for output_name, output_format in self.definition.get("outputs", {}).items():
189
            data_source = DatabaseOutputDataSource()
Samuel GAIST's avatar
Samuel GAIST committed
190
191
192
193
194
195
196
197
198
            data_source.setup(
                self,
                output_name,
                output_format,
                self.prefix,
                start_index=start_index,
                end_index=end_index,
                pack=pack,
            )
199
            self.data_sources[output_name] = data_source
200

201
        self.ready = True
202

203
    def get(self, output, index):
Samuel GAIST's avatar
Samuel GAIST committed
204
        """Returns the data of the provided output at the provided index"""
205

Philip ABBET's avatar
Philip ABBET committed
206
        if not self.ready:
207
            raise self.exc("Database view not yet setup")
208

Samuel GAIST's avatar
Samuel GAIST committed
209
        return loader.run(self.obj, "get", self.exc, output, index)
210

211
    def get_output_mapping(self, output):
Samuel GAIST's avatar
Samuel GAIST committed
212
        return loader.run(self.obj, "get_output_mapping", self.exc, output)
213

214
215
216
217
    def objects(self):
        return self.obj.objs


218
# ----------------------------------------------------------
Philip ABBET's avatar
Philip ABBET committed
219

220
221

class Database(object):
Philip ABBET's avatar
Philip ABBET committed
222
    """Databases define the start point of the dataflow in an experiment.
223
224


Philip ABBET's avatar
Philip ABBET committed
225
    Parameters:
226

Samuel GAIST's avatar
Samuel GAIST committed
227
      prefix (str): Establishes the prefix of your installation.
228

Philip ABBET's avatar
Philip ABBET committed
229
      name (str): The fully qualified database name (e.g. ``db/1``)
230

231
232
233
234
235
236
      dataformat_cache (:py:class:`dict`, Optional): A dictionary mapping
        dataformat names to loaded dataformats. This parameter is optional and,
        if passed, may greatly speed-up database loading times as dataformats
        that are already loaded may be re-used. If you use this parameter, you
        must guarantee that the cache is refreshed as appropriate in case the
        underlying dataformats change.
237
238


Philip ABBET's avatar
Philip ABBET committed
239
    Attributes:
240

Philip ABBET's avatar
Philip ABBET committed
241
      name (str): The full, valid name of this database
242

Philip ABBET's avatar
Philip ABBET committed
243
244
      data (dict): The original data for this database, as loaded by our JSON
        decoder.
245

Philip ABBET's avatar
Philip ABBET committed
246
    """
247

Philip ABBET's avatar
Philip ABBET committed
248
    def __init__(self, prefix, name, dataformat_cache=None):
249

Philip ABBET's avatar
Philip ABBET committed
250
251
        self._name = None
        self.prefix = prefix
252
        self.dataformats = {}  # preloaded dataformats
Philip ABBET's avatar
Philip ABBET committed
253
        self.storage = None
254

Philip ABBET's avatar
Philip ABBET committed
255
256
        self.errors = []
        self.data = None
257

Philip ABBET's avatar
Philip ABBET committed
258
259
        # if the user has not provided a cache, still use one for performance
        dataformat_cache = dataformat_cache if dataformat_cache is not None else {}
260

Philip ABBET's avatar
Philip ABBET committed
261
        self._load(name, dataformat_cache)
262

Samuel GAIST's avatar
Samuel GAIST committed
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
    def _update_dataformat_cache(self, outputs, dataformat_cache):
        for key, value in outputs.items():

            if value in self.dataformats:
                continue

            if value in dataformat_cache:
                dataformat = dataformat_cache[value]
            else:
                dataformat = DataFormat(self.prefix, value)
                dataformat_cache[value] = dataformat

            self.dataformats[value] = dataformat

    def _load_v1(self, dataformat_cache):
        """Loads a v1 database and fills the dataformat cache"""

        for protocol in self.data["protocols"]:
            for set_ in protocol["sets"]:
                self._update_dataformat_cache(set_["outputs"], dataformat_cache)

    def _load_v2(self, dataformat_cache):
        """Loads a v2 database and fills the dataformat cache"""

        for protocol in self.data["protocols"]:
            protocol_template = ProtocolTemplate(
                self.prefix, protocol["template"], dataformat_cache
            )
            for set_ in protocol_template.sets():
                self._update_dataformat_cache(set_["outputs"], dataformat_cache)
293

Philip ABBET's avatar
Philip ABBET committed
294
295
    def _load(self, data, dataformat_cache):
        """Loads the database"""
296

Philip ABBET's avatar
Philip ABBET committed
297
        self._name = data
Philip ABBET's avatar
Philip ABBET committed
298

Philip ABBET's avatar
Philip ABBET committed
299
300
301
        self.storage = Storage(self.prefix, self._name)
        json_path = self.storage.json.path
        if not self.storage.json.exists():
Samuel GAIST's avatar
Samuel GAIST committed
302
            self.errors.append("Database declaration file not found: %s" % json_path)
Philip ABBET's avatar
Philip ABBET committed
303
            return
Philip ABBET's avatar
Philip ABBET committed
304

Samuel GAIST's avatar
Samuel GAIST committed
305
306
        with open(json_path, "rb") as f:
            self.data = simplejson.loads(f.read().decode("utf-8"))
Philip ABBET's avatar
Philip ABBET committed
307

308
309
310
        self.code_path = self.storage.code.path
        self.code = self.storage.code.load()

311
        if self.schema_version == 1:
Samuel GAIST's avatar
Samuel GAIST committed
312
            self._load_v1(dataformat_cache)
313
        elif self.schema_version == 2:
Samuel GAIST's avatar
Samuel GAIST committed
314
315
            self._load_v2(dataformat_cache)
        else:
316
317
            raise RuntimeError(
                "Invalid schema version {schema_version}".format(
318
                    schema_version=self.schema_version
319
320
                )
            )
321

Philip ABBET's avatar
Philip ABBET committed
322
323
324
325
    @property
    def name(self):
        """Returns the name of this object
        """
Samuel GAIST's avatar
Samuel GAIST committed
326
        return self._name or "__unnamed_database__"
327

328
329
330
331
332
333
334
335
    @name.setter
    def name(self, value):
        self._name = value
        self.storage = Storage(self.prefix, value)

    @property
    def description(self):
        """The short description for this object"""
Samuel GAIST's avatar
Samuel GAIST committed
336
        return self.data.get("description", None)
337
338
339
340

    @description.setter
    def description(self, value):
        """Sets the short description for this object"""
Samuel GAIST's avatar
Samuel GAIST committed
341
        self.data["description"] = value
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360

    @property
    def documentation(self):
        """The full-length description for this object"""

        if not self._name:
            raise RuntimeError("database has no name")

        if self.storage.doc.exists():
            return self.storage.doc.load()
        return None

    @documentation.setter
    def documentation(self, value):
        """Sets the full-length description for this object"""

        if not self._name:
            raise RuntimeError("database has no name")

Samuel GAIST's avatar
Samuel GAIST committed
361
        if hasattr(value, "read"):
362
363
364
365
366
367
368
369
370
371
372
373
            self.storage.doc.save(value.read())
        else:
            self.storage.doc.save(value)

    def hash(self):
        """Returns the hexadecimal hash for its declaration"""

        if not self._name:
            raise RuntimeError("database has no name")

        return self.storage.hash()

Philip ABBET's avatar
Philip ABBET committed
374
375
376
    @property
    def schema_version(self):
        """Returns the schema version"""
Samuel GAIST's avatar
Samuel GAIST committed
377
        return self.data.get("schema_version", 1)
378

Philip ABBET's avatar
Philip ABBET committed
379
380
    @property
    def valid(self):
381
        """A boolean that indicates if this database is valid or not"""
Philip ABBET's avatar
Philip ABBET committed
382

383
        return not bool(self.errors)
Philip ABBET's avatar
Philip ABBET committed
384

Philip ABBET's avatar
Philip ABBET committed
385
386
387
    @property
    def protocols(self):
        """The declaration of all the protocols of the database"""
388

Samuel GAIST's avatar
Samuel GAIST committed
389
390
        data = self.data["protocols"]
        return dict(zip([k["name"] for k in data], data))
391

Philip ABBET's avatar
Philip ABBET committed
392
393
    def protocol(self, name):
        """The declaration of a specific protocol in the database"""
394

Philip ABBET's avatar
Philip ABBET committed
395
        return self.protocols[name]
396

Philip ABBET's avatar
Philip ABBET committed
397
398
399
    @property
    def protocol_names(self):
        """Names of protocols declared for this database"""
400

Samuel GAIST's avatar
Samuel GAIST committed
401
402
        data = self.data["protocols"]
        return [k["name"] for k in data]
403

Philip ABBET's avatar
Philip ABBET committed
404
405
    def sets(self, protocol):
        """The declaration of a specific set in the database protocol"""
406

407
        if self.schema_version == 1:
Samuel GAIST's avatar
Samuel GAIST committed
408
409
410
411
412
413
414
415
416
            data = self.protocol(protocol)["sets"]
        else:
            protocol = self.protocol(protocol)
            protocol_template = ProtocolTemplate(self.prefix, protocol["template"])
            if not protocol_template.valid:
                raise RuntimeError(
                    "\n  * {}".format("\n  * ".join(protocol_template.errors))
                )
            data = protocol_template.sets()
417

Samuel GAIST's avatar
Samuel GAIST committed
418
        return dict(zip([k["name"] for k in data], data))
419

Philip ABBET's avatar
Philip ABBET committed
420
421
    def set(self, protocol, name):
        """The declaration of all the protocols of the database"""
422

Philip ABBET's avatar
Philip ABBET committed
423
        return self.sets(protocol)[name]
424

Philip ABBET's avatar
Philip ABBET committed
425
426
    def set_names(self, protocol):
        """The names of sets in a given protocol for this database"""
427

428
        if self.schema_version == 1:
Samuel GAIST's avatar
Samuel GAIST committed
429
430
431
432
433
434
435
436
437
438
439
440
441
442
            data = self.protocol(protocol)["sets"]
        else:
            protocol = self.protocol(protocol)
            protocol_template = ProtocolTemplate(self.prefix, protocol["template"])
            if not protocol_template.valid:
                raise RuntimeError(
                    "\n  * {}".format("\n  * ".join(protocol_template.errors))
                )
            data = protocol_template.sets()

        return [k["name"] for k in data]

    def view_definition(self, protocol_name, set_name):
        """Returns the definition of a view
443

Samuel GAIST's avatar
Samuel GAIST committed
444
445
446
447
448
449
450
451
452
        Parameters:
          protocol_name (str): The name of the protocol where to retrieve the view
            from

          set_name (str): The name of the set in the protocol where to retrieve the
            view from

        """

453
        if self.schema_version == 1:
Samuel GAIST's avatar
Samuel GAIST committed
454
455
456
457
458
459
460
            view_definition = self.set(protocol_name, set_name)
        else:
            protocol = self.protocol(protocol_name)
            template_name = protocol["template"]
            protocol_template = ProtocolTemplate(self.prefix, template_name)
            view_definition = protocol_template.set(set_name)
            view_definition["view"] = protocol["views"][set_name]["view"]
461
462
463
            parameters = protocol["views"][set_name].get("parameters")
            if parameters is not None:
                view_definition["parameters"] = parameters
Samuel GAIST's avatar
Samuel GAIST committed
464
465

        return view_definition
466

467
    def view(self, protocol, name, exc=None, root_folder=None):
Philip ABBET's avatar
Philip ABBET committed
468
        """Returns the database view, given the protocol and the set name
469

Philip ABBET's avatar
Philip ABBET committed
470
        Parameters:
471

472
473
          protocol (str): The name of the protocol where to retrieve the view
            from
474

Philip ABBET's avatar
Philip ABBET committed
475
476
          name (str): The name of the set in the protocol where to retrieve the
            view from
477

478
          exc (:std:term:`class`): If passed, must be a valid exception class
479
480
            that will be used to report errors in the read-out of this
            database's view.
481

Philip ABBET's avatar
Philip ABBET committed
482
        Returns:
483

Philip ABBET's avatar
Philip ABBET committed
484
485
          The database view, which will be constructed, but not setup. You
          **must** set it up before using methods ``done`` or ``next``.
486

Philip ABBET's avatar
Philip ABBET committed
487
488
489
490
491
492
493
        """

        if not self._name:
            exc = exc or RuntimeError
            raise exc("database has no name")

        if not self.valid:
Samuel GAIST's avatar
Samuel GAIST committed
494
495
496
497
498
            message = (
                "cannot load view for set `%s' of protocol `%s' "
                "from invalid database (%s)\n%s"
                % (protocol, name, self.name, "   \n".join(self.errors))
            )
499
500
501
            if exc:
                raise exc(message)

Philip ABBET's avatar
Philip ABBET committed
502
503
            raise RuntimeError(message)

504
505
        # loads the module only once through the lifetime of the database
        # object
Philip ABBET's avatar
Philip ABBET committed
506
        try:
Samuel GAIST's avatar
Samuel GAIST committed
507
508
509
510
511
            if not hasattr(self, "_module"):
                self._module = loader.load_module(
                    self.name.replace(os.sep, "_"), self.storage.code.path, {}
                )
        except Exception:
Philip ABBET's avatar
Philip ABBET committed
512
513
514
515
            if exc is not None:
                type, value, traceback = sys.exc_info()
                six.reraise(exc, exc(value), traceback)
            else:
516
                raise  # just re-raise the user exception
517

518
        if root_folder is None:
Samuel GAIST's avatar
Samuel GAIST committed
519
            root_folder = self.data["root_folder"]
520

Samuel GAIST's avatar
Samuel GAIST committed
521
522
523
524
525
526
527
        return Runner(
            self._module,
            self.view_definition(protocol, name),
            self.prefix,
            root_folder,
            exc,
        )
528

529
530
531
532
533
534
    def json_dumps(self, indent=4):
        """Dumps the JSON declaration of this object in a string


        Parameters:

535
536
          indent (int): The number of indentation spaces at every indentation
            level
537
538
539
540
541
542
543
544


        Returns:

          str: The JSON representation for this object

        """

Samuel GAIST's avatar
Samuel GAIST committed
545
        return simplejson.dumps(self.data, indent=indent, cls=utils.NumpyJSONEncoder)
546
547
548
549
550
551
552
553
554

    def __str__(self):
        return self.json_dumps()

    def write(self, storage=None):
        """Writes contents to prefix location

        Parameters:

555
556
557
          storage (:py:class:`.Storage`, Optional): If you pass a new storage,
            then this object will be written to that storage point rather than
            its default.
558
559
560
561
562
563

        """

        if storage is None:
            if not self._name:
                raise RuntimeError("database has no name")
564
            storage = self.storage  # overwrite
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595

        storage.save(str(self), self.code, self.description)

    def export(self, prefix):
        """Recursively exports itself into another prefix

        Dataformats associated are also exported recursively


        Parameters:

          prefix (str): A path to a prefix that must different then my own.


        Returns:

          None


        Raises:

          RuntimeError: If prefix and self.prefix point to the same directory.

        """

        if not self._name:
            raise RuntimeError("database has no name")

        if not self.valid:
            raise RuntimeError("database is not valid")

Samuel GAIST's avatar
Samuel GAIST committed
596
        if prefix == self.prefix:
Samuel GAIST's avatar
Samuel GAIST committed
597
598
599
            raise RuntimeError(
                "Cannot export database to the same prefix (" "%s)" % prefix
            )
600
601
602
603

        for k in self.dataformats.values():
            k.export(prefix)

604
        if self.schema_version != 1:
Samuel GAIST's avatar
Samuel GAIST committed
605
606
607
608
            for protocol in self.protocols.values():
                protocol_template = ProtocolTemplate(self.prefix, protocol["template"])
                protocol_template.export(prefix)

609
610
611
        self.write(Storage(prefix, self.name))


612
# ----------------------------------------------------------
613
614
615


class View(object):
616
617
618
619
620
621
622
623
    def __init__(self):
        #  Current databases definitions uses named tuple to store information.
        #  This has one limitation, python keywords like `class` cannot be used.
        #  output_member_map allows to use that kind of keyword as output name
        #  while using something different for the named tuple (for example cls,
        #  klass, etc.)
        self.output_member_map = {}

624
625
626
627
    def index(self, root_folder, parameters):
        """Returns a list of (named) tuples describing the data provided by the view.

        The ordering of values inside the tuples is free, but it is expected
628
629
        that the list is ordered in a consistent manner (ie. all train images
        of person A, then all train images of person B, ...).
630
631
632

        For instance, assuming a view providing that kind of data:

633
634
635
636
637
638
639
640
641
642
643
        .. code-block:: text

           ----------- ----------- ----------- ----------- ----------- -----------
           |  image  | |  image  | |  image  | |  image  | |  image  | |  image  |
           ----------- ----------- ----------- ----------- ----------- -----------
           ----------- ----------- ----------- ----------- ----------- -----------
           | file_id | | file_id | | file_id | | file_id | | file_id | | file_id |
           ----------- ----------- ----------- ----------- ----------- -----------
           ----------------------------------- -----------------------------------
           |             client_id           | |             client_id           |
           ----------------------------------- -----------------------------------
644
645
646

        a list like the following should be generated:

647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
        .. code-block:: python

           [
               (client_id=1, file_id=1, image=filename1),
               (client_id=1, file_id=2, image=filename2),
               (client_id=1, file_id=3, image=filename3),
               (client_id=2, file_id=4, image=filename4),
               (client_id=2, file_id=5, image=filename5),
               (client_id=2, file_id=6, image=filename6),
               ...
           ]

        .. warning::

           DO NOT store images, sound files or data loadable from a file in the
           list!  Store the path of the file to load instead.

664
665
666
667
668
669
        """

        raise NotImplementedError

    def setup(self, root_folder, parameters, objs, start_index=None, end_index=None):

670
        # Initialisation
671
672
673
674
675
676
677
678
        self.root_folder = root_folder
        self.parameters = parameters
        self.objs = objs

        # Determine the range of indices that must be provided
        self.start_index = start_index if start_index is not None else 0
        self.end_index = end_index if end_index is not None else len(self.objs) - 1

Samuel GAIST's avatar
Samuel GAIST committed
679
        self.objs = self.objs[self.start_index : self.end_index + 1]  # noqa
680
681

    def get(self, output, index):
682
683
684
685
        """Returns the data of the provided output at the provided index in the
        list of (named) tuples describing the data provided by the view
        (accessible at self.objs)
        """
686
687

        raise NotImplementedError
Philip ABBET's avatar
Philip ABBET committed
688

689
690
691
692
693
694
    def get_output_mapping(self, output):
        """Returns the object member to use for given output if any otherwise
        the member name is the output name.
        """
        return self.output_member_map.get(output, output)

Samuel GAIST's avatar
Samuel GAIST committed
695

696
# ----------------------------------------------------------
Philip ABBET's avatar
Philip ABBET committed
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728


class DatabaseTester:
    """Used while developing a new database view, to test its behavior

    This class tests that, for each combination of connected/not connected
    outputs:

      - Data indices seems consistent
      - All the connected outputs produce data
      - All the not connected outputs don't produce data

    It also report some stats, and can generate a text file detailing the
    data generated by each output.

    By default, outputs are assumed to produce data at constant intervals.
    Those that don't follow this pattern, must be declared as 'irregular'.

    Note that no particular check is done about the database declaration or
    the correctness of the generated data with their data formats. This class
    is mainly used to check that the outputs are correctly synchronized.
    """

    # Mock output class
    class MockOutput:
        def __init__(self, name, connected):
            self.name = name
            self.connected = connected
            self.last_written_data_index = -1
            self.written_data = []

        def write(self, data, end_data_index):
Samuel GAIST's avatar
Samuel GAIST committed
729
730
731
            self.written_data.append(
                (self.last_written_data_index + 1, end_data_index, data)
            )
Philip ABBET's avatar
Philip ABBET committed
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
            self.last_written_data_index = end_data_index

        def isConnected(self):
            return self.connected

    class SynchronizedUnit:
        def __init__(self, start, end):
            self.start = start
            self.end = end
            self.data = {}
            self.children = []

        def addData(self, output, start, end, data):
            if (start == self.start) and (end == self.end):
                self.data[output] = self._dataToString(data)
            elif (len(self.children) == 0) or (self.children[-1].end < start):
                unit = DatabaseTester.SynchronizedUnit(start, end)
                unit.addData(output, start, end, data)
                self.children.append(unit)
            else:
                for index, unit in enumerate(self.children):
                    if (unit.start <= start) and (unit.end >= end):
                        unit.addData(output, start, end, data)
                        break
                    elif (unit.start == start) and (unit.end < end):
                        new_unit = DatabaseTester.SynchronizedUnit(start, end)
                        new_unit.addData(output, start, end, data)
                        new_unit.children.append(unit)

                        for i in range(index + 1, len(self.children)):
                            unit = self.children[i]
Samuel GAIST's avatar
Samuel GAIST committed
763
                            if unit.end <= end:
Philip ABBET's avatar
Philip ABBET committed
764
765
766
767
                                new_unit.children.append(unit)
                            else:
                                break

Samuel GAIST's avatar
Samuel GAIST committed
768
769
770
                        self.children = (
                            self.children[:index] + [new_unit] + self.children[i:]
                        )
Philip ABBET's avatar
Philip ABBET committed
771
772
773
774
775
776
777
778
                        break

        def toString(self):
            texts = {}

            for child in self.children:
                child_texts = child.toString()
                for output, text in child_texts.items():
779
                    if output in texts:
Samuel GAIST's avatar
Samuel GAIST committed
780
                        texts[output] += " " + text
Philip ABBET's avatar
Philip ABBET committed
781
782
783
784
                    else:
                        texts[output] = text

            if len(self.data) > 0:
Samuel GAIST's avatar
Samuel GAIST committed
785
                length = max([len(x) + 6 for x in self.data.values()])
Philip ABBET's avatar
Philip ABBET committed
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800

                if len(texts) > 0:
                    children_length = len(texts.values()[0])
                    if children_length >= length:
                        length = children_length
                    else:
                        diff = length - children_length
                        if diff % 2 == 0:
                            diff1 = diff / 2
                            diff2 = diff1
                        else:
                            diff1 = diff // 2
                            diff2 = diff - diff1

                        for k, v in texts.items():
Samuel GAIST's avatar
Samuel GAIST committed
801
                            texts[k] = "|%s%s%s|" % ("-" * diff1, v[1:-1], "-" * diff2)
Philip ABBET's avatar
Philip ABBET committed
802
803
804
805
806
807
808
809
810
811

                for output, value in self.data.items():
                    output_length = len(value) + 6
                    diff = length - output_length
                    if diff % 2 == 0:
                        diff1 = diff / 2
                        diff2 = diff1
                    else:
                        diff1 = diff // 2
                        diff2 = diff - diff1
Samuel GAIST's avatar
Samuel GAIST committed
812
                    texts[output] = "|-%s %s %s-|" % ("-" * diff1, value, "-" * diff2)
Philip ABBET's avatar
Philip ABBET committed
813
814
815
816

            length = max(len(x) for x in texts.values())
            for k, v in texts.items():
                if len(v) < length:
Samuel GAIST's avatar
Samuel GAIST committed
817
                    texts[k] += " " * (length - len(v))
Philip ABBET's avatar
Philip ABBET committed
818
819
820
821
822

            return texts

        def _dataToString(self, data):
            if (len(data) > 1) or (len(data) == 0):
Samuel GAIST's avatar
Samuel GAIST committed
823
                return "X"
Philip ABBET's avatar
Philip ABBET committed
824
825
826
827

            value = data[data.keys()[0]]

            if isinstance(value, np.ndarray) or isinstance(value, dict):
Samuel GAIST's avatar
Samuel GAIST committed
828
                return "X"
Philip ABBET's avatar
Philip ABBET committed
829
830
831

            return str(value)

Samuel GAIST's avatar
Samuel GAIST committed
832
833
834
835
836
837
838
839
840
    def __init__(
        self,
        name,
        view_class,
        outputs_declaration,
        parameters,
        irregular_outputs=[],
        all_combinations=True,
    ):
Philip ABBET's avatar
Philip ABBET committed
841
842
843
844
845
846
847
848
849
850
        self.name = name
        self.view_class = view_class
        self.outputs_declaration = {}
        self.parameters = parameters
        self.irregular_outputs = irregular_outputs

        self.determine_regular_intervals(outputs_declaration)

        if all_combinations:
            for L in range(0, len(self.outputs_declaration) + 1):
Samuel GAIST's avatar
Samuel GAIST committed
851
852
853
                for subset in itertools.combinations(
                    self.outputs_declaration.keys(), L
                ):
Philip ABBET's avatar
Philip ABBET committed
854
855
856
857
858
859
860
861
862
863
                    self.run(subset)
        else:
            self.run(self.outputs_declaration.keys())

    def determine_regular_intervals(self, outputs_declaration):
        outputs = OutputList()
        for name in outputs_declaration:
            outputs.add(DatabaseTester.MockOutput(name, True))

        view = self.view_class()
Samuel GAIST's avatar
Samuel GAIST committed
864
        view.setup("", outputs, self.parameters)
Philip ABBET's avatar
Philip ABBET committed
865
866
867
868
869

        view.next()

        for output in outputs:
            if output.name not in self.irregular_outputs:
Samuel GAIST's avatar
Samuel GAIST committed
870
871
872
                self.outputs_declaration[output.name] = (
                    output.last_written_data_index + 1
                )
Philip ABBET's avatar
Philip ABBET committed
873
874
875
876
877
878
879
            else:
                self.outputs_declaration[output.name] = None

    def run(self, connected_outputs):
        if len(connected_outputs) == 0:
            return

Samuel GAIST's avatar
Samuel GAIST committed
880
881
882
883
        print(
            "Testing '%s', with %d output(s): %s"
            % (self.name, len(connected_outputs), ", ".join(connected_outputs))
        )
Philip ABBET's avatar
Philip ABBET committed
884
885

        # Create the mock outputs
Samuel GAIST's avatar
Samuel GAIST committed
886
887
888
889
890
891
892
893
894
895
896
        connected_outputs = dict(
            [x for x in self.outputs_declaration.items() if x[0] in connected_outputs]
        )

        not_connected_outputs = dict(
            [
                x
                for x in self.outputs_declaration.items()
                if x[0] not in connected_outputs
            ]
        )
Philip ABBET's avatar
Philip ABBET committed
897
898
899
900
901
902
903

        outputs = OutputList()
        for name in self.outputs_declaration.keys():
            outputs.add(DatabaseTester.MockOutput(name, name in connected_outputs))

        # Create the view
        view = self.view_class()
Samuel GAIST's avatar
Samuel GAIST committed
904
        view.setup("", outputs, self.parameters)
Philip ABBET's avatar
Philip ABBET committed
905
906
907
908
909
910
911
912
913
914

        # Initialisations
        next_expected_indices = {}
        for name, interval in connected_outputs.items():
            next_expected_indices[name] = 0

        next_index = 0

        def _done():
            for output in outputs:
Samuel GAIST's avatar
Samuel GAIST committed
915
916
917
                if output.isConnected() and not view.done(
                    output.last_written_data_index
                ):
Philip ABBET's avatar
Philip ABBET committed
918
919
920
921
                    return False
            return True

        # Ask for all the data
Samuel GAIST's avatar
Samuel GAIST committed
922
        while not (_done()):
Philip ABBET's avatar
Philip ABBET committed
923
924
925
926
927
            view.next()

            # Check the indices for the connected outputs
            for name in connected_outputs.keys():
                if name not in self.irregular_outputs:
Samuel GAIST's avatar
Samuel GAIST committed
928
929
930
931
932
933
934
935
936
                    if not (
                        outputs[name].written_data[-1][0] == next_expected_indices[name]
                    ):
                        raise OutputError("Wrong current index")
                    if not (
                        outputs[name].written_data[-1][1]
                        == next_expected_indices[name] + connected_outputs[name] - 1
                    ):
                        raise OutputError("Wrong next index")
Philip ABBET's avatar
Philip ABBET committed
937
                else:
Samuel GAIST's avatar
Samuel GAIST committed
938
939
940
941
942
943
944
945
                    if not (
                        outputs[name].written_data[-1][0] == next_expected_indices[name]
                    ):
                        raise OutputError("Wrong current index")
                    if not (
                        outputs[name].written_data[-1][1] >= next_expected_indices[name]
                    ):
                        raise OutputError("Wrong next index")
Philip ABBET's avatar
Philip ABBET committed
946
947
948

            # Check that the not connected outputs didn't produce data
            for name in not_connected_outputs.keys():
Samuel GAIST's avatar
Samuel GAIST committed
949
950
                if len(outputs[name].written_data) != 0:
                    raise OutputError("Data written on unconnected output")
Philip ABBET's avatar
Philip ABBET committed
951
952

            # Compute the next data index that should be produced
Samuel GAIST's avatar
Samuel GAIST committed
953
954
955
            next_index = 1 + min(
                [x.written_data[-1][1] for x in outputs if x.isConnected()]
            )
Philip ABBET's avatar
Philip ABBET committed
956

957
958
            # Compute the next data index that should be produced by each
            # connected output
Philip ABBET's avatar
Philip ABBET committed
959
960
            for name in connected_outputs.keys():
                if name not in self.irregular_outputs:
Samuel GAIST's avatar
Samuel GAIST committed
961
962
963
964
                    if (
                        next_index
                        == next_expected_indices[name] + connected_outputs[name]
                    ):
Philip ABBET's avatar
Philip ABBET committed
965
966
967
                        next_expected_indices[name] += connected_outputs[name]
                else:
                    if next_index > outputs[name].written_data[-1][1]:
Samuel GAIST's avatar
Samuel GAIST committed
968
969
970
                        next_expected_indices[name] = (
                            outputs[name].written_data[-1][1] + 1
                        )
Philip ABBET's avatar
Philip ABBET committed
971
972
973

        # Check the number of data produced on the regular outputs
        for name in connected_outputs.keys():
Samuel GAIST's avatar
Samuel GAIST committed
974
            print("  - %s: %d data" % (name, len(outputs[name].written_data)))
Philip ABBET's avatar
Philip ABBET committed
975
            if name not in self.irregular_outputs:
Samuel GAIST's avatar
Samuel GAIST committed
976
977
978
979
980
                if not (
                    len(outputs[name].written_data)
                    == next_index / connected_outputs[name]
                ):
                    raise OutputError("Invalid number of data produced")
Philip ABBET's avatar
Philip ABBET committed
981
982
983

        # Check that all outputs ends on the same index
        for name in connected_outputs.keys():
Samuel GAIST's avatar
Samuel GAIST committed
984
985
            if not outputs[name].written_data[-1][1] == next_index - 1:
                raise OutputError("Outputs not on same index")
Philip ABBET's avatar
Philip ABBET committed
986

987
988
        # Generate a text file with lots of details (only if all outputs are
        # connected)
Philip ABBET's avatar
Philip ABBET committed
989
990
991
        if len(connected_outputs) == len(self.outputs_declaration):
            sorted_outputs = sorted(outputs, key=lambda x: len(x.written_data))

Samuel GAIST's avatar
Samuel GAIST committed
992
993
994
            unit = DatabaseTester.SynchronizedUnit(
                0, sorted_outputs[0].written_data[-1][1]
            )
Philip ABBET's avatar
Philip ABBET committed
995
996
997
998
999
1000

            for output in sorted_outputs:
                for data in output.written_data:
                    unit.addData(output.name, data[0], data[1], data[2])

            texts = unit.toString()
For faster browsing, not all history is shown. View entire blame