Skip to content
Snippets Groups Projects
Commit 96bc574c authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Use generators to save memory

parent 5bd0e8ec
No related branches found
No related tags found
1 merge request!9Use generators instead of list to concatenate the loaded training data
Pipeline #
...@@ -71,7 +71,7 @@ class ISV (GMM): ...@@ -71,7 +71,7 @@ class ISV (GMM):
"""Train Projector and Enroller at the same time""" """Train Projector and Enroller at the same time"""
[self._check_feature(feature) for client in train_features for feature in client] [self._check_feature(feature) for client in train_features for feature in client]
data1 = numpy.vstack([feature for client in train_features for feature in client]) data1 = numpy.vstack(feature for client in train_features for feature in client)
self.train_ubm(data1) self.train_ubm(data1)
# to save some memory, we might want to delete these data # to save some memory, we might want to delete these data
del data1 del data1
......
...@@ -133,10 +133,9 @@ class IVector (GMM): ...@@ -133,10 +133,9 @@ class IVector (GMM):
"""Train Projector and Enroller at the same time""" """Train Projector and Enroller at the same time"""
[self._check_feature(feature) for client in train_features for feature in client] [self._check_feature(feature) for client in train_features for feature in client]
train_features_flatten = [feature for client in train_features for feature in client]
# train UBM # train UBM
data = numpy.vstack(train_features_flatten) data = numpy.vstack(feature for client in train_features for feature in client)
self.train_ubm(data) self.train_ubm(data)
del data del data
......
...@@ -3,7 +3,7 @@ import bob.learn.em ...@@ -3,7 +3,7 @@ import bob.learn.em
import shutil import shutil
import numpy import numpy
import os import os
import functools
import logging import logging
logger = logging.getLogger("bob.bio.gmm") logger = logging.getLogger("bob.bio.gmm")
...@@ -24,7 +24,10 @@ def kmeans_initialize(algorithm, extractor, limit_data = None, force = False): ...@@ -24,7 +24,10 @@ def kmeans_initialize(algorithm, extractor, limit_data = None, force = False):
# read data # read data
logger.info("UBM training: initializing kmeans") logger.info("UBM training: initializing kmeans")
training_list = utils.selected_elements(fs.training_list('extracted', 'train_projector'), limit_data) training_list = utils.selected_elements(fs.training_list('extracted', 'train_projector'), limit_data)
data = numpy.vstack([read_feature(extractor, feature_file) for feature_file in training_list])
# read the features
reader = functools.partial(read_feature, extractor)
data = utils.vstack_features(reader, training_list)
# Perform KMeans initialization # Perform KMeans initialization
kmeans_machine = bob.learn.em.KMeansMachine(algorithm.gaussians, data.shape[1]) kmeans_machine = bob.learn.em.KMeansMachine(algorithm.gaussians, data.shape[1])
...@@ -55,8 +58,11 @@ def kmeans_estep(algorithm, extractor, iteration, indices, force=False): ...@@ -55,8 +58,11 @@ def kmeans_estep(algorithm, extractor, iteration, indices, force=False):
logger.info("UBM training: KMeans E-Step round %d from range(%d, %d)", iteration, *indices) logger.info("UBM training: KMeans E-Step round %d from range(%d, %d)", iteration, *indices)
# read data # read the features
data = numpy.vstack([read_feature(extractor, training_list[index]) for index in range(indices[0], indices[1])]) reader = functools.partial(read_feature, extractor)
data = utils.vstack_features(
reader,
(training_list[index] for index in range(indices[0], indices[1])))
# Performs the E-step # Performs the E-step
trainer = algorithm.kmeans_trainer trainer = algorithm.kmeans_trainer
...@@ -168,9 +174,11 @@ def gmm_initialize(algorithm, extractor, limit_data = None, force = False): ...@@ -168,9 +174,11 @@ def gmm_initialize(algorithm, extractor, limit_data = None, force = False):
else: else:
logger.info("UBM Training: Initializing GMM") logger.info("UBM Training: Initializing GMM")
# read features
training_list = utils.selected_elements(fs.training_list('extracted', 'train_projector'), limit_data) training_list = utils.selected_elements(fs.training_list('extracted', 'train_projector'), limit_data)
data = numpy.vstack([read_feature(extractor, feature_file) for feature_file in training_list])
# read the features
reader = functools.partial(read_feature, extractor)
data = utils.vstack_features(reader, training_list)
# get means and variances of kmeans result # get means and variances of kmeans result
kmeans_machine = bob.learn.em.KMeansMachine(bob.io.base.HDF5File(fs.kmeans_file)) kmeans_machine = bob.learn.em.KMeansMachine(bob.io.base.HDF5File(fs.kmeans_file))
...@@ -209,8 +217,12 @@ def gmm_estep(algorithm, extractor, iteration, indices, force=False): ...@@ -209,8 +217,12 @@ def gmm_estep(algorithm, extractor, iteration, indices, force=False):
logger.info("UBM training: GMM E-Step from range(%d, %d)", *indices) logger.info("UBM training: GMM E-Step from range(%d, %d)", *indices)
# read data # read the features
data = numpy.vstack([read_feature(extractor, training_list[index]) for index in range(indices[0], indices[1])]) reader = functools.partial(read_feature, extractor)
data = utils.vstack_features(
reader,
(training_list[index] for index in range(indices[0], indices[1])))
trainer = algorithm.ubm_trainer trainer = algorithm.ubm_trainer
trainer.initialize(gmm_machine, None) trainer.initialize(gmm_machine, None)
......
...@@ -69,7 +69,7 @@ def read_feature(extractor, feature_file): ...@@ -69,7 +69,7 @@ def read_feature(extractor, feature_file):
import bob.bio.video import bob.bio.video
if isinstance(extractor, bob.bio.video.extractor.Wrapper): if isinstance(extractor, bob.bio.video.extractor.Wrapper):
assert isinstance(feature, bob.bio.video.FrameContainer) assert isinstance(feature, bob.bio.video.FrameContainer)
return numpy.vstack([frame for _,frame,_ in feature]) return numpy.vstack(frame for _, frame, _ in feature)
except ImportError: except ImportError:
pass pass
return feature return feature
...@@ -10,5 +10,5 @@ bob.sp ...@@ -10,5 +10,5 @@ bob.sp
bob.learn.em bob.learn.em
bob.measure bob.measure
bob.db.base bob.db.base
bob.bio.base bob.bio.base > 3.1
matplotlib # for plotting matplotlib # for plotting
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment