Skip to content
Snippets Groups Projects
Commit d773cb99 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Added tfdebug

parent 26a4c0a2
No related branches found
No related tags found
No related merge requests found
Pipeline #
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
import tensorflow as tf import tensorflow as tf
import os import os
from tensorflow.python import debug as tf_debug
class Extractor(object): class Extractor(object):
...@@ -14,7 +15,7 @@ class Extractor(object): ...@@ -14,7 +15,7 @@ class Extractor(object):
""" """
def __init__(self, checkpoint_filename, input_tensor, graph): def __init__(self, checkpoint_filename, input_tensor, graph, debug=False):
"""Loads the tensorflow model """Loads the tensorflow model
Parameters Parameters
...@@ -34,17 +35,21 @@ class Extractor(object): ...@@ -34,17 +35,21 @@ class Extractor(object):
self.graph = graph self.graph = graph
# Initializing the variables of the current graph # Initializing the variables of the current graph
self.session = tf.Session() self.session = tf.Session()
self.session.run(tf.global_variables_initializer()) self.session.run(tf.global_variables_initializer())
# Loading the last checkpoint and overwriting the current variables # Loading the last checkpoint and overwriting the current variables
saver = tf.train.Saver() saver = tf.train.Saver()
if os.path.splitext(checkpoint_filename)[1] == ".meta": if os.path.isdir(checkpoint_filename):
saver.restore(self.session, tf.train.latest_checkpoint(os.path.dirname(checkpoint_filename))) saver.restore(self.session, tf.train.latest_checkpoint(checkpoint_filename))
else: else:
saver.restore(self.session, checkpoint_filename) saver.restore(self.session, checkpoint_filename)
# Activating the debug
if debug:
self.session = tf_debug.LocalCLIDebugWrapperSession(self.session)
def __del__(self): def __del__(self):
tf.reset_default_graph() tf.reset_default_graph()
...@@ -65,6 +70,5 @@ class Extractor(object): ...@@ -65,6 +70,5 @@ class Extractor(object):
The features. The features.
""" """
return self.session.run(self.graph, feed_dict={self.input_tensor: data}) return self.session.run(self.graph, feed_dict={self.input_tensor: data})
#!/usr/bin/env python #!/usr/bin/env python
from .Extractor import Extractor
from .InceptionResNet_v1 import InceptionResNet_v1
def scratch_network(inputs, end_point="fc1", reuse = False): def scratch_network(inputs, end_point="fc1", reuse = False):
import tensorflow as tf import tensorflow as tf
...@@ -53,7 +57,6 @@ def download_file(url, out_file): ...@@ -53,7 +57,6 @@ def download_file(url, out_file):
with open(out_file, 'wb') as f: with open(out_file, 'wb') as f:
copyfileobj(response, f) copyfileobj(response, f)
from .Extractor import Extractor
def get_config(): def get_config():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment