Skip to content
Snippets Groups Projects
Commit 25eb1a20 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

propagating allow missing files

parent f9fa293a
Branches
Tags
No related merge requests found
Pipeline #
......@@ -162,7 +162,7 @@ def execute(args):
force = args.force)
# train the feature projector
elif args.sub_task == 'kmeans-e-step':
elif args.sub_task == 'kmeans-e-step':
tools.kmeans_estep(
algorithm,
args.extractor,
......
......@@ -313,8 +313,9 @@ def gmm_project(algorithm, extractor, indices, force=False, allow_missing_files
projected_file = projected_files[i]
if not utils.check_file(projected_file, force):
# load feature
feature = read_feature(extractor, feature_file)
feature = read_feature(extractor, feature_file, allow_missing_files=allow_missing_files)
# project feature
projected = algorithm.project_ubm(feature)
# write it
......
......@@ -7,7 +7,7 @@ import os
from bob.bio.base.tools.FileSelector import FileSelector
from bob.bio.base import utils, tools
def train_isv(algorithm, force=False):
def train_isv(algorithm, force=False, allow_missing_files=False):
"""Finally, the UBM is used to train the ISV projector/enroller."""
fs = FileSelector.instance()
......@@ -19,7 +19,21 @@ def train_isv(algorithm, force=False):
# read training data
training_list = fs.training_list('projected_gmm', 'train_projector', arrange_by_client = True)
train_gmm_stats = [[algorithm.read_gmm_stats(filename) for filename in client_files] for client_files in training_list]
train_gmm_stats = []
for client_files in training_list:
client_stats = []
for filename in client_files:
if not os.path.exists(filename):
if allow_missing_files:
logger.debug("... Cannot find the file %s; skipping", filename)
else:
raise RuntimeError("Cannot find the file '%s' " % filename)
client_stats.append(algorithm.read_gmm_stats(filename))
train_gmm_stats.append(client_stats)
#train_gmm_stats = [[algorithm.read_gmm_stats(filename) for filename in client_files] for client_files in training_list]
# perform ISV training
logger.info("ISV training: training ISV with %d clients", len(train_gmm_stats))
......
......@@ -10,7 +10,7 @@ from bob.bio.base import utils, tools
def ivector_estep(algorithm, iteration, indices, force=False):
def ivector_estep(algorithm, iteration, indices, force=False, allow_missing_files = False):
"""Performs a single E-step of the IVector algorithm (parallel)"""
fs = FileSelector.instance()
stats_file = fs.ivector_stats_file(iteration, indices[0], indices[1])
......@@ -38,7 +38,17 @@ def ivector_estep(algorithm, iteration, indices, force=False):
# Load data
training_list = fs.training_list('projected_gmm', 'train_projector')
data = [algorithm.read_gmm_stats(training_list[i]) for i in range(indices[0], indices[1])]
data = []
for i in range(indices[0], indices[1]):
filename = training_list[i]
if not os.path.exists(filename):
if allow_missing_files:
logger.debug("... Cannot find the file %s; skipping", filename)
else:
raise RuntimeError("Cannot find the file '%s' " % filename)
data.append(algorithm.read_gmm_stats(filename))
#data = [algorithm.read_gmm_stats(training_list[i]) for i in range(indices[0], indices[1])]
# Perform the E-step
trainer.e_step(tv, data)
......@@ -134,7 +144,7 @@ def ivector_mstep(algorithm, iteration, number_of_parallel_jobs, force=False, cl
shutil.rmtree(old_dir)
def ivector_project(algorithm, indices, force=False):
def ivector_project(algorithm, indices, force=False, allow_missing_files=False):
"""Performs IVector projection"""
# read UBM and TV into the IVector class
fs = FileSelector.instance()
......@@ -150,6 +160,13 @@ def ivector_project(algorithm, indices, force=False):
gmm_stats_file = gmm_stats_files[i]
ivector_file = ivector_files[i]
if not utils.check_file(ivector_file, force):
if not os.path.exists(gmm_stats_file):
if allow_missing_files:
logger.debug("... Cannot find the file %s; skipping", gmm_stats_file)
else:
raise RuntimeError("Cannot find the file '%s' " % gmm_stats_file)
# load feature
feature = algorithm.read_gmm_stats(gmm_stats_file)
# project feature
......
import bob.bio.base
import numpy
import os
def add_jobs(args, submitter, local_job_adder):
"""Adds all (desired) jobs of the tool chain to the grid, or to the local list to be executed."""
......@@ -63,7 +64,14 @@ def base(algorithm):
"""Returns the base algorithm, if it is a video extension, otherwise returns the algorithm itself"""
return algorithm.algorithm if is_video_extension(algorithm) else algorithm
def read_feature(extractor, feature_file):
def read_feature(extractor, feature_file, allow_missing_files = False):
if not os.path.exists(feature_file):
if allow_missing_files:
logger.debug("... Cannot find preprocessed data file %s; skipping", feature_file)
else:
raise RuntimeError("Cannot find file '%s' " % feature_file)
feature = extractor.read_feature(feature_file)
try:
import bob.bio.video
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment