diff --git a/bob/learn/tensorflow/extractors/Base.py b/bob/learn/tensorflow/extractors/Base.py index 3f41f882fc49b5d7edee0a89e6c76af1475c922f..c4e1064ff9278b9db778149dc4d2717f70d0ca06 100644 --- a/bob/learn/tensorflow/extractors/Base.py +++ b/bob/learn/tensorflow/extractors/Base.py @@ -20,7 +20,7 @@ def normalize_checkpoint_path(path): class Base: def __init__(self, output_name, input_shape, checkpoint, scopes, input_transform=None, output_transform=None, - input_dtype='float32', **kwargs): + input_dtype='float32', extra_feed=None, **kwargs): self.output_name = output_name self.input_shape = input_shape @@ -29,6 +29,7 @@ class Base: self.input_transform = input_transform self.output_transform = output_transform self.input_dtype = input_dtype + self.extra_feed = extra_feed self.session = None super().__init__(**kwargs) @@ -60,8 +61,11 @@ class Base: self.load() data = np.ascontiguousarray(data, dtype=self.input_dtype) + feed_dict = {self.input: data} + if self.extra_feed is not None: + feed_dict.update(self.extra_feed) - return self.session.run(self.output, feed_dict={self.input: data}) + return self.session.run(self.output, feed_dict=feed_dict) def get_output(self, data, mode): raise NotImplementedError()