-
André Anjos authoredAndré Anjos authored
test_roll.py 1.22 KiB
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# Laurent El Shafey <Laurent.El-Shafey@idiap.ch>
# Tue Jun 25 20:44:40 CEST 2013
#
# Copyright (C) 2011-2013 Idiap Research Institute, Martigny, Switzerland
"""Tests on the roll/unroll functions
"""
import unittest
import bob
import numpy
class RollTest(unittest.TestCase):
"""Performs various roll/unroll tests."""
def test01_roll(self):
m = bob.machine.MLP((10,3,8,5))
m.randomize()
vec = bob.machine.unroll(m)
m2 = bob.machine.MLP((10,3,8,5))
bob.machine.roll(m2, vec)
self.assertTrue( m == m2 )
m3 = bob.machine.MLP((10,3,8,5))
vec = bob.machine.unroll(m.weights, m.biases)
bob.machine.roll(m3, vec)
self.assertTrue( m == m3 )
def test02_roll(self):
w = [numpy.array([[2,3.]]), numpy.array([[2,3,4.],[5,6,7]])]
b = [numpy.array([5.,]), numpy.array([7,8.])]
vec = numpy.ndarray(11, numpy.float64)
bob.machine.unroll(w, b, vec)
w_ = [numpy.ndarray((1,2), numpy.float64), numpy.ndarray((2,3), numpy.float64)]
b_ = [numpy.ndarray(1, numpy.float64), numpy.ndarray(2, numpy.float64)]
bob.machine.roll(w_, b_, vec)
self.assertTrue( (w_[0] == w[0]).all() )
self.assertTrue( (b_[0] == b[0]).all() )