database.py 34.6 KB
Newer Older
1
2
3
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :

Samuel GAIST's avatar
Samuel GAIST committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
###################################################################################
#                                                                                 #
# Copyright (c) 2019 Idiap Research Institute, http://www.idiap.ch/               #
# Contact: beat.support@idiap.ch                                                  #
#                                                                                 #
# Redistribution and use in source and binary forms, with or without              #
# modification, are permitted provided that the following conditions are met:     #
#                                                                                 #
# 1. Redistributions of source code must retain the above copyright notice, this  #
# list of conditions and the following disclaimer.                                #
#                                                                                 #
# 2. Redistributions in binary form must reproduce the above copyright notice,    #
# this list of conditions and the following disclaimer in the documentation       #
# and/or other materials provided with the distribution.                          #
#                                                                                 #
# 3. Neither the name of the copyright holder nor the names of its contributors   #
# may be used to endorse or promote products derived from this software without   #
# specific prior written permission.                                              #
#                                                                                 #
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND #
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED   #
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE          #
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE    #
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL      #
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR      #
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER      #
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,   #
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE   #
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.            #
#                                                                                 #
###################################################################################
35
36


37
38
39
40
41
42
43
"""
========
database
========

Validation of databases
"""
44
45
46
47
48

import os
import sys

import six
49
import simplejson as json
Philip ABBET's avatar
Philip ABBET committed
50
51
import itertools
import numpy as np
52
from collections import namedtuple
53
54

from . import loader
Philip ABBET's avatar
Philip ABBET committed
55
56
from . import utils

Samuel GAIST's avatar
Samuel GAIST committed
57
from .protocoltemplate import ProtocolTemplate
Philip ABBET's avatar
Philip ABBET committed
58
from .dataformat import DataFormat
Philip ABBET's avatar
Philip ABBET committed
59
from .outputs import OutputList
Samuel GAIST's avatar
Samuel GAIST committed
60
from .exceptions import OutputError
Philip ABBET's avatar
Philip ABBET committed
61
62


63
# ----------------------------------------------------------
Philip ABBET's avatar
Philip ABBET committed
64

Philip ABBET's avatar
Philip ABBET committed
65
66

class Storage(utils.CodeStorage):
Philip ABBET's avatar
Philip ABBET committed
67
    """Resolves paths for databases
Philip ABBET's avatar
Philip ABBET committed
68

Philip ABBET's avatar
Philip ABBET committed
69
    Parameters:
Philip ABBET's avatar
Philip ABBET committed
70

Samuel GAIST's avatar
Samuel GAIST committed
71
      prefix (str): Establishes the prefix of your installation.
Philip ABBET's avatar
Philip ABBET committed
72

Philip ABBET's avatar
Philip ABBET committed
73
74
      name (str): The name of the database object in the format
        ``<name>/<version>``.
Philip ABBET's avatar
Philip ABBET committed
75

Philip ABBET's avatar
Philip ABBET committed
76
    """
Philip ABBET's avatar
Philip ABBET committed
77

78
79
80
    asset_type = "database"
    asset_folder = "databases"

Philip ABBET's avatar
Philip ABBET committed
81
    def __init__(self, prefix, name):
Philip ABBET's avatar
Philip ABBET committed
82

Samuel GAIST's avatar
Samuel GAIST committed
83
        if name.count("/") != 1:
Philip ABBET's avatar
Philip ABBET committed
84
            raise RuntimeError("invalid database name: `%s'" % name)
Philip ABBET's avatar
Philip ABBET committed
85

Samuel GAIST's avatar
Samuel GAIST committed
86
        self.name, self.version = name.split("/")
Philip ABBET's avatar
Philip ABBET committed
87
        self.fullname = name
Samuel GAIST's avatar
Samuel GAIST committed
88
        self.prefix = prefix
Philip ABBET's avatar
Philip ABBET committed
89

90
        path = os.path.join(self.prefix, self.asset_folder, name + ".json")
91
        path = path[:-5]
92
        # views are coded in Python
Samuel GAIST's avatar
Samuel GAIST committed
93
        super(Storage, self).__init__(path, "python")
94
95


96
# ----------------------------------------------------------
Philip ABBET's avatar
Philip ABBET committed
97

98

99
class Runner(object):
Samuel GAIST's avatar
Samuel GAIST committed
100
    """A special loader class for database views, with specialized methods
101

Philip ABBET's avatar
Philip ABBET committed
102
    Parameters:
103

Philip ABBET's avatar
Philip ABBET committed
104
      db_name (str): The full name of the database object for this view
105

106
107
      module (:std:term:`module`): The preloaded module containing the database
        views as returned by :py:func:`.loader.load_module`.
108

Samuel GAIST's avatar
Samuel GAIST committed
109
      prefix (str): Establishes the prefix of your installation.
110

André Anjos's avatar
André Anjos committed
111
      root_folder (str): The path pointing to the root folder of this database
112

113
114
115
      exc (:std:term:`class`): The class to use as base exception when
        translating the exception from the user code. Read the documention of
        :py:func:`.loader.run` for more details.
116

Philip ABBET's avatar
Philip ABBET committed
117
      *args: Constructor parameters for the database view. Normally, none.
118

Philip ABBET's avatar
Philip ABBET committed
119
      **kwargs: Constructor parameters for the database view. Normally, none.
120

Samuel GAIST's avatar
Samuel GAIST committed
121
    """
122

123
    def __init__(self, module, definition, prefix, root_folder, exc=None):
124

Philip ABBET's avatar
Philip ABBET committed
125
        try:
Samuel GAIST's avatar
Samuel GAIST committed
126
127
            class_ = getattr(module, definition["view"])
        except Exception:
Philip ABBET's avatar
Philip ABBET committed
128
129
130
131
            if exc is not None:
                type, value, traceback = sys.exc_info()
                six.reraise(exc, exc(value), traceback)
            else:
132
                raise  # just re-raise the user exception
133

Samuel GAIST's avatar
Samuel GAIST committed
134
135
136
137
138
139
        self.obj = loader.run(class_, "__new__", exc)
        self.ready = False
        self.prefix = prefix
        self.root_folder = root_folder
        self.definition = definition
        self.exc = exc or RuntimeError
140
        self.data_sources = None
141

142
    def index(self, filename):
Samuel GAIST's avatar
Samuel GAIST committed
143
        """Index the content of the view"""
144

Samuel GAIST's avatar
Samuel GAIST committed
145
        parameters = self.definition.get("parameters", {})
146

Samuel GAIST's avatar
Samuel GAIST committed
147
        objs = loader.run(self.obj, "index", self.exc, self.root_folder, parameters)
148

149
150
        if not isinstance(objs, list):
            raise self.exc("index() didn't return a list")
151

152
153
        if not os.path.exists(os.path.dirname(filename)):
            os.makedirs(os.path.dirname(filename))
154

Samuel GAIST's avatar
Samuel GAIST committed
155
        with open(filename, "wb") as f:
156
            data = json.dumps(objs, cls=utils.NumpyJSONEncoder)
Samuel GAIST's avatar
Samuel GAIST committed
157
            f.write(data.encode("utf-8"))
158

159
    def setup(self, filename, start_index=None, end_index=None, pack=True):
Samuel GAIST's avatar
Samuel GAIST committed
160
        """Sets up the view"""
161

162
163
        if self.ready:
            return
164

Samuel GAIST's avatar
Samuel GAIST committed
165
        with open(filename, "rb") as f:
166
167
168
169
            objs = json.loads(
                f.read().decode("utf-8"),
                object_pairs_hook=utils.error_on_duplicate_key_hook,
            )
170

Samuel GAIST's avatar
Samuel GAIST committed
171
172
        Entry = namedtuple("Entry", sorted(objs[0].keys()))
        objs = [Entry(**x) for x in objs]
173

Samuel GAIST's avatar
Samuel GAIST committed
174
        parameters = self.definition.get("parameters", {})
175

Samuel GAIST's avatar
Samuel GAIST committed
176
177
178
179
180
181
182
183
184
185
        loader.run(
            self.obj,
            "setup",
            self.exc,
            self.root_folder,
            parameters,
            objs,
            start_index=start_index,
            end_index=end_index,
        )
186

187
188
        # Create data sources for the outputs
        from .data import DatabaseOutputDataSource
189

190
        self.data_sources = {}
Samuel GAIST's avatar
Samuel GAIST committed
191
        for output_name, output_format in self.definition.get("outputs", {}).items():
192
            data_source = DatabaseOutputDataSource()
Samuel GAIST's avatar
Samuel GAIST committed
193
194
195
196
197
198
199
200
201
            data_source.setup(
                self,
                output_name,
                output_format,
                self.prefix,
                start_index=start_index,
                end_index=end_index,
                pack=pack,
            )
202
            self.data_sources[output_name] = data_source
203

204
        self.ready = True
205

206
    def get(self, output, index):
Samuel GAIST's avatar
Samuel GAIST committed
207
        """Returns the data of the provided output at the provided index"""
208

Philip ABBET's avatar
Philip ABBET committed
209
        if not self.ready:
210
            raise self.exc("Database view not yet setup")
211

Samuel GAIST's avatar
Samuel GAIST committed
212
        return loader.run(self.obj, "get", self.exc, output, index)
213

214
    def get_output_mapping(self, output):
Samuel GAIST's avatar
Samuel GAIST committed
215
        return loader.run(self.obj, "get_output_mapping", self.exc, output)
216

217
218
219
220
    def objects(self):
        return self.obj.objs


221
# ----------------------------------------------------------
Philip ABBET's avatar
Philip ABBET committed
222

223
224

class Database(object):
Philip ABBET's avatar
Philip ABBET committed
225
    """Databases define the start point of the dataflow in an experiment.
226
227


Philip ABBET's avatar
Philip ABBET committed
228
    Parameters:
229

Samuel GAIST's avatar
Samuel GAIST committed
230
      prefix (str): Establishes the prefix of your installation.
231

Philip ABBET's avatar
Philip ABBET committed
232
      name (str): The fully qualified database name (e.g. ``db/1``)
233

234
235
236
237
238
239
      dataformat_cache (:py:class:`dict`, Optional): A dictionary mapping
        dataformat names to loaded dataformats. This parameter is optional and,
        if passed, may greatly speed-up database loading times as dataformats
        that are already loaded may be re-used. If you use this parameter, you
        must guarantee that the cache is refreshed as appropriate in case the
        underlying dataformats change.
240
241


Philip ABBET's avatar
Philip ABBET committed
242
    Attributes:
243

Philip ABBET's avatar
Philip ABBET committed
244
      name (str): The full, valid name of this database
245

Philip ABBET's avatar
Philip ABBET committed
246
247
      data (dict): The original data for this database, as loaded by our JSON
        decoder.
248

Philip ABBET's avatar
Philip ABBET committed
249
    """
250

Philip ABBET's avatar
Philip ABBET committed
251
    def __init__(self, prefix, name, dataformat_cache=None):
252

Philip ABBET's avatar
Philip ABBET committed
253
254
        self._name = None
        self.prefix = prefix
255
        self.dataformats = {}  # preloaded dataformats
Philip ABBET's avatar
Philip ABBET committed
256
        self.storage = None
257

Philip ABBET's avatar
Philip ABBET committed
258
259
        self.errors = []
        self.data = None
260

Philip ABBET's avatar
Philip ABBET committed
261
262
        # if the user has not provided a cache, still use one for performance
        dataformat_cache = dataformat_cache if dataformat_cache is not None else {}
263

Philip ABBET's avatar
Philip ABBET committed
264
        self._load(name, dataformat_cache)
265

Samuel GAIST's avatar
Samuel GAIST committed
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
    def _update_dataformat_cache(self, outputs, dataformat_cache):
        for key, value in outputs.items():

            if value in self.dataformats:
                continue

            if value in dataformat_cache:
                dataformat = dataformat_cache[value]
            else:
                dataformat = DataFormat(self.prefix, value)
                dataformat_cache[value] = dataformat

            self.dataformats[value] = dataformat

    def _load_v1(self, dataformat_cache):
        """Loads a v1 database and fills the dataformat cache"""

        for protocol in self.data["protocols"]:
            for set_ in protocol["sets"]:
                self._update_dataformat_cache(set_["outputs"], dataformat_cache)

    def _load_v2(self, dataformat_cache):
        """Loads a v2 database and fills the dataformat cache"""

        for protocol in self.data["protocols"]:
            protocol_template = ProtocolTemplate(
                self.prefix, protocol["template"], dataformat_cache
            )
            for set_ in protocol_template.sets():
                self._update_dataformat_cache(set_["outputs"], dataformat_cache)
296

Philip ABBET's avatar
Philip ABBET committed
297
298
    def _load(self, data, dataformat_cache):
        """Loads the database"""
299

Philip ABBET's avatar
Philip ABBET committed
300
        self._name = data
Philip ABBET's avatar
Philip ABBET committed
301

Philip ABBET's avatar
Philip ABBET committed
302
303
304
        self.storage = Storage(self.prefix, self._name)
        json_path = self.storage.json.path
        if not self.storage.json.exists():
Samuel GAIST's avatar
Samuel GAIST committed
305
            self.errors.append("Database declaration file not found: %s" % json_path)
Philip ABBET's avatar
Philip ABBET committed
306
            return
Philip ABBET's avatar
Philip ABBET committed
307

Samuel GAIST's avatar
Samuel GAIST committed
308
        with open(json_path, "rb") as f:
309
310
311
312
313
314
315
316
            try:
                self.data = json.loads(
                    f.read().decode("utf-8"),
                    object_pairs_hook=utils.error_on_duplicate_key_hook,
                )
            except RuntimeError as error:
                self.errors.append("Database declaration file invalid: %s" % error)
                return
Philip ABBET's avatar
Philip ABBET committed
317

318
319
320
        self.code_path = self.storage.code.path
        self.code = self.storage.code.load()

321
        if self.schema_version == 1:
Samuel GAIST's avatar
Samuel GAIST committed
322
            self._load_v1(dataformat_cache)
323
        elif self.schema_version == 2:
Samuel GAIST's avatar
Samuel GAIST committed
324
325
            self._load_v2(dataformat_cache)
        else:
326
327
            raise RuntimeError(
                "Invalid schema version {schema_version}".format(
328
                    schema_version=self.schema_version
329
330
                )
            )
331

Philip ABBET's avatar
Philip ABBET committed
332
333
334
335
    @property
    def name(self):
        """Returns the name of this object
        """
Samuel GAIST's avatar
Samuel GAIST committed
336
        return self._name or "__unnamed_database__"
337

338
339
340
341
342
343
344
345
    @name.setter
    def name(self, value):
        self._name = value
        self.storage = Storage(self.prefix, value)

    @property
    def description(self):
        """The short description for this object"""
Samuel GAIST's avatar
Samuel GAIST committed
346
        return self.data.get("description", None)
347
348
349
350

    @description.setter
    def description(self, value):
        """Sets the short description for this object"""
Samuel GAIST's avatar
Samuel GAIST committed
351
        self.data["description"] = value
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370

    @property
    def documentation(self):
        """The full-length description for this object"""

        if not self._name:
            raise RuntimeError("database has no name")

        if self.storage.doc.exists():
            return self.storage.doc.load()
        return None

    @documentation.setter
    def documentation(self, value):
        """Sets the full-length description for this object"""

        if not self._name:
            raise RuntimeError("database has no name")

Samuel GAIST's avatar
Samuel GAIST committed
371
        if hasattr(value, "read"):
372
373
374
375
376
377
378
379
380
381
382
383
            self.storage.doc.save(value.read())
        else:
            self.storage.doc.save(value)

    def hash(self):
        """Returns the hexadecimal hash for its declaration"""

        if not self._name:
            raise RuntimeError("database has no name")

        return self.storage.hash()

Philip ABBET's avatar
Philip ABBET committed
384
385
386
    @property
    def schema_version(self):
        """Returns the schema version"""
Samuel GAIST's avatar
Samuel GAIST committed
387
        return self.data.get("schema_version", 1)
388

Philip ABBET's avatar
Philip ABBET committed
389
390
    @property
    def valid(self):
391
        """A boolean that indicates if this database is valid or not"""
Philip ABBET's avatar
Philip ABBET committed
392

393
        return not bool(self.errors)
Philip ABBET's avatar
Philip ABBET committed
394

Philip ABBET's avatar
Philip ABBET committed
395
396
397
    @property
    def protocols(self):
        """The declaration of all the protocols of the database"""
398

Samuel GAIST's avatar
Samuel GAIST committed
399
400
        data = self.data["protocols"]
        return dict(zip([k["name"] for k in data], data))
401

Philip ABBET's avatar
Philip ABBET committed
402
403
    def protocol(self, name):
        """The declaration of a specific protocol in the database"""
404

Philip ABBET's avatar
Philip ABBET committed
405
        return self.protocols[name]
406

Philip ABBET's avatar
Philip ABBET committed
407
408
409
    @property
    def protocol_names(self):
        """Names of protocols declared for this database"""
410

Samuel GAIST's avatar
Samuel GAIST committed
411
412
        data = self.data["protocols"]
        return [k["name"] for k in data]
413

Philip ABBET's avatar
Philip ABBET committed
414
415
    def sets(self, protocol):
        """The declaration of a specific set in the database protocol"""
416

417
        if self.schema_version == 1:
Samuel GAIST's avatar
Samuel GAIST committed
418
419
420
421
422
423
424
425
426
            data = self.protocol(protocol)["sets"]
        else:
            protocol = self.protocol(protocol)
            protocol_template = ProtocolTemplate(self.prefix, protocol["template"])
            if not protocol_template.valid:
                raise RuntimeError(
                    "\n  * {}".format("\n  * ".join(protocol_template.errors))
                )
            data = protocol_template.sets()
427

Samuel GAIST's avatar
Samuel GAIST committed
428
        return dict(zip([k["name"] for k in data], data))
429

Philip ABBET's avatar
Philip ABBET committed
430
431
    def set(self, protocol, name):
        """The declaration of all the protocols of the database"""
432

Philip ABBET's avatar
Philip ABBET committed
433
        return self.sets(protocol)[name]
434

Philip ABBET's avatar
Philip ABBET committed
435
436
    def set_names(self, protocol):
        """The names of sets in a given protocol for this database"""
437

438
        if self.schema_version == 1:
Samuel GAIST's avatar
Samuel GAIST committed
439
440
441
442
443
444
445
446
447
448
449
450
451
452
            data = self.protocol(protocol)["sets"]
        else:
            protocol = self.protocol(protocol)
            protocol_template = ProtocolTemplate(self.prefix, protocol["template"])
            if not protocol_template.valid:
                raise RuntimeError(
                    "\n  * {}".format("\n  * ".join(protocol_template.errors))
                )
            data = protocol_template.sets()

        return [k["name"] for k in data]

    def view_definition(self, protocol_name, set_name):
        """Returns the definition of a view
453

Samuel GAIST's avatar
Samuel GAIST committed
454
455
456
457
458
459
460
461
462
        Parameters:
          protocol_name (str): The name of the protocol where to retrieve the view
            from

          set_name (str): The name of the set in the protocol where to retrieve the
            view from

        """

463
        if self.schema_version == 1:
Samuel GAIST's avatar
Samuel GAIST committed
464
465
466
467
468
469
470
            view_definition = self.set(protocol_name, set_name)
        else:
            protocol = self.protocol(protocol_name)
            template_name = protocol["template"]
            protocol_template = ProtocolTemplate(self.prefix, template_name)
            view_definition = protocol_template.set(set_name)
            view_definition["view"] = protocol["views"][set_name]["view"]
471
472
473
            parameters = protocol["views"][set_name].get("parameters")
            if parameters is not None:
                view_definition["parameters"] = parameters
Samuel GAIST's avatar
Samuel GAIST committed
474
475

        return view_definition
476

477
    def view(self, protocol, name, exc=None, root_folder=None):
Philip ABBET's avatar
Philip ABBET committed
478
        """Returns the database view, given the protocol and the set name
479

Philip ABBET's avatar
Philip ABBET committed
480
        Parameters:
481

482
483
          protocol (str): The name of the protocol where to retrieve the view
            from
484

Philip ABBET's avatar
Philip ABBET committed
485
486
          name (str): The name of the set in the protocol where to retrieve the
            view from
487

488
          exc (:std:term:`class`): If passed, must be a valid exception class
489
490
            that will be used to report errors in the read-out of this
            database's view.
491

Philip ABBET's avatar
Philip ABBET committed
492
        Returns:
493

Philip ABBET's avatar
Philip ABBET committed
494
495
          The database view, which will be constructed, but not setup. You
          **must** set it up before using methods ``done`` or ``next``.
496

Philip ABBET's avatar
Philip ABBET committed
497
498
499
500
501
502
503
        """

        if not self._name:
            exc = exc or RuntimeError
            raise exc("database has no name")

        if not self.valid:
Samuel GAIST's avatar
Samuel GAIST committed
504
505
506
507
508
            message = (
                "cannot load view for set `%s' of protocol `%s' "
                "from invalid database (%s)\n%s"
                % (protocol, name, self.name, "   \n".join(self.errors))
            )
509
510
511
            if exc:
                raise exc(message)

Philip ABBET's avatar
Philip ABBET committed
512
513
            raise RuntimeError(message)

514
515
        # loads the module only once through the lifetime of the database
        # object
Philip ABBET's avatar
Philip ABBET committed
516
        try:
Samuel GAIST's avatar
Samuel GAIST committed
517
518
519
520
521
            if not hasattr(self, "_module"):
                self._module = loader.load_module(
                    self.name.replace(os.sep, "_"), self.storage.code.path, {}
                )
        except Exception:
Philip ABBET's avatar
Philip ABBET committed
522
523
524
525
            if exc is not None:
                type, value, traceback = sys.exc_info()
                six.reraise(exc, exc(value), traceback)
            else:
526
                raise  # just re-raise the user exception
527

528
        if root_folder is None:
Samuel GAIST's avatar
Samuel GAIST committed
529
            root_folder = self.data["root_folder"]
530

Samuel GAIST's avatar
Samuel GAIST committed
531
532
533
534
535
536
537
        return Runner(
            self._module,
            self.view_definition(protocol, name),
            self.prefix,
            root_folder,
            exc,
        )
538

539
540
541
542
543
544
    def json_dumps(self, indent=4):
        """Dumps the JSON declaration of this object in a string


        Parameters:

545
546
          indent (int): The number of indentation spaces at every indentation
            level
547
548
549
550
551
552
553
554


        Returns:

          str: The JSON representation for this object

        """

555
        return json.dumps(self.data, indent=indent, cls=utils.NumpyJSONEncoder)
556
557
558
559
560
561
562
563
564

    def __str__(self):
        return self.json_dumps()

    def write(self, storage=None):
        """Writes contents to prefix location

        Parameters:

565
566
567
          storage (:py:class:`.Storage`, Optional): If you pass a new storage,
            then this object will be written to that storage point rather than
            its default.
568
569
570
571
572
573

        """

        if storage is None:
            if not self._name:
                raise RuntimeError("database has no name")
574
            storage = self.storage  # overwrite
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605

        storage.save(str(self), self.code, self.description)

    def export(self, prefix):
        """Recursively exports itself into another prefix

        Dataformats associated are also exported recursively


        Parameters:

          prefix (str): A path to a prefix that must different then my own.


        Returns:

          None


        Raises:

          RuntimeError: If prefix and self.prefix point to the same directory.

        """

        if not self._name:
            raise RuntimeError("database has no name")

        if not self.valid:
            raise RuntimeError("database is not valid")

Samuel GAIST's avatar
Samuel GAIST committed
606
        if prefix == self.prefix:
Samuel GAIST's avatar
Samuel GAIST committed
607
608
609
            raise RuntimeError(
                "Cannot export database to the same prefix (" "%s)" % prefix
            )
610
611
612
613

        for k in self.dataformats.values():
            k.export(prefix)

614
        if self.schema_version != 1:
Samuel GAIST's avatar
Samuel GAIST committed
615
616
617
618
            for protocol in self.protocols.values():
                protocol_template = ProtocolTemplate(self.prefix, protocol["template"])
                protocol_template.export(prefix)

619
620
621
        self.write(Storage(prefix, self.name))


622
# ----------------------------------------------------------
623
624
625


class View(object):
626
627
628
629
630
631
632
633
    def __init__(self):
        #  Current databases definitions uses named tuple to store information.
        #  This has one limitation, python keywords like `class` cannot be used.
        #  output_member_map allows to use that kind of keyword as output name
        #  while using something different for the named tuple (for example cls,
        #  klass, etc.)
        self.output_member_map = {}

634
635
636
637
    def index(self, root_folder, parameters):
        """Returns a list of (named) tuples describing the data provided by the view.

        The ordering of values inside the tuples is free, but it is expected
638
639
        that the list is ordered in a consistent manner (ie. all train images
        of person A, then all train images of person B, ...).
640
641
642

        For instance, assuming a view providing that kind of data:

643
644
645
646
647
648
649
650
651
652
653
        .. code-block:: text

           ----------- ----------- ----------- ----------- ----------- -----------
           |  image  | |  image  | |  image  | |  image  | |  image  | |  image  |
           ----------- ----------- ----------- ----------- ----------- -----------
           ----------- ----------- ----------- ----------- ----------- -----------
           | file_id | | file_id | | file_id | | file_id | | file_id | | file_id |
           ----------- ----------- ----------- ----------- ----------- -----------
           ----------------------------------- -----------------------------------
           |             client_id           | |             client_id           |
           ----------------------------------- -----------------------------------
654
655
656

        a list like the following should be generated:

657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
        .. code-block:: python

           [
               (client_id=1, file_id=1, image=filename1),
               (client_id=1, file_id=2, image=filename2),
               (client_id=1, file_id=3, image=filename3),
               (client_id=2, file_id=4, image=filename4),
               (client_id=2, file_id=5, image=filename5),
               (client_id=2, file_id=6, image=filename6),
               ...
           ]

        .. warning::

           DO NOT store images, sound files or data loadable from a file in the
           list!  Store the path of the file to load instead.

674
675
676
677
678
679
        """

        raise NotImplementedError

    def setup(self, root_folder, parameters, objs, start_index=None, end_index=None):

680
        # Initialisation
681
682
683
684
685
686
687
688
        self.root_folder = root_folder
        self.parameters = parameters
        self.objs = objs

        # Determine the range of indices that must be provided
        self.start_index = start_index if start_index is not None else 0
        self.end_index = end_index if end_index is not None else len(self.objs) - 1

Samuel GAIST's avatar
Samuel GAIST committed
689
        self.objs = self.objs[self.start_index : self.end_index + 1]  # noqa
690
691

    def get(self, output, index):
692
693
694
695
        """Returns the data of the provided output at the provided index in the
        list of (named) tuples describing the data provided by the view
        (accessible at self.objs)
        """
696
697

        raise NotImplementedError
Philip ABBET's avatar
Philip ABBET committed
698

699
700
701
702
703
704
    def get_output_mapping(self, output):
        """Returns the object member to use for given output if any otherwise
        the member name is the output name.
        """
        return self.output_member_map.get(output, output)

Samuel GAIST's avatar
Samuel GAIST committed
705

706
# ----------------------------------------------------------
Philip ABBET's avatar
Philip ABBET committed
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738


class DatabaseTester:
    """Used while developing a new database view, to test its behavior

    This class tests that, for each combination of connected/not connected
    outputs:

      - Data indices seems consistent
      - All the connected outputs produce data
      - All the not connected outputs don't produce data

    It also report some stats, and can generate a text file detailing the
    data generated by each output.

    By default, outputs are assumed to produce data at constant intervals.
    Those that don't follow this pattern, must be declared as 'irregular'.

    Note that no particular check is done about the database declaration or
    the correctness of the generated data with their data formats. This class
    is mainly used to check that the outputs are correctly synchronized.
    """

    # Mock output class
    class MockOutput:
        def __init__(self, name, connected):
            self.name = name
            self.connected = connected
            self.last_written_data_index = -1
            self.written_data = []

        def write(self, data, end_data_index):
Samuel GAIST's avatar
Samuel GAIST committed
739
740
741
            self.written_data.append(
                (self.last_written_data_index + 1, end_data_index, data)
            )
Philip ABBET's avatar
Philip ABBET committed
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
            self.last_written_data_index = end_data_index

        def isConnected(self):
            return self.connected

    class SynchronizedUnit:
        def __init__(self, start, end):
            self.start = start
            self.end = end
            self.data = {}
            self.children = []

        def addData(self, output, start, end, data):
            if (start == self.start) and (end == self.end):
                self.data[output] = self._dataToString(data)
            elif (len(self.children) == 0) or (self.children[-1].end < start):
                unit = DatabaseTester.SynchronizedUnit(start, end)
                unit.addData(output, start, end, data)
                self.children.append(unit)
            else:
                for index, unit in enumerate(self.children):
                    if (unit.start <= start) and (unit.end >= end):
                        unit.addData(output, start, end, data)
                        break
                    elif (unit.start == start) and (unit.end < end):
                        new_unit = DatabaseTester.SynchronizedUnit(start, end)
                        new_unit.addData(output, start, end, data)
                        new_unit.children.append(unit)

                        for i in range(index + 1, len(self.children)):
                            unit = self.children[i]
Samuel GAIST's avatar
Samuel GAIST committed
773
                            if unit.end <= end:
Philip ABBET's avatar
Philip ABBET committed
774
775
776
777
                                new_unit.children.append(unit)
                            else:
                                break

Samuel GAIST's avatar
Samuel GAIST committed
778
779
780
                        self.children = (
                            self.children[:index] + [new_unit] + self.children[i:]
                        )
Philip ABBET's avatar
Philip ABBET committed
781
782
783
784
785
786
787
788
                        break

        def toString(self):
            texts = {}

            for child in self.children:
                child_texts = child.toString()
                for output, text in child_texts.items():
789
                    if output in texts:
Samuel GAIST's avatar
Samuel GAIST committed
790
                        texts[output] += " " + text
Philip ABBET's avatar
Philip ABBET committed
791
792
793
794
                    else:
                        texts[output] = text

            if len(self.data) > 0:
Samuel GAIST's avatar
Samuel GAIST committed
795
                length = max([len(x) + 6 for x in self.data.values()])
Philip ABBET's avatar
Philip ABBET committed
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810

                if len(texts) > 0:
                    children_length = len(texts.values()[0])
                    if children_length >= length:
                        length = children_length
                    else:
                        diff = length - children_length
                        if diff % 2 == 0:
                            diff1 = diff / 2
                            diff2 = diff1
                        else:
                            diff1 = diff // 2
                            diff2 = diff - diff1

                        for k, v in texts.items():
Samuel GAIST's avatar
Samuel GAIST committed
811
                            texts[k] = "|%s%s%s|" % ("-" * diff1, v[1:-1], "-" * diff2)
Philip ABBET's avatar
Philip ABBET committed
812
813
814
815
816
817
818
819
820
821

                for output, value in self.data.items():
                    output_length = len(value) + 6
                    diff = length - output_length
                    if diff % 2 == 0:
                        diff1 = diff / 2
                        diff2 = diff1
                    else:
                        diff1 = diff // 2
                        diff2 = diff - diff1
Samuel GAIST's avatar
Samuel GAIST committed
822
                    texts[output] = "|-%s %s %s-|" % ("-" * diff1, value, "-" * diff2)
Philip ABBET's avatar
Philip ABBET committed
823
824
825
826

            length = max(len(x) for x in texts.values())
            for k, v in texts.items():
                if len(v) < length:
Samuel GAIST's avatar
Samuel GAIST committed
827
                    texts[k] += " " * (length - len(v))
Philip ABBET's avatar
Philip ABBET committed
828
829
830
831
832

            return texts

        def _dataToString(self, data):
            if (len(data) > 1) or (len(data) == 0):
Samuel GAIST's avatar
Samuel GAIST committed
833
                return "X"
Philip ABBET's avatar
Philip ABBET committed
834
835
836
837

            value = data[data.keys()[0]]

            if isinstance(value, np.ndarray) or isinstance(value, dict):
Samuel GAIST's avatar
Samuel GAIST committed
838
                return "X"
Philip ABBET's avatar
Philip ABBET committed
839
840
841

            return str(value)

Samuel GAIST's avatar
Samuel GAIST committed
842
843
844
845
846
847
848
849
850
    def __init__(
        self,
        name,
        view_class,
        outputs_declaration,
        parameters,
        irregular_outputs=[],
        all_combinations=True,
    ):
Philip ABBET's avatar
Philip ABBET committed
851
852
853
854
855
856
857
858
859
860
        self.name = name
        self.view_class = view_class
        self.outputs_declaration = {}
        self.parameters = parameters
        self.irregular_outputs = irregular_outputs

        self.determine_regular_intervals(outputs_declaration)

        if all_combinations:
            for L in range(0, len(self.outputs_declaration) + 1):
Samuel GAIST's avatar
Samuel GAIST committed
861
862
863
                for subset in itertools.combinations(
                    self.outputs_declaration.keys(), L
                ):
Philip ABBET's avatar
Philip ABBET committed
864
865
866
867
868
869
870
871
872
873
                    self.run(subset)
        else:
            self.run(self.outputs_declaration.keys())

    def determine_regular_intervals(self, outputs_declaration):
        outputs = OutputList()
        for name in outputs_declaration:
            outputs.add(DatabaseTester.MockOutput(name, True))

        view = self.view_class()
Samuel GAIST's avatar
Samuel GAIST committed
874
        view.setup("", outputs, self.parameters)
Philip ABBET's avatar
Philip ABBET committed
875
876
877
878
879

        view.next()

        for output in outputs:
            if output.name not in self.irregular_outputs:
Samuel GAIST's avatar
Samuel GAIST committed
880
881
882
                self.outputs_declaration[output.name] = (
                    output.last_written_data_index + 1
                )
Philip ABBET's avatar
Philip ABBET committed
883
884
885
886
887
888
889
            else:
                self.outputs_declaration[output.name] = None

    def run(self, connected_outputs):
        if len(connected_outputs) == 0:
            return

Samuel GAIST's avatar
Samuel GAIST committed
890
891
892
893
        print(
            "Testing '%s', with %d output(s): %s"
            % (self.name, len(connected_outputs), ", ".join(connected_outputs))
        )
Philip ABBET's avatar
Philip ABBET committed
894
895

        # Create the mock outputs
Samuel GAIST's avatar
Samuel GAIST committed
896
897
898
899
900
901
902
903
904
905
906
        connected_outputs = dict(
            [x for x in self.outputs_declaration.items() if x[0] in connected_outputs]
        )

        not_connected_outputs = dict(
            [
                x
                for x in self.outputs_declaration.items()
                if x[0] not in connected_outputs
            ]
        )
Philip ABBET's avatar
Philip ABBET committed
907
908
909
910
911
912
913

        outputs = OutputList()
        for name in self.outputs_declaration.keys():
            outputs.add(DatabaseTester.MockOutput(name, name in connected_outputs))

        # Create the view
        view = self.view_class()
Samuel GAIST's avatar
Samuel GAIST committed
914
        view.setup("", outputs, self.parameters)
Philip ABBET's avatar
Philip ABBET committed
915
916
917
918
919
920
921
922
923
924

        # Initialisations
        next_expected_indices = {}
        for name, interval in connected_outputs.items():
            next_expected_indices[name] = 0

        next_index = 0

        def _done():
            for output in outputs:
Samuel GAIST's avatar
Samuel GAIST committed
925
926
927
                if output.isConnected() and not view.done(
                    output.last_written_data_index
                ):
Philip ABBET's avatar
Philip ABBET committed
928
929
930
931
                    return False
            return True

        # Ask for all the data
Samuel GAIST's avatar
Samuel GAIST committed
932
        while not (_done()):
Philip ABBET's avatar
Philip ABBET committed
933
934
935
936
937
            view.next()

            # Check the indices for the connected outputs
            for name in connected_outputs.keys():
                if name not in self.irregular_outputs:
Samuel GAIST's avatar
Samuel GAIST committed
938
939
940
941
942
943
944
945
946
                    if not (
                        outputs[name].written_data[-1][0] == next_expected_indices[name]
                    ):
                        raise OutputError("Wrong current index")
                    if not (
                        outputs[name].written_data[-1][1]
                        == next_expected_indices[name] + connected_outputs[name] - 1
                    ):
                        raise OutputError("Wrong next index")
Philip ABBET's avatar
Philip ABBET committed
947
                else:
Samuel GAIST's avatar
Samuel GAIST committed
948
949
950
951
952
953
954
955
                    if not (
                        outputs[name].written_data[-1][0] == next_expected_indices[name]
                    ):
                        raise OutputError("Wrong current index")
                    if not (
                        outputs[name].written_data[-1][1] >= next_expected_indices[name]
                    ):
                        raise OutputError("Wrong next index")
Philip ABBET's avatar
Philip ABBET committed
956
957
958

            # Check that the not connected outputs didn't produce data
            for name in not_connected_outputs.keys():
Samuel GAIST's avatar
Samuel GAIST committed
959
960
                if len(outputs[name].written_data) != 0:
                    raise OutputError("Data written on unconnected output")
Philip ABBET's avatar
Philip ABBET committed
961
962

            # Compute the next data index that should be produced
Samuel GAIST's avatar
Samuel GAIST committed
963
964
965
            next_index = 1 + min(
                [x.written_data[-1][1] for x in outputs if x.isConnected()]
            )
Philip ABBET's avatar
Philip ABBET committed
966

967
968
            # Compute the next data index that should be produced by each
            # connected output
Philip ABBET's avatar
Philip ABBET committed
969
970
            for name in connected_outputs.keys():
                if name not in self.irregular_outputs:
Samuel GAIST's avatar
Samuel GAIST committed
971
972
973
974
                    if (
                        next_index
                        == next_expected_indices[name] + connected_outputs[name]
                    ):
Philip ABBET's avatar
Philip ABBET committed
975
976
977
                        next_expected_indices[name] += connected_outputs[name]
                else:
                    if next_index > outputs[name].written_data[-1][1]:
Samuel GAIST's avatar
Samuel GAIST committed
978
979
980
                        next_expected_indices[name] = (
                            outputs[name].written_data[-1][1] + 1
                        )
Philip ABBET's avatar
Philip ABBET committed
981
982
983

        # Check the number of data produced on the regular outputs
        for name in connected_outputs.keys():
Samuel GAIST's avatar
Samuel GAIST committed
984
            print("  - %s: %d data" % (name, len(outputs[name].written_data)))
Philip ABBET's avatar
Philip ABBET committed
985
            if name not in self.irregular_outputs:
Samuel GAIST's avatar
Samuel GAIST committed
986
987
988
989
990
                if not (
                    len(outputs[name].written_data)
                    == next_index / connected_outputs[name]
                ):
                    raise OutputError("Invalid number of data produced")
Philip ABBET's avatar
Philip ABBET committed
991
992
993

        # Check that all outputs ends on the same index
        for name in connected_outputs.keys():
Samuel GAIST's avatar
Samuel GAIST committed
994
995
            if not outputs[name].written_data[-1][1] == next_index - 1:
                raise OutputError("Outputs not on same index")
Philip ABBET's avatar
Philip ABBET committed
996

997
998
        # Generate a text file with lots of details (only if all outputs are
        # connected)
Philip ABBET's avatar
Philip ABBET committed
999
1000
        if len(connected_outputs) == len(self.outputs_declaration):
            sorted_outputs = sorted(outputs, key=lambda x: len(x.written_data))
For faster browsing, not all history is shown. View entire blame