Commit a8bc6541 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

Add an extra feed option to the extractors

parent b9583f15
......@@ -20,7 +20,7 @@ def normalize_checkpoint_path(path):
class Base:
def __init__(self, output_name, input_shape, checkpoint, scopes,
input_transform=None, output_transform=None,
input_dtype='float32', **kwargs):
input_dtype='float32', extra_feed=None, **kwargs):
self.output_name = output_name
self.input_shape = input_shape
......@@ -29,6 +29,7 @@ class Base:
self.input_transform = input_transform
self.output_transform = output_transform
self.input_dtype = input_dtype
self.extra_feed = extra_feed
self.session = None
super().__init__(**kwargs)
......@@ -60,8 +61,11 @@ class Base:
self.load()
data = np.ascontiguousarray(data, dtype=self.input_dtype)
feed_dict = {self.input: data}
if self.extra_feed is not None:
feed_dict.update(self.extra_feed)
return self.session.run(self.output, feed_dict={self.input: data})
return self.session.run(self.output, feed_dict=feed_dict)
def get_output(self, data, mode):
raise NotImplementedError()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment