Fixed issue with InceptionNetowrks

parent 2d167d1c
......@@ -80,6 +80,14 @@ class TransformTensorflow(TransformerMixin, BaseEstimator):
d["model"] = None
return d
def inference(self, X):
if self.preprocessor is not None:
X = self.preprocessor(tf.cast(X, "float32"))
prelogits = self.model.predict_on_batch(X)
embeddings = tf.math.l2_normalize(prelogits, axis=-1)
return embeddings
def _more_tags(self):
return {"stateless": True, "requires_fit": False}
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment