Commit 718e78fe authored by Pavel KORSHUNOV's avatar Pavel KORSHUNOV

fixing tests

parent 15574ef9
Pipeline #9007 passed with stages
in 18 minutes and 51 seconds
This diff is collapsed.
......@@ -32,7 +32,7 @@ class TensorflowAlgo(Algorithm):
)
self.data_reader = DiskAudio([0], [0])
self.session = tf.Session()
# self.session = tf.Session()
self.dnn_model = None
# def __del__(self):
......@@ -47,8 +47,8 @@ class TensorflowAlgo(Algorithm):
def load_projector(self, projector_file):
logger.info("Loading pretrained model from {0}".format(projector_file))
self.dnn_model = SequenceNetwork()
self.dnn_model.load_hdf5(bob.io.base.HDF5File(projector_file), shape=[1, 6560, 1])
# self.dnn_model.load(projector_file, True)
# self.dnn_model.load_hdf5(bob.io.base.HDF5File(projector_file), shape=[1, 6560, 1])
self.dnn_model.load(projector_file, True)
def project_feature(self, feature):
......@@ -58,8 +58,8 @@ class TensorflowAlgo(Algorithm):
logger.debug(" .... And %d frames are extracted to pass into DNN model" % frames.shape[0])
frames = numpy.reshape(frames, (frames.shape[0], -1, 1))
forward_output = self.dnn_model(frames)
return tf.nn.log_softmax(tf.nn.log_softmax(forward_output)).eval(session=self.session)
# return forward_output
# return tf.nn.log_softmax(tf.nn.log_softmax(forward_output)).eval(session=self.session)
return forward_output
def project(self, feature):
"""project(feature) -> projected
......
......@@ -80,10 +80,13 @@ class ASVspoof2017PadDatabase(PadDatabase):
purposes = self.convert_purposes(purposes, ('genuine', 'spoof'), ('real', 'attack'))
if protocol == 'largetrain':
if groups == 'train':
# this configuration is for ASVspoof2017 compettiion
if 'train' in groups and 'dev' in groups:
groups = ('train', 'dev', 'eval')
elif 'train' in groups:
groups = ('train', 'dev')
if groups == 'dev':
groups = 'eval'
elif 'dev' in groups:
groups = ('eval',)
objects = self.__db.objects(protocol=protocol, groups=groups, purposes=purposes, **kwargs)
return [ASVspoof2017PadFile(f) for f in objects]
......@@ -38,15 +38,15 @@ class AudioTFExtractor(Extractor):
skip_extractor_training=True, **kwargs)
# block parameters
import tensorflow as tf
self.session = tf.Session()
# import tensorflow as tf
# self.session = tf.Session()
# self.session = Session.instance().session
self.feature_layer = feature_layer
self.data_reader = DiskAudio([0], [0])
self.dnn_model = SequenceNetwork(default_feature_layer=feature_layer)
self.dnn_model = None
def __call__(self, input_data):
"""
......@@ -69,9 +69,9 @@ class AudioTFExtractor(Extractor):
def load(self, extractor_file):
logger.info("Loading pretrained model from {0}".format(extractor_file))
self.dnn_model = SequenceNetwork()
self.dnn_model.load_hdf5(bob.io.base.HDF5File(extractor_file), shape=[1, 6560, 1])
# self.dnn_model.load(extractor_file, clear_devices=True)
self.dnn_model = SequenceNetwork(default_feature_layer=self.feature_layer)
# self.dnn_model.load_hdf5(bob.io.base.HDF5File(extractor_file), shape=[1, 6560, 1])
self.dnn_model.load(extractor_file, clear_devices=True)
#hdf5 = bob.io.base.HDF5File(extractor_file)
#self.lenet.load(hdf5, shape=(1,125,125,3), session=self.session)
......
from .dummy import DummyDatabaseSpeechSpoof
# gets sphinx autodoc done right - don't remove it
def __appropriate__(*args):
"""Says object was actually declared here, and not in the import module.
Fixing sphinx warnings of not being able to find classes, when path is shortened.
Parameters:
*args: An iterable of objects to modify
Resolves `Sphinx referencing issues
<https://github.com/sphinx-doc/sphinx/issues/3048>`
"""
for obj in args: obj.__module__ = __name__
__appropriate__(
DummyDatabaseSpeechSpoof,
)
__all__ = [_ for _ in dir() if not _.startswith('_')]
\ No newline at end of file
from . import database
from .database import DummyDatabaseSpeechSpoof
# gets sphinx autodoc done right - don't remove it
def __appropriate__(*args):
"""Says object was actually declared here, and not in the import module.
Fixing sphinx warnings of not being able to find classes, when path is shortened.
Parameters:
*args: An iterable of objects to modify
Resolves `Sphinx referencing issues
<https://github.com/sphinx-doc/sphinx/issues/3048>`
"""
for obj in args: obj.__module__ = __name__
__appropriate__(
DummyDatabaseSpeechSpoof,
)
__all__ = [_ for _ in dir() if not _.startswith('_')]
......@@ -138,6 +138,8 @@ class DummyDatabaseSpeechSpoof(bob.pad.base.database.PadDatabase):
return return_list
def annotations(self, file):
pass
database = DummyDatabaseSpeechSpoof(
protocol='Default',
......
......@@ -10,16 +10,11 @@ eggs = bob.pad.voice
bob.bio.spear
bob.bio.gmm
bob.pad.base
bob.db.base
bob.measure
bob.db.asvspoof
bob.db.asvspoof2017
bob.db.avspoof
bob.db.voicepa
bob.extension
bob.learn.tensorflow
bob.pad.db
bob.bio.db
bob.db.cpqd_replay
gridtk
......@@ -36,11 +31,7 @@ develop = src/bob.bio.spear
src/bob.db.cpqd_replay
src/bob.pad.base
src/bob.bio.base
src/bob.db.base
src/bob.extension
src/bob.learn.tensorflow
src/bob.bio.db
src/bob.pad.db
.
; options for bob.buildout
......@@ -53,14 +44,10 @@ bob.bio.gmm = git git@gitlab.idiap.ch:bob/bob.bio.gmm.git
bob.db.asvspoof2017 = git git@gitlab.idiap.ch:bob/bob.db.asvspoof2017.git
bob.db.avspoof = git git@gitlab.idiap.ch:bob/bob.db.avspoof.git
bob.db.asvspoof = git git@gitlab.idiap.ch:bob/bob.db.asvspoof.git
bob.db.base = git git@gitlab.idiap.ch:bob/bob.db.base.git
bob.bio.base = git git@gitlab.idiap.ch:bob/bob.bio.base.git
bob.pad.base = git git@gitlab.idiap.ch:bob/bob.pad.base.git
bob.db.voicepa = git git@gitlab.idiap.ch:bob/bob.db.voicepa.git
bob.extension = git git@gitlab.idiap.ch:bob/bob.extension.git
bob.learn.tensorflow = git branch=epoch git@gitlab.idiap.ch:bob/bob.learn.tensorflow.git
bob.bio.db = git git@gitlab.idiap.ch:bob/bob.bio.db.git
bob.pad.db = git git@gitlab.idiap.ch:bob/bob.pad.db.git
bob.bio.spear = git git@gitlab.idiap.ch:bob/bob.bio.spear.git
bob.db.cpqd_replay = git git@gitlab.idiap.ch:bob/bob.db.cpqd_replay.git
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment