Commit 9a8c3745 authored by Philip ABBET's avatar Philip ABBET

[scripts] Add a script to index a database

parent aec7d89b
......@@ -382,7 +382,7 @@ class Database(object):
return [k['name'] for k in data]
def view(self, protocol, name, exc=None):
def view(self, protocol, name, exc=None, root_folder=None):
"""Returns the database view, given the protocol and the set name
Parameters:
......@@ -426,8 +426,11 @@ class Database(object):
else:
raise #just re-raise the user exception
return Runner(self._module, self.set(protocol, name), self.prefix,
self.data['root_folder'], exc)
if root_folder is None:
root_folder = self.data['root_folder']
return Runner(self._module, self.set(protocol, name),
self.prefix, root_folder, exc)
def json_dumps(self, indent=4):
......
......@@ -150,7 +150,7 @@ def main(arguments=None):
'--disabled-login', '--gecos', '""', '-q',
'beat-nobody'])
if retcode != 0:
send_error(logger, socket, 'sys', 'Failed to create an user with the UID %s' % args['uid'])
message_handler.send_error('Failed to create an user with the UID %s' % args['uid'], 'sys')
return 1
# Next, ensure that the needed files are readable by this user
......
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
###############################################################################
# #
# Copyright (c) 2018 Idiap Research Institute, http://www.idiap.ch/ #
# Contact: beat.support@idiap.ch #
# #
# This file is part of the beat.backend.python module of the BEAT platform. #
# #
# Commercial License Usage #
# Licensees holding valid commercial BEAT licenses may use this file in #
# accordance with the terms contained in a written agreement between you #
# and Idiap. For further information contact tto@idiap.ch #
# #
# Alternatively, this file may be used under the terms of the GNU Affero #
# Public License version 3 as published by the Free Software and appearing #
# in the file LICENSE.AGPL included in the packaging of this file. #
# The BEAT platform is distributed in the hope that it will be useful, but #
# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY #
# or FITNESS FOR A PARTICULAR PURPOSE. #
# #
# You should have received a copy of the GNU Affero Public License along #
# with the BEAT platform. If not, see http://www.gnu.org/licenses/. #
# #
###############################################################################
"""Executes some database views. (%(version)s)
usage:
%(prog)s [--debug] [--uid=UID] [--db_root_folder=root_folder] <prefix> <cache> <database> [<protocol> [<set>]]
%(prog)s (--help)
%(prog)s (--version)
arguments:
<prefix> Path to the prefix
<cache> Path to the cache
<database> Full name of the database
options:
-h, --help Shows this help message and exit
-V, --version Shows program's version number and exit
-d, --debug Runs in debugging mode
--uid=UID UID to run as
--db_root_folder=root_folder Root folder to use for the database data (overrides the
one declared by the database)
"""
import logging
import os
import sys
import docopt
import pwd
from ..database import Database
from ..hash import hashDataset
from ..hash import toPath
#----------------------------------------------------------
def main(arguments=None):
# Parse the command-line arguments
if arguments is None:
arguments = sys.argv[1:]
package = __name__.rsplit('.', 2)[0]
version = package + ' v' + \
__import__('pkg_resources').require(package)[0].version
prog = os.path.basename(sys.argv[0])
args = docopt.docopt(
__doc__ % dict(prog=prog, version=version),
argv=arguments,
version=version
)
# Setup the logging system
formatter = logging.Formatter(fmt="[%(asctime)s - index.py - " \
"%(name)s] %(levelname)s: %(message)s",
datefmt="%d/%b/%Y %H:%M:%S")
handler = logging.StreamHandler()
handler.setFormatter(formatter)
root_logger = logging.getLogger('beat.backend.python')
root_logger.addHandler(handler)
if args['--debug']:
root_logger.setLevel(logging.DEBUG)
else:
root_logger.setLevel(logging.INFO)
logger = logging.getLogger(__name__)
if args['--uid']:
uid = int(args['--uid'])
# First create the user (if it doesn't exists)
try:
user = pwd.getpwuid(uid)
except:
import subprocess
retcode = subprocess.call(['adduser', '--uid', str(uid),
'--no-create-home', '--disabled-password',
'--disabled-login', '--gecos', '""', '-q',
'beat-nobody'])
if retcode != 0:
logger.error('Failed to create an user with the UID %d' % uid)
return 1
# Change the current user
try:
os.setgid(uid)
os.setuid(uid)
except:
import traceback
logger.error(traceback.format_exc())
return 1
# Check the paths
if not os.path.exists(args['<prefix>']):
logger.error('Invalid prefix path: %s' % args['<prefix>'])
return 1
if not os.path.exists(args['<cache>']):
logger.error('Invalid cache path: %s' % args['<cache>'])
return 1
# Indexing
try:
database = Database(args['<prefix>'], args['<database>'])
if args['<protocol>'] is None:
protocols = database.protocol_names
else:
protocols = [ args['<protocol>'] ]
for protocol in protocols:
if args['<set>'] is None:
sets = database.set_names(protocol)
else:
sets = [ args['<set>'] ]
for set_name in sets:
filename = toPath(hashDataset(args['<database>'], protocol, set_name),
suffix='.db')
view = database.view(protocol, set_name, root_folder=args['--db_root_folder'])
view.index(os.path.join(args['<cache>'], filename))
except Exception as e:
import traceback
logger.error(traceback.format_exc())
return 1
return 0
if __name__ == '__main__':
sys.exit(main())
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
###############################################################################
# #
# Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/ #
# Contact: beat.support@idiap.ch #
# #
# This file is part of the beat.backend.python module of the BEAT platform. #
# #
# Commercial License Usage #
# Licensees holding valid commercial BEAT licenses may use this file in #
# accordance with the terms contained in a written agreement between you #
# and Idiap. For further information contact tto@idiap.ch #
# #
# Alternatively, this file may be used under the terms of the GNU Affero #
# Public License version 3 as published by the Free Software and appearing #
# in the file LICENSE.AGPL included in the packaging of this file. #
# The BEAT platform is distributed in the hope that it will be useful, but #
# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY #
# or FITNESS FOR A PARTICULAR PURPOSE. #
# #
# You should have received a copy of the GNU Affero Public License along #
# with the BEAT platform. If not, see http://www.gnu.org/licenses/. #
# #
###############################################################################
# Tests for experiment execution
import os
import unittest
import multiprocessing
import Queue
import tempfile
import shutil
from ..scripts import index
from ..database import Database
from ..hash import hashDataset
from ..hash import toPath
from . import prefix
from . import tmp_prefix
#----------------------------------------------------------
class IndexationProcess(multiprocessing.Process):
def __init__(self, queue, arguments):
super(IndexationProcess, self).__init__()
self.queue = queue
self.arguments = arguments
def run(self):
self.queue.put('STARTED')
index.main(self.arguments)
#----------------------------------------------------------
class TestDatabaseIndexation(unittest.TestCase):
def __init__(self, methodName='runTest'):
super(TestDatabaseIndexation, self).__init__(methodName)
self.databases_indexation_process = None
self.working_dir = None
self.cache_root = None
def setUp(self):
self.shutdown_everything() # In case another test failed badly during its setUp()
self.working_dir = tempfile.mkdtemp(prefix=__name__)
self.cache_root = tempfile.mkdtemp(prefix=__name__)
def tearDown(self):
self.shutdown_everything()
shutil.rmtree(self.working_dir)
shutil.rmtree(self.cache_root)
self.working_dir = None
self.cache_root = None
self.data_source = None
def shutdown_everything(self):
if self.databases_indexation_process is not None:
self.databases_indexation_process.terminate()
self.databases_indexation_process.join()
del self.databases_indexation_process
self.databases_indexation_process = None
def process(self, database, protocol_name=None, set_name=None):
args = [
prefix.paths[0],
self.cache_root,
database,
]
if protocol_name is not None:
args.append(protocol_name)
if set_name is not None:
args.append(set_name)
self.databases_indexation_process = IndexationProcess(multiprocessing.Queue(), args)
self.databases_indexation_process.start()
self.databases_indexation_process.queue.get()
self.databases_indexation_process.join()
del self.databases_indexation_process
self.databases_indexation_process = None
def test_one_set(self):
self.process('integers_db/1', 'double', 'double')
expected_files = [
hashDataset('integers_db/1', 'double', 'double')
]
for filename in expected_files:
self.assertTrue(os.path.exists(os.path.join(self.cache_root,
toPath(filename, suffix='.db'))
))
def test_one_protocol(self):
self.process('integers_db/1', 'two_sets')
expected_files = [
hashDataset('integers_db/1', 'two_sets', 'double'),
hashDataset('integers_db/1', 'two_sets', 'triple')
]
for filename in expected_files:
self.assertTrue(os.path.exists(os.path.join(self.cache_root,
toPath(filename, suffix='.db'))
))
def test_whole_database(self):
self.process('integers_db/1')
expected_files = [
hashDataset('integers_db/1', 'double', 'double'),
hashDataset('integers_db/1', 'triple', 'triple'),
hashDataset('integers_db/1', 'two_sets', 'double'),
hashDataset('integers_db/1', 'two_sets', 'triple'),
hashDataset('integers_db/1', 'labelled', 'labelled'),
hashDataset('integers_db/1', 'different_frequencies', 'double'),
]
for filename in expected_files:
self.assertTrue(os.path.exists(os.path.join(self.cache_root,
toPath(filename, suffix='.db'))
))
def test_error(self):
self.process('crash/1', 'protocol', 'index_crashes')
unexpected_files = [
hashDataset('crash/1', 'protocol', 'index_crashes'),
]
for filename in unexpected_files:
self.assertFalse(os.path.exists(os.path.join(self.cache_root,
toPath(filename, suffix='.db'))
))
......@@ -144,10 +144,10 @@ class DatabasesProviderProcess(multiprocessing.Process):
#----------------------------------------------------------
class TestDatabasesProviderBase(unittest.TestCase):
class TestDatabasesProvider(unittest.TestCase):
def __init__(self, methodName='runTest'):
super(TestDatabasesProviderBase, self).__init__(methodName)
super(TestDatabasesProvider, self).__init__(methodName)
self.databases_provider_process = None
self.working_dir = None
self.cache_root = None
......
......@@ -69,6 +69,7 @@ setup(
'execute = beat.backend.python.scripts.execute:main',
'describe = beat.backend.python.scripts.describe:main',
'databases_provider = beat.backend.python.scripts.databases_provider:main',
'index = beat.backend.python.scripts.index:main',
],
},
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment