diff --git a/beat/backend/python/data.py b/beat/backend/python/data.py new file mode 100644 index 0000000000000000000000000000000000000000..ac9c59b182852bd644c0e1ce4995f2713a858099 --- /dev/null +++ b/beat/backend/python/data.py @@ -0,0 +1,908 @@ +#!/usr/bin/env python +# vim: set fileencoding=utf-8 : + +############################################################################### +# # +# Copyright (c) 2016 Idiap Research Institute, http://www.idiap.ch/ # +# Contact: beat.support@idiap.ch # +# # +# This file is part of the beat.backend.python module of the BEAT platform. # +# # +# Commercial License Usage # +# Licensees holding valid commercial BEAT licenses may use this file in # +# accordance with the terms contained in a written agreement between you # +# and Idiap. For further information contact tto@idiap.ch # +# # +# Alternatively, this file may be used under the terms of the GNU Affero # +# Public License version 3 as published by the Free Software and appearing # +# in the file LICENSE.AGPL included in the packaging of this file. # +# The BEAT platform is distributed in the hope that it will be useful, but # +# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY # +# or FITNESS FOR A PARTICULAR PURPOSE. # +# # +# You should have received a copy of the GNU Affero Public License along # +# with the BEAT platform. If not, see http://www.gnu.org/licenses/. # +# # +############################################################################### + + +"""Data I/O classes and functions""" + +import os +import re +import glob +import simplejson as json +import select +import time +import tempfile +import abc +from functools import reduce + +import logging +logger = logging.getLogger(__name__) + +import six +from .hash import hashFileContents +from .dataformat import DataFormat +from .algorithm import Algorithm + + +class DataSource(object): + + """Interface of all the Data Sources + + Data Sources are used to provides data to the inputs of an algorithm. + """ + __metaclass__ = abc.ABCMeta + + @abc.abstractmethod + def next(self, load=True): + """Retrieves the next block of data + + Returns: + + A tuple (*data*, *start_index*, *end_index*) + + """ + + pass + + @abc.abstractmethod + def hasMoreData(self): + """Indicates if there is more data to process on some of the inputs""" + + pass + + +class DataSink(object): + + """Interface of all the Data Sinks + + Data Sinks are used by the outputs of an algorithm to write/transmit data. + """ + __metaclass__ = abc.ABCMeta + + @abc.abstractmethod + def write(self, data, start_data_index, end_data_index): + """Writes a block of data + + Parameters: + + data (beat.core.baseformat.baseformat): The block of data to write + + start_data_index (int): Start index of the written data + + end_data_index (int): End index of the written data + + """ + + pass + + @abc.abstractmethod + def isConnected(self): + pass + + +class CachedDataSource(DataSource): + """Data Source that load data from the Cache""" + + def __init__(self): + self.filenames = None + self.cur_file = None + self.cur_file_index = None + self.encoding = None # must be 'binary' or 'json' + self.prefix = None # where to find dataformats + self.dataformat = None # the dataformat itself + self.preloaded = False + self.next_start_index = None + self.next_end_index = None + self.next_data_size = None + self.force_start_index = None + self.force_end_index = None + self._cache_size = 10 * 1024 * 1024 # 10 megabytes + self._cache = six.b('') + self._nb_bytes_read = 0 + self._read_duration = 0 + + def _readHeader(self): + """Read the header of the current file""" + + # Read file format + encoding = self.cur_file.readline()[:-1] + if not isinstance(encoding, str): encoding = encoding.decode('utf8') + + if encoding not in ('binary', 'json'): + raise RuntimeError("valid formats for data reading are 'binary' " + "or 'json': the format `%s' is invalid" % (encoding,)) + self.encoding = encoding + + # Read data format + dataformat_name = self.cur_file.readline()[:-1] + if not isinstance(dataformat_name, str): + dataformat_name = dataformat_name.decode('utf8') + if dataformat_name.startswith('analysis:'): + algo_name = dataformat_name.split(':')[1] + algo = Algorithm(self.prefix, algo_name) + if not algo.valid: + raise RuntimeError("the dataformat `%s' is the result of an " \ + "algorithm which is not valid" % algo_name) + self.dataformat = algo.result_dataformat() + else: + self.dataformat = DataFormat(self.prefix, dataformat_name) + if not self.dataformat.valid: + raise RuntimeError("the dataformat `%s' is not valid" % dataformat_name) + + return True + + def setup(self, filename, prefix, force_start_index=None, + force_end_index=None): + """Configures the data source + + + Parameters: + + filename (str): Name of the file to read the data from + + prefix (str, path): Path to the prefix where the dataformats are stored. + + force_start_index (int): The starting index (if not set or set to + ``None``, the default, read data from the begin of file) + + force_end_index (int): The end index (if not set or set to ``None``, the + default, reads the data until the end) + + + Returns: + + ``True``, if successful, or ``False`` otherwise. + + """ + index_re = re.compile(r'^.*\.(\d+)\.(\d+)\.(data|index)(.checksum)?$') + + def file_start(f): + """Returns the converted start indexes from a filename, otherwise 0""" + + r = index_re.match(f) + if r: return int(r.group(1)) + return 0 + + def trim_filename(l, start_index, end_index): + """Function to trim out the useless file given a range of indices + """ + + res = [] + for f in l: + r = index_re.match(f) + if r: + s = int(r.group(1)) + e = int(r.group(2)) + if (start_index is not None and e < start_index) or \ + (end_index is not None and s > end_index): + continue + res.append(f) + return res + + def check_consistency(data_filenames, basename, data_ext): + """Perform some sanity check on the data/checksum files on disk: + + 1. One-to-one mapping between data and checksum files + 2. Checksum comparison between hash(data) and checksum files + 3. Contiguous indices if they are present + """ + + # Check checksum of files + checksum_filenames = sorted(glob.glob(basename + '*' + data_ext + '.checksum'), key=file_start) + + # Make sure that we have a perfect match between data files and checksum + # files + checksum_filenames_noext = [os.path.splitext(f)[0] for f in checksum_filenames] + + if data_filenames != checksum_filenames_noext: + raise IOError("number of data files and checksum files for `%s' " \ + "does not match (%d != %d)" % (filename, len(data_filenames), + len(checksum_filenames_noext))) + + # list of start/end indices to check that there are contiguous + indices = [] + for f_data, f_chck in zip(data_filenames, checksum_filenames): + + expected_chksum = open(f_chck, 'rt').read().strip() + current_chksum = hashFileContents(f_data) + if expected_chksum != current_chksum: + raise IOError("data file `%s' has a checksum (%s) that differs " \ + "from expected one (%s)" % (f_data, current_chksum, + expected_chksum)) + + r = index_re.match(f_data) + if r: indices.append((int(r.group(1)), int(r.group(2)))) + + indices = sorted(indices, key=lambda v: v[0]) + ok_indices = True + + if len(indices) > 0: + ok_indices = (indices[0][0] == 0) + + if ok_indices and len(indices) > 1: + ok_indices = sum([indices[i + 1][0] - indices[i][1] == 1 + for i in range(len(indices) - 1)]) + + if not ok_indices: + raise IOError("data file `%s' have missing indices." % f_data) + + self.prefix = prefix + basename, data_ext = os.path.splitext(filename) + data_filenames = sorted(glob.glob(basename + '*' + data_ext), + key=file_start) + + # Check consistency of the data/checksum files + check_consistency(data_filenames, basename, data_ext) + + # List files to process + self.force_start_index = force_start_index + self.force_end_index = force_end_index + self.filenames = trim_filename(data_filenames, force_start_index, + force_end_index) + + # Read the first file to process + self.cur_file_index = 0 + try: + self.cur_file = open(self.filenames[self.cur_file_index], 'rb') + except Exception as e: + logger.warn("Could not setup `%s': %s" % (filename, e)) + return False + + # Reads the header of the current file + self._readHeader() + + if force_start_index is not None: + + # Read chunck until force_start_index is reached + while True: + + # Read indexes + t1 = time.time() + line = self.cur_file.readline() + self._nb_bytes_read += len(line) + t2 = time.time() + self._read_duration += t2 - t1 + + (self.next_start_index, self.next_end_index, self.next_data_size) = \ + [int(x) for x in line.split()] + + # Seek to the next chunck of data if start index is still too small + if self.next_start_index < force_start_index: + self.cur_file.seek(self.next_data_size, 1) + + # Otherwise, read the next 'chunck' of data (binary or json) + else: + t1 = time.time() + data = self.cur_file.read(self._cache_size - len(self._cache)) + t2 = time.time() + + self._nb_bytes_read += len(data) + self._read_duration += t2 - t1 + self._cache += data + + self.preloaded = True + break + + else: + # Preload the data + self._preload() + + return True + + def close(self): + """Closes the data source""" + + if self.cur_file is not None: + self.cur_file.close() + + def __del__(self): + """Makes sure the files are close when the object is deleted""" + + self.close() + + def next(self): + """Retrieve the next block of data + + Returns: + + A tuple (data, start_index, end_index) + + """ + if self.next_start_index is None: + return (None, None, None) + + # Determine if the cache already contains all the data we need + if len(self._cache) >= self.next_data_size: + encoded_data = self._cache[:self.next_data_size] + self._cache = self._cache[self.next_data_size:] + else: + t1 = time.time() + data = self.cur_file.read(self.next_data_size - len(self._cache)) + t2 = time.time() + + self._nb_bytes_read += len(data) + self._read_duration += t2 - t1 + + encoded_data = self._cache + data + self._cache = six.b('') + + data = self.dataformat.type() + data.unpack(encoded_data) #checks validity + + result = (data, self.next_start_index, self.next_end_index) + + self._preload() + + return result + + def hasMoreData(self): + """Indicates if there is more data to process on some of the inputs""" + + if not(self.preloaded): + self._preload(blocking=True) + + if self.force_end_index is not None and \ + self.next_start_index is not None and \ + self.next_start_index > self.force_end_index: + return False + + return (self.next_start_index is not None) + + def statistics(self): + """Return the statistics about the number of bytes read from the cache""" + return (self._nb_bytes_read, self._read_duration) + + def _preload(self, blocking=False): + # Determine if the cache already contains all the data we need + offset = self._cache.find(six.b('\n')) + if offset == -1: + + # Try to read the next chunck of data + while True: + + # Read in the current file + t1 = time.time() + + if blocking: + (readable, writable, errors) = select.select([self.cur_file], [], []) + + data = self.cur_file.read(self._cache_size - len(self._cache)) + + t2 = time.time() + + self._nb_bytes_read += len(data) + self._read_duration += t2 - t1 + self._cache += data + + # If not read from the current file + if (len(data) == 0) or (self._cache.find(six.b('\n')) == -1): + # Read the next one if possible + if self.cur_file_index < len(self.filenames) - 1: + + if self.cur_file is not None: + self.cur_file.close() + + self.cur_file_index += 1 + + try: + self.cur_file = open(self.filenames[self.cur_file_index], 'rb') + except: + return + + self._readHeader() + + # Otherwise, stop the parsing + else: + self.next_start_index = None + self.next_end_index = None + self.next_data_size = None + self.preloaded = blocking + return + + else: + break + + offset = self._cache.find(six.b('\n')) + + # Extract the informations about the next block of data + line = self._cache[:offset] + self._cache = self._cache[offset + 1:] + + (self.next_start_index, self.next_end_index, self.next_data_size) = \ + [int(x) for x in line.split()] + + self.preloaded = True + + +class CachedDataSink(DataSink): + + """Data Sink that save data in the Cache + + The default behavior is to save the data in a binary format. + """ + + def __init__(self): + + self.filename = None + self.process_id = None + self.split_id = None + self.max_size = None + + self._nb_bytes_written = 0 + self._write_duration = 0 + self._nb_bytes_written_split = 0 + + self._new_file = False + + self._cur_filename = None + self._cur_file = None + self._cur_indexname = None + self._cur_index = None + + self._cur_start_index = None + self._cur_end_index = None + self._filenames = [] + self._filenames_tmp = [] + + self._tmp_ext = '.tmp' + + self.encoding = None + self.dataformat = None + + def _curTmpFilenameWithSplit(self): + + filename, data_ext = os.path.splitext(self.filename) + dirname = os.path.dirname(filename) + basename = os.path.basename(filename) + fd, tmp_file = tempfile.mkstemp( + dir=dirname, + prefix=basename+'.' + str(self.process_id)+'.'+ str(self.split_id)+'_', + suffix=data_ext + self._tmp_ext, + ) + os.close(fd) # Preserve only the name + os.unlink(tmp_file) + return tmp_file + + def _curFilenameWithIndices(self): + + basename = os.path.basename(self.filename) + basename, data_ext = os.path.splitext(basename) + dirname = os.path.dirname(self.filename) + return os.path.join(dirname, basename + '.' + str(self._cur_start_index) + '.' + str(self._cur_end_index) + data_ext) + + def _tmpIndexFilenameFromTmpFilename(self, tmp_filename): + + return os.path.splitext(os.path.splitext(tmp_filename)[0])[0] + '.index' + self._tmp_ext + + def _indexFilenameFromFilename(self, filename): + return os.path.splitext(filename)[0] + '.index' + + def _openAndWriteHeader(self): + """Write the header of the current file""" + + # Close current file if open + self._close_current() + + # Open new file in writing mode + self._cur_filename = self._curTmpFilenameWithSplit() + self._cur_indexname = \ + self._tmpIndexFilenameFromTmpFilename(self._cur_filename) + self._filenames_tmp.append(self._cur_filename) + try: + self._cur_file = open(self._cur_filename, 'wb') + self._cur_index = open(self._cur_indexname, 'wt') + except: + return + + # Write dataformat + self._cur_file.write(six.b('%s\n%s\n' % \ + (self.encoding, self.dataformat.name))) + self._cur_file.flush() + + # Reset few flags + self._cur_start_index = None + self._cur_end_index = None + self._new_file = False + self._nb_bytes_written_split = 0 + + def setup(self, filename, dataformat, encoding='binary', process_id=0, + max_size=0): + """Configures the data sink + + Parameters: + + filename (str): Name of the file to generate + + dataformat (beat.core.dataformat.DataFormat): The dataformat to be used + inside this file. All objects stored inside this file will respect that + format. + + encoding (str): String defining the encoding to be used for encoding the + data. Only a few options are supported: ``binary`` (the default) or + ``json`` (debugging purposes). + + """ + + if encoding not in ('binary', 'json'): + raise RuntimeError("valid formats for data writting are 'binary' " + "or 'json': the format `%s' is invalid" % format) + + if dataformat.name == '__unnamed_dataformat__': + raise RuntimeError("cannot record data using an unnammed data format") + + self.filename = filename + self.process_id = process_id + self.split_id = 0 + self.max_size = max_size + + self._nb_bytes_written = 0 + self._write_duration = 0 + self._new_file = True + + self._cur_filename = None + self._cur_file = None + self._cur_indexname = None + self._cur_index = None + self._cur_start_index = None + self._cur_end_index = None + + self._filenames = [] + self._filenames_tmp = [] + + self.dataformat = dataformat + self.encoding = encoding + + return True + + def _close_current(self): + """Closes the data sink + """ + + if self._cur_file is not None: + self._cur_file.close() + self._cur_index.close() + + # If file is empty, remove it + if self._cur_start_index is None or self._cur_end_index is None: + try: + os.remove(self._cur_filename) + os.remove(self._cur_index) + except: + return False + self._filenames_tmp.pop() + + # Otherwise, append final filename to list + else: + self._filenames.append(self._curFilenameWithIndices()) + + self._cur_filename = None + self._cur_file = None + self._cur_indexname = None + self._cur_index = None + + def close(self): + """Move the files to final location + """ + + self._close_current() + assert len(self._filenames_tmp) == len(self._filenames) + + for i in range(len(self._filenames_tmp)): + + os.rename(self._filenames_tmp[i], self._filenames[i]) + tmp_indexname = \ + self._tmpIndexFilenameFromTmpFilename(self._filenames_tmp[i]) + final_indexname = self._indexFilenameFromFilename(self._filenames[i]) + os.rename(tmp_indexname, final_indexname) + + # creates the checksums for all data and indexes + chksum_data = hashFileContents(self._filenames[i]) + with open(self._filenames[i] + '.checksum', 'wt') as f: + f.write(chksum_data) + chksum_index = hashFileContents(final_indexname) + with open(final_indexname + '.checksum', 'wt') as f: + f.write(chksum_index) + + self._cur_filename = None + self._cur_file = None + self._cur_indexname = None + self._cur_index = None + + self._cur_start_index = None + self._cur_end_index = None + self._filenames = [] + self._filenames_tmp = [] + + def reset(self): + """Move the files to final location + """ + + self._close_current() + assert len(self._filenames_tmp) == len(self._filenames) + for i in range(len(self._filenames_tmp)): + try: + os.remove(self._filenames_tmp[i]) + tmp_indexname = \ + self._tmpIndexFilenameFromTmpFilename(self._filenames_tmp[i]) + os.remove(tmp_indexname) + except: + return False + + self._cur_filename = None + self._cur_file = None + self._cur_indexname = None + self._cur_index = None + + self._cur_start_index = None + self._cur_end_index = None + self._filenames = [] + self._filenames_tmp = [] + + def __del__(self): + """Make sure the files are close and renamed when the object is deleted + """ + + self.close() + + def write(self, data, start_data_index, end_data_index): + """Writes a block of data to the filesystem + + Parameters: + + data (beat.core.baseformat.baseformat): The block of data to write + + start_data_index (int): Start index of the written data + + end_data_index (int): End index of the written data + + """ + + if isinstance(data, dict): + # the user passed a dictionary - must convert + data = self.dataformat.type(**data) + else: + # Checks that the input data conforms to the expected format + if data.__class__._name != self.dataformat.name: + raise TypeError("input data uses format `%s' while this sink " + "expects `%s'" % (data.__class__._name, self.dataformat)) + + # If the flag new_file is set, open new file and write header + if self._new_file: + self._openAndWriteHeader() + + if self._cur_file is None: + raise RuntimeError("no destination file") + + # encoding happens here + if self.encoding == 'binary': + encoded_data = data.pack() + else: + from .utils import NumpyJSONEncoder + encoded_data = json.dumps(data.as_dict(), indent=4, cls=NumpyJSONEncoder) + + # adds a new line by the end of the encoded data, for clarity + encoded_data += six.b('\n') + + informations = six.b('%d %d %d\n' % (start_data_index, + end_data_index, len(encoded_data))) + + t1 = time.time() + + self._cur_file.write(informations + encoded_data) + self._cur_file.flush() + + indexes = '%d %d\n' % (start_data_index, end_data_index) + self._cur_index.write(indexes) + self._cur_index.flush() + + t2 = time.time() + + self._nb_bytes_written += \ + len(informations) + len(encoded_data) + len(indexes) + self._nb_bytes_written_split += \ + len(informations) + len(encoded_data) + len(indexes) + self._write_duration += t2 - t1 + + # Update start and end indices + if self._cur_start_index is None: + self._cur_start_index = start_data_index + self._cur_end_index = end_data_index + + # If file size exceeds max, sets the flag to create a new file + if self.max_size != 0 and self._nb_bytes_written >= self.max_size: + self._new_file = True + self.split_id += 1 + + def statistics(self): + """Return the statistics about the number of bytes written to the cache""" + + return (self._nb_bytes_written, self._write_duration) + + def isConnected(self): + return (self.filename is not None) + + +class MemoryDataSource(DataSource): + + """Interface of all the Data Sources + + Data Sources are used to provides data to the inputs of an algorithm. + """ + + def __init__(self, done_callback, next_callback=None, index=None): + + self.data = [] + self._done_callback = done_callback + self._next_callback = next_callback + + def add(self, data, start_data_index, end_data_index): + self.data.append((data, start_data_index, end_data_index)) + + def next(self): + """Retrieves the next block of data + + :return: A tuple (*data*, *start_index*, *end_index*) + """ + + if (len(self.data) == 0) and (self._next_callback is not None): + if not(self._done_callback()): + self._next_callback() + + if len(self.data) == 0: + return (None, None, None) + + return self.data.pop(0) + + def hasMoreData(self): + + if len(self.data) != 0: + return True + return not self._done_callback() + + def statistics(self): + """Return the statistics about the number of bytes read from the cache""" + + return (0, 0) + + +class MemoryDataSink(DataSink): + + """Data Sink that directly transmit data to associated MemoryDataSource + objects. + """ + + def __init__(self): + self.data_sources = None + + def setup(self, data_sources): + """Configure the data sink + + :param list data_sources: The MemoryDataSource objects to use + """ + self.data_sources = data_sources + + def write(self, data, start_data_index, end_data_index): + """Write a block of data + + Parameters: + + data (beat.core.baseformat.baseformat) The block of data to write + + start_data_index (int): Start index of the written data + + end_data_index (int): End index of the written data + + """ + for data_source in self.data_sources: + data_source.add(data, start_data_index, end_data_index) + + def isConnected(self): + return len(self.data_sources) > 0 + + +def load_data_index(cache_prefix, hash_path): + """Loads a cached-data index if it exists. Returns empty otherwise. + + Parameters: + + cache_prefix (str): The path to the root of the cache directory + + hash_path (str): The hashed path of the input you wish to load the indexes + for, as it is returned by the utility function + :py:func:`beat.core.hash.toPath`. + + + Returns: + + A list, which will be empty if the index file is not present. Note that, + given the current design, an empty list means an error condition. + + """ + + name_no_extension = os.path.splitext(hash_path)[0] # remove .data + index_stem = os.path.join(cache_prefix, name_no_extension) + index_glob = index_stem + '*.index' + + candidates = glob.glob(index_glob) + + assert candidates, "No index file matching the pattern `%s' found." % \ + index_glob + + retval = [] + end_index = 0 + for filename in candidates: + with open(filename, 'rt') as f: + data = [k.split() for k in f.readlines() if k.strip()] + start = [int(k[0]) for k in data] + end = int(data[-1][1]) # last index + + # checks if the sum exists and is correct, only appends in that case + # returns an empty list otherwise, as these indices are considered + # invalid. + expected_chksum = open(filename + '.checksum', 'rt').read().strip() + + current_chksum = hashFileContents(filename) + assert expected_chksum == current_chksum, "index file `%s' has a " \ + "checksum (%s) that differs from expected one (%s)" % \ + (filename, current_chksum, expected_chksum) + + # else, group indices + retval.extend(start) + if end > end_index: + end_index = end + + return sorted(retval) + [end_index + 1] + + +def _foundCommonIndices(lst): + """Returns the list of common indices, given a list of list of indices + """ + + if lst == []: + return lst + lst_set = [set(k) for k in lst] + common_indices = sorted(list(reduce(set.intersection, lst_set))) + return common_indices + + +def foundSplitRanges(lst, n_split): + """Splits a list of lists of indices into n splits for parallelization + purposes. """ + + if [] in lst or lst == []: + return [] + ci_lst = _foundCommonIndices(lst) + res = [] + average_length = (float)(ci_lst[-1]) / n_split + c = 0 + s = 0 + for i in range(1, len(ci_lst)): + if (ci_lst[i] - ci_lst[s] >= average_length and c < n_split - 1) or \ + len(ci_lst) - i <= n_split - c: + res.append((ci_lst[s], ci_lst[i] - 1)) + s = i + c += 1 + return res diff --git a/beat/backend/python/database.py b/beat/backend/python/database.py new file mode 100644 index 0000000000000000000000000000000000000000..f4114bd1d98309a977c6f53fae0549650d2747f6 --- /dev/null +++ b/beat/backend/python/database.py @@ -0,0 +1,325 @@ +#!/usr/bin/env python +# vim: set fileencoding=utf-8 : + +############################################################################### +# # +# Copyright (c) 2016 Idiap Research Institute, http://www.idiap.ch/ # +# Contact: beat.support@idiap.ch # +# # +# This file is part of the beat.backend.python module of the BEAT platform. # +# # +# Commercial License Usage # +# Licensees holding valid commercial BEAT licenses may use this file in # +# accordance with the terms contained in a written agreement between you # +# and Idiap. For further information contact tto@idiap.ch # +# # +# Alternatively, this file may be used under the terms of the GNU Affero # +# Public License version 3 as published by the Free Software and appearing # +# in the file LICENSE.AGPL included in the packaging of this file. # +# The BEAT platform is distributed in the hope that it will be useful, but # +# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY # +# or FITNESS FOR A PARTICULAR PURPOSE. # +# # +# You should have received a copy of the GNU Affero Public License along # +# with the BEAT platform. If not, see http://www.gnu.org/licenses/. # +# # +############################################################################### + + +"""Validation of databases""" + +import os +import sys +# import collections + +import six +import simplejson + +from . import loader + +# from . import dataformat +# from . import hash +# from . import utils +# from . import prototypes + + +class View(object): + '''A special loader class for database views, with specialized methods + + Parameters: + + db_name (str): The full name of the database object for this view + + module (module): The preloaded module containing the database views as + returned by :py:func:`beat.core.loader.load_module`. + + prefix (str, path): The prefix path for the current installation + + root_folder (str, path): The path pointing to the root folder of this + database + + exc (class): The class to use as base exception when translating the + exception from the user code. Read the documention of :py:func:`run` + for more details. + + *args: Constructor parameters for the database view. Normally, none. + + **kwargs: Constructor parameters for the database view. Normally, none. + + ''' + + + def __init__(self, module, definition, prefix, root_folder, exc=None, + *args, **kwargs): + + try: + class_ = getattr(module, definition['view']) + except Exception as e: + if exc is not None: + type, value, traceback = sys.exc_info() + six.reraise(exc, exc(value), traceback) + else: + raise #just re-raise the user exception + + self.obj = loader.run(class_, '__new__', exc, *args, **kwargs) + self.ready = False + self.prefix = prefix + self.root_folder = root_folder + self.definition = definition + self.exc = exc or RuntimeError + self.outputs = None + + + def prepare_outputs(self): + '''Prepares the outputs of the dataset''' + + from .outputs import Output, OutputList + from .data import MemoryDataSink + from .dataformat import DataFormat + + # create the stock outputs for this dataset, so data is dumped + # on a in-memory sink + self.outputs = OutputList() + for out_name, out_format in self.definition.get('outputs', {}).items(): + data_sink = MemoryDataSink() + data_sink.dataformat = DataFormat(self.prefix, out_format) + data_sink.setup([]) + self.outputs.add(Output(out_name, data_sink, dataset_output=True)) + + + def setup(self, *args, **kwargs): + '''Sets up the view''' + + kwargs.setdefault('root_folder', self.root_folder) + kwargs.setdefault('parameters', self.definition.get('parameters', {})) + + if 'outputs' not in kwargs: + kwargs['outputs'] = self.outputs + else: + self.outputs = kwargs['outputs'] #record outputs nevertheless + + self.ready = loader.run(self.obj, 'setup', self.exc, *args, **kwargs) + + if not self.ready: + raise self.exc("unknow setup failure") + + return self.ready + + + def input_group(self, name='default', exclude_outputs=[]): + '''A memory-source input group matching the outputs from the view''' + + if not self.ready: + raise self.exc("database view not yet setup") + + from .data import MemoryDataSource + from .outputs import SynchronizationListener + from .inputs import Input, InputGroup + + # Setup the inputs + synchronization_listener = SynchronizationListener() + input_group = InputGroup(name, + synchronization_listener=synchronization_listener, + restricted_access=False) + + for output in self.outputs: + if output.name in exclude_outputs: continue + data_source = MemoryDataSource(self.done, next_callback=self.next) + output.data_sink.data_sources.append(data_source) + input_group.add(Input(output.name, + output.data_sink.dataformat, data_source)) + + return input_group + + + def done(self, *args, **kwargs): + '''Checks if the view is done''' + + if not self.ready: + raise self.exc("database view not yet setup") + + return loader.run(self.obj, 'done', self.exc, *args, **kwargs) + + + def next(self, *args, **kwargs): + '''Runs through the next data chunk''' + + if not self.ready: + raise self.exc("database view not yet setup") + return loader.run(self.obj, 'next', self.exc, *args, **kwargs) + + + def __getattr__(self, key): + '''Returns an attribute of the view - only called at last resort''' + return getattr(self.obj, key) + + + +class Database(object): + """Databases define the start point of the dataflow in an experiment. + + + Parameters: + + prefix (str): Establishes the prefix of your installation. + + name (str): The fully qualified database name (e.g. ``db/1``) + + dataformat_cache (dict, optional): A dictionary mapping dataformat names + to loaded dataformats. This parameter is optional and, if passed, may + greatly speed-up database loading times as dataformats that are already + loaded may be re-used. If you use this parameter, you must guarantee + that the cache is refreshed as appropriate in case the underlying + dataformats change. + + + Attributes: + + name (str): The full, valid name of this database + + data (dict): The original data for this database, as loaded by our JSON + decoder. + + """ + + def __init__(self, prefix, name, dataformat_cache=None): + + self._name = None + self.prefix = prefix + self.dataformats = {} # preloaded dataformats + + self.data = None + + # if the user has not provided a cache, still use one for performance + dataformat_cache = dataformat_cache if dataformat_cache is not None else {} + + self._load(name, dataformat_cache) + + + def _load(self, data, dataformat_cache): + """Loads the database""" + + self._name = data + json_path = os.path.join(prefix, 'databases', name + '.json') + with open(json_path, 'rb') as f: self.data = simplejson.load(f) + + + @property + def name(self): + """Returns the name of this object + """ + return self._name or '__unnamed_database__' + + + @property + def schema_version(self): + """Returns the schema version""" + return self.data.get('schema_version', 1) + + + @property + def protocols(self): + """The declaration of all the protocols of the database""" + + data = self.data['protocols'] + return dict(zip([k['name'] for k in data], data)) + + + def protocol(self, name): + """The declaration of a specific protocol in the database""" + + return self.protocols[name] + + + @property + def protocol_names(self): + """Names of protocols declared for this database""" + + data = self.data['protocols'] + return [k['name'] for k in data] + + + def sets(self, protocol): + """The declaration of a specific set in the database protocol""" + + data = self.protocol(protocol)['sets'] + return dict(zip([k['name'] for k in data], data)) + + + def set(self, protocol, name): + """The declaration of all the protocols of the database""" + + return self.sets(protocol)[name] + + + def set_names(self, protocol): + """The names of sets in a given protocol for this database""" + + data = self.protocol(protocol)['sets'] + return [k['name'] for k in data] + + + def view(self, protocol, name, exc=None): + """Returns the database view, given the protocol and the set name + + Parameters: + + protocol (str): The name of the protocol where to retrieve the view from + + name (str): The name of the set in the protocol where to retrieve the + view from + + exc (class): If passed, must be a valid exception class that will be + used to report errors in the read-out of this database's view. + + Returns: + + The database view, which will be constructed, but not setup. You + **must** set it up before using methods ``done`` or ``next``. + + """ + + if not self._name: + exc = exc or RuntimeError + raise exc("database has no name") + + if not self.valid: + message = "cannot load view for set `%s' of protocol `%s' " \ + "from invalid database (%s)" % (protocol, name, self.name) + if exc: raise exc(message) + raise RuntimeError(message) + + # loads the module only once through the lifetime of the database object + try: + if not hasattr(self, '_module'): + self._module = loader.load_module(self.name.replace(os.sep, '_'), + self.storage.code.path, {}) + except Exception as e: + if exc is not None: + type, value, traceback = sys.exc_info() + six.reraise(exc, exc(value), traceback) + else: + raise #just re-raise the user exception + + return View(self._module, self.set(protocol, name), self.prefix, + self.data['root_folder'], exc) diff --git a/beat/backend/python/dbexecution.py b/beat/backend/python/dbexecution.py new file mode 100644 index 0000000000000000000000000000000000000000..899a699257ffe5f3397088596c11e55371727488 --- /dev/null +++ b/beat/backend/python/dbexecution.py @@ -0,0 +1,282 @@ +#!/usr/bin/env python +# vim: set fileencoding=utf-8 : + +############################################################################### +# # +# Copyright (c) 2016 Idiap Research Institute, http://www.idiap.ch/ # +# Contact: beat.support@idiap.ch # +# # +# This file is part of the beat.backend.python module of the BEAT platform. # +# # +# Commercial License Usage # +# Licensees holding valid commercial BEAT licenses may use this file in # +# accordance with the terms contained in a written agreement between you # +# and Idiap. For further information contact tto@idiap.ch # +# # +# Alternatively, this file may be used under the terms of the GNU Affero # +# Public License version 3 as published by the Free Software and appearing # +# in the file LICENSE.AGPL included in the packaging of this file. # +# The BEAT platform is distributed in the hope that it will be useful, but # +# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY # +# or FITNESS FOR A PARTICULAR PURPOSE. # +# # +# You should have received a copy of the GNU Affero Public License along # +# with the BEAT platform. If not, see http://www.gnu.org/licenses/. # +# # +############################################################################### + + +'''Execution utilities''' + +import os +import sys +import glob +import errno +import tempfile +import subprocess + +import logging +logger = logging.getLogger(__name__) + +import simplejson + +# from . import schema +from . import database +from . import inputs +from . import outputs +from . import data +from . import message_handler + + +class DBExecutor(object): + """Executor specialised in database views + + + Parameters: + + prefix (str): Establishes the prefix of your installation. + + data (dict, str): The piece of data representing the block to be executed. + It must validate against the schema defined for execution blocks. If a + string is passed, it is supposed to be a fully qualified absolute path to + a JSON file containing the block execution information. + + dataformat_cache (dict, optional): A dictionary mapping dataformat names to + loaded dataformats. This parameter is optional and, if passed, may + greatly speed-up database loading times as dataformats that are already + loaded may be re-used. If you use this parameter, you must guarantee that + the cache is refreshed as appropriate in case the underlying dataformats + change. + + database_cache (dict, optional): A dictionary mapping database names to + loaded databases. This parameter is optional and, if passed, may + greatly speed-up database loading times as databases that are already + loaded may be re-used. If you use this parameter, you must guarantee that + the cache is refreshed as appropriate in case the underlying databases + change. + + + Attributes: + + errors (list): A list containing errors found while loading this execution + block. + + data (dict): The original data for this executor, as loaded by our JSON + decoder. + + databases (dict): A dictionary in which keys are strings with database + names and values are :py:class:`database.Database`, representing the + databases required for running this block. The dictionary may be empty + in case all inputs are taken from the file cache. + + views (dict): A dictionary in which the keys are tuples pointing to the + ``(<database-name>, <protocol>, <set>)`` and the value is a setup view + for that particular combination of details. The dictionary may be empty + in case all inputs are taken from the file cache. + + input_list (beat.core.inputs.InputList): A list of inputs that will be + served to the algorithm. + + data_sources (list): A list with all data-sources created by our execution + loader. + + """ + + def __init__(self, prefix, data, dataformat_cache=None, database_cache=None): + + self.prefix = prefix + + # some attributes + self.databases = {} + self.views = {} + self.input_list = None + self.data_sources = [] + self.handler = None + self.errors = [] + self.data = None + + # temporary caches, if the user has not set them, for performance + database_cache = database_cache if database_cache is not None else {} + self.dataformat_cache = dataformat_cache if dataformat_cache is not None else {} + + self._load(data, self.dataformat_cache, database_cache) + + + def _load(self, data, dataformat_cache, database_cache): + """Loads the block execution information""" + + # reset + self.data = None + self.errors = [] + self.databases = {} + self.views = {} + self.input_list = None + self.data_sources = [] + + if not isinstance(data, dict): #user has passed a file name + if not os.path.exists(data): + self.errors.append('File not found: %s' % data) + return + + with open(data) as f: + self.data = simplejson.load(f) + else: + self.data = data + + # this runs basic validation, including JSON loading if required + # self.data, self.errors = schema.validate('execution', data) + # if self.errors: return #don't proceed with the rest of validation + + # load databases + for name, details in self.data['inputs'].items(): + if 'database' in details: + + if details['database'] not in self.databases: + + if details['database'] in database_cache: #reuse + db = database_cache[details['database']] + else: #load it + db = database.Database(self.prefix, details['database'], + dataformat_cache) + database_cache[db.name] = db + + self.databases[details['database']] = db + + if not db.valid: + self.errors += db.errors + continue + + if not db.valid: + # do not add errors again + continue + + # create and load the required views + key = (details['database'], details['protocol'], details['set']) + if key not in self.views: + view = self.databases[details['database']].view(details['protocol'], + details['set']) + + if details['channel'] == self.data['channel']: #synchronized + start_index, end_index = self.data.get('range', (None, None)) + else: + start_index, end_index = (None, None) + view.prepare_outputs() + self.views[key] = (view, start_index, end_index) + + + def __enter__(self): + """Prepares inputs and outputs for the processing task + + Raises: + + IOError: in case something cannot be properly setup + + """ + + self._prepare_inputs() + + # The setup() of a database view may call isConnected() on an input + # to set the index at the right location when parallelization is enabled. + # This is why setup() should be called after initialized the inputs. + for key, (view, start_index, end_index) in self.views.items(): + + if (start_index is None) and (end_index is None): + status = view.setup() + else: + status = view.setup(force_start_index=start_index, + force_end_index=end_index) + + if not status: + raise RuntimeError("Could not setup database view `%s'" % key) + + return self + + + def __exit__(self, exc_type, exc_value, traceback): + """Closes all sinks and disconnects inputs and outputs + """ + self.input_list = None + self.data_sources = [] + + + def _prepare_inputs(self): + """Prepares all input required by the execution.""" + + self.input_list = inputs.InputList() + + # This is used for parallelization purposes + start_index, end_index = self.data.get('range', (None, None)) + + for name, details in self.data['inputs'].items(): + + if 'database' in details: #it is a dataset input + + view_key = (details['database'], details['protocol'], details['set']) + view = self.views[view_key][0] + + data_source = data.MemoryDataSource(view.done, next_callback=view.next) + self.data_sources.append(data_source) + output = view.outputs[details['output']] + + # if it's a synchronized channel, makes the output start at the right + # index, otherwise, it gets lost + if start_index is not None and \ + details['channel'] == self.data['channel']: + output.last_written_data_index = start_index - 1 + output.data_sink.data_sources.append(data_source) + + # Synchronization bits + group = self.input_list.group(details['channel']) + if group is None: + group = inputs.InputGroup( + details['channel'], + synchronization_listener=outputs.SynchronizationListener(), + restricted_access=(details['channel'] == self.data['channel']) + ) + self.input_list.add(group) + + input_db = self.databases[details['database']] + input_dataformat_name = input_db.set(details['protocol'], details['set'])['outputs'][details['output']] + group.add(inputs.Input(name, self.dataformat_cache[input_dataformat_name], data_source)) + + + def process(self, zmq_context, zmq_socket): + + self.handler = message_handler.MessageHandler(self.input_list, zmq_context, zmq_socket) + self.handler.start() + + + @property + def valid(self): + """A boolean that indicates if this executor is valid or not""" + + return not bool(self.errors) + + + def wait(self): + self.handler.join() + self.handler = None + + + def __str__(self): + return simplejson.dumps(self.data, indent=4) diff --git a/beat/backend/python/hash.py b/beat/backend/python/hash.py new file mode 100644 index 0000000000000000000000000000000000000000..82c3f7824945bd369b34fd9393390168ba170ed2 --- /dev/null +++ b/beat/backend/python/hash.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python +# vim: set fileencoding=utf-8 : + +############################################################################### +# # +# Copyright (c) 2016 Idiap Research Institute, http://www.idiap.ch/ # +# Contact: beat.support@idiap.ch # +# # +# This file is part of the beat.backend.python module of the BEAT platform. # +# # +# Commercial License Usage # +# Licensees holding valid commercial BEAT licenses may use this file in # +# accordance with the terms contained in a written agreement between you # +# and Idiap. For further information contact tto@idiap.ch # +# # +# Alternatively, this file may be used under the terms of the GNU Affero # +# Public License version 3 as published by the Free Software and appearing # +# in the file LICENSE.AGPL included in the packaging of this file. # +# The BEAT platform is distributed in the hope that it will be useful, but # +# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY # +# or FITNESS FOR A PARTICULAR PURPOSE. # +# # +# You should have received a copy of the GNU Affero Public License along # +# with the BEAT platform. If not, see http://www.gnu.org/licenses/. # +# # +############################################################################### + + +"""Various functions for hashing platform contributions and others""" + + +import hashlib + + +def hashFileContents(path): + """Hashes the file contents using :py:func:`hashlib.sha256`.""" + + with open(path, 'rb') as f: + return hashlib.sha256(f.read()).hexdigest() diff --git a/beat/backend/python/outputs.py b/beat/backend/python/outputs.py index 6eef4484fa439035405b8b5928d0aee556b673b7..6e05c72ff455c1e2e1fe6b633e4960ca8f14ae84 100644 --- a/beat/backend/python/outputs.py +++ b/beat/backend/python/outputs.py @@ -36,9 +36,141 @@ import zmq from .baseformat import baseformat +class SynchronizationListener: + """A callback mechanism to keep Inputs and Outputs in groups and lists + synchronized together.""" + + def __init__(self): + self.data_index_start = -1 + self.data_index_end = -1 + + def onIntervalChanged(self, data_index_start, data_index_end): + self.data_index_start = data_index_start + self.data_index_end = data_index_end + + +#---------------------------------------------------------- + + class Output: """Represents one output of a processing block + A list of outputs implementing this interface is provided to the algorithms + (see :py:class:`beat.core.outputs.OutputList`). + + + Parameters: + + name (str): Name of the output + + data_sink (beat.core.data.DataSink): Sink of data to be used by the output, + pre-configured with the correct data format. + + + Attributes: + + name (str): Name of the output (algorithm-specific) + + data_sink (beat.core.data.DataSink): Sink of data used by the output + + last_written_data_index (int): Index of the last block of data written by + the output + + nb_data_blocks_written (int): Number of data blocks written so far + + + """ + + def __init__(self, name, data_sink, synchronization_listener=None, + dataset_output=False, force_start_index=0): + + self.name = name + self.data_sink = data_sink + self._synchronization_listener = synchronization_listener + self._dataset_output = dataset_output + self.last_written_data_index = force_start_index-1 + self.nb_data_blocks_written = 0 + + + def _createData(self): + """Retrieves an uninitialized block of data corresponding to the data + format of the output + + This method must be called to correctly create a new block of data + """ + + if hasattr(self.data_sink, 'dataformat'): + return self.data_sink.dataformat.type() + else: + raise RuntimeError("The currently used data sink is not bound to " \ + "a dataformat - you cannot create uninitialized data under " \ + "these circumstances") + + + def write(self, data, end_data_index=None): + """Write a block of data on the output + + Parameters: + + data (beat.core.baseformat.baseformat): The block of data to write, or + None (if the algorithm doesn't want to write any data) + + end_data_index (int): Last index of the written data (see the section + *Inputs synchronization* of the User's Guide). If not specified, the + *current end data index* of the Inputs List is used + + """ + + if self._dataset_output: + if end_data_index is None: + end_data_index = self.last_written_data_index + 1 + elif end_data_index < self.last_written_data_index + 1: + raise KeyError("Database wants to write an `end_data_index' (%d) " \ + "which is smaller than the last written index (%d) " \ + "+1 - this is a database bug - Fix it!" % \ + (end_data_index, self.last_written_data_index)) + + elif end_data_index is not None: + if (end_data_index < self.last_written_data_index + 1) or \ + ((self._synchronization_listener is not None) and \ + (end_data_index > self._synchronization_listener.data_index_end)): + raise KeyError("Algorithm logic error on write(): `end_data_index' " \ + "is not consistent with last written index") + + elif self._synchronization_listener is not None: + end_data_index = self._synchronization_listener.data_index_end + + else: + end_data_index = self.last_written_data_index + 1 + + # if the user passes a dictionary, converts to the proper baseformat type + if isinstance(data, dict): + d = self.data_sink.dataformat.type() + d.from_dict(data, casting='safe', add_defaults=False) + data = d + + self.data_sink.write(data, self.last_written_data_index + 1, end_data_index) + + self.last_written_data_index = end_data_index + self.nb_data_blocks_written += 1 + + + def isDataMissing(self): + return not(self._dataset_output) and \ + (self._synchronization_listener is not None) and \ + (self._synchronization_listener.data_index_end != self.last_written_data_index) + + + def isConnected(self): + return self.data_sink.isConnected() + + +#---------------------------------------------------------- + + +class RemoteOutput: + """Represents one output of a processing block + A list of outputs implementing this interface is provided to the algorithms (see :py:class:`beat.backend.python.outputs.OutputList`). @@ -54,7 +186,6 @@ class Output: """ - def __init__(self, name, data_format, socket): self.name = name @@ -127,6 +258,9 @@ class Output: return answer == 'tru' +#---------------------------------------------------------- + + class OutputList: """Represents the list of outputs of a processing block @@ -155,7 +289,6 @@ class OutputList: """ def __init__(self): - self._outputs = [] @@ -170,14 +303,15 @@ class OutputList: if index < len(self._outputs): return self._outputs[index] return None - def __iter__(self): + def __iter__(self): for k in self._outputs: yield k - def __len__(self): + def __len__(self): return len(self._outputs) + def add(self, output): """Adds an output to the list