Commit c77e386e authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Add several extractors which are usefull at inference time

parent 13ae3919
import tensorflow as tf
import os
import numpy as np
import logging
logger = logging.getLogger(__name__)
def normalize_checkpoint_path(path):
if os.path.splitext(path)[1] == ".meta":
filename = os.path.splitext(path)[0]
elif os.path.isdir(path):
filename = tf.train.latest_checkpoint(path)
filename = path
return filename
class Base:
def __init__(self, output_name, input_shape, checkpoint, scopes,
input_transform=None, output_transform=None,
input_dtype='float32', **kwargs):
self.output_name = output_name
self.input_shape = input_shape
self.checkpoint = normalize_checkpoint_path(checkpoint)
self.scopes = scopes
self.input_transform = input_transform
self.output_transform = output_transform
self.input_dtype = input_dtype
self.session = None
def load(self):
self.session = tf.Session(graph=tf.Graph())
with self.session.as_default(), self.session.graph.as_default():
self.input = data = tf.placeholder(self.input_dtype, self.input_shape)
if self.input_transform is not None:
data = self.input_transform(data)
self.output = self.get_output(data, tf.estimator.ModeKeys.PREDICT)
if self.output_transform is not None:
self.output = self.output_transform(self.output)
# global_variables_initializer must run after init_from_checkpoint'Restored the model from %s', self.checkpoint)
def __call__(self, data):
if self.session is None:
data = np.ascontiguousarray(data, dtype=self.input_dtype)
return, feed_dict={self.input: data})
def get_output(self, data, mode):
raise NotImplementedError()
import tensorflow as tf
from .Base import Base
class Estimator(Base):
def __init__(self, estimator, **kwargs):
self.estimator = estimator
kwargs['checkpoint'] = kwargs.get('checkpoint', estimator.model_dir)
def get_output(self, data, mode):
features = {'data': data, 'key': tf.constant(['key'])}
self.estimator_spec = self.estimator._call_model_fn(
features, None, mode, None)
self.end_points = self.estimator.end_points
return self.end_points[self.output_name]
from .Base import Base
class Generic(Base):
def __init__(self, architecture, **kwargs):
self.architecture = architecture
def get_output(self, data, mode):
self.end_points = self.architecture(data, mode=mode)[1]
return self.end_points[self.output_name]
from .Base import Base
from .Generic import Generic
from .Estimator import Estimator
# gets sphinx autodoc done right - don't remove it
def __appropriate__(*args):
"""Says object was actually declared here, an not on the import module.
*args: An iterable of objects to modify
Resolves `Sphinx referencing issues
for obj in args:
obj.__module__ = __name__
__all__ = [_ for _ in dir() if not _.startswith('_')]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment