test_dnn.py 1.78 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# Milos Cernak <milos.cernak@idiap.ch>
# September 1, 2017

'''Tests for Kaldi bindings'''

import pkg_resources
import numpy as np
import bob.io.audio

import bob.kaldi


def test_forward_pass():

    sample = pkg_resources.resource_filename(__name__, 'data/sample16k.wav')
    nnetfile   = pkg_resources.resource_filename(__name__, 'dnn/ami.nnet.txt')
    transfile = pkg_resources.resource_filename(__name__,
    'dnn/ami.feature_transform.txt')
    
    reference = pkg_resources.resource_filename(
        __name__, 'data/sample16k-posteriors.txt')

    data = bob.io.audio.reader(sample)

    feats = bob.kaldi.cepstral(data.load()[0], 'mfcc', data.rate,
           normalization=False)

    with open(nnetfile) as nnetf, \
        open(transfile) as trnf:
        dnn = nnetf.read()
        trn = trnf.read()
        ours = bob.kaldi.nnet_forward(feats, dnn, trn)
    
    theirs = np.loadtxt(reference)

    assert ours.shape == theirs.shape

    assert np.allclose(ours, theirs, 1e-03, 1e-05)

def test_compute_dnn_vad():

    sample = pkg_resources.resource_filename(__name__, 'data/sample16k.wav')
    reference = pkg_resources.resource_filename(
        __name__, 'data/sample16k-dnn-vad.txt')

    data = bob.io.audio.reader(sample)

    ours = bob.kaldi.compute_dnn_vad(data.load()[0], data.rate)
    theirs = np.loadtxt(reference)

    assert np.allclose(ours, theirs)

Milos CERNAK's avatar
Milos CERNAK committed
55 56 57 58 59 60 61 62 63 64 65 66
def test_compute_dnn_phone():

    sample = pkg_resources.resource_filename(__name__, 'data/librivox.wav')

    data = bob.io.audio.reader(sample)

    post, labs = bob.kaldi.compute_dnn_phone(data.load()[0], data.rate)

    mdecoding=np.argmax(post,axis=1) # max decoding

    # check if the last spoken sound at frame 250 is 'N' (word DOMAIN)
    assert(labs[mdecoding[250]]=='N')