Skip to content
Snippets Groups Projects
Commit a8bc6541 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Add an extra feed option to the extractors

parent b9583f15
No related branches found
No related tags found
No related merge requests found
...@@ -20,7 +20,7 @@ def normalize_checkpoint_path(path): ...@@ -20,7 +20,7 @@ def normalize_checkpoint_path(path):
class Base: class Base:
def __init__(self, output_name, input_shape, checkpoint, scopes, def __init__(self, output_name, input_shape, checkpoint, scopes,
input_transform=None, output_transform=None, input_transform=None, output_transform=None,
input_dtype='float32', **kwargs): input_dtype='float32', extra_feed=None, **kwargs):
self.output_name = output_name self.output_name = output_name
self.input_shape = input_shape self.input_shape = input_shape
...@@ -29,6 +29,7 @@ class Base: ...@@ -29,6 +29,7 @@ class Base:
self.input_transform = input_transform self.input_transform = input_transform
self.output_transform = output_transform self.output_transform = output_transform
self.input_dtype = input_dtype self.input_dtype = input_dtype
self.extra_feed = extra_feed
self.session = None self.session = None
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -60,8 +61,11 @@ class Base: ...@@ -60,8 +61,11 @@ class Base:
self.load() self.load()
data = np.ascontiguousarray(data, dtype=self.input_dtype) data = np.ascontiguousarray(data, dtype=self.input_dtype)
feed_dict = {self.input: data}
if self.extra_feed is not None:
feed_dict.update(self.extra_feed)
return self.session.run(self.output, feed_dict={self.input: data}) return self.session.run(self.output, feed_dict=feed_dict)
def get_output(self, data, mode): def get_output(self, data, mode):
raise NotImplementedError() raise NotImplementedError()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment