Skip to content
Snippets Groups Projects
Commit a8bc6541 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Add an extra feed option to the extractors

parent b9583f15
No related branches found
No related tags found
No related merge requests found
......@@ -20,7 +20,7 @@ def normalize_checkpoint_path(path):
class Base:
def __init__(self, output_name, input_shape, checkpoint, scopes,
input_transform=None, output_transform=None,
input_dtype='float32', **kwargs):
input_dtype='float32', extra_feed=None, **kwargs):
self.output_name = output_name
self.input_shape = input_shape
......@@ -29,6 +29,7 @@ class Base:
self.input_transform = input_transform
self.output_transform = output_transform
self.input_dtype = input_dtype
self.extra_feed = extra_feed
self.session = None
super().__init__(**kwargs)
......@@ -60,8 +61,11 @@ class Base:
self.load()
data = np.ascontiguousarray(data, dtype=self.input_dtype)
feed_dict = {self.input: data}
if self.extra_feed is not None:
feed_dict.update(self.extra_feed)
return self.session.run(self.output, feed_dict={self.input: data})
return self.session.run(self.output, feed_dict=feed_dict)
def get_output(self, data, mode):
raise NotImplementedError()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment