diff --git a/bob/learn/tensorflow/extractors/Base.py b/bob/learn/tensorflow/extractors/Base.py new file mode 100644 index 0000000000000000000000000000000000000000..3f41f882fc49b5d7edee0a89e6c76af1475c922f --- /dev/null +++ b/bob/learn/tensorflow/extractors/Base.py @@ -0,0 +1,67 @@ +import tensorflow as tf +import os +import numpy as np +import logging + +logger = logging.getLogger(__name__) + + +def normalize_checkpoint_path(path): + if os.path.splitext(path)[1] == ".meta": + filename = os.path.splitext(path)[0] + elif os.path.isdir(path): + filename = tf.train.latest_checkpoint(path) + else: + filename = path + + return filename + + +class Base: + def __init__(self, output_name, input_shape, checkpoint, scopes, + input_transform=None, output_transform=None, + input_dtype='float32', **kwargs): + + self.output_name = output_name + self.input_shape = input_shape + self.checkpoint = normalize_checkpoint_path(checkpoint) + self.scopes = scopes + self.input_transform = input_transform + self.output_transform = output_transform + self.input_dtype = input_dtype + self.session = None + super().__init__(**kwargs) + + def load(self): + self.session = tf.Session(graph=tf.Graph()) + + with self.session.as_default(), self.session.graph.as_default(): + + self.input = data = tf.placeholder(self.input_dtype, self.input_shape) + + if self.input_transform is not None: + data = self.input_transform(data) + + self.output = self.get_output(data, tf.estimator.ModeKeys.PREDICT) + + if self.output_transform is not None: + self.output = self.output_transform(self.output) + + tf.train.init_from_checkpoint( + ckpt_dir_or_file=self.checkpoint, + assignment_map=self.scopes, + ) + # global_variables_initializer must run after init_from_checkpoint + self.session.run(tf.global_variables_initializer()) + logger.info('Restored the model from %s', self.checkpoint) + + def __call__(self, data): + if self.session is None: + self.load() + + data = np.ascontiguousarray(data, dtype=self.input_dtype) + + return self.session.run(self.output, feed_dict={self.input: data}) + + def get_output(self, data, mode): + raise NotImplementedError() diff --git a/bob/learn/tensorflow/extractors/Estimator.py b/bob/learn/tensorflow/extractors/Estimator.py new file mode 100644 index 0000000000000000000000000000000000000000..768b539f1906d2a70f5b8a71537d16a1924ac472 --- /dev/null +++ b/bob/learn/tensorflow/extractors/Estimator.py @@ -0,0 +1,16 @@ +import tensorflow as tf +from .Base import Base + + +class Estimator(Base): + def __init__(self, estimator, **kwargs): + self.estimator = estimator + kwargs['checkpoint'] = kwargs.get('checkpoint', estimator.model_dir) + super().__init__(**kwargs) + + def get_output(self, data, mode): + features = {'data': data, 'key': tf.constant(['key'])} + self.estimator_spec = self.estimator._call_model_fn( + features, None, mode, None) + self.end_points = self.estimator.end_points + return self.end_points[self.output_name] diff --git a/bob/learn/tensorflow/extractors/Generic.py b/bob/learn/tensorflow/extractors/Generic.py new file mode 100644 index 0000000000000000000000000000000000000000..3aab2573317c0916531eea2a748329147e543f87 --- /dev/null +++ b/bob/learn/tensorflow/extractors/Generic.py @@ -0,0 +1,12 @@ +from .Base import Base + + +class Generic(Base): + def __init__(self, architecture, **kwargs): + + self.architecture = architecture + super().__init__(**kwargs) + + def get_output(self, data, mode): + self.end_points = self.architecture(data, mode=mode)[1] + return self.end_points[self.output_name] diff --git a/bob/learn/tensorflow/extractors/__init__.py b/bob/learn/tensorflow/extractors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ac13aeec4364c6d2481e00aa2d0c37abfb94407e --- /dev/null +++ b/bob/learn/tensorflow/extractors/__init__.py @@ -0,0 +1,27 @@ +from .Base import Base +from .Generic import Generic +from .Estimator import Estimator + + +# gets sphinx autodoc done right - don't remove it +def __appropriate__(*args): + """Says object was actually declared here, an not on the import module. + + Parameters: + + *args: An iterable of objects to modify + + Resolves `Sphinx referencing issues + <https://github.com/sphinx-doc/sphinx/issues/3048>` + """ + + for obj in args: + obj.__module__ = __name__ + + +__appropriate__( + Base, + Generic, + Estimator, +) +__all__ = [_ for _ in dir() if not _.startswith('_')]