Skip to content
Snippets Groups Projects
Commit 96bc574c authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Use generators to save memory

parent 5bd0e8ec
No related branches found
Tags v3.0.2
1 merge request!9Use generators instead of list to concatenate the loaded training data
Pipeline #
...@@ -71,7 +71,7 @@ class ISV (GMM): ...@@ -71,7 +71,7 @@ class ISV (GMM):
"""Train Projector and Enroller at the same time""" """Train Projector and Enroller at the same time"""
[self._check_feature(feature) for client in train_features for feature in client] [self._check_feature(feature) for client in train_features for feature in client]
data1 = numpy.vstack([feature for client in train_features for feature in client]) data1 = numpy.vstack(feature for client in train_features for feature in client)
self.train_ubm(data1) self.train_ubm(data1)
# to save some memory, we might want to delete these data # to save some memory, we might want to delete these data
del data1 del data1
......
...@@ -133,10 +133,9 @@ class IVector (GMM): ...@@ -133,10 +133,9 @@ class IVector (GMM):
"""Train Projector and Enroller at the same time""" """Train Projector and Enroller at the same time"""
[self._check_feature(feature) for client in train_features for feature in client] [self._check_feature(feature) for client in train_features for feature in client]
train_features_flatten = [feature for client in train_features for feature in client]
# train UBM # train UBM
data = numpy.vstack(train_features_flatten) data = numpy.vstack(feature for client in train_features for feature in client)
self.train_ubm(data) self.train_ubm(data)
del data del data
......
...@@ -3,7 +3,7 @@ import bob.learn.em ...@@ -3,7 +3,7 @@ import bob.learn.em
import shutil import shutil
import numpy import numpy
import os import os
import functools
import logging import logging
logger = logging.getLogger("bob.bio.gmm") logger = logging.getLogger("bob.bio.gmm")
...@@ -24,7 +24,10 @@ def kmeans_initialize(algorithm, extractor, limit_data = None, force = False): ...@@ -24,7 +24,10 @@ def kmeans_initialize(algorithm, extractor, limit_data = None, force = False):
# read data # read data
logger.info("UBM training: initializing kmeans") logger.info("UBM training: initializing kmeans")
training_list = utils.selected_elements(fs.training_list('extracted', 'train_projector'), limit_data) training_list = utils.selected_elements(fs.training_list('extracted', 'train_projector'), limit_data)
data = numpy.vstack([read_feature(extractor, feature_file) for feature_file in training_list])
# read the features
reader = functools.partial(read_feature, extractor)
data = utils.vstack_features(reader, training_list)
# Perform KMeans initialization # Perform KMeans initialization
kmeans_machine = bob.learn.em.KMeansMachine(algorithm.gaussians, data.shape[1]) kmeans_machine = bob.learn.em.KMeansMachine(algorithm.gaussians, data.shape[1])
...@@ -55,8 +58,11 @@ def kmeans_estep(algorithm, extractor, iteration, indices, force=False): ...@@ -55,8 +58,11 @@ def kmeans_estep(algorithm, extractor, iteration, indices, force=False):
logger.info("UBM training: KMeans E-Step round %d from range(%d, %d)", iteration, *indices) logger.info("UBM training: KMeans E-Step round %d from range(%d, %d)", iteration, *indices)
# read data # read the features
data = numpy.vstack([read_feature(extractor, training_list[index]) for index in range(indices[0], indices[1])]) reader = functools.partial(read_feature, extractor)
data = utils.vstack_features(
reader,
(training_list[index] for index in range(indices[0], indices[1])))
# Performs the E-step # Performs the E-step
trainer = algorithm.kmeans_trainer trainer = algorithm.kmeans_trainer
...@@ -168,9 +174,11 @@ def gmm_initialize(algorithm, extractor, limit_data = None, force = False): ...@@ -168,9 +174,11 @@ def gmm_initialize(algorithm, extractor, limit_data = None, force = False):
else: else:
logger.info("UBM Training: Initializing GMM") logger.info("UBM Training: Initializing GMM")
# read features
training_list = utils.selected_elements(fs.training_list('extracted', 'train_projector'), limit_data) training_list = utils.selected_elements(fs.training_list('extracted', 'train_projector'), limit_data)
data = numpy.vstack([read_feature(extractor, feature_file) for feature_file in training_list])
# read the features
reader = functools.partial(read_feature, extractor)
data = utils.vstack_features(reader, training_list)
# get means and variances of kmeans result # get means and variances of kmeans result
kmeans_machine = bob.learn.em.KMeansMachine(bob.io.base.HDF5File(fs.kmeans_file)) kmeans_machine = bob.learn.em.KMeansMachine(bob.io.base.HDF5File(fs.kmeans_file))
...@@ -209,8 +217,12 @@ def gmm_estep(algorithm, extractor, iteration, indices, force=False): ...@@ -209,8 +217,12 @@ def gmm_estep(algorithm, extractor, iteration, indices, force=False):
logger.info("UBM training: GMM E-Step from range(%d, %d)", *indices) logger.info("UBM training: GMM E-Step from range(%d, %d)", *indices)
# read data # read the features
data = numpy.vstack([read_feature(extractor, training_list[index]) for index in range(indices[0], indices[1])]) reader = functools.partial(read_feature, extractor)
data = utils.vstack_features(
reader,
(training_list[index] for index in range(indices[0], indices[1])))
trainer = algorithm.ubm_trainer trainer = algorithm.ubm_trainer
trainer.initialize(gmm_machine, None) trainer.initialize(gmm_machine, None)
......
...@@ -69,7 +69,7 @@ def read_feature(extractor, feature_file): ...@@ -69,7 +69,7 @@ def read_feature(extractor, feature_file):
import bob.bio.video import bob.bio.video
if isinstance(extractor, bob.bio.video.extractor.Wrapper): if isinstance(extractor, bob.bio.video.extractor.Wrapper):
assert isinstance(feature, bob.bio.video.FrameContainer) assert isinstance(feature, bob.bio.video.FrameContainer)
return numpy.vstack([frame for _,frame,_ in feature]) return numpy.vstack(frame for _, frame, _ in feature)
except ImportError: except ImportError:
pass pass
return feature return feature
...@@ -10,5 +10,5 @@ bob.sp ...@@ -10,5 +10,5 @@ bob.sp
bob.learn.em bob.learn.em
bob.measure bob.measure
bob.db.base bob.db.base
bob.bio.base bob.bio.base > 3.1
matplotlib # for plotting matplotlib # for plotting
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment