Skip to content
Snippets Groups Projects
Commit c77e386e authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Add several extractors which are usefull at inference time

parent 13ae3919
No related branches found
No related tags found
1 merge request!75A lot of new features
import tensorflow as tf
import os
import numpy as np
import logging
logger = logging.getLogger(__name__)
def normalize_checkpoint_path(path):
if os.path.splitext(path)[1] == ".meta":
filename = os.path.splitext(path)[0]
elif os.path.isdir(path):
filename = tf.train.latest_checkpoint(path)
else:
filename = path
return filename
class Base:
def __init__(self, output_name, input_shape, checkpoint, scopes,
input_transform=None, output_transform=None,
input_dtype='float32', **kwargs):
self.output_name = output_name
self.input_shape = input_shape
self.checkpoint = normalize_checkpoint_path(checkpoint)
self.scopes = scopes
self.input_transform = input_transform
self.output_transform = output_transform
self.input_dtype = input_dtype
self.session = None
super().__init__(**kwargs)
def load(self):
self.session = tf.Session(graph=tf.Graph())
with self.session.as_default(), self.session.graph.as_default():
self.input = data = tf.placeholder(self.input_dtype, self.input_shape)
if self.input_transform is not None:
data = self.input_transform(data)
self.output = self.get_output(data, tf.estimator.ModeKeys.PREDICT)
if self.output_transform is not None:
self.output = self.output_transform(self.output)
tf.train.init_from_checkpoint(
ckpt_dir_or_file=self.checkpoint,
assignment_map=self.scopes,
)
# global_variables_initializer must run after init_from_checkpoint
self.session.run(tf.global_variables_initializer())
logger.info('Restored the model from %s', self.checkpoint)
def __call__(self, data):
if self.session is None:
self.load()
data = np.ascontiguousarray(data, dtype=self.input_dtype)
return self.session.run(self.output, feed_dict={self.input: data})
def get_output(self, data, mode):
raise NotImplementedError()
import tensorflow as tf
from .Base import Base
class Estimator(Base):
def __init__(self, estimator, **kwargs):
self.estimator = estimator
kwargs['checkpoint'] = kwargs.get('checkpoint', estimator.model_dir)
super().__init__(**kwargs)
def get_output(self, data, mode):
features = {'data': data, 'key': tf.constant(['key'])}
self.estimator_spec = self.estimator._call_model_fn(
features, None, mode, None)
self.end_points = self.estimator.end_points
return self.end_points[self.output_name]
from .Base import Base
class Generic(Base):
def __init__(self, architecture, **kwargs):
self.architecture = architecture
super().__init__(**kwargs)
def get_output(self, data, mode):
self.end_points = self.architecture(data, mode=mode)[1]
return self.end_points[self.output_name]
from .Base import Base
from .Generic import Generic
from .Estimator import Estimator
# gets sphinx autodoc done right - don't remove it
def __appropriate__(*args):
"""Says object was actually declared here, an not on the import module.
Parameters:
*args: An iterable of objects to modify
Resolves `Sphinx referencing issues
<https://github.com/sphinx-doc/sphinx/issues/3048>`
"""
for obj in args:
obj.__module__ = __name__
__appropriate__(
Base,
Generic,
Estimator,
)
__all__ = [_ for _ in dir() if not _.startswith('_')]
  • Owner

    With this commit we should remove some code from bob.ip.tensorflow_extractor to avoid code repetition.

  • Author Owner

    so you are going to make bob.ip... depend on bob.learn...?

  • Owner

    It makes more sense bob.learn.tensorflow is a very stable package.

    Is this make sense for you?

0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment