Commit 751a2488 authored by Pavel KORSHUNOV's avatar Pavel KORSHUNOV
Browse files

asvspoo2017 corrections, tensorflow gmm for hdf5

parent bc6bc98a
...@@ -32,7 +32,7 @@ class TensorflowAlgo(Algorithm): ...@@ -32,7 +32,7 @@ class TensorflowAlgo(Algorithm):
) )
self.data_reader = DiskAudio([0], [0]) self.data_reader = DiskAudio([0], [0])
# self.session = tf.Session() self.session = tf.Session()
self.dnn_model = None self.dnn_model = None
# def __del__(self): # def __del__(self):
...@@ -47,8 +47,8 @@ class TensorflowAlgo(Algorithm): ...@@ -47,8 +47,8 @@ class TensorflowAlgo(Algorithm):
def load_projector(self, projector_file): def load_projector(self, projector_file):
logger.info("Loading pretrained model from {0}".format(projector_file)) logger.info("Loading pretrained model from {0}".format(projector_file))
self.dnn_model = SequenceNetwork() self.dnn_model = SequenceNetwork()
# self.dnn_model.load(bob.io.base.HDF5File(projector_file), session=self.session) self.dnn_model.load_hdf5(bob.io.base.HDF5File(projector_file), shape=[1, 6560, 1])
self.dnn_model.load(projector_file, True) # self.dnn_model.load(projector_file, True)
def project_feature(self, feature): def project_feature(self, feature):
...@@ -58,8 +58,8 @@ class TensorflowAlgo(Algorithm): ...@@ -58,8 +58,8 @@ class TensorflowAlgo(Algorithm):
logger.debug(" .... And %d frames are extracted to pass into DNN model" % frames.shape[0]) logger.debug(" .... And %d frames are extracted to pass into DNN model" % frames.shape[0])
frames = numpy.reshape(frames, (frames.shape[0], -1, 1)) frames = numpy.reshape(frames, (frames.shape[0], -1, 1))
forward_output = self.dnn_model(frames) forward_output = self.dnn_model(frames)
# return tf.nn.log_softmax(tf.nn.log_softmax(forward_output)).eval(session=self.session) return tf.nn.log_softmax(tf.nn.log_softmax(forward_output)).eval(session=self.session)
return forward_output # return forward_output
def project(self, feature): def project(self, feature):
"""project(feature) -> projected """project(feature) -> projected
......
...@@ -48,7 +48,7 @@ class ASVspoof2017PadDatabase(PadDatabase): ...@@ -48,7 +48,7 @@ class ASVspoof2017PadDatabase(PadDatabase):
if names is None: if names is None:
return None return None
mapping = dict(zip(low_level_names, high_level_names)) mapping = dict(zip(high_level_names, low_level_names))
if isinstance(names, str): if isinstance(names, str):
return mapping.get(names) return mapping.get(names)
return [mapping[g] for g in names] return [mapping[g] for g in names]
...@@ -78,5 +78,6 @@ class ASVspoof2017PadDatabase(PadDatabase): ...@@ -78,5 +78,6 @@ class ASVspoof2017PadDatabase(PadDatabase):
Returns: A set of Files with the specified properties. Returns: A set of Files with the specified properties.
""" """
purposes = self.convert_purposes(purposes, ('genuine', 'spoof'), ('real', 'attack')) purposes = self.convert_purposes(purposes, ('genuine', 'spoof'), ('real', 'attack'))
objects = self.__db.objects(protocol=protocol, groups=groups, purposes=purposes, **kwargs) objects = self.__db.objects(protocol=protocol, groups=groups, purposes=purposes, **kwargs)
return [ASVspoof2017PadFile(f) for f in objects] return [ASVspoof2017PadFile(f) for f in objects]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment