test.py 5.62 KB
Newer Older
Guillaume HEUSCH's avatar
Guillaume HEUSCH committed
1
2
3
4
5
6
7
8
#!/usr/bin/env python
# encoding: utf-8
# Guillaume HEUSCH <guillaume.heusch@idiap.ch>
# Fri 23 Dec 09:49:48 CET 2016


"""A few checks on the protocols of the FARGO public database 
"""
9
10
import os, sys
import bob.db.fargo
11

12
13
14
15
16
def db_available(test):
  """Decorator for detecting if the database file is available"""
  from bob.io.base.test_utils import datafile
  from nose.plugins.skip import SkipTest
  import functools
17

18
19
20
21
22
23
24
  @functools.wraps(test)
  def wrapper(*args, **kwargs):
    dbfile = datafile("db.sql3", __name__, None)
    if os.path.exists(dbfile):
      return test(*args, **kwargs)
    else:
      raise SkipTest("The database file '%s' is not available; did you forget to run 'bob_dbmanage.py %s create' ?" % (dbfile, 'fargo'))
25

26
  return wrapper
27
28


29
30
@db_available
def test_clients():
31

32
33
34
35
36
37
38
  # test whether the correct number of clients is returned
  db = bob.db.fargo.Database()
  assert len(db.groups()) == 3
  assert len(db.clients()) == 75
  assert len(db.clients(groups='world')) == 25
  assert len(db.clients(groups='dev')) == 25
  assert len(db.clients(groups='eval')) == 25
39
40


41
42
43
44
45
@db_available
def test_objects():
#  # tests if the right number of File objects is returned
  
  db = bob.db.fargo.Database()
46

47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
  assert len(db.objects(protocol='mc-rgb', groups='world')) == 1000
  assert len(db.objects(protocol='mc-rgb', groups='dev', purposes='enroll')) == 500
  assert len(db.objects(protocol='mc-rgb', groups='dev', purposes='enroll', model_ids=26)) == 20
  assert len(db.objects(protocol='mc-rgb', groups='dev', purposes='probe')) == 500
  assert len(db.objects(protocol='mc-rgb', groups='dev', purposes='probe', model_ids=26)) == 500 # dense probing
  assert len(db.objects(protocol='mc-rgb', groups='eval', purposes='enroll')) == 500
  assert len(db.objects(protocol='mc-rgb', groups='eval', purposes='enroll', model_ids=51)) == 20
  assert len(db.objects(protocol='mc-rgb', groups='eval', purposes='probe')) == 500
  assert len(db.objects(protocol='mc-rgb', groups='eval', purposes='probe', model_ids=51)) == 500 # dense probing

  assert len(db.objects(protocol='ud-nir', groups='world')) == 1000
  assert len(db.objects(protocol='ud-nir', groups='dev', purposes='enroll')) == 500
  assert len(db.objects(protocol='ud-nir', groups='dev', purposes='enroll', model_ids=26)) == 20
  assert len(db.objects(protocol='ud-nir', groups='dev', purposes='probe')) == 1000
  assert len(db.objects(protocol='ud-nir', groups='dev', purposes='probe', model_ids=26)) == 1000 # dense probing
  assert len(db.objects(protocol='ud-nir', groups='eval', purposes='enroll')) == 500
  assert len(db.objects(protocol='ud-nir', groups='eval', purposes='enroll', model_ids=51)) == 20
  assert len(db.objects(protocol='ud-nir', groups='eval', purposes='probe')) == 1000
  assert len(db.objects(protocol='ud-nir', groups='eval', purposes='probe', model_ids=51)) == 1000 # dense probing

  assert len(db.objects(protocol='uo-depth', groups='world')) == 1000
  assert len(db.objects(protocol='uo-depth', groups='dev', purposes='enroll')) == 500
  assert len(db.objects(protocol='uo-depth', groups='dev', purposes='enroll', model_ids=26)) == 20
  assert len(db.objects(protocol='uo-depth', groups='dev', purposes='probe')) == 1000
  assert len(db.objects(protocol='uo-depth', groups='dev', purposes='probe', model_ids=26)) == 1000 # dense probing
  assert len(db.objects(protocol='uo-depth', groups='eval', purposes='enroll')) == 500
  assert len(db.objects(protocol='uo-depth', groups='eval', purposes='enroll', model_ids=51)) == 20
  assert len(db.objects(protocol='uo-depth', groups='eval', purposes='probe')) == 1000
  assert len(db.objects(protocol='uo-depth', groups='eval', purposes='probe', model_ids=51)) == 1000 # dense probing
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133


@db_available
def test_heterogeneous():    
    # Test heterogeous protocols    

    db = bob.db.fargo.Database()

    groups = ["dev", "eval"]

    ##############
    # Testing controlled
    ##############
    protocols = ["mc-rgb2nir", "mc-rgb2depth"]
    probe_modalities = ["nir", "depth"]

    for p, m in zip(protocols, probe_modalities):
        assert len(db.objects(protocol=p)) == 3000
        assert len(db.objects(protocol=p, groups="world")) == 1000

        for g in groups:
            assert len(db.objects(protocol=p, groups="dev")) == 1000
            assert len(db.objects(protocol=p, groups="eval")) == 1000

            # Checking the modalities
            modality = set([o.modality for o in db.objects(protocol=p, groups=g, purposes="enroll")])
            assert len(modality) == 1
            assert list(modality)[0] == "rgb"

            modality = set([o.modality for o in db.objects(protocol=p, groups=g, purposes="probe")])
            assert len(modality) == 1
            assert list(modality)[0] == m
            
    #############
    # Testing UNcontrolled
    #############
    protocols = ["ud-rgb2nir", "ud-rgb2depth",
                 "uo-rgb2nir", "uo-rgb2depth"]
    probe_modalities = ["nir", "depth",
                        "nir", "depth"]
    
    for p, m in zip(protocols, probe_modalities):
        assert len(db.objects(protocol=p)) == 4000
        assert len(db.objects(protocol=p, groups="world")) == 1000

        for g in groups:
            assert len(db.objects(protocol=p, groups="dev")) == 1500
            assert len(db.objects(protocol=p, groups="eval")) == 1500

            # Checking the modalities
            modality = set([o.modality for o in db.objects(protocol=p, groups=g, purposes="enroll")])
            assert len(modality) == 1
            assert list(modality)[0] == "rgb"

            modality = set([o.modality for o in db.objects(protocol=p, groups=g, purposes="probe")])
            assert len(modality) == 1
            assert list(modality)[0] == m