Commit 1d039f94 authored by Saeed SARFJOO's avatar Saeed SARFJOO

fix vstack_features

parent a6af4d6b
Pipeline #21731 failed with stage
in 18 minutes and 15 seconds
......@@ -262,14 +262,18 @@ def take_from_config_or_command_line(args, config, keyword, default, required=Tr
setattr(config, keyword, None)
def check_config_consumed(config):
def check_config_consumed(config, args=None):
if config is not None:
import inspect
for keyword in dir(config):
if not keyword.startswith('_') and not keyword.isupper():
attr = getattr(config, keyword)
if attr is not None and not inspect.isclass(attr) and not inspect.ismodule(attr):
logger.warn("The variable '%s' in a configuration file is not known or not supported by this application; use a '_' prefix to the variable name (e.g., '_%s') to suppress this warning", keyword, keyword)
if not args is None and hasattr(args, keyword):
setattr(args, keyword, attr)
setattr(config, keyword, None)
else:
logger.warn("The variable '%s' in a configuration file is not known or not supported by this application; use a '_' prefix to the variable name (e.g., '_%s') to suppress this warning", keyword, keyword)
def parse_config_file(parsers, args, args_dictionary, keywords, skips):
......@@ -306,7 +310,7 @@ def parse_config_file(parsers, args, args_dictionary, keywords, skips):
parser.get_default(keyword), required=False, is_resource=False)
# check that all variables in the config file are consumed by the above options
check_config_consumed(config)
check_config_consumed(config, args=args)
# evaluate skips
if skips is not None and args.execute_only is not None:
......
......@@ -5,6 +5,8 @@ import collections # this is needed for the sphinx documentation
import functools # this is needed for the sphinx documentation
import numpy
import logging
import types
import itertools
logger = logging.getLogger("bob.bio.base")
from .. import database
......@@ -44,6 +46,19 @@ def filter_none(data, split_by_client=False):
return existing_data
# def check_file(filename, force, expected_file_size=1):
# """Checks if the file with the given ``filename`` exists and has size greater or equal to ``expected_file_size``.
# If the file is to small, **or** if the ``force`` option is set to ``True``, the file is removed.
# This function returns ``True`` is the file exists (and has not been removed), otherwise ``False``"""
# if os.path.exists(filename):
# if force or os.path.getsize(filename) < expected_file_size:
# logger.debug(" .. Removing old file '%s'.", filename)
# os.remove(filename)
# return False
# else:
# return True
# return False
def check_file(filename, force, expected_file_size=1):
"""Checks if the file with the given ``filename`` exists and has size greater or equal to ``expected_file_size``.
If the file is to small, **or** if the ``force`` option is set to ``True``, the file is removed.
......@@ -55,7 +70,10 @@ def check_file(filename, force, expected_file_size=1):
return False
else:
return True
return False
elif os.path.exists(filename.replace('/preprocessed/', '/extracted/')): # sss #
return True
else:
return False
def read_original_data(biofile, directory, extension):
......@@ -224,7 +242,7 @@ def _generate_features(reader, paths, same_size=False,
yield value
def vstack_features(reader, paths, same_size=False, allow_missing_files=False):
def vstack_features(reader, paths, same_size=False, allow_missing_files=False, remove_nans=False, use_iterable=False):
"""Stacks all features in a memory efficient way.
Parameters
......@@ -247,7 +265,10 @@ def vstack_features(reader, paths, same_size=False, allow_missing_files=False):
allow_missing_files : :obj:`bool`, optional
If ``True``, it assumes that the items inside paths are actual files and
ignores the ones that do not exist.
remove_nans : :obj:`bool`, optional
If ``True``, it will remove the NaN samples from the data.
use_iterable : :obj:`bool`, optional
If ``True``, it uses iterable function for loading data.
Returns
-------
numpy.ndarray
......@@ -305,15 +326,59 @@ def vstack_features(reader, paths, same_size=False, allow_missing_files=False):
if same_size and allow_missing_files:
raise ValueError("Both same_size and allow_missing_files cannot be True at"
" the same time.")
iterable = _generate_features(reader, paths, same_size, allow_missing_files)
dtype, shape = next(iterable)
if same_size:
total_size = int(len(paths) * numpy.prod(shape))
all_features = numpy.fromiter(iterable, dtype, total_size)
else:
all_features = numpy.fromiter(iterable, dtype)
if use_iterable:
iterable = _generate_features(reader, paths, same_size, allow_missing_files)
dtype, shape = next(iterable)
if same_size:
total_size = int(len(paths) * numpy.prod(shape))
all_features = numpy.fromiter(iterable, dtype, total_size)
else:
all_features = numpy.fromiter(iterable, dtype)
# the shape is assumed to be (n_samples, ...) it can be (5, 2) or (5, 3, 4).
shape = list(shape)
shape[0] = -1
return numpy.reshape(all_features, shape, order='C')
# the shape is assumed to be (n_samples, ...) it can be (5, 2) or (5, 3, 4).
shape = list(shape)
shape[0] = -1
feats = numpy.reshape(all_features, shape, order='C')
else:
if isinstance(paths, types.GeneratorType):
paths_2nd, paths = itertools.tee(paths)
else:
paths_2nd = paths
total_sample = 0
files_samples = []
shape = [1]
dtype = numpy.float64
for i, feat_file in enumerate(paths):
f = bob.io.base.HDF5File(feat_file)
cur_file_sample = 0
for j, key in enumerate(f.keys()):
f_info = f.describe(key)
cur_type = f_info[1][0][0]
cur_shape = list(f_info[1][0][1])
total_sample += cur_shape[0]
cur_file_sample += cur_shape[0]
cur_shape = cur_shape[1:]
if i==0 and j == 0:
dtype = cur_type
shape = cur_shape
else:
assert shape == cur_shape
assert dtype == cur_type
files_samples.append(cur_file_sample)
f.close()
if same_size:
files_samples = files_samples * len(paths)
total_sample = total_sample * len(paths)
break
feats = numpy.empty([total_sample]+ shape, dtype=dtype)
data_idx = 0
for k, feat_file in enumerate(paths_2nd):
feats[data_idx:data_idx+files_samples[k]] = reader(feat_file)
data_idx+=files_samples[k]
if remove_nans:
feat_idx = [numpy.sum(numpy.isnan(feats[id])) == 0 for id in range(feats.shape[0])]
feats = feats[feat_idx]
return feats
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment