Commit 1e21e10c authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

remove the extractors folder

parent a28815cd
import tensorflow as tf
import os
import numpy as np
import logging
logger = logging.getLogger(__name__)
def normalize_checkpoint_path(path):
if os.path.splitext(path)[1] == ".meta":
filename = os.path.splitext(path)[0]
elif os.path.isdir(path):
filename = tf.train.latest_checkpoint(path)
else:
filename = path
return filename
class Base:
def __init__(self, output_name, input_shape, checkpoint, scopes,
input_transform=None, output_transform=None,
input_dtype='float32', extra_feed=None, **kwargs):
self.output_name = output_name
self.input_shape = input_shape
self.checkpoint = normalize_checkpoint_path(checkpoint)
self.scopes = scopes
self.input_transform = input_transform
self.output_transform = output_transform
self.input_dtype = input_dtype
self.extra_feed = extra_feed
self.session = None
super().__init__(**kwargs)
def load(self):
self.session = tf.Session(graph=tf.Graph())
with self.session.as_default(), self.session.graph.as_default():
self.input = data = tf.placeholder(self.input_dtype, self.input_shape)
if self.input_transform is not None:
data = self.input_transform(data)
self.output = self.get_output(data, tf.estimator.ModeKeys.PREDICT)
if self.output_transform is not None:
self.output = self.output_transform(self.output)
tf.train.init_from_checkpoint(
ckpt_dir_or_file=self.checkpoint,
assignment_map=self.scopes,
)
# global_variables_initializer must run after init_from_checkpoint
self.session.run(tf.global_variables_initializer())
logger.info('Restored the model from %s', self.checkpoint)
def __call__(self, data):
if self.session is None:
self.load()
data = np.ascontiguousarray(data, dtype=self.input_dtype)
feed_dict = {self.input: data}
if self.extra_feed is not None:
feed_dict.update(self.extra_feed)
return self.session.run(self.output, feed_dict=feed_dict)
def get_output(self, data, mode):
raise NotImplementedError()
import tensorflow as tf
from .Base import Base
class Estimator(Base):
def __init__(self, estimator, **kwargs):
self.estimator = estimator
kwargs['checkpoint'] = kwargs.get('checkpoint', estimator.model_dir)
super().__init__(**kwargs)
def get_output(self, data, mode):
features = {'data': data, 'key': tf.constant(['key'])}
self.estimator_spec = self.estimator._call_model_fn(
features, None, mode, None)
self.end_points = self.estimator.end_points
return self.end_points[self.output_name]
from .Base import Base
class Generic(Base):
def __init__(self, architecture, **kwargs):
self.architecture = architecture
super().__init__(**kwargs)
def get_output(self, data, mode):
self.end_points = self.architecture(data, mode=mode)[1]
return self.end_points[self.output_name]
from .Base import Base, normalize_checkpoint_path
from .Generic import Generic
from .Estimator import Estimator
# gets sphinx autodoc done right - don't remove it
def __appropriate__(*args):
"""Says object was actually declared here, an not on the import module.
Parameters:
*args: An iterable of objects to modify
Resolves `Sphinx referencing issues
<https://github.com/sphinx-doc/sphinx/issues/3048>`
"""
for obj in args:
obj.__module__ = __name__
__appropriate__(
Base,
Generic,
Estimator,
)
__all__ = [_ for _ in dir() if not _.startswith('_')]
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment