Skip to content
Snippets Groups Projects
Commit 1a6db853 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

This is an experiment to approach #106

1 - Removed the keyword argument "extractor_file" from Extractor.train
so we remove the model persistence from it
2 - Create the method Extractor.save to handle the persistence task
3 - Make tools.extractor save the model.
parent 18bcb692
Branches big-issue-106
No related tags found
1 merge request!177WIP: This is an experiment to approach bob.bio.base#106
Pipeline #32352 failed
...@@ -142,7 +142,7 @@ class Extractor (object): ...@@ -142,7 +142,7 @@ class Extractor (object):
pass pass
def train(self, training_data, extractor_file): def train(self, training_data):
"""This function can be overwritten to train the feature extractor. """This function can be overwritten to train the feature extractor.
If you do this, please also register the function by calling this base class constructor If you do this, please also register the function by calling this base class constructor
and enabling the training by ``requires_training = True``. and enabling the training by ``requires_training = True``.
...@@ -154,8 +154,16 @@ class Extractor (object): ...@@ -154,8 +154,16 @@ class Extractor (object):
Data will be provided in a single list, if ``split_training_features_by_client = False`` was specified in the constructor, Data will be provided in a single list, if ``split_training_features_by_client = False`` was specified in the constructor,
otherwise the data will be split into lists, each of which contains the data of a single (training-)client. otherwise the data will be split into lists, each of which contains the data of a single (training-)client.
"""
raise NotImplementedError("Please overwrite this function in your derived class, or unset the 'requires_training' option in the constructor.")
def save(self, extractor_file):
"""
extractor_file : str extractor_file : str
The file to write. The file to write.
This file should be readable with the :py:meth:`load` function. This file should be readable with the :py:meth:`load` function.
"""
"""
raise NotImplementedError("Please overwrite this function in your derived class, or unset the 'requires_training' option in the constructor.") raise NotImplementedError("Please overwrite this function in your derived class, or unset the 'requires_training' option in the constructor.")
...@@ -52,8 +52,8 @@ def train_extractor(extractor, preprocessor, allow_missing_files = False, force ...@@ -52,8 +52,8 @@ def train_extractor(extractor, preprocessor, allow_missing_files = False, force
else: else:
logger.info("- Extraction: training extractor '%s' using %d training files:", fs.extractor_file, len(train_files)) logger.info("- Extraction: training extractor '%s' using %d training files:", fs.extractor_file, len(train_files))
# train model # train model
extractor.train(train_data, fs.extractor_file) extractor.train(train_data)
extractor.save(fs.extractor_file)
def extract(extractor, preprocessor, groups=None, indices = None, allow_missing_files = False, force = False): def extract(extractor, preprocessor, groups=None, indices = None, allow_missing_files = False, force = False):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment