Commit 0403717b authored by Philip ABBET's avatar Philip ABBET
Browse files

[unittests] Refactoring of the 'test_cacheddata.py' file

parent a8871304
......@@ -3,7 +3,7 @@
###############################################################################
# #
# Copyright (c) 2016 Idiap Research Institute, http://www.idiap.ch/ #
# Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/ #
# Contact: beat.support@idiap.ch #
# #
# This file is part of the beat.core module of the BEAT platform. #
......@@ -26,234 +26,170 @@
###############################################################################
import unittest
import os
import glob
import tempfile
import six
import numpy
import nose.tools
from ..data import CachedDataSink, CachedDataSource, foundSplitRanges
from ..hash import hashFileContents
from ..dataformat import DataFormat
from . import prefix
testfile = None
def create_tempfile():
global testfile
testfile = tempfile.NamedTemporaryFile(prefix=__name__, suffix='.data')
testfile.close() #preserve only name
#----------------------------------------------------------
class TestCachedDataBase(unittest.TestCase):
def erase_tempfiles():
global testfile
basename, data_ext = os.path.splitext(testfile.name)
filenames = [testfile.name]
filenames += glob.glob(basename + '*' + data_ext)
filenames += glob.glob(basename + '*' + data_ext + '.checksum')
filenames += glob.glob(basename + '*.index')
filenames += glob.glob(basename + '*.index.checksum')
for filename in filenames:
if os.path.exists(filename): os.unlink(filename)
def setUp(self):
testfile = tempfile.NamedTemporaryFile(prefix=__name__, suffix='.data')
testfile.close() # preserve only the name
self.filename = testfile.name
@nose.tools.with_setup(create_tempfile, erase_tempfiles)
def test_data_sink_creation():
dataformat = DataFormat(prefix, 'user/integers/1')
assert dataformat.valid
def tearDown(self):
basename, ext = os.path.splitext(self.filename)
filenames = [self.filename]
filenames += glob.glob(basename + '*')
for filename in filenames:
if os.path.exists(filename):
os.unlink(filename)
data_sink = CachedDataSink()
assert data_sink.setup(testfile.name, dataformat)
@nose.tools.with_setup(create_tempfile, erase_tempfiles)
def test_data_source_creation():
def writeData(self, dataformat_name, start_index=0, end_index=10):
dataformat = DataFormat(prefix, dataformat_name)
self.assertTrue(dataformat.valid)
f = open(testfile.name, 'wb')
f.write(b'json\nuser/integers/1\n')
f.close()
data_sink = CachedDataSink()
self.assertTrue(data_sink.setup(self.filename, dataformat))
chksum_data = hashFileContents(testfile.name)
all_data = []
for i in range(start_index, end_index + 1):
data = dataformat.type()
data_sink.write(data, i, i)
all_data.append(data)
with open(testfile.name + '.checksum', 'wt') as f: f.write(chksum_data)
(nb_bytes, duration) = data_sink.statistics()
self.assertTrue(nb_bytes > 0)
self.assertTrue(duration > 0)
data_source = CachedDataSource()
data_sink.close()
del data_sink
assert data_source.setup(testfile.name, prefix)
assert data_source.dataformat.valid
assert not data_source.hasMoreData()
return all_data
(data, start_index, end_index) = data_source.next()
assert data is None
assert start_index is None
assert end_index is None
data_source.close()
#----------------------------------------------------------
def test_cached_data_split():
l = [[0,2,4,6,8,10,12],[0,3,6,9,12]]
n_split = 2
ref = [(0, 5), (6, 11)]
res = foundSplitRanges(l, n_split)
nose.tools.eq_(res, ref)
class TestDataSink(TestCachedDataBase):
l = [[0,2,4,6,8,10,12,15],[0,3,6,9,12,15]]
n_split = 5
ref = [(0, 5), (6, 11), (12, 14)]
res = foundSplitRanges(l, n_split)
nose.tools.eq_(res, ref)
def test_creation(self):
dataformat = DataFormat(prefix, 'user/integers/1')
self.assertTrue(dataformat.valid)
def serialization(format_name, data_modifier=None, data_tester=None):
data_sink = CachedDataSink()
self.assertTrue(data_sink.setup(self.filename, dataformat))
dataformat = DataFormat(prefix, format_name)
assert dataformat.valid
data_sink = CachedDataSink()
assert data_sink.setup(testfile.name, dataformat)
#----------------------------------------------------------
for i in six.moves.range(0, 5):
data = dataformat.type()
if data_modifier is not None: data_modifier(data)
data_sink.write(data, i, i)
(nb_bytes, duration) = data_sink.statistics()
assert nb_bytes > 0
assert duration > 0
class TestDataSource(TestCachedDataBase):
data_sink.close()
del data_sink
def test_creation(self):
self.writeData('user/integers/1')
data_source = CachedDataSource()
assert data_source.setup(testfile.name, prefix)
data_source = CachedDataSource()
for i in six.moves.range(0, 5):
assert data_source.hasMoreData()
self.assertTrue(data_source.setup(self.filename, prefix))
self.assertTrue(data_source.dataformat.valid)
self.assertTrue(data_source.hasMoreData())
(data, start_index, end_index) = data_source.next()
assert data is not None
nose.tools.eq_(start_index, i)
nose.tools.eq_(end_index, i)
data_source.close()
if data_tester is not None: data_tester(data)
assert not data_source.hasMoreData()
def perform_deserialization(self, dataformat_name, start_index=0, end_index=10):
reference = self.writeData(dataformat_name) # Always generate 10 data units
(nb_bytes, duration) = data_source.statistics()
assert nb_bytes > 0
assert duration > 0
data_source = CachedDataSource()
data_source.close()
self.assertTrue(data_source.setup(self.filename, prefix,
force_start_index=start_index, force_end_index=end_index))
@nose.tools.with_setup(create_tempfile, erase_tempfiles)
def test_integers():
self.assertTrue(data_source.dataformat.valid)
def data_modifier(data):
data.value8 = numpy.int8(2**6)
data.value16 = numpy.int16(2**14)
data.value32 = numpy.int32(2**30)
data.value64 = numpy.int64(2**62)
for i in range(start_index, end_index + 1):
self.assertTrue(data_source.hasMoreData())
def data_tester(data):
nose.tools.eq_(data.value8, numpy.int8(2**6))
nose.tools.eq_(data.value16, numpy.int16(2**14))
nose.tools.eq_(data.value32, numpy.int32(2**30))
nose.tools.eq_(data.value64, numpy.int64(2**62))
(data, start, end) = data_source.next()
self.assertTrue(data is not None)
serialization('user/integers/1',
data_modifier=data_modifier, data_tester=data_tester)
self.assertEqual(i, start)
self.assertEqual(i, end)
self.assertEqual(reference[i].as_dict(), data.as_dict())
@nose.tools.with_setup(create_tempfile, erase_tempfiles)
def test_objects():
serialization('user/two_objects/1')
self.assertFalse(data_source.hasMoreData())
@nose.tools.with_setup(create_tempfile, erase_tempfiles)
def test_hierarchy_of_objects():
serialization('user/hierarchy_of_objects/1')
(nb_bytes, duration) = data_source.statistics()
self.assertTrue(nb_bytes > 0)
self.assertTrue(duration > 0)
@nose.tools.with_setup(create_tempfile, erase_tempfiles)
def test_3d_array_of_integers():
serialization('user/3d_array_of_integers/1')
data_source.close()
@nose.tools.with_setup(create_tempfile, erase_tempfiles)
def test_3d_array_of_objects():
serialization('user/3d_array_of_objects/1')
def splitting(format_name, data_modifier=None, data_tester=None):
def test_integers(self):
self.perform_deserialization('user/integers/1')
dataformat = DataFormat(prefix, format_name)
assert dataformat.valid
data_sink = CachedDataSink()
assert data_sink.setup(testfile.name, dataformat)
def test_objects(self):
self.perform_deserialization('user/two_objects/1')
for i in six.moves.range(0, 10):
data = dataformat.type()
if data_modifier is not None: data_modifier(data)
data_sink.write(data, i, i)
(nb_bytes, duration) = data_sink.statistics()
assert nb_bytes > 0
assert duration > 0
def test_hierarchy_of_objects(self):
self.perform_deserialization('user/hierarchy_of_objects/1')
data_sink.close()
force_start_index = 3
force_end_index = 8
data_source = CachedDataSource()
assert data_source.setup(testfile.name, prefix, force_start_index=force_start_index, force_end_index=force_end_index)
def test_3d_array_of_integers(self):
self.perform_deserialization('user/3d_array_of_integers/1')
for i in six.moves.range(force_start_index, force_end_index+1):
assert data_source.hasMoreData()
(data, start_index, end_index) = data_source.next()
assert data is not None
nose.tools.eq_(start_index, i)
nose.tools.eq_(end_index, i)
def test_3d_array_of_objects(self):
self.perform_deserialization('user/3d_array_of_objects/1')
if data_tester is not None: data_tester(data)
assert not data_source.hasMoreData()
def test_integers_slice_1(self):
self.perform_deserialization('user/integers/1', 0, 4)
(nb_bytes, duration) = data_source.statistics()
assert nb_bytes > 0
assert duration > 0
data_source.close()
def test_integers_slice_2(self):
self.perform_deserialization('user/integers/1', 3, 6)
@nose.tools.with_setup(create_tempfile, erase_tempfiles)
def test_integers_splitting():
def test_integers_slice_3(self):
self.perform_deserialization('user/integers/1', 7, 10)
def data_modifier(data):
data.value8 = numpy.int8(2**6)
data.value16 = numpy.int16(2**14)
data.value32 = numpy.int32(2**30)
data.value64 = numpy.int64(2**62)
def data_tester(data):
nose.tools.eq_(data.value8, numpy.int8(2**6))
nose.tools.eq_(data.value16, numpy.int16(2**14))
nose.tools.eq_(data.value32, numpy.int32(2**30))
nose.tools.eq_(data.value64, numpy.int64(2**62))
#----------------------------------------------------------
splitting('user/integers/1',
data_modifier=data_modifier, data_tester=data_tester)
@nose.tools.with_setup(create_tempfile, erase_tempfiles)
def test_objects_splitting():
splitting('user/two_objects/1')
class TestFoundSplitRanges(unittest.TestCase):
@nose.tools.with_setup(create_tempfile, erase_tempfiles)
def test_hierarchy_of_objects_splitting():
splitting('user/hierarchy_of_objects/1')
def test_2_splits(self):
l = [[0,2,4,6,8,10,12], [0,3,6,9,12]]
n_split = 2
ref = [(0, 5), (6, 11)]
res = foundSplitRanges(l, n_split)
self.assertEqual(res, ref)
@nose.tools.with_setup(create_tempfile, erase_tempfiles)
def test_3d_array_of_integers_splitting():
splitting('user/3d_array_of_integers/1')
@nose.tools.with_setup(create_tempfile, erase_tempfiles)
def test_3d_array_of_objects_splitting():
splitting('user/3d_array_of_objects/1')
def test_5_splits(self):
l = [[0,2,4,6,8,10,12,15], [0,3,6,9,12,15]]
n_split = 5
ref = [(0, 5), (6, 11), (12, 14)]
res = foundSplitRanges(l, n_split)
self.assertEqual(res, ref)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment