Commit 8577d403 authored by Philip ABBET's avatar Philip ABBET

Refactoring: reassign some classes from beat.core

parent 14bb7f2d
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
###############################################################################
# #
# Copyright (c) 2016 Idiap Research Institute, http://www.idiap.ch/ #
# Contact: beat.support@idiap.ch #
# #
# This file is part of the beat.backend.python module of the BEAT platform. #
# #
# Commercial License Usage #
# Licensees holding valid commercial BEAT licenses may use this file in #
# accordance with the terms contained in a written agreement between you #
# and Idiap. For further information contact tto@idiap.ch #
# #
# Alternatively, this file may be used under the terms of the GNU Affero #
# Public License version 3 as published by the Free Software and appearing #
# in the file LICENSE.AGPL included in the packaging of this file. #
# The BEAT platform is distributed in the hope that it will be useful, but #
# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY #
# or FITNESS FOR A PARTICULAR PURPOSE. #
# #
# You should have received a copy of the GNU Affero Public License along #
# with the BEAT platform. If not, see http://www.gnu.org/licenses/. #
# #
###############################################################################
"""Data I/O classes and functions"""
import os
import re
import glob
import simplejson as json
import select
import time
import tempfile
import abc
from functools import reduce
import logging
logger = logging.getLogger(__name__)
import six
from .hash import hashFileContents
from .dataformat import DataFormat
from .algorithm import Algorithm
class DataSource(object):
"""Interface of all the Data Sources
Data Sources are used to provides data to the inputs of an algorithm.
"""
__metaclass__ = abc.ABCMeta
@abc.abstractmethod
def next(self, load=True):
"""Retrieves the next block of data
Returns:
A tuple (*data*, *start_index*, *end_index*)
"""
pass
@abc.abstractmethod
def hasMoreData(self):
"""Indicates if there is more data to process on some of the inputs"""
pass
class DataSink(object):
"""Interface of all the Data Sinks
Data Sinks are used by the outputs of an algorithm to write/transmit data.
"""
__metaclass__ = abc.ABCMeta
@abc.abstractmethod
def write(self, data, start_data_index, end_data_index):
"""Writes a block of data
Parameters:
data (beat.core.baseformat.baseformat): The block of data to write
start_data_index (int): Start index of the written data
end_data_index (int): End index of the written data
"""
pass
@abc.abstractmethod
def isConnected(self):
pass
class CachedDataSource(DataSource):
"""Data Source that load data from the Cache"""
def __init__(self):
self.filenames = None
self.cur_file = None
self.cur_file_index = None
self.encoding = None # must be 'binary' or 'json'
self.prefix = None # where to find dataformats
self.dataformat = None # the dataformat itself
self.preloaded = False
self.next_start_index = None
self.next_end_index = None
self.next_data_size = None
self.force_start_index = None
self.force_end_index = None
self._cache_size = 10 * 1024 * 1024 # 10 megabytes
self._cache = six.b('')
self._nb_bytes_read = 0
self._read_duration = 0
def _readHeader(self):
"""Read the header of the current file"""
# Read file format
encoding = self.cur_file.readline()[:-1]
if not isinstance(encoding, str): encoding = encoding.decode('utf8')
if encoding not in ('binary', 'json'):
raise RuntimeError("valid formats for data reading are 'binary' "
"or 'json': the format `%s' is invalid" % (encoding,))
self.encoding = encoding
# Read data format
dataformat_name = self.cur_file.readline()[:-1]
if not isinstance(dataformat_name, str):
dataformat_name = dataformat_name.decode('utf8')
if dataformat_name.startswith('analysis:'):
algo_name = dataformat_name.split(':')[1]
algo = Algorithm(self.prefix, algo_name)
if not algo.valid:
raise RuntimeError("the dataformat `%s' is the result of an " \
"algorithm which is not valid" % algo_name)
self.dataformat = algo.result_dataformat()
else:
self.dataformat = DataFormat(self.prefix, dataformat_name)
if not self.dataformat.valid:
raise RuntimeError("the dataformat `%s' is not valid" % dataformat_name)
return True
def setup(self, filename, prefix, force_start_index=None,
force_end_index=None):
"""Configures the data source
Parameters:
filename (str): Name of the file to read the data from
prefix (str, path): Path to the prefix where the dataformats are stored.
force_start_index (int): The starting index (if not set or set to
``None``, the default, read data from the begin of file)
force_end_index (int): The end index (if not set or set to ``None``, the
default, reads the data until the end)
Returns:
``True``, if successful, or ``False`` otherwise.
"""
index_re = re.compile(r'^.*\.(\d+)\.(\d+)\.(data|index)(.checksum)?$')
def file_start(f):
"""Returns the converted start indexes from a filename, otherwise 0"""
r = index_re.match(f)
if r: return int(r.group(1))
return 0
def trim_filename(l, start_index, end_index):
"""Function to trim out the useless file given a range of indices
"""
res = []
for f in l:
r = index_re.match(f)
if r:
s = int(r.group(1))
e = int(r.group(2))
if (start_index is not None and e < start_index) or \
(end_index is not None and s > end_index):
continue
res.append(f)
return res
def check_consistency(data_filenames, basename, data_ext):
"""Perform some sanity check on the data/checksum files on disk:
1. One-to-one mapping between data and checksum files
2. Checksum comparison between hash(data) and checksum files
3. Contiguous indices if they are present
"""
# Check checksum of files
checksum_filenames = sorted(glob.glob(basename + '*' + data_ext + '.checksum'), key=file_start)
# Make sure that we have a perfect match between data files and checksum
# files
checksum_filenames_noext = [os.path.splitext(f)[0] for f in checksum_filenames]
if data_filenames != checksum_filenames_noext:
raise IOError("number of data files and checksum files for `%s' " \
"does not match (%d != %d)" % (filename, len(data_filenames),
len(checksum_filenames_noext)))
# list of start/end indices to check that there are contiguous
indices = []
for f_data, f_chck in zip(data_filenames, checksum_filenames):
expected_chksum = open(f_chck, 'rt').read().strip()
current_chksum = hashFileContents(f_data)
if expected_chksum != current_chksum:
raise IOError("data file `%s' has a checksum (%s) that differs " \
"from expected one (%s)" % (f_data, current_chksum,
expected_chksum))
r = index_re.match(f_data)
if r: indices.append((int(r.group(1)), int(r.group(2))))
indices = sorted(indices, key=lambda v: v[0])
ok_indices = True
if len(indices) > 0:
ok_indices = (indices[0][0] == 0)
if ok_indices and len(indices) > 1:
ok_indices = sum([indices[i + 1][0] - indices[i][1] == 1
for i in range(len(indices) - 1)])
if not ok_indices:
raise IOError("data file `%s' have missing indices." % f_data)
self.prefix = prefix
basename, data_ext = os.path.splitext(filename)
data_filenames = sorted(glob.glob(basename + '*' + data_ext),
key=file_start)
# Check consistency of the data/checksum files
check_consistency(data_filenames, basename, data_ext)
# List files to process
self.force_start_index = force_start_index
self.force_end_index = force_end_index
self.filenames = trim_filename(data_filenames, force_start_index,
force_end_index)
# Read the first file to process
self.cur_file_index = 0
try:
self.cur_file = open(self.filenames[self.cur_file_index], 'rb')
except Exception as e:
logger.warn("Could not setup `%s': %s" % (filename, e))
return False
# Reads the header of the current file
self._readHeader()
if force_start_index is not None:
# Read chunck until force_start_index is reached
while True:
# Read indexes
t1 = time.time()
line = self.cur_file.readline()
self._nb_bytes_read += len(line)
t2 = time.time()
self._read_duration += t2 - t1
(self.next_start_index, self.next_end_index, self.next_data_size) = \
[int(x) for x in line.split()]
# Seek to the next chunck of data if start index is still too small
if self.next_start_index < force_start_index:
self.cur_file.seek(self.next_data_size, 1)
# Otherwise, read the next 'chunck' of data (binary or json)
else:
t1 = time.time()
data = self.cur_file.read(self._cache_size - len(self._cache))
t2 = time.time()
self._nb_bytes_read += len(data)
self._read_duration += t2 - t1
self._cache += data
self.preloaded = True
break
else:
# Preload the data
self._preload()
return True
def close(self):
"""Closes the data source"""
if self.cur_file is not None:
self.cur_file.close()
def __del__(self):
"""Makes sure the files are close when the object is deleted"""
self.close()
def next(self):
"""Retrieve the next block of data
Returns:
A tuple (data, start_index, end_index)
"""
if self.next_start_index is None:
return (None, None, None)
# Determine if the cache already contains all the data we need
if len(self._cache) >= self.next_data_size:
encoded_data = self._cache[:self.next_data_size]
self._cache = self._cache[self.next_data_size:]
else:
t1 = time.time()
data = self.cur_file.read(self.next_data_size - len(self._cache))
t2 = time.time()
self._nb_bytes_read += len(data)
self._read_duration += t2 - t1
encoded_data = self._cache + data
self._cache = six.b('')
data = self.dataformat.type()
data.unpack(encoded_data) #checks validity
result = (data, self.next_start_index, self.next_end_index)
self._preload()
return result
def hasMoreData(self):
"""Indicates if there is more data to process on some of the inputs"""
if not(self.preloaded):
self._preload(blocking=True)
if self.force_end_index is not None and \
self.next_start_index is not None and \
self.next_start_index > self.force_end_index:
return False
return (self.next_start_index is not None)
def statistics(self):
"""Return the statistics about the number of bytes read from the cache"""
return (self._nb_bytes_read, self._read_duration)
def _preload(self, blocking=False):
# Determine if the cache already contains all the data we need
offset = self._cache.find(six.b('\n'))
if offset == -1:
# Try to read the next chunck of data
while True:
# Read in the current file
t1 = time.time()
if blocking:
(readable, writable, errors) = select.select([self.cur_file], [], [])
data = self.cur_file.read(self._cache_size - len(self._cache))
t2 = time.time()
self._nb_bytes_read += len(data)
self._read_duration += t2 - t1
self._cache += data
# If not read from the current file
if (len(data) == 0) or (self._cache.find(six.b('\n')) == -1):
# Read the next one if possible
if self.cur_file_index < len(self.filenames) - 1:
if self.cur_file is not None:
self.cur_file.close()
self.cur_file_index += 1
try:
self.cur_file = open(self.filenames[self.cur_file_index], 'rb')
except:
return
self._readHeader()
# Otherwise, stop the parsing
else:
self.next_start_index = None
self.next_end_index = None
self.next_data_size = None
self.preloaded = blocking
return
else:
break
offset = self._cache.find(six.b('\n'))
# Extract the informations about the next block of data
line = self._cache[:offset]
self._cache = self._cache[offset + 1:]
(self.next_start_index, self.next_end_index, self.next_data_size) = \
[int(x) for x in line.split()]
self.preloaded = True
class CachedDataSink(DataSink):
"""Data Sink that save data in the Cache
The default behavior is to save the data in a binary format.
"""
def __init__(self):
self.filename = None
self.process_id = None
self.split_id = None
self.max_size = None
self._nb_bytes_written = 0
self._write_duration = 0
self._nb_bytes_written_split = 0
self._new_file = False
self._cur_filename = None
self._cur_file = None
self._cur_indexname = None
self._cur_index = None
self._cur_start_index = None
self._cur_end_index = None
self._filenames = []
self._filenames_tmp = []
self._tmp_ext = '.tmp'
self.encoding = None
self.dataformat = None
def _curTmpFilenameWithSplit(self):
filename, data_ext = os.path.splitext(self.filename)
dirname = os.path.dirname(filename)
basename = os.path.basename(filename)
fd, tmp_file = tempfile.mkstemp(
dir=dirname,
prefix=basename+'.' + str(self.process_id)+'.'+ str(self.split_id)+'_',
suffix=data_ext + self._tmp_ext,
)
os.close(fd) # Preserve only the name
os.unlink(tmp_file)
return tmp_file
def _curFilenameWithIndices(self):
basename = os.path.basename(self.filename)
basename, data_ext = os.path.splitext(basename)
dirname = os.path.dirname(self.filename)
return os.path.join(dirname, basename + '.' + str(self._cur_start_index) + '.' + str(self._cur_end_index) + data_ext)
def _tmpIndexFilenameFromTmpFilename(self, tmp_filename):
return os.path.splitext(os.path.splitext(tmp_filename)[0])[0] + '.index' + self._tmp_ext
def _indexFilenameFromFilename(self, filename):
return os.path.splitext(filename)[0] + '.index'
def _openAndWriteHeader(self):
"""Write the header of the current file"""
# Close current file if open
self._close_current()
# Open new file in writing mode
self._cur_filename = self._curTmpFilenameWithSplit()
self._cur_indexname = \
self._tmpIndexFilenameFromTmpFilename(self._cur_filename)
self._filenames_tmp.append(self._cur_filename)
try:
self._cur_file = open(self._cur_filename, 'wb')
self._cur_index = open(self._cur_indexname, 'wt')
except:
return
# Write dataformat
self._cur_file.write(six.b('%s\n%s\n' % \
(self.encoding, self.dataformat.name)))
self._cur_file.flush()
# Reset few flags
self._cur_start_index = None
self._cur_end_index = None
self._new_file = False
self._nb_bytes_written_split = 0
def setup(self, filename, dataformat, encoding='binary', process_id=0,
max_size=0):
"""Configures the data sink
Parameters:
filename (str): Name of the file to generate
dataformat (beat.core.dataformat.DataFormat): The dataformat to be used
inside this file. All objects stored inside this file will respect that
format.
encoding (str): String defining the encoding to be used for encoding the
data. Only a few options are supported: ``binary`` (the default) or
``json`` (debugging purposes).
"""
if encoding not in ('binary', 'json'):
raise RuntimeError("valid formats for data writting are 'binary' "
"or 'json': the format `%s' is invalid" % format)
if dataformat.name == '__unnamed_dataformat__':
raise RuntimeError("cannot record data using an unnammed data format")
self.filename = filename
self.process_id = process_id
self.split_id = 0
self.max_size = max_size
self._nb_bytes_written = 0
self._write_duration = 0
self._new_file = True
self._cur_filename = None
self._cur_file = None
self._cur_indexname = None
self._cur_index = None
self._cur_start_index = None
self._cur_end_index = None
self._filenames = []
self._filenames_tmp = []
self.dataformat = dataformat
self.encoding = encoding
return True
def _close_current(self):
"""Closes the data sink
"""
if self._cur_file is not None:
self._cur_file.close()
self._cur_index.close()
# If file is empty, remove it
if self._cur_start_index is None or self._cur_end_index is None:
try:
os.remove(self._cur_filename)
os.remove(self._cur_index)
except:
return False
self._filenames_tmp.pop()
# Otherwise, append final filename to list
else:
self._filenames.append(self._curFilenameWithIndices())
self._cur_filename = None
self._cur_file = None
self._cur_indexname = None
self._cur_index = None
def close(self):
"""Move the files to final location
"""
self._close_current()
assert len(self._filenames_tmp) == len(self._filenames)
for i in range(len(self._filenames_tmp)):
os.rename(self._filenames_tmp[i], self._filenames[i])
tmp_indexname = \
self._tmpIndexFilenameFromTmpFilename(self._filenames_tmp[i])
final_indexname = self._indexFilenameFromFilename(self._filenames[i])
os.rename(tmp_indexname, final_indexname)
# creates the checksums for all data and indexes
chksum_data = hashFileContents(self._filenames[i])
with open(self._filenames[i] + '.checksum', 'wt') as f:
f.write(chksum_data)
chksum_index = hashFileContents(final_indexname)
with open(final_indexname + '.checksum', 'wt') as f:
f.write(chksum_index)
self._cur_filename = None
self._cur_file = None
self._cur_indexname = None
self._cur_index = None
self._cur_start_index = None
self._cur_end_index = None
self._filenames = []
self._filenames_tmp = []
def reset(self):
"""Move the files to final location
"""
self._close_current()
assert len(self._filenames_tmp) == len(self._filenames)
for i in range(len(self._filenames_tmp)):
try:
os.remove(self._filenames_tmp[i])
tmp_indexname = \
self._tmpIndexFilenameFromTmpFilename(self._filenames_tmp[i])
os.remove(tmp_indexname)
except:
return False
self._cur_filename = None
self._cur_file = None
self._cur_indexname = None
self._cur_index = None
self._cur_start_index = None
self._cur_end_index = None
self._filenames = []
self._filenames_tmp = []
def __del__(self):
"""Make sure the files are close and renamed when the object is deleted
"""
self.close()
def write(self, data, start_data_index, end_data_index):
"""Writes a block of data to the filesystem
Parameters:
data (beat.core.baseformat.baseformat): The block of data to write
start_data_index (int): Start index of the written data
end_data_index (int): End index of the written data
"""
if isinstance(data, dict):
# the user passed a dictionary - must convert
data = self.dataformat.type(**data)
else:
# Checks that the input data conforms to the expected format
if data.__class__._name != self.dataformat.name:
raise TypeError("input data uses format `%s' while this sink "
"expects `%s'" % (data.__class__._name, self.dataformat))
# If the flag new_file is set, open new file and write header
if self._new_file:
self._openAndWriteHeader()
if self._cur_file is None:
raise RuntimeError("no destination file")
# encoding happens here
if self.encoding == 'binary':
encoded_data = data.pack()
else:
from .utils import NumpyJSONEncoder
encoded_data = json.dumps(data.as_dict(), indent=4, cls=NumpyJSONEncoder)
# adds a new line by the end of the encoded data, for clarity
encoded_data += six.b('\n')
informations = six.b('%d %d %d\n' % (start_data_index,
end_data_index, len(encoded_data)))
t1 = time.time()
self._cur_file.write(informations + encoded_data)
self._cur_file.flush()
indexes = '%d %d\n' % (start_data_index, end_data_index)